下記サイトを参考に、kerasでpix2pixが動くようにしました。
https://qiita.com/mine820/items/36ffc3c0aea0b98027fd
タイトルの通り、学習済みのgeneratorモデルを保存し、トレーニング画像以外のテスト画像で変換を行いたいと考えております。
def train()のブロックの最下段にmodel.saveを追記しました。
保存されたモデル(.h5)をロードして画像変換を行おうとしたところ、うまくいきませんでした。学習初期?のような画像が生成されました。
model.saveはどこに追記すれば学習済みのgeneratorモデルを保存することができるでしょうか。
python
1def train(): 2 # load data 3 rawImage, procImage, rawImage_val, procImage_val = load_data(datasetpath) 4 5 img_shape = rawImage.shape[-3:] 6 patch_num = (img_shape[0] // patch_size) * (img_shape[1] // patch_size) 7 disc_img_shape = (patch_size, patch_size, procImage.shape[-1]) 8 9 # train 10 opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) 11 opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08) 12 13 # load generator model 14 generator_model = load_generator(img_shape, disc_img_shape) 15 # load discriminator model 16 discriminator_model = load_DCGAN_discriminator(img_shape, disc_img_shape, patch_num) 17 18 generator_model.compile(loss='mae', optimizer=opt_discriminator) 19 discriminator_model.trainable = False 20 21 DCGAN_model = load_DCGAN(generator_model, discriminator_model, img_shape, patch_size) 22 23 loss = [l1_loss, 'binary_crossentropy'] 24 loss_weights = [1E1, 1] 25 DCGAN_model.compile(loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan) 26 27 discriminator_model.trainable = True 28 discriminator_model.compile(loss='binary_crossentropy', optimizer=opt_discriminator) 29 30 # start training 31 print('start training') 32 for e in range(epoch): 33 34 starttime = time.time() 35 perm = np.random.permutation(rawImage.shape[0]) 36 X_procImage = procImage[perm] 37 X_rawImage = rawImage[perm] 38 X_procImageIter = [X_procImage[i:i+batch_size] for i in range(0, rawImage.shape[0], batch_size)] 39 X_rawImageIter = [X_rawImage[i:i+batch_size] for i in range(0, rawImage.shape[0], batch_size)] 40 b_it = 0 41 progbar = generic_utils.Progbar(len(X_procImageIter)*batch_size) 42 for (X_proc_batch, X_raw_batch) in zip(X_procImageIter, X_rawImageIter): 43 b_it += 1 44 X_disc, y_disc = get_disc_batch(X_proc_batch, X_raw_batch, generator_model, b_it, patch_size) 45 raw_disc, _ = get_disc_batch(X_raw_batch, X_raw_batch, generator_model, 1, patch_size) 46 x_disc = X_disc + raw_disc 47 # update the discriminator 48 disc_loss = discriminator_model.train_on_batch(x_disc, y_disc) 49 50 # create a batch to feed the generator model 51 idx = np.random.choice(procImage.shape[0], batch_size) 52 X_gen_target, X_gen = procImage[idx], rawImage[idx] 53 y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) 54 y_gen[:, 1] = 1 55 56 # Freeze the discriminator 57 discriminator_model.trainable = False 58 gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen]) 59 # Unfreeze the discriminator 60 discriminator_model.trainable = True 61 62 progbar.add(batch_size, values=[ 63 ("D logloss", disc_loss), 64 ("G tot", gen_loss[0]), 65 ("G L1", gen_loss[1]), 66 ("G logloss", gen_loss[2]) 67 ]) 68 69 # save images for visualization 70 if b_it % (procImage.shape[0]//batch_size//2) == 0: 71 plot_generated_batch(X_proc_batch, X_raw_batch, generator_model, batch_size, "training") 72 idx = np.random.choice(procImage_val.shape[0], batch_size) 73 X_gen_target, X_gen = procImage_val[idx], rawImage_val[idx] 74 plot_generated_batch(X_gen_target, X_gen, generator_model, batch_size, "validation") 75 76 print("") 77 print('Epoch %s/%s, Time: %s' % (e + 1, epoch, time.time() - starttime)) 78 79 generator_model.save('C:/フォルダーの場所/model.h5')
以下、追記
「保存されたモデル(.h5)をロードして画像変換」は下記になります。
python
1import keras 2from keras import optimizers 3from matplotlib import pyplot as plt 4from keras.preprocessing.image import ImageDataGenerator 5from keras.models import Sequential, Model 6from keras.layers import Dense, Dropout, Activation, Flatten 7from keras.layers import Conv2D, MaxPooling2D, Add, Input, Multiply 8from keras.callbacks import ModelCheckpoint, CSVLogger 9from keras import optimizers 10from keras.layers import GlobalAveragePooling2D 11import cv2 12import os, glob, random 13import os, tkinter, tkinter.filedialog, tkinter.messagebox 14from PIL import Image 15import pickle 16import numpy as np 17import time 18import datetime 19import random 20import keras.backend as K 21 22def psnr(y_true, y_pred): 23 return -10*K.log(K.mean(K.flatten((y_true - y_pred))**2)) / np.log(10) 24 25random.seed(0) 26 27img_width = 256 28img_height = 256 29 30print('学習済みモデルを指定') 31root = tkinter.Tk() 32root.withdraw() 33fTyp = [("","*")] 34iDir = os.path.abspath(os.path.dirname(__file__)) 35file_path = tkinter.filedialog.askopenfilename(filetypes = fTyp,initialdir = iDir) 36print(file_path) 37 38from keras.models import load_model 39model = load_model(file_path, custom_objects={'psnr': psnr, 'val_psnr': psnr}) 40model.summary() 41 42 43print('データセットのフォルダを指定') 44root = tkinter.Tk() 45root.withdraw() 46fTyp = [("","*")] 47iDir = os.path.abspath(os.path.dirname(__file__)) 48file_path = tkinter.filedialog.askdirectory(initialdir = iDir) 49print(file_path+'/*.jpg') 50 51x_val = [] 52fn=glob.glob(file_path+'/*.jpg') 53for nm in fn: 54 img = cv2.imread(nm) 55 x_val.append(img) 56x_val = np.asarray(x_val) 57x_val = x_val.astype('float32') 58x_val = x_val / 255.0 59#x_val = x_val.reshape((len(x_val), 256, 256, 3)) 60 61 62decoded_imgs = model.predict(x_val) 63 64if not os.path.exists('C:/フォルダの場所'): 65 os.makedirs('C:/フォルダの場所') 66 67i = -1 68for nm in fn: 69 nm = nm.split('\')[1] 70 i = i + 1 71 ax = decoded_imgs[i].reshape(img_height, img_width, 3) 72 cv2.imwrite('C:/フォルダの場所/'+nm,ax*255) 73 74print('結果が保存') 75 76
あなたの回答
tips
プレビュー