質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

2852閲覧

[TensorFlow]学習したmodelを保存したい

Amanokawa

総合スコア41

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2018/12/20 11:35

編集2018/12/22 00:53

tensorflowを使ってGANを学習させているのですが、generatorとdiscriminatorを別々に保存・復元できるようにしたいです。
graphはそれぞれ"gen"と"dis"のscopeに属すよう作成しました。
そのうえで、saverをそれぞれ以下のように定義し、

python

1gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "gen") 2dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "dis") 3saver_gen = tf.train.Saver(gen_var) 4saver_dis = tf.train.Saver(dis_var)

学習後それぞれの保存と

python

1saver_gen.save(sess, MODEL_GEN_DIR + "model.ckpt") 2saver_dis.save(sess, MODEL_DIS_DIR + "model.ckpt")

復元をこのように記述しています。

python

1sess.tf.global_variables_initializer() 2saver_gen.restore(sess, MODEL_GEN_DIR + "model.ckpt") 3saver_dis.restore(sess, MODEL_DIS_DIR + "model.ckpt")

saveとrestoreを行う際はエラーも出ず一見うまくいっているようなのですが、出力が保存時と復元時で異なっているようなのです。
variableを詳しく調べてみると、tf.contrib.layers.batch_norm()で使用される'(スコープ)/moving_mean'と'(スコープ)/moving_variance'のvariableが保存されていない事がわかりました。
どうやら、batch_normのvariableの内gammaとbetaはtf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)で取得できるが、moving_meanとmoving_varianceは取得できていないようです。

この2変数の取得はどのように行えばよいでしょうか?御教示お願いします。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

optimizerにはtf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)で取得したリストを使い、
tf.train.Saverにはtf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)で取得したリストを入れる事で解決いたしました。

GLOBAL_VARIABLESはoptimizer定義前に取得しないと、optimizerのvariableも入ってしまう。
optimizerにGLOBAL_VARIABLESの変数を渡すと、trainable=Falseのvariableも勾配計算してしまう。
等、graphの背景に対しての理解が甘かったようです。

投稿2018/12/27 05:53

Amanokawa

総合スコア41

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問