PyTorch でニューラルネットワークのモデルを定義するとき、中心となるのが nn.Module クラスだ。すべてのモデルはこのクラスを継承して作られる。nn.Module はパラメータ管理、デバイス移動、学習・推論モードの切り替えなど、モデルに必要な機能を一手に引き受けている。
最小限のモデル定義
nn.Module を継承したクラスでは、init でレイヤーを定義し、forward でデータの流れを記述する。この 2 つを書くだけで、PyTorch のモデルとして機能する。
import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) def forward(self, x): return self.linear(x) model = SimpleModel() print(model)
super().init() の呼び出しは必須で、これを忘れると PyTorch がパラメータを正しく追跡できなくなる。forward メソッドはモデルにデータを渡したとき自動的に呼ばれるため、model(x) と書くだけで forward(x) が実行される仕組みになっている。
パラメータの自動管理
nn.Module の大きな利点は、内部に定義したレイヤーのパラメータを自動で収集してくれることにある。parameters() メソッドを使えば、すべての学習可能パラメータにアクセスできる。
model = SimpleModel() for name, param in model.named_parameters(): print(f"{name}: shape={param.shape}")
重みやバイアスを自分で変数に保持し、オプティマイザに一つずつ渡す必要がある
レイヤーを属性として定義するだけで、parameters() が全パラメータを自動収集する
この自動収集はネストしたモデルでも機能する。nn.Module の中に別の nn.Module を持たせた場合、親モデルの parameters() で子モデルのパラメータもまとめて取得される。
学習モードと推論モード
nn.Module には train() と eval() という 2 つのモード切り替えメソッドがある。Dropout や BatchNorm のように、学習時と推論時で振る舞いが変わるレイヤーを正しく動作させるために使う。
model.train() # 学習モード(Dropout が有効) # ... 学習処理 ... model.eval() # 推論モード(Dropout が無効) with torch.no_grad(): output = model(test_data)
eval() への切り替えを忘れると、推論時にも Dropout が適用されてしまい、結果が毎回変わるという不具合が起きる。推論時には torch.no_grad() と組み合わせて使うのが定番のパターンだ。
デバイス間の移動
GPU を使う場合、モデルとデータを同じデバイスに配置する必要がある。nn.Module の to() メソッドで、モデル内の全パラメータを一括で移動できる。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleModel().to(device) # 入力データも同じデバイスに移動 x = torch.randn(5, 10).to(device) output = model(x)
to() は破壊的操作で、呼び出したモデル自体が書き換わる。テンソルの to() がコピーを返すのとは異なる点に注意が必要になる。
モデルの保存と読み込み
学習済みモデルの保存には state_dict() を使うのが推奨されている。state_dict() はパラメータ名と値の辞書を返すため、モデル構造とパラメータを分離して管理できる。
# 保存 torch.save(model.state_dict(), "model.pth") # 読み込み model = SimpleModel() model.load_state_dict(torch.load("model.pth"))
torch.save(model) でモデル全体を保存することもできるが、この方法はpickleに依存するため、異なる環境やコードの変更に弱い。state_dict() による保存が公式に推奨されている。
Python のオブジェクト直列化ライブラリ。クラス定義が変わると読み込みに失敗する。
nn.Module はこうした基本機能を透過的に提供しており、開発者はモデルのアーキテクチャ設計に集中できる。PyTorch のモデル構築はすべてこのクラスから始まるため、その仕組みを理解しておくことが重要だ。