質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

3758閲覧

TensorFlowのsaver.restoreについて

TakakiKuwabara

総合スコア38

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2017/09/12 11:04

Python3のTensorFlowにて、下記のコードによりエラーが発生しました。
エラーの原因はtf.train.Saverにあるようなのですが、調べたところTensorFlowのバージョンが悪いのやら、
cifar10.pyの一部を書き換えが必要なのやら、何が正しいのかわからない状態です。

どなたかわかる方がいらっしゃれば、お教えいただけないでしょうか。

Python3

1import sys 2import tensorflow as tf 3import cifar10 4 5FLAGS = tf.app.flags.FLAGS 6cifar10.NUM_CLASSES = 7 7tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train', 8 """Directory where to read model checkpoints.""") 9 10def evaluate(filename): 11 with tf.Graph().as_default() as g: 12 file = tf.read_file(filename) 13 image = tf.image.decode_jpeg(file, channels = 3) 14 image = tf.image.resize_images(image, [32, 32]) 15 # cifar10は内部処理で32×32を24×24に切り出して利用している 16 image = tf.image.resize_image_with_crop_or_pad(image, 24, 24) 17 logits = cifar10.inference([image]) 18 19 # ここのkの値もクラス数と一致させる 20 top_k_op = tf.nn.top_k(logits, k=7) 21 22 variable_averages = tf.train.ExponentialMovingAverage(cifar10.MOVING_AVERAGE_DECAY) 23 variables_to_restore = variable_averages.variables_to_restore() 24 saver = tf.train.Saver(variables_to_restore) 25 26 sess = tf.InteractiveSession() 27 sess.run(tf.initialize_all_variables()) 28 ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 29 30 if ckpt and ckpt.model_checkpoint_path: 31 saver.restore(sess, ckpt.model_checkpoint_path) 32 33 else: 34 print('No checkpoint file found') 35 return 36 37 tf.train.start_queue_runners(sess=sess) 38 values, indices = sess.run(top_k_op) 39 ratio = sess.run(tf.nn.softmax(values[0])) 40 41 # 予想したラベルとそれぞれに対する確信度 42 print(indices[0]) 43 print(ratio) 44 45def main(argv=None): 46 evaluate(sys.argv[1]) 47 48if __name__ == '__main__': 49 tf.app.run()

エラー内容は下記の通りです。

Python3

1Traceback (most recent call last): 2 File "cifar10_sample.py", line 60, in <module> 3 tf.app.run() 4 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 44, in run 5 _sys.exit(main(_sys.argv[:1] + flags_passthrough)) 6 File "cifar10_sample.py", line 56, in main 7 evaluate(sys.argv[1]) 8 File "cifar10_sample.py", line 40, in evaluate 9 saver.restore(sess, ckpt.model_checkpoint_path) 10 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1439, in restore 11 {self.saver_def.filename_tensor_name: save_path}) 12 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 767, in run 13 run_metadata_ptr) 14 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 965, in _run 15 feed_dict_string, options, run_metadata) 16 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run 17 target_list, options, run_metadata) 18 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call 19 raise type(e)(node_def, op, message) 20tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [192,7] rhs shape= [192,10] 21 [[Node: save/Assign_9 = Assign[T=DT_FLOAT, _class=["loc:@softmax_linear/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](softmax_linear/weights, save/RestoreV2_9)]] 22 23Caused by op 'save/Assign_9', defined at: 24 File "cifar10_sample.py", line 60, in <module> 25 tf.app.run() 26 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 44, in run 27 _sys.exit(main(_sys.argv[:1] + flags_passthrough)) 28 File "cifar10_sample.py", line 56, in main 29 evaluate(sys.argv[1]) 30 File "cifar10_sample.py", line 33, in evaluate 31 saver = tf.train.Saver(variables_to_restore) 32 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1051, in __init__ 33 self.build() 34 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1081, in build 35 restore_sequentially=self._restore_sequentially) 36 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 675, in build 37 restore_sequentially, reshape) 38 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 414, in _AddRestoreOps 39 assign_ops.append(saveable.restore(tensors, shapes)) 40 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 155, in restore 41 self.op.get_shape().is_fully_defined()) 42 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/ops/gen_state_ops.py", line 47, in assign 43 use_locking=use_locking, name=name) 44 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op 45 op_def=op_def) 46 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op 47 original_op=self._default_original_op, op_def=op_def) 48 File "/Users/k_aki86/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__ 49 self._traceback = _extract_stack() 50 51InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [192,7] rhs shape= [192,10] 52 [[Node: save/Assign_9 = Assign[T=DT_FLOAT, _class=["loc:@softmax_linear/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](softmax_linear/weights, save/RestoreV2_9)]]

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

投稿2017/09/13 09:24

TakakiKuwabara

総合スコア38

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問