kerasでディープラーニングの計算をする際に、一日放置していた時、モデルが過学習を起こしていたので、指定したepochごとに学習したモデルとlossなどの計算結果を保存するプログラムが欲しいです。
fitで学習を実行したときに引数のcallbacksでログを保存することは知っているのですが、それだと最終結果しか保存されないので困っています。
よろしくお願いします。
気になる質問をクリップする
クリップした質問は、後からいつでもMYページで確認できます。
またクリップした質問に回答があった際、通知やメールを受け取ることができます。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
回答1件
0
ベストアンサー
ModelCheckpoint を fit() 関数に渡してあげるとできます。
サンプルコード
python
1import os 2 3import numpy as np 4from keras.callbacks import ModelCheckpoint 5from keras.datasets import mnist 6from keras.layers import Activation, BatchNormalization, Dense 7from keras.models import Sequential 8from keras.utils.np_utils import to_categorical 9 10# MNIST データを取得する。 11(x_train, y_train), (x_test, y_test) = mnist.load_data() 12print('x_train.shape', x_train.shape) # x_train.shape (60000, 28, 28) 13print('y_train.shape', y_train.shape) # y_train.shape (60000,) 14print('x_test.shape', x_test.shape) # x_test.shape (10000, 28, 28) 15print('y_test.shape', y_test.shape) # y_test.shape (10000,) 16 17# 1次元配列にする。 (28, 28) -> (784,) にする 18x_train = x_train.reshape(len(x_train), -1) 19x_test = x_test.reshape(len(x_test), -1) 20 21# one-hot 表現に変換する。 22y_train = to_categorical(y_train) 23y_test = to_categorical(y_test) 24 25# モデルを作成する。 26model = Sequential() 27model.add(Dense(10, input_dim=784)) 28model.add(BatchNormalization()) 29model.add(Activation('relu')) 30model.add(Dense(10)) 31model.add(BatchNormalization()) 32model.add(Activation('relu')) 33model.add(Dense(10)) 34model.add(BatchNormalization()) 35model.add(Activation('softmax')) 36model.compile(optimizer='adam', 37 loss='categorical_crossentropy', 38 metrics=['accuracy']) 39 40# 学習する。 41os.makedirs('models', exist_ok=True) 42model_checkpoint = ModelCheckpoint( 43 filepath=os.path.join('models', 'model_{epoch:02d}_{val_loss:.2f}.h5'), 44 monitor='val_loss', 45 verbose=1) 46 47history = model.fit(x_train, y_train, epochs=10, batch_size=128, 48 validation_data=(x_test, y_test), 49 callbacks=[model_checkpoint])
bash
1$ tree models 2models 3|-- model_01_0.52.h5 4|-- model_02_0.35.h5 5|-- model_03_0.29.h5 6|-- model_04_0.26.h5 7|-- model_05_0.25.h5 8|-- model_06_0.23.h5 9|-- model_07_0.23.h5 10|-- model_08_0.22.h5 11|-- model_09_0.22.h5 12`-- model_10_0.22.h5 13 140 directories, 10 files
投稿2018/10/04 02:18
編集2018/10/04 02:20総合スコア21956
あなたの回答
tips
太字
斜体
打ち消し線
見出し
引用テキストの挿入
コードの挿入
リンクの挿入
リストの挿入
番号リストの挿入
表の挿入
水平線の挿入
プレビュー
質問の解決につながる回答をしましょう。 サンプルコードなど、より具体的な説明があると質問者の理解の助けになります。 また、読む側のことを考えた、分かりやすい文章を心がけましょう。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2018/10/04 02:51
2018/10/04 02:54
2018/10/04 04:18