teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

2

コードの一部修正

2021/03/01 12:33

投稿

SuzuAya
SuzuAya

スコア71

title CHANGED
File without changes
body CHANGED
@@ -1,12 +1,15 @@
1
- ### 前提・実現したいこと
2
- Pytorchで保存したモデルを読み込みたいのですが、エラーが発生してしまいました。
3
- 原因についてお分かりの方がいらっしゃいましたら、ご教示いただけますと幸いです。
4
-
5
1
  ### 発生している問題・エラーメッセージ
6
2
  ```
3
+ ----> 5 model = classifier(flags.aug_kwargs)
7
- ----> 3 model = model.load_state_dict(torch.load(model_path))
4
+ 6 model.load_state_dict(torch.load(model_path))
8
5
 
6
+ 725 result = self._slow_forward(*input, **kwargs)
7
+ 726 else:
8
+ --> 727 result = self.forward(*input, **kwargs)
9
+ 728 for hook in itertools.chain(
10
+ 729 _global_forward_hooks.values(),
11
+
9
- AttributeError: 'str' object has no attribute 'load_state_dict'
12
+ TypeError: forward() missing 1 required positional argument: 'targets'
10
13
  ```
11
14
 
12
15
  ### 該当のソースコード
@@ -273,6 +276,7 @@
273
276
  ####################################追記終わり#####################################################
274
277
 
275
278
  #モデルの読み込み
276
- model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt'
279
+ model_path = './predictor_last.pt'
280
+ model = classifier(flags.aug_kwargs)
277
- model = model.load_state_dict(torch.load(model_path))
281
+ model.load_state_dict(torch.load(model_path))
278
282
  ```

1

コードの追記

2021/03/01 12:33

投稿

SuzuAya
SuzuAya

スコア71

title CHANGED
File without changes
body CHANGED
@@ -1,34 +1,278 @@
1
1
  ### 前提・実現したいこと
2
- Pytorchで保存したモデルを読み込みたいと思っていま
2
+ Pytorchで保存したモデルを読み込みたいのですが、エラーが発生ししまいました
3
- [こちら](https://tzmi.hatenablog.com/entry/2020/03/05/222813)のGPUでの読み出し方法を参考に
4
- コードを書いてみたのですが、エラーが発生してしまいました。
5
3
  原因についてお分かりの方がいらっしゃいましたら、ご教示いただけますと幸いです。
6
4
 
7
5
  ### 発生している問題・エラーメッセージ
8
-
9
6
  ```
10
- AttributeError Traceback (most recent call last)
11
- <ipython-input-15-64589b233dc6> in <module>
12
- 1 #モデルの読み込み方法
13
- 2 model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt'
14
7
  ----> 3 model = model.load_state_dict(torch.load(model_path))
15
8
 
16
9
  AttributeError: 'str' object has no attribute 'load_state_dict'
17
10
  ```
18
11
 
19
12
  ### 該当のソースコード
20
-
21
13
  ```Python
22
- import torch
14
+ @dataclass
15
+ class Flags:
16
+ # General
17
+ debug: bool = True
18
+ outdir: str = "results/det"
23
- from torchvision import models
19
+ device: str = "cuda:0"
24
20
 
21
+ # Data config
22
+ imgdir_name: str = "vinbigdata-chest-xray-resized-png-256x256"
23
+ # split_mode: str = "all_train"
24
+ seed: int = 111
25
+ target_fold: int = 0 # 0~4
26
+ label_smoothing: float = 0.0
27
+ # Model config
28
+ model_name: str = "tf_inception_v3"
29
+ model_mode: str = "normal" # normal, cnn_fixed supported
30
+ # Training config
31
+ epoch: int = 20
32
+ batchsize: int = 8
33
+ valid_batchsize: int = 16
34
+ num_workers: int = 4
35
+ snapshot_freq: int = 5
36
+ ema_decay: float = 0.999 # negative value is to inactivate ema.
37
+ scheduler_type: str = ""
38
+ scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {})
39
+ scheduler_trigger: List[Union[int, str]] = field(default_factory=lambda: [1, "iteration"])
40
+ aug_kwargs: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {})
41
+ mixup_prob: float = -1.0 # Apply mixup augmentation when positive value is set.
42
+
43
+ def update(self, param_dict: Dict) -> "Flags":
44
+ # Overwrite by `param_dict`
45
+ for key, value in param_dict.items():
46
+ if not hasattr(self, key):
47
+ raise ValueError(f"[ERROR] Unexpected key for flag = {key}")
48
+ setattr(self, key, value)
49
+ return self
50
+
51
+ flags_dict = {
52
+ "debug": False, # Change to True for fast debug run!
53
+ "outdir": "results/tmp_debug",
54
+ # Data
55
+ "imgdir_name": "vinbigdata-chest-xray-resized-png-256x256",
56
+ # Model
57
+ "model_name": "tf_inception_v3",#"resnet18",
58
+ # Training
59
+ "num_workers": 4,
60
+ "epoch": 15,
61
+ "batchsize": 8,
62
+ "scheduler_type": "CosineAnnealingWarmRestarts",
63
+ "scheduler_kwargs": {"T_0": 28125}, # 15000 * 15 epoch // (batchsize=8)
64
+ "scheduler_trigger": [1, "iteration"],
65
+ "aug_kwargs": {
66
+ "HorizontalFlip": {"p": 0.5},
67
+ "ShiftScaleRotate": {"scale_limit": 0.15, "rotate_limit": 10, "p": 0.5},
68
+ "RandomBrightnessContrast": {"p": 0.5},
69
+ "CoarseDropout": {"max_holes": 8, "max_height": 25, "max_width": 25, "p": 0.5},
70
+ "Blur": {"blur_limit": [3, 7], "p": 0.5},
71
+ "Downscale": {"scale_min": 0.25, "scale_max": 0.9, "p": 0.3},
72
+ "RandomGamma": {"gamma_limit": [80, 120], "p": 0.6},
73
+ }
74
+ }
75
+
76
+ # args = parse()
77
+ print("torch", torch.__version__)
78
+ flags = Flags().update(flags_dict)
79
+ print("flags", flags)
80
+ debug = flags.debug
81
+ outdir = Path(flags.outdir)
82
+ os.makedirs(str(outdir), exist_ok=True)
83
+ flags_dict = dataclasses.asdict(flags)
84
+ save_yaml(str(outdir / "flags.yaml"), flags_dict)
85
+
86
+ # --- Read data ---
87
+ inputdir = Path("/kaggle/input")
88
+ datadir = inputdir / "vinbigdata-chest-xray-abnormalities-detection"
89
+ imgdir = inputdir / flags.imgdir_name
90
+
91
+ # Read in the data CSV files
92
+ train = pd.read_csv(datadir / "train.csv")
93
+
94
+ class CNNFixedPredictor(nn.Module):
95
+ def __init__(self, cnn: nn.Module, num_classes: int = 2):
96
+ super(CNNFixedPredictor, self).__init__()
97
+ self.cnn = cnn
98
+ self.lin = Linear(cnn.num_features, num_classes)
99
+ print("cnn.num_features", cnn.num_features)
100
+
101
+ # We do not learn CNN parameters.
102
+ # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
103
+ for param in self.cnn.parameters():
104
+ param.requires_grad = False
105
+
106
+ def forward(self, x):
107
+ feat = self.cnn(x)
108
+ return self.lin(feat)
109
+
110
+ def build_predictor(model_name: str, model_mode: str = "normal"):
111
+ if model_mode == "normal":
112
+ # normal configuration. train all parameters.
113
+ return timm.create_model(model_name, pretrained=True, num_classes=2, in_chans=3)
114
+ elif model_mode == "cnn_fixed":
115
+ # normal configuration. train all parameters.
116
+ # https://rwightman.github.io/pytorch-image-models/feature_extraction/
117
+ timm_model = timm.create_model(model_name, pretrained=True, num_classes=0, in_chans=3)
118
+ return CNNFixedPredictor(timm_model, num_classes=2)
119
+ else:
120
+ raise ValueError(f"[ERROR] Unexpected value model_mode={model_mode}")
121
+
122
+ class Classifier(nn.Module):
123
+ """two class classfication"""
124
+
125
+ def __init__(self, predictor, lossfun=cross_entropy_with_logits):
126
+ super().__init__()
127
+ self.predictor = predictor
128
+ self.lossfun = lossfun
129
+ self.prefix = ""
130
+
131
+ def forward(self, image, targets):
132
+ outputs = self.predictor(image)
133
+ loss = self.lossfun(outputs, targets)
134
+ metrics = {
135
+ f"{self.prefix}loss": loss.item(),
136
+ f"{self.prefix}acc": accuracy_with_logits(outputs, targets).item()
137
+ }
138
+ ppe.reporting.report(metrics, self)
139
+ return loss, metrics
140
+
141
+ def predict(self, data_loader):
142
+ pred = self.predict_proba(data_loader)
143
+ label = torch.argmax(pred, dim=1)
144
+ return label
145
+
146
+ def predict_proba(self, data_loader):
147
+ device: torch.device = next(self.parameters()).device
148
+ y_list = []
149
+ self.eval()
150
+ with torch.no_grad():
151
+ for batch in data_loader:
152
+ if isinstance(batch, (tuple, list)):
153
+ # Assumes first argument is "image"
154
+ batch = batch[0].to(device)
155
+ else:
156
+ batch = batch.to(device)
157
+ y = self.predictor(batch)
158
+ y = torch.softmax(y, dim=-1)
159
+ y_list.append(y)
160
+ pred = torch.cat(y_list)
161
+ return pred
162
+
163
+ def create_trainer(model, optimizer, device) -> Engine:
164
+ model.to(device)
165
+
166
+ def update_fn(engine, batch):
167
+ model.train()
168
+ optimizer.zero_grad()
169
+ loss, metrics = model(*[elem.to(device) for elem in batch])
170
+ loss.backward()
171
+ optimizer.step()
172
+ return metrics
173
+ trainer = Engine(update_fn)
174
+ return trainer
175
+
176
+ skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=flags.seed)
177
+ # skf.get_n_splits(None, None)
178
+ y = np.array([int(len(d["annotations"]) > 0) for d in dataset_dicts])
179
+ split_inds = list(skf.split(dataset_dicts, y))
180
+ train_inds, valid_inds = split_inds[flags.target_fold] # 0th fold
181
+ train_dataset = VinbigdataTwoClassDataset(
182
+ [dataset_dicts[i] for i in train_inds],
183
+ image_transform=Transform(flags.aug_kwargs),
184
+ mixup_prob=flags.mixup_prob,
185
+ label_smoothing=flags.label_smoothing,
186
+ )
187
+ valid_dataset = VinbigdataTwoClassDataset([dataset_dicts[i] for i in valid_inds])
188
+
189
+ train_loader = DataLoader(
190
+ train_dataset,
191
+ batch_size=flags.batchsize,
192
+ num_workers=flags.num_workers,
193
+ shuffle=True,
194
+ pin_memory=True,
195
+ )
196
+ valid_loader = DataLoader(
197
+ valid_dataset,
198
+ batch_size=flags.valid_batchsize,
199
+ num_workers=flags.num_workers,
200
+ shuffle=False,
201
+ pin_memory=True,
202
+ )
203
+
204
+ device = torch.device(flags.device)
205
+
206
+ predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode)
207
+ classifier = Classifier(predictor)
208
+ model = classifier
209
+ optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-3)
210
+
211
+ # Train setup
212
+ trainer = create_trainer(model, optimizer, device)
213
+
214
+ ema = EMA(predictor, decay=flags.ema_decay)
215
+
216
+ def eval_func(*batch):
217
+ loss, metrics = model(*[elem.to(device) for elem in batch])
218
+ # HACKING: report ema value with prefix.
219
+ if flags.ema_decay > 0:
220
+ classifier.prefix = "ema_"
221
+ ema.assign()
222
+ loss, metrics = model(*[elem.to(device) for elem in batch])
223
+ ema.resume()
224
+ classifier.prefix = ""
225
+
226
+ valid_evaluator = E.Evaluator(
227
+ valid_loader, model, progress_bar=False, eval_func=eval_func, device=device
228
+ )
229
+
230
+ # log_trigger = (10 if debug else 1000, "iteration")
231
+ log_trigger = (1, "epoch")
232
+ log_report = E.LogReport(trigger=log_trigger)
233
+ extensions = [
234
+ log_report,
235
+ E.ProgressBarNotebook(update_interval=10 if debug else 100), # Show progress bar during training
236
+ E.PrintReportNotebook(),
237
+ E.FailOnNonNumber(),
238
+ ]
239
+ epoch = flags.epoch
240
+ models = {"main": model}
241
+ optimizers = {"main": optimizer}
242
+ manager = IgniteExtensionsManager(
243
+ trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(outdir),
244
+ )
245
+ # Run evaluation for valid dataset in each epoch.
246
+ manager.extend(valid_evaluator)
247
+
248
+ # Save predictor.pt every epoch
249
+ manager.extend(
250
+ E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "epoch")
251
+ )
252
+
253
+ manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger)
254
+
255
+ if flags.ema_decay > 0:
256
+ # Exponential moving average
257
+ manager.extend(lambda manager: ema(), trigger=(1, "iteration"))
258
+
259
+ def save_ema_model(manager):
260
+ ema.assign()
261
+ #torch.save(predictor.state_dict(), outdir / "predictor_ema.pt")
262
+ torch.save(predictor.state_dict(), "/kaggle/working/predictor_ema.pt")
263
+ ema.resume()
264
+
265
+ manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch"))
266
+
267
+ _ = trainer.run(train_loader, max_epochs=epoch)
268
+
269
+ torch.save(predictor.state_dict(), "/kaggle/working/predictor_tf_inception_v3.pt")
270
+ df = log_report.to_dataframe()
271
+ df.to_csv("/kaggle/working/log.csv", index=False)
272
+
273
+ ####################################追記終わり#####################################################
274
+
25
- #モデルの読み込み方法
275
+ #モデルの読み込み
26
276
  model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt'
27
277
  model = model.load_state_dict(torch.load(model_path))
28
- ```
278
+ ```
29
-
30
- ### 補足情報(FW/ツールのバージョンなど)
31
-
32
- Pytorch: 1.7.1
33
- cuda: 10.1
34
- kaggleのNotebookを使用しております