前提・実現したいこと
tensorflowで関数(defを用いるもの)の中に定義した変数(tf.Variable)を最適化関数を使用せずに更新して、保存(save)、再利用(restore)したい。
発生している問題
バッチの正規化を行う関数(def)を作成してCNNを学習されているのだが、1枚のデータでテストを行う際にバッチの正規化を正常に行うことができない。
is_training=Falseにした際にpop_mean,pop_varが学習によって更新された値(平均のようなもの)を出力するのではなく、初期値を返してしまう。(下のコード参照)
該当のソースコード
Python
1def batch_norm_wrapper(inputs, filter_shape, is_training, decay = 0.999): 2 epsilon = 1e-3 3 out_channels = filter_shape[3] 4 scale = tf.Variable(tf.ones([out_channels])) 5 beta = tf.Variable(tf.zeros([out_channels])) 6 pop_mean = tf.Variable(tf.zeros([out_channels]), trainable=False) 7 pop_var = tf.Variable(tf.ones([out_channels]), trainable=False) 8 9 if is_training: 10 batch_mean, batch_var = tf.nn.moments(inputs,axes=[0,1,2]) 11 train_mean = tf.assign(pop_mean, 12 pop_mean * decay + batch_mean * (1 - decay)) 13 train_var = tf.assign(pop_var, 14 pop_var * decay + batch_var * (1 - decay)) 15 with tf.control_dependencies([train_mean, train_var]): 16 return tf.nn.batch_normalization(inputs, 17 batch_mean, batch_var, beta, scale, epsilon) 18 else: 19 return tf.nn.batch_normalization(inputs, 20 pop_mean, pop_var, beta, scale, epsilon)
Python
1 # ============================save============================ 2 # var_to_start = tf.trainable_variables() 3 # saver = tf.train.Saver(var_to_start) 4 # saver = tf.train.Saver(tf.global_variables()) 5 # saver = tf.train.Saver() 6 var_all = tf.all_variables() 7 saver = tf.train.Saver(var_all) 8 save_path = saver.save(sess, CKPT_PATH + '/model.ckpt', global_step=100) 9 print("Model saved in file: %s"%save_path) 10 k += 1
Python
1# ============================restore->test============================ 2init = tf.global_variables_initializer() 3sess = tf.InteractiveSession() 4sess.run(init) 5# check 6# var_to_start = tf.trainable_variables() 7# saver = tf.train.Saver(var_to_start) 8# saver = tf.train.Saver(tf.global_variables()) 9var_all = tf.all_variables() 10saver = tf.train.Saver(var_all) 11# saver = tf.train.Saver() 12 13ckpt = tf.train.get_checkpoint_state(CKPT_PATH + '/') 14if ckpt: 15 last_model = ckpt.model_checkpoint_path 16 print("load " + last_model) 17 saver.restore(sess, last_model) 18 print("Model restored.") 19 # print(sess.run(beta1)) 20 21 print("test accuracy %g" % accuracy.eval(feed_dict={ 22 x: data.test.images, y_: data.test.labels, keep_prob: 1.0})) 23 # print(W_conv1.eval()) 24else: 25 sess.run(init)
補足情報(FW/ツールのバージョンなど)
参考URL
https://tyfkda.github.io/blog/2016/09/14/batch-norm-mnist.html
https://r2rt.com/implementing-batch-normalization-in-tensorflow.html
回答2件
あなたの回答
tips
プレビュー