Python でシグネチャを保持するデコレータ(inspect.signature の活用)

functools.wraps を使うだけでは、デコレートされた関数のシグネチャが失われます。inspect.signature を活用すると、元のシグネチャを完全に保持できます。

問題の確認

from functools import wraps
import inspect

def decorator(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@decorator
def greet(name: str, greeting: str = "Hello") -> str:
    return f"{greeting}, {name}!"

print(inspect.signature(greet))
# (*args, **kwargs) ← 元のシグネチャが失われている

wraps はメタデータをコピーしますが、シグネチャは wrapper のものになります。

signature を設定する

from functools import wraps
import inspect

def decorator(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    wrapper.__signature__ = inspect.signature(func)
    return wrapper

@decorator
def greet(name: str, greeting: str = "Hello") -> str:
    return f"{greeting}, {name}!"

print(inspect.signature(greet))
# (name: str, greeting: str = 'Hello') -> str

引数を追加するデコレータ

新しい引数を追加する場合、シグネチャも更新する必要があります。

from functools import wraps
import inspect

def with_debug(func):
    sig = inspect.signature(func)
    new_param = inspect.Parameter(
        "debug",
        inspect.Parameter.KEYWORD_ONLY,
        default=False
    )
    new_params = list(sig.parameters.values()) + [new_param]
    new_sig = sig.replace(parameters=new_params)
    
    @wraps(func)
    def wrapper(*args, debug=False, **kwargs):
        if debug:
            print(f"Calling {func.__name__} with args={args}, kwargs={kwargs}")
        return func(*args, **kwargs)
    
    wrapper.__signature__ = new_sig
    return wrapper

@with_debug
def add(a: int, b: int) -> int:
    return a + b

print(inspect.signature(add))
# (a: int, b: int, *, debug=False) -> int

引数の検証にシグネチャを活用

def validated(func):
    sig = inspect.signature(func)
    
    @wraps(func)
    def wrapper(*args, **kwargs):
        bound = sig.bind(*args, **kwargs)
        bound.apply_defaults()
        return func(*bound.args, **bound.kwargs)
    
    wrapper.__signature__ = sig
    return wrapper

sig.bind() は引数をパラメータにマッピングし、不足や過剰を検出します。シグネチャを保持することで、IDE の補完やドキュメント生成も正しく動作します。