質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
Chainer

Chainerは、国産の深層学習フレームワークです。あらゆるニューラルネットワークをPythonで柔軟に書くことができ、学習させることが可能。GPUをサポートしており、複数のGPUを用いた学習も直感的に記述できます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

1回答

7127閲覧

chainerでfine tuningした際にsnapshotが読み込めない

hukuda222

総合スコア13

Chainer

Chainerは、国産の深層学習フレームワークです。あらゆるニューラルネットワークをPythonで柔軟に書くことができ、学習させることが可能。GPUをサポートしており、複数のGPUを用いた学習も直感的に記述できます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2018/05/06 11:15

編集2018/05/06 15:08

python

1class VGG(Chain): 2 def __init__(self): 3 super(VGG, self).__init__() 4 5 with self.init_scope(): 6 self.base = L.VGG16Layers() 7 self.classify = L.Linear(None, 20) 8 9 def __call__(self, x): 10 h = self.base(x, layers=['fc7'])['fc7'] 11 return self.classify(h)

上記のようなネットワークで学習させると、学習自体はうまくいくのですが、trainerのsnapshotを読み込む際に、下記のようなエラーがでます。見た所、snapshotのデータにあるべきデータが存在しないようです。

File "train.py", line 100, in main serializers.load_npz(args.resume, trainer) File "/.local/lib/python3.6/site-packages/chainer/serializers/npz.py", line 179, in load_npz d.load(obj) File "/home/.local/lib/python3.6/site-packages/chainer/serializer.py", line 83, in load obj.serialize(self) File "/.local/lib/python3.6/site-packages/chainer/training/trainer.py", line 332, in serialize self.updater.serialize(serializer['updater']) File "/.local/lib/python3.6/site-packages/chainer/training/updaters/standard_updater.py", line 172, in serialize optimizer.serialize(serializer['optimizer:' + name]) File "/.local/lib/python3.6/site-packages/chainer/optimizer.py", line 549, in serialize rule.serialize(serializer[name]) File "/.local/lib/python3.6/site-packages/chainer/optimizer.py", line 295, in serialize self._state[key] = serializer(key, None) File "/.local/lib/python3.6/site-packages/chainer/serializers/npz.py", line 142, in __call__ dataset = self.npz[key] File "/.local/lib/python3.6/site-packages/numpy/lib/npyio.py", line 239, in __getitem__ raise KeyError("%s is not a file in the archive" % key) KeyError: 'updater/optimizer:main/predictor/base/conv2_2/b/m is not a file in the archive'

chainerのバージョンは4.0.0
pythonのバージョンは3.6.5
です。

このエラーの対策をご存知の方がいらっしゃれば、ご教授お願いします。

エラーが再現できたコード

以下のコードで同様のエラーが発生しました。

python

1import numpy as np 2import chainer.links as L 3import chainer.functions as F 4from chainer import dataset, Chain, training, optimizers, \ 5 iterators, reporter, cuda,serializers 6import argparse 7if cuda.available: 8 xp = cuda.cupy 9else: 10 xp = np 11 12class VGG(Chain): 13 def __init__(self): 14 super(VGG, self).__init__() 15 16 with self.init_scope(): 17 self.base = L.VGG16Layers() 18 self.classify = L.Linear(None, 20) 19 20 def __call__(self, x): 21 h = self.base(x, layers=['fc7'])['fc7'] 22 return self.classify(h) 23 24 25class DataSet(dataset.DatasetMixin): 26 def __init__(self): 27 pass 28 29 def __len__(self): 30 return 1 31 32 def get_example(self, _): 33 return xp.ones((3, 224, 224)).astype('float32'), xp.zeros((1,)).astype('int32')[0] 34 35 36def main(): 37 parser = argparse.ArgumentParser() 38 parser.add_argument('--epoch', '-e', type=int, default=2, 39 help='Number of examples in epoch') 40 parser.add_argument('--batchsize', '-b', type=int, default=1, 41 help='Number of examples in each mini-batch') 42 parser.add_argument('--gpu', '-g', type=int, default=-1, 43 help='GPU ID (negative value indicates CPU)') 44 parser.add_argument('--out', '-o', default='result2', 45 help='Directory to output the result') 46 parser.add_argument('--resume', '-r', default='', 47 help='Resume the training from snapshot') 48 49 args = parser.parse_args() 50 51 train_dataset = DataSet() 52 53 model = L.Classifier(VGG()) 54 55 56 if args.gpu >= 0: 57 cuda.get_device_from_id(args.gpu).use() 58 model.to_gpu() 59 60 optimizer = optimizers.Adam() 61 optimizer.setup(model) 62 model.predictor.base.disable_update() 63 64 train_iter = iterators.SerialIterator( 65 train_dataset, batch_size=args.batchsize) 66 67 updater = training.StandardUpdater(train_iter, optimizer) 68 trainer = training.Trainer( 69 updater, (args.epoch, 'epoch'), out=args.out) 70 71 trainer.extend(training.extensions.LogReport( 72 trigger=(1, 'epoch'))) 73 trainer.extend(training.extensions.PrintReport( 74 entries=['iteration', 'main/loss', 75 'main/accuracy', 'elapsed_time']), 76 trigger=(1, 'epoch')) 77 # ここでsnapshotを取っています。 78 trainer.extend(training.extensions.snapshot(), 79 trigger=(1, 'epoch')) 80 81 if args.resume: 82     # ここで読み込んでいます 83 serializers.load_npz(args.resume, trainer) 84 trainer.run() 85 86 87if __name__ == "__main__": 88 main()

以下のように実行すると上記のエラーが生じました。

python test.py --gpu 0 python test.py --gpu 0 --resume ./result2/snapshot_iter_1

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

mkgrei

2018/05/06 12:45 編集

コードが正しく書かれているのであれば、特にエラーが出ることなく実行されるのが普通です。という当たり前のことを申し上げておきます。モデルの保存・読込に該当する部分を記述してください。
hukuda222

2018/05/06 12:52

追記しました。
guest

回答1

0

ベストアンサー

serializers.load_npz(args.resume, trainer, strict=False)

どうしてこのようになってしまうのかは定かではありませんが、snapshotからだとstrict=Falseしないとエラーが出るようです。

serializers.save_npzを使ってmodelやtrainerを保存すると、必要はありません。

バイナリなので、見づらいこと極まりないのですが、snapshotを作る際に何かのpathを壊しているのかもしれません。


ざっと調べた限り、同様のトラブルは見つかりませんでした。
なぜでしょう…


それに加えて、
snapshot_iter_1
デフォルトだとiterの番号になっているので、2,4,...と番号が変化するかと思います。
epochにするためにはfilenameを正しく与えてやる必要があります。
お気をつけください。

投稿2018/05/07 13:51

編集2018/05/07 13:53
mkgrei

総合スコア8560

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

hukuda222

2018/05/07 13:56

無事動きました。ありがとうございました。
kyng

2018/11/28 01:25

https://docs.chainer.org/en/stable/reference/generated/chainer.serializers.load_npz.html?highlight=load_npz strict (bool) – If True, the deserializer raises an error when an expected value is not found in the given NPZ file. Otherwise, it ignores the value and skip deserialization. とありますので、strict=false にすると、NPZファイル内に見つからなかったパラメータがロードされず、期待した振る舞いにならないかもしれないことにはご注意ください。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問