teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

3

ご回答に対して修正した部分をのせました。

2019/01/11 03:01

投稿

pyn
pyn

スコア13

title CHANGED
File without changes
body CHANGED
@@ -102,4 +102,32 @@
102
102
 
103
103
  コードは
104
104
  https://github.com/owruby/shake-shake_pytorch
105
- から借りています。
105
+ から借りています。
106
+
107
+
108
+ ご指摘の点について
109
+ ShakeBlock内のforwardを
110
+ ```
111
+ def forward(self, x, y):
112
+ h1 = self.branch1(x)
113
+ h2 = self.branch2(x)
114
+ h = ShakeShake.apply(h1, h2, self.training)
115
+ h0 = x if self.equal_io else self.shortcut(x)
116
+
117
+ return h + h0, h + h0
118
+ ```
119
+ ShakeResNet内のforwardを
120
+ ```
121
+ def forward(self, x):
122
+ h = self.c_in(x)
123
+ h, h = self.layer1(h, h)
124
+ h, h = self.layer2(h, h)
125
+ h, h = self.layer3(h, h)
126
+ h = F.relu(h)
127
+ h = F.avg_pool2d(h, 8)
128
+ h = h.view(-1, self.in_chs[3])
129
+ h = self.fc_out(h)
130
+
131
+ return h
132
+ ```
133
+ とそれぞれ変更しましたが、同じエラーが出ます。

2

コード元のアドレスを追加しました。

2019/01/11 03:01

投稿

pyn
pyn

スコア13

title CHANGED
File without changes
body CHANGED
@@ -98,4 +98,8 @@
98
98
  result = self.forward(*input, **kwargs)
99
99
  TypeError: forward() takes 2 positional arguments but 3 were given
100
100
 
101
- というエラーが出ます。
101
+ というエラーが出ます。
102
+
103
+ コードは
104
+ https://github.com/owruby/shake-shake_pytorch
105
+ から借りています。

1

コードを追加しました。

2019/01/10 09:29

投稿

pyn
pyn

スコア13

title CHANGED
File without changes
body CHANGED
@@ -14,4 +14,88 @@
14
14
  TypeError: forward() takes 2 positional arguments but 4 were given
15
15
 
16
16
  というエラーが出ます。
17
- forwardの引数は、必ず(self, x)のままでないといけないのでしょうか?
17
+ forwardの引数は、必ず(self, x)のままでないといけないのでしょうか?
18
+
19
+ ```
20
+ class ShakeBlock(nn.Module):
21
+ def __init__(self, in_ch, out_ch, stride=1):
22
+ super(ShakeBlock, self).__init__()
23
+ self.equal_io = in_ch == out_ch
24
+ self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride)
25
+
26
+ self.branch1 = self._make_branch(in_ch, out_ch, stride)
27
+ self.branch2 = self._make_branch(in_ch, out_ch, stride)
28
+
29
+ def forward(self, x, y):
30
+ h1 = self.branch1(x)
31
+ h2 = self.branch2(x)
32
+ h = ShakeShake.apply(h1, h2, self.training)
33
+ h0 = x if self.equal_io else self.shortcut(x)
34
+
35
+ return h + h0
36
+
37
+ def _make_branch(self, in_ch, out_ch, stride=1):
38
+ return nn.Sequential(
39
+ nn.ReLU(inplace=False),
40
+ nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False),
41
+ nn.BatchNorm2d(out_ch),
42
+ nn.ReLU(inplace=False),
43
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False),
44
+ nn.BatchNorm2d(out_ch))
45
+
46
+
47
+ class ShakeResNet(nn.Module):
48
+ def __init__(self, depth, num_classes):
49
+ super(ShakeResNet, self).__init__()
50
+ n_units = (depth - 2) / 6
51
+ w_base = 32
52
+ in_chs = [16, w_base, w_base * 2, w_base * 4]
53
+
54
+ self.in_chs = in_chs
55
+
56
+ self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1)
57
+ self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1])
58
+ self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2)
59
+ self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2)
60
+ self.fc_out = nn.Linear(in_chs[3], num_classes)
61
+
62
+ # Initialize paramters
63
+ for m in self.modules():
64
+ if isinstance(m, nn.Conv2d):
65
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
66
+ m.weight.data.normal_(0, math.sqrt(2. / n))
67
+ elif isinstance(m, nn.BatchNorm2d):
68
+ m.weight.data.fill_(1)
69
+ m.bias.data.zero_()
70
+ elif isinstance(m, nn.Linear):
71
+ m.bias.data.zero_()
72
+
73
+ def forward(self, x):
74
+ h = self.c_in(x)
75
+ h = self.layer1(h, h)
76
+ h = self.layer2(h)
77
+ h = self.layer3(h)
78
+ h = F.relu(h)
79
+ h = F.avg_pool2d(h, 8)
80
+ h = h.view(-1, self.in_chs[3])
81
+ h = self.fc_out(h)
82
+
83
+ return h
84
+
85
+ def _make_layer(self, n_units, in_ch, out_ch, stride=1):
86
+ layers = []
87
+ for i in range(int(n_units)):
88
+ layers.append(ShakeBlock(in_ch, out_ch, stride=stride))
89
+ in_ch, stride = out_ch, 1
90
+
91
+ return nn.Sequential(*layers)
92
+ ```
93
+
94
+ 仮にh = self.layer1(h, h)として
95
+ ShakeBlock内のforward(self, x, y)を呼び出していますが、
96
+ 実行すると、
97
+
98
+ result = self.forward(*input, **kwargs)
99
+ TypeError: forward() takes 2 positional arguments but 3 were given
100
+
101
+ というエラーが出ます。