🎄teratailクリスマスプレゼントキャンペーン2024🎄』開催中!

\teratail特別グッズやAmazonギフトカード最大2,000円分が当たる!/

詳細はこちら
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

2回答

7995閲覧

Pytorchで保存したモデルを読み込みたい

SuzuAya

総合スコア71

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/02/28 11:52

編集2021/03/01 12:33

発生している問題・エラーメッセージ

----> 5 model = classifier(flags.aug_kwargs) 6 model.load_state_dict(torch.load(model_path)) 725 result = self._slow_forward(*input, **kwargs) 726 else: --> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(), TypeError: forward() missing 1 required positional argument: 'targets'

該当のソースコード

Python

1@dataclass 2class Flags: 3 # General 4 debug: bool = True 5 outdir: str = "results/det" 6 device: str = "cuda:0" 7 8 # Data config 9 imgdir_name: str = "vinbigdata-chest-xray-resized-png-256x256" 10 # split_mode: str = "all_train" 11 seed: int = 111 12 target_fold: int = 0 # 0~4 13 label_smoothing: float = 0.0 14 # Model config 15 model_name: str = "tf_inception_v3" 16 model_mode: str = "normal" # normal, cnn_fixed supported 17 # Training config 18 epoch: int = 20 19 batchsize: int = 8 20 valid_batchsize: int = 16 21 num_workers: int = 4 22 snapshot_freq: int = 5 23 ema_decay: float = 0.999 # negative value is to inactivate ema. 24 scheduler_type: str = "" 25 scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {}) 26 scheduler_trigger: List[Union[int, str]] = field(default_factory=lambda: [1, "iteration"]) 27 aug_kwargs: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {}) 28 mixup_prob: float = -1.0 # Apply mixup augmentation when positive value is set. 29 30 def update(self, param_dict: Dict) -> "Flags": 31 # Overwrite by `param_dict` 32 for key, value in param_dict.items(): 33 if not hasattr(self, key): 34 raise ValueError(f"[ERROR] Unexpected key for flag = {key}") 35 setattr(self, key, value) 36 return self 37 38flags_dict = { 39 "debug": False, # Change to True for fast debug run! 40 "outdir": "results/tmp_debug", 41 # Data 42 "imgdir_name": "vinbigdata-chest-xray-resized-png-256x256", 43 # Model 44 "model_name": "tf_inception_v3",#"resnet18", 45 # Training 46 "num_workers": 4, 47 "epoch": 15, 48 "batchsize": 8, 49 "scheduler_type": "CosineAnnealingWarmRestarts", 50 "scheduler_kwargs": {"T_0": 28125}, # 15000 * 15 epoch // (batchsize=8) 51 "scheduler_trigger": [1, "iteration"], 52 "aug_kwargs": { 53 "HorizontalFlip": {"p": 0.5}, 54 "ShiftScaleRotate": {"scale_limit": 0.15, "rotate_limit": 10, "p": 0.5}, 55 "RandomBrightnessContrast": {"p": 0.5}, 56 "CoarseDropout": {"max_holes": 8, "max_height": 25, "max_width": 25, "p": 0.5}, 57 "Blur": {"blur_limit": [3, 7], "p": 0.5}, 58 "Downscale": {"scale_min": 0.25, "scale_max": 0.9, "p": 0.3}, 59 "RandomGamma": {"gamma_limit": [80, 120], "p": 0.6}, 60 } 61} 62 63# args = parse() 64print("torch", torch.__version__) 65flags = Flags().update(flags_dict) 66print("flags", flags) 67debug = flags.debug 68outdir = Path(flags.outdir) 69os.makedirs(str(outdir), exist_ok=True) 70flags_dict = dataclasses.asdict(flags) 71save_yaml(str(outdir / "flags.yaml"), flags_dict) 72 73# --- Read data --- 74inputdir = Path("/kaggle/input") 75datadir = inputdir / "vinbigdata-chest-xray-abnormalities-detection" 76imgdir = inputdir / flags.imgdir_name 77 78# Read in the data CSV files 79train = pd.read_csv(datadir / "train.csv") 80 81class CNNFixedPredictor(nn.Module): 82 def __init__(self, cnn: nn.Module, num_classes: int = 2): 83 super(CNNFixedPredictor, self).__init__() 84 self.cnn = cnn 85 self.lin = Linear(cnn.num_features, num_classes) 86 print("cnn.num_features", cnn.num_features) 87 88 # We do not learn CNN parameters. 89 # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html 90 for param in self.cnn.parameters(): 91 param.requires_grad = False 92 93 def forward(self, x): 94 feat = self.cnn(x) 95 return self.lin(feat) 96 97def build_predictor(model_name: str, model_mode: str = "normal"): 98 if model_mode == "normal": 99 # normal configuration. train all parameters. 100 return timm.create_model(model_name, pretrained=True, num_classes=2, in_chans=3) 101 elif model_mode == "cnn_fixed": 102 # normal configuration. train all parameters. 103 # https://rwightman.github.io/pytorch-image-models/feature_extraction/ 104 timm_model = timm.create_model(model_name, pretrained=True, num_classes=0, in_chans=3) 105 return CNNFixedPredictor(timm_model, num_classes=2) 106 else: 107 raise ValueError(f"[ERROR] Unexpected value model_mode={model_mode}") 108 109class Classifier(nn.Module): 110 """two class classfication""" 111 112 def __init__(self, predictor, lossfun=cross_entropy_with_logits): 113 super().__init__() 114 self.predictor = predictor 115 self.lossfun = lossfun 116 self.prefix = "" 117 118 def forward(self, image, targets): 119 outputs = self.predictor(image) 120 loss = self.lossfun(outputs, targets) 121 metrics = { 122 f"{self.prefix}loss": loss.item(), 123 f"{self.prefix}acc": accuracy_with_logits(outputs, targets).item() 124 } 125 ppe.reporting.report(metrics, self) 126 return loss, metrics 127 128 def predict(self, data_loader): 129 pred = self.predict_proba(data_loader) 130 label = torch.argmax(pred, dim=1) 131 return label 132 133 def predict_proba(self, data_loader): 134 device: torch.device = next(self.parameters()).device 135 y_list = [] 136 self.eval() 137 with torch.no_grad(): 138 for batch in data_loader: 139 if isinstance(batch, (tuple, list)): 140 # Assumes first argument is "image" 141 batch = batch[0].to(device) 142 else: 143 batch = batch.to(device) 144 y = self.predictor(batch) 145 y = torch.softmax(y, dim=-1) 146 y_list.append(y) 147 pred = torch.cat(y_list) 148 return pred 149 150def create_trainer(model, optimizer, device) -> Engine: 151 model.to(device) 152 153 def update_fn(engine, batch): 154 model.train() 155 optimizer.zero_grad() 156 loss, metrics = model(*[elem.to(device) for elem in batch]) 157 loss.backward() 158 optimizer.step() 159 return metrics 160 trainer = Engine(update_fn) 161 return trainer 162 163skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=flags.seed) 164# skf.get_n_splits(None, None) 165y = np.array([int(len(d["annotations"]) > 0) for d in dataset_dicts]) 166split_inds = list(skf.split(dataset_dicts, y)) 167train_inds, valid_inds = split_inds[flags.target_fold] # 0th fold 168train_dataset = VinbigdataTwoClassDataset( 169 [dataset_dicts[i] for i in train_inds], 170 image_transform=Transform(flags.aug_kwargs), 171 mixup_prob=flags.mixup_prob, 172 label_smoothing=flags.label_smoothing, 173) 174valid_dataset = VinbigdataTwoClassDataset([dataset_dicts[i] for i in valid_inds]) 175 176train_loader = DataLoader( 177 train_dataset, 178 batch_size=flags.batchsize, 179 num_workers=flags.num_workers, 180 shuffle=True, 181 pin_memory=True, 182) 183valid_loader = DataLoader( 184 valid_dataset, 185 batch_size=flags.valid_batchsize, 186 num_workers=flags.num_workers, 187 shuffle=False, 188 pin_memory=True, 189) 190 191device = torch.device(flags.device) 192 193predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode) 194classifier = Classifier(predictor) 195model = classifier 196optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-3) 197 198# Train setup 199trainer = create_trainer(model, optimizer, device) 200 201ema = EMA(predictor, decay=flags.ema_decay) 202 203def eval_func(*batch): 204 loss, metrics = model(*[elem.to(device) for elem in batch]) 205 # HACKING: report ema value with prefix. 206 if flags.ema_decay > 0: 207 classifier.prefix = "ema_" 208 ema.assign() 209 loss, metrics = model(*[elem.to(device) for elem in batch]) 210 ema.resume() 211 classifier.prefix = "" 212 213valid_evaluator = E.Evaluator( 214 valid_loader, model, progress_bar=False, eval_func=eval_func, device=device 215) 216 217# log_trigger = (10 if debug else 1000, "iteration") 218log_trigger = (1, "epoch") 219log_report = E.LogReport(trigger=log_trigger) 220extensions = [ 221 log_report, 222 E.ProgressBarNotebook(update_interval=10 if debug else 100), # Show progress bar during training 223 E.PrintReportNotebook(), 224 E.FailOnNonNumber(), 225] 226epoch = flags.epoch 227models = {"main": model} 228optimizers = {"main": optimizer} 229manager = IgniteExtensionsManager( 230 trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(outdir), 231) 232# Run evaluation for valid dataset in each epoch. 233manager.extend(valid_evaluator) 234 235# Save predictor.pt every epoch 236manager.extend( 237 E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "epoch") 238) 239 240manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger) 241 242if flags.ema_decay > 0: 243 # Exponential moving average 244 manager.extend(lambda manager: ema(), trigger=(1, "iteration")) 245 246 def save_ema_model(manager): 247 ema.assign() 248 #torch.save(predictor.state_dict(), outdir / "predictor_ema.pt") 249 torch.save(predictor.state_dict(), "/kaggle/working/predictor_ema.pt") 250 ema.resume() 251 252 manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch")) 253 254_ = trainer.run(train_loader, max_epochs=epoch) 255 256torch.save(predictor.state_dict(), "/kaggle/working/predictor_tf_inception_v3.pt") 257df = log_report.to_dataframe() 258df.to_csv("/kaggle/working/log.csv", index=False) 259 260####################################追記終わり##################################################### 261 262#モデルの読み込み 263model_path = './predictor_last.pt' 264model = classifier(flags.aug_kwargs) 265model.load_state_dict(torch.load(model_path))

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

meg_

2021/02/28 12:03

質問のコードには変数modelの定義が抜けています。追記してください。
SuzuAya

2021/02/28 12:17

>meg_様 ありがとうございます。 確認の上、追記させていただきます。 Pytorchに不慣れであるため、少々お時間をいただくかもしれず、また、追記できたとしても内容が不十分なものとなるかもしれません。その際はお手数ですがご指摘いただけますと幸いです。
meg_

2021/02/28 13:06

> AttributeError: 'str' object has no attribute 'load_state_dict' 上記エラーが出ているということは、下記コードを実行する前に定義自体はされていることになります。 ----> 3 model = model.load_state_dict(torch.load(model_path)) 実際に実行したコードを掲載していただければ良いだけですよ。
jbpb0

2021/02/28 16:01

model = model.load_state_dict(torch.load(model_path)) ↓ model.load_state_dict(torch.load(model_path)) ここでは「model =」は要りません
SuzuAya

2021/03/01 06:44

>meg様 ありがとうございます。かなり長いのですが、学習コードを追記しました。 学習とモデルの保存まではうまくいくのですが、保存したモデルを別のNotebookで簡単に読み込めるようになれば便利だと感じ、試したところエラーが出た次第です。
SuzuAya

2021/03/01 06:44

>jbpb0様 ありがとうございます!初歩的なところで色々やらかしており恐れ入ります…。
meg_

2021/03/01 07:10

model*の定義元は下記ですかね? predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode) classifier = Classifier(predictor) model = classifier ただエラーによるとどこかでstr型に書き換わってしまったようですね。
SuzuAya

2021/03/01 08:03

>meg_様 はい、ご理解の通りです。 分かりにくくて申し訳ないのですが、以下のコードは、学習後、学習とは別のnotebookで試したものなんです(学習の際に保存したモデルを、別のnotebookからでも読み込めるのか試したかったため)。 それゆえにstr型と認識されてしまい、エラーが発生してしまっているのかもしれません。何か解決方法はあるでしょうか。 #モデルの読み込み model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt' model = model.load_state_dict(torch.load(model_path))
jbpb0

2021/03/01 09:26 編集

1. 推論時のコード内にも、学習時のコードと同じモデルの定義をPythonで書く (いろいろ略) model =... 2. 上記で定義したmodelで、下記を実行 (ここでは「model =」を付けてはダメ) model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt' model.load_state_dict(torch.load(model_path))
SuzuAya

2021/03/01 12:36

>jbpb0様 ご助言ありがとうございます!一旦、学習と同じnotebookでモデルの読み込みができるか試したところ、エラーメッセージが変わりました…。こちらはどのようにコードを変えれば良いかご存知でしょうか。
meg_

2021/03/01 12:38 編集

> 学習とは別のnotebookで試したものなんです > それゆえにstr型と認識されてしまい、エラーが発生してしまっているのかもしれません。 変数が未定義ですと質問のエラーとは別のエラーが出ますので、そのnotebook内で変数modelに何かしらの文字列を代入済ということになります。何をどう代入したかは質問者さんにしか分かりません。 解決策は皆さんがおっしゃている方法で良いかと思います。
SuzuAya

2021/03/01 12:40

>meg_様 お手数をお掛けしており申し訳ありません。一旦、学習と同じnotebookでモデルの読み込みができるか試したところ、エラーメッセージが変わったのですが、こちらはどのようにコードを変えれば良いかご存知でしょうか。
jbpb0

2021/03/01 12:53 編集

「該当のソースコード」の、どこからどこまでが実際に同じファイルなのか分かりません 「該当のソースコード」の全体が一つのファイルなのですか? 学習と推論でファイルを分けてるのですよね? 推論で使ってるファイルに実際書かれているのはどの部分なのでしょうか? また、推論で使ってるファイルから他のファイルをimportしてたりするなら、その関係も分かるようにしてください あと、 torch.save(predictor.state_dict()... で保存したファイルを読むのだから、 model.load_state_dict(... の「model」にも学習時の「predictor」と同じものを代入しないといけないような 【追記】 下記を読み飛ばしていました 失礼しました > 一旦、学習と同じnotebookでモデルの読み込みができるか試した 「該当のソースコード」全体が、実際でも一つのファイルなのですか?
jbpb0

2021/03/01 23:43 編集

https://betashort-lab.com/%E3%83%87%E3%83%BC%E3%82%BF%E3%82%B5%E3%82%A4%E3%82%A8%E3%83%B3%E3%82%B9/%E3%83%87%E3%82%A3%E3%83%BC%E3%83%97%E3%83%A9%E3%83%BC%E3%83%8B%E3%83%B3%E3%82%B0/pytorch%E3%81%AE%E3%83%A2%E3%83%87%E3%83%AB%E7%AE%A1%E7%90%86%E3%81%A8%E3%83%91%E3%83%A9%E3%83%A1%E3%83%BC%E3%82%BF%E4%BF%9D%E5%AD%98%E3%83%AD%E3%83%BC%E3%83%89/ を見てください 学習用コードも、推論用コードも、下記が共通してます from models.model import Net model = Net()... models.modelのNetでニューラルネットの構造が定義されていて、両方のコードで共通してそれを使っているので、両方のコードで「model」は共通の構造です 学習用コードでは、 optimizer = torch.optim.Adam(model.parameters(),... で「model」の学習の設定がされ、(省略されてますが)学習され、 torch.save(model.state_dict(),... で学習された「model」のパラメータをファイルに保存してます 推論用コードでは、上記で保存したファイルを、 model.load_state_dict(torch.load(... で「model」に読み込んでます
guest

回答2

0

modelの定義が間違っています。
modelは TheModelClass のインスタンスでなければなりません。

AttributeError: 'str' object has no attribute 'load_state_dict'

というエラーメッセージによれば、現在のmodelは文字列なのでエラーだと言っています。

投稿2021/02/28 12:17

ppaul

総合スコア24670

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

SuzuAya

2021/03/01 12:36

>ppaul様 ご回答ありがとうございます! 一旦、学習と同じnotebookでモデルの読み込みができるか試したところ、エラーメッセージが変わりました…。こちらはどのようにコードを変えればエラーが消えるかお分かりになるでしょうか。
guest

0

自己解決

以下の通り、modelではなくpredictorとして呼び出すことでうまくいきました!
たくさんアドバイスをくださり、ありがとうございました。

Python

1model_path = './predictor_last.pt' 2predictor.load_state_dict(torch.load(model_path))

投稿2021/03/05 15:48

SuzuAya

総合スコア71

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.36%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問