質問編集履歴
3
さらに追記
test
CHANGED
File without changes
|
test
CHANGED
@@ -20,7 +20,7 @@
|
|
20
20
|
|
21
21
|
|
22
22
|
|
23
|
-
```
|
23
|
+
```python
|
24
24
|
|
25
25
|
import os
|
26
26
|
|
@@ -537,3 +537,33 @@
|
|
537
537
|
質問は
|
538
538
|
|
539
539
|
一度保存した重みデータを 再度実行するときに 読み込むという意味です
|
540
|
+
|
541
|
+
|
542
|
+
|
543
|
+
追記2:
|
544
|
+
|
545
|
+
質問のコードは [こちら](https://github.com/tjwei/GANotebooks/blob/master/wgan2-keras.ipynb)を元に
|
546
|
+
|
547
|
+
WGAN-GPの実行いたしたく手直ししたものです
|
548
|
+
|
549
|
+
|
550
|
+
|
551
|
+
Colab 上での起動のため (時間制限があるので) 重みデータを (Google Driveに) 一旦 保存して
|
552
|
+
|
553
|
+
次の実行の際 そのデータを 読み込むことを想定しております
|
554
|
+
|
555
|
+
|
556
|
+
|
557
|
+
機械学習初心者で はじめから構築は難しく
|
558
|
+
|
559
|
+
既存のプログラムの実行して 勉強しております
|
560
|
+
|
561
|
+
|
562
|
+
|
563
|
+
model 定義後か Optimizer 定義後か
|
564
|
+
|
565
|
+
どちらで 重みデータを 読み込むのが
|
566
|
+
|
567
|
+
適切なのかが 勉強不足のため わかりませんので
|
568
|
+
|
569
|
+
質問いたしました
|
2
追記
test
CHANGED
File without changes
|
test
CHANGED
@@ -530,4 +530,10 @@
|
|
530
530
|
|
531
531
|
|
532
532
|
|
533
|
-
`
|
533
|
+
`
|
534
|
+
|
535
|
+
追記:
|
536
|
+
|
537
|
+
質問は
|
538
|
+
|
539
|
+
一度保存した重みデータを 再度実行するときに 読み込むという意味です
|
1
コードを修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -348,6 +348,184 @@
|
|
348
348
|
|
349
349
|
"""
|
350
350
|
|
351
|
+
|
352
|
+
|
353
|
+
|
354
|
+
|
355
|
+
|
356
|
+
|
357
|
+
from PIL import Image
|
358
|
+
|
359
|
+
import numpy as np
|
360
|
+
|
361
|
+
import tarfile
|
362
|
+
|
363
|
+
|
364
|
+
|
365
|
+
|
366
|
+
|
367
|
+
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array, array_to_img
|
368
|
+
|
369
|
+
import glob
|
370
|
+
|
371
|
+
import cv2
|
372
|
+
|
373
|
+
from keras.preprocessing import image
|
374
|
+
|
375
|
+
import os
|
376
|
+
|
377
|
+
|
378
|
+
|
379
|
+
train_X=[]
|
380
|
+
|
381
|
+
train_y=[]
|
382
|
+
|
383
|
+
files = glob.glob('/content/drive/train/*.*')
|
384
|
+
|
385
|
+
|
386
|
+
|
387
|
+
for img_file in files:
|
388
|
+
|
389
|
+
img = cv2.imread(img_file)
|
390
|
+
|
391
|
+
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
|
392
|
+
|
393
|
+
img = cv2.resize(img,(256,256),interpolation = cv2.INTER_AREA)
|
394
|
+
|
395
|
+
img = np.transpose(img,(2,0,1))
|
396
|
+
|
397
|
+
|
398
|
+
|
399
|
+
train_X.extend(img.reshape(-1,3,256,256)/255*2-1)
|
400
|
+
|
401
|
+
|
402
|
+
|
403
|
+
|
404
|
+
|
405
|
+
|
406
|
+
|
407
|
+
train_X = np.array(train_X)
|
408
|
+
|
409
|
+
|
410
|
+
|
411
|
+
train_X = np.concatenate([train_X[:,:,:,::-1], train_X])
|
412
|
+
|
413
|
+
|
414
|
+
|
415
|
+
print('画像データの読み込みが終了しました')
|
416
|
+
|
417
|
+
|
418
|
+
|
419
|
+
|
420
|
+
|
421
|
+
from IPython.display import display
|
422
|
+
|
423
|
+
def showX(X, rows=1):
|
424
|
+
|
425
|
+
assert X.shape[0]%rows == 0
|
426
|
+
|
427
|
+
int_X = ( (X+1)/2*255).clip(0,255).astype('uint8')
|
428
|
+
|
429
|
+
|
430
|
+
|
431
|
+
int_X = np.moveaxis(int_X.reshape(-1,3,256,256), 1, 3)
|
432
|
+
|
433
|
+
int_X = int_X.reshape(rows, -1, 256, 256,3).swapaxes(1,2).reshape(rows*256,-1, 3)
|
434
|
+
|
435
|
+
|
436
|
+
|
437
|
+
(Image.fromarray(int_X)).save('./WGAN-GP_result.png')
|
438
|
+
|
439
|
+
|
440
|
+
|
441
|
+
|
442
|
+
|
443
|
+
|
444
|
+
|
445
|
+
fixed_noise = np.random.normal(size=(batchSize, nz)).astype('float32')
|
446
|
+
|
447
|
+
|
448
|
+
|
449
|
+
|
450
|
+
|
451
|
+
import time
|
452
|
+
|
453
|
+
t0 = time.time()
|
454
|
+
|
455
|
+
niter = 1000000
|
456
|
+
|
457
|
+
gen_iterations = 0
|
458
|
+
|
459
|
+
errG = 0
|
460
|
+
|
461
|
+
targetD = np.float32([2]*batchSize+[-2]*batchSize)[:, None]
|
462
|
+
|
463
|
+
targetG = np.ones(batchSize, dtype=np.float32)[:, None]
|
464
|
+
|
465
|
+
for epoch in range(niter):
|
466
|
+
|
467
|
+
i = 0
|
468
|
+
|
469
|
+
|
470
|
+
|
471
|
+
np.random.shuffle(train_X)
|
472
|
+
|
473
|
+
batches = train_X.shape[0]//batchSize
|
474
|
+
|
475
|
+
while i < batches:
|
476
|
+
|
477
|
+
if gen_iterations < 25 or gen_iterations % 500 == 0:
|
478
|
+
|
479
|
+
_Diters = 100
|
480
|
+
|
481
|
+
else:
|
482
|
+
|
483
|
+
_Diters = Diters
|
484
|
+
|
485
|
+
j = 0
|
486
|
+
|
487
|
+
while j < _Diters and i < batches:
|
488
|
+
|
489
|
+
j+=1
|
490
|
+
|
491
|
+
real_data = train_X[i*batchSize:(i+1)*batchSize]
|
492
|
+
|
493
|
+
i+=1
|
494
|
+
|
495
|
+
noise = np.random.normal(size=(batchSize, nz))
|
496
|
+
|
497
|
+
ϵ = np.random.uniform(size=(batchSize, 1, 1 ,1))
|
498
|
+
|
499
|
+
errD_real, errD_fake = netD_train([real_data, noise, ϵ])
|
500
|
+
|
501
|
+
errD = errD_real - errD_fake
|
502
|
+
|
503
|
+
|
504
|
+
|
505
|
+
if gen_iterations%500==0:
|
506
|
+
|
507
|
+
print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
|
508
|
+
|
509
|
+
% (epoch, niter, i, batches, gen_iterations,errD, errG, errD_real, errD_fake), time.time()-t0)
|
510
|
+
|
511
|
+
fake = netG.predict(fixed_noise)
|
512
|
+
|
513
|
+
showX(fake, 2)
|
514
|
+
|
515
|
+
if gen_iterations % 5000 == 0:
|
516
|
+
|
517
|
+
netG.save_weights('./netG.hdf5')
|
518
|
+
|
519
|
+
netD.save_weights('./netD.hdf5')
|
520
|
+
|
521
|
+
|
522
|
+
|
523
|
+
noise = np.random.normal(size=(batchSize, nz))
|
524
|
+
|
525
|
+
errG, = netG_train([noise])
|
526
|
+
|
527
|
+
gen_iterations+=1
|
528
|
+
|
351
529
|
```
|
352
530
|
|
353
531
|
|