質問編集履歴
2
コードの一部修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -1,20 +1,26 @@
|
|
1
|
-
### 前提・実現したいこと
|
2
|
-
|
3
|
-
Pytorchで保存したモデルを読み込みたいのですが、エラーが発生してしまいました。
|
4
|
-
|
5
|
-
原因についてお分かりの方がいらっしゃいましたら、ご教示いただけますと幸いです。
|
6
|
-
|
7
|
-
|
8
|
-
|
9
1
|
### 発生している問題・エラーメッセージ
|
10
2
|
|
11
3
|
```
|
12
4
|
|
5
|
+
----> 5 model = classifier(flags.aug_kwargs)
|
6
|
+
|
13
|
-
|
7
|
+
6 model.load_state_dict(torch.load(model_path))
|
8
|
+
|
9
|
+
|
10
|
+
|
14
|
-
|
11
|
+
725 result = self._slow_forward(*input, **kwargs)
|
12
|
+
|
15
|
-
|
13
|
+
726 else:
|
14
|
+
|
16
|
-
|
15
|
+
--> 727 result = self.forward(*input, **kwargs)
|
16
|
+
|
17
|
+
728 for hook in itertools.chain(
|
18
|
+
|
19
|
+
729 _global_forward_hooks.values(),
|
20
|
+
|
21
|
+
|
22
|
+
|
17
|
-
|
23
|
+
TypeError: forward() missing 1 required positional argument: 'targets'
|
18
24
|
|
19
25
|
```
|
20
26
|
|
@@ -548,8 +554,10 @@
|
|
548
554
|
|
549
555
|
#モデルの読み込み
|
550
556
|
|
551
|
-
model_path = '.
|
557
|
+
model_path = './predictor_last.pt'
|
558
|
+
|
552
|
-
|
559
|
+
model = classifier(flags.aug_kwargs)
|
560
|
+
|
553
|
-
model
|
561
|
+
model.load_state_dict(torch.load(model_path))
|
554
562
|
|
555
563
|
```
|
1
コードの追記
test
CHANGED
File without changes
|
test
CHANGED
@@ -1,10 +1,6 @@
|
|
1
1
|
### 前提・実現したいこと
|
2
2
|
|
3
|
-
Pytorchで保存したモデルを読み込みたい
|
3
|
+
Pytorchで保存したモデルを読み込みたいのですが、エラーが発生してしまいました。
|
4
|
-
|
5
|
-
[こちら](https://tzmi.hatenablog.com/entry/2020/03/05/222813)のGPUでの読み出し方法を参考に
|
6
|
-
|
7
|
-
コードを書いてみたのですが、エラーが発生してしまいました。
|
8
4
|
|
9
5
|
原因についてお分かりの方がいらっしゃいましたら、ご教示いただけますと幸いです。
|
10
6
|
|
@@ -12,18 +8,8 @@
|
|
12
8
|
|
13
9
|
### 発生している問題・エラーメッセージ
|
14
10
|
|
15
|
-
|
16
|
-
|
17
11
|
```
|
18
12
|
|
19
|
-
AttributeError Traceback (most recent call last)
|
20
|
-
|
21
|
-
<ipython-input-15-64589b233dc6> in <module>
|
22
|
-
|
23
|
-
1 #モデルの読み込み方法
|
24
|
-
|
25
|
-
2 model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt'
|
26
|
-
|
27
13
|
----> 3 model = model.load_state_dict(torch.load(model_path))
|
28
14
|
|
29
15
|
|
@@ -36,32 +22,534 @@
|
|
36
22
|
|
37
23
|
### 該当のソースコード
|
38
24
|
|
39
|
-
|
40
|
-
|
41
25
|
```Python
|
42
26
|
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
27
|
+
@dataclass
|
28
|
+
|
29
|
+
class Flags:
|
30
|
+
|
31
|
+
# General
|
32
|
+
|
33
|
+
debug: bool = True
|
34
|
+
|
35
|
+
outdir: str = "results/det"
|
36
|
+
|
37
|
+
device: str = "cuda:0"
|
38
|
+
|
39
|
+
|
40
|
+
|
41
|
+
# Data config
|
42
|
+
|
43
|
+
imgdir_name: str = "vinbigdata-chest-xray-resized-png-256x256"
|
44
|
+
|
45
|
+
# split_mode: str = "all_train"
|
46
|
+
|
47
|
+
seed: int = 111
|
48
|
+
|
49
|
+
target_fold: int = 0 # 0~4
|
50
|
+
|
51
|
+
label_smoothing: float = 0.0
|
52
|
+
|
53
|
+
# Model config
|
54
|
+
|
55
|
+
model_name: str = "tf_inception_v3"
|
56
|
+
|
57
|
+
model_mode: str = "normal" # normal, cnn_fixed supported
|
58
|
+
|
59
|
+
# Training config
|
60
|
+
|
61
|
+
epoch: int = 20
|
62
|
+
|
63
|
+
batchsize: int = 8
|
64
|
+
|
65
|
+
valid_batchsize: int = 16
|
66
|
+
|
67
|
+
num_workers: int = 4
|
68
|
+
|
69
|
+
snapshot_freq: int = 5
|
70
|
+
|
71
|
+
ema_decay: float = 0.999 # negative value is to inactivate ema.
|
72
|
+
|
73
|
+
scheduler_type: str = ""
|
74
|
+
|
75
|
+
scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {})
|
76
|
+
|
77
|
+
scheduler_trigger: List[Union[int, str]] = field(default_factory=lambda: [1, "iteration"])
|
78
|
+
|
79
|
+
aug_kwargs: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {})
|
80
|
+
|
81
|
+
mixup_prob: float = -1.0 # Apply mixup augmentation when positive value is set.
|
82
|
+
|
83
|
+
|
84
|
+
|
85
|
+
def update(self, param_dict: Dict) -> "Flags":
|
86
|
+
|
87
|
+
# Overwrite by `param_dict`
|
88
|
+
|
89
|
+
for key, value in param_dict.items():
|
90
|
+
|
91
|
+
if not hasattr(self, key):
|
92
|
+
|
93
|
+
raise ValueError(f"[ERROR] Unexpected key for flag = {key}")
|
94
|
+
|
95
|
+
setattr(self, key, value)
|
96
|
+
|
97
|
+
return self
|
98
|
+
|
99
|
+
|
100
|
+
|
101
|
+
flags_dict = {
|
102
|
+
|
103
|
+
"debug": False, # Change to True for fast debug run!
|
104
|
+
|
105
|
+
"outdir": "results/tmp_debug",
|
106
|
+
|
107
|
+
# Data
|
108
|
+
|
109
|
+
"imgdir_name": "vinbigdata-chest-xray-resized-png-256x256",
|
110
|
+
|
111
|
+
# Model
|
112
|
+
|
113
|
+
"model_name": "tf_inception_v3",#"resnet18",
|
114
|
+
|
115
|
+
# Training
|
116
|
+
|
117
|
+
"num_workers": 4,
|
118
|
+
|
119
|
+
"epoch": 15,
|
120
|
+
|
121
|
+
"batchsize": 8,
|
122
|
+
|
123
|
+
"scheduler_type": "CosineAnnealingWarmRestarts",
|
124
|
+
|
125
|
+
"scheduler_kwargs": {"T_0": 28125}, # 15000 * 15 epoch // (batchsize=8)
|
126
|
+
|
127
|
+
"scheduler_trigger": [1, "iteration"],
|
128
|
+
|
129
|
+
"aug_kwargs": {
|
130
|
+
|
131
|
+
"HorizontalFlip": {"p": 0.5},
|
132
|
+
|
133
|
+
"ShiftScaleRotate": {"scale_limit": 0.15, "rotate_limit": 10, "p": 0.5},
|
134
|
+
|
135
|
+
"RandomBrightnessContrast": {"p": 0.5},
|
136
|
+
|
137
|
+
"CoarseDropout": {"max_holes": 8, "max_height": 25, "max_width": 25, "p": 0.5},
|
138
|
+
|
139
|
+
"Blur": {"blur_limit": [3, 7], "p": 0.5},
|
140
|
+
|
141
|
+
"Downscale": {"scale_min": 0.25, "scale_max": 0.9, "p": 0.3},
|
142
|
+
|
143
|
+
"RandomGamma": {"gamma_limit": [80, 120], "p": 0.6},
|
144
|
+
|
145
|
+
}
|
146
|
+
|
147
|
+
}
|
148
|
+
|
149
|
+
|
150
|
+
|
151
|
+
# args = parse()
|
152
|
+
|
153
|
+
print("torch", torch.__version__)
|
154
|
+
|
155
|
+
flags = Flags().update(flags_dict)
|
156
|
+
|
157
|
+
print("flags", flags)
|
158
|
+
|
159
|
+
debug = flags.debug
|
160
|
+
|
161
|
+
outdir = Path(flags.outdir)
|
162
|
+
|
163
|
+
os.makedirs(str(outdir), exist_ok=True)
|
164
|
+
|
165
|
+
flags_dict = dataclasses.asdict(flags)
|
166
|
+
|
167
|
+
save_yaml(str(outdir / "flags.yaml"), flags_dict)
|
168
|
+
|
169
|
+
|
170
|
+
|
171
|
+
# --- Read data ---
|
172
|
+
|
173
|
+
inputdir = Path("/kaggle/input")
|
174
|
+
|
175
|
+
datadir = inputdir / "vinbigdata-chest-xray-abnormalities-detection"
|
176
|
+
|
177
|
+
imgdir = inputdir / flags.imgdir_name
|
178
|
+
|
179
|
+
|
180
|
+
|
181
|
+
# Read in the data CSV files
|
182
|
+
|
183
|
+
train = pd.read_csv(datadir / "train.csv")
|
184
|
+
|
185
|
+
|
186
|
+
|
187
|
+
class CNNFixedPredictor(nn.Module):
|
188
|
+
|
189
|
+
def __init__(self, cnn: nn.Module, num_classes: int = 2):
|
190
|
+
|
191
|
+
super(CNNFixedPredictor, self).__init__()
|
192
|
+
|
193
|
+
self.cnn = cnn
|
194
|
+
|
195
|
+
self.lin = Linear(cnn.num_features, num_classes)
|
196
|
+
|
197
|
+
print("cnn.num_features", cnn.num_features)
|
198
|
+
|
199
|
+
|
200
|
+
|
201
|
+
# We do not learn CNN parameters.
|
202
|
+
|
203
|
+
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
|
204
|
+
|
205
|
+
for param in self.cnn.parameters():
|
206
|
+
|
207
|
+
param.requires_grad = False
|
208
|
+
|
209
|
+
|
210
|
+
|
211
|
+
def forward(self, x):
|
212
|
+
|
213
|
+
feat = self.cnn(x)
|
214
|
+
|
215
|
+
return self.lin(feat)
|
216
|
+
|
217
|
+
|
218
|
+
|
219
|
+
def build_predictor(model_name: str, model_mode: str = "normal"):
|
220
|
+
|
221
|
+
if model_mode == "normal":
|
222
|
+
|
223
|
+
# normal configuration. train all parameters.
|
224
|
+
|
225
|
+
return timm.create_model(model_name, pretrained=True, num_classes=2, in_chans=3)
|
226
|
+
|
227
|
+
elif model_mode == "cnn_fixed":
|
228
|
+
|
229
|
+
# normal configuration. train all parameters.
|
230
|
+
|
231
|
+
# https://rwightman.github.io/pytorch-image-models/feature_extraction/
|
232
|
+
|
233
|
+
timm_model = timm.create_model(model_name, pretrained=True, num_classes=0, in_chans=3)
|
234
|
+
|
235
|
+
return CNNFixedPredictor(timm_model, num_classes=2)
|
236
|
+
|
237
|
+
else:
|
238
|
+
|
239
|
+
raise ValueError(f"[ERROR] Unexpected value model_mode={model_mode}")
|
240
|
+
|
241
|
+
|
242
|
+
|
243
|
+
class Classifier(nn.Module):
|
244
|
+
|
245
|
+
"""two class classfication"""
|
246
|
+
|
247
|
+
|
248
|
+
|
249
|
+
def __init__(self, predictor, lossfun=cross_entropy_with_logits):
|
250
|
+
|
251
|
+
super().__init__()
|
252
|
+
|
253
|
+
self.predictor = predictor
|
254
|
+
|
255
|
+
self.lossfun = lossfun
|
256
|
+
|
257
|
+
self.prefix = ""
|
258
|
+
|
259
|
+
|
260
|
+
|
261
|
+
def forward(self, image, targets):
|
262
|
+
|
263
|
+
outputs = self.predictor(image)
|
264
|
+
|
265
|
+
loss = self.lossfun(outputs, targets)
|
266
|
+
|
267
|
+
metrics = {
|
268
|
+
|
269
|
+
f"{self.prefix}loss": loss.item(),
|
270
|
+
|
271
|
+
f"{self.prefix}acc": accuracy_with_logits(outputs, targets).item()
|
272
|
+
|
273
|
+
}
|
274
|
+
|
275
|
+
ppe.reporting.report(metrics, self)
|
276
|
+
|
277
|
+
return loss, metrics
|
278
|
+
|
279
|
+
|
280
|
+
|
281
|
+
def predict(self, data_loader):
|
282
|
+
|
283
|
+
pred = self.predict_proba(data_loader)
|
284
|
+
|
285
|
+
label = torch.argmax(pred, dim=1)
|
286
|
+
|
287
|
+
return label
|
288
|
+
|
289
|
+
|
290
|
+
|
291
|
+
def predict_proba(self, data_loader):
|
292
|
+
|
293
|
+
device: torch.device = next(self.parameters()).device
|
294
|
+
|
295
|
+
y_list = []
|
296
|
+
|
297
|
+
self.eval()
|
298
|
+
|
299
|
+
with torch.no_grad():
|
300
|
+
|
301
|
+
for batch in data_loader:
|
302
|
+
|
303
|
+
if isinstance(batch, (tuple, list)):
|
304
|
+
|
305
|
+
# Assumes first argument is "image"
|
306
|
+
|
307
|
+
batch = batch[0].to(device)
|
308
|
+
|
309
|
+
else:
|
310
|
+
|
311
|
+
batch = batch.to(device)
|
312
|
+
|
313
|
+
y = self.predictor(batch)
|
314
|
+
|
315
|
+
y = torch.softmax(y, dim=-1)
|
316
|
+
|
317
|
+
y_list.append(y)
|
318
|
+
|
319
|
+
pred = torch.cat(y_list)
|
320
|
+
|
321
|
+
return pred
|
322
|
+
|
323
|
+
|
324
|
+
|
325
|
+
def create_trainer(model, optimizer, device) -> Engine:
|
326
|
+
|
327
|
+
model.to(device)
|
328
|
+
|
329
|
+
|
330
|
+
|
331
|
+
def update_fn(engine, batch):
|
332
|
+
|
333
|
+
model.train()
|
334
|
+
|
335
|
+
optimizer.zero_grad()
|
336
|
+
|
337
|
+
loss, metrics = model(*[elem.to(device) for elem in batch])
|
338
|
+
|
339
|
+
loss.backward()
|
340
|
+
|
341
|
+
optimizer.step()
|
342
|
+
|
343
|
+
return metrics
|
344
|
+
|
345
|
+
trainer = Engine(update_fn)
|
346
|
+
|
347
|
+
return trainer
|
348
|
+
|
349
|
+
|
350
|
+
|
351
|
+
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=flags.seed)
|
352
|
+
|
353
|
+
# skf.get_n_splits(None, None)
|
354
|
+
|
355
|
+
y = np.array([int(len(d["annotations"]) > 0) for d in dataset_dicts])
|
356
|
+
|
357
|
+
split_inds = list(skf.split(dataset_dicts, y))
|
358
|
+
|
359
|
+
train_inds, valid_inds = split_inds[flags.target_fold] # 0th fold
|
360
|
+
|
361
|
+
train_dataset = VinbigdataTwoClassDataset(
|
362
|
+
|
363
|
+
[dataset_dicts[i] for i in train_inds],
|
364
|
+
|
365
|
+
image_transform=Transform(flags.aug_kwargs),
|
366
|
+
|
367
|
+
mixup_prob=flags.mixup_prob,
|
368
|
+
|
369
|
+
label_smoothing=flags.label_smoothing,
|
370
|
+
|
371
|
+
)
|
372
|
+
|
373
|
+
valid_dataset = VinbigdataTwoClassDataset([dataset_dicts[i] for i in valid_inds])
|
374
|
+
|
375
|
+
|
376
|
+
|
377
|
+
train_loader = DataLoader(
|
378
|
+
|
379
|
+
train_dataset,
|
380
|
+
|
381
|
+
batch_size=flags.batchsize,
|
382
|
+
|
383
|
+
num_workers=flags.num_workers,
|
384
|
+
|
385
|
+
shuffle=True,
|
386
|
+
|
387
|
+
pin_memory=True,
|
388
|
+
|
389
|
+
)
|
390
|
+
|
391
|
+
valid_loader = DataLoader(
|
392
|
+
|
393
|
+
valid_dataset,
|
394
|
+
|
395
|
+
batch_size=flags.valid_batchsize,
|
396
|
+
|
397
|
+
num_workers=flags.num_workers,
|
398
|
+
|
399
|
+
shuffle=False,
|
400
|
+
|
401
|
+
pin_memory=True,
|
402
|
+
|
403
|
+
)
|
404
|
+
|
405
|
+
|
406
|
+
|
407
|
+
device = torch.device(flags.device)
|
408
|
+
|
409
|
+
|
410
|
+
|
411
|
+
predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode)
|
412
|
+
|
413
|
+
classifier = Classifier(predictor)
|
414
|
+
|
415
|
+
model = classifier
|
416
|
+
|
417
|
+
optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-3)
|
418
|
+
|
419
|
+
|
420
|
+
|
421
|
+
# Train setup
|
422
|
+
|
423
|
+
trainer = create_trainer(model, optimizer, device)
|
424
|
+
|
425
|
+
|
426
|
+
|
427
|
+
ema = EMA(predictor, decay=flags.ema_decay)
|
428
|
+
|
429
|
+
|
430
|
+
|
431
|
+
def eval_func(*batch):
|
432
|
+
|
433
|
+
loss, metrics = model(*[elem.to(device) for elem in batch])
|
434
|
+
|
435
|
+
# HACKING: report ema value with prefix.
|
436
|
+
|
437
|
+
if flags.ema_decay > 0:
|
438
|
+
|
439
|
+
classifier.prefix = "ema_"
|
440
|
+
|
441
|
+
ema.assign()
|
442
|
+
|
443
|
+
loss, metrics = model(*[elem.to(device) for elem in batch])
|
444
|
+
|
445
|
+
ema.resume()
|
446
|
+
|
447
|
+
classifier.prefix = ""
|
448
|
+
|
449
|
+
|
450
|
+
|
451
|
+
valid_evaluator = E.Evaluator(
|
452
|
+
|
453
|
+
valid_loader, model, progress_bar=False, eval_func=eval_func, device=device
|
454
|
+
|
455
|
+
)
|
456
|
+
|
457
|
+
|
458
|
+
|
459
|
+
# log_trigger = (10 if debug else 1000, "iteration")
|
460
|
+
|
461
|
+
log_trigger = (1, "epoch")
|
462
|
+
|
463
|
+
log_report = E.LogReport(trigger=log_trigger)
|
464
|
+
|
465
|
+
extensions = [
|
466
|
+
|
467
|
+
log_report,
|
468
|
+
|
469
|
+
E.ProgressBarNotebook(update_interval=10 if debug else 100), # Show progress bar during training
|
470
|
+
|
471
|
+
E.PrintReportNotebook(),
|
472
|
+
|
473
|
+
E.FailOnNonNumber(),
|
474
|
+
|
475
|
+
]
|
476
|
+
|
477
|
+
epoch = flags.epoch
|
478
|
+
|
479
|
+
models = {"main": model}
|
480
|
+
|
481
|
+
optimizers = {"main": optimizer}
|
482
|
+
|
483
|
+
manager = IgniteExtensionsManager(
|
484
|
+
|
485
|
+
trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(outdir),
|
486
|
+
|
487
|
+
)
|
488
|
+
|
489
|
+
# Run evaluation for valid dataset in each epoch.
|
490
|
+
|
491
|
+
manager.extend(valid_evaluator)
|
492
|
+
|
493
|
+
|
494
|
+
|
495
|
+
# Save predictor.pt every epoch
|
496
|
+
|
497
|
+
manager.extend(
|
498
|
+
|
499
|
+
E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "epoch")
|
500
|
+
|
501
|
+
)
|
502
|
+
|
503
|
+
|
504
|
+
|
505
|
+
manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger)
|
506
|
+
|
507
|
+
|
508
|
+
|
509
|
+
if flags.ema_decay > 0:
|
510
|
+
|
511
|
+
# Exponential moving average
|
512
|
+
|
513
|
+
manager.extend(lambda manager: ema(), trigger=(1, "iteration"))
|
514
|
+
|
515
|
+
|
516
|
+
|
517
|
+
def save_ema_model(manager):
|
518
|
+
|
519
|
+
ema.assign()
|
520
|
+
|
521
|
+
#torch.save(predictor.state_dict(), outdir / "predictor_ema.pt")
|
522
|
+
|
523
|
+
torch.save(predictor.state_dict(), "/kaggle/working/predictor_ema.pt")
|
524
|
+
|
525
|
+
ema.resume()
|
526
|
+
|
527
|
+
|
528
|
+
|
529
|
+
manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch"))
|
530
|
+
|
531
|
+
|
532
|
+
|
533
|
+
_ = trainer.run(train_loader, max_epochs=epoch)
|
534
|
+
|
535
|
+
|
536
|
+
|
537
|
+
torch.save(predictor.state_dict(), "/kaggle/working/predictor_tf_inception_v3.pt")
|
538
|
+
|
539
|
+
df = log_report.to_dataframe()
|
540
|
+
|
541
|
+
df.to_csv("/kaggle/working/log.csv", index=False)
|
542
|
+
|
543
|
+
|
544
|
+
|
545
|
+
####################################追記終わり#####################################################
|
546
|
+
|
547
|
+
|
548
|
+
|
549
|
+
#モデルの読み込み
|
50
550
|
|
51
551
|
model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt'
|
52
552
|
|
53
553
|
model = model.load_state_dict(torch.load(model_path))
|
54
554
|
|
55
555
|
```
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
### 補足情報(FW/ツールのバージョンなど)
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
Pytorch: 1.7.1
|
64
|
-
|
65
|
-
cuda: 10.1
|
66
|
-
|
67
|
-
kaggleのNotebookを使用しております
|