質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

2回答

300閲覧

最適化関数を用いずTensorflowで変数を更新したい

jo-jo-

総合スコア7

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

1グッド

0クリップ

投稿2018/03/26 18:32

編集2018/03/27 10:16

前提・実現したいこと

tensorflowで関数(defを用いるもの)の中に定義した変数(tf.Variable)を最適化関数を使用せずに更新して、保存(save)、再利用(restore)したい。

発生している問題

バッチの正規化を行う関数(def)を作成してCNNを学習されているのだが、1枚のデータでテストを行う際にバッチの正規化を正常に行うことができない。
is_training=Falseにした際にpop_mean,pop_varが学習によって更新された値(平均のようなもの)を出力するのではなく、初期値を返してしまう。(下のコード参照)

該当のソースコード

Python

1def batch_norm_wrapper(inputs, filter_shape, is_training, decay = 0.999): 2 epsilon = 1e-3 3 out_channels = filter_shape[3] 4 scale = tf.Variable(tf.ones([out_channels])) 5 beta = tf.Variable(tf.zeros([out_channels])) 6 pop_mean = tf.Variable(tf.zeros([out_channels]), trainable=False) 7 pop_var = tf.Variable(tf.ones([out_channels]), trainable=False) 8 9 if is_training: 10 batch_mean, batch_var = tf.nn.moments(inputs,axes=[0,1,2]) 11 train_mean = tf.assign(pop_mean, 12 pop_mean * decay + batch_mean * (1 - decay)) 13 train_var = tf.assign(pop_var, 14 pop_var * decay + batch_var * (1 - decay)) 15 with tf.control_dependencies([train_mean, train_var]): 16 return tf.nn.batch_normalization(inputs, 17 batch_mean, batch_var, beta, scale, epsilon) 18 else: 19 return tf.nn.batch_normalization(inputs, 20 pop_mean, pop_var, beta, scale, epsilon)

Python

1 # ============================save============================ 2 # var_to_start = tf.trainable_variables() 3 # saver = tf.train.Saver(var_to_start) 4 # saver = tf.train.Saver(tf.global_variables()) 5 # saver = tf.train.Saver() 6 var_all = tf.all_variables() 7 saver = tf.train.Saver(var_all) 8 save_path = saver.save(sess, CKPT_PATH + '/model.ckpt', global_step=100) 9 print("Model saved in file: %s"%save_path) 10 k += 1

Python

1# ============================restore->test============================ 2init = tf.global_variables_initializer() 3sess = tf.InteractiveSession() 4sess.run(init) 5# check 6# var_to_start = tf.trainable_variables() 7# saver = tf.train.Saver(var_to_start) 8# saver = tf.train.Saver(tf.global_variables()) 9var_all = tf.all_variables() 10saver = tf.train.Saver(var_all) 11# saver = tf.train.Saver() 12 13ckpt = tf.train.get_checkpoint_state(CKPT_PATH + '/') 14if ckpt: 15 last_model = ckpt.model_checkpoint_path 16 print("load " + last_model) 17 saver.restore(sess, last_model) 18 print("Model restored.") 19 # print(sess.run(beta1)) 20 21 print("test accuracy %g" % accuracy.eval(feed_dict={ 22 x: data.test.images, y_: data.test.labels, keep_prob: 1.0})) 23 # print(W_conv1.eval()) 24else: 25 sess.run(init)

補足情報(FW/ツールのバージョンなど)

参考URL
https://tyfkda.github.io/blog/2016/09/14/batch-norm-mnist.html
https://r2rt.com/implementing-batch-normalization-in-tensorflow.html

tachikoma👍を押しています

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

mkgrei

2018/03/27 08:06

「is_training=Falseにした際」には「pop_mean,pop_var」を更新しないのでは?
guest

回答2

0

自己解決

以下のようにdefを改良したらsaveできるようになりました。

Python

1# conv(relu):init 2def conv_layer(inpt, filter_shape, stride): 3 out_channels = filter_shape[3] 4 filter_ = weight_variable(filter_shape) 5 conv = tf.nn.conv2d(inpt, filter=filter_, strides=[1, stride, stride, 1], padding="SAME") 6 decay= tf.constant(0.95, name="delay") 7 8 beta = tf.Variable(tf.zeros([out_channels]), name="beta") 9 gamma = weight_variable([out_channels], name="gamma") 10 11 mean, var = tf.nn.moments(conv, axes=[0,1,2]) 12 13 pop_mean = tf.Variable(tf.zeros([out_channels]), trainable=False, name="pop_mean") 14 pop_var = tf.Variable(tf.ones([out_channels]), trainable=False, name="pop_var") 15 pop_mean_update = tf.assign(pop_mean, pop_mean * decay + mean * (1 - decay)) 16 pop_var_update = tf.assign(pop_var, pop_var * decay + var * (1 - decay)) 17 18 with tf.control_dependencies([pop_mean_update, pop_var_update]): 19 batch_norm = tf.nn.batch_norm_with_global_normalization(conv, pop_mean_update, pop_var_update, beta, gamma, 0.001, 20 scale_after_normalization=True) 21 22 out = leaky_relu(batch_norm, 1e-3) 23 24 return out

投稿2018/04/06 02:22

jo-jo-

総合スコア7

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

0

https://r2rt.com/implementing-batch-normalization-in-tensorflow.html

こちらのサイトにあるように学習後に重みを保存しておいて、判定の前に読み込んでいます。

python

1#Build training graph, train and save the trained model 2sess.close() 3tf.reset_default_graph() 4(x, y_), train_step, accuracy, _, saver = build_graph(is_training=True) 5 6acc = [] 7with tf.Session() as sess: 8 sess.run(tf.global_variables_initializer()) 9 for i in tqdm.tqdm(range(10000)): 10 batch = mnist.train.next_batch(60) 11 train_step.run(feed_dict={x: batch[0], y_: batch[1]}) 12 if i % 50 is 0: 13 res = sess.run([accuracy],feed_dict={x: mnist.test.images, y_: mnist.test.labels}) 14 acc.append(res[0]) 15 saved_model = saver.save(sess, './temp-bn-save') #<--------------------------------------- 16 17print("Final accuracy:", acc[-1])

python

1tf.reset_default_graph() 2(x, y_), _, accuracy, y, saver = build_graph(is_training=False) 3 4predictions = [] 5correct = 0 6with tf.Session() as sess: 7 sess.run(tf.global_variables_initializer()) 8 saver.restore(sess, './temp-bn-save') #<--------------------------------------------- 9 for i in range(100): 10 pred, corr = sess.run([tf.arg_max(y,1), accuracy], 11 feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]}) 12 correct += corr 13 predictions.append(pred[0]) 14print("PREDICTIONS:", predictions) 15print("ACCURACY:", correct/100)

投稿2018/03/27 08:50

mkgrei

総合スコア8560

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

jo-jo-

2018/03/27 10:12

回答ありがとうございます。 質問がわかりにくく、申し訳ございません。 最適化関数により更新させた変数はsaveとrestoreによって保持されるのですがpop_varとpop_meanをほじすることができていないという現状です。どのようなことが原因なのでしょうか。
jo-jo-

2018/03/27 10:17

コードを追加しました。saverが間違っているからでしょうか?
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問