みーのぺーじ

みーが趣味でやっているPCやソフトウェアについて.Python, Javascript, Processing, Unityなど.

PyTorch でモデル構築の比較

PyTorch でモデルを構築する方法として,

  1. torch.nn.Sequential を使用する
  2. torch.nn.Module のサブクラスを作る
  3. torch.nn.Module のサブクラスで torch.nn.functional を使用する

の3個の方法があります*1

1.torch.nn.Sequential を使用する

nn.Sequential() に順番に記載していけばよいので簡単です*2

from torch import nn

n = 4

ModelA = nn.Sequential(
    nn.Linear(1, n),
    nn.ReLU(),
    nn.Linear(n, 1),
)

print(ModelA)

実行すると以下の出力が得られました.

Sequential(
  (0): Linear(in_features=1, out_features=4, bias=True)
  (1): ReLU()
  (2): Linear(in_features=4, out_features=1, bias=True)
)

2. torch.nn.Module のサブクラスを作る

ReLU は,torch.nn.Module のサブクラスである torch.nn.ReLU() を使用します*3

class ModelB(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(1, n)
        self.act1 = nn.ReLU()
        self.linear2 = nn.Linear(n, 1)

    def forward(self, input):
        output = self.linear1(input)
        output = self.act1(output)
        output = self.linear2(output)
        return output


print(ModelB())

実行すると以下の出力が得られました.

ModelB(
  (linear1): Linear(in_features=1, out_features=4, bias=True)
  (act1): ReLU()
  (linear2): Linear(in_features=4, out_features=1, bias=True)
)

3. torch.nn.Module のサブクラスで torch.nn.functional を使用する

ReLU は,torch.nn.functional.relu() 関数を使用します*4

class ModelC(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(1, n)
        self.linear2 = nn.Linear(n, 1)

    def forward(self, input):
        output = self.linear1(input)
        output = self.linear2(nn.functional.relu(output))
        return output


print(ModelC())

実行すると以下の出力が得られました.

ModelC(
  (linear1): Linear(in_features=1, out_features=4, bias=True)
  (linear2): Linear(in_features=4, out_features=1, bias=True)
)

まとめ

いずれの方法でも同じモデルを構築できますが,書きやすさや汎用性が異なるので,使いやすい方法を適宜選択すればよさそうです.