前提・実現したいこと
タイトル,質問の内容ともに説明が不十分だったので修正しました.
一度学習した物体検出モデルのクラス数変更のため,学習済みモデルを元に再度学習を行いたい.
pretrained_modelに'imagenet'を設定した場合には最終層の重みはコピーされないのに対し,
セーブしたnpzファイルをロードするとすべての重みをコピーしようとするため,エラーが生じるのだと思います.
コピーできる重みだけをコピーするにはどうすればよいでしょうか.
発生している問題・エラーメッセージ
Traceback (most recent call last): File "<input>", line 1, in <module> File "C:\Users\UserName\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainercv\links\model\faster_rcnn\faster_rcnn_vgg.py", line 141, in __init__ chainer.serializers.load_npz(path, self) File "C:\Users\UserName\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainer\serializers\npz.py", line 243, in load_npz d.load(obj) File "C:\Users\UserName\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainer\serializer.py", line 83, in load obj.serialize(self) File "C:\Users\UserName\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainer\link.py", line 1026, in serialize d[name].serialize(serializer[name]) File "C:\Users\UserName\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainer\link.py", line 1026, in serialize d[name].serialize(serializer[name]) File "C:\Users\UserName\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainer\link.py", line 657, in serialize data = serializer(name, param.data) # type: types.NdArray File "C:\Users\UserName\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainer\serializers\npz.py", line 185, in __call__ numpy.copyto(value, dataset) File "<__array_function__ internals>", line 6, in copyto ValueError: could not broadcast input array from shape (101) into shape (102)
該当のソースコード
python
1import chainer 2import chainercv 3 4model = chainercv.links.FasterRCNNVGG16(n_fg_class=100) 5train_chain = chainercv.links.model.faster_rcnn.FasterRCNNTrainChain(model) 6 7'''' 8trainの処理をしたとする 9''' 10 11chainer.serializers.save_npz('tmp.npz', train_chain.faster_rcnn) 12 13# 検出したい物体が増えたので,上記のモデルを元に再度学習したい 14model = chainercv.links.FasterRCNNVGG16(n_fg_class=101, pretrained_model='tmp.npz') 15train_chain = chainercv.links.model.faster_rcnn.FasterRCNNTrainChain(model) 16 17'''' 18再度trainする 19''' 20 21
補足情報(FW/ツールのバージョンなど)
chainer.version
'7.2.0'
chainercv.version
'0.13.1'
回答3件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。