質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

3839閲覧

tensorflowで保存したモデルの復元が失敗します

pacifinapacific

総合スコア14

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2018/03/06 14:32

Tensorflow で保存したモデルの読み込みをしたいです

tensorflowで画像よみこんでVGG-16で画像認識を試してみたのですが、保存するモデルを復元する際にエラーが出てしまっています。tf.train.Saverのタイミングがおかしかったりするのでしょうか。モデル保存のタイミングがつかめていません

発生している問題・エラーメッセージ

--------------------------------------------------------------------------- FailedPreconditionError Traceback (most recent call last) /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args) 1360 try: -> 1361 return fn(*args) 1362 except errors.OpError as e: /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata) 1339 return tf_session.TF_Run(session, options, feed_dict, fetch_list, -> 1340 target_list, status, run_metadata) 1341 /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg) 515 compat.as_text(c_api.TF_Message(self.status.status)), --> 516 c_api.TF_GetCode(self.status.status)) 517 # Delete the underlying status object from memory otherwise it stays alive FailedPreconditionError: Attempting to use uninitialized value VGG-16/conv1/w [[Node: VGG-16/conv1/w/read = Identity[T=DT_FLOAT, _class=["loc:@VGG-16/conv1/w"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](VGG-16/conv1/w)]] During handling of the above exception, another exception occurred: FailedPreconditionError Traceback (most recent call last) <ipython-input-3-a297b268c7ce> in <module>() 153 154 if __name__=='__main__': --> 155 main() 156 <ipython-input-3-a297b268c7ce> in main() 117 for j in range(3937): 118 print(j) --> 119 logit=sess.run(test_logits) 120 logit=logit.astype(np.float64) 121 pred=np.argmax(logit,1) /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 903 try: 904 result = self._run(None, fetches, feed_dict, options_ptr, --> 905 run_metadata_ptr) 906 if run_metadata: 907 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 1135 if final_fetches or final_targets or (handle and feed_dict_tensor): 1136 results = self._do_run(handle, final_targets, final_fetches, -> 1137 feed_dict_tensor, options, run_metadata) 1138 else: 1139 results = [] /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata) 1353 if handle is None: 1354 return self._do_call(_run_fn, self._session, feeds, fetches, targets, -> 1355 options, run_metadata) 1356 else: 1357 return self._do_call(_prun_fn, self._session, handle, feeds, fetches) /usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args) 1372 except KeyError: 1373 pass -> 1374 raise type(e)(node_def, op, message) 1375 1376 def _extend_graph(self): FailedPreconditionError: Attempting to use uninitialized value VGG-16/conv1/w [[Node: VGG-16/conv1/w/read = Identity[T=DT_FLOAT, _class=["loc:@VGG-16/conv1/w"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](VGG-16/conv1/w)]]

該当のソースコード

python

1def conv2d(input,output_dim,name,k_h=3,k_w=3,s_h=1,s_w=1,stddev=0.02): 2 3 with tf.variable_scope(name): 4 w=tf.get_variable('w',[k_h,k_w,input.get_shape()[-1],output_dim], 5 initializer=tf.truncated_normal_initializer(stddev=stddev)) 6 biases=tf.get_variable('b',[output_dim],initializer=tf.zeros_initializer()) 7 conv=tf.nn.conv2d(input,w,strides=[1,s_h,s_w,1],padding='SAME') 8 conv=tf.reshape(tf.nn.bias_add(conv,biases),conv.get_shape()) 9 print("ok") 10 11 return conv 12 13def maxpool(input,k_h=2,k_w=2,s_h=2,s_w=2): 14 15 maxpool=tf.nn.max_pool(input,ksize=[1,k_h,k_w,1],strides=[1,s_h,s_w,1],padding='SAME') 16 17 return maxpool 18 19def relu(input): 20 return tf.nn.relu(input) 21 22def dense(input,input_dim,output_dim,name,stddev=0.02): 23 with tf.variable_scope(name): 24 reshape=tf.reshape(input,[input.get_shape()[0].value,-1]) 25 w=tf.get_variable('w',[input_dim,output_dim], 26 initializer=tf.truncated_normal_initializer(stddev=stddev)) 27 biases=tf.get_variable('b',[output_dim], 28 initializer=tf.zeros_initializer()) 29 30 out=tf.add(tf.matmul(reshape,w),biases) 31 32 return out 33 34 35 36def VGG_16(inputs,reuse=False): 37 with tf.variable_scope('VGG-16',reuse=reuse): 38 conv1=relu(conv2d(inputs,64,'conv1')) 39 conv2=relu(conv2d(conv1,64,'conv2')) 40 pool2=maxpool(conv2) 41 42 conv3=relu(conv2d(pool2,128,'conv3')) 43 conv4=relu(conv2d(conv3,128,'conv4')) 44 pool4=maxpool(conv4) 45 46 47 conv5=relu(conv2d(pool4,256,'conv5')) 48 conv6=relu(conv2d(conv5,256,'conv6')) 49 conv7=relu(conv2d(conv6,256,'conv7')) 50 conv8=relu(conv2d(conv7,256,'conv8')) 51 pool8=maxpool(conv8) 52 53 conv9=relu(conv2d(pool8,512,'conv9')) 54 conv10=relu(conv2d(conv9,512,'conv10')) 55 conv11=relu(conv2d(conv10,512,'conv11')) 56 conv12=relu(conv2d(conv11,512,'conv12')) 57 pool12=maxpool(conv12) 58 59 conv13=relu(conv2d(pool12,512,'conv13')) 60 conv14=relu(conv2d(conv13,512,'conv14')) 61 conv15=relu(conv2d(conv14,512,'conv15')) 62 conv16=relu(conv2d(conv15,512,'conv16')) 63 pool16=maxpool(conv16) 64 fc17=dense(pool16,25088,4096,'fc17') 65 fc18=dense(fc17,4096,4096,'fc18') 66 fc19=dense(fc18,4096,55,'fc19') 67 68 return fc19 69 70 71def loss(labels,logits): 72 73 74 75 cross_entropy=tf.nn.softmax_cross_entropy_with_logits( 76 labels=labels, 77 logits=logits) 78 79 return tf.reduce_mean(cross_entropy) 80 81def training(loss): 82 optimizer=tf.train.AdamOptimizer(0.01) 83 84 return optimizer.minimize(loss) 85 86 87 88def main(): 89 90 train_images,train_labels=read_tfrecord_train("drive/dataset.tfrecords") 91 test_images,test_labels=read_tfrecord_test("drive/test2.tfrecords") 92# test_images,test_labels=inputs.test_batch(test_label,test_path) 93 94 train_logits=VGG_16(train_images) 95 test_logits=VGG_16(test_images,reuse=True) 96 losses=loss(train_labels,train_logits) 97 train_op=training(losses) 98 99# predict_op=pred(test_l) 100 101 102 with tf.Session() as sess: 103 104 saver = tf.train.import_meta_graph("./model_pred100000.ckpt-100000.meta") 105 saver.restore(sess, "./model_pred100000.ckpt-100000") 106 flag=0 107 cwd = os.getcwd() #windowsのみ 108 109 ckpt_state=tf.train.get_checkpoint_state(cwd+"//") 110 if ckpt_state: 111 112 flag=1 113 else: 114 sess.run(tf.global_variables_initializer()) 115 116 coord=tf.train.Coordinator() 117 threads=tf.train.start_queue_runners(coord=coord) 118 119 120 121 with tf.name_scope("summary"): 122 tf.summary.scalar('loss',losses) 123 merged=tf.summary.merge_all() 124 writer=tf.summary.FileWriter('./logs',sess.graph) 125 126 127 if flag==1: 128 pred_list=[] 129 for j in range(3937): 130 print(j) 131 logit=sess.run(test_logits) 132 logit=logit.astype(np.float64) 133 pred=np.argmax(logit,1) 134 pred_list.extend(logit.tolist()) 135 pred_df=pd.DataFrame(pred_list) 136 pred_df.to_csv('predict_test.csv') 137 138 139# cwd = os.getcwd() windowsのみ 140 141 if flag==0: 142 for i in range(1000000): 143 144 _, loss_value=sess.run([train_op,losses]) 145 print("step:{:3d}: {:5f}".format(i+1,loss_value)) 146 147 if i%10000==0: 148 saver.save(sess,"model_pred{}.ckpt".format(i),global_step=i) 149 150 coord.request_stop() 151 coord.join(threads) 152 153 154if __name__=='__main__': 155 main() 156

試したこと

補足情報(FW/ツールのバージョンなど)

googleのコラボラトリ上で試しています

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

pacifinapacific

2018/03/06 16:35

成功しました!モデルを読み込むときはinitializerは必要ないと勘違いしていました。ありがとうございます
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問