質問編集履歴

3

ご回答に対して修正した部分をのせました。

2019/01/11 03:01

投稿

pyn
pyn

スコア13

test CHANGED
File without changes
test CHANGED
@@ -207,3 +207,59 @@
207
207
  https://github.com/owruby/shake-shake_pytorch
208
208
 
209
209
  から借りています。
210
+
211
+
212
+
213
+
214
+
215
+ ご指摘の点について
216
+
217
+ ShakeBlock内のforwardを
218
+
219
+ ```
220
+
221
+ def forward(self, x, y):
222
+
223
+ h1 = self.branch1(x)
224
+
225
+ h2 = self.branch2(x)
226
+
227
+ h = ShakeShake.apply(h1, h2, self.training)
228
+
229
+ h0 = x if self.equal_io else self.shortcut(x)
230
+
231
+
232
+
233
+ return h + h0, h + h0
234
+
235
+ ```
236
+
237
+ ShakeResNet内のforwardを
238
+
239
+ ```
240
+
241
+ def forward(self, x):
242
+
243
+ h = self.c_in(x)
244
+
245
+ h, h = self.layer1(h, h)
246
+
247
+ h, h = self.layer2(h, h)
248
+
249
+ h, h = self.layer3(h, h)
250
+
251
+ h = F.relu(h)
252
+
253
+ h = F.avg_pool2d(h, 8)
254
+
255
+ h = h.view(-1, self.in_chs[3])
256
+
257
+ h = self.fc_out(h)
258
+
259
+
260
+
261
+ return h
262
+
263
+ ```
264
+
265
+ とそれぞれ変更しましたが、同じエラーが出ます。

2

コード元のアドレスを追加しました。

2019/01/11 03:01

投稿

pyn
pyn

スコア13

test CHANGED
File without changes
test CHANGED
@@ -199,3 +199,11 @@
199
199
 
200
200
 
201
201
  というエラーが出ます。
202
+
203
+
204
+
205
+ コードは
206
+
207
+ https://github.com/owruby/shake-shake_pytorch
208
+
209
+ から借りています。

1

コードを追加しました。

2019/01/10 09:29

投稿

pyn
pyn

スコア13

test CHANGED
File without changes
test CHANGED
@@ -31,3 +31,171 @@
31
31
  というエラーが出ます。
32
32
 
33
33
  forwardの引数は、必ず(self, x)のままでないといけないのでしょうか?
34
+
35
+
36
+
37
+ ```
38
+
39
+ class ShakeBlock(nn.Module):
40
+
41
+ def __init__(self, in_ch, out_ch, stride=1):
42
+
43
+ super(ShakeBlock, self).__init__()
44
+
45
+ self.equal_io = in_ch == out_ch
46
+
47
+ self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride)
48
+
49
+
50
+
51
+ self.branch1 = self._make_branch(in_ch, out_ch, stride)
52
+
53
+ self.branch2 = self._make_branch(in_ch, out_ch, stride)
54
+
55
+
56
+
57
+ def forward(self, x, y):
58
+
59
+ h1 = self.branch1(x)
60
+
61
+ h2 = self.branch2(x)
62
+
63
+ h = ShakeShake.apply(h1, h2, self.training)
64
+
65
+ h0 = x if self.equal_io else self.shortcut(x)
66
+
67
+
68
+
69
+ return h + h0
70
+
71
+
72
+
73
+ def _make_branch(self, in_ch, out_ch, stride=1):
74
+
75
+ return nn.Sequential(
76
+
77
+ nn.ReLU(inplace=False),
78
+
79
+ nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False),
80
+
81
+ nn.BatchNorm2d(out_ch),
82
+
83
+ nn.ReLU(inplace=False),
84
+
85
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False),
86
+
87
+ nn.BatchNorm2d(out_ch))
88
+
89
+
90
+
91
+
92
+
93
+ class ShakeResNet(nn.Module):
94
+
95
+ def __init__(self, depth, num_classes):
96
+
97
+ super(ShakeResNet, self).__init__()
98
+
99
+ n_units = (depth - 2) / 6
100
+
101
+ w_base = 32
102
+
103
+ in_chs = [16, w_base, w_base * 2, w_base * 4]
104
+
105
+
106
+
107
+ self.in_chs = in_chs
108
+
109
+
110
+
111
+ self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1)
112
+
113
+ self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1])
114
+
115
+ self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2)
116
+
117
+ self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2)
118
+
119
+ self.fc_out = nn.Linear(in_chs[3], num_classes)
120
+
121
+
122
+
123
+ # Initialize paramters
124
+
125
+ for m in self.modules():
126
+
127
+ if isinstance(m, nn.Conv2d):
128
+
129
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130
+
131
+ m.weight.data.normal_(0, math.sqrt(2. / n))
132
+
133
+ elif isinstance(m, nn.BatchNorm2d):
134
+
135
+ m.weight.data.fill_(1)
136
+
137
+ m.bias.data.zero_()
138
+
139
+ elif isinstance(m, nn.Linear):
140
+
141
+ m.bias.data.zero_()
142
+
143
+
144
+
145
+ def forward(self, x):
146
+
147
+ h = self.c_in(x)
148
+
149
+ h = self.layer1(h, h)
150
+
151
+ h = self.layer2(h)
152
+
153
+ h = self.layer3(h)
154
+
155
+ h = F.relu(h)
156
+
157
+ h = F.avg_pool2d(h, 8)
158
+
159
+ h = h.view(-1, self.in_chs[3])
160
+
161
+ h = self.fc_out(h)
162
+
163
+
164
+
165
+ return h
166
+
167
+
168
+
169
+ def _make_layer(self, n_units, in_ch, out_ch, stride=1):
170
+
171
+ layers = []
172
+
173
+ for i in range(int(n_units)):
174
+
175
+ layers.append(ShakeBlock(in_ch, out_ch, stride=stride))
176
+
177
+ in_ch, stride = out_ch, 1
178
+
179
+
180
+
181
+ return nn.Sequential(*layers)
182
+
183
+ ```
184
+
185
+
186
+
187
+ 仮にh = self.layer1(h, h)として
188
+
189
+ ShakeBlock内のforward(self, x, y)を呼び出していますが、
190
+
191
+ 実行すると、
192
+
193
+
194
+
195
+ result = self.forward(*input, **kwargs)
196
+
197
+ TypeError: forward() takes 2 positional arguments but 3 were given
198
+
199
+
200
+
201
+ というエラーが出ます。