質問するログイン新規登録

質問編集履歴

1

おそらくエラーの原因と思われるデータローダのコードを補足しました。

2024/01/17 05:25

投稿

hataaaaa
hataaaaa

スコア0

title CHANGED
File without changes
body CHANGED
@@ -176,4 +176,98 @@
176
176
  データローダーに問題があると思われるが原因が分からなかった。
177
177
 
178
178
  ### 補足
179
+ class DataTransform():
180
+ """
181
+ 画像とアノテーションの前処理クラス。訓練時と検証時で異なる動作をする。
182
+ 画像のサイズをinput_size x input_sizeにする。
183
+ 訓練時はデータオーギュメンテーションする。
184
+
185
+
186
+ Attributes
187
+ ----------
188
+ input_size : int
189
+ リサイズ先の画像の大きさ。
190
+ color_mean : (R, G, B)
191
+ 各色チャネルの平均値。
192
+ color_std : (R, G, B)
193
+ 各色チャネルの標準偏差。
194
+ """
195
+
196
+ def __init__(self, input_size, color_mean, color_std):
197
+ self.data_transform = {
198
+ 'train': Compose([
199
+ Scale(scale=[0.5, 1.5]), # 画像の拡大
200
+ RandomRotation(angle=[-10, 10]), # 回転
201
+ RandomMirror(), # ランダムミラー
202
+ Resize(input_size), # リサイズ(input_size)
203
+ Normalize_Tensor(color_mean, color_std) # 色情報の標準化とテンソル化
179
- 特になし
204
+ ]),
205
+ 'val': Compose([
206
+ Resize(input_size), # リサイズ(input_size)
207
+ Normalize_Tensor(color_mean, color_std) # 色情報の標準化とテンソル化
208
+ ])
209
+ }
210
+
211
+ def __call__(self, phase, img, anno_class_img):
212
+ """
213
+ Parameters
214
+ ----------
215
+ phase : 'train' or 'val'
216
+ 前処理のモードを指定。
217
+ """
218
+ return self.data_transform[phase](img, anno_class_img)
219
+
220
+
221
+ class VOCDataset(data.Dataset):
222
+ """
223
+ VOC2012のDatasetを作成するクラス。PyTorchのDatasetクラスを継承。
224
+
225
+ Attributes
226
+ ----------
227
+ img_list : リスト
228
+ 画像のパスを格納したリスト
229
+ anno_list : リスト
230
+ アノテーションへのパスを格納したリスト
231
+ phase : 'train' or 'test'
232
+ 学習か訓練かを設定する。
233
+ transform : object
234
+ 前処理クラスのインスタンス
235
+ """
236
+
237
+ def __init__(self, img_list, anno_list, phase, transform):
238
+ self.img_list = img_list
239
+ self.anno_list = anno_list
240
+ self.phase = phase
241
+ self.transform = transform
242
+
243
+ def __len__(self):
244
+ '''画像の枚数を返す'''
245
+ return len(self.img_list)
246
+
247
+ def __getitem__(self, index):
248
+ '''
249
+ 前処理をした画像のTensor形式のデータとアノテーションを取得
250
+ '''
251
+ img, anno_class_img = self.pull_item(index)
252
+ print("Before transformation - img shape:", img.shape)
253
+ print("Before transformation - anno_class_img shape:", anno_class_img.shape)
254
+ return img, anno_class_img
255
+
256
+
257
+ def pull_item(self, index):
258
+ '''画像のTensor形式のデータ、アノテーションを取得する'''
259
+
260
+ # 1. 画像読み込み
261
+ image_file_path = self.img_list[index]
262
+ img = Image.open(image_file_path) # [高さ][幅][色RGB]
263
+
264
+ # 2. アノテーション画像読み込み
265
+ anno_file_path = self.anno_list[index]
266
+ anno_class_img = Image.open(anno_file_path) # [高さ][幅]
267
+
268
+ # 3. 前処理を実施
269
+ img, anno_class_img = self.transform(self.phase, img, anno_class_img)
270
+ print("After transformation - img shape:", img.shape)
271
+ print("After transformation - anno_class_img shape:", anno_class_img.shape)
272
+
273
+ return img, anno_class_img