質問編集履歴

3

さらに追記

2019/02/17 11:29

投稿

Reach
Reach

スコア733

test CHANGED
File without changes
test CHANGED
@@ -20,7 +20,7 @@
20
20
 
21
21
 
22
22
 
23
- ```ここに言語を入力
23
+ ```python
24
24
 
25
25
  import os
26
26
 
@@ -537,3 +537,33 @@
537
537
  質問は
538
538
 
539
539
  一度保存した重みデータを 再度実行するときに 読み込むという意味です
540
+
541
+
542
+
543
+ 追記2:
544
+
545
+ 質問のコードは [こちら](https://github.com/tjwei/GANotebooks/blob/master/wgan2-keras.ipynb)を元に
546
+
547
+ WGAN-GPの実行いたしたく手直ししたものです
548
+
549
+
550
+
551
+ Colab 上での起動のため (時間制限があるので) 重みデータを (Google Driveに) 一旦 保存して
552
+
553
+ 次の実行の際 そのデータを 読み込むことを想定しております
554
+
555
+
556
+
557
+ 機械学習初心者で はじめから構築は難しく
558
+
559
+ 既存のプログラムの実行して 勉強しております
560
+
561
+
562
+
563
+ model 定義後か Optimizer 定義後か
564
+
565
+ どちらで 重みデータを 読み込むのが
566
+
567
+ 適切なのかが 勉強不足のため わかりませんので
568
+
569
+ 質問いたしました

2

追記

2019/02/17 11:29

投稿

Reach
Reach

スコア733

test CHANGED
File without changes
test CHANGED
@@ -530,4 +530,10 @@
530
530
 
531
531
 
532
532
 
533
- ```
533
+ `
534
+
535
+ 追記:
536
+
537
+ 質問は
538
+
539
+ 一度保存した重みデータを 再度実行するときに 読み込むという意味です

1

コードを修正

2019/02/17 10:30

投稿

Reach
Reach

スコア733

test CHANGED
File without changes
test CHANGED
@@ -348,6 +348,184 @@
348
348
 
349
349
  """
350
350
 
351
+
352
+
353
+
354
+
355
+
356
+
357
+ from PIL import Image
358
+
359
+ import numpy as np
360
+
361
+ import tarfile
362
+
363
+
364
+
365
+
366
+
367
+ from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array, array_to_img
368
+
369
+ import glob
370
+
371
+ import cv2
372
+
373
+ from keras.preprocessing import image
374
+
375
+ import os
376
+
377
+
378
+
379
+ train_X=[]
380
+
381
+ train_y=[]
382
+
383
+ files = glob.glob('/content/drive/train/*.*')
384
+
385
+
386
+
387
+ for img_file in files:
388
+
389
+ img = cv2.imread(img_file)
390
+
391
+ img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
392
+
393
+ img = cv2.resize(img,(256,256),interpolation = cv2.INTER_AREA)
394
+
395
+ img = np.transpose(img,(2,0,1))
396
+
397
+
398
+
399
+ train_X.extend(img.reshape(-1,3,256,256)/255*2-1)
400
+
401
+
402
+
403
+
404
+
405
+
406
+
407
+ train_X = np.array(train_X)
408
+
409
+
410
+
411
+ train_X = np.concatenate([train_X[:,:,:,::-1], train_X])
412
+
413
+
414
+
415
+ print('画像データの読み込みが終了しました')
416
+
417
+
418
+
419
+
420
+
421
+ from IPython.display import display
422
+
423
+ def showX(X, rows=1):
424
+
425
+ assert X.shape[0]%rows == 0
426
+
427
+ int_X = ( (X+1)/2*255).clip(0,255).astype('uint8')
428
+
429
+
430
+
431
+ int_X = np.moveaxis(int_X.reshape(-1,3,256,256), 1, 3)
432
+
433
+ int_X = int_X.reshape(rows, -1, 256, 256,3).swapaxes(1,2).reshape(rows*256,-1, 3)
434
+
435
+
436
+
437
+ (Image.fromarray(int_X)).save('./WGAN-GP_result.png')
438
+
439
+
440
+
441
+
442
+
443
+
444
+
445
+ fixed_noise = np.random.normal(size=(batchSize, nz)).astype('float32')
446
+
447
+
448
+
449
+
450
+
451
+ import time
452
+
453
+ t0 = time.time()
454
+
455
+ niter = 1000000
456
+
457
+ gen_iterations = 0
458
+
459
+ errG = 0
460
+
461
+ targetD = np.float32([2]*batchSize+[-2]*batchSize)[:, None]
462
+
463
+ targetG = np.ones(batchSize, dtype=np.float32)[:, None]
464
+
465
+ for epoch in range(niter):
466
+
467
+ i = 0
468
+
469
+
470
+
471
+ np.random.shuffle(train_X)
472
+
473
+ batches = train_X.shape[0]//batchSize
474
+
475
+ while i < batches:
476
+
477
+ if gen_iterations < 25 or gen_iterations % 500 == 0:
478
+
479
+ _Diters = 100
480
+
481
+ else:
482
+
483
+ _Diters = Diters
484
+
485
+ j = 0
486
+
487
+ while j < _Diters and i < batches:
488
+
489
+ j+=1
490
+
491
+ real_data = train_X[i*batchSize:(i+1)*batchSize]
492
+
493
+ i+=1
494
+
495
+ noise = np.random.normal(size=(batchSize, nz))
496
+
497
+ ϵ = np.random.uniform(size=(batchSize, 1, 1 ,1))
498
+
499
+ errD_real, errD_fake = netD_train([real_data, noise, ϵ])
500
+
501
+ errD = errD_real - errD_fake
502
+
503
+
504
+
505
+ if gen_iterations%500==0:
506
+
507
+ print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
508
+
509
+ % (epoch, niter, i, batches, gen_iterations,errD, errG, errD_real, errD_fake), time.time()-t0)
510
+
511
+ fake = netG.predict(fixed_noise)
512
+
513
+ showX(fake, 2)
514
+
515
+ if gen_iterations % 5000 == 0:
516
+
517
+ netG.save_weights('./netG.hdf5')
518
+
519
+ netD.save_weights('./netD.hdf5')
520
+
521
+
522
+
523
+ noise = np.random.normal(size=(batchSize, nz))
524
+
525
+ errG, = netG_train([noise])
526
+
527
+ gen_iterations+=1
528
+
351
529
  ```
352
530
 
353
531