世界の国560595 views
中学理科1626207 views
いろは2986023 views
Computer365120 views
MathPython491378 views
高校日本史189857 views
高校生物549842 views
教育148875 views
中学英語808712 views
高校国語785655 views
Help
Tools

English

PyTorch の nn.Module - モデル定義の基本

PyTorch でニューラルネットワークのモデルを定義するとき、中心となるのが nn.Module クラスだ。すべてのモデルはこのクラスを継承して作られる。nn.Module はパラメータ管理、デバイス移動、学習・推論モードの切り替えなど、モデルに必要な機能を一手に引き受けている。

最小限のモデル定義

nn.Module を継承したクラスでは、init でレイヤーを定義し、forward でデータの流れを記述する。この 2 つを書くだけで、PyTorch のモデルとして機能する。

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
print(model)

super().init() の呼び出しは必須で、これを忘れると PyTorch がパラメータを正しく追跡できなくなる。forward メソッドはモデルにデータを渡したとき自動的に呼ばれるため、model(x) と書くだけで forward(x) が実行される仕組みになっている。

パラメータの自動管理

nn.Module の大きな利点は、内部に定義したレイヤーのパラメータを自動で収集してくれることにある。parameters() メソッドを使えば、すべての学習可能パラメータにアクセスできる。

model = SimpleModel()

for name, param in model.named_parameters():
    print(f"{name}: shape={param.shape}")
手動でのパラメータ管理

重みやバイアスを自分で変数に保持し、オプティマイザに一つずつ渡す必要がある

nn.Module のパラメータ管理

レイヤーを属性として定義するだけで、parameters() が全パラメータを自動収集する

この自動収集はネストしたモデルでも機能する。nn.Module の中に別の nn.Module を持たせた場合、親モデルの parameters() で子モデルのパラメータもまとめて取得される。

学習モードと推論モード

nn.Module には train() と eval() という 2 つのモード切り替えメソッドがある。Dropout や BatchNorm のように、学習時と推論時で振る舞いが変わるレイヤーを正しく動作させるために使う。

model.train()   # 学習モード(Dropout が有効)
# ... 学習処理 ...

model.eval()    # 推論モード(Dropout が無効)
with torch.no_grad():
    output = model(test_data)

eval() への切り替えを忘れると、推論時にも Dropout が適用されてしまい、結果が毎回変わるという不具合が起きる。推論時には torch.no_grad() と組み合わせて使うのが定番のパターンだ。

デバイス間の移動

GPU を使う場合、モデルとデータを同じデバイスに配置する必要がある。nn.Module の to() メソッドで、モデル内の全パラメータを一括で移動できる。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleModel().to(device)

# 入力データも同じデバイスに移動
x = torch.randn(5, 10).to(device)
output = model(x)

to() は破壊的操作で、呼び出したモデル自体が書き換わる。テンソルの to() がコピーを返すのとは異なる点に注意が必要になる。

モデルの保存と読み込み

学習済みモデルの保存には state_dict() を使うのが推奨されている。state_dict() はパラメータ名と値の辞書を返すため、モデル構造とパラメータを分離して管理できる。

# 保存
torch.save(model.state_dict(), "model.pth")

# 読み込み
model = SimpleModel()
model.load_state_dict(torch.load("model.pth"))

torch.save(model) でモデル全体を保存することもできるが、この方法はpickleに依存するため、異なる環境やコードの変更に弱い。state_dict() による保存が公式に推奨されている。

Python のオブジェクト直列化ライブラリ。クラス定義が変わると読み込みに失敗する。

nn.Module はこうした基本機能を透過的に提供しており、開発者はモデルのアーキテクチャ設計に集中できる。PyTorch のモデル構築はすべてこのクラスから始まるため、その仕組みを理解しておくことが重要だ。