質問編集履歴
3
コードの修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -258,7 +258,7 @@
|
|
258
258
|
|
259
259
|
filenames = glob.glob("./train/NORMAL_train_dataset/*.jpeg")
|
260
260
|
|
261
|
-
|
261
|
+
X = []
|
262
262
|
|
263
263
|
|
264
264
|
|
@@ -268,7 +268,7 @@
|
|
268
268
|
|
269
269
|
filename, color_mode = "grayscale", target_size=(512,496))
|
270
270
|
|
271
|
-
|
271
|
+
X.append(img)
|
272
272
|
|
273
273
|
|
274
274
|
|
2
全コード掲載
test
CHANGED
File without changes
|
test
CHANGED
@@ -98,6 +98,158 @@
|
|
98
98
|
|
99
99
|
from PIL import Image
|
100
100
|
|
101
|
+
#from google.colab.patches import cv2_imshow
|
102
|
+
|
103
|
+
|
104
|
+
|
105
|
+
import warnings
|
106
|
+
|
107
|
+
warnings.filterwarnings('ignore')
|
108
|
+
|
109
|
+
|
110
|
+
|
111
|
+
%matplotlib inline
|
112
|
+
|
113
|
+
|
114
|
+
|
115
|
+
def sampling(args):
|
116
|
+
|
117
|
+
z_mean, z_log_var = args
|
118
|
+
|
119
|
+
batch = K.shape(z_mean)[0]
|
120
|
+
|
121
|
+
dim = K.int_shape(z_mean)[1]
|
122
|
+
|
123
|
+
# by default, random_normal has mean=0 and std=1.0
|
124
|
+
|
125
|
+
epsilon = K.random_normal(shape=(batch, dim))
|
126
|
+
|
127
|
+
return z_mean + K.exp(0.5 * z_log_var) * epsilon
|
128
|
+
|
129
|
+
|
130
|
+
|
131
|
+
#def plot_results(models,
|
132
|
+
|
133
|
+
#data,
|
134
|
+
|
135
|
+
#batch_size=500,#128,
|
136
|
+
|
137
|
+
#model_name="vae_OCT"):
|
138
|
+
|
139
|
+
"""Plots labels and MNIST digits as function of 2-dim latent vector
|
140
|
+
|
141
|
+
|
142
|
+
|
143
|
+
# Arguments
|
144
|
+
|
145
|
+
models (tuple): encoder and decoder models
|
146
|
+
|
147
|
+
data (tuple): test data and label
|
148
|
+
|
149
|
+
batch_size (int): prediction batch size
|
150
|
+
|
151
|
+
encoder, decoder = models
|
152
|
+
|
153
|
+
x_test = data #, y_test削除
|
154
|
+
|
155
|
+
os.makedirs(model_name, exist_ok=True)
|
156
|
+
|
157
|
+
model_name (string): which model is using this function
|
158
|
+
|
159
|
+
"""
|
160
|
+
|
161
|
+
|
162
|
+
|
163
|
+
filename = os.path.join('.', "vae_mean.png")
|
164
|
+
|
165
|
+
# display a 2D plot of the digit classes in the latent space
|
166
|
+
|
167
|
+
z_mean, _, _ = encoder.predict(x_test,
|
168
|
+
|
169
|
+
batch_size=batch_size)
|
170
|
+
|
171
|
+
#plt.figure(figsize=(12, 10))
|
172
|
+
|
173
|
+
#plt.scatter(z_mean[:, 0], z_mean[:, 1])#c=y_test削除
|
174
|
+
|
175
|
+
#plt.colorbar()
|
176
|
+
|
177
|
+
#plt.xlabel("z[0]")
|
178
|
+
|
179
|
+
#plt.ylabel("z[1]")
|
180
|
+
|
181
|
+
#plt.savefig(filename)
|
182
|
+
|
183
|
+
#plt.show()
|
184
|
+
|
185
|
+
|
186
|
+
|
187
|
+
#filename = os.path.join(model_name, "digits_over_latent.png")
|
188
|
+
|
189
|
+
# display a 30x30 2D manifold of digits
|
190
|
+
|
191
|
+
#n = 30
|
192
|
+
|
193
|
+
#digit_size = 496 うまく動かないので一旦プロット部分はすべてコメントアウト
|
194
|
+
|
195
|
+
#digit_size_width = 512 #width追加
|
196
|
+
|
197
|
+
#digit_size_height = 496 #height追加
|
198
|
+
|
199
|
+
#figure = np.zeros((digit_size_width * n, digit_size_height * n))#width, height追加
|
200
|
+
|
201
|
+
# linearly spaced coordinates corresponding to the 2D plot
|
202
|
+
|
203
|
+
# of digit classes in the latent space
|
204
|
+
|
205
|
+
#grid_x = np.linspace(-4, 4, n)
|
206
|
+
|
207
|
+
#grid_y = np.linspace(-4, 4, n)[::-1]
|
208
|
+
|
209
|
+
|
210
|
+
|
211
|
+
#for i, yi in enumerate(grid_y):
|
212
|
+
|
213
|
+
#for j, xi in enumerate(grid_x):
|
214
|
+
|
215
|
+
#z_sample = np.array([[xi, yi]])
|
216
|
+
|
217
|
+
#x_decoded = decoder.predict(z_sample)
|
218
|
+
|
219
|
+
#digit = x_decoded[0].reshape(digit_size_width, digit_size_height)#width, height追加
|
220
|
+
|
221
|
+
#figure[i * digit_size_width: (i + 1) * digit_size_height,#width, height追加
|
222
|
+
|
223
|
+
#j * digit_size_width: (j + 1) * digit_size_height] = digit#width, height追加
|
224
|
+
|
225
|
+
|
226
|
+
|
227
|
+
#plt.figure(figsize=(10, 10))
|
228
|
+
|
229
|
+
#start_range = digit_size // 2
|
230
|
+
|
231
|
+
#end_range = n * digit_size + start_range + 1
|
232
|
+
|
233
|
+
#pixel_range = np.arange(start_range, end_range, digit_size)
|
234
|
+
|
235
|
+
#sample_range_x = np.round(grid_x, 1)
|
236
|
+
|
237
|
+
#sample_range_y = np.round(grid_y, 1)
|
238
|
+
|
239
|
+
#plt.xticks(pixel_range, sample_range_x)
|
240
|
+
|
241
|
+
#plt.yticks(pixel_range, sample_range_y)
|
242
|
+
|
243
|
+
#plt.xlabel("z[0]")
|
244
|
+
|
245
|
+
#plt.ylabel("z[1]")
|
246
|
+
|
247
|
+
#plt.imshow(figure, cmap='Greys_r')
|
248
|
+
|
249
|
+
#plt.savefig(filename)
|
250
|
+
|
251
|
+
#plt.show()
|
252
|
+
|
101
253
|
|
102
254
|
|
103
255
|
#original dataset
|
@@ -106,27 +258,311 @@
|
|
106
258
|
|
107
259
|
filenames = glob.glob("./train/NORMAL_train_dataset/*.jpeg")
|
108
260
|
|
261
|
+
x_train = []
|
262
|
+
|
109
263
|
|
110
264
|
|
111
|
-
X = []
|
112
|
-
|
113
|
-
|
114
|
-
|
115
265
|
for filename in filenames:
|
116
266
|
|
117
267
|
img = img_to_array(load_img(
|
118
268
|
|
119
|
-
filename, color_mode = "grayscale"
|
269
|
+
filename, color_mode = "grayscale", target_size=(512,496))
|
270
|
+
|
120
|
-
|
271
|
+
x_train.append(img)
|
272
|
+
|
273
|
+
|
274
|
+
|
275
|
+
x_train = np.asarray(X)
|
276
|
+
|
277
|
+
|
278
|
+
|
279
|
+
#test
|
280
|
+
|
281
|
+
filenames = glob.glob("./validation/*.jpeg")
|
282
|
+
|
283
|
+
x_test = []
|
284
|
+
|
285
|
+
|
286
|
+
|
287
|
+
for filename in filenames:
|
288
|
+
|
121
|
-
|
289
|
+
img = img_to_array(load_img(
|
290
|
+
|
122
|
-
|
291
|
+
filename, color_mode = "grayscale", target_size=(512,496))
|
292
|
+
|
123
|
-
|
293
|
+
x_test.append(img)
|
294
|
+
|
295
|
+
|
296
|
+
|
297
|
+
x_test = np.asarray(x_test)
|
298
|
+
|
299
|
+
|
300
|
+
|
301
|
+
image_size = x_train.shape[1]
|
302
|
+
|
303
|
+
original_dim = 512 * 496 *1 #3削除
|
304
|
+
|
305
|
+
x_train = np.reshape(x_train, [-1, original_dim,1])# x_train = np.reshape(x_train, [-1, original_dim])
|
306
|
+
|
307
|
+
x_test = np.reshape(x_test, [-1, original_dim,1])# x_test = np.reshape(x_test, [-1, original_dim])
|
308
|
+
|
309
|
+
x_train = x_train.astype('float32') / 255
|
310
|
+
|
311
|
+
x_test = x_test.astype('float32') / 255
|
312
|
+
|
313
|
+
|
314
|
+
|
315
|
+
#train_generator = train_datagen.flow(x_train)#追記。generator作成
|
316
|
+
|
317
|
+
#test_generator = test_datagen.flow(x_test)#追記。generator作成
|
318
|
+
|
319
|
+
|
320
|
+
|
321
|
+
print(x_train.shape)
|
322
|
+
|
323
|
+
print(x_test.shape)
|
324
|
+
|
325
|
+
|
326
|
+
|
327
|
+
|
328
|
+
|
329
|
+
# network parameters
|
330
|
+
|
331
|
+
input_shape = (512, 496, 1)# (original_dim,)
|
332
|
+
|
333
|
+
kernel_size = 3
|
334
|
+
|
335
|
+
filters = 16
|
336
|
+
|
337
|
+
#intermediate_dim = 512
|
338
|
+
|
339
|
+
batch_size = 500#128
|
340
|
+
|
341
|
+
latent_dim = 2# Dimensionality of the latent space: a plane 潜在空間の次元数:平面 https://fisproject.jp/2018/09/vae-with-python-keras/#vae-with-keras
|
342
|
+
|
343
|
+
epochs = 5#1#50
|
344
|
+
|
345
|
+
|
346
|
+
|
347
|
+
|
348
|
+
|
349
|
+
# build encoder model
|
350
|
+
|
351
|
+
inputs = Input(shape=input_shape, name='encoder_input')
|
352
|
+
|
353
|
+
x = inputs
|
354
|
+
|
355
|
+
for i in range(4):
|
356
|
+
|
357
|
+
filters *= 2
|
358
|
+
|
359
|
+
x = Conv2D(filters=filters,kernel_size=kernel_size,activation='relu',strides=2,padding='same')(x)
|
360
|
+
|
361
|
+
|
362
|
+
|
363
|
+
# shape info needed to build decoder model これは画像のshapeとって割る時とかに結構使う. ちなみにtensorflowでのみ動作します. https://www.mathgram.xyz/entry/keras/backend
|
364
|
+
|
365
|
+
shape = K.int_shape(x)
|
366
|
+
|
367
|
+
|
368
|
+
|
369
|
+
# use reparameterization trick to push the sampling out as input
|
370
|
+
|
371
|
+
# note that "output_shape" isn't necessary with the TensorFlow backend
|
372
|
+
|
373
|
+
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
|
374
|
+
|
375
|
+
|
376
|
+
|
377
|
+
# instantiate encoder model
|
378
|
+
|
379
|
+
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
|
380
|
+
|
381
|
+
encoder.summary()
|
382
|
+
|
383
|
+
plot_model(encoder, to_file='vae_mlp_encoder.png', show_shapes=True)
|
384
|
+
|
385
|
+
|
386
|
+
|
387
|
+
# build decoder model
|
388
|
+
|
389
|
+
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
|
390
|
+
|
391
|
+
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
|
392
|
+
|
393
|
+
x = Reshape((shape[1], shape[2], shape[3]))(x)
|
394
|
+
|
395
|
+
|
396
|
+
|
397
|
+
for i in range(4):
|
398
|
+
|
399
|
+
x = Conv2DTranspose(filters=filters, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(x)
|
400
|
+
|
401
|
+
filters //= 2
|
402
|
+
|
403
|
+
|
404
|
+
|
405
|
+
outputs = Conv2DTranspose(filters=1, kernel_size=kernel_size, activation='sigmoid', padding='same', name='decoder_output')(x)
|
406
|
+
|
407
|
+
|
408
|
+
|
409
|
+
# instantiate decoder model
|
410
|
+
|
411
|
+
decoder = Model(latent_inputs, outputs, name='decoder')
|
412
|
+
|
413
|
+
decoder.summary()
|
414
|
+
|
415
|
+
plot_model(decoder, to_file='vae_mlp_decoder.png', show_shapes=True)
|
416
|
+
|
417
|
+
|
418
|
+
|
419
|
+
# instantiate VAE model
|
420
|
+
|
421
|
+
outputs = decoder(encoder(inputs)[2])
|
422
|
+
|
423
|
+
vae = Model(inputs, outputs, name='vae_mlp')
|
424
|
+
|
425
|
+
|
426
|
+
|
427
|
+
|
428
|
+
|
429
|
+
|
430
|
+
|
431
|
+
if __name__ == '__main__':
|
432
|
+
|
433
|
+
args = easydict.EasyDict({
|
434
|
+
|
435
|
+
"batchsize": 500,#40,
|
436
|
+
|
437
|
+
"epoch": 1,#50,
|
438
|
+
|
439
|
+
#"gpu": 0,
|
440
|
+
|
441
|
+
"out": "result",
|
442
|
+
|
443
|
+
"resume": False,
|
444
|
+
|
445
|
+
#"unit": 1000
|
446
|
+
|
447
|
+
})
|
448
|
+
|
449
|
+
#parser = argparse.ArgumentParser() parserがうまくうごかないため削除
|
450
|
+
|
451
|
+
#help_ = "Load h5 model trained weights"
|
452
|
+
|
453
|
+
#parser.add_argument("-w", "--weights", help=help_)
|
454
|
+
|
455
|
+
#help_ = "Use mse loss instead of binary cross entropy (default)"
|
456
|
+
|
457
|
+
#parser.add_argument("-m",
|
458
|
+
|
459
|
+
#"--mse",
|
460
|
+
|
461
|
+
#help=help_, action='store_true')
|
462
|
+
|
463
|
+
#args = parser.parse_args()
|
464
|
+
|
465
|
+
models = (encoder, decoder)
|
466
|
+
|
467
|
+
data = (x_test)#, y_test削除
|
468
|
+
|
469
|
+
|
470
|
+
|
471
|
+
# VAE loss = mse_loss or xent_loss + kl_loss
|
472
|
+
|
473
|
+
#if args.mse:
|
474
|
+
|
475
|
+
#reconstruction_loss = mse(inputs, outputs)
|
476
|
+
|
477
|
+
#else:
|
478
|
+
|
479
|
+
reconstruction_loss = binary_crossentropy(inputs,
|
480
|
+
|
481
|
+
outputs)
|
482
|
+
|
483
|
+
|
484
|
+
|
485
|
+
reconstruction_loss *= original_dim
|
486
|
+
|
487
|
+
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
|
488
|
+
|
489
|
+
kl_loss = K.sum(kl_loss, axis=-1)
|
490
|
+
|
491
|
+
kl_loss *= -0.5
|
492
|
+
|
493
|
+
vae_loss = K.mean(reconstruction_loss + kl_loss)
|
494
|
+
|
495
|
+
vae.add_loss(vae_loss)
|
496
|
+
|
497
|
+
vae.compile(optimizer='adam')
|
498
|
+
|
499
|
+
vae.summary()
|
500
|
+
|
501
|
+
plot_model(vae,
|
502
|
+
|
503
|
+
to_file='vae_mlp.png',
|
504
|
+
|
505
|
+
show_shapes=True)
|
506
|
+
|
507
|
+
|
508
|
+
|
509
|
+
|
510
|
+
|
511
|
+
|
512
|
+
|
513
|
+
callbacks = []
|
514
|
+
|
515
|
+
callbacks.append(ModelCheckpoint(filepath="model.ep{epoch:02d}.h5"))# 各epochでのモデルの保存
|
516
|
+
|
517
|
+
callbacks.append(EarlyStopping(monitor='val_loss', patience=0, verbose=1))
|
518
|
+
|
519
|
+
callbacks.append(LearningRateScheduler(lambda ep: float(1e-3 / 3 ** (ep * 4 // MAX_EPOCH))))
|
520
|
+
|
521
|
+
callbacks.append(CSVLogger("history.csv"))
|
522
|
+
|
523
|
+
|
124
524
|
|
125
525
|
|
126
526
|
|
127
|
-
|
527
|
+
#if args.weights:
|
528
|
+
|
128
|
-
|
529
|
+
#vae.load_weights(args.weights)
|
530
|
+
|
531
|
+
#else:
|
532
|
+
|
129
|
-
|
533
|
+
# train the autoencoder
|
534
|
+
|
535
|
+
history = vae.fit(x_train,
|
536
|
+
|
537
|
+
epochs=epochs,
|
538
|
+
|
539
|
+
batch_size=batch_size,
|
540
|
+
|
541
|
+
validation_data=(x_test, None),
|
542
|
+
|
543
|
+
callbacks=callbacks)
|
544
|
+
|
545
|
+
|
546
|
+
|
547
|
+
score = model.evaluate(x_test, verbose=0)#y_test削除
|
548
|
+
|
549
|
+
print('Test loss:', score[0])
|
550
|
+
|
551
|
+
print('Test accuracy:', score[1])
|
552
|
+
|
553
|
+
|
554
|
+
|
555
|
+
plt.plot(history.history["acc"], label="acc", ls="-", marker="o")
|
556
|
+
|
557
|
+
plt.plot(history.history["val_acc"], label="val_acc", ls="-", marker="x")
|
558
|
+
|
559
|
+
plt.ylabel("accuracy")
|
560
|
+
|
561
|
+
plt.xlabel("epoch")
|
562
|
+
|
563
|
+
plt.legend(loc="best")
|
564
|
+
|
565
|
+
plt.show()
|
130
566
|
|
131
567
|
```
|
132
568
|
|
1
コードの修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -108,7 +108,7 @@
|
|
108
108
|
|
109
109
|
|
110
110
|
|
111
|
-
|
111
|
+
X = []
|
112
112
|
|
113
113
|
|
114
114
|
|
@@ -120,13 +120,13 @@
|
|
120
120
|
|
121
121
|
, target_size=(512,496)))
|
122
122
|
|
123
|
-
|
123
|
+
X.append(img)
|
124
124
|
|
125
125
|
|
126
126
|
|
127
127
|
|
128
128
|
|
129
|
-
x_train = np.asarray(
|
129
|
+
x_train = np.asarray(X)
|
130
130
|
|
131
131
|
```
|
132
132
|
|