質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

0回答

518閲覧

tensorflow dcganの際にprefetchをしてもデータ読み込みが前もって行われない

ulthar

総合スコア8

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2021/10/29 14:17

deeplearning初学者です。質問の情報が足りない場合はご教授ください。

tensorflow(keras)をつかってdeeplearningの練習をしています。基礎的な画像分類がある程度できるようになったのでdcganに挑戦しています。

使用するデータはTFrecordの形式にしており、prefetchでepoch間のデータ読み込みを短縮しようと思ったのですが、ターミナルの出力を確認するとepoch終了後に改めてデータのshuffleをしているようです。

画像分類の際はmodel.fitにdatasetを投げるだけで訓練できたため、dcganのように少し工夫をしてmodelの訓練を行う際にprefetchを上手に機能させる方法がわかりません。お教えいただけると幸いです。

python

1import tensorflow as tf 2from tensorflow.keras import models,layers 3from IPython import display 4import os 5import time 6import matplotlib.pyplot as plt 7 8noise_vector=100 9batch=12 10epochs=5000 11feature_dim=(448,448,1) 12num_data=48648/4 13num_genex=16 14 15AUTOTUNE=tf.data.experimental.AUTOTUNE 16 17def parse_example(example): 18 features=tf.io.parse_single_example(example,features={ 19 "image": tf.io.FixedLenFeature([],dtype=tf.string), 20 }) 21 image=tf.image.resize(tf.reshape(tf.io.decode_raw(features["image"],tf.float32),feature_dim),(224,224)) 22 return image 23 24dataset=tf.data.TFRecordDataset(["dcgan.tfrecords"]).map(parse_example).shuffle(int(num_data)).batch(batch).prefetch(buffer_size=AUTOTUNE) 25 26 27 28 29 30def generator(): 31 model=models.Sequential() 32 model.add(layers.Dense(7*7*128,use_bias=False,input_shape=(noise_vector,))) 33 model.add(layers.BatchNormalization()) 34 model.add(layers.LeakyReLU()) # 公式のみ 35 model.add(layers.Reshape((7,7,128))) 36 #(7,7,64) 37 38 model.add(layers.Conv2DTranspose(64,(5,5),strides=2,padding="same",use_bias=False)) 39 model.add(layers.BatchNormalization()) 40 model.add(layers.LeakyReLU()) 41 #(14,14,64) 42 43 model.add(layers.Conv2DTranspose(64,(5,5),strides=2,padding="same",use_bias=False)) 44 model.add(layers.BatchNormalization()) 45 model.add(layers.LeakyReLU()) 46 #(28,28,64) 47 48 model.add(layers.Conv2DTranspose(64,(5,5),strides=2,padding="same",use_bias=False)) 49 model.add(layers.BatchNormalization()) 50 model.add(layers.LeakyReLU()) 51 #(56,56,64) 52 53 model.add(layers.Conv2DTranspose(64,(5,5),strides=2,padding="same",use_bias=False)) 54 model.add(layers.BatchNormalization()) 55 model.add(layers.LeakyReLU()) 56 #(112,112,64) 57 58 model.add(layers.Conv2DTranspose(64,(5,5),strides=2,padding="same",use_bias=False)) 59 model.add(layers.BatchNormalization()) 60 model.add(layers.LeakyReLU()) 61 #(224,224,64) 62 63 64 model.add(layers.Conv2D(1,(5,5),strides=1,padding="same",activation="tanh")) 65 return model 66 67def discriminator(): 68 model=models.Sequential() 69 model.add(layers.Conv2D(64,(5,5),strides=2,padding="same",input_shape=(224,224,1))) 70 model.add(layers.LeakyReLU()) 71 model.add(layers.Conv2D(64,(5,5),strides=2,padding="same")) 72 model.add(layers.LeakyReLU()) 73 model.add(layers.Dropout(0.3)) 74 75 76 model.add(layers.Conv2D(128,(5,5),strides=1,padding="same")) 77 model.add(layers.LeakyReLU()) 78 model.add(layers.Dropout(0.4)) 79 80 model.add(layers.Flatten()) 81 model.add(layers.Dense(1,activation="tanh")) 82 return model 83 84gen=generator() 85dis=discriminator() 86 87cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True) 88def discriminator_loss(real_output,fake_output): 89 real_loss=cross_entropy(tf.ones_like(real_output),real_output) 90 fake_loss=cross_entropy(tf.zeros_like(fake_output),fake_output) 91 total_loss=real_loss+fake_loss 92 return total_loss 93 94def generator_loss(fake_output): 95 return cross_entropy(tf.ones_like(fake_output),fake_output) 96 97generator_optimizer=tf.keras.optimizers.Adam(1e-4) 98discriminator_optimizer=tf.keras.optimizers.Adam(1e-4) 99 100dcgan_dir="./dcgan1" 101checkpoint_dir=dcgan_dir+"/dcgan_checkpoints" 102checkpoint_prefix=os.path.join(checkpoint_dir,"ckpt") 103checkpoint=tf.train.Checkpoint(generator_optimizer=generator_optimizer,discriminator_optimizer=discriminator_optimizer,generator=gen,discriminator=dis) 104 105seed=tf.random.normal([num_genex,noise_vector]) 106 107@tf.function 108def train_step(images): 109 noise=tf.random.normal([batch,noise_vector]) 110 with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape: 111 generated_images=gen(noise,training=True) 112 real_output=dis(images,training=True) 113 fake_output=dis(generated_images,training=True) 114 gen_loss=generator_loss(fake_output) 115 dis_loss=discriminator_loss(real_output,fake_output) 116 grad_of_gen=gen_tape.gradient(gen_loss,gen.trainable_variables) 117 grad_of_dis=dis_tape.gradient(dis_loss,dis.trainable_variables) 118 generator_optimizer.apply_gradients(zip(grad_of_gen,gen.trainable_variables)) 119 discriminator_optimizer.apply_gradients(zip(grad_of_dis,dis.trainable_variables)) 120 121def train(dataset,epochs): 122 for epoch in range(epochs): 123 start=time.time() 124 count=0 125 for image_batch in dataset: 126 count+=1 127 print(count) 128 train_step(image_batch) 129 display.clear_output(wait=True) 130 generate_and_save_images(gen,epoch+1,seed) 131 if (epoch+1)%15==0: 132 checkpoint.save(file_prefix=checkpoint_prefix) 133 print("Time for epoch {} is {} sec".format(epoch+1,time.time()-start)) 134 display.clear_output(wait=True) 135 generate_and_save_images(gen,epochs,seed) 136 137def generate_and_save_images(model,epoch,test_input): 138 predictions=model(test_input,training=False) 139 fig=plt.figure(figsize=(4,4)) 140 for i in range(predictions.shape[0]): 141 plt.subplot(4,4,i+1) 142 plt.imshow(predictions[i,:,:,0]*127.5+127.5,cmap="gray") 143 plt.axis("off") 144 if epoch%15==0: 145 plt.savefig(dcgan_dir+"image_at_epoch{:04d}.png".format(epoch)) 146 plt.show() 147train(dataset,epochs)

このコードで特にエラーは出ずに訓練ができるのですが、epoch間でshuffle処理が入る状態です。基本的な質問かもしれませんがどうぞよろしくお願いします。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問