質問編集履歴

1

コードをかなり編集しました。どうかお願いします。

2020/08/07 07:05

投稿

oinari03
oinari03

スコア59

test CHANGED
@@ -1 +1 @@
1
- pytorchで自作データセットを作成したいgetitemや前処理、lable付与の仕方わからない
1
+ pytorchで自作データセットを作成し画像の分類をしたい(init,len,getitemのれ)
test CHANGED
@@ -6,63 +6,35 @@
6
6
 
7
7
 
8
8
 
9
+
10
+
9
11
  ### やりたいこと
10
12
 
11
13
  ・cifer10の画像分類問題を自分で集めた画像だけで実装したい
12
14
 
13
15
  ・フォルダにあるいくつかの画像を分類し、ラベルを付与することです。
14
16
 
15
- 思っている想があってほかの方法があったら教えてください!
17
+ ディレクトリtrain:val=7:3
16
-
17
-
18
-
19
- 今回聞きたいのはひとまずですが、どのようにlabelになる部分を作成すればいいのかということです。まったくわかりません。
18
+
20
-
21
-
22
-
23
- ________________以下構想________________________
24
-
25
-
26
-
27
- ・スクレイピングで画像収集(とりあえずは100枚ほどで考えています)
19
+ ```ここに言語を入力
28
-
29
-
30
-
31
- ・収集したデータを手動で分類(dataというディレクトリの中にcatというディレクトリをいれ、さらにtrainというディレクトリ、testというディレクトリに画像を分類。比率は7:3くらい)する予定です。catだけでなくdog,human,carなどのパターンも増やす予定です。
20
+
32
-
33
-
34
-
35
- trainに前処理を施して画像をきれいにしていきます。
21
+ ├─animal_dataset
36
-
37
-
38
-
22
+
39
- ここからが全然想像できないので困っています。流れが見えてきません。
23
+ ├─train
40
-
41
- ・具体的にはdatasetのclassを継承してinit,len,getitemという関数に分けてという処理をよくめにするので、説明のためにもそのように分けたいと考えています
24
+
42
-
43
-
44
-
45
- ・しかしどのようにlabelを用意して,どのように画像と紐づけするのかがわかりません(cifer10の画像分類チュートリアルのようにclassesの中に入れる形式をするとよいのかそれともcsvなのかjsonなのか)
46
-
47
- ・はたまたそのときの書き方がわかりません。csvの書き方自体はググれば出てきそうですが、datasetになりうるcsvファイルの書き方だったり、jsonの書き方だったりはなかなか出てこないです
48
-
49
-
50
-
51
-
52
-
53
- ### やったこと理解しない可能性あり
25
+ │ ├─cat70枚くらい)
54
-
55
- ・cifer10での画像分類チュートリアルを通したりして中身を確認した(ただnetworkの部分特にlabelを付与するところが理解できませんでした)
26
+
56
-
57
- ・画像をスクレイピングで集めて正規化するところまでは確認した(composeなどを利用した)
58
-
59
-
60
-
61
- ・いくつかのサイトを参考にひな型を作成した
62
-
63
-
64
-
65
- 結果的にまだまだ流れがつかめてなです。。。。
27
+ │ └─dog(70枚くら)
28
+
29
+ └─val
30
+
31
+ ├─cat(30枚くらい)
32
+
33
+ └─dog(30枚くらい)
34
+
35
+ ```
36
+
37
+
66
38
 
67
39
 
68
40
 
@@ -74,17 +46,31 @@
74
46
 
75
47
  ・スクレイピングの処理は省いています
76
48
 
49
+ ・前処理としてtrainとvalそれぞれにかける
50
+
51
+ ・データセットを作成する。
52
+
53
+ ・initの中身にからのdataとlabelを用意する
54
+
55
+ ・画像を呼び出す処理をする??
56
+
57
+ ・for分とif文でディレクトリ名?が一致していたら0/1で場合分けしてlabelのlistに格納
58
+
59
+ ・これで呼び出すときlabelとして機能する?
60
+
61
+ ・lenでデータ数を返す
62
+
77
- data/data_resizeに猫の画像が100枚入って
63
+ getitemでindex番目の画像をロードした
78
-
64
+
79
- ここは書いてないがdata/train data/testと手作業70:30くらいの比率に分け
65
+ ただ本当ロードてる
80
-
66
+
81
- 使い方がわからなかったの分けただけ
67
+ ちゃんとlabelとして認識きてる?
68
+
69
+
70
+
82
-
71
+ 全体的に
83
-
84
-
72
+
85
- ・コメントアウトにてcsvにつ書いてはいるがcsvデータセットの書き方わからなくて書いない
73
+ しい書き方とのがです。もう少しきれいコードできれにlabel分類したいです。
86
-
87
- ・dataトlabelがどう紐づけられるのか全然イメージがつかない...
88
74
 
89
75
 
90
76
 
@@ -98,17 +84,17 @@
98
84
 
99
85
  from torchvision import transforms
100
86
 
101
- from torchvision import dataets, transforms
87
+ from torchvision import datasets, transforms
102
88
 
103
89
  import numpy as np
104
90
 
105
91
 
106
92
 
107
-
93
+ import os
108
94
 
109
95
  import glob
110
96
 
111
- import cv2
97
+
112
98
 
113
99
 
114
100
 
@@ -122,43 +108,103 @@
122
108
 
123
109
  # 前処理
124
110
 
111
+
112
+
113
+ class MyTransform():
114
+
115
+ def __init__(self, resize, mean, std):
116
+
117
+ self.resize = resize
118
+
119
+ self.mean = mean
120
+
121
+ self.std = std
122
+
123
+
124
+
125
+ def __call__(self,img, key ='train'):
126
+
127
+ data_transform = {
128
+
125
- transform = transforms.Compose(
129
+ 'train': transforms.Compose(
126
-
130
+
127
- [transforms.Resize((256,256),
131
+ [transforms.Resize((256,256)),
128
-
132
+
129
- transforms.ToTensor(),
133
+ transforms.ToTensor(),
130
-
134
+
131
- transforms.RandomResizedCrop(32, scale=(1.0, 1.0), ratio=(1.0, 1.0))])
135
+ transforms.Normalize(self.mean, self.std) #標準化
136
+
132
-
137
+ ]),
138
+
133
-
139
+ 'val': transforms.Compose(
140
+
134
-
141
+ [transforms.Resize((256,256)),
142
+
135
- # rootの中に猫の画像が100マイ入っている
143
+ transforms.ToTensor(),
136
-
144
+
137
- data = torchvision.datasets.ImageFolder(root='../data/data_resize', transform=transform)
145
+ transforms.Normalize(self.mean, self.std)
146
+
138
-
147
+ ])
148
+
139
-
149
+ }
150
+
151
+
152
+
140
-
153
+ return data_transform[key](img)
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
141
-
163
+ # データセット作成
142
-
143
-
144
-
145
-
146
-
164
+
147
- class Creat_Datasets(Dataset):
165
+ class MyDatasets(data.Dataset):
148
-
149
-
150
-
166
+
151
- def __init__(self, imgpath='./data', csvpath='./csv', transform=transform):
167
+ def __init__(self, path=None, key='train', transform=None):
152
168
 
153
169
  self.transform = transform
154
170
 
171
+ self.key = key
172
+
155
- # 以下のコードは使えるかどうか不明
173
+ self.path = path
174
+
156
-
175
+ self.data = []
176
+
177
+ self.lables = []
178
+
179
+
180
+
157
- # self.imgfiles = sorted(glob('%s/*.png' % imgpath))
181
+ target_path = os.path.join(self.path + self.key + '/**/*.jpg')
182
+
183
+
184
+
158
-
185
+ for i in glob(target_path):
186
+
187
+ # データリスト作成
188
+
189
+ self.data.append(i)
190
+
191
+
192
+
193
+ #ラベルリスト作成
194
+
159
- # self.csvfiles = sorted(glob('%s/*.csv' % csvpath))
195
+ label = os.path.basename(os.path.dirname(i))
196
+
160
-
197
+ if label == "cat":
198
+
161
- pass
199
+ label = 0
200
+
201
+ elif label == "dog":
202
+
203
+ label = 1
204
+
205
+ self.lables.append(label)
206
+
207
+
162
208
 
163
209
 
164
210
 
@@ -166,30 +212,72 @@
166
212
 
167
213
  def __len__(self):
168
214
 
169
- # return len(self.csvfiles)
215
+ return len(self.data)
170
-
171
- pass
172
216
 
173
217
 
174
218
 
175
- # dataとlabelのタプルを返してほしい
219
+ # dataとlabelを返すはず
176
220
 
177
221
  def __getitem__(self, index):
178
222
 
223
+ # index番目の画像をロード
224
+
225
+ img_path = self.data[index]
226
+
227
+ img = Image.open(img)
228
+
229
+
230
+
231
+ img_transformed = self.transform(img, self.key)
232
+
179
- pass
233
+ label = self.labels[index]
180
234
 
181
235
 
182
236
 
183
237
 
184
238
 
185
- return image, label
239
+ return img_transformed, label
240
+
241
+
242
+
243
+
244
+
186
-
245
+ train_dataset = MyDatasets()
246
+
187
-
247
+ print(train_dataset.label)
188
-
189
-
248
+
249
+
250
+
251
+
252
+
253
+
254
+
255
+
190
256
 
191
257
  ```
192
258
 
259
+ ### error
260
+
261
+ こちらのコードを実行した結果です。画像の読み込みに失敗しているきがしますがどのように書き換えればいいのかわかりません。
262
+
263
+
264
+
265
+ ```ここに言語を入力
266
+
267
+ Traceback (most recent call last):
268
+
269
+ File "dataset.py", line 90, in <module>
270
+
271
+ train_dataset = MyDatasets()
272
+
273
+ File "dataset.py", line 59, in __init__
274
+
275
+ target_path = os.path.join(self.path + self.key + '/**/*.jpg')
276
+
277
+ TypeError: unsupported operand type(s) for +: 'NoneType' and 'str'
278
+
279
+ ```
280
+
193
281
 
194
282
 
195
283
  ### 参考にしたサイト
@@ -200,12 +288,6 @@
200
288
 
201
289
 
202
290
 
203
- [上記のチュートリアルの解説にあたる記事だが自作でのlabelの紐づけ方がないので不明](https://qiita.com/kuto/items/0ff3ccb4e089d213871d#%E3%83%8B%E3%83%A5%E3%83%BC%E3%83%A9%E3%83%AB%E3%83%8D%E3%83%83%E3%83%88%E3%83%AF%E3%83%BC%E3%82%AF%E3%81%AE%E8%A8%93%E7%B7%B4)
204
-
205
- [似たようなことをしていたがlabelの紐づけが乗っていなかった](https://qiita.com/ryryry/items/b1da4855504dcd3f9d98#%E3%83%87%E3%83%BC%E3%82%BF%E3%83%AD%E3%83%BC%E3%83%80)
206
-
207
-
208
-
209
291
  ### 最後に
210
292
 
211
293
  初心者故、ごちゃごちゃな質問をしてしまったのは否めないです。ただ理解があいまいなためうまく質問もできない状況です。
@@ -214,13 +296,13 @@
214
296
 
215
297
  以下が聞きたいことのまとめかと思いますが補足情報などありましたらお願いします。
216
298
 
217
- label作り
299
+ 画像読み込み
218
-
300
+
219
- dataset独自書き方なのか?、どこ配置たらいいのか?どように取り込んいのか?
301
+ おそらく画像読み込み失敗ると思ますで別方法があれば教えていたきたです。
220
-
221
-
222
-
302
+
303
+
304
+
223
- dataとlabel紐づけ
305
+ init,getitem書きが正しいのか確認してほしいです。
224
306
 
225
307
  前処理としてこんな内容でいいのかわかりませんが、initの中身、getitemでどんなコードを書いたら紐づけができるのか想像ができません。
226
308