質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

91.05%

  • Python

    5154questions

    Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

  • TensorFlow

    450questions

tensorflow1.4.0での変数のrestoreについて

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 1
  • VIEW 212

KEEL

score 1

Windows10, Anaconda上でtensorflow1.4.0を用いて学習済みモデルの読み込みをしたいと考えています。そこでMNISTのコードを改変しましたが、現状、変数がrestoreされず困っています。

変数をsaveするコード

このコードではrestoreができているようです。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# データ読み込み
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# placeholder用意 xは学習用画像
x = tf.placeholder(tf.float32, [None, 784])
# y_は学習用ラベル
y_ = tf.placeholder(tf.float32, [None, 10])

# weightとbias
# さっきの例ではw * xだったけど、今回はw * x + b
W = tf.Variable(tf.zeros([784, 10]),name='W')
b = tf.Variable(tf.zeros([10]),name='b')

# Softmax Regressionを使う yはモデル
y = tf.nn.softmax(tf.matmul(x, W) + b)

# 交差エントロピー
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

# 先ほど使ったGradientDescentOptimizerで、今回はcross_entropyを利用
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

model_path = './model/model.ckpt'
  # 初期化
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# 学習
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if i == 999:    
        saver = tf.train.Saver()
        saver.save(sess,model_path)
        print("model saved" + model_path)
        print(sess.run(W))
        print(sess.run(b))

# テストデータで予測
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
print('accuracy : '+str(acc));
sess.close()

sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)

saver.restore(sess,model_path)
print(sess.run(W))
print(sess.run(b))
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
print('accuracy : '+str(acc))

実行時コンソール出力
この時Wとbの値を表示させておりどちらも同じ値が確認できます。

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
model saved./model/model.ckpt
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
[-0.61506635  0.45124692  0.21486145 -0.35527256 -0.03915499  1.95576119
 -0.17007764  0.89766943 -1.96336353 -0.37660047]
accuracy : 0.9153
INFO:tensorflow:Restoring parameters from ./model/model.ckpt
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
[-0.61506635  0.45124692  0.21486145 -0.35527256 -0.03915499  1.95576119
 -0.17007764  0.89766943 -1.96336353 -0.37660047]
accuracy : 0.9153

変数をrestoreするコード

このコードでは変数のrestoreが出来ていないようです。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])

W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
model_path = './model/model.ckpt'

ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt)

saver.restore(sess,model_path)
print(sess.run(W))
print(sess.run(b))
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
print('accuracy : '+str(acc));

実行時コンソール出力
こちらでもWとbの値を表示させていますがbの値が正しく読み込まれていないようです。

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
model_checkpoint_path: "./model/model.ckpt"
all_model_checkpoint_paths: "./model/model.ckpt"

INFO:tensorflow:Restoring parameters from ./model/model.ckpt
accuracy : 0.098
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]

学習済みモデルの読み込みをし、正解率を表示させたいのですがsaver.restore(sess,model_path)の命令で変数が読み込まれると解釈しましたが、コンソールの出力を見るとbの値が読み込めていないように見受けられます。
そこでなんらかの方法で学習済みモデルを復元する方法を教えていただきたいです。
まだ勉強を始めて日が浅く、分からないことが多いため初歩的な質問で申し訳ないですが回答をいただけると幸いです。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • 退会済みユーザー

    退会済みユーザー

    2018/01/04 03:11 編集

    変数をsaveするコード-->restore OK、変数をrestoreするコード-->restore NGのようですが、「変数をrestoreするコード」のように改変したところ、値がきちんとロードされなくなったということでしょうか?もう少し、○○をしたくて、××をしたけれども、どこの部分(どの行で)どうならなくて困っている、という書き方をした方が回答が得られやすいと思います。

    キャンセル

  • KEEL

    2018/01/04 10:56

    アドバイスありがとうございますご指摘の通りです

    キャンセル

回答 1

checkベストアンサー

+1

saver = tf.train.import_meta_graph(...)という書き方を初めて見ました。
試していないのでいけませんが、ここが怪しい気がします。

気になって調べると、以下の関係がある事が分かります。

操作対象 書き出し 読み出し
セッション saver.save(...) saver.restore
メタグラフ export_meta_graph import_meta_graph

問題はセッションの復旧だと思いますので、以下のようにして、メタグラフのコードをバッサリ削ってはいかがでしょうか?

# saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
saver = tf.train.Saver()
model_path = './model/model.ckpt'

ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt)

saver.restore(sess,model_path)

2018-01-04 16:55 追記

書き替えを適用するとこんな感じです。
確認のため学習を10STEPくらいで止めたものを保存しました。
結果、Accuracyは0.90弱でしたが、ロード後(以下のコードです)で学習を飛ばして判定させても結果は0.90弱になりました。

ロード後のイメージ写真を貼りますね。
イメージ説明

# coding: UTF-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# データ読み込み
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# placeholder用意 xは学習用画像
x = tf.placeholder(tf.float32, [None, 784])
# y_は学習用ラベル
y_ = tf.placeholder(tf.float32, [None, 10])

# weightとbias
# さっきの例ではw * xだったけど、今回はw * x + b
W = tf.Variable(tf.zeros([784, 10]),name='W')
b = tf.Variable(tf.zeros([10]),name='b')

# Softmax Regressionを使う yはモデル
y = tf.nn.softmax(tf.matmul(x, W) + b)

# 交差エントロピー
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

# 先ほど使ったGradientDescentOptimizerで、今回はcross_entropyを利用
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)


# 初期化に関するコードの書き換え
# -------------------------------------------

# model_path = './model/model.ckpt'
#  # 初期化
#init = tf.global_variables_initializer()
#sess = tf.Session()
#sess.run(init)

model_path = './model/model.ckpt'
# 初期化
init = tf.global_variables_initializer()
sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess,model_path)
sess.run(init)
# -------------------------------------------


# データ読み込み確認のため学習はコメントアウト
# -------------------------------------------
# 学習
#for i in range(1000):
#    print(i)
#    batch_xs, batch_ys = mnist.train.next_batch(100)
#    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#    if i == 99:    
#        saver = tf.train.Saver()
#        saver.save(sess,model_path)
#        print("model saved" + model_path)
#        print(sess.run(W))
#        print(sess.run(b))
# -------------------------------------------

# テストデータで予測
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
print('accuracy : '+str(acc));
sess.close()

sess = tf.InteractiveSession()
init = tf.global_variables_initializer()
sess.run(init)

saver.restore(sess,model_path)
print(sess.run(W))
print(sess.run(b))
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
print('accuracy : '+str(acc))

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2018/01/04 14:49

    回答ありがとうございます
    ご指摘の通り'saver = tf.train.import_meta_graph('./model/model.ckpt.meta')'の部分を'saver = tf.train.Saver()'としたところ

    NotFoundError: Key Variable_4 not found in checkpoint
    [[Node: save_4/RestoreV2_4 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_4/Const_0_0, save_4/RestoreV2_4/tensor_names, save_4/RestoreV2_4/shape_and_slices)]]

    といったエラーが出るようになりました。
    このエラーの解決法を調べるとモデルの構築前に'tf.reset_default_graph()'を挟むと書いてあったのでこれをモデルの構築前に追加し実行するとエラーは出なくなりましたが、実行結果は上述のようにWもbも全て0となってしまいました。

    キャンセル

  • 2018/01/04 16:50

    回答を書き換えますね

    キャンセル

  • 2018/01/04 17:00

    書き替え後のコードの「# テストデータで予測」のあたりが、先の読み込み処理と重複していますが、元のコードからあまり変えすぎると分かりにくくなると思って、そのまま直さずにおいてあります。実装する際には重複する部分は適当に修正ください。

    キャンセル

  • 2018/01/04 20:55

    実行できました
    ありがとうございました!

    キャンセル

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 91.05%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る

  • Python

    5154questions

    Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

  • TensorFlow

    450questions