質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

0回答

1095閲覧

python、tensorflowにてDCGAN動かしたい

makotoinukai

総合スコア6

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2020/04/20 09:21

前提・実現したいこと

tensorflowでDCGANを組む。
データセットにはfashion_mnistを用いる。

発生している問題・エラーメッセージ

----------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-27-49aad1ceb5b1> in <module> 1 dataset = Dataset(trainset, testset, scale_func=scale) 2 ----> 3 losses, samples = train(net, dataset, epochs, batch_size, figsize=(8,8)) <ipython-input-25-2f229ab850d5> in train(net, dataset, epochs, batch_size, print_every, show_every, figsize) 15 batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size)) 16 ---> 17 _ = sess.run(net.d_opt, feed_dict={net.input_real: x, net.input_z: batch_z}) 18 _ = sess.run(net.g_opt, feed_dict={net.input_z: batch_z, net.input_real: x}) 19 ------------------略------------------ ValueError: Cannot feed value of shape (96, 28, 28) for Tensor 'input_real:0', which has shape '(?, 28, 28, 3)'

該当のソースコード

ソースコードが長いため、該当箇所と思われる部分を抜粋しています。

python

1 2#トレーニング実行関数の定義 3 4from __future__ import absolute_import, division, print_function, unicode_literals 5 6# TensorFlow and tf.keras 7import tensorflow as tf 8from tensorflow import keras 9 10import numpy as np 11import matplotlib.pyplot as plt 12 13print(tf.__version__) 14 15fashion_mnist = keras.datasets.fashion_mnist 16 17(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 18 19class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 20 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 21 22trainset = train_images 23testset = test_images 24 25def train(net, dataset, epochs, batch_size, print_every=10, show_every=100, figsize=(8,8)): 26 saver = tf.train.Saver() 27 sample_z = np.random.uniform(-1, 1, size=(72, z_size)) 28 29 samples, losses = [], [] 30 steps = 0 31 32 with tf.Session() as sess: 33 sess.run(tf.global_variables_initializer()) 34 for e in range(epochs): 35 for x,y in dataset.batches(batch_size): 36 steps += 1 37 batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size)) 38 39 _ = sess.run(net.d_opt, feed_dict={net.input_real: x, net.input_z: batch_z}) 40 _ = sess.run(net.g_opt, feed_dict={net.input_z: batch_z, net.input_real: x}) 41 42 43#net.input_real, net.input_z : 入力次元(ノイズの次元)をプレイスホルダーに入れたもの 44 45 if(steps % print_every == 0): 46 train_loss_d = net.d_loss.eval({net.input_z: batch_z, net_input_real: x}) 47 train_loss_g = net.g_loss.eval({net.input_z: batch_z}) 48 49 print("epoch {}/{}:".format(e+1, epochs), 50 "D_loss: {:.4f}".format(train_loss_d), 51 "G_loss: {:.4f}".format(train_loss_g)) 52 53 losses.append(train_loss_d, train_loss_g) 54 55 if(steps % show_every == 0): 56 gen_samples = sess.run(generator(net.input_z, 1, reuse=True, training=False), 57 feed_dict={net.input_z: sample_z}) 58 59 samples.append(gen_samples) 60 _ = view_samples(-1, samples, 6, 12, figsize=figsize) 61 62 plt.show() 63 64 saver.save(sess, 'generator.ckpt') 65 66 with open('samples.pkl', 'wb') as f: 67 pkl.dump(samples, f) 68 69 return losses, samples 70 71def model_inputs(real_dim, z_dim):#real_dim=(28,28,3), z_dim=100(z_size) 72 inputs_real = tf.placeholder(tf.float32, (None, *real_dim), name='input_real') 73 inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z') 74 75 return inputs_real, inputs_z 76 77#損失関数の定義 78 79def model_loss(input_real, input_z, output_dim, alpha=0.2): 80 g_model = generator(input_z, output_dim, alpha=alpha) 81 d_model_real, d_logits_real = discriminator(input_real, alpha=alpha) 82 d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, alpha=alpha) 83 84 d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_model_real))) 85 d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_model_fake))) 86 g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_model_fake))) 87 88 d_loss = d_loss_real + d_loss_fake 89 90 return d_loss, g_loss 91 92 93#最適化関数の定義 94 95def model_opt(d_loss, g_loss, learning_rate, beta1): 96 t_vars = tf.trainable_variables() 97 d_vars = [var for var in t_vars if var.name.startswith('discriminator')] 98 g_vars = [var for var in t_vars if var.name.startswith('generator')] 99 100 with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 101 d_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars) 102 g_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars) 103 104 return d_train_opt, g_train_opt 105 106class GAN: 107 def __init__(self, real_size, z_size, learning_rate, alpha=0.2, beta1=0.5): 108 tf.reset_default_graph() 109 110 self.input_real, self.input_z = model_inputs(real_size, z_size) 111 self.d_loss, self.g_loss = model_loss(self.input_real, self.input_z, real_size[2], alpha=alpha)#リアルサイズ[2]ここでoutput_dimを定義している 112 self.d_opt, self.g_opt = model_opt(self.d_loss, self.g_loss, learning_rate, beta1) 113 114#ハイパーパラメータの初期化とトレーニングの実行 115 116 117real_size = (28, 28, 3) 118# real_size = (32, 32, 3) 119z_size = 100 120learning_rate = 0.0002 121batch_size = 128 122epochs = 25 123alpha = 0.2 124beta1 = 0.5 125 126net = GAN(real_size, z_size, learning_rate, alpha=alpha, beta1=beta1) 127dataset = Dataset(trainset, testset, scale_func=scale) 128 129losses, samples = train(net, dataset, epochs, batch_size, figsize=(8,8)) 130 131

試したこと

placeholderがおかしいのかと思い、いろいろ調整してみましたが解決できませんでした。
また、エラーにあるshape '(?, 28, 28, 3)'は定義しているreal_sizeを指していると思うのですが、
shape (96, 28, 28)が何を指しているかわかりませんでした。

解決策でなくても、情報をいただけるだけでも幸いです。
よろしくお願いします。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問