teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

1

コードの追加

2020/05/24 05:33

投稿

tomoki_fab
tomoki_fab

スコア25

title CHANGED
File without changes
body CHANGED
@@ -46,4 +46,192 @@
46
46
  - 画像分類などは問題なく作れる
47
47
 
48
48
  インターネットで手当たり次第に調べたり、書籍で参考になりそうな情報を探したり、手を尽くしましたがなぜ全くうまくいかないのかわかりません。問題がレイヤーの組み方なのか、何か決定的に重要な処理を飛ばしてしまっているのか、損失関数などの評価の仕方が悪いのかも見当もつきません。
49
- なにか見落としているポイントなどがありますでしょうか?
49
+ なにか見落としているポイントなどがありますでしょうか?
50
+
51
+
52
+ コード
53
+ ```Python
54
+ import numpy as np
55
+ from PIL import Image
56
+
57
+ from keras.datasets import mnist
58
+ from keras.layers import *
59
+ from keras.models import *
60
+ from keras.optimizers import *
61
+
62
+
63
+ # --- プログレスバーを表示するクラス(学習自体には関係ありません) ---------
64
+ class ProgressBar:
65
+ def __init__(self, entireJob):
66
+ self.job = entireJob
67
+ self.width = 40
68
+ def draw(self, progress):
69
+ print( ("\r["+"#"*int((progress+1)*self.width/self.job)+" "*(self.width-int((progress+1)*self.width/self.job) ) +"] %d/%d")%(progress+1,self.job), end="")
70
+
71
+
72
+ # --- Generatorモデルの定義 -----------
73
+ class Generator:
74
+ def __init__(self):
75
+ layer0 = Input(shape=(1,1,100))
76
+
77
+ layer1 = UpSampling2D(size=(3,3))(layer0)
78
+ layer1 = Conv2D(
79
+ filters=100,
80
+ kernel_size=(2,2),
81
+ strides=(1,1),
82
+ padding='same',
83
+ activation='relu' )(layer1)
84
+ layer1 = BatchNormalization()(layer1)
85
+
86
+ layer2 = UpSampling2D(size=(3,3))(layer1)
87
+ layer2 = Conv2D(
88
+ filters=100,
89
+ kernel_size=(2,2),
90
+ strides=(1,1),
91
+ padding='same',
92
+ activation='relu' )(layer2)
93
+ layer2 = BatchNormalization()(layer2)
94
+
95
+ layer3 = UpSampling2D(size=(2,2))(layer2)
96
+ layer3 = Conv2D(
97
+ filters=80,
98
+ kernel_size=(3,3),
99
+ strides=(1,1),
100
+ padding='valid',
101
+ activation='elu' )(layer3)
102
+ layer3 = BatchNormalization()(layer3)
103
+
104
+ layer4 = UpSampling2D(size=(2,2))(layer3)
105
+ layer4 = Conv2D(
106
+ filters=50,
107
+ kernel_size=(3,3),
108
+ strides=(1,1),
109
+ padding='same',
110
+ activation='elu' )(layer4)
111
+ layer4 = BatchNormalization()(layer4)
112
+
113
+ layer5 = UpSampling2D(size=(2,2))(layer4)
114
+ layer5 = Conv2D(
115
+ filters=20,
116
+ kernel_size=(4,4),
117
+ strides=(2,2),
118
+ padding='valid',
119
+ activation='elu' )(layer5)
120
+ layer5 = BatchNormalization()(layer5)
121
+
122
+ layer6 = Conv2D(
123
+ filters=1,
124
+ kernel_size=(4,4),
125
+ strides=(1,1),
126
+ padding='valid',
127
+ activation='tanh' )(layer5)
128
+
129
+ self.model = Model(layer0, layer6)
130
+ self.model.summary()
131
+
132
+ # --- Discriminatorモデルの定義 -------
133
+ class Discriminator:
134
+ def __init__(self):
135
+ layer0 = Input(shape=(28,28,1))
136
+ layer1 = Conv2D(
137
+ filters=5,
138
+ kernel_size=(3,3),
139
+ strides=(2,2),
140
+ padding='valid',
141
+ activation='elu' )(layer0)
142
+ layer1 = BatchNormalization()(layer1)
143
+
144
+ layer2 = Conv2D(
145
+ filters=10,
146
+ kernel_size=(3,3),
147
+ strides=(2,2),
148
+ padding='valid',
149
+ activation='elu' )(layer1)
150
+ layer2 = BatchNormalization()(layer2)
151
+
152
+ layer3 = Conv2D(
153
+ filters=5,
154
+ kernel_size=(3,3),
155
+ strides=(1,1),
156
+ padding='valid',
157
+ activation='relu' )(layer2)
158
+ layer3 = BatchNormalization()(layer3)
159
+
160
+ layer4 = Flatten()(layer3)
161
+ layer4 = Dense(units=30, activation='tanh')(layer4)
162
+ layer4 = BatchNormalization()(layer4)
163
+
164
+ layer5 = Dense(units=1, activation='sigmoid' )(layer4)
165
+
166
+ self.model = Model(layer0, layer5)
167
+ self.model.summary()
168
+
169
+
170
+
171
+ class Main:
172
+ def __init__(self):
173
+ # --- Discriminatorの定義 -----------------
174
+ self.discriminator = Discriminator().model
175
+ self.discriminator.compile(
176
+ optimizer=SGD(learning_rate=1e-4),
177
+ loss='binary_crossentropy',
178
+ metrics=['accuracy'] )
179
+
180
+ # --- GeneratorとDiscriminatorを連結したモデルの定義 ---
181
+ self.generator = Generator().model
182
+ z = Input(shape=(1,1,100))
183
+ img = self.generator(z)
184
+ self.discriminator.trainable = False # Discriminatorを更新しないよう設定
185
+ valid = self.discriminator(img)
186
+ self.combined = Model(z, valid)
187
+ self.combined.compile(
188
+ optimizer=Adam(learning_rate=1e-6),
189
+ loss='binary_crossentropy',
190
+ metrics=['accuracy'] )
191
+
192
+ # --- MNISTデータセットの用意 ---------------
193
+ (x_train, t_train), (x_test, t_test) = mnist.load_data()
194
+ x_train = x_train.reshape(60000, 28, 28, 1)
195
+ x_test = x_test.reshape(10000, 28, 28, 1)
196
+ self.x_train = x_train.astype('float32')
197
+ self.x_test = x_test.astype('float32')
198
+
199
+ # --- 学習 -------------------------------------
200
+ def _train(self, iteration, batch_size):
201
+ progress = ProgressBar(iteration) # プログレスバーを用意
202
+ for i in range(iteration):
203
+ z = np.random.uniform(-1,1,(batch_size//2,1,1,100)) # ノイズベクトルの生成
204
+ f_img = self.generator.predict(z) # f_img(fake_img)の生成
205
+ r_img = self.x_train[np.random.randint(0, 60000, batch_size//2)] # r_img(real_img)を読み込み
206
+ loss_d, acc_d = self.discriminator.train_on_batch(f_img, np.zeros((batch_size//2,1))) # Discriminatorの学習
207
+ loss_d_, acc_d_ = self.discriminator.train_on_batch(r_img, np.ones( (batch_size//2,1))) # acc_d = Discriminatorのaccuracy
208
+ acc_d += acc_d_
209
+
210
+ z = np.random.uniform(-1,1,(batch_size,1,1,100)) # ノイズベクトルの生成
211
+ loss_g, acc_g = self.combined.train_on_batch(z, np.ones((batch_size,1))) # Generatorの学習
212
+ progress.draw(i) # プログレスバーの表示
213
+ print(" Accuracy=(%f,%f)"%(acc_g, acc_d/2), end="")
214
+
215
+ def train(self, iteration, batch_size, epoch):
216
+ for i in range(epoch):
217
+ print("Epoch %d/%d\n"%(i+1, epoch))
218
+ self._train(iteration, batch_size) # _train()をepoch回繰り返します
219
+
220
+ # --- 学習が終わった時の確認用に一枚だけ画像を作ります -------
221
+ def create_image(self):
222
+ z = np.random.uniform(-1,1,(1,1,1,100))
223
+ img = self.generator.predict(z)
224
+ return img.reshape(1,28,28)
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main = Main()
229
+ main.train(iteration=1875, batch_size=32, epoch=1)
230
+
231
+ # --- 画像を表示 -----------------------
232
+ img = main.create_image()
233
+ img = Image.fromarray(np.uint8(img.reshape(28,28) * 255))
234
+ img.show()
235
+ img.save("gan_generated_img.png")
236
+
237
+ ```