質問編集履歴

4

添付コードの不具合の修正

2018/05/06 15:08

投稿

hukuda222
hukuda222

スコア13

test CHANGED
File without changes
test CHANGED
@@ -144,7 +144,7 @@
144
144
 
145
145
  def __len__(self):
146
146
 
147
- return 2
147
+ return 1
148
148
 
149
149
 
150
150
 

3

記載したコードの抜けを追記しました

2018/05/06 15:08

投稿

hukuda222
hukuda222

スコア13

test CHANGED
File without changes
test CHANGED
@@ -108,6 +108,32 @@
108
108
 
109
109
 
110
110
 
111
+ class VGG(Chain):
112
+
113
+ def __init__(self):
114
+
115
+ super(VGG, self).__init__()
116
+
117
+
118
+
119
+ with self.init_scope():
120
+
121
+ self.base = L.VGG16Layers()
122
+
123
+ self.classify = L.Linear(None, 20)
124
+
125
+
126
+
127
+ def __call__(self, x):
128
+
129
+ h = self.base(x, layers=['fc7'])['fc7']
130
+
131
+ return self.classify(h)
132
+
133
+
134
+
135
+
136
+
111
137
  class DataSet(dataset.DatasetMixin):
112
138
 
113
139
  def __init__(self):
@@ -118,7 +144,7 @@
118
144
 
119
145
  def __len__(self):
120
146
 
121
- return 10
147
+ return 2
122
148
 
123
149
 
124
150
 

2

エラーが出る具体的なコードを記載しました

2018/05/06 13:53

投稿

hukuda222
hukuda222

スコア13

test CHANGED
File without changes
test CHANGED
@@ -80,18 +80,172 @@
80
80
 
81
81
 
82
82
 
83
- ## 追記
83
+ ## エラーが再現できたコード
84
-
84
+
85
- snapshotの書き出しと、読み込みは以下のコードで行なってい
85
+ 以下のコードで同様のエラーが発生しした
86
86
 
87
87
  ```python
88
88
 
89
+ import numpy as np
90
+
91
+ import chainer.links as L
92
+
93
+ import chainer.functions as F
94
+
95
+ from chainer import dataset, Chain, training, optimizers, \
96
+
97
+ iterators, reporter, cuda,serializers
98
+
99
+ import argparse
100
+
101
+ if cuda.available:
102
+
103
+ xp = cuda.cupy
104
+
105
+ else:
106
+
107
+ xp = np
108
+
109
+
110
+
111
+ class DataSet(dataset.DatasetMixin):
112
+
113
+ def __init__(self):
114
+
115
+ pass
116
+
117
+
118
+
119
+ def __len__(self):
120
+
121
+ return 10
122
+
123
+
124
+
125
+ def get_example(self, _):
126
+
127
+ return xp.ones((3, 224, 224)).astype('float32'), xp.zeros((1,)).astype('int32')[0]
128
+
129
+
130
+
131
+
132
+
133
+ def main():
134
+
135
+ parser = argparse.ArgumentParser()
136
+
137
+ parser.add_argument('--epoch', '-e', type=int, default=2,
138
+
139
+ help='Number of examples in epoch')
140
+
141
+ parser.add_argument('--batchsize', '-b', type=int, default=1,
142
+
143
+ help='Number of examples in each mini-batch')
144
+
145
+ parser.add_argument('--gpu', '-g', type=int, default=-1,
146
+
147
+ help='GPU ID (negative value indicates CPU)')
148
+
149
+ parser.add_argument('--out', '-o', default='result2',
150
+
151
+ help='Directory to output the result')
152
+
153
+ parser.add_argument('--resume', '-r', default='',
154
+
155
+ help='Resume the training from snapshot')
156
+
157
+
158
+
159
+ args = parser.parse_args()
160
+
161
+
162
+
163
+ train_dataset = DataSet()
164
+
165
+
166
+
167
+ model = L.Classifier(VGG())
168
+
169
+
170
+
171
+
172
+
173
+ if args.gpu >= 0:
174
+
175
+ cuda.get_device_from_id(args.gpu).use()
176
+
177
+ model.to_gpu()
178
+
179
+
180
+
181
+ optimizer = optimizers.Adam()
182
+
183
+ optimizer.setup(model)
184
+
185
+ model.predictor.base.disable_update()
186
+
187
+
188
+
189
+ train_iter = iterators.SerialIterator(
190
+
191
+ train_dataset, batch_size=args.batchsize)
192
+
193
+
194
+
195
+ updater = training.StandardUpdater(train_iter, optimizer)
196
+
197
+ trainer = training.Trainer(
198
+
199
+ updater, (args.epoch, 'epoch'), out=args.out)
200
+
201
+
202
+
203
+ trainer.extend(training.extensions.LogReport(
204
+
205
+ trigger=(1, 'epoch')))
206
+
207
+ trainer.extend(training.extensions.PrintReport(
208
+
209
+ entries=['iteration', 'main/loss',
210
+
211
+ 'main/accuracy', 'elapsed_time']),
212
+
213
+ trigger=(1, 'epoch'))
214
+
215
+ # ここでsnapshotを取っています。
216
+
89
- trainer.extend(training.extensions.snapshot(),
217
+ trainer.extend(training.extensions.snapshot(),
90
-
218
+
91
- trigger=((1, 'epoch')))
219
+ trigger=(1, 'epoch'))
92
-
220
+
221
+
222
+
93
- if args.resume:
223
+ if args.resume:
224
+
94
-
225
+      # ここで読み込んでいます
226
+
95
- chainer.serializers.load_npz(args.resume, trainer)
227
+ serializers.load_npz(args.resume, trainer)
228
+
96
-
229
+ trainer.run()
230
+
231
+
232
+
233
+
234
+
235
+ if __name__ == "__main__":
236
+
237
+ main()
238
+
97
- ````
239
+ ```
240
+
241
+
242
+
243
+ 以下のように実行すると上記のエラーが生じました。
244
+
245
+ ```
246
+
247
+ python test.py --gpu 0
248
+
249
+ python test.py --gpu 0 --resume ./result2/snapshot_iter_1
250
+
251
+ ```

1

追記の依頼があったらため、エラーの原因と考えられるコードを加筆しました。

2018/05/06 13:50

投稿

hukuda222
hukuda222

スコア13

test CHANGED
File without changes
test CHANGED
@@ -77,3 +77,21 @@
77
77
 
78
78
 
79
79
  このエラーの対策をご存知の方がいらっしゃれば、ご教授お願いします。
80
+
81
+
82
+
83
+ ## 追記
84
+
85
+ snapshotの書き出しと、読み込みは以下のコードで行なっています。
86
+
87
+ ```python
88
+
89
+ trainer.extend(training.extensions.snapshot(),
90
+
91
+ trigger=((1, 'epoch')))
92
+
93
+ if args.resume:
94
+
95
+ chainer.serializers.load_npz(args.resume, trainer)
96
+
97
+ ````