teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

1

データセットの処理および定義を追加しました.

2020/07/09 14:22

投稿

nigo1973
nigo1973

スコア14

title CHANGED
File without changes
body CHANGED
@@ -128,4 +128,68 @@
128
128
  41 _, preds = torch.max(outputs, 1)
129
129
 
130
130
  AttributeError: 'tuple' object has no attribute 'size'
131
+ ```
132
+ データセットの情報は以下となります.
133
+
134
+ ```python
135
+ # 0と1のラベルした画像のDatasetを作成する
136
+
137
+
138
+ class Dataset(data.Dataset):
139
+
140
+
141
+ def __init__(self, file_list, transform=None, phase='train'):
142
+ self.file_list = file_list
143
+ self.transform = transform
144
+ self.phase = phase
145
+
146
+ def __len__(self):
147
+ '''画像の枚数を返す'''
148
+ return len(self.file_list)
149
+
150
+ def __getitem__(self, index):
151
+ '''
152
+ 前処理をした画像のTensor形式のデータとラベルを取得
153
+ '''
154
+ # index番目の画像をロード
155
+ img_path = self.file_list[index]
156
+ img = Image.open(img_path)
157
+
158
+ # 画像の前処理を実施
159
+ img_transformed = self.transform(
160
+ img, self.phase) # torch.Size([3, 224, 224])
161
+
162
+ # 画像のラベルをファイル名から抜き出す
163
+ if self.phase == "train":
164
+ label = img_path[14:16]
165
+ #print(label)
166
+ elif self.phase == "val":
167
+ label = img_path[14:16]
168
+ #print(label)
169
+
170
+ # ラベルを数値に変更する
171
+ if label == "00":
172
+ label = 0
173
+ elif label == "01":
174
+ label = 1
175
+
176
+ #print(type(label))
177
+ return img_transformed, label
178
+
179
+ # 実行
180
+ train_dataset = Dataset(
181
+ file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')
182
+
183
+ val_dataset = Dataset(
184
+ file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')
185
+
186
+ # 動作確認
187
+ index = 0
188
+ print(train_dataset.__getitem__(index)[0].size())
189
+ print(train_dataset.__getitem__(index)[1])
190
+ ```
191
+ 出力
192
+ ```output
193
+ torch.Size([3, 96, 96])
194
+ 0
131
195
  ```