質問編集履歴

3 追記

kokawa2003

kokawa2003 score 153

2019/09/17 11:44  投稿

tensorflowの停止条件が分からない
tensorflowのビギナーです
適当な面白そうな奴を見つけて走らせればどうか?とおもったので
https://github.com/vanhuyz/CycleGAN-TensorFlow
を走らせてみました。
これトレーニングのソースを見ると
```python
import tensorflow as tf
from model import CycleGAN
from reader import Reader
from datetime import datetime
import os
import logging
from utils import ImagePool
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1')
tf.flags.DEFINE_integer('image_size', 256, 'image size, default: 256')
tf.flags.DEFINE_bool('use_lsgan', True,
                    'use lsgan (mean squared error) or cross entropy loss, default: True')
tf.flags.DEFINE_string('norm', 'instance',
                      '[instance, batch] use instance norm or batch norm, default: instance')
tf.flags.DEFINE_integer('lambda1', 10,
                       'weight for forward cycle loss (X->Y->X), default: 10')
tf.flags.DEFINE_integer('lambda2', 10,
                       'weight for backward cycle loss (Y->X->Y), default: 10')
tf.flags.DEFINE_float('learning_rate', 2e-4,
                     'initial learning rate for Adam, default: 0.0002')
tf.flags.DEFINE_float('beta1', 0.5,
                     'momentum term of Adam, default: 0.5')
tf.flags.DEFINE_float('pool_size', 50,
                     'size of image buffer that stores previously generated images, default: 50')
tf.flags.DEFINE_integer('ngf', 64,
                       'number of gen filters in first conv layer, default: 64')
tf.flags.DEFINE_string('X', 'data/tfrecords/apple.tfrecords',
                      'X tfrecords file for training, default: data/tfrecords/apple.tfrecords')
tf.flags.DEFINE_string('Y', 'data/tfrecords/orange.tfrecords',
                      'Y tfrecords file for training, default: data/tfrecords/orange.tfrecords')
tf.flags.DEFINE_string('load_model', None,
                       'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None')
def train():
 if FLAGS.load_model is not None:
   checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
 else:
   current_time = datetime.now().strftime("%Y%m%d-%H%M")
   checkpoints_dir = "checkpoints/{}".format(current_time)
   try:
     os.makedirs(checkpoints_dir)
   except os.error:
     pass
 graph = tf.Graph()
 with graph.as_default():
   cycle_gan = CycleGAN(
       X_train_file=FLAGS.X,
       Y_train_file=FLAGS.Y,
       batch_size=FLAGS.batch_size,
       image_size=FLAGS.image_size,
       use_lsgan=FLAGS.use_lsgan,
       norm=FLAGS.norm,
       lambda1=FLAGS.lambda1,
       lambda2=FLAGS.lambda2,
       learning_rate=FLAGS.learning_rate,
       beta1=FLAGS.beta1,
       ngf=FLAGS.ngf
   )
   G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
   optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)
   summary_op = tf.summary.merge_all()
   train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
   saver = tf.train.Saver()
 with tf.Session(graph=graph) as sess:
   if FLAGS.load_model is not None:
     checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
     meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
     restore = tf.train.import_meta_graph(meta_graph_path)
     restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
     step = int(meta_graph_path.split("-")[2].split(".")[0])
   else:
     sess.run(tf.global_variables_initializer())
     step = 0
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(sess=sess, coord=coord)
   try:
     fake_Y_pool = ImagePool(FLAGS.pool_size)
     fake_X_pool = ImagePool(FLAGS.pool_size)
     while not coord.should_stop():
       # get previously generated images
       fake_y_val, fake_x_val = sess.run([fake_y, fake_x])
       # train
       _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
             sess.run(
                 [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                 feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
             )
       )
       train_writer.add_summary(summary, step)
       train_writer.flush()
       if step % 100 == 0:
         logging.info('-----------Step %d:-------------' % step)
         logging.info(' G_loss  : {}'.format(G_loss_val))
         logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
         logging.info(' F_loss  : {}'.format(F_loss_val))
         logging.info(' D_X_loss : {}'.format(D_X_loss_val))
       if step % 10000 == 0:
         save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
         logging.info("Model saved in file: %s" % save_path)
       step += 1
   except KeyboardInterrupt:
     logging.info('Interrupted')
     coord.request_stop()
   except Exception as e:
     coord.request_stop(e)
   finally:
     save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
     logging.info("Model saved in file: %s" % save_path)
     # When done, ask the threads to stop.
     coord.request_stop()
     coord.join(threads)
def main(unused_argv):
 train()
if __name__ == '__main__':
 logging.basicConfig(level=logging.INFO)
 tf.app.run()
```
こうなっています。
これをみるとKEYの割り込み(CTL+C)か処理が例外のときに終了するらしいのはわかったのですがそれ以外にも停止することがあるのでしょうか?ってのがまず一個目の分からない分からないことです。
(このソースで割り込み以外で処理を終了させることができるのだろうか?ってことです。それともTF独特の終わらせ方があるのか?)
次に
学習中に
         logging.info('-----------Step %d:-------------' % step)
         logging.info(' G_loss  : {}'.format(G_loss_val))
         logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
         logging.info(' F_loss  : {}'.format(F_loss_val))
         logging.info(' D_X_loss : {}'.format(D_X_loss_val))
の値を出力しますが、
現在
INFO:root:-----------Step 22700:-------------
INFO:root: G_loss  : 2.449359893798828
INFO:root: D_Y_loss : 0.11760234832763672
INFO:root: F_loss  : 2.8169190883636475
INFO:root: D_X_loss : 0.05803598836064339
みたいな値ですがこれをどう評価すべきか分かりません。
次の分からないことは
これ学習に人の顔(256X256)をA アニメ少女の顔をB(256X256)
に突っ込んでいます。
進捗をtensorboardで見ると  
現在
![イメージ説明](b01355bde400c4deba556f15b2410524.png)
このような感じですが
待っていたら改善しますか?
シュワルツネッガーが美少女になりますか?
2 追記

kokawa2003

kokawa2003 score 153

2019/09/17 09:30  投稿

tensorflowの停止条件が分からない
tensorflowのビギナーです
適当な面白そうな奴を見つけて走らせればどうか?とおもったので
https://github.com/vanhuyz/CycleGAN-TensorFlow
を走らせてみました。
これトレーニングのソースを見ると
```python
import tensorflow as tf
from model import CycleGAN
from reader import Reader
from datetime import datetime
import os
import logging
from utils import ImagePool
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1')
tf.flags.DEFINE_integer('image_size', 256, 'image size, default: 256')
tf.flags.DEFINE_bool('use_lsgan', True,
                    'use lsgan (mean squared error) or cross entropy loss, default: True')
tf.flags.DEFINE_string('norm', 'instance',
                      '[instance, batch] use instance norm or batch norm, default: instance')
tf.flags.DEFINE_integer('lambda1', 10,
                       'weight for forward cycle loss (X->Y->X), default: 10')
tf.flags.DEFINE_integer('lambda2', 10,
                       'weight for backward cycle loss (Y->X->Y), default: 10')
tf.flags.DEFINE_float('learning_rate', 2e-4,
                     'initial learning rate for Adam, default: 0.0002')
tf.flags.DEFINE_float('beta1', 0.5,
                     'momentum term of Adam, default: 0.5')
tf.flags.DEFINE_float('pool_size', 50,
                     'size of image buffer that stores previously generated images, default: 50')
tf.flags.DEFINE_integer('ngf', 64,
                       'number of gen filters in first conv layer, default: 64')
tf.flags.DEFINE_string('X', 'data/tfrecords/apple.tfrecords',
                      'X tfrecords file for training, default: data/tfrecords/apple.tfrecords')
tf.flags.DEFINE_string('Y', 'data/tfrecords/orange.tfrecords',
                      'Y tfrecords file for training, default: data/tfrecords/orange.tfrecords')
tf.flags.DEFINE_string('load_model', None,
                       'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None')
def train():
 if FLAGS.load_model is not None:
   checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
 else:
   current_time = datetime.now().strftime("%Y%m%d-%H%M")
   checkpoints_dir = "checkpoints/{}".format(current_time)
   try:
     os.makedirs(checkpoints_dir)
   except os.error:
     pass
 graph = tf.Graph()
 with graph.as_default():
   cycle_gan = CycleGAN(
       X_train_file=FLAGS.X,
       Y_train_file=FLAGS.Y,
       batch_size=FLAGS.batch_size,
       image_size=FLAGS.image_size,
       use_lsgan=FLAGS.use_lsgan,
       norm=FLAGS.norm,
       lambda1=FLAGS.lambda1,
       lambda2=FLAGS.lambda2,
       learning_rate=FLAGS.learning_rate,
       beta1=FLAGS.beta1,
       ngf=FLAGS.ngf
   )
   G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
   optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)
   summary_op = tf.summary.merge_all()
   train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
   saver = tf.train.Saver()
 with tf.Session(graph=graph) as sess:
   if FLAGS.load_model is not None:
     checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
     meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
     restore = tf.train.import_meta_graph(meta_graph_path)
     restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
     step = int(meta_graph_path.split("-")[2].split(".")[0])
   else:
     sess.run(tf.global_variables_initializer())
     step = 0
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(sess=sess, coord=coord)
   try:
     fake_Y_pool = ImagePool(FLAGS.pool_size)
     fake_X_pool = ImagePool(FLAGS.pool_size)
     while not coord.should_stop():
       # get previously generated images
       fake_y_val, fake_x_val = sess.run([fake_y, fake_x])
       # train
       _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
             sess.run(
                 [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                 feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
             )
       )
       train_writer.add_summary(summary, step)
       train_writer.flush()
       if step % 100 == 0:
         logging.info('-----------Step %d:-------------' % step)
         logging.info(' G_loss  : {}'.format(G_loss_val))
         logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
         logging.info(' F_loss  : {}'.format(F_loss_val))
         logging.info(' D_X_loss : {}'.format(D_X_loss_val))
       if step % 10000 == 0:
         save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
         logging.info("Model saved in file: %s" % save_path)
       step += 1
   except KeyboardInterrupt:
     logging.info('Interrupted')
     coord.request_stop()
   except Exception as e:
     coord.request_stop(e)
   finally:
     save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
     logging.info("Model saved in file: %s" % save_path)
     # When done, ask the threads to stop.
     coord.request_stop()
     coord.join(threads)
def main(unused_argv):
 train()
if __name__ == '__main__':
 logging.basicConfig(level=logging.INFO)
 tf.app.run()
```
こうなっています。
これをみるとKEYの割り込み(CTL+C)か処理が例外のときに終了するらしいのはわかったのですがそれ以外にも停止することがあるのでしょうか?ってのがまず一個目の分からない分からないことです。
(このソースで割り込み以外で処理を終了させることができるのだろうか?ってことです。それともTF独特の終わらせ方があるのか?)  
次に
学習中に
         logging.info('-----------Step %d:-------------' % step)
         logging.info(' G_loss  : {}'.format(G_loss_val))
         logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
         logging.info(' F_loss  : {}'.format(F_loss_val))
         logging.info(' D_X_loss : {}'.format(D_X_loss_val))
の値を出力しますが、
現在
INFO:root:-----------Step 22700:-------------
INFO:root: G_loss  : 2.449359893798828
INFO:root: D_Y_loss : 0.11760234832763672
INFO:root: F_loss  : 2.8169190883636475
INFO:root: D_X_loss : 0.05803598836064339
みたいな値ですがこれをどう評価すべきか分かりません。
次の分からないことは
これ学習に人の顔(256X256)をA アニメ少女の顔をB(256X256)
に突っ込んでいます。
現在
![イメージ説明](b01355bde400c4deba556f15b2410524.png)
このような感じですが
待っていたら改善しますか?
シュワルツネッガーが美少女になりますか?
1 追記

kokawa2003

kokawa2003 score 153

2019/09/17 09:23  投稿

tensorflowの停止条件が分からない
tensorflowのビギナーです
適当な面白そうな奴を見つけて走らせればどうか?とおもったので
https://github.com/vanhuyz/CycleGAN-TensorFlow
を走らせてみました。
これトレーニングのソースを見ると
```python
import tensorflow as tf
from model import CycleGAN
from reader import Reader
from datetime import datetime
import os
import logging
from utils import ImagePool
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1')
tf.flags.DEFINE_integer('image_size', 256, 'image size, default: 256')
tf.flags.DEFINE_bool('use_lsgan', True,
                    'use lsgan (mean squared error) or cross entropy loss, default: True')
tf.flags.DEFINE_string('norm', 'instance',
                      '[instance, batch] use instance norm or batch norm, default: instance')
tf.flags.DEFINE_integer('lambda1', 10,
                       'weight for forward cycle loss (X->Y->X), default: 10')
tf.flags.DEFINE_integer('lambda2', 10,
                       'weight for backward cycle loss (Y->X->Y), default: 10')
tf.flags.DEFINE_float('learning_rate', 2e-4,
                     'initial learning rate for Adam, default: 0.0002')
tf.flags.DEFINE_float('beta1', 0.5,
                     'momentum term of Adam, default: 0.5')
tf.flags.DEFINE_float('pool_size', 50,
                     'size of image buffer that stores previously generated images, default: 50')
tf.flags.DEFINE_integer('ngf', 64,
                       'number of gen filters in first conv layer, default: 64')
tf.flags.DEFINE_string('X', 'data/tfrecords/apple.tfrecords',
                      'X tfrecords file for training, default: data/tfrecords/apple.tfrecords')
tf.flags.DEFINE_string('Y', 'data/tfrecords/orange.tfrecords',
                      'Y tfrecords file for training, default: data/tfrecords/orange.tfrecords')
tf.flags.DEFINE_string('load_model', None,
                       'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None')
def train():
 if FLAGS.load_model is not None:
   checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
 else:
   current_time = datetime.now().strftime("%Y%m%d-%H%M")
   checkpoints_dir = "checkpoints/{}".format(current_time)
   try:
     os.makedirs(checkpoints_dir)
   except os.error:
     pass
 graph = tf.Graph()
 with graph.as_default():
   cycle_gan = CycleGAN(
       X_train_file=FLAGS.X,
       Y_train_file=FLAGS.Y,
       batch_size=FLAGS.batch_size,
       image_size=FLAGS.image_size,
       use_lsgan=FLAGS.use_lsgan,
       norm=FLAGS.norm,
       lambda1=FLAGS.lambda1,
       lambda2=FLAGS.lambda2,
       learning_rate=FLAGS.learning_rate,
       beta1=FLAGS.beta1,
       ngf=FLAGS.ngf
   )
   G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
   optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)
   summary_op = tf.summary.merge_all()
   train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
   saver = tf.train.Saver()
 with tf.Session(graph=graph) as sess:
   if FLAGS.load_model is not None:
     checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
     meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
     restore = tf.train.import_meta_graph(meta_graph_path)
     restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
     step = int(meta_graph_path.split("-")[2].split(".")[0])
   else:
     sess.run(tf.global_variables_initializer())
     step = 0
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(sess=sess, coord=coord)
   try:
     fake_Y_pool = ImagePool(FLAGS.pool_size)
     fake_X_pool = ImagePool(FLAGS.pool_size)
     while not coord.should_stop():
       # get previously generated images
       fake_y_val, fake_x_val = sess.run([fake_y, fake_x])
       # train
       _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
             sess.run(
                 [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                 feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                            cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
             )
       )
       train_writer.add_summary(summary, step)
       train_writer.flush()
       if step % 100 == 0:
         logging.info('-----------Step %d:-------------' % step)
         logging.info(' G_loss  : {}'.format(G_loss_val))
         logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
         logging.info(' F_loss  : {}'.format(F_loss_val))
         logging.info(' D_X_loss : {}'.format(D_X_loss_val))
       if step % 10000 == 0:
         save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
         logging.info("Model saved in file: %s" % save_path)
       step += 1
   except KeyboardInterrupt:
     logging.info('Interrupted')
     coord.request_stop()
   except Exception as e:
     coord.request_stop(e)
   finally:
     save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
     logging.info("Model saved in file: %s" % save_path)
     # When done, ask the threads to stop.
     coord.request_stop()
     coord.join(threads)
def main(unused_argv):
 train()
if __name__ == '__main__':
 logging.basicConfig(level=logging.INFO)
 tf.app.run()
```
こうなっています。
これをみるとKEYの割り込み(CTL+C)か処理が例外のときに終了するらしいのはわかったのですがそれ以外にも停止することがあるのでしょうか?ってのがまず一個目の分からない分からないことです。
次に  
学習中に
         logging.info('-----------Step %d:-------------' % step)
         logging.info(' G_loss  : {}'.format(G_loss_val))
         logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))
         logging.info(' F_loss  : {}'.format(F_loss_val))
         logging.info(' D_X_loss : {}'.format(D_X_loss_val))
の値を出力しますが、
現在
INFO:root:-----------Step 22700:-------------
INFO:root: G_loss  : 2.449359893798828
INFO:root: D_Y_loss : 0.11760234832763672
INFO:root: F_loss  : 2.8169190883636475
INFO:root: D_X_loss : 0.05803598836064339
みたいな値ですがこれをどう評価すべきか分かりません。
次の分からないことは
これ学習に人の顔(256X256)をA アニメ少女の顔をB(256X256)
に突っ込んでいます。
に突っ込んでいます。
現在
![イメージ説明](b01355bde400c4deba556f15b2410524.png)
このような感じですが
待っていたら改善しますか?
シュワルツネッガーが美少女になりますか?

思考するエンジニアのためのQ&Aサイト「teratail」について詳しく知る