質問編集履歴
1
データセットの処理および定義を追加しました.
test
CHANGED
File without changes
|
test
CHANGED
@@ -259,3 +259,131 @@
|
|
259
259
|
AttributeError: 'tuple' object has no attribute 'size'
|
260
260
|
|
261
261
|
```
|
262
|
+
|
263
|
+
データセットの情報は以下となります.
|
264
|
+
|
265
|
+
|
266
|
+
|
267
|
+
```python
|
268
|
+
|
269
|
+
# 0と1のラベルした画像のDatasetを作成する
|
270
|
+
|
271
|
+
|
272
|
+
|
273
|
+
|
274
|
+
|
275
|
+
class Dataset(data.Dataset):
|
276
|
+
|
277
|
+
|
278
|
+
|
279
|
+
|
280
|
+
|
281
|
+
def __init__(self, file_list, transform=None, phase='train'):
|
282
|
+
|
283
|
+
self.file_list = file_list
|
284
|
+
|
285
|
+
self.transform = transform
|
286
|
+
|
287
|
+
self.phase = phase
|
288
|
+
|
289
|
+
|
290
|
+
|
291
|
+
def __len__(self):
|
292
|
+
|
293
|
+
'''画像の枚数を返す'''
|
294
|
+
|
295
|
+
return len(self.file_list)
|
296
|
+
|
297
|
+
|
298
|
+
|
299
|
+
def __getitem__(self, index):
|
300
|
+
|
301
|
+
'''
|
302
|
+
|
303
|
+
前処理をした画像のTensor形式のデータとラベルを取得
|
304
|
+
|
305
|
+
'''
|
306
|
+
|
307
|
+
# index番目の画像をロード
|
308
|
+
|
309
|
+
img_path = self.file_list[index]
|
310
|
+
|
311
|
+
img = Image.open(img_path)
|
312
|
+
|
313
|
+
|
314
|
+
|
315
|
+
# 画像の前処理を実施
|
316
|
+
|
317
|
+
img_transformed = self.transform(
|
318
|
+
|
319
|
+
img, self.phase) # torch.Size([3, 224, 224])
|
320
|
+
|
321
|
+
|
322
|
+
|
323
|
+
# 画像のラベルをファイル名から抜き出す
|
324
|
+
|
325
|
+
if self.phase == "train":
|
326
|
+
|
327
|
+
label = img_path[14:16]
|
328
|
+
|
329
|
+
#print(label)
|
330
|
+
|
331
|
+
elif self.phase == "val":
|
332
|
+
|
333
|
+
label = img_path[14:16]
|
334
|
+
|
335
|
+
#print(label)
|
336
|
+
|
337
|
+
|
338
|
+
|
339
|
+
# ラベルを数値に変更する
|
340
|
+
|
341
|
+
if label == "00":
|
342
|
+
|
343
|
+
label = 0
|
344
|
+
|
345
|
+
elif label == "01":
|
346
|
+
|
347
|
+
label = 1
|
348
|
+
|
349
|
+
|
350
|
+
|
351
|
+
#print(type(label))
|
352
|
+
|
353
|
+
return img_transformed, label
|
354
|
+
|
355
|
+
|
356
|
+
|
357
|
+
# 実行
|
358
|
+
|
359
|
+
train_dataset = Dataset(
|
360
|
+
|
361
|
+
file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')
|
362
|
+
|
363
|
+
|
364
|
+
|
365
|
+
val_dataset = Dataset(
|
366
|
+
|
367
|
+
file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')
|
368
|
+
|
369
|
+
|
370
|
+
|
371
|
+
# 動作確認
|
372
|
+
|
373
|
+
index = 0
|
374
|
+
|
375
|
+
print(train_dataset.__getitem__(index)[0].size())
|
376
|
+
|
377
|
+
print(train_dataset.__getitem__(index)[1])
|
378
|
+
|
379
|
+
```
|
380
|
+
|
381
|
+
出力
|
382
|
+
|
383
|
+
```output
|
384
|
+
|
385
|
+
torch.Size([3, 96, 96])
|
386
|
+
|
387
|
+
0
|
388
|
+
|
389
|
+
```
|