teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

1

内容追加

2020/12/22 11:41

投稿

kane_study
kane_study

スコア4

title CHANGED
File without changes
body CHANGED
@@ -63,4 +63,71 @@
63
63
  model.add(Dense(512, activation='relu'))
64
64
  model.add(Dropout(0.5))
65
65
  model.add(Dense(10, activation='softmax'))
66
+ ```
67
+
68
+
69
+
70
+
71
+ ---
72
+
73
+
74
+ データ取得部分
75
+ ```python
76
+ orig_image, orig_label, class_list = load_imgs(root_dir='kill_me_images/kill_me_baby_datasets/kill_me_baby_datasets/')
77
+ #['agiri', 'botsu', 'others', 'sonya', 'yasuna', 'yasuna&agiri', 'yasuna&sonya']
78
+ ```
79
+ ---
80
+ 関数部分(cv2.COLOR_BGR2RGBでBRGをRGBに調整)
81
+ ```python
82
+ def load_imgs(root_dir):
83
+
84
+ print(class_list)
85
+ num_class = len(class_list)
86
+ img_paths = []
87
+ labels = []
88
+ images = []
89
+ for cl_name in class_list:
90
+ img_names = os.listdir(os.path.join(root_dir, cl_name))
91
+ for img_name in img_names:
92
+ img_paths.append(os.path.abspath(os.path.join(root_dir, cl_name, img_name)))
93
+ hot_cl_name = get_class_one_hot(cl_name, class_list)
94
+ labels.append(hot_cl_name)
95
+
96
+ for img_path in img_paths:
97
+ img = cv2.imread(img_path)
98
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
99
+ images.append(img)
100
+
101
+ images = np.array(images)
102
+
103
+ return np.array(images), np.array(labels), class_list
104
+ ```
105
+ ---
106
+ onehot
107
+ ```python
108
+ def get_class_one_hot(class_str, class_list):
109
+ label = class_list.index(class_str)
110
+ label_hot = tf.one_hot(label, len(class_list))
111
+
112
+ return label_hot
113
+ ```
114
+ 実行部分
115
+ ```python
116
+ batch_size = 32
117
+
118
+ model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])
119
+ tb_cb = tf.keras.callbacks.TensorBoard(log_dir="log_dir")
120
+ ckps = [tb_cb]
121
+
122
+ # 学習用データを用意する
123
+ train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, horizontal_flip=True)
124
+ train_generator = train_datagen.flow_from_directory('kill_me_images/kill_me_baby_datasets/', classes=class_list, target_size=(128, 128), batch_size=batch_size, class_mode='categorical')
125
+
126
+ # 学習開始!
127
+ model.fit_generator(train_generator, steps_per_epoch=train_generator.samples//batch_size, epochs=100, callbacks=ckps)
128
+ model.save("models/killme_vgg16.h5")
129
+
130
+ sess = tf.keras.backend.get_session()
131
+ saver = tf.train.Saver()
132
+ saver.save(sess, "models/killme_vgg16.ckpt")
66
133
  ```