質問編集履歴
1
内容追加
test
CHANGED
File without changes
|
test
CHANGED
@@ -129,3 +129,137 @@
|
|
129
129
|
model.add(Dense(10, activation='softmax'))
|
130
130
|
|
131
131
|
```
|
132
|
+
|
133
|
+
|
134
|
+
|
135
|
+
|
136
|
+
|
137
|
+
|
138
|
+
|
139
|
+
|
140
|
+
|
141
|
+
---
|
142
|
+
|
143
|
+
|
144
|
+
|
145
|
+
|
146
|
+
|
147
|
+
データ取得部分
|
148
|
+
|
149
|
+
```python
|
150
|
+
|
151
|
+
orig_image, orig_label, class_list = load_imgs(root_dir='kill_me_images/kill_me_baby_datasets/kill_me_baby_datasets/')
|
152
|
+
|
153
|
+
#['agiri', 'botsu', 'others', 'sonya', 'yasuna', 'yasuna&agiri', 'yasuna&sonya']
|
154
|
+
|
155
|
+
```
|
156
|
+
|
157
|
+
---
|
158
|
+
|
159
|
+
関数部分(cv2.COLOR_BGR2RGBでBRGをRGBに調整)
|
160
|
+
|
161
|
+
```python
|
162
|
+
|
163
|
+
def load_imgs(root_dir):
|
164
|
+
|
165
|
+
|
166
|
+
|
167
|
+
print(class_list)
|
168
|
+
|
169
|
+
num_class = len(class_list)
|
170
|
+
|
171
|
+
img_paths = []
|
172
|
+
|
173
|
+
labels = []
|
174
|
+
|
175
|
+
images = []
|
176
|
+
|
177
|
+
for cl_name in class_list:
|
178
|
+
|
179
|
+
img_names = os.listdir(os.path.join(root_dir, cl_name))
|
180
|
+
|
181
|
+
for img_name in img_names:
|
182
|
+
|
183
|
+
img_paths.append(os.path.abspath(os.path.join(root_dir, cl_name, img_name)))
|
184
|
+
|
185
|
+
hot_cl_name = get_class_one_hot(cl_name, class_list)
|
186
|
+
|
187
|
+
labels.append(hot_cl_name)
|
188
|
+
|
189
|
+
|
190
|
+
|
191
|
+
for img_path in img_paths:
|
192
|
+
|
193
|
+
img = cv2.imread(img_path)
|
194
|
+
|
195
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
196
|
+
|
197
|
+
images.append(img)
|
198
|
+
|
199
|
+
|
200
|
+
|
201
|
+
images = np.array(images)
|
202
|
+
|
203
|
+
|
204
|
+
|
205
|
+
return np.array(images), np.array(labels), class_list
|
206
|
+
|
207
|
+
```
|
208
|
+
|
209
|
+
---
|
210
|
+
|
211
|
+
onehot
|
212
|
+
|
213
|
+
```python
|
214
|
+
|
215
|
+
def get_class_one_hot(class_str, class_list):
|
216
|
+
|
217
|
+
label = class_list.index(class_str)
|
218
|
+
|
219
|
+
label_hot = tf.one_hot(label, len(class_list))
|
220
|
+
|
221
|
+
|
222
|
+
|
223
|
+
return label_hot
|
224
|
+
|
225
|
+
```
|
226
|
+
|
227
|
+
実行部分
|
228
|
+
|
229
|
+
```python
|
230
|
+
|
231
|
+
batch_size = 32
|
232
|
+
|
233
|
+
|
234
|
+
|
235
|
+
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])
|
236
|
+
|
237
|
+
tb_cb = tf.keras.callbacks.TensorBoard(log_dir="log_dir")
|
238
|
+
|
239
|
+
ckps = [tb_cb]
|
240
|
+
|
241
|
+
|
242
|
+
|
243
|
+
# 学習用データを用意する
|
244
|
+
|
245
|
+
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, horizontal_flip=True)
|
246
|
+
|
247
|
+
train_generator = train_datagen.flow_from_directory('kill_me_images/kill_me_baby_datasets/', classes=class_list, target_size=(128, 128), batch_size=batch_size, class_mode='categorical')
|
248
|
+
|
249
|
+
|
250
|
+
|
251
|
+
# 学習開始!
|
252
|
+
|
253
|
+
model.fit_generator(train_generator, steps_per_epoch=train_generator.samples//batch_size, epochs=100, callbacks=ckps)
|
254
|
+
|
255
|
+
model.save("models/killme_vgg16.h5")
|
256
|
+
|
257
|
+
|
258
|
+
|
259
|
+
sess = tf.keras.backend.get_session()
|
260
|
+
|
261
|
+
saver = tf.train.Saver()
|
262
|
+
|
263
|
+
saver.save(sess, "models/killme_vgg16.ckpt")
|
264
|
+
|
265
|
+
```
|