回答編集履歴
2
d
test
CHANGED
@@ -12,23 +12,69 @@
|
|
12
12
|
|
13
13
|
|
14
14
|
|
15
|
-
##
|
15
|
+
## モデル構造及び重みの保存方法
|
16
|
+
|
17
|
+
|
18
|
+
|
16
|
-
|
19
|
+
Keras にはモデル構造及び重みの保存方法は以下があります。
|
20
|
+
|
21
|
+
|
22
|
+
|
23
|
+
|
24
|
+
|
17
|
-
|
25
|
+
* 重み + モデル構造 (.h5)
|
26
|
+
|
18
|
-
|
27
|
+
model.save(): 保存
|
28
|
+
|
19
|
-
|
29
|
+
model = load_model(): 読み込み
|
30
|
+
|
31
|
+
|
32
|
+
|
20
|
-
|
33
|
+
* 重み (.h5)
|
34
|
+
|
21
|
-
|
35
|
+
model.save_weights(): 保存
|
36
|
+
|
22
|
-
|
37
|
+
model.load_weights(): 読み込み
|
38
|
+
|
39
|
+
|
40
|
+
|
41
|
+
* モデル構造 (.json)
|
42
|
+
|
43
|
+
json_string = model.to_json(): 保存
|
44
|
+
|
45
|
+
model = model_from_json(json_string): 読み込み
|
46
|
+
|
47
|
+
|
48
|
+
|
49
|
+
* モデル構造 (.yaml)
|
50
|
+
|
51
|
+
yaml_string = model.to_yaml()
|
52
|
+
|
53
|
+
model = model_from_yaml(yaml_string)
|
54
|
+
|
55
|
+
|
56
|
+
|
57
|
+
## モデルの保存 (ModelCheckpoint で最良のモデルのみ保存する。)
|
58
|
+
|
59
|
+
|
60
|
+
|
23
|
-
```
|
61
|
+
```
|
62
|
+
|
63
|
+
import os
|
64
|
+
|
65
|
+
import time
|
66
|
+
|
67
|
+
|
24
68
|
|
25
69
|
import numpy as np
|
26
70
|
|
71
|
+
from keras.callbacks import ModelCheckpoint
|
72
|
+
|
27
73
|
from keras.datasets import mnist
|
28
74
|
|
29
75
|
from keras.layers import Activation, BatchNormalization, Dense
|
30
76
|
|
31
|
-
from keras.models import Sequential
|
77
|
+
from keras.models import Sequential
|
32
78
|
|
33
79
|
from keras.utils.np_utils import to_categorical
|
34
80
|
|
@@ -40,6 +86,22 @@
|
|
40
86
|
|
41
87
|
|
42
88
|
|
89
|
+
# 1次元配列にする。 (28, 28) -> (784,) にする
|
90
|
+
|
91
|
+
x_train = x_train.reshape(len(x_train), -1)
|
92
|
+
|
93
|
+
x_test = x_test.reshape(len(x_test), -1)
|
94
|
+
|
95
|
+
|
96
|
+
|
97
|
+
# one-hot 表現に変換する。
|
98
|
+
|
99
|
+
y_train = to_categorical(y_train)
|
100
|
+
|
101
|
+
y_test = to_categorical(y_test)
|
102
|
+
|
103
|
+
|
104
|
+
|
43
105
|
# モデルを作成する。
|
44
106
|
|
45
107
|
model = Sequential()
|
@@ -70,36 +132,136 @@
|
|
70
132
|
|
71
133
|
|
72
134
|
|
73
|
-
#
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
validation_s
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
model
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
```
|
135
|
+
# 学習する。
|
136
|
+
|
137
|
+
save_path = os.path.join('models', time.strftime("%Y%m%d-%H%M%S")) # 保存するディレクトリ
|
138
|
+
|
139
|
+
os.makedirs(save_path, exist_ok=True)
|
140
|
+
|
141
|
+
model_checkpoint = ModelCheckpoint(
|
142
|
+
|
143
|
+
filepath=os.path.join(save_path, 'model_{epoch:02d}_{val_loss:.2f}.h5'),
|
144
|
+
|
145
|
+
monitor='val_loss',
|
146
|
+
|
147
|
+
save_best_only=True,
|
148
|
+
|
149
|
+
verbose=1)
|
150
|
+
|
151
|
+
|
152
|
+
|
153
|
+
model.fit(x_train, y_train, epochs=10, batch_size=128,
|
154
|
+
|
155
|
+
validation_data=(x_test, y_test),
|
156
|
+
|
157
|
+
callbacks=[model_checkpoint])
|
158
|
+
|
159
|
+
```
|
160
|
+
|
161
|
+
|
162
|
+
|
163
|
+
以下のように models ディレクトリ以下に現在の時刻でフォルダが作成され、その下に重みファイルが保存されます。
|
164
|
+
|
165
|
+
|
166
|
+
|
167
|
+
```
|
168
|
+
|
169
|
+
models
|
170
|
+
|
171
|
+
`-- 20181025-022510
|
172
|
+
|
173
|
+
|-- model_01_0.55.h5
|
174
|
+
|
175
|
+
|-- model_02_0.35.h5
|
176
|
+
|
177
|
+
|-- model_03_0.30.h5
|
178
|
+
|
179
|
+
|-- model_04_0.27.h5
|
180
|
+
|
181
|
+
|-- model_05_0.25.h5
|
182
|
+
|
183
|
+
|-- model_06_0.24.h5
|
184
|
+
|
185
|
+
|-- model_07_0.22.h5
|
186
|
+
|
187
|
+
|-- model_09_0.22.h5
|
188
|
+
|
189
|
+
`-- model_10_0.22.h5
|
190
|
+
|
191
|
+
```
|
192
|
+
|
193
|
+
|
194
|
+
|
195
|
+
### 読み込むとき
|
196
|
+
|
197
|
+
|
198
|
+
|
199
|
+
モデル構造は、save_json(), load_from_json() を使ってもいいですが、モデルがわかっているのであれば、それを構築して、model.load_weights() で重みだけ読み込めばよいです。
|
200
|
+
|
201
|
+
|
202
|
+
|
203
|
+
```
|
204
|
+
|
205
|
+
# 一番新しいモデルのパスを取得する。
|
206
|
+
|
207
|
+
import glob
|
208
|
+
|
209
|
+
model_paths = glob.glob(os.path.join(save_path, 'model_*.h5'))
|
210
|
+
|
211
|
+
last_model_path = sorted(model_paths, reverse=True)[0]
|
212
|
+
|
213
|
+
print('loading weights...', last_model_path)
|
214
|
+
|
215
|
+
|
216
|
+
|
217
|
+
# モデルを作成する。
|
218
|
+
|
219
|
+
model = Sequential()
|
220
|
+
|
221
|
+
model.add(Dense(10, input_dim=784))
|
222
|
+
|
223
|
+
model.add(BatchNormalization())
|
224
|
+
|
225
|
+
model.add(Activation('relu'))
|
226
|
+
|
227
|
+
model.add(Dense(10))
|
228
|
+
|
229
|
+
model.add(BatchNormalization())
|
230
|
+
|
231
|
+
model.add(Activation('relu'))
|
232
|
+
|
233
|
+
model.add(Dense(10))
|
234
|
+
|
235
|
+
model.add(BatchNormalization())
|
236
|
+
|
237
|
+
model.add(Activation('softmax'))
|
238
|
+
|
239
|
+
model.compile(optimizer='adam',
|
240
|
+
|
241
|
+
loss='categorical_crossentropy',
|
242
|
+
|
243
|
+
metrics=['accuracy'])
|
244
|
+
|
245
|
+
|
246
|
+
|
247
|
+
# 重みを読み込む。
|
248
|
+
|
249
|
+
model.load_weights(last_model_path)
|
250
|
+
|
251
|
+
loss, acc = model.evaluate(x_train, y_train)
|
252
|
+
|
253
|
+
print('Loss: {}, Accuracy: {}'.format(loss, acc))
|
254
|
+
|
255
|
+
```
|
256
|
+
|
257
|
+
|
258
|
+
|
259
|
+
```
|
260
|
+
|
261
|
+
loading weights... models/20181025-022510/model_10_0.22.h5
|
262
|
+
|
263
|
+
60000/60000 [==============================] - 1s 23us/step
|
264
|
+
|
265
|
+
Loss: 0.18432119323263566, Accuracy: 94.75%
|
266
|
+
|
267
|
+
```
|
1
d
test
CHANGED
@@ -9,3 +9,97 @@
|
|
9
9
|
os.path.abspath(PATH + "/model.json")
|
10
10
|
|
11
11
|
```
|
12
|
+
|
13
|
+
|
14
|
+
|
15
|
+
## 追記
|
16
|
+
|
17
|
+
|
18
|
+
|
19
|
+
モデルの保存、読み込みであれば、model.save() 及び load_model() を使うとよいと思います。
|
20
|
+
|
21
|
+
|
22
|
+
|
23
|
+
```
|
24
|
+
|
25
|
+
import numpy as np
|
26
|
+
|
27
|
+
from keras.datasets import mnist
|
28
|
+
|
29
|
+
from keras.layers import Activation, BatchNormalization, Dense
|
30
|
+
|
31
|
+
from keras.models import Sequential, load_model
|
32
|
+
|
33
|
+
from keras.utils.np_utils import to_categorical
|
34
|
+
|
35
|
+
|
36
|
+
|
37
|
+
# MNIST データを取得する。
|
38
|
+
|
39
|
+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
40
|
+
|
41
|
+
|
42
|
+
|
43
|
+
# モデルを作成する。
|
44
|
+
|
45
|
+
model = Sequential()
|
46
|
+
|
47
|
+
model.add(Dense(10, input_dim=784))
|
48
|
+
|
49
|
+
model.add(BatchNormalization())
|
50
|
+
|
51
|
+
model.add(Activation('relu'))
|
52
|
+
|
53
|
+
model.add(Dense(10))
|
54
|
+
|
55
|
+
model.add(BatchNormalization())
|
56
|
+
|
57
|
+
model.add(Activation('relu'))
|
58
|
+
|
59
|
+
model.add(Dense(10))
|
60
|
+
|
61
|
+
model.add(BatchNormalization())
|
62
|
+
|
63
|
+
model.add(Activation('softmax'))
|
64
|
+
|
65
|
+
model.compile(optimizer='adam',
|
66
|
+
|
67
|
+
loss='categorical_crossentropy',
|
68
|
+
|
69
|
+
metrics=['accuracy'])
|
70
|
+
|
71
|
+
|
72
|
+
|
73
|
+
# モデルの入力に合わせて1次元配列にする。 (28, 28) -> (784,) にする
|
74
|
+
|
75
|
+
x_train = x_train.reshape(len(x_train), -1)
|
76
|
+
|
77
|
+
x_test = x_test.reshape(len(x_test), -1)
|
78
|
+
|
79
|
+
|
80
|
+
|
81
|
+
# one-hot 表現に変換する。
|
82
|
+
|
83
|
+
y_train_onehot = to_categorical(y_train)
|
84
|
+
|
85
|
+
|
86
|
+
|
87
|
+
model.fit(x_train, y_train_onehot,
|
88
|
+
|
89
|
+
batch_size=128,
|
90
|
+
|
91
|
+
epochs=10,
|
92
|
+
|
93
|
+
validation_split=0.1)
|
94
|
+
|
95
|
+
|
96
|
+
|
97
|
+
model.save('model.h5')
|
98
|
+
|
99
|
+
|
100
|
+
|
101
|
+
model = load_model('model.h5')
|
102
|
+
|
103
|
+
model.summary()
|
104
|
+
|
105
|
+
```
|