add_module
じゃなくても良い場合,というかadd_module
は後方にしか結合できないので次の選択肢をとりましょう.
Python:code
1from torch import nn
2
3class Net(nn.Module):
4 def __init__(self):
5 super().__init__()
6 self.fc1 = nn.Linear(1000, 100)
7 self.fc2 = nn.Linear(100, 10)
8 self.relu = nn.ReLU()
9 self.dropout = nn.Dropout(0.2)
10
11 def forward(self, x):
12 x = self.fc1(x)
13 x = self.relu(x)
14 x = self.dropout(x)
15 x = self.fc2(x)
16 return x
17
18class IOWrappedNet(nn.Module):
19 def __init__(self, net):
20 super().__init__()
21 self.fc0 = nn.Linear(10000, 1000)
22 self.net = net
23 self.softmax = nn.Softmax()
24
25 def forward(self, x):
26 x = self.fc0(x)
27 x = self.net(x)
28 x = self.softmax(x)
29 return x
30
31net = IOWrappedNet(Net())
32print(net)
Python:stdout
1IOWrappedNet(
2 (fc0): Linear(in_features=10000, out_features=1000, bias=True)
3 (net): Net(
4 (fc1): Linear(in_features=1000, out_features=100, bias=True)
5 (fc2): Linear(in_features=100, out_features=10, bias=True)
6 (relu): ReLU()
7 (dropout): Dropout(p=0.2, inplace=False)
8 )
9 (softmax): Softmax(dim=None)
10)