質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
87.20%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

解決済

Pytorchで保存したモデルを読み込みたい

SuzuAya
SuzuAya

総合スコア0

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

2回答

0評価

0クリップ

2417閲覧

投稿2021/02/28 11:52

編集2022/01/12 10:58

前提・実現したいこと

Pytorchで保存したモデルを読み込みたいのですが、エラーが発生してしまいました。
原因についてお分かりの方がいらっしゃいましたら、ご教示いただけますと幸いです。

発生している問題・エラーメッセージ

----> 3 model = model.load_state_dict(torch.load(model_path)) AttributeError: 'str' object has no attribute 'load_state_dict'

該当のソースコード

Python

@dataclass class Flags: # General debug: bool = True outdir: str = "results/det" device: str = "cuda:0" # Data config imgdir_name: str = "vinbigdata-chest-xray-resized-png-256x256" # split_mode: str = "all_train" seed: int = 111 target_fold: int = 0 # 0~4 label_smoothing: float = 0.0 # Model config model_name: str = "tf_inception_v3" model_mode: str = "normal" # normal, cnn_fixed supported # Training config epoch: int = 20 batchsize: int = 8 valid_batchsize: int = 16 num_workers: int = 4 snapshot_freq: int = 5 ema_decay: float = 0.999 # negative value is to inactivate ema. scheduler_type: str = "" scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {}) scheduler_trigger: List[Union[int, str]] = field(default_factory=lambda: [1, "iteration"]) aug_kwargs: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {}) mixup_prob: float = -1.0 # Apply mixup augmentation when positive value is set. def update(self, param_dict: Dict) -> "Flags": # Overwrite by `param_dict` for key, value in param_dict.items(): if not hasattr(self, key): raise ValueError(f"[ERROR] Unexpected key for flag = {key}") setattr(self, key, value) return self flags_dict = { "debug": False, # Change to True for fast debug run! "outdir": "results/tmp_debug", # Data "imgdir_name": "vinbigdata-chest-xray-resized-png-256x256", # Model "model_name": "tf_inception_v3",#"resnet18", # Training "num_workers": 4, "epoch": 15, "batchsize": 8, "scheduler_type": "CosineAnnealingWarmRestarts", "scheduler_kwargs": {"T_0": 28125}, # 15000 * 15 epoch // (batchsize=8) "scheduler_trigger": [1, "iteration"], "aug_kwargs": { "HorizontalFlip": {"p": 0.5}, "ShiftScaleRotate": {"scale_limit": 0.15, "rotate_limit": 10, "p": 0.5}, "RandomBrightnessContrast": {"p": 0.5}, "CoarseDropout": {"max_holes": 8, "max_height": 25, "max_width": 25, "p": 0.5}, "Blur": {"blur_limit": [3, 7], "p": 0.5}, "Downscale": {"scale_min": 0.25, "scale_max": 0.9, "p": 0.3}, "RandomGamma": {"gamma_limit": [80, 120], "p": 0.6}, } } # args = parse() print("torch", torch.__version__) flags = Flags().update(flags_dict) print("flags", flags) debug = flags.debug outdir = Path(flags.outdir) os.makedirs(str(outdir), exist_ok=True) flags_dict = dataclasses.asdict(flags) save_yaml(str(outdir / "flags.yaml"), flags_dict) # --- Read data --- inputdir = Path("/kaggle/input") datadir = inputdir / "vinbigdata-chest-xray-abnormalities-detection" imgdir = inputdir / flags.imgdir_name # Read in the data CSV files train = pd.read_csv(datadir / "train.csv") class CNNFixedPredictor(nn.Module): def __init__(self, cnn: nn.Module, num_classes: int = 2): super(CNNFixedPredictor, self).__init__() self.cnn = cnn self.lin = Linear(cnn.num_features, num_classes) print("cnn.num_features", cnn.num_features) # We do not learn CNN parameters. # https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html for param in self.cnn.parameters(): param.requires_grad = False def forward(self, x): feat = self.cnn(x) return self.lin(feat) def build_predictor(model_name: str, model_mode: str = "normal"): if model_mode == "normal": # normal configuration. train all parameters. return timm.create_model(model_name, pretrained=True, num_classes=2, in_chans=3) elif model_mode == "cnn_fixed": # normal configuration. train all parameters. # https://rwightman.github.io/pytorch-image-models/feature_extraction/ timm_model = timm.create_model(model_name, pretrained=True, num_classes=0, in_chans=3) return CNNFixedPredictor(timm_model, num_classes=2) else: raise ValueError(f"[ERROR] Unexpected value model_mode={model_mode}") class Classifier(nn.Module): """two class classfication""" def __init__(self, predictor, lossfun=cross_entropy_with_logits): super().__init__() self.predictor = predictor self.lossfun = lossfun self.prefix = "" def forward(self, image, targets): outputs = self.predictor(image) loss = self.lossfun(outputs, targets) metrics = { f"{self.prefix}loss": loss.item(), f"{self.prefix}acc": accuracy_with_logits(outputs, targets).item() } ppe.reporting.report(metrics, self) return loss, metrics def predict(self, data_loader): pred = self.predict_proba(data_loader) label = torch.argmax(pred, dim=1) return label def predict_proba(self, data_loader): device: torch.device = next(self.parameters()).device y_list = [] self.eval() with torch.no_grad(): for batch in data_loader: if isinstance(batch, (tuple, list)): # Assumes first argument is "image" batch = batch[0].to(device) else: batch = batch.to(device) y = self.predictor(batch) y = torch.softmax(y, dim=-1) y_list.append(y) pred = torch.cat(y_list) return pred def create_trainer(model, optimizer, device) -> Engine: model.to(device) def update_fn(engine, batch): model.train() optimizer.zero_grad() loss, metrics = model(*[elem.to(device) for elem in batch]) loss.backward() optimizer.step() return metrics trainer = Engine(update_fn) return trainer skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=flags.seed) # skf.get_n_splits(None, None) y = np.array([int(len(d["annotations"]) > 0) for d in dataset_dicts]) split_inds = list(skf.split(dataset_dicts, y)) train_inds, valid_inds = split_inds[flags.target_fold] # 0th fold train_dataset = VinbigdataTwoClassDataset( [dataset_dicts[i] for i in train_inds], image_transform=Transform(flags.aug_kwargs), mixup_prob=flags.mixup_prob, label_smoothing=flags.label_smoothing, ) valid_dataset = VinbigdataTwoClassDataset([dataset_dicts[i] for i in valid_inds]) train_loader = DataLoader( train_dataset, batch_size=flags.batchsize, num_workers=flags.num_workers, shuffle=True, pin_memory=True, ) valid_loader = DataLoader( valid_dataset, batch_size=flags.valid_batchsize, num_workers=flags.num_workers, shuffle=False, pin_memory=True, ) device = torch.device(flags.device) predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode) classifier = Classifier(predictor) model = classifier optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-3) # Train setup trainer = create_trainer(model, optimizer, device) ema = EMA(predictor, decay=flags.ema_decay) def eval_func(*batch): loss, metrics = model(*[elem.to(device) for elem in batch]) # HACKING: report ema value with prefix. if flags.ema_decay > 0: classifier.prefix = "ema_" ema.assign() loss, metrics = model(*[elem.to(device) for elem in batch]) ema.resume() classifier.prefix = "" valid_evaluator = E.Evaluator( valid_loader, model, progress_bar=False, eval_func=eval_func, device=device ) # log_trigger = (10 if debug else 1000, "iteration") log_trigger = (1, "epoch") log_report = E.LogReport(trigger=log_trigger) extensions = [ log_report, E.ProgressBarNotebook(update_interval=10 if debug else 100), # Show progress bar during training E.PrintReportNotebook(), E.FailOnNonNumber(), ] epoch = flags.epoch models = {"main": model} optimizers = {"main": optimizer} manager = IgniteExtensionsManager( trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(outdir), ) # Run evaluation for valid dataset in each epoch. manager.extend(valid_evaluator) # Save predictor.pt every epoch manager.extend( E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "epoch") ) manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger) if flags.ema_decay > 0: # Exponential moving average manager.extend(lambda manager: ema(), trigger=(1, "iteration")) def save_ema_model(manager): ema.assign() #torch.save(predictor.state_dict(), outdir / "predictor_ema.pt") torch.save(predictor.state_dict(), "/kaggle/working/predictor_ema.pt") ema.resume() manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch")) _ = trainer.run(train_loader, max_epochs=epoch) torch.save(predictor.state_dict(), "/kaggle/working/predictor_tf_inception_v3.pt") df = log_report.to_dataframe() df.to_csv("/kaggle/working/log.csv", index=False) ####################################追記終わり##################################################### #モデルの読み込み model_path = '../input/detectron2-model/predictor_tf_inception_v3.pt' model = model.load_state_dict(torch.load(model_path))

良い質問の評価を上げる

以下のような質問は評価を上げましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

  • プログラミングに関係のない質問
  • やってほしいことだけを記載した丸投げの質問
  • 問題・課題が含まれていない質問
  • 意図的に内容が抹消された質問
  • 過去に投稿した質問と同じ内容の質問
  • 広告と受け取られるような投稿

評価を下げると、トップページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

meg_
meg_

2021/02/28 12:03

質問のコードには変数modelの定義が抜けています。追記してください。
SuzuAya
SuzuAya

2021/02/28 12:17

>meg_様 ありがとうございます。 確認の上、追記させていただきます。 Pytorchに不慣れであるため、少々お時間をいただくかもしれず、また、追記できたとしても内容が不十分なものとなるかもしれません。その際はお手数ですがご指摘いただけますと幸いです。
meg_
meg_

2021/02/28 13:06

> AttributeError: 'str' object has no attribute 'load_state_dict' 上記エラーが出ているということは、下記コードを実行する前に定義自体はされていることになります。 ----> 3 model = model\.load_state_dict\(torch\.load\(model_path\)\) 実際に実行したコードを掲載していただければ良いだけですよ。
jbpb0
jbpb0

2021/02/28 16:01

model = model\.load_state_dict\(torch\.load\(model_path\)\) ↓ model\.load_state_dict\(torch\.load\(model_path\)\) ここでは「model =」は要りません
SuzuAya
SuzuAya

2021/03/01 06:44

>meg様 ありがとうございます。かなり長いのですが、学習コードを追記しました。 学習とモデルの保存まではうまくいくのですが、保存したモデルを別のNotebookで簡単に読み込めるようになれば便利だと感じ、試したところエラーが出た次第です。
SuzuAya
SuzuAya

2021/03/01 06:44

>jbpb0様 ありがとうございます!初歩的なところで色々やらかしており恐れ入ります…。
meg_
meg_

2021/03/01 07:10

model\*の定義元は下記ですかね? predictor = build_predictor\(model_name=flags\.model_name, model_mode=flags\.model_mode\) classifier = Classifier\(predictor\) model = classifier ただエラーによるとどこかでstr型に書き換わってしまったようですね。
SuzuAya
SuzuAya

2021/03/01 08:03

>meg_様 はい、ご理解の通りです。 分かりにくくて申し訳ないのですが、以下のコードは、学習後、学習とは別のnotebookで試したものなんです(学習の際に保存したモデルを、別のnotebookからでも読み込めるのか試したかったため)。 それゆえにstr型と認識されてしまい、エラーが発生してしまっているのかもしれません。何か解決方法はあるでしょうか。 #モデルの読み込み model_path = '\.\./input/detectron2-model/predictor_tf_inception_v3\.pt' model = model\.load_state_dict\(torch\.load\(model_path\)\)
jbpb0
jbpb0

2021/03/01 09:26 編集

1\. 推論時のコード内にも、学習時のコードと同じモデルの定義をPythonで書く \(いろいろ略\) model =\.\.\. 2\. 上記で定義したmodelで、下記を実行 \(ここでは「model =」を付けてはダメ\) model_path = '\.\./input/detectron2-model/predictor_tf_inception_v3\.pt' model\.load_state_dict\(torch\.load\(model_path\)\)
jbpb0
jbpb0

2021/03/01 09:25

「state_dict\(\)」を使う場合と使わない場合で、保存・読み込みのやり方が違います 下記を見てください https://www\.javaer101\.com/en/article/919582\.html
SuzuAya
SuzuAya

2021/03/01 12:36

>jbpb0様 ご助言ありがとうございます!一旦、学習と同じnotebookでモデルの読み込みができるか試したところ、エラーメッセージが変わりました…。こちらはどのようにコードを変えれば良いかご存知でしょうか。
meg_
meg_

2021/03/01 12:38 編集

> 学習とは別のnotebookで試したものなんです > それゆえにstr型と認識されてしまい、エラーが発生してしまっているのかもしれません。 変数が未定義ですと質問のエラーとは別のエラーが出ますので、そのnotebook内で変数modelに何かしらの文字列を代入済ということになります。何をどう代入したかは質問者さんにしか分かりません。 解決策は皆さんがおっしゃている方法で良いかと思います。
SuzuAya
SuzuAya

2021/03/01 12:40

>meg_様 お手数をお掛けしており申し訳ありません。一旦、学習と同じnotebookでモデルの読み込みができるか試したところ、エラーメッセージが変わったのですが、こちらはどのようにコードを変えれば良いかご存知でしょうか。
jbpb0
jbpb0

2021/03/01 12:53 編集

「該当のソースコード」の、どこからどこまでが実際に同じファイルなのか分かりません 「該当のソースコード」の全体が一つのファイルなのですか? 学習と推論でファイルを分けてるのですよね? 推論で使ってるファイルに実際書かれているのはどの部分なのでしょうか? また、推論で使ってるファイルから他のファイルをimportしてたりするなら、その関係も分かるようにしてください あと、 torch\.save\(predictor\.state_dict\(\)\.\.\. で保存したファイルを読むのだから、 model\.load_state_dict\(\.\.\. の「model」にも学習時の「predictor」と同じものを代入しないといけないような 【追記】 下記を読み飛ばしていました 失礼しました > 一旦、学習と同じnotebookでモデルの読み込みができるか試した 「該当のソースコード」全体が、実際でも一つのファイルなのですか?
jbpb0
jbpb0

2021/03/01 23:43 編集

https://betashort-lab\.com/%E3%83%87%E3%83%BC%E3%82%BF%E3%82%B5%E3%82%A4%E3%82%A8%E3%83%B3%E3%82%B9/%E3%83%87%E3%82%A3%E3%83%BC%E3%83%97%E3%83%A9%E3%83%BC%E3%83%8B%E3%83%B3%E3%82%B0/pytorch%E3%81%AE%E3%83%A2%E3%83%87%E3%83%AB%E7%AE%A1%E7%90%86%E3%81%A8%E3%83%91%E3%83%A9%E3%83%A1%E3%83%BC%E3%82%BF%E4%BF%9D%E5%AD%98%E3%83%AD%E3%83%BC%E3%83%89/ を見てください 学習用コードも、推論用コードも、下記が共通してます from models\.model import Net model = Net\(\)\.\.\. models\.modelのNetでニューラルネットの構造が定義されていて、両方のコードで共通してそれを使っているので、両方のコードで「model」は共通の構造です 学習用コードでは、 optimizer = torch\.optim\.Adam\(model\.parameters\(\),\.\.\. で「model」の学習の設定がされ、\(省略されてますが\)学習され、 torch\.save\(model\.state_dict\(\),\.\.\. で学習された「model」のパラメータをファイルに保存してます 推論用コードでは、上記で保存したファイルを、 model\.load_state_dict\(torch\.load\(\.\.\. で「model」に読み込んでます

まだ回答がついていません

会員登録して回答してみよう

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
87.20%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。