【Python】【Keras】【GAN】指定ディレクトリ下の画像の学習ができない
- 評価
- クリップ 0
- VIEW 1,128
お世話になります
機械(深層)学習の超初心者です
現在 WGAN-GPで カラー画像を (Colab上で) 学習させようとチャレンジしております
ここのコードを使っております
画像は読み込んでいるようなのですが 以下のエラーが出て 実行できません
Traceback (most recent call last):
File "wgan_gp_3.py", line 270, in <module>
wgan.train(epochs=30000, batch_size=8, sample_interval=100)
File "wgan_gp_3.py", line 234, in train
[valid, fake, dummy])
File "/usr/local/lib/python3.6/dist-packages/keras/engine/training.py", line 1217, in train_on_batch
outputs = self.train_function(ins)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 2715, in __call__
return self._call(inputs)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 2675, in _call
fetched = self._callable_fn(*array_vals)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py", line 1439, in __call__
run_metadata_ptr)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [32,1,1,1] vs. [8,3,256,256]
[[{{node random_weighted_average_1/mul_1}}]]
[[{{node loss/model_2_loss/Mean_3}}]]
変更部分コード
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers.merge import _Merge
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
from functools import partial
from keras.preprocessing.image import load_img, img_to_array,save_img , array_to_img
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
import keras.backend as K
import matplotlib.pyplot as plt
import glob
import sys
import numpy as np
class WGANGP():
def __init__(self):
self.img_rows = 256
self.img_cols = 256
self.channels = 3
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
data_dir = '/content/drive/My Drive/images/'
batch_size = 8
gen = ImageDataGenerator(rescale=1/127.5, samplewise_center=True)
iters = gen.flow_from_directory(
directory=data_dir,
classes=['test'],
class_mode=None,
color_mode='rgb',
target_size=self.img_shape[:2],
batch_size=batch_size,
shuffle=True
)
self.x_train_batch=next(iters)
# Following parameter and optimizer set as recommended in paper
self.n_critic = 5
optimizer = RMSprop(lr=0.00005)
# Build the generator and critic
self.generator = self.build_generator()
self.critic = self.build_critic()
#-------------------------------
# Construct Computational Graph
# for the Critic
#-------------------------------
# Freeze generator's layers while training critic
self.generator.trainable = False
# Image input (real sample)
real_img = Input(shape=self.img_shape)
# Noise input
z_disc = Input(shape=(self.latent_dim,))
# Generate image based of noise (fake sample)
fake_img = self.generator(z_disc)
# Discriminator determines validity of the real and fake images
fake = self.critic(fake_img)
valid = self.critic(real_img)
# Construct weighted average between real and fake images
interpolated_img = RandomWeightedAverage()([real_img, fake_img])
# Determine validity of weighted sample
validity_interpolated = self.critic(interpolated_img)
# Use Python partial to provide loss function with additional
# 'averaged_samples' argument
partial_gp_loss = partial(self.gradient_penalty_loss,
averaged_samples=interpolated_img)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
self.critic_model = Model(inputs=[real_img, z_disc],
outputs=[valid, fake, validity_interpolated])
self.critic_model.compile(loss=[self.wasserstein_loss,
self.wasserstein_loss,
partial_gp_loss],
optimizer=optimizer,
loss_weights=[1, 1, 10])
#-------------------------------
# Construct Computational Graph
# for Generator
#-------------------------------
# For the generator we freeze the critic's layers
self.critic.trainable = False
self.generator.trainable = True
# Sampled noise for input to generator
z_gen = Input(shape=(100,))
# Generate images based of noise
img = self.generator(z_gen)
# Discriminator determines validity
valid = self.critic(img)
# Defines generator model
self.generator_model = Model(z_gen, valid)
self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(128 * 64 * 64, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((64, 64, 128)))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=4, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=4, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
model.add(Activation("tanh"))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def train(self, epochs, batch_size, sample_interval=50):
# Adversarial ground truths
valid = -np.ones((batch_size, 1))
fake = np.ones((batch_size, 1))
dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
for epoch in range(epochs):
for _ in range(self.n_critic):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
X_train = self.x_train_batch
X_train = ((np.array(X_train, dtype=np.float32) - 127.5) / 127.5)
imgs = X_train
# Sample generator input
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the critic
d_loss = self.critic_model.train_on_batch([imgs, noise],
[valid, fake, dummy])
# ---------------------
# Train Generator
# ---------------------
g_loss = self.generator_model.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
試したこと
色々と 数値を変えて試してみましたが 解決に至りませんでした
どなたか わかる方 ご教示 よろしく お願い致します
追記:
画像は 256 x 256 のカラー画像です (538枚)
他 ^C が表示されたコード
class RandomWeightedAverage(_Merge):
"""Provides a (random) weighted average between real and generated image samples"""
def _merge_function(self, inputs):
alpha = K.random_uniform((256, 1, 1, 1))
return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])
def build_critic(self):
model = Sequential()
model.add(Conv2D(256, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(1024, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(2048, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
-
気になる質問をクリップする
クリップした質問は、後からいつでもマイページで確認できます。
またクリップした質問に回答があった際、通知やメールを受け取ることができます。
クリップを取り消します
-
良い質問の評価を上げる
以下のような質問は評価を上げましょう
- 質問内容が明確
- 自分も答えを知りたい
- 質問者以外のユーザにも役立つ
評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。
質問の評価を上げたことを取り消します
-
評価を下げられる数の上限に達しました
評価を下げることができません
- 1日5回まで評価を下げられます
- 1日に1ユーザに対して2回まで評価を下げられます
質問の評価を下げる
teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。
- プログラミングに関係のない質問
- やってほしいことだけを記載した丸投げの質問
- 問題・課題が含まれていない質問
- 意図的に内容が抹消された質問
- 過去に投稿した質問と同じ内容の質問
- 広告と受け取られるような投稿
評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。
質問の評価を下げたことを取り消します
この機能は開放されていません
評価を下げる条件を満たしてません
質問の評価を下げる機能の利用条件
この機能を利用するためには、以下の事項を行う必要があります。
- 質問回答など一定の行動
-
メールアドレスの認証
メールアドレスの認証
-
質問評価に関するヘルプページの閲覧
質問評価に関するヘルプページの閲覧
15分調べてもわからないことは、teratailで質問しよう!
- ただいまの回答率 88.20%
- 質問をまとめることで、思考を整理して素早く解決
- テンプレート機能で、簡単に質問をまとめられる
質問への追記・修正、ベストアンサー選択の依頼
Q71
2019/02/10 21:22
読ませようとした画像はどんなものでしょうか。1×1のグレースケール?
色々いじったを、具体的に示してください。