質問編集履歴

1

一部修正

2020/02/06 15:16

投稿

gorigori123
gorigori123

スコア12

test CHANGED
File without changes
test CHANGED
@@ -25,467 +25,3 @@
25
25
  total [..................................................] 0.33%
26
26
 
27
27
  this epoch [#############################.....................] 59.17%
28
-
29
-
30
-
31
-
32
-
33
-
34
-
35
-
36
-
37
- ```Python
38
-
39
- # -*- coding: utf-8 -*-
40
-
41
- """
42
-
43
-
44
-
45
-
46
-
47
- (ChainerCV: MIT License)
48
-
49
- URL: https://github.com/chainer/chainercv/blob/master/examples/ssd/train.py
50
-
51
- """
52
-
53
- #import sys
54
-
55
- #sys.setdefaultencoding('utf-8')
56
-
57
- from __future__ import division, print_function
58
-
59
-
60
-
61
- import argparse
62
-
63
- import copy
64
-
65
-
66
-
67
- import chainer
68
-
69
- import numpy as np
70
-
71
- from chainer import serializers, training
72
-
73
- from chainer.datasets import TransformDataset
74
-
75
- from chainer.optimizer import WeightDecay
76
-
77
- from chainer.training import extensions, triggers
78
-
79
- from chainercv import transforms
80
-
81
- from chainercv.datasets import VOCBboxDataset, voc_bbox_label_names
82
-
83
- from chainercv.extensions import DetectionVOCEvaluator
84
-
85
- from chainercv.links import SSD512
86
-
87
- from chainercv.links.model.ssd import (GradientScaling, multibox_loss,
88
-
89
- random_crop_with_bbox_constraints,
90
-
91
- random_distort,
92
-
93
- resize_with_random_interpolation)
94
-
95
-
96
-
97
-
98
-
99
-
100
-
101
- from image_pyramid_detection_src.bbox_dataset_from_csv import \
102
-
103
- BboxDatasetFromCsv
104
-
105
-
106
-
107
- class ConcatenatedDataset(chainer.dataset.DatasetMixin):
108
-
109
- def __init__(self, *datasets):
110
-
111
- self._datasets = datasets
112
-
113
-
114
-
115
- def __len__(self):
116
-
117
- return sum(len(dataset) for dataset in self._datasets)
118
-
119
-
120
-
121
- def get_example(self, i):
122
-
123
- if i < 0:
124
-
125
- raise IndexError
126
-
127
- for dataset in self._datasets:
128
-
129
- if i < len(dataset):
130
-
131
- return dataset[i]
132
-
133
- i -= len(dataset)
134
-
135
- raise IndexError
136
-
137
-
138
-
139
-
140
-
141
- class MultiboxTrainChain(chainer.Chain):
142
-
143
-
144
-
145
- def __init__(self, model, alpha=1, k=3):
146
-
147
- super(MultiboxTrainChain, self).__init__()
148
-
149
- with self.init_scope():
150
-
151
- self.model = model
152
-
153
- self.alpha = alpha
154
-
155
- self.k = k
156
-
157
-
158
-
159
- def __call__(self, imgs, gt_mb_locs, gt_mb_labels):
160
-
161
- mb_locs, mb_confs = self.model(imgs)
162
-
163
- loc_loss, conf_loss = multibox_loss(
164
-
165
- mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, self.k)
166
-
167
- loss = loc_loss * self.alpha + conf_loss
168
-
169
- chainer.reporter.report(
170
-
171
- {"loss": loss, "loss/loc": loc_loss, "loss/conf": conf_loss},self)
172
-
173
-
174
-
175
- return loss
176
-
177
-
178
-
179
-
180
-
181
- def get_ft_model(n_class):
182
-
183
-
184
-
185
- pretrained_model = SSD512(20,"ssd512_voc0712_trained.npz")
186
-
187
- return pretrained_model
188
-
189
-
190
-
191
- model = SSD512(n_class, "ssd_vgg16_imagenet.npz")
192
-
193
- model.extractor.copyparams(pretrained_model.extractor)
194
-
195
- model.multibox.loc.copyparams(pretrained_model.multibox.loc)
196
-
197
- return model
198
-
199
-
200
-
201
- class Transform(object):
202
-
203
-
204
-
205
- def __init__(self, coder, size, mean):
206
-
207
- # to send cpu, make a copy
208
-
209
- self.coder = copy.copy(coder)
210
-
211
- self.coder.to_cpu()
212
-
213
-
214
-
215
- self.size = size
216
-
217
- self.mean = mean
218
-
219
-
220
-
221
- def __call__(self, in_data):
222
-
223
- # 1. Color augmentation
224
-
225
- # 2. Random expansion
226
-
227
- # 3. Random cropping
228
-
229
- # 4. Resizing with random interpolation
230
-
231
- # 5. Random horizontal flipping
232
-
233
- img, bbox, label = in_data
234
-
235
-
236
-
237
- # 1. Color augmentation
238
-
239
- img = random_distort(
240
-
241
- img,
242
-
243
- brightness_delta=32,
244
-
245
- contrast_low=0.2, contrast_high=0.9,
246
-
247
- saturation_low=0.2, saturation_high=0.9,
248
-
249
- hue_delta=18)
250
-
251
-
252
-
253
- # 2. Random expansion
254
-
255
- img, param = transforms.random_expand(img, fill=self.mean, return_param=True)
256
-
257
- bbox = transforms.translate_bbox(
258
-
259
- bbox, y_offset=param["y_offset"], x_offset=param["x_offset"])
260
-
261
-
262
-
263
- # 3. Random croppin
264
-
265
- img, param = random_crop_with_bbox_constraints(
266
-
267
- img, bbox, return_param=True)
268
-
269
- bbox, param = transforms.crop_bbox(
270
-
271
- bbox, y_slice=param["y_slice"], x_slice=param["x_slice"],
272
-
273
- allow_outside_center=False, return_param=True)
274
-
275
- label = label[param["index"]]
276
-
277
-
278
-
279
- # 4. Resizing with random interpolatation
280
-
281
- _, h_size, w_size = img.shape
282
-
283
- img = resize_with_random_interpolation(img, (self.size, self.size))
284
-
285
- bbox = transforms.resize_bbox(
286
-
287
- bbox, (h_size, w_size), (self.size, self.size))
288
-
289
-
290
-
291
- # 5. Random horizontal flipping
292
-
293
- img, params = transforms.random_flip(
294
-
295
- img, x_random=True, return_param=True)
296
-
297
- bbox = transforms.flip_bbox(
298
-
299
- bbox, (self.size, self.size), x_flip=params["x_flip"])
300
-
301
-
302
-
303
- img -= self.mean
304
-
305
- mb_loc, mb_label = self.coder.encode(bbox, label)
306
-
307
-
308
-
309
- return img, mb_loc, mb_label
310
-
311
-
312
-
313
-
314
-
315
-
316
-
317
-
318
-
319
- def run(args):
320
-
321
- model = get_ft_model(args.n_class)
322
-
323
- model.use_preset("evaluate")
324
-
325
-
326
-
327
- train_chain = MultiboxTrainChain(model)
328
-
329
-
330
-
331
- if args.gpu >= 0:
332
-
333
- chainer.cuda.get_device_from_id(args.gpu).use()
334
-
335
- model.to_gpu()
336
-
337
-
338
-
339
- train = TransformDataset(
340
-
341
- # VOCBboxDataset(year="2007", split="trainval")
342
-
343
- BboxDatasetFromCsv(args.input_csv, args.rootdir),
344
-
345
- Transform(model.coder, model.insize, model.mean))
346
-
347
- train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
348
-
349
-
350
-
351
- if args.validation_csv:
352
-
353
- test = BboxDatasetFromCsv(args.validation_csv, args.eval_rootdir)
354
-
355
-
356
-
357
- test_iter = chainer.iterators.SerialIterator(
358
-
359
- test, args.batchsize, repeat=False, shuffle=False)
360
-
361
-
362
-
363
- optimizer = chainer.optimizers.MomentumSGD()
364
-
365
- optimizer.setup(train_chain)
366
-
367
- for param in train_chain.params():
368
-
369
- if param.name == "b":
370
-
371
- param.update_rule.add_hook(GradientScaling(2))
372
-
373
- else:
374
-
375
- param.update_rule.add_hook(WeightDecay(0.0005))
376
-
377
-
378
-
379
- updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
380
-
381
- trainer = training.Trainer(
382
-
383
- updater, (args.max_iteration, "iteration"), args.out)
384
-
385
- trainer.extend(
386
-
387
- extensions.ExponentialShift("lr", 0.1, init=1e-4),
388
-
389
- trigger=triggers.ManualScheduleTrigger([], "iteration"))
390
-
391
-
392
-
393
- if args.validation_csv:
394
-
395
- trainer.extend(
396
-
397
- DetectionVOCEvaluator(
398
-
399
- test_iter, model, use_07_metric=True,
400
-
401
- label_names=["Class #" + str(i) for i in range(args.n_class)]),
402
-
403
- trigger=(args.val_interval, "iteration"))
404
-
405
-
406
-
407
- log_interval = 20, "iteration"
408
-
409
- trainer.extend(extensions.LogReport(trigger=log_interval))
410
-
411
- trainer.extend(extensions.observe_lr(), trigger=log_interval)
412
-
413
- trainer.extend(extensions.PrintReport(
414
-
415
- ["epoch", "iteration", "lr",
416
-
417
- "main/loss", "main/loss/loc", "main/loss/conf", "validation/main/accuracy"]), trigger=log_interval)
418
-
419
- trainer.extend(extensions.ProgressBar(update_interval=10))
420
-
421
-
422
-
423
- trainer.extend(extensions.snapshot(), trigger=(
424
-
425
- args.save_interval, "iteration"))
426
-
427
- trainer.extend(
428
-
429
- extensions.snapshot_object(model, "model_iter_{.updater.iteration}"),
430
-
431
- trigger=(args.save_interval, "iteration"))
432
-
433
-
434
-
435
- if args.resume:
436
-
437
- serializers.load_npz(args.resume, trainer)
438
-
439
-
440
-
441
- trainer.run()
442
-
443
-
444
-
445
-
446
-
447
- def main():
448
-
449
- parser = argparse.ArgumentParser()
450
-
451
- parser.add_argument("-i", "--input-csv", required=True)
452
-
453
- parser.add_argument("-r", "--rootdir", type=str, default=".")
454
-
455
- parser.add_argument("-o", "--out", dest="out",
456
-
457
- metavar="OUTDIR", default="output_ssd")
458
-
459
- parser.add_argument("-c", "--n-class", type=int, default=1)
460
-
461
- parser.add_argument("-S", "--save-interval", type=int, default=2000)
462
-
463
- parser.add_argument("-M", "--max-iteration", type=int, default=10000)
464
-
465
- parser.add_argument("-T", "--validation-csv", type=str, default="")
466
-
467
- parser.add_argument("-e", "--eval-rootdir", type=str, default=".")
468
-
469
- parser.add_argument("-v","--val-interval", type=int, default=10000)
470
-
471
- parser.add_argument("-b", "--batchsize",metavar="BATCH_SIZE", type=int, default=32)
472
-
473
- parser.add_argument("--gpu", type=int, default=-1)
474
-
475
- parser.add_argument("--resume", metavar="SNAPSHOT")
476
-
477
- args = parser.parse_args()
478
-
479
- run(args)
480
-
481
-
482
-
483
-
484
-
485
- if __name__ == "__main__":
486
-
487
- main()
488
-
489
-
490
-
491
- ```