質問編集履歴

2

コードの一部修正

2021/03/01 12:33

投稿

SuzuAya
SuzuAya

スコア71

test CHANGED
File without changes
test CHANGED
@@ -1,20 +1,26 @@
1
- ### 前提・実現したいこと
2
-
3
- Pytorchで保存したモデルを読み込みたいのですが、エラーが発生してしまいました。
4
-
5
- 原因についてお分かりの方がいらっしゃいましたら、ご教示いただけますと幸いです。
6
-
7
-
8
-
9
1
  ### 発生している問題・エラーメッセージ
10
2
 
11
3
  ```
12
4
 
5
+ ----> 5 model = classifier(flags.aug_kwargs)
6
+
13
- ----> 3 model = model.load_state_dict(torch.load(model_path))
7
+ 6 model.load_state_dict(torch.load(model_path))
8
+
9
+
10
+
14
-
11
+ 725 result = self._slow_forward(*input, **kwargs)
12
+
15
-
13
+ 726 else:
14
+
16
-
15
+ --> 727 result = self.forward(*input, **kwargs)
16
+
17
+ 728 for hook in itertools.chain(
18
+
19
+ 729 _global_forward_hooks.values(),
20
+
21
+
22
+
17
- AttributeError: 'str' object has no attribute 'load_state_dict'
23
+ TypeError: forward() missing 1 required positional argument: 'targets'
18
24
 
19
25
  ```
20
26
 
@@ -548,8 +554,10 @@
548
554
 
549
555
  #モデルの読み込み
550
556
 
551
- model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt'
557
+ model_path = './predictor_last.pt'
558
+
552
-
559
+ model = classifier(flags.aug_kwargs)
560
+
553
- model = model.load_state_dict(torch.load(model_path))
561
+ model.load_state_dict(torch.load(model_path))
554
562
 
555
563
  ```

1

コードの追記

2021/03/01 12:33

投稿

SuzuAya
SuzuAya

スコア71

test CHANGED
File without changes
test CHANGED
@@ -1,10 +1,6 @@
1
1
  ### 前提・実現したいこと
2
2
 
3
- Pytorchで保存したモデルを読み込みたいと思っていま
3
+ Pytorchで保存したモデルを読み込みたいのですが、エラーが発生ししまいました
4
-
5
- [こちら](https://tzmi.hatenablog.com/entry/2020/03/05/222813)のGPUでの読み出し方法を参考に
6
-
7
- コードを書いてみたのですが、エラーが発生してしまいました。
8
4
 
9
5
  原因についてお分かりの方がいらっしゃいましたら、ご教示いただけますと幸いです。
10
6
 
@@ -12,18 +8,8 @@
12
8
 
13
9
  ### 発生している問題・エラーメッセージ
14
10
 
15
-
16
-
17
11
  ```
18
12
 
19
- AttributeError Traceback (most recent call last)
20
-
21
- <ipython-input-15-64589b233dc6> in <module>
22
-
23
- 1 #モデルの読み込み方法
24
-
25
- 2 model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt'
26
-
27
13
  ----> 3 model = model.load_state_dict(torch.load(model_path))
28
14
 
29
15
 
@@ -36,32 +22,534 @@
36
22
 
37
23
  ### 該当のソースコード
38
24
 
39
-
40
-
41
25
  ```Python
42
26
 
43
- import torch
44
-
45
- from torchvision import models
46
-
47
-
48
-
49
- #モデルの読み込み方法
27
+ @dataclass
28
+
29
+ class Flags:
30
+
31
+ # General
32
+
33
+ debug: bool = True
34
+
35
+ outdir: str = "results/det"
36
+
37
+ device: str = "cuda:0"
38
+
39
+
40
+
41
+ # Data config
42
+
43
+ imgdir_name: str = "vinbigdata-chest-xray-resized-png-256x256"
44
+
45
+ # split_mode: str = "all_train"
46
+
47
+ seed: int = 111
48
+
49
+ target_fold: int = 0 # 0~4
50
+
51
+ label_smoothing: float = 0.0
52
+
53
+ # Model config
54
+
55
+ model_name: str = "tf_inception_v3"
56
+
57
+ model_mode: str = "normal" # normal, cnn_fixed supported
58
+
59
+ # Training config
60
+
61
+ epoch: int = 20
62
+
63
+ batchsize: int = 8
64
+
65
+ valid_batchsize: int = 16
66
+
67
+ num_workers: int = 4
68
+
69
+ snapshot_freq: int = 5
70
+
71
+ ema_decay: float = 0.999 # negative value is to inactivate ema.
72
+
73
+ scheduler_type: str = ""
74
+
75
+ scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {})
76
+
77
+ scheduler_trigger: List[Union[int, str]] = field(default_factory=lambda: [1, "iteration"])
78
+
79
+ aug_kwargs: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {})
80
+
81
+ mixup_prob: float = -1.0 # Apply mixup augmentation when positive value is set.
82
+
83
+
84
+
85
+ def update(self, param_dict: Dict) -> "Flags":
86
+
87
+ # Overwrite by `param_dict`
88
+
89
+ for key, value in param_dict.items():
90
+
91
+ if not hasattr(self, key):
92
+
93
+ raise ValueError(f"[ERROR] Unexpected key for flag = {key}")
94
+
95
+ setattr(self, key, value)
96
+
97
+ return self
98
+
99
+
100
+
101
+ flags_dict = {
102
+
103
+ "debug": False, # Change to True for fast debug run!
104
+
105
+ "outdir": "results/tmp_debug",
106
+
107
+ # Data
108
+
109
+ "imgdir_name": "vinbigdata-chest-xray-resized-png-256x256",
110
+
111
+ # Model
112
+
113
+ "model_name": "tf_inception_v3",#"resnet18",
114
+
115
+ # Training
116
+
117
+ "num_workers": 4,
118
+
119
+ "epoch": 15,
120
+
121
+ "batchsize": 8,
122
+
123
+ "scheduler_type": "CosineAnnealingWarmRestarts",
124
+
125
+ "scheduler_kwargs": {"T_0": 28125}, # 15000 * 15 epoch // (batchsize=8)
126
+
127
+ "scheduler_trigger": [1, "iteration"],
128
+
129
+ "aug_kwargs": {
130
+
131
+ "HorizontalFlip": {"p": 0.5},
132
+
133
+ "ShiftScaleRotate": {"scale_limit": 0.15, "rotate_limit": 10, "p": 0.5},
134
+
135
+ "RandomBrightnessContrast": {"p": 0.5},
136
+
137
+ "CoarseDropout": {"max_holes": 8, "max_height": 25, "max_width": 25, "p": 0.5},
138
+
139
+ "Blur": {"blur_limit": [3, 7], "p": 0.5},
140
+
141
+ "Downscale": {"scale_min": 0.25, "scale_max": 0.9, "p": 0.3},
142
+
143
+ "RandomGamma": {"gamma_limit": [80, 120], "p": 0.6},
144
+
145
+ }
146
+
147
+ }
148
+
149
+
150
+
151
+ # args = parse()
152
+
153
+ print("torch", torch.__version__)
154
+
155
+ flags = Flags().update(flags_dict)
156
+
157
+ print("flags", flags)
158
+
159
+ debug = flags.debug
160
+
161
+ outdir = Path(flags.outdir)
162
+
163
+ os.makedirs(str(outdir), exist_ok=True)
164
+
165
+ flags_dict = dataclasses.asdict(flags)
166
+
167
+ save_yaml(str(outdir / "flags.yaml"), flags_dict)
168
+
169
+
170
+
171
+ # --- Read data ---
172
+
173
+ inputdir = Path("/kaggle/input")
174
+
175
+ datadir = inputdir / "vinbigdata-chest-xray-abnormalities-detection"
176
+
177
+ imgdir = inputdir / flags.imgdir_name
178
+
179
+
180
+
181
+ # Read in the data CSV files
182
+
183
+ train = pd.read_csv(datadir / "train.csv")
184
+
185
+
186
+
187
+ class CNNFixedPredictor(nn.Module):
188
+
189
+ def __init__(self, cnn: nn.Module, num_classes: int = 2):
190
+
191
+ super(CNNFixedPredictor, self).__init__()
192
+
193
+ self.cnn = cnn
194
+
195
+ self.lin = Linear(cnn.num_features, num_classes)
196
+
197
+ print("cnn.num_features", cnn.num_features)
198
+
199
+
200
+
201
+ # We do not learn CNN parameters.
202
+
203
+ # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
204
+
205
+ for param in self.cnn.parameters():
206
+
207
+ param.requires_grad = False
208
+
209
+
210
+
211
+ def forward(self, x):
212
+
213
+ feat = self.cnn(x)
214
+
215
+ return self.lin(feat)
216
+
217
+
218
+
219
+ def build_predictor(model_name: str, model_mode: str = "normal"):
220
+
221
+ if model_mode == "normal":
222
+
223
+ # normal configuration. train all parameters.
224
+
225
+ return timm.create_model(model_name, pretrained=True, num_classes=2, in_chans=3)
226
+
227
+ elif model_mode == "cnn_fixed":
228
+
229
+ # normal configuration. train all parameters.
230
+
231
+ # https://rwightman.github.io/pytorch-image-models/feature_extraction/
232
+
233
+ timm_model = timm.create_model(model_name, pretrained=True, num_classes=0, in_chans=3)
234
+
235
+ return CNNFixedPredictor(timm_model, num_classes=2)
236
+
237
+ else:
238
+
239
+ raise ValueError(f"[ERROR] Unexpected value model_mode={model_mode}")
240
+
241
+
242
+
243
+ class Classifier(nn.Module):
244
+
245
+ """two class classfication"""
246
+
247
+
248
+
249
+ def __init__(self, predictor, lossfun=cross_entropy_with_logits):
250
+
251
+ super().__init__()
252
+
253
+ self.predictor = predictor
254
+
255
+ self.lossfun = lossfun
256
+
257
+ self.prefix = ""
258
+
259
+
260
+
261
+ def forward(self, image, targets):
262
+
263
+ outputs = self.predictor(image)
264
+
265
+ loss = self.lossfun(outputs, targets)
266
+
267
+ metrics = {
268
+
269
+ f"{self.prefix}loss": loss.item(),
270
+
271
+ f"{self.prefix}acc": accuracy_with_logits(outputs, targets).item()
272
+
273
+ }
274
+
275
+ ppe.reporting.report(metrics, self)
276
+
277
+ return loss, metrics
278
+
279
+
280
+
281
+ def predict(self, data_loader):
282
+
283
+ pred = self.predict_proba(data_loader)
284
+
285
+ label = torch.argmax(pred, dim=1)
286
+
287
+ return label
288
+
289
+
290
+
291
+ def predict_proba(self, data_loader):
292
+
293
+ device: torch.device = next(self.parameters()).device
294
+
295
+ y_list = []
296
+
297
+ self.eval()
298
+
299
+ with torch.no_grad():
300
+
301
+ for batch in data_loader:
302
+
303
+ if isinstance(batch, (tuple, list)):
304
+
305
+ # Assumes first argument is "image"
306
+
307
+ batch = batch[0].to(device)
308
+
309
+ else:
310
+
311
+ batch = batch.to(device)
312
+
313
+ y = self.predictor(batch)
314
+
315
+ y = torch.softmax(y, dim=-1)
316
+
317
+ y_list.append(y)
318
+
319
+ pred = torch.cat(y_list)
320
+
321
+ return pred
322
+
323
+
324
+
325
+ def create_trainer(model, optimizer, device) -> Engine:
326
+
327
+ model.to(device)
328
+
329
+
330
+
331
+ def update_fn(engine, batch):
332
+
333
+ model.train()
334
+
335
+ optimizer.zero_grad()
336
+
337
+ loss, metrics = model(*[elem.to(device) for elem in batch])
338
+
339
+ loss.backward()
340
+
341
+ optimizer.step()
342
+
343
+ return metrics
344
+
345
+ trainer = Engine(update_fn)
346
+
347
+ return trainer
348
+
349
+
350
+
351
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=flags.seed)
352
+
353
+ # skf.get_n_splits(None, None)
354
+
355
+ y = np.array([int(len(d["annotations"]) > 0) for d in dataset_dicts])
356
+
357
+ split_inds = list(skf.split(dataset_dicts, y))
358
+
359
+ train_inds, valid_inds = split_inds[flags.target_fold] # 0th fold
360
+
361
+ train_dataset = VinbigdataTwoClassDataset(
362
+
363
+ [dataset_dicts[i] for i in train_inds],
364
+
365
+ image_transform=Transform(flags.aug_kwargs),
366
+
367
+ mixup_prob=flags.mixup_prob,
368
+
369
+ label_smoothing=flags.label_smoothing,
370
+
371
+ )
372
+
373
+ valid_dataset = VinbigdataTwoClassDataset([dataset_dicts[i] for i in valid_inds])
374
+
375
+
376
+
377
+ train_loader = DataLoader(
378
+
379
+ train_dataset,
380
+
381
+ batch_size=flags.batchsize,
382
+
383
+ num_workers=flags.num_workers,
384
+
385
+ shuffle=True,
386
+
387
+ pin_memory=True,
388
+
389
+ )
390
+
391
+ valid_loader = DataLoader(
392
+
393
+ valid_dataset,
394
+
395
+ batch_size=flags.valid_batchsize,
396
+
397
+ num_workers=flags.num_workers,
398
+
399
+ shuffle=False,
400
+
401
+ pin_memory=True,
402
+
403
+ )
404
+
405
+
406
+
407
+ device = torch.device(flags.device)
408
+
409
+
410
+
411
+ predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode)
412
+
413
+ classifier = Classifier(predictor)
414
+
415
+ model = classifier
416
+
417
+ optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-3)
418
+
419
+
420
+
421
+ # Train setup
422
+
423
+ trainer = create_trainer(model, optimizer, device)
424
+
425
+
426
+
427
+ ema = EMA(predictor, decay=flags.ema_decay)
428
+
429
+
430
+
431
+ def eval_func(*batch):
432
+
433
+ loss, metrics = model(*[elem.to(device) for elem in batch])
434
+
435
+ # HACKING: report ema value with prefix.
436
+
437
+ if flags.ema_decay > 0:
438
+
439
+ classifier.prefix = "ema_"
440
+
441
+ ema.assign()
442
+
443
+ loss, metrics = model(*[elem.to(device) for elem in batch])
444
+
445
+ ema.resume()
446
+
447
+ classifier.prefix = ""
448
+
449
+
450
+
451
+ valid_evaluator = E.Evaluator(
452
+
453
+ valid_loader, model, progress_bar=False, eval_func=eval_func, device=device
454
+
455
+ )
456
+
457
+
458
+
459
+ # log_trigger = (10 if debug else 1000, "iteration")
460
+
461
+ log_trigger = (1, "epoch")
462
+
463
+ log_report = E.LogReport(trigger=log_trigger)
464
+
465
+ extensions = [
466
+
467
+ log_report,
468
+
469
+ E.ProgressBarNotebook(update_interval=10 if debug else 100), # Show progress bar during training
470
+
471
+ E.PrintReportNotebook(),
472
+
473
+ E.FailOnNonNumber(),
474
+
475
+ ]
476
+
477
+ epoch = flags.epoch
478
+
479
+ models = {"main": model}
480
+
481
+ optimizers = {"main": optimizer}
482
+
483
+ manager = IgniteExtensionsManager(
484
+
485
+ trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(outdir),
486
+
487
+ )
488
+
489
+ # Run evaluation for valid dataset in each epoch.
490
+
491
+ manager.extend(valid_evaluator)
492
+
493
+
494
+
495
+ # Save predictor.pt every epoch
496
+
497
+ manager.extend(
498
+
499
+ E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "epoch")
500
+
501
+ )
502
+
503
+
504
+
505
+ manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger)
506
+
507
+
508
+
509
+ if flags.ema_decay > 0:
510
+
511
+ # Exponential moving average
512
+
513
+ manager.extend(lambda manager: ema(), trigger=(1, "iteration"))
514
+
515
+
516
+
517
+ def save_ema_model(manager):
518
+
519
+ ema.assign()
520
+
521
+ #torch.save(predictor.state_dict(), outdir / "predictor_ema.pt")
522
+
523
+ torch.save(predictor.state_dict(), "/kaggle/working/predictor_ema.pt")
524
+
525
+ ema.resume()
526
+
527
+
528
+
529
+ manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch"))
530
+
531
+
532
+
533
+ _ = trainer.run(train_loader, max_epochs=epoch)
534
+
535
+
536
+
537
+ torch.save(predictor.state_dict(), "/kaggle/working/predictor_tf_inception_v3.pt")
538
+
539
+ df = log_report.to_dataframe()
540
+
541
+ df.to_csv("/kaggle/working/log.csv", index=False)
542
+
543
+
544
+
545
+ ####################################追記終わり#####################################################
546
+
547
+
548
+
549
+ #モデルの読み込み
50
550
 
51
551
  model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt'
52
552
 
53
553
  model = model.load_state_dict(torch.load(model_path))
54
554
 
55
555
  ```
56
-
57
-
58
-
59
- ### 補足情報(FW/ツールのバージョンなど)
60
-
61
-
62
-
63
- Pytorch: 1.7.1
64
-
65
- cuda: 10.1
66
-
67
- kaggleのNotebookを使用しております