質問編集履歴
1
追記しました。
test
CHANGED
File without changes
|
test
CHANGED
@@ -50,7 +50,7 @@
|
|
50
50
|
|
51
51
|
実行ファイルです。
|
52
52
|
|
53
|
-
|
53
|
+
推論をするためのファイルです。
|
54
54
|
|
55
55
|
|
56
56
|
|
@@ -128,7 +128,7 @@
|
|
128
128
|
|
129
129
|
train_loader = torch.utils.data.DataLoader(
|
130
130
|
|
131
|
-
train_dataset,
|
131
|
+
dataset=train_dataset,
|
132
132
|
|
133
133
|
batch_size=4,
|
134
134
|
|
@@ -152,7 +152,7 @@
|
|
152
152
|
|
153
153
|
valid_loader = torch.utils.data.DataLoader(
|
154
154
|
|
155
|
-
valid_dataset,
|
155
|
+
dataset=valid_dataset,
|
156
156
|
|
157
157
|
batch_size=4,
|
158
158
|
|
@@ -164,127 +164,267 @@
|
|
164
164
|
|
165
165
|
|
166
166
|
|
167
|
-
|
167
|
+
|
168
|
-
|
169
|
-
total = 0
|
170
168
|
|
171
169
|
|
172
170
|
|
173
171
|
acc = []
|
174
172
|
|
173
|
+
num_epoch = 100
|
174
|
+
|
175
|
+
for epoch in range(num_epoch):
|
176
|
+
|
177
|
+
correct = 0 #val_acc
|
178
|
+
|
179
|
+
total = 0
|
180
|
+
|
181
|
+
with torch.no_grad():
|
182
|
+
|
183
|
+
for data in valid_loader:
|
184
|
+
|
185
|
+
inputs, labels = data[0].to(device), data[1].to(device)
|
186
|
+
|
187
|
+
# print("label={}".format(labels))
|
188
|
+
|
189
|
+
# print("inputs={}".format(inputs))
|
190
|
+
|
191
|
+
|
192
|
+
|
193
|
+
outputs = model(inputs).to(device)
|
194
|
+
|
195
|
+
|
196
|
+
|
197
|
+
# 確率
|
198
|
+
|
199
|
+
_, predicted = torch.max(outputs.data, 1)
|
200
|
+
|
201
|
+
total += labels.size(0)
|
202
|
+
|
203
|
+
correct += (predicted == labels).sum().item()
|
204
|
+
|
205
|
+
acc.append(float(correct/total))
|
206
|
+
|
207
|
+
|
208
|
+
|
209
|
+
print("Accuracy of the network on the 100 test images: %d/%d = %.1f" % (correct, total, 100*correct/len(valid_loader.dataset)))
|
210
|
+
|
211
|
+
|
212
|
+
|
213
|
+
|
214
|
+
|
215
|
+
|
216
|
+
|
217
|
+
|
218
|
+
|
219
|
+
plt.figure()
|
220
|
+
|
221
|
+
plt.plot(range(num_epoch), acc, color='blue',linestyle='-',label='acc')
|
222
|
+
|
223
|
+
plt.legend()
|
224
|
+
|
225
|
+
plt.xlabel('epoch')
|
226
|
+
|
227
|
+
plt.ylabel('acc')
|
228
|
+
|
229
|
+
plt.grid()
|
230
|
+
|
231
|
+
plt.savefig("acc.png")
|
232
|
+
|
175
233
|
|
176
234
|
|
235
|
+
|
236
|
+
|
237
|
+
```
|
238
|
+
|
239
|
+
|
240
|
+
|
241
|
+
### 実行結果
|
242
|
+
|
243
|
+
|
244
|
+
|
245
|
+
一応、グラフを作ってみました。(追記)
|
246
|
+
|
247
|
+
![イメージ説明](6f476a632313c5c68b65625788bba398.png)
|
248
|
+
|
249
|
+
以下の部分をvalid_loader(評価用)でループを回したとき
|
250
|
+
|
251
|
+
```
|
252
|
+
|
177
253
|
with torch.no_grad():
|
178
254
|
|
179
255
|
for data in valid_loader:
|
180
256
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
# plt.savefig("acc.png")
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
# plot_acc(acc)
|
257
|
+
```
|
258
|
+
|
259
|
+
かなり精度はあれですが、こんな感じです。
|
260
|
+
|
261
|
+
```
|
262
|
+
|
263
|
+
Accuracy of the network on the 100 test images: 30/54 = 55.6
|
264
|
+
|
265
|
+
```
|
266
|
+
|
267
|
+
以下の部分をtrain_loader(評価用)でループを回したとき
|
268
|
+
|
269
|
+
```
|
270
|
+
|
271
|
+
with torch.no_grad():
|
272
|
+
|
273
|
+
for data in train_loader:
|
274
|
+
|
275
|
+
```
|
276
|
+
|
277
|
+
|
278
|
+
|
279
|
+
100%っておかしくないか....
|
280
|
+
|
281
|
+
```
|
282
|
+
|
283
|
+
Accuracy of the network on the 100 test images: 129/129 = 100.0
|
284
|
+
|
285
|
+
```
|
286
|
+
|
287
|
+
|
288
|
+
|
289
|
+
### 追記
|
290
|
+
|
291
|
+
前処理を行っているdataset.pyの添付ファイルをのせます。
|
292
|
+
|
293
|
+
```
|
294
|
+
|
295
|
+
import glob
|
296
|
+
|
297
|
+
import os
|
298
|
+
|
299
|
+
|
300
|
+
|
301
|
+
from PIL import Image
|
302
|
+
|
303
|
+
from torch.utils import data as data
|
304
|
+
|
305
|
+
from torchvision import transforms as transforms
|
306
|
+
|
307
|
+
|
240
308
|
|
241
309
|
|
242
310
|
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
f
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
311
|
+
|
312
|
+
|
313
|
+
transform = transforms.Compose([
|
314
|
+
|
315
|
+
transforms.Resize((256, 256)),
|
316
|
+
|
317
|
+
transforms.ToTensor(),
|
318
|
+
|
319
|
+
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
|
320
|
+
|
321
|
+
# transforms.RandomHorizontalFlip(),
|
322
|
+
|
323
|
+
# transforms.RandomVerticalFlip(),
|
324
|
+
|
325
|
+
# transforms.RandomRotation(degrees=30)
|
326
|
+
|
327
|
+
])
|
328
|
+
|
329
|
+
|
330
|
+
|
331
|
+
|
332
|
+
|
333
|
+
|
334
|
+
|
335
|
+
class MyDatasets(data.Dataset):
|
336
|
+
|
337
|
+
def __init__(self, root_dir, key):
|
338
|
+
|
339
|
+
self.transform = transform
|
340
|
+
|
341
|
+
self.data = []
|
342
|
+
|
343
|
+
self.labels = []
|
344
|
+
|
345
|
+
name_to_label = {"cat": 0, "dog": 1}
|
346
|
+
|
347
|
+
|
348
|
+
|
349
|
+
target_dir = os.path.join(root_dir, key, "**/*")
|
350
|
+
|
351
|
+
|
352
|
+
|
353
|
+
for path in glob.glob(target_dir):
|
354
|
+
|
355
|
+
name = os.path.basename(os.path.dirname(path))
|
356
|
+
|
357
|
+
label = name_to_label[name]
|
358
|
+
|
359
|
+
|
360
|
+
|
361
|
+
self.data.append(path)
|
362
|
+
|
363
|
+
self.labels.append(label)
|
364
|
+
|
365
|
+
|
366
|
+
|
367
|
+
def __len__(self):
|
368
|
+
|
369
|
+
return len(self.data)
|
370
|
+
|
371
|
+
|
372
|
+
|
373
|
+
def __getitem__(self, index):
|
374
|
+
|
375
|
+
img_path = self.data[index]
|
376
|
+
|
377
|
+
label = self.labels[index]
|
378
|
+
|
379
|
+
|
380
|
+
|
381
|
+
img = Image.open(img_path).convert("RGB")
|
382
|
+
|
383
|
+
|
384
|
+
|
385
|
+
img = self.transform(img)
|
386
|
+
|
387
|
+
|
388
|
+
|
389
|
+
return img, label
|
390
|
+
|
391
|
+
|
392
|
+
|
393
|
+
if __name__ == "__main__":
|
394
|
+
|
395
|
+
train_dataset = MyDatasets("./animal_dataset", "train")
|
396
|
+
|
397
|
+
train_dataloader = data.DataLoader(train_dataset, batch_size=4, shuffle= True)
|
398
|
+
|
399
|
+
|
400
|
+
|
401
|
+
for data, labels in train_dataloader:
|
402
|
+
|
403
|
+
print(data.shape, labels.shape)
|
404
|
+
|
405
|
+
datas, labels = iter(train_dataloader).next()
|
406
|
+
|
407
|
+
# print(datas[0].shape)
|
408
|
+
|
409
|
+
s=10
|
410
|
+
|
411
|
+
pic = transforms.ToPILImage(mode='RGB')(datas[s])
|
412
|
+
|
413
|
+
pic.save('./result.jpg')
|
414
|
+
|
415
|
+
if labels[s].numpy() == 0:
|
416
|
+
|
417
|
+
print("label: cat")
|
418
|
+
|
419
|
+
else:
|
420
|
+
|
421
|
+
print("label: dog")
|
422
|
+
|
423
|
+
|
424
|
+
|
425
|
+
```
|
426
|
+
|
427
|
+
なんかしっかり前処理できているか不安になりました。
|
288
428
|
|
289
429
|
|
290
430
|
|