回答編集履歴

2

d

2018/10/25 02:30

投稿

tiitoi
tiitoi

スコア21956

test CHANGED
@@ -12,23 +12,69 @@
12
12
 
13
13
 
14
14
 
15
- ## 追記
15
+ ## モデル構造及び重みの保存方法
16
+
17
+
18
+
16
-
19
+ Keras にはモデル構造及び重みの保存方法は以下があります。
20
+
21
+
22
+
23
+
24
+
17
-
25
+ * 重み + モデル構造 (.h5)
26
+
18
-
27
+ model.save(): 保存
28
+
19
- モデルの保存、読み込みであれば、model.save() 及び load_model() を使うとよいと思います。
29
+ model = load_model(): 読み込み
30
+
31
+
32
+
20
-
33
+ * 重み (.h5)
34
+
21
-
35
+ model.save_weights(): 保存
36
+
22
-
37
+ model.load_weights(): 読み込み
38
+
39
+
40
+
41
+ * モデル構造 (.json)
42
+
43
+ json_string = model.to_json(): 保存
44
+
45
+ model = model_from_json(json_string): 読み込み
46
+
47
+
48
+
49
+ * モデル構造 (.yaml)
50
+
51
+ yaml_string = model.to_yaml()
52
+
53
+ model = model_from_yaml(yaml_string)
54
+
55
+
56
+
57
+ ## モデルの保存 (ModelCheckpoint で最良のモデルのみ保存する。)
58
+
59
+
60
+
23
- ```
61
+ ```
62
+
63
+ import os
64
+
65
+ import time
66
+
67
+
24
68
 
25
69
  import numpy as np
26
70
 
71
+ from keras.callbacks import ModelCheckpoint
72
+
27
73
  from keras.datasets import mnist
28
74
 
29
75
  from keras.layers import Activation, BatchNormalization, Dense
30
76
 
31
- from keras.models import Sequential, load_model
77
+ from keras.models import Sequential
32
78
 
33
79
  from keras.utils.np_utils import to_categorical
34
80
 
@@ -40,6 +86,22 @@
40
86
 
41
87
 
42
88
 
89
+ # 1次元配列にする。 (28, 28) -> (784,) にする
90
+
91
+ x_train = x_train.reshape(len(x_train), -1)
92
+
93
+ x_test = x_test.reshape(len(x_test), -1)
94
+
95
+
96
+
97
+ # one-hot 表現に変換する。
98
+
99
+ y_train = to_categorical(y_train)
100
+
101
+ y_test = to_categorical(y_test)
102
+
103
+
104
+
43
105
  # モデルを作成する。
44
106
 
45
107
  model = Sequential()
@@ -70,36 +132,136 @@
70
132
 
71
133
 
72
134
 
73
- # モデルの入力に合わせて1次元配列にする。 (28, 28) -> (784,) にする
74
-
75
- x_train = x_train.reshape(len(x_train), -1)
76
-
77
- x_test = x_test.reshape(len(x_test), -1)
78
-
79
-
80
-
81
- # one-hot 表現に変換する。
82
-
83
- y_train_onehot = to_categorical(y_train)
84
-
85
-
86
-
87
- model.fit(x_train, y_train_onehot,
88
-
89
- batch_size=128,
90
-
91
- epochs=10,
92
-
93
- validation_split=0.1)
94
-
95
-
96
-
97
- model.save('model.h5')
98
-
99
-
100
-
101
- model = load_model('model.h5')
102
-
103
- model.summary()
104
-
105
- ```
135
+ # 学習する。
136
+
137
+ save_path = os.path.join('models', time.strftime("%Y%m%d-%H%M%S")) # 保存するディレクトリ
138
+
139
+ os.makedirs(save_path, exist_ok=True)
140
+
141
+ model_checkpoint = ModelCheckpoint(
142
+
143
+ filepath=os.path.join(save_path, 'model_{epoch:02d}_{val_loss:.2f}.h5'),
144
+
145
+ monitor='val_loss',
146
+
147
+ save_best_only=True,
148
+
149
+ verbose=1)
150
+
151
+
152
+
153
+ model.fit(x_train, y_train, epochs=10, batch_size=128,
154
+
155
+ validation_data=(x_test, y_test),
156
+
157
+ callbacks=[model_checkpoint])
158
+
159
+ ```
160
+
161
+
162
+
163
+ 以下のように models ディレクトリ以下に現在の時刻でフォルダが作成され、その下に重みファイルが保存されます。
164
+
165
+
166
+
167
+ ```
168
+
169
+ models
170
+
171
+ `-- 20181025-022510
172
+
173
+ |-- model_01_0.55.h5
174
+
175
+ |-- model_02_0.35.h5
176
+
177
+ |-- model_03_0.30.h5
178
+
179
+ |-- model_04_0.27.h5
180
+
181
+ |-- model_05_0.25.h5
182
+
183
+ |-- model_06_0.24.h5
184
+
185
+ |-- model_07_0.22.h5
186
+
187
+ |-- model_09_0.22.h5
188
+
189
+ `-- model_10_0.22.h5
190
+
191
+ ```
192
+
193
+
194
+
195
+ ### 読み込むとき
196
+
197
+
198
+
199
+ モデル構造は、save_json(), load_from_json() を使ってもいいですが、モデルがわかっているのであれば、それを構築して、model.load_weights() で重みだけ読み込めばよいです。
200
+
201
+
202
+
203
+ ```
204
+
205
+ # 一番新しいモデルのパスを取得する。
206
+
207
+ import glob
208
+
209
+ model_paths = glob.glob(os.path.join(save_path, 'model_*.h5'))
210
+
211
+ last_model_path = sorted(model_paths, reverse=True)[0]
212
+
213
+ print('loading weights...', last_model_path)
214
+
215
+
216
+
217
+ # モデルを作成する。
218
+
219
+ model = Sequential()
220
+
221
+ model.add(Dense(10, input_dim=784))
222
+
223
+ model.add(BatchNormalization())
224
+
225
+ model.add(Activation('relu'))
226
+
227
+ model.add(Dense(10))
228
+
229
+ model.add(BatchNormalization())
230
+
231
+ model.add(Activation('relu'))
232
+
233
+ model.add(Dense(10))
234
+
235
+ model.add(BatchNormalization())
236
+
237
+ model.add(Activation('softmax'))
238
+
239
+ model.compile(optimizer='adam',
240
+
241
+ loss='categorical_crossentropy',
242
+
243
+ metrics=['accuracy'])
244
+
245
+
246
+
247
+ # 重みを読み込む。
248
+
249
+ model.load_weights(last_model_path)
250
+
251
+ loss, acc = model.evaluate(x_train, y_train)
252
+
253
+ print('Loss: {}, Accuracy: {}'.format(loss, acc))
254
+
255
+ ```
256
+
257
+
258
+
259
+ ```
260
+
261
+ loading weights... models/20181025-022510/model_10_0.22.h5
262
+
263
+ 60000/60000 [==============================] - 1s 23us/step
264
+
265
+ Loss: 0.18432119323263566, Accuracy: 94.75%
266
+
267
+ ```

1

d

2018/10/25 02:30

投稿

tiitoi
tiitoi

スコア21956

test CHANGED
@@ -9,3 +9,97 @@
9
9
  os.path.abspath(PATH + "/model.json")
10
10
 
11
11
  ```
12
+
13
+
14
+
15
+ ## 追記
16
+
17
+
18
+
19
+ モデルの保存、読み込みであれば、model.save() 及び load_model() を使うとよいと思います。
20
+
21
+
22
+
23
+ ```
24
+
25
+ import numpy as np
26
+
27
+ from keras.datasets import mnist
28
+
29
+ from keras.layers import Activation, BatchNormalization, Dense
30
+
31
+ from keras.models import Sequential, load_model
32
+
33
+ from keras.utils.np_utils import to_categorical
34
+
35
+
36
+
37
+ # MNIST データを取得する。
38
+
39
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
40
+
41
+
42
+
43
+ # モデルを作成する。
44
+
45
+ model = Sequential()
46
+
47
+ model.add(Dense(10, input_dim=784))
48
+
49
+ model.add(BatchNormalization())
50
+
51
+ model.add(Activation('relu'))
52
+
53
+ model.add(Dense(10))
54
+
55
+ model.add(BatchNormalization())
56
+
57
+ model.add(Activation('relu'))
58
+
59
+ model.add(Dense(10))
60
+
61
+ model.add(BatchNormalization())
62
+
63
+ model.add(Activation('softmax'))
64
+
65
+ model.compile(optimizer='adam',
66
+
67
+ loss='categorical_crossentropy',
68
+
69
+ metrics=['accuracy'])
70
+
71
+
72
+
73
+ # モデルの入力に合わせて1次元配列にする。 (28, 28) -> (784,) にする
74
+
75
+ x_train = x_train.reshape(len(x_train), -1)
76
+
77
+ x_test = x_test.reshape(len(x_test), -1)
78
+
79
+
80
+
81
+ # one-hot 表現に変換する。
82
+
83
+ y_train_onehot = to_categorical(y_train)
84
+
85
+
86
+
87
+ model.fit(x_train, y_train_onehot,
88
+
89
+ batch_size=128,
90
+
91
+ epochs=10,
92
+
93
+ validation_split=0.1)
94
+
95
+
96
+
97
+ model.save('model.h5')
98
+
99
+
100
+
101
+ model = load_model('model.h5')
102
+
103
+ model.summary()
104
+
105
+ ```