質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

1回答

1313閲覧

DCGANの実装について

naoki_9936

総合スコア12

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2019/05/31 13:43

前提・実現したいこと

とある雑誌を参考にkerasでDCGANを実装しようとしています。
が実行しても、以下のメッセージがでるのみで終了してしまいます。
どなたかご教授の方、お願いいたします。

発生している問題・エラーメッセージ

Using TensorFlow backend.

該当のソースコード

python

1# -*- coding: shift_jis -*- 2 3# (a)必要なライブラリのインポート 4from keras.models import Sequential, Model 5from keras.layers.convolutional import Conv2D 6from keras.layers import BatchNormalization 7from keras.layers.advanced_activations import LeakyReLU 8from keras.layers import Flatten, Dense, Input, Reshape, Input 9from keras.optimizers import Adam 10from keras.layers.convolutional import Conv2DTranspose 11from keras.preprocessing.image import ImageDataGenerator 12import numpy as np 13import glob, os, sys 14import cv2 15import datetime as dt 16import pickle as pkl 17from keras.utils.vis_utils import plot_model 18import matplotlib.pyplot as plt 19 20REAL_IMAGE_PATH = '' # 本物画像のフォルダー 21FAKE_IMAGE_PATH = '' # 生成画像の保存先 22BATCH_SIZE = 100 23ITERATION_MAX = 30000 # イテレーションの上限 24LR = 0.0002 25ALPHA = 0.2 26BETA_1 = 0.5 27 28# (b)Generator(生成 モデル)を作る関数 の定義 29# 生成画像は32×32ピクセルのサイズ 30def Generator(): 31# SequentialモデルでGeneratorを定義 32 model = Sequential() 33 34 #1層目:入力層 35 model.add(Dense(4*4*512, input_shape=(100, ))) 36 model.add(Reshape((4, 4, 512))) 37 model.add(BatchNormalization()) 38 model.add(LeakyReLU(alpha=ALPHA)) 39 40 #2層目:逆畳み込み層1 41 model.add( 42 Conv2DTranspose( 43 filters = 256, 44 kernel_size = 5, 45 strides = 2, 46 padding = 'same' 47 ) 48 ) 49 model.add(BatchNormalization()) 50 model.add(LeakyReLU(alpha = ALPHA)) 51 52 # 3層目:逆畳み込み層2 53 model.add( 54 Conv2DTranspose( 55 filters = 128, 56 kernel_size = 5, 57 strides = 2, 58 padding = 'same' 59 ) 60 ) 61 model.add(BatchNormalization()) 62 model.add(LeakyReLU(alpha = ALPHA)) 63 64 # 4層目:逆畳み込み層3 65 model.add( 66 Conv2DTranspose( 67 filters = 3, 68 kernel_size = 5, 69 strides = 2, 70 padding = 'same', 71 activation = 'tanh' 72 ) 73 ) 74 return model 75 76# (c)Discriminator(識別モデルを作る関数の定義 77# 論文に従い、プーリング層は排除 78def Discriminator(): 79 # SequentialモデルでDiscriminatorを定義 80 model = Sequential() 81 82 # 1層目 83 model.add( 84 Conv2D( 85 filters = 32, 86 kernel_size = 5, 87 strides = 2, 88 input_shape = (32, 32, 3), 89 padding = 'same', 90 ) 91 ) 92 model.add(BatchNormalization()) 93 model.add(LeakyReLU(alpha = ALPHA)) 94 95 # 2層目 96 model.add( 97 Conv2D( 98 filters = 64, 99 kernel_size = 5, 100 strides = 2, 101 padding = 'same', 102 ) 103 ) 104 model.add(BatchNormalization()) 105 model.add(LeakyReLU(alpha = ALPHA)) 106 107 # 3層目 108 model.add( 109 Conv2D( 110 filters = 128, 111 kernel_size = 5, 112 strides = 2, 113 padding = 'same', 114 ) 115 ) 116 model.add(BatchNormalization()) 117 model.add(LeakyReLU(alpha = ALPHA)) 118 model.add(Flatten()) 119 model.add(Dense(1, activation = 'sigmoid')) 120 121 return model 122 123# GeneratorとDiscriminatorを連結して 124# Generatorトレーニング用のモデルを定義 125def Combined(generator, discriminator): 126 # Generatorの学習のみ行いたいので 127 # Disctiminatorのパラメータ更新は行わない 128 discriminator.trainable = False 129 130 # GeneratorとDiscriminatorを結合 131 model = Sequential([generator, discriminator]) 132 133 return model 134 135#(d)本物画像を読み込む関数 の定義 136def loadRealImages(): 137 img_paths = [] 138 images = [] 139 imageFiles = glob.glob(os.path.join(REAL_IMAGE_PATH, '*.jpg')) 140 141 for file in imageFiles: 142 img = cv2.imread(file) 143 img = cv2.resize(img, (32, 32)) 144 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 145 images.append(img) 146 147 images = np.array(images) 148 149 return np.array(images) 150 151# (e)生成画像を保存する関数の定義 152def saveGeneratImage(iteration, images): 153 #Generatorで1度に25枚の画像を生成するため、表示エリアに5×5の 154 #パネルを準備する 155 fig, axes = plt.subplots(5, 5) 156 images = generator.predict(noise) 157 158 # 0~1にスケールを揃える 159 images = 0.5 * images + 0.5 160 161 #zipで表示する画像(save_gimage)と表示位置(axes.flatten)を 162 #対で取得し、順に表示(imshow)する 163 for img, ax in zip(images, axes.flatten()): 164 ax.imshow(img) 165 ax.axis('off') 166 167 if os.path.exists(FAKE_IMAGE_PATH) == False: 168 os.mkdir(FAKE_IMAGE_PATH) 169 fname = FAKE_IMAGE_PATH+'/generate_%05d.png' % iteration 170 fig.savefig(fname) 171 print() 172 orint('output: {}'.format(fname), end='',flush=True) 173 save_gimage.append(images) 174 175 plt.close() 176 177if __name__ == '__main=__': 178 args = sys.argv 179 180 if len(args) == 1: 181 print('生成画像のキーワードを指定してください。') 182 sys.exit() 183 184 REAL_IMAGE_PATH = args[1] 185 # 本物画像のフォルダーをチェック。無ければ終了 186 # 存在する場合 はモデル構築、トレーニングを行う 187 if os.path.exists(REAL_IMAGE_PATH) == False: 188 print('本物画像のフォルダーがありません。') 189 sys.exit() 190 191 # (f)モデル構築と学習の実行 192 # Generator単独 でのCompileはしない。Combinedモデルに 193 # 入れてからCompileを行う 194 generator = Generator() 195 196 # Discriminator 197 discriminator = Discriminator() 198 discriminator.compile( 199 # 損失関数は2値交差エントロピー 200 loss = 'binary_crossentropy', 201 # 最適化アルゴリズムとしてAdamを指定 202 optimizer = Adam(lr=LR, beta_1=BETA_1), 203 # 評価関数 204 metrics = ['accuracy'] 205 ) 206 207 # GeneratorとDisciminatorを連結したモデル 208 combined = Combined(generator, discriminator) 209 combined.compile( 210 # 損失関数は2値交差エントロピー 211 loss = 'binary_crossentropy', 212 # 最適化アルゴリズムとしてAdamを指定 213 optimizer = Adam(lr=LR, beta_1=BETA_1) 214 ) 215 216 # 処理開始時刻の取得 217 tstamp_s = dt.datetime.now().strftime("%H:%M:%S") 218 219 checkPointNoise = np.random.uniform(-1, 1, (25, 100)) 220 checkPoint = 1000 221 222 # 本物画像 223 realImages = loadRealImages() 224 225 # 本物画像を正規化(0 to 255 >> -1 to 1) 226 realImages = (realImages.astype(np.float32) - 127.5) / 127.5 227 228 # 途中経過を保存する変数の定義 229 save_gimage = [] 230 save_loss = [] 231 try: 232 # 1回のイテレーションで100個の画像を取ってくることにする 233 for iteration in range(ITERATION_MAX): 234 # Discriminatorトレーニング 235 idx = np.random.randint(0, len(realImages), BATCH_SIZE) 236 # 本物画像の取得 237 rimage = realImages[idx] 238 noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100)) 239 gimage = generator.predict(noise) 240 rimage = rimage.reshape(BATCH_SIZE, 32, 32, 3) 241 d_loss_real = discriminator.train_on_batch(rimage, np.ones((BATCH_SIZE, 1))) 242 d_loss_fake = discriminator.train_on_batch(gimage, np.zeros((BATCH_SIZE, 1))) 243 d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 244 noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100)) 245 g_loss = combined.train_on_batch(noise, np.ones((BATCH_SIZE, 1))) 246 247 # 1000イテレーションごとに生成画像を出力 248 if iteration % checkPoint == 0: 249 saveGeneratImage(iteration, checkPointNoise) 250 251 # 10イテレーションごとに損失関数の値を表示 252 if iteration % 10 == 0: 253 print('') 254 print('{0} Iteration={1}/{2}, DLoss={3:.4F}, GLoss={4:.4F}'.format( 255 dt.datetime.now().strftime("%H:%M:%S"), iteration, 256 ITERATION_MAX, d_loss[0], g_loss), end='', flush=True) 257 save_loss.append((d_loss[0], g_loss)) 258 else: 259 print('.', end='', flush=True) 260 261 # 学習完了後のモデルで画像を生成する 262 saveGeneratImage(iteration, checkPointNoise) 263 # pkl形式 でも生成画像を保存 264 with open('save_dcgan_image.pkl', 'wb') as f: 265 pkl.dump(save_gimage, f) 266 # pkl形式 で損失関数の値を保存 267 with open('save_dcgan_loss.pkl', 'wb') as f: 268 pkl.dump(save_loss, f) 269 270 # 処理終了時刻の取得 271 tstamp_e = dt.datetime.now().strftime("%H:%M:%S") 272 time1 = dt.datetime.strptime(tstamp_s, "%H:%M:%S") 273 time2 = dt.datetime.strptime(tstamp_e, "%H:%M:%S") 274 275 # 処理時間を表示 276 print('') 277 print("開始: {0}、終了:{1}、処理時間:{2}".format(tstamp_s, tstamp_e, (time2 - time1))) 278 279 except KeyboardInterrupt: 280 print('') 281 print('強制終了しました。') 282 283 with open('save_dcgan_image.pkl', 'wb') as f: 284 pkl.dump(save_gimage, f) 285 with open('save_dcgan_loss.pkl', 'wb') as f: 286 pkl.dump(save_loss, f)

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

以下のコードが誤字になっています。

if __name__ == '__main=__':#正しくは'__main__'

誤字のために、main部分が実行されずに飛ばされているんだと思います。

投稿2019/06/03 02:27

amahara_waya

総合スコア1029

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問