前提・実現したいこと
現在Pytorchでのセグメンテーションを勉強しています。
学習モデルはFusionNetと呼ばれる下記のものを使用しています。
https://github.com/GunhoChoi/FusionNet-Pytorch
学習まではおそらくうまくいっていて、モデルも保存されているのですがそれを読み込んで推論(テスト)を行うことができずに困っています。
Github上には推論のプログラムまでは用意されていなかったので、見よう見まねで自作しようと思っていたのですが、モデルを読み込むところからつまずいてしまいました。
発生している問題・エラーメッセージ
Traceback (most recent call last):
File "visu.py", line 17, in <module>
model.load_state_dict(state_dict)
File "C:\Users\Anaconda3\envs\FusionNet\lib\site-packages\torch\nn\modules\module.py", line 803, in load_state_dict
state_dict = state_dict.copy()
File "C:\Users\Anaconda3\envs\FusionNet\lib\site-packages\torch\nn\modules\module.py", line 576, in getattr
type(self).name, name))
AttributeError: 'DataParallel' object has no attribute 'copy'
該当のソースコード
python
1model = nn.DataParallel(FusionGenerator(3,3,16)).cuda() 2state_dict = torch.load("./model/fusion.pkl", 3 map_location={'cuda:0': 'cpu'}) 4model.load_state_dict(state_dict) 5
試したこと
とりあえず一般的なやり方として、使用したモデルを定義してtorch.loadしようとしましたが、うまくいきませんでした。
DataParallelの構造について理解しようとしましたがなかなか難しく、その読み込み方について教えていただければ幸いです。
補足情報(FW/ツールのバージョンなど)
python 3.5.2
あなたの回答
tips
プレビュー