質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

1回答

671閲覧

【Python】【Keras】【GAN】指定ディレクトリ下の画像の学習ができない

Reach

総合スコア733

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2019/02/10 06:59

編集2019/02/10 15:01

お世話になります
機械(深層)学習の超初心者です

現在 WGAN-GPで カラー画像を (Colab上で) 学習させようとチャレンジしております
ここのコードを使っております

画像は読み込んでいるようなのですが 以下のエラーが出て 実行できません

Traceback (most recent call last): File "wgan_gp_3.py", line 270, in <module> wgan.train(epochs=30000, batch_size=8, sample_interval=100) File "wgan_gp_3.py", line 234, in train [valid, fake, dummy]) File "/usr/local/lib/python3.6/dist-packages/keras/engine/training.py", line 1217, in train_on_batch outputs = self.train_function(ins) File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 2715, in __call__ return self._call(inputs) File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 2675, in _call fetched = self._callable_fn(*array_vals) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1439, in __call__ run_metadata_ptr) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__ c_api.TF_GetCode(self.status.status)) tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [32,1,1,1] vs. [8,3,256,256] [[{{node random_weighted_average_1/mul_1}}]] [[{{node loss/model_2_loss/Mean_3}}]]

変更部分コード

Python

1 2 3from __future__ import print_function, division 4 5from keras.datasets import mnist 6from keras.layers.merge import _Merge 7from keras.layers import Input, Dense, Reshape, Flatten, Dropout 8from keras.layers import BatchNormalization, Activation, ZeroPadding2D 9from keras.layers.advanced_activations import LeakyReLU 10from keras.layers.convolutional import UpSampling2D, Conv2D 11from keras.models import Sequential, Model 12from keras.optimizers import RMSprop 13from functools import partial 14from keras.preprocessing.image import load_img, img_to_array,save_img , array_to_img 15 16 17from tensorflow.python.keras.preprocessing.image import ImageDataGenerator 18 19import keras.backend as K 20 21import matplotlib.pyplot as plt 22 23import glob 24 25import sys 26 27import numpy as np 28 29 30class WGANGP(): 31 def __init__(self): 32 self.img_rows = 256 33 self.img_cols = 256 34 self.channels = 3 35 self.img_shape = (self.img_rows, self.img_cols, self.channels) 36 self.latent_dim = 100 37 38 data_dir = '/content/drive/My Drive/images/' 39 40 batch_size = 8 41 42 43 gen = ImageDataGenerator(rescale=1/127.5, samplewise_center=True) 44 iters = gen.flow_from_directory( 45 directory=data_dir, 46 classes=['test'], 47 class_mode=None, 48 color_mode='rgb', 49 target_size=self.img_shape[:2], 50 batch_size=batch_size, 51 shuffle=True 52 ) 53 self.x_train_batch=next(iters) 54 55 56 # Following parameter and optimizer set as recommended in paper 57 self.n_critic = 5 58 optimizer = RMSprop(lr=0.00005) 59 60 # Build the generator and critic 61 self.generator = self.build_generator() 62 self.critic = self.build_critic() 63 64 #------------------------------- 65 # Construct Computational Graph 66 # for the Critic 67 #------------------------------- 68 69 # Freeze generator's layers while training critic 70 self.generator.trainable = False 71 72 # Image input (real sample) 73 real_img = Input(shape=self.img_shape) 74 75 # Noise input 76 z_disc = Input(shape=(self.latent_dim,)) 77 # Generate image based of noise (fake sample) 78 fake_img = self.generator(z_disc) 79 80 # Discriminator determines validity of the real and fake images 81 fake = self.critic(fake_img) 82 valid = self.critic(real_img) 83 84 # Construct weighted average between real and fake images 85 interpolated_img = RandomWeightedAverage()([real_img, fake_img]) 86 # Determine validity of weighted sample 87 validity_interpolated = self.critic(interpolated_img) 88 89 # Use Python partial to provide loss function with additional 90 # 'averaged_samples' argument 91 partial_gp_loss = partial(self.gradient_penalty_loss, 92 averaged_samples=interpolated_img) 93 partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names 94 95 self.critic_model = Model(inputs=[real_img, z_disc], 96 outputs=[valid, fake, validity_interpolated]) 97 self.critic_model.compile(loss=[self.wasserstein_loss, 98 self.wasserstein_loss, 99 partial_gp_loss], 100 optimizer=optimizer, 101 loss_weights=[1, 1, 10]) 102 #------------------------------- 103 # Construct Computational Graph 104 # for Generator 105 #------------------------------- 106 107 # For the generator we freeze the critic's layers 108 self.critic.trainable = False 109 self.generator.trainable = True 110 111 # Sampled noise for input to generator 112 z_gen = Input(shape=(100,)) 113 # Generate images based of noise 114 img = self.generator(z_gen) 115 # Discriminator determines validity 116 valid = self.critic(img) 117 # Defines generator model 118 self.generator_model = Model(z_gen, valid) 119 self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer) 120 121 122 def build_generator(self): 123 124 model = Sequential() 125 126 model.add(Dense(128 * 64 * 64, activation="relu", input_dim=self.latent_dim)) 127 model.add(Reshape((64, 64, 128))) 128 model.add(UpSampling2D()) 129 model.add(Conv2D(128, kernel_size=4, padding="same")) 130 model.add(BatchNormalization(momentum=0.8)) 131 model.add(Activation("relu")) 132 model.add(UpSampling2D()) 133 model.add(Conv2D(64, kernel_size=4, padding="same")) 134 model.add(BatchNormalization(momentum=0.8)) 135 model.add(Activation("relu")) 136 model.add(Conv2D(self.channels, kernel_size=4, padding="same")) 137 model.add(Activation("tanh")) 138 139 model.summary() 140 141 noise = Input(shape=(self.latent_dim,)) 142 img = model(noise) 143 144 return Model(noise, img) 145 146 147 148 def train(self, epochs, batch_size, sample_interval=50): 149 150 151 # Adversarial ground truths 152 valid = -np.ones((batch_size, 1)) 153 fake = np.ones((batch_size, 1)) 154 dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty 155 for epoch in range(epochs): 156 157 for _ in range(self.n_critic): 158 159 # --------------------- 160 # Train Discriminator 161 # --------------------- 162 163 # Select a random batch of images 164 X_train = self.x_train_batch 165 X_train = ((np.array(X_train, dtype=np.float32) - 127.5) / 127.5) 166 imgs = X_train 167 # Sample generator input 168 noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) 169 # Train the critic 170 d_loss = self.critic_model.train_on_batch([imgs, noise], 171 [valid, fake, dummy]) 172 173 # --------------------- 174 # Train Generator 175 # --------------------- 176 177 g_loss = self.generator_model.train_on_batch(noise, valid) 178 179 # Plot the progress 180 print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss)) 181 182 # If at save interval => save generated image samples 183 if epoch % sample_interval == 0: 184 self.sample_images(epoch) 185 186 187

試したこと
色々と 数値を変えて試してみましたが 解決に至りませんでした

どなたか わかる方 ご教示 よろしく お願い致します

追記:

画像は 256 x 256 のカラー画像です (538枚)

他 ^C が表示されたコード

Python

1class RandomWeightedAverage(_Merge): 2 """Provides a (random) weighted average between real and generated image samples""" 3 def _merge_function(self, inputs): 4 alpha = K.random_uniform((256, 1, 1, 1)) 5 return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

Python

1def build_critic(self): 2 3 model = Sequential() 4 5 model.add(Conv2D(256, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same")) 6 model.add(LeakyReLU(alpha=0.2)) 7 model.add(Dropout(0.25)) 8 model.add(Conv2D(512, kernel_size=3, strides=2, padding="same")) 9 model.add(ZeroPadding2D(padding=((0,1),(0,1)))) 10 model.add(BatchNormalization(momentum=0.8)) 11 model.add(LeakyReLU(alpha=0.2)) 12 model.add(Dropout(0.25)) 13 model.add(Conv2D(1024, kernel_size=3, strides=2, padding="same")) 14 model.add(BatchNormalization(momentum=0.8)) 15 model.add(LeakyReLU(alpha=0.2)) 16 model.add(Dropout(0.25)) 17 model.add(Conv2D(2048, kernel_size=3, strides=1, padding="same")) 18 model.add(BatchNormalization(momentum=0.8)) 19 model.add(LeakyReLU(alpha=0.2)) 20 model.add(Dropout(0.25)) 21 model.add(Flatten()) 22 model.add(Dense(1)) 23 24 model.summary() 25 26 img = Input(shape=self.img_shape) 27 validity = model(img) 28 29 return Model(img, validity) 30

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

Q71

2019/02/10 12:22

読ませようとした画像はどんなものでしょうか。1×1のグレースケール? 色々いじったを、具体的に示してください。
guest

回答1

0

自己解決

こちら
コードを 手直しして 目的が達成できました

学ぶことは 多いです

投稿2019/02/15 15:02

Reach

総合スコア733

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問