質問編集履歴
1
コードをかなり編集しました。どうかお願いします。
test
CHANGED
@@ -1 +1 @@
|
|
1
|
-
pytorchで自作データセットを作成したい
|
1
|
+
pytorchで自作データセットを作成し画像の分類をしたい(init,len,getitemのながれ)
|
test
CHANGED
@@ -6,63 +6,35 @@
|
|
6
6
|
|
7
7
|
|
8
8
|
|
9
|
+
|
10
|
+
|
9
11
|
### やりたいこと
|
10
12
|
|
11
13
|
・cifer10の画像分類問題を自分で集めた画像だけで実装したい
|
12
14
|
|
13
15
|
・フォルダにあるいくつかの画像を分類し、ラベルを付与することです。
|
14
16
|
|
15
|
-
|
17
|
+
ディレクトリ構成(train:val=7:3)
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
18
|
+
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
________________以下構想________________________
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
19
|
+
```ここに言語を入力
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
20
|
+
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
21
|
+
├─animal_dataset
|
36
|
-
|
37
|
-
|
38
|
-
|
22
|
+
|
39
|
-
|
23
|
+
├─train
|
40
|
-
|
41
|
-
|
24
|
+
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
・しかしどのようにlabelを用意して,どのように画像と紐づけするのかがわかりません(cifer10の画像分類チュートリアルのようにclassesの中に入れる形式をするとよいのかそれともcsvなのかjsonなのか)
|
46
|
-
|
47
|
-
・はたまたそのときの書き方がわかりません。csvの書き方自体はググれば出てきそうですが、datasetになりうるcsvファイルの書き方だったり、jsonの書き方だったりはなかなか出てこないです
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
25
|
+
│ ├─cat(70枚くらい)
|
54
|
-
|
55
|
-
|
26
|
+
|
56
|
-
|
57
|
-
・画像をスクレイピングで集めて正規化するところまでは確認した(composeなどを利用した)
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
・いくつかのサイトを参考にひな型を作成した
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
27
|
+
│ └─dog(70枚くらい)
|
28
|
+
|
29
|
+
└─val
|
30
|
+
|
31
|
+
├─cat(30枚くらい)
|
32
|
+
|
33
|
+
└─dog(30枚くらい)
|
34
|
+
|
35
|
+
```
|
36
|
+
|
37
|
+
|
66
38
|
|
67
39
|
|
68
40
|
|
@@ -74,17 +46,31 @@
|
|
74
46
|
|
75
47
|
・スクレイピングの処理は省いています
|
76
48
|
|
49
|
+
・前処理としてtrainとvalそれぞれにかける
|
50
|
+
|
51
|
+
・データセットを作成する。
|
52
|
+
|
53
|
+
・initの中身にからのdataとlabelを用意する
|
54
|
+
|
55
|
+
・画像を呼び出す処理をする??
|
56
|
+
|
57
|
+
・for分とif文でディレクトリ名?が一致していたら0/1で場合分けしてlabelのlistに格納
|
58
|
+
|
59
|
+
・これで呼び出すときlabelとして機能する?
|
60
|
+
|
61
|
+
・lenでデータ数を返す
|
62
|
+
|
77
|
-
・
|
63
|
+
・getitemでindex番目の画像をロードしたい
|
78
|
-
|
64
|
+
|
79
|
-
・
|
65
|
+
・ただ本当にロードできてる?
|
80
|
-
|
66
|
+
|
81
|
-
・
|
67
|
+
・ちゃんとlabelとして認識できてる?
|
68
|
+
|
69
|
+
|
70
|
+
|
82
|
-
|
71
|
+
全体的に
|
83
|
-
|
84
|
-
|
72
|
+
|
85
|
-
|
73
|
+
正しい書き方というのがあいまいです。もう少しきれいなコードできれいにlabel分類したいです。
|
86
|
-
|
87
|
-
・dataトlabelがどう紐づけられるのか全然イメージがつかない...
|
88
74
|
|
89
75
|
|
90
76
|
|
@@ -98,17 +84,17 @@
|
|
98
84
|
|
99
85
|
from torchvision import transforms
|
100
86
|
|
101
|
-
from torchvision import dataets, transforms
|
87
|
+
from torchvision import datasets, transforms
|
102
88
|
|
103
89
|
import numpy as np
|
104
90
|
|
105
91
|
|
106
92
|
|
107
|
-
|
93
|
+
import os
|
108
94
|
|
109
95
|
import glob
|
110
96
|
|
111
|
-
|
97
|
+
|
112
98
|
|
113
99
|
|
114
100
|
|
@@ -122,43 +108,103 @@
|
|
122
108
|
|
123
109
|
# 前処理
|
124
110
|
|
111
|
+
|
112
|
+
|
113
|
+
class MyTransform():
|
114
|
+
|
115
|
+
def __init__(self, resize, mean, std):
|
116
|
+
|
117
|
+
self.resize = resize
|
118
|
+
|
119
|
+
self.mean = mean
|
120
|
+
|
121
|
+
self.std = std
|
122
|
+
|
123
|
+
|
124
|
+
|
125
|
+
def __call__(self,img, key ='train'):
|
126
|
+
|
127
|
+
data_transform = {
|
128
|
+
|
125
|
-
tran
|
129
|
+
'train': transforms.Compose(
|
126
|
-
|
130
|
+
|
127
|
-
[transforms.Resize((256,256),
|
131
|
+
[transforms.Resize((256,256)),
|
128
|
-
|
132
|
+
|
129
|
-
transforms.ToTensor(),
|
133
|
+
transforms.ToTensor(),
|
130
|
-
|
134
|
+
|
131
|
-
transforms.
|
135
|
+
transforms.Normalize(self.mean, self.std) #標準化
|
136
|
+
|
132
|
-
|
137
|
+
]),
|
138
|
+
|
133
|
-
|
139
|
+
'val': transforms.Compose(
|
140
|
+
|
134
|
-
|
141
|
+
[transforms.Resize((256,256)),
|
142
|
+
|
135
|
-
|
143
|
+
transforms.ToTensor(),
|
136
|
-
|
144
|
+
|
137
|
-
|
145
|
+
transforms.Normalize(self.mean, self.std)
|
146
|
+
|
138
|
-
|
147
|
+
])
|
148
|
+
|
139
|
-
|
149
|
+
}
|
150
|
+
|
151
|
+
|
152
|
+
|
140
|
-
|
153
|
+
return data_transform[key](img)
|
154
|
+
|
155
|
+
|
156
|
+
|
157
|
+
|
158
|
+
|
159
|
+
|
160
|
+
|
161
|
+
|
162
|
+
|
141
|
-
|
163
|
+
# データセット作成
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
164
|
+
|
147
|
-
class
|
165
|
+
class MyDatasets(data.Dataset):
|
148
|
-
|
149
|
-
|
150
|
-
|
166
|
+
|
151
|
-
def __init__(self,
|
167
|
+
def __init__(self, path=None, key='train', transform=None):
|
152
168
|
|
153
169
|
self.transform = transform
|
154
170
|
|
171
|
+
self.key = key
|
172
|
+
|
155
|
-
|
173
|
+
self.path = path
|
174
|
+
|
156
|
-
|
175
|
+
self.data = []
|
176
|
+
|
177
|
+
self.lables = []
|
178
|
+
|
179
|
+
|
180
|
+
|
157
|
-
|
181
|
+
target_path = os.path.join(self.path + self.key + '/**/*.jpg')
|
182
|
+
|
183
|
+
|
184
|
+
|
158
|
-
|
185
|
+
for i in glob(target_path):
|
186
|
+
|
187
|
+
# データリスト作成
|
188
|
+
|
189
|
+
self.data.append(i)
|
190
|
+
|
191
|
+
|
192
|
+
|
193
|
+
#ラベルリスト作成
|
194
|
+
|
159
|
-
|
195
|
+
label = os.path.basename(os.path.dirname(i))
|
196
|
+
|
160
|
-
|
197
|
+
if label == "cat":
|
198
|
+
|
161
|
-
|
199
|
+
label = 0
|
200
|
+
|
201
|
+
elif label == "dog":
|
202
|
+
|
203
|
+
label = 1
|
204
|
+
|
205
|
+
self.lables.append(label)
|
206
|
+
|
207
|
+
|
162
208
|
|
163
209
|
|
164
210
|
|
@@ -166,30 +212,72 @@
|
|
166
212
|
|
167
213
|
def __len__(self):
|
168
214
|
|
169
|
-
|
215
|
+
return len(self.data)
|
170
|
-
|
171
|
-
pass
|
172
216
|
|
173
217
|
|
174
218
|
|
175
|
-
# dataとlabel
|
219
|
+
# dataとlabelを返すはず
|
176
220
|
|
177
221
|
def __getitem__(self, index):
|
178
222
|
|
223
|
+
# index番目の画像をロード
|
224
|
+
|
225
|
+
img_path = self.data[index]
|
226
|
+
|
227
|
+
img = Image.open(img)
|
228
|
+
|
229
|
+
|
230
|
+
|
231
|
+
img_transformed = self.transform(img, self.key)
|
232
|
+
|
179
|
-
|
233
|
+
label = self.labels[index]
|
180
234
|
|
181
235
|
|
182
236
|
|
183
237
|
|
184
238
|
|
185
|
-
return ima
|
239
|
+
return img_transformed, label
|
240
|
+
|
241
|
+
|
242
|
+
|
243
|
+
|
244
|
+
|
186
|
-
|
245
|
+
train_dataset = MyDatasets()
|
246
|
+
|
187
|
-
|
247
|
+
print(train_dataset.label)
|
188
|
-
|
189
|
-
|
248
|
+
|
249
|
+
|
250
|
+
|
251
|
+
|
252
|
+
|
253
|
+
|
254
|
+
|
255
|
+
|
190
256
|
|
191
257
|
```
|
192
258
|
|
259
|
+
### error
|
260
|
+
|
261
|
+
こちらのコードを実行した結果です。画像の読み込みに失敗しているきがしますがどのように書き換えればいいのかわかりません。
|
262
|
+
|
263
|
+
|
264
|
+
|
265
|
+
```ここに言語を入力
|
266
|
+
|
267
|
+
Traceback (most recent call last):
|
268
|
+
|
269
|
+
File "dataset.py", line 90, in <module>
|
270
|
+
|
271
|
+
train_dataset = MyDatasets()
|
272
|
+
|
273
|
+
File "dataset.py", line 59, in __init__
|
274
|
+
|
275
|
+
target_path = os.path.join(self.path + self.key + '/**/*.jpg')
|
276
|
+
|
277
|
+
TypeError: unsupported operand type(s) for +: 'NoneType' and 'str'
|
278
|
+
|
279
|
+
```
|
280
|
+
|
193
281
|
|
194
282
|
|
195
283
|
### 参考にしたサイト
|
@@ -200,12 +288,6 @@
|
|
200
288
|
|
201
289
|
|
202
290
|
|
203
|
-
[上記のチュートリアルの解説にあたる記事だが自作でのlabelの紐づけ方がないので不明](https://qiita.com/kuto/items/0ff3ccb4e089d213871d#%E3%83%8B%E3%83%A5%E3%83%BC%E3%83%A9%E3%83%AB%E3%83%8D%E3%83%83%E3%83%88%E3%83%AF%E3%83%BC%E3%82%AF%E3%81%AE%E8%A8%93%E7%B7%B4)
|
204
|
-
|
205
|
-
[似たようなことをしていたがlabelの紐づけが乗っていなかった](https://qiita.com/ryryry/items/b1da4855504dcd3f9d98#%E3%83%87%E3%83%BC%E3%82%BF%E3%83%AD%E3%83%BC%E3%83%80)
|
206
|
-
|
207
|
-
|
208
|
-
|
209
291
|
### 最後に
|
210
292
|
|
211
293
|
初心者故、ごちゃごちゃな質問をしてしまったのは否めないです。ただ理解があいまいなためうまく質問もできない状況です。
|
@@ -214,13 +296,13 @@
|
|
214
296
|
|
215
297
|
以下が聞きたいことのまとめかと思いますが補足情報などありましたらお願いします。
|
216
298
|
|
217
|
-
・
|
299
|
+
・画像の読み込み方
|
218
|
-
|
300
|
+
|
219
|
-
|
301
|
+
おそらく画像の読み込みに失敗していると思いますので別の方法があれば教えていただきたいです。
|
220
|
-
|
221
|
-
|
222
|
-
|
302
|
+
|
303
|
+
|
304
|
+
|
223
|
-
・
|
305
|
+
・init,getitemの書き方が正しいのか確認してほしいです。
|
224
306
|
|
225
307
|
前処理としてこんな内容でいいのかわかりませんが、initの中身、getitemでどんなコードを書いたら紐づけができるのか想像ができません。
|
226
308
|
|