質問編集履歴
3
誤字の修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -88,8 +88,6 @@
|
|
88
88
|
|
89
89
|
dim = K.int_shape(z_mean)[1]
|
90
90
|
|
91
|
-
# by default, random_normal has mean=0 and std=1.0
|
92
|
-
|
93
91
|
epsilon = K.random_normal(shape=(batch, dim))
|
94
92
|
|
95
93
|
return z_mean + K.exp(0.5 * z_log_var) * epsilon
|
2
誤字の修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -112,8 +112,6 @@
|
|
112
112
|
|
113
113
|
encoder, decoder = models
|
114
114
|
|
115
|
-
#x_test, y_test = data
|
116
|
-
|
117
115
|
os.makedirs(model_name, exist_ok=True)
|
118
116
|
|
119
117
|
|
1
エラー内容の変更
test
CHANGED
@@ -1 +1 @@
|
|
1
|
-
Python3 ValueError
|
1
|
+
Python3 VInvalidArgumentError (see above for traceback): Incompatible shapesのエラー解消方法について
|
test
CHANGED
@@ -2,9 +2,9 @@
|
|
2
2
|
|
3
3
|
|
4
4
|
|
5
|
-
KerasのVAEを自作データセットで実装中、以下のようなエラーが出ました。該当の部分を見ても、どう修正したら良いか分からず、
|
5
|
+
KerasのVAEを自作データセットで実装中、以下のようなエラーが出ました。該当の部分を見ても、どう修正したら良いか分からず、質問させて頂きました。
|
6
|
-
|
7
|
-
|
6
|
+
|
7
|
+
|
8
8
|
|
9
9
|
|
10
10
|
|
@@ -12,13 +12,233 @@
|
|
12
12
|
|
13
13
|
|
14
14
|
|
15
|
+
InvalidArgumentError Traceback (most recent call last)
|
16
|
+
|
17
|
+
|
18
|
+
|
15
|
-
|
19
|
+
InvalidArgumentError: Incompatible shapes: [12800] vs. [450]
|
16
|
-
|
17
|
-
|
20
|
+
|
21
|
+
|
22
|
+
|
23
|
+
|
18
24
|
|
19
25
|
### 該当のソースコード
|
20
26
|
|
27
|
+
from __future__ import absolute_import
|
28
|
+
|
29
|
+
from __future__ import division
|
30
|
+
|
31
|
+
from __future__ import print_function
|
32
|
+
|
33
|
+
|
34
|
+
|
35
|
+
from keras.layers import Dense, Input
|
36
|
+
|
37
|
+
from keras.layers import Conv2D, Flatten, Lambda
|
38
|
+
|
39
|
+
from keras.layers import Reshape, Conv2DTranspose
|
40
|
+
|
41
|
+
from keras.models import Model
|
42
|
+
|
43
|
+
from keras.losses import mse, binary_crossentropy
|
44
|
+
|
45
|
+
from keras.utils import plot_model
|
46
|
+
|
47
|
+
from keras import backend as K
|
48
|
+
|
49
|
+
from keras import optimizers
|
50
|
+
|
51
|
+
|
52
|
+
|
53
|
+
|
54
|
+
|
55
|
+
import numpy as np
|
56
|
+
|
57
|
+
import matplotlib.pyplot as plt
|
58
|
+
|
59
|
+
import argparse
|
60
|
+
|
61
|
+
import tensorflow as tf
|
62
|
+
|
63
|
+
import random as rn
|
64
|
+
|
65
|
+
import os
|
66
|
+
|
67
|
+
import easydict
|
68
|
+
|
69
|
+
|
70
|
+
|
71
|
+
import warnings
|
72
|
+
|
73
|
+
warnings.filterwarnings('ignore')
|
74
|
+
|
75
|
+
|
76
|
+
|
77
|
+
%matplotlib inline
|
78
|
+
|
79
|
+
|
80
|
+
|
81
|
+
def sampling(args):
|
82
|
+
|
83
|
+
|
84
|
+
|
85
|
+
z_mean, z_log_var = args
|
86
|
+
|
87
|
+
batch = K.shape(z_mean)[0]
|
88
|
+
|
89
|
+
dim = K.int_shape(z_mean)[1]
|
90
|
+
|
91
|
+
# by default, random_normal has mean=0 and std=1.0
|
92
|
+
|
93
|
+
epsilon = K.random_normal(shape=(batch, dim))
|
94
|
+
|
95
|
+
return z_mean + K.exp(0.5 * z_log_var) * epsilon
|
96
|
+
|
97
|
+
|
98
|
+
|
99
|
+
|
100
|
+
|
101
|
+
def plot_results(models,
|
102
|
+
|
103
|
+
data,
|
104
|
+
|
105
|
+
batch_size=50,
|
106
|
+
|
107
|
+
model_name="vae_mnist"):
|
108
|
+
|
109
|
+
|
110
|
+
|
111
|
+
|
112
|
+
|
21
|
-
|
113
|
+
encoder, decoder = models
|
114
|
+
|
115
|
+
#x_test, y_test = data
|
116
|
+
|
117
|
+
os.makedirs(model_name, exist_ok=True)
|
118
|
+
|
119
|
+
|
120
|
+
|
121
|
+
filename = os.path.join(model_name, "vae_mean.png")
|
122
|
+
|
123
|
+
|
124
|
+
|
125
|
+
z_mean, _, _ = encoder.predict(x_test,
|
126
|
+
|
127
|
+
batch_size=batch_size)
|
128
|
+
|
129
|
+
plt.figure(figsize=(12, 10))
|
130
|
+
|
131
|
+
plt.scatter(z_mean[:, 0], z_mean[:, 1])
|
132
|
+
|
133
|
+
plt.colorbar()
|
134
|
+
|
135
|
+
plt.xlabel("z[0]")
|
136
|
+
|
137
|
+
plt.ylabel("z[1]")
|
138
|
+
|
139
|
+
plt.savefig(filename)
|
140
|
+
|
141
|
+
plt.show()
|
142
|
+
|
143
|
+
|
144
|
+
|
145
|
+
filename = os.path.join(model_name, "digits_over_latent.png")
|
146
|
+
|
147
|
+
|
148
|
+
|
149
|
+
n = 30
|
150
|
+
|
151
|
+
digit_size = 224
|
152
|
+
|
153
|
+
figure = np.zeros((digit_size * n, digit_size * n))
|
154
|
+
|
155
|
+
|
156
|
+
|
157
|
+
grid_x = np.linspace(-4, 4, n)
|
158
|
+
|
159
|
+
grid_y = np.linspace(-4, 4, n)[::-1]
|
160
|
+
|
161
|
+
|
162
|
+
|
163
|
+
for i, yi in enumerate(grid_y):
|
164
|
+
|
165
|
+
for j, xi in enumerate(grid_x):
|
166
|
+
|
167
|
+
z_sample = np.array([[xi, yi]])
|
168
|
+
|
169
|
+
x_decoded = decoder.predict(z_sample)
|
170
|
+
|
171
|
+
digit = x_decoded[0].reshape(digit_size, digit_size)
|
172
|
+
|
173
|
+
figure[i * digit_size: (i + 1) * digit_size,
|
174
|
+
|
175
|
+
j * digit_size: (j + 1) * digit_size] = digit
|
176
|
+
|
177
|
+
|
178
|
+
|
179
|
+
plt.figure(figsize=(10, 10))
|
180
|
+
|
181
|
+
start_range = digit_size // 2
|
182
|
+
|
183
|
+
end_range = n * digit_size + start_range + 1
|
184
|
+
|
185
|
+
pixel_range = np.arange(start_range, end_range, digit_size)
|
186
|
+
|
187
|
+
sample_range_x = np.round(grid_x, 1)
|
188
|
+
|
189
|
+
sample_range_y = np.round(grid_y, 1)
|
190
|
+
|
191
|
+
plt.xticks(pixel_range, sample_range_x)
|
192
|
+
|
193
|
+
plt.yticks(pixel_range, sample_range_y)
|
194
|
+
|
195
|
+
plt.xlabel("z[0]")
|
196
|
+
|
197
|
+
plt.ylabel("z[1]")
|
198
|
+
|
199
|
+
plt.imshow(figure, cmap='Greys_r')
|
200
|
+
|
201
|
+
plt.savefig(filename)
|
202
|
+
|
203
|
+
plt.show()
|
204
|
+
|
205
|
+
|
206
|
+
|
207
|
+
|
208
|
+
|
209
|
+
x_train_tom = np.load('./folder_a.npy')
|
210
|
+
|
211
|
+
x_test_tom = np.load('./folder_b.npy')
|
212
|
+
|
213
|
+
|
214
|
+
|
215
|
+
image_size = x_train_tom.shape[1]
|
216
|
+
|
217
|
+
x_train = np.reshape(x_train_tom, [-1, image_size, image_size, 1])
|
218
|
+
|
219
|
+
x_test = np.reshape(x_test_tom, [-1, image_size, image_size, 1])
|
220
|
+
|
221
|
+
x_train = x_train.astype('float32') / 255
|
222
|
+
|
223
|
+
x_test = x_test.astype('float32') / 255
|
224
|
+
|
225
|
+
print(x_train.shape,x_test.shape)
|
226
|
+
|
227
|
+
|
228
|
+
|
229
|
+
input_shape = (image_size,image_size, 1)
|
230
|
+
|
231
|
+
batch_size = 50
|
232
|
+
|
233
|
+
kernel_size = 3
|
234
|
+
|
235
|
+
filters = 16
|
236
|
+
|
237
|
+
latent_dim = 2
|
238
|
+
|
239
|
+
epochs = 50
|
240
|
+
|
241
|
+
|
22
242
|
|
23
243
|
inputs = Input(shape=input_shape, name='encoder_input')
|
24
244
|
|
@@ -28,4 +248,248 @@
|
|
28
248
|
|
29
249
|
filters *= 2
|
30
250
|
|
251
|
+
x = Conv2D(filters=filters,
|
252
|
+
|
253
|
+
kernel_size=kernel_size,
|
254
|
+
|
255
|
+
activation='relu',
|
256
|
+
|
257
|
+
strides=2,
|
258
|
+
|
259
|
+
padding='same')(x)
|
260
|
+
|
261
|
+
|
262
|
+
|
263
|
+
shape = K.int_shape(x)
|
264
|
+
|
265
|
+
|
266
|
+
|
267
|
+
x = Flatten()(x)
|
268
|
+
|
269
|
+
x = Dense(64, activation='relu')(x)
|
270
|
+
|
271
|
+
z_mean = Dense(latent_dim, name='z_mean')(x)
|
272
|
+
|
273
|
+
z_log_var = Dense(latent_dim, name='z_log_var')(x)
|
274
|
+
|
275
|
+
|
276
|
+
|
277
|
+
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
|
278
|
+
|
279
|
+
|
280
|
+
|
281
|
+
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
|
282
|
+
|
283
|
+
encoder.summary()
|
284
|
+
|
285
|
+
plot_model(encoder, to_file='vae_cnn_encoder.png', show_shapes=True)
|
286
|
+
|
287
|
+
|
288
|
+
|
289
|
+
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
|
290
|
+
|
31
|
-
|
291
|
+
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
|
292
|
+
|
293
|
+
x = Reshape((shape[1], shape[2], shape[3]))(x)
|
294
|
+
|
295
|
+
|
296
|
+
|
297
|
+
for i in range(4):
|
298
|
+
|
299
|
+
x = Conv2DTranspose(filters=filters,
|
300
|
+
|
301
|
+
kernel_size=kernel_size,
|
302
|
+
|
303
|
+
activation='relu',
|
304
|
+
|
305
|
+
strides=2,
|
306
|
+
|
307
|
+
padding='same')(x)
|
308
|
+
|
309
|
+
filters //= 2
|
310
|
+
|
311
|
+
|
312
|
+
|
313
|
+
outputs = Conv2DTranspose(filters=1,
|
314
|
+
|
315
|
+
kernel_size=kernel_size,
|
316
|
+
|
317
|
+
activation='sigmoid',
|
318
|
+
|
319
|
+
padding='same',
|
320
|
+
|
321
|
+
name='decoder_output')(x)
|
322
|
+
|
323
|
+
|
324
|
+
|
325
|
+
decoder = Model(latent_inputs, outputs, name='decoder')
|
326
|
+
|
327
|
+
decoder.summary()
|
328
|
+
|
329
|
+
plot_model(decoder, to_file='vae_cnn_decoder.png', show_shapes=True)
|
330
|
+
|
331
|
+
|
332
|
+
|
333
|
+
outputs = decoder(encoder(inputs)[2])
|
334
|
+
|
335
|
+
vae = Model(inputs, outputs, name='vae')
|
336
|
+
|
337
|
+
|
338
|
+
|
339
|
+
|
340
|
+
|
341
|
+
|
342
|
+
|
343
|
+
def plot_history(history):
|
344
|
+
|
345
|
+
|
346
|
+
|
347
|
+
plt.plot(history.history['loss'])
|
348
|
+
|
349
|
+
plt.plot(history.history['val_loss'])
|
350
|
+
|
351
|
+
plt.title('model accuracy')
|
352
|
+
|
353
|
+
plt.xlabel('epoch')
|
354
|
+
|
355
|
+
plt.ylabel('accuracy')
|
356
|
+
|
357
|
+
plt.legend(['acc', 'val_acc'], loc='lower right')
|
358
|
+
|
359
|
+
plt.show()
|
360
|
+
|
361
|
+
|
362
|
+
|
363
|
+
plt.plot(history.history['loss'])
|
364
|
+
|
365
|
+
plt.plot(history.history['val_loss'])
|
366
|
+
|
367
|
+
plt.title('model loss')
|
368
|
+
|
369
|
+
plt.xlabel('epoch')
|
370
|
+
|
371
|
+
plt.ylabel('loss')
|
372
|
+
|
373
|
+
plt.legend(['loss', 'val_loss'], loc='lower right')
|
374
|
+
|
375
|
+
plt.savefig('loss.png') # -----(2)
|
376
|
+
|
377
|
+
plt.show()
|
378
|
+
|
379
|
+
|
380
|
+
|
381
|
+
|
382
|
+
|
383
|
+
|
384
|
+
|
385
|
+
if __name__ == '__main__':
|
386
|
+
|
387
|
+
args = easydict.EasyDict({
|
388
|
+
|
389
|
+
"batchsize": 50,
|
390
|
+
|
391
|
+
"epoch": 50,
|
392
|
+
|
393
|
+
"gpu": 0,
|
394
|
+
|
395
|
+
"out": "result",
|
396
|
+
|
397
|
+
"resume": False,
|
398
|
+
|
399
|
+
"unit": 1000
|
400
|
+
|
401
|
+
})
|
402
|
+
|
403
|
+
models = (encoder, decoder)
|
404
|
+
|
405
|
+
|
406
|
+
|
407
|
+
|
408
|
+
|
409
|
+
os.environ['PYTHONHASHSEED'] = '0'
|
410
|
+
|
411
|
+
np.random.seed(5)
|
412
|
+
|
413
|
+
rn.seed(5)
|
414
|
+
|
415
|
+
|
416
|
+
|
417
|
+
config = tf.ConfigProto(
|
418
|
+
|
419
|
+
gpu_options=tf.GPUOptions(
|
420
|
+
|
421
|
+
visible_device_list="0,1",
|
422
|
+
|
423
|
+
allow_growth=True
|
424
|
+
|
425
|
+
)
|
426
|
+
|
427
|
+
)
|
428
|
+
|
429
|
+
|
430
|
+
|
431
|
+
tf.set_random_seed(5)
|
432
|
+
|
433
|
+
sess = tf.Session(graph=tf.get_default_graph(), config=config)
|
434
|
+
|
435
|
+
K.set_session(sess)
|
436
|
+
|
437
|
+
|
438
|
+
|
439
|
+
|
440
|
+
|
441
|
+
reconstruction_loss = binary_crossentropy(K.flatten(inputs),
|
442
|
+
|
443
|
+
K.flatten(outputs))
|
444
|
+
|
445
|
+
|
446
|
+
|
447
|
+
reconstruction_loss *= image_size * image_size
|
448
|
+
|
449
|
+
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
|
450
|
+
|
451
|
+
kl_loss = K.sum(kl_loss, axis=-1)
|
452
|
+
|
453
|
+
kl_loss *= -0.5
|
454
|
+
|
455
|
+
vae_loss = K.mean(reconstruction_loss + kl_loss)
|
456
|
+
|
457
|
+
vae.add_loss(vae_loss)
|
458
|
+
|
459
|
+
Adam = optimizers.Adam(lr=0.0005)
|
460
|
+
|
461
|
+
vae.compile(optimizer=Adam)
|
462
|
+
|
463
|
+
vae.summary()
|
464
|
+
|
465
|
+
plot_model(vae, to_file='vae_cnn.png', show_shapes=True)
|
466
|
+
|
467
|
+
|
468
|
+
|
469
|
+
history = vae.fit(x_train,
|
470
|
+
|
471
|
+
epochs=epochs,
|
472
|
+
|
473
|
+
batch_size=batch_size,
|
474
|
+
|
475
|
+
validation_data=(x_test, None))
|
476
|
+
|
477
|
+
|
478
|
+
|
479
|
+
open('vae_cnn.json', "w").write(vae.to_json())
|
480
|
+
|
481
|
+
vae.save('vae_cnn.h5')
|
482
|
+
|
483
|
+
|
484
|
+
|
485
|
+
plot_results(models, data, batch_size=batch_size, model_name="vae_cnn")
|
486
|
+
|
487
|
+
plot_history(history)
|
488
|
+
|
489
|
+
|
490
|
+
|
491
|
+
#補足
|
492
|
+
|
493
|
+
kerasやtensorflowのバージョンが問題になっている場合もあるようですが、上手くいった人のバージョンに変更しても変化がありませんでした。
|
494
|
+
|
495
|
+
コードが長くてお手数をおかけしますが、アドバイス頂けると幸いです。どうぞよろしくお願い致します。
|