質問編集履歴
1
内容追加
title
CHANGED
File without changes
|
body
CHANGED
@@ -63,4 +63,71 @@
|
|
63
63
|
model.add(Dense(512, activation='relu'))
|
64
64
|
model.add(Dropout(0.5))
|
65
65
|
model.add(Dense(10, activation='softmax'))
|
66
|
+
```
|
67
|
+
|
68
|
+
|
69
|
+
|
70
|
+
|
71
|
+
---
|
72
|
+
|
73
|
+
|
74
|
+
データ取得部分
|
75
|
+
```python
|
76
|
+
orig_image, orig_label, class_list = load_imgs(root_dir='kill_me_images/kill_me_baby_datasets/kill_me_baby_datasets/')
|
77
|
+
#['agiri', 'botsu', 'others', 'sonya', 'yasuna', 'yasuna&agiri', 'yasuna&sonya']
|
78
|
+
```
|
79
|
+
---
|
80
|
+
関数部分(cv2.COLOR_BGR2RGBでBRGをRGBに調整)
|
81
|
+
```python
|
82
|
+
def load_imgs(root_dir):
|
83
|
+
|
84
|
+
print(class_list)
|
85
|
+
num_class = len(class_list)
|
86
|
+
img_paths = []
|
87
|
+
labels = []
|
88
|
+
images = []
|
89
|
+
for cl_name in class_list:
|
90
|
+
img_names = os.listdir(os.path.join(root_dir, cl_name))
|
91
|
+
for img_name in img_names:
|
92
|
+
img_paths.append(os.path.abspath(os.path.join(root_dir, cl_name, img_name)))
|
93
|
+
hot_cl_name = get_class_one_hot(cl_name, class_list)
|
94
|
+
labels.append(hot_cl_name)
|
95
|
+
|
96
|
+
for img_path in img_paths:
|
97
|
+
img = cv2.imread(img_path)
|
98
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
99
|
+
images.append(img)
|
100
|
+
|
101
|
+
images = np.array(images)
|
102
|
+
|
103
|
+
return np.array(images), np.array(labels), class_list
|
104
|
+
```
|
105
|
+
---
|
106
|
+
onehot
|
107
|
+
```python
|
108
|
+
def get_class_one_hot(class_str, class_list):
|
109
|
+
label = class_list.index(class_str)
|
110
|
+
label_hot = tf.one_hot(label, len(class_list))
|
111
|
+
|
112
|
+
return label_hot
|
113
|
+
```
|
114
|
+
実行部分
|
115
|
+
```python
|
116
|
+
batch_size = 32
|
117
|
+
|
118
|
+
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])
|
119
|
+
tb_cb = tf.keras.callbacks.TensorBoard(log_dir="log_dir")
|
120
|
+
ckps = [tb_cb]
|
121
|
+
|
122
|
+
# 学習用データを用意する
|
123
|
+
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, horizontal_flip=True)
|
124
|
+
train_generator = train_datagen.flow_from_directory('kill_me_images/kill_me_baby_datasets/', classes=class_list, target_size=(128, 128), batch_size=batch_size, class_mode='categorical')
|
125
|
+
|
126
|
+
# 学習開始!
|
127
|
+
model.fit_generator(train_generator, steps_per_epoch=train_generator.samples//batch_size, epochs=100, callbacks=ckps)
|
128
|
+
model.save("models/killme_vgg16.h5")
|
129
|
+
|
130
|
+
sess = tf.keras.backend.get_session()
|
131
|
+
saver = tf.train.Saver()
|
132
|
+
saver.save(sess, "models/killme_vgg16.ckpt")
|
66
133
|
```
|