tensorflowのビギナーです
適当な面白そうな奴を見つけて走らせればどうか?とおもったので
https://github.com/vanhuyz/CycleGAN-TensorFlow
を走らせてみました。
これトレーニングのソースを見ると
python
1import tensorflow as tf 2from model import CycleGAN 3from reader import Reader 4from datetime import datetime 5import os 6import logging 7from utils import ImagePool 8 9FLAGS = tf.flags.FLAGS 10 11tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1') 12tf.flags.DEFINE_integer('image_size', 256, 'image size, default: 256') 13tf.flags.DEFINE_bool('use_lsgan', True, 14 'use lsgan (mean squared error) or cross entropy loss, default: True') 15tf.flags.DEFINE_string('norm', 'instance', 16 '[instance, batch] use instance norm or batch norm, default: instance') 17tf.flags.DEFINE_integer('lambda1', 10, 18 'weight for forward cycle loss (X->Y->X), default: 10') 19tf.flags.DEFINE_integer('lambda2', 10, 20 'weight for backward cycle loss (Y->X->Y), default: 10') 21tf.flags.DEFINE_float('learning_rate', 2e-4, 22 'initial learning rate for Adam, default: 0.0002') 23tf.flags.DEFINE_float('beta1', 0.5, 24 'momentum term of Adam, default: 0.5') 25tf.flags.DEFINE_float('pool_size', 50, 26 'size of image buffer that stores previously generated images, default: 50') 27tf.flags.DEFINE_integer('ngf', 64, 28 'number of gen filters in first conv layer, default: 64') 29 30tf.flags.DEFINE_string('X', 'data/tfrecords/apple.tfrecords', 31 'X tfrecords file for training, default: data/tfrecords/apple.tfrecords') 32tf.flags.DEFINE_string('Y', 'data/tfrecords/orange.tfrecords', 33 'Y tfrecords file for training, default: data/tfrecords/orange.tfrecords') 34tf.flags.DEFINE_string('load_model', None, 35 'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None') 36 37 38def train(): 39 if FLAGS.load_model is not None: 40 checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/") 41 else: 42 current_time = datetime.now().strftime("%Y%m%d-%H%M") 43 checkpoints_dir = "checkpoints/{}".format(current_time) 44 try: 45 os.makedirs(checkpoints_dir) 46 except os.error: 47 pass 48 49 graph = tf.Graph() 50 with graph.as_default(): 51 cycle_gan = CycleGAN( 52 X_train_file=FLAGS.X, 53 Y_train_file=FLAGS.Y, 54 batch_size=FLAGS.batch_size, 55 image_size=FLAGS.image_size, 56 use_lsgan=FLAGS.use_lsgan, 57 norm=FLAGS.norm, 58 lambda1=FLAGS.lambda1, 59 lambda2=FLAGS.lambda2, 60 learning_rate=FLAGS.learning_rate, 61 beta1=FLAGS.beta1, 62 ngf=FLAGS.ngf 63 ) 64 G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model() 65 optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss) 66 67 summary_op = tf.summary.merge_all() 68 train_writer = tf.summary.FileWriter(checkpoints_dir, graph) 69 saver = tf.train.Saver() 70 71 with tf.Session(graph=graph) as sess: 72 if FLAGS.load_model is not None: 73 checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) 74 meta_graph_path = checkpoint.model_checkpoint_path + ".meta" 75 restore = tf.train.import_meta_graph(meta_graph_path) 76 restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir)) 77 step = int(meta_graph_path.split("-")[2].split(".")[0]) 78 else: 79 sess.run(tf.global_variables_initializer()) 80 step = 0 81 82 coord = tf.train.Coordinator() 83 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 84 85 try: 86 fake_Y_pool = ImagePool(FLAGS.pool_size) 87 fake_X_pool = ImagePool(FLAGS.pool_size) 88 89 while not coord.should_stop(): 90 # get previously generated images 91 fake_y_val, fake_x_val = sess.run([fake_y, fake_x]) 92 93 # train 94 _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = ( 95 sess.run( 96 [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op], 97 feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val), 98 cycle_gan.fake_x: fake_X_pool.query(fake_x_val)} 99 ) 100 ) 101 102 train_writer.add_summary(summary, step) 103 train_writer.flush() 104 105 if step % 100 == 0: 106 logging.info('-----------Step %d:-------------' % step) 107 logging.info(' G_loss : {}'.format(G_loss_val)) 108 logging.info(' D_Y_loss : {}'.format(D_Y_loss_val)) 109 logging.info(' F_loss : {}'.format(F_loss_val)) 110 logging.info(' D_X_loss : {}'.format(D_X_loss_val)) 111 112 if step % 10000 == 0: 113 save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) 114 logging.info("Model saved in file: %s" % save_path) 115 116 step += 1 117 118 except KeyboardInterrupt: 119 logging.info('Interrupted') 120 coord.request_stop() 121 except Exception as e: 122 coord.request_stop(e) 123 finally: 124 save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) 125 logging.info("Model saved in file: %s" % save_path) 126 # When done, ask the threads to stop. 127 coord.request_stop() 128 coord.join(threads) 129 130def main(unused_argv): 131 train() 132 133if __name__ == '__main__': 134 logging.basicConfig(level=logging.INFO) 135 tf.app.run()
こうなっています。
これをみるとKEYの割り込み(CTL+C)か処理が例外のときに終了するらしいのはわかったのですがそれ以外にも停止することがあるのでしょうか?ってのがまず一個目の分からない分からないことです。
(このソースで割り込み以外で処理を終了させることができるのだろうか?ってことです。それともTF独特の終わらせ方があるのか?)
次に
学習中に
logging.info('-----------Step %d:-------------' % step)
logging.info(' G_loss : {}'.format(G_loss_val))
logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
logging.info(' F_loss : {}'.format(F_loss_val))
logging.info(' D_X_loss : {}'.format(D_X_loss_val))
の値を出力しますが、
現在
INFO:root:-----------Step 22700:-------------
INFO:root: G_loss : 2.449359893798828
INFO:root: D_Y_loss : 0.11760234832763672
INFO:root: F_loss : 2.8169190883636475
INFO:root: D_X_loss : 0.05803598836064339
みたいな値ですがこれをどう評価すべきか分かりません。
次の分からないことは
これ学習に人の顔(256X256)をA アニメ少女の顔をB(256X256)
に突っ込んでいます。
進捗をtensorboardで見ると
現在
このような感じですが
待っていたら改善しますか?
シュワルツネッガーが美少女になりますか?
あなたの回答
tips
プレビュー