Python のトランポリン(末尾再帰の手動最適化)

トランポリンは末尾再帰をループに変換し、スタックオーバーフローを防ぐ手法です。Python は末尾再帰最適化をしないため、深い再帰にはこの技法が有効です。

問題の再確認

def factorial(n, acc=1):
    if n <= 1:
        return acc
    return factorial(n - 1, acc * n)

factorial(10000)  # RecursionError

末尾再帰の形でも、Python ではスタックを消費します。

トランポリンの仕組み

関数が直接再帰呼び出しする代わりに、「次に呼び出すべき関数」を返します。トランポリン関数がこれをループで実行します。

class Thunk:
    def __init__(self, func, *args, **kwargs):
        self.func = func
        self.args = args
        self.kwargs = kwargs

def trampoline(func, *args, **kwargs):
    result = func(*args, **kwargs)
    while isinstance(result, Thunk):
        result = result.func(*result.args, **result.kwargs)
    return result

トランポリン対応の階乗

def factorial_t(n, acc=1):
    if n <= 1:
        return acc
    return Thunk(factorial_t, n - 1, acc * n)

result = trampoline(factorial_t, 10000)
print(len(str(result)))  # 35660 桁

RecursionError なしで計算できます。

デコレータ化

def trampolined(func):
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)
        while isinstance(result, Thunk):
            result = result.func(*result.args, **result.kwargs)
        return result
    return wrapper

@trampolined
def fib_t(n, a=0, b=1):
    if n == 0:
        return a
    return Thunk(fib_t, n - 1, b, a + b)

print(fib_t(10000))  # 巨大なフィボナッチ数

ジェネレータを使った方法

def trampoline_gen(gen):
    stack = [gen]
    result = None
    while stack:
        try:
            item = stack[-1].send(result)
            if hasattr(item, "__next__"):
                stack.append(item)
                result = None
            else:
                result = item
        except StopIteration as e:
            stack.pop()
            result = e.value
    return result

トランポリンは再帰的アルゴリズムを Python で安全に実行する実用的なパターンです。