質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%

Q&A

0回答

333閲覧

tensorflowの停止条件が分からない

kokawa2003

総合スコア217

0グッド

2クリップ

投稿2019/09/17 00:18

編集2019/09/17 02:45

tensorflowのビギナーです
適当な面白そうな奴を見つけて走らせればどうか?とおもったので
https://github.com/vanhuyz/CycleGAN-TensorFlow
を走らせてみました。
これトレーニングのソースを見ると

python

1import tensorflow as tf 2from model import CycleGAN 3from reader import Reader 4from datetime import datetime 5import os 6import logging 7from utils import ImagePool 8 9FLAGS = tf.flags.FLAGS 10 11tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1') 12tf.flags.DEFINE_integer('image_size', 256, 'image size, default: 256') 13tf.flags.DEFINE_bool('use_lsgan', True, 14 'use lsgan (mean squared error) or cross entropy loss, default: True') 15tf.flags.DEFINE_string('norm', 'instance', 16 '[instance, batch] use instance norm or batch norm, default: instance') 17tf.flags.DEFINE_integer('lambda1', 10, 18 'weight for forward cycle loss (X->Y->X), default: 10') 19tf.flags.DEFINE_integer('lambda2', 10, 20 'weight for backward cycle loss (Y->X->Y), default: 10') 21tf.flags.DEFINE_float('learning_rate', 2e-4, 22 'initial learning rate for Adam, default: 0.0002') 23tf.flags.DEFINE_float('beta1', 0.5, 24 'momentum term of Adam, default: 0.5') 25tf.flags.DEFINE_float('pool_size', 50, 26 'size of image buffer that stores previously generated images, default: 50') 27tf.flags.DEFINE_integer('ngf', 64, 28 'number of gen filters in first conv layer, default: 64') 29 30tf.flags.DEFINE_string('X', 'data/tfrecords/apple.tfrecords', 31 'X tfrecords file for training, default: data/tfrecords/apple.tfrecords') 32tf.flags.DEFINE_string('Y', 'data/tfrecords/orange.tfrecords', 33 'Y tfrecords file for training, default: data/tfrecords/orange.tfrecords') 34tf.flags.DEFINE_string('load_model', None, 35 'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None') 36 37 38def train(): 39 if FLAGS.load_model is not None: 40 checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/") 41 else: 42 current_time = datetime.now().strftime("%Y%m%d-%H%M") 43 checkpoints_dir = "checkpoints/{}".format(current_time) 44 try: 45 os.makedirs(checkpoints_dir) 46 except os.error: 47 pass 48 49 graph = tf.Graph() 50 with graph.as_default(): 51 cycle_gan = CycleGAN( 52 X_train_file=FLAGS.X, 53 Y_train_file=FLAGS.Y, 54 batch_size=FLAGS.batch_size, 55 image_size=FLAGS.image_size, 56 use_lsgan=FLAGS.use_lsgan, 57 norm=FLAGS.norm, 58 lambda1=FLAGS.lambda1, 59 lambda2=FLAGS.lambda2, 60 learning_rate=FLAGS.learning_rate, 61 beta1=FLAGS.beta1, 62 ngf=FLAGS.ngf 63 ) 64 G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() 65 optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) 66 67 summary_op = tf.summary.merge_all() 68 train_writer = tf.summary.FileWriter(checkpoints_dir, graph) 69 saver = tf.train.Saver() 70 71 with tf.Session(graph=graph) as sess: 72 if FLAGS.load_model is not None: 73 checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) 74 meta_graph_path = checkpoint.model_checkpoint_path + ".meta" 75 restore = tf.train.import_meta_graph(meta_graph_path) 76 restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) 77 step = int(meta_graph_path.split("-")[2].split(".")[0]) 78 else: 79 sess.run(tf.global_variables_initializer()) 80 step = 0 81 82 coord = tf.train.Coordinator() 83 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 84 85 try: 86 fake_Y_pool = ImagePool(FLAGS.pool_size) 87 fake_X_pool = ImagePool(FLAGS.pool_size) 88 89 while not coord.should_stop(): 90 # get previously generated images 91 fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) 92 93 # train 94 _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( 95 sess.run( 96 [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op], 97 feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), 98 cycle_gan.fake_x: fake_X_pool.query(fake_x_val)} 99 ) 100 ) 101 102 train_writer.add_summary(summary, step) 103 train_writer.flush() 104 105 if step % 100 == 0: 106 logging.info('-----------Step %d:-------------' % step) 107 logging.info(' G_loss : {}'.format(G_loss_val)) 108 logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) 109 logging.info(' F_loss : {}'.format(F_loss_val)) 110 logging.info(' D_X_loss : {}'.format(D_X_loss_val)) 111 112 if step % 10000 == 0: 113 save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) 114 logging.info("Model saved in file: %s" % save_path) 115 116 step += 1 117 118 except KeyboardInterrupt: 119 logging.info('Interrupted') 120 coord.request_stop() 121 except Exception as e: 122 coord.request_stop(e) 123 finally: 124 save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) 125 logging.info("Model saved in file: %s" % save_path) 126 # When done, ask the threads to stop. 127 coord.request_stop() 128 coord.join(threads) 129 130def main(unused_argv): 131 train() 132 133if __name__ == '__main__': 134 logging.basicConfig(level=logging.INFO) 135 tf.app.run()

こうなっています。
これをみるとKEYの割り込み(CTL+C)か処理が例外のときに終了するらしいのはわかったのですがそれ以外にも停止することがあるのでしょうか?ってのがまず一個目の分からない分からないことです。
(このソースで割り込み以外で処理を終了させることができるのだろうか?ってことです。それともTF独特の終わらせ方があるのか?)
次に
学習中に
logging.info('-----------Step %d:-------------' % step)
logging.info(' G_loss : {}'.format(G_loss_val))
logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
logging.info(' F_loss : {}'.format(F_loss_val))
logging.info(' D_X_loss : {}'.format(D_X_loss_val))
の値を出力しますが、
現在
INFO:root:-----------Step 22700:-------------
INFO:root: G_loss : 2.449359893798828
INFO:root: D_Y_loss : 0.11760234832763672
INFO:root: F_loss : 2.8169190883636475
INFO:root: D_X_loss : 0.05803598836064339
みたいな値ですがこれをどう評価すべきか分かりません。

次の分からないことは
これ学習に人の顔(256X256)をA アニメ少女の顔をB(256X256)
に突っ込んでいます。
進捗をtensorboardで見ると
現在
イメージ説明
このような感じですが
待っていたら改善しますか?
シュワルツネッガーが美少女になりますか?

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問