質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

89.13%

最適化関数を用いずTensorflowで変数を更新したい

解決済

回答 2

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 1,284

jo-jo-

score 7

 前提・実現したいこと

tensorflowで関数(defを用いるもの)の中に定義した変数(tf.Variable)を最適化関数を使用せずに更新して、保存(save)、再利用(restore)したい。

 発生している問題

バッチの正規化を行う関数(def)を作成してCNNを学習されているのだが、1枚のデータでテストを行う際にバッチの正規化を正常に行うことができない。
is_training=Falseにした際にpop_mean,pop_varが学習によって更新された値(平均のようなもの)を出力するのではなく、初期値を返してしまう。(下のコード参照)

 該当のソースコード

def batch_norm_wrapper(inputs, filter_shape, is_training, decay = 0.999):
    epsilon = 1e-3
    out_channels = filter_shape[3]
    scale = tf.Variable(tf.ones([out_channels]))
    beta = tf.Variable(tf.zeros([out_channels]))
    pop_mean = tf.Variable(tf.zeros([out_channels]), trainable=False)
    pop_var = tf.Variable(tf.ones([out_channels]), trainable=False)

    if is_training:
        batch_mean, batch_var = tf.nn.moments(inputs,axes=[0,1,2])
        train_mean = tf.assign(pop_mean,
                               pop_mean * decay + batch_mean * (1 - decay))
        train_var = tf.assign(pop_var,
                              pop_var * decay + batch_var * (1 - decay))
        with tf.control_dependencies([train_mean, train_var]):
            return tf.nn.batch_normalization(inputs,
                batch_mean, batch_var, beta, scale, epsilon)
    else:
        return tf.nn.batch_normalization(inputs,
            pop_mean, pop_var, beta, scale, epsilon)
    # ============================save============================
    # var_to_start = tf.trainable_variables()
    # saver = tf.train.Saver(var_to_start)
    # saver = tf.train.Saver(tf.global_variables())
    # saver = tf.train.Saver()
    var_all = tf.all_variables()
    saver = tf.train.Saver(var_all)
    save_path = saver.save(sess, CKPT_PATH + '/model.ckpt', global_step=100)
    print("Model saved in file: %s"%save_path)
    k += 1
# ============================restore->test============================
init = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init)
# check
# var_to_start = tf.trainable_variables()
# saver = tf.train.Saver(var_to_start)
# saver = tf.train.Saver(tf.global_variables())
var_all = tf.all_variables()
saver = tf.train.Saver(var_all)
# saver = tf.train.Saver()

ckpt = tf.train.get_checkpoint_state(CKPT_PATH + '/')
if ckpt:
    last_model = ckpt.model_checkpoint_path
    print("load " + last_model)
    saver.restore(sess, last_model)
    print("Model restored.")
    # print(sess.run(beta1))

    print("test accuracy %g" % accuracy.eval(feed_dict={
        x: data.test.images, y_: data.test.labels, keep_prob: 1.0}))
    # print(W_conv1.eval())
else:
    sess.run(init)

 補足情報(FW/ツールのバージョンなど)

参考URL
https://tyfkda.github.io/blog/2016/09/14/batch-norm-mnist.html
https://r2rt.com/implementing-batch-normalization-in-tensorflow.html

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • mkgrei

    2018/03/27 17:06

    「is_training=Falseにした際」には「pop_mean,pop_var」を更新しないのでは?

    キャンセル

回答 2

check解決した方法

0

以下のようにdefを改良したらsaveできるようになりました。

# conv(relu):init
def conv_layer(inpt, filter_shape, stride):
    out_channels = filter_shape[3]
    filter_ = weight_variable(filter_shape)
    conv = tf.nn.conv2d(inpt, filter=filter_, strides=[1, stride, stride, 1], padding="SAME")
    decay= tf.constant(0.95, name="delay")

    beta = tf.Variable(tf.zeros([out_channels]), name="beta")
    gamma = weight_variable([out_channels], name="gamma")

    mean, var = tf.nn.moments(conv, axes=[0,1,2])

    pop_mean = tf.Variable(tf.zeros([out_channels]), trainable=False, name="pop_mean")
    pop_var = tf.Variable(tf.ones([out_channels]), trainable=False, name="pop_var")
    pop_mean_update = tf.assign(pop_mean, pop_mean * decay + mean * (1 - decay))
    pop_var_update = tf.assign(pop_var, pop_var * decay + var * (1 - decay))

    with tf.control_dependencies([pop_mean_update, pop_var_update]):
        batch_norm = tf.nn.batch_norm_with_global_normalization(conv, pop_mean_update, pop_var_update, beta, gamma, 0.001,
                                                                scale_after_normalization=True)

    out = leaky_relu(batch_norm, 1e-3)

    return out

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

0

https://r2rt.com/implementing-batch-normalization-in-tensorflow.html

こちらのサイトにあるように学習後に重みを保存しておいて、判定の前に読み込んでいます。

#Build training graph, train and save the trained model
sess.close()
tf.reset_default_graph()
(x, y_), train_step, accuracy, _, saver = build_graph(is_training=True)

acc = []
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in tqdm.tqdm(range(10000)):
        batch = mnist.train.next_batch(60)
        train_step.run(feed_dict={x: batch[0], y_: batch[1]})
        if i % 50 is 0:
            res = sess.run([accuracy],feed_dict={x: mnist.test.images, y_: mnist.test.labels})
            acc.append(res[0])
    saved_model = saver.save(sess, './temp-bn-save')    #<---------------------------------------

print("Final accuracy:", acc[-1])
tf.reset_default_graph()
(x, y_), _, accuracy, y, saver = build_graph(is_training=False)

predictions = []
correct = 0
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, './temp-bn-save')   #<---------------------------------------------
    for i in range(100):
        pred, corr = sess.run([tf.arg_max(y,1), accuracy],
                             feed_dict={x: [mnist.test.images[i]], y_: [mnist.test.labels[i]]})
        correct += corr
        predictions.append(pred[0])
print("PREDICTIONS:", predictions)
print("ACCURACY:", correct/100)

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2018/03/27 19:12

    回答ありがとうございます。
    質問がわかりにくく、申し訳ございません。
    最適化関数により更新させた変数はsaveとrestoreによって保持されるのですがpop_varとpop_meanをほじすることができていないという現状です。どのようなことが原因なのでしょうか。

    キャンセル

  • 2018/03/27 19:17

    コードを追加しました。saverが間違っているからでしょうか?

    キャンセル

  • 2018/03/28 16:08

    https://stackoverflow.com/questions/36113090/how-to-get-the-global-step-when-restoring-checkpoints-in-tensorflow

    global_stepのせいだったりしませんか?
    断片的なコードなので確証はありませんが。

    キャンセル

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 89.13%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる