質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

1回答

494閲覧

Python3 エポック数2回につき1回重みを保存したい

SuzuAya

総合スコア71

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2019/05/16 11:29

前提・実現したいこと

以下のようなコードを試したいのですが、現在のままだとエポックが終わるたびに重みが保存されることになるため、2回につき1回保存できるようにしたいと思っているのですが、その場合はどのように書き換えればいいのでしょうか。

該当のソースコード

from __future__ import absolute_import from __future__ import division from __future__ import print_function from keras.layers import Lambda, Input, Dense from keras.models import Model from keras.models import Sequential, model_from_json from keras.losses import mse, binary_crossentropy from keras.layers import Conv2D, Flatten, Lambda from keras.layers import Reshape, Conv2DTranspose from keras.utils import plot_model, np_utils from keras.utils import plot_model from keras.callbacks import Callback, EarlyStopping, TensorBoard, ModelCheckpoint, LearningRateScheduler, CSVLogger from keras import optimizers from keras import backend as K from keras.preprocessing.image import array_to_img, img_to_array,load_img from keras.preprocessing.image import ImageDataGenerator import numpy as np import matplotlib.pyplot as plt import argparse import os import re import glob import random as rn import tensorflow as tf import cv2 import easydict from PIL import Image from google.colab.patches import cv2_imshow import warnings warnings.filterwarnings('ignore') %matplotlib inline get_ipython().run_line_magic('matplotlib', 'inline') def sampling(args): """Reparameterization trick by sampling fr an isotropic unit Gaussian. # Arguments args (tensor): mean and log of variance of Q(z|X) # Returns z (tensor): sampled latent vector """ z_mean, z_log_var = args batch = K.shape(z_mean)[0] dim = K.int_shape(z_mean)[1] # by default, random_normal has mean=0 and std=1.0 epsilon = K.random_normal(shape=(batch, dim)) return z_mean + K.exp(0.5 * z_log_var) * epsilon def plot_results(models, data, batch_size=32, model_name="vae_OCT"): """Plots labels and MNIST digits as function of 2-dim latent vector # Arguments models (tuple): encoder and decoder models data (tuple): test data and label batch_size (int): prediction batch size model_name (string): which model is using this function """ encoder, decoder = models x_test = data os.makedirs(model_name, exist_ok=True) filename = os.path.join(model_name, "vae_mean.png") #original dataset #train filenames = glob.glob("./NORMAL_resize_100_1_0506/*.jpeg") X = [] for filename in filenames: img = img_to_array(load_img( filename, color_mode = "grayscale" , target_size=(512,496))) X.append(img) #Y.append(0) X = np.asarray(X) #test img_size = (512,496) dir_name = './NORMAL_resize_0506' file_type = 'jpeg' img_list = glob.glob('./' + dir_name + '/*.' + file_type) test_img_array_list = [] for img in img_list: test_img = load_img(img,grayscale=True, target_size=(img_size))#元々はgrayscale=False test_img_array = img_to_array(test_img) /255 test_img_array_list.append(test_img_array) test_img_array_list = np.array(test_img_array_list) np.save(dir_name+'.npy',test_img_array_list) del test_img_array_list # test_img_array_listをメモリから解放 x_test = np.load('./NORMAL_resize_0506.npy') print(x_test.shape) image_size = (512, 496)#X.shape[1] original_dim = 512 * 496 #3削除 x_train = np.reshape(X, [-1, original_dim])# x_train = np.reshape(X, [-1, original_dim, 1]) x_test = np.reshape(x_test, [-1, original_dim])# x_test = np.reshape(X, [-1, original_dim, 1]) x_train = x_train.astype('float32') / 255 print(x_train.shape) print(x_test.shape) # network parameters input_shape = (original_dim, ) intermediate_dim = 8#512 batch_size = 50#128 latent_dim = 2 epochs = 20#50 # VAE model = encoder + decoder # build encoder model inputs = Input(shape=input_shape, name='encoder_input') x = Dense(intermediate_dim, activation='relu')(inputs) z_mean = Dense(latent_dim, name='z_mean')(x) z_log_var = Dense(latent_dim, name='z_log_var')(x) # use reparameterization trick to push the sampling out as input # note that "output_shape" isn't necessary with the TensorFlow backend z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) # instantiate encoder model encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder') encoder.summary() plot_model(encoder, to_file='vae_mlp_encoder.png', show_shapes=True) # build decoder model latent_inputs = Input(shape=(latent_dim,), name='z_sampling') x = Dense(intermediate_dim, activation='relu')(latent_inputs) outputs = Dense(original_dim, activation='sigmoid')(x) # instantiate decoder model decoder = Model(latent_inputs, outputs, name='decoder') decoder.summary() plot_model(decoder, to_file='vae_mlp_decoder.png', show_shapes=True) # instantiate VAE model outputs = decoder(encoder(inputs)[2]) vae = Model(inputs, outputs, name='vae_mlp') if __name__ == '__main__': parser = argparse.ArgumentParser() help_ = "Load h5 model trained weights" parser.add_argument("-w", "--weights", help=help_) help_ = "Use mse loss instead of binary cross entropy (default)" parser.add_argument("-m", "--mse", help=help_, action='store_true') args = parser.parse_args([]) models = (encoder, decoder) data = (x_test) # VAE loss = mse_loss or xent_loss + kl_loss if args.mse: reconstruction_loss = mse(inputs, outputs) else: reconstruction_loss = binary_crossentropy(inputs, outputs) reconstruction_loss *= original_dim kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var) kl_loss = K.sum(kl_loss, axis=-1) kl_loss *= -0.5 vae_loss = K.mean(reconstruction_loss + kl_loss) vae.add_loss(vae_loss) vae.compile(optimizer='adam') vae.summary() plot_model(vae, to_file='vae_mlp.png', show_shapes=True) callbacks = [] callbacks.append(ModelCheckpoint(filepath="model.ep{epoch:02d}.h5")) callbacks.append(EarlyStopping(monitor='val_loss', patience=0, verbose=1)) #callbacks.append(LearningRateScheduler(lambda ep: float(1e-3 / 3 ** (ep * 4 // 5))))# MAX_EPOCH削除, 5追記 callbacks.append(CSVLogger("history.csv")) if args.weights: vae.load_weights(args.weights) else: # train the autoencoder history=vae.fit(x_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, None), callbacks = callbacks) #vae.save_weights('vae_mlp_mnist.h5') plot_results(models, data, batch_size=batch_size, model_name="vae_mlp") loss = history.history['loss'] val_loss = history.history['val_loss'] plt.plot(range(1,epochs), loss[1:], marker='.', label='loss') plt.plot(range(1,epochs), val_loss[1:], marker='.', label='val_loss') plt.legend(loc='best', fontsize=10) plt.grid() plt.xlabel('epoch') plt.ylabel('loss') plt.show()

試したこと

Kerasのドキュメントを読み、periodで指定するのかと思い試してみましたがうまくいきませんでした。

補足情報(FW/ツールのバージョンなど)

Google colabを使っています。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

動作未確認ですみませんが。

ModelCheckpoint()の引数にチェックポイント間の間隔(エポック数)を設定できるperiod (default-1) があるようですが、これを修正してみるとよいのではないでしょうか。

https://keras.io/ja/callbacks/#modelcheckpoint

投稿2019/05/16 12:28

magichan

総合スコア15898

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

SuzuAya

2019/05/16 12:33

ありがとうございます!!periodを指定すると、確かに重みは思った通りに保存できるようになるのですが、以下のようにplot部分にエラーが出てしまいます…。 ValueError Traceback (most recent call last) <ipython-input-14-d4636a228aac> in <module>() 224 val_loss = history.history['val_loss'] 225 --> 226 plt.plot(range(1,epochs), loss[1:], marker='.', label='loss') 227 plt.plot(range(1,epochs), val_loss[1:], marker='.', label='val_loss') 228 plt.legend(loc='best', fontsize=10) 5 frames /usr/local/lib/python3.6/dist-packages/matplotlib/axes/_base.py in _xy_from_xy(self, x, y) 229 if x.shape[0] != y.shape[0]: 230 raise ValueError("x and y must have same first dimension, but " --> 231 "have shapes {} and {}".format(x.shape, y.shape)) 232 if x.ndim > 2 or y.ndim > 2: 233 raise ValueError("x and y can be no greater than 2-D, but have " ValueError: x and y must have same first dimension, but have shapes (19,) and (6,)
SuzuAya

2019/05/16 12:53

上記のエラーは質問とは別問題(earlystopが実行された時に発生するエラーっぽいです)のようなので、質問については解決済みとさせていただきます。ご回答ありがとうございました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問