質問編集履歴
1
おそらくエラーの原因と思われるデータローダのコードを補足しました。
title
CHANGED
File without changes
|
body
CHANGED
@@ -176,4 +176,98 @@
|
|
176
176
|
データローダーに問題があると思われるが原因が分からなかった。
|
177
177
|
|
178
178
|
### 補足
|
179
|
+
class DataTransform():
|
180
|
+
"""
|
181
|
+
画像とアノテーションの前処理クラス。訓練時と検証時で異なる動作をする。
|
182
|
+
画像のサイズをinput_size x input_sizeにする。
|
183
|
+
訓練時はデータオーギュメンテーションする。
|
184
|
+
|
185
|
+
|
186
|
+
Attributes
|
187
|
+
----------
|
188
|
+
input_size : int
|
189
|
+
リサイズ先の画像の大きさ。
|
190
|
+
color_mean : (R, G, B)
|
191
|
+
各色チャネルの平均値。
|
192
|
+
color_std : (R, G, B)
|
193
|
+
各色チャネルの標準偏差。
|
194
|
+
"""
|
195
|
+
|
196
|
+
def __init__(self, input_size, color_mean, color_std):
|
197
|
+
self.data_transform = {
|
198
|
+
'train': Compose([
|
199
|
+
Scale(scale=[0.5, 1.5]), # 画像の拡大
|
200
|
+
RandomRotation(angle=[-10, 10]), # 回転
|
201
|
+
RandomMirror(), # ランダムミラー
|
202
|
+
Resize(input_size), # リサイズ(input_size)
|
203
|
+
Normalize_Tensor(color_mean, color_std) # 色情報の標準化とテンソル化
|
179
|
-
|
204
|
+
]),
|
205
|
+
'val': Compose([
|
206
|
+
Resize(input_size), # リサイズ(input_size)
|
207
|
+
Normalize_Tensor(color_mean, color_std) # 色情報の標準化とテンソル化
|
208
|
+
])
|
209
|
+
}
|
210
|
+
|
211
|
+
def __call__(self, phase, img, anno_class_img):
|
212
|
+
"""
|
213
|
+
Parameters
|
214
|
+
----------
|
215
|
+
phase : 'train' or 'val'
|
216
|
+
前処理のモードを指定。
|
217
|
+
"""
|
218
|
+
return self.data_transform[phase](img, anno_class_img)
|
219
|
+
|
220
|
+
|
221
|
+
class VOCDataset(data.Dataset):
|
222
|
+
"""
|
223
|
+
VOC2012のDatasetを作成するクラス。PyTorchのDatasetクラスを継承。
|
224
|
+
|
225
|
+
Attributes
|
226
|
+
----------
|
227
|
+
img_list : リスト
|
228
|
+
画像のパスを格納したリスト
|
229
|
+
anno_list : リスト
|
230
|
+
アノテーションへのパスを格納したリスト
|
231
|
+
phase : 'train' or 'test'
|
232
|
+
学習か訓練かを設定する。
|
233
|
+
transform : object
|
234
|
+
前処理クラスのインスタンス
|
235
|
+
"""
|
236
|
+
|
237
|
+
def __init__(self, img_list, anno_list, phase, transform):
|
238
|
+
self.img_list = img_list
|
239
|
+
self.anno_list = anno_list
|
240
|
+
self.phase = phase
|
241
|
+
self.transform = transform
|
242
|
+
|
243
|
+
def __len__(self):
|
244
|
+
'''画像の枚数を返す'''
|
245
|
+
return len(self.img_list)
|
246
|
+
|
247
|
+
def __getitem__(self, index):
|
248
|
+
'''
|
249
|
+
前処理をした画像のTensor形式のデータとアノテーションを取得
|
250
|
+
'''
|
251
|
+
img, anno_class_img = self.pull_item(index)
|
252
|
+
print("Before transformation - img shape:", img.shape)
|
253
|
+
print("Before transformation - anno_class_img shape:", anno_class_img.shape)
|
254
|
+
return img, anno_class_img
|
255
|
+
|
256
|
+
|
257
|
+
def pull_item(self, index):
|
258
|
+
'''画像のTensor形式のデータ、アノテーションを取得する'''
|
259
|
+
|
260
|
+
# 1. 画像読み込み
|
261
|
+
image_file_path = self.img_list[index]
|
262
|
+
img = Image.open(image_file_path) # [高さ][幅][色RGB]
|
263
|
+
|
264
|
+
# 2. アノテーション画像読み込み
|
265
|
+
anno_file_path = self.anno_list[index]
|
266
|
+
anno_class_img = Image.open(anno_file_path) # [高さ][幅]
|
267
|
+
|
268
|
+
# 3. 前処理を実施
|
269
|
+
img, anno_class_img = self.transform(self.phase, img, anno_class_img)
|
270
|
+
print("After transformation - img shape:", img.shape)
|
271
|
+
print("After transformation - anno_class_img shape:", anno_class_img.shape)
|
272
|
+
|
273
|
+
return img, anno_class_img
|