質問編集履歴
1
修正・追加依頼に基づく改変
title
CHANGED
File without changes
|
body
CHANGED
@@ -6,13 +6,17 @@
|
|
6
6
|
|
7
7
|
```
|
8
8
|
Traceback (most recent call last):
|
9
|
-
File "/Users/○○○
|
9
|
+
File "/Users/○○○/Desktop/△△△_data/cnn.py", line 125, in <module>
|
10
|
+
train_loss_list, test_loss_list = run(30, optimizer, criterion, device)
|
11
|
+
File "/Users/○○○/Desktop/△△△_data/cnn.py", line 73, in run
|
10
|
-
|
12
|
+
train_loss = train_epoch(model, optimizer, criterion, train_loader, device)
|
11
|
-
File "/
|
13
|
+
File "/Users/○○○/Desktop/△△△_data/cnn.py", line 46, in train_epoch
|
12
|
-
|
14
|
+
outputs = model(images)
|
13
|
-
File "/usr/local/lib/python3.9/site-packages/torch/
|
15
|
+
File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
|
16
|
+
return forward_call(*input, **kwargs)
|
14
|
-
|
17
|
+
File "/Users/○○○/Desktop/△△△_data/cnn.py", line 33, in forward
|
18
|
+
x = x.view(-1, 16*5*5)
|
15
|
-
|
19
|
+
RuntimeError: shape '[-1, 400]' is invalid for input of size 2080832
|
16
20
|
```
|
17
21
|
|
18
22
|
### 該当のソースコード
|
@@ -28,49 +32,12 @@
|
|
28
32
|
import torch.nn as nn
|
29
33
|
import torch.nn.functional as F
|
30
34
|
import torch.optim as optim
|
31
|
-
import torchvision
|
32
35
|
import torchvision.transforms as transforms
|
33
36
|
|
34
|
-
import numpy as np
|
35
37
|
import matplotlib.pyplot as plt
|
36
38
|
|
39
|
+
from torchvision.datasets import ImageFolder
|
37
40
|
|
38
|
-
class ImageFolder(Dataset):
|
39
|
-
IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".bmp"]
|
40
|
-
|
41
|
-
def __init__(self, img_dir, transform=None):
|
42
|
-
# 画像ファイルのパス一覧を取得する。
|
43
|
-
self.img_paths = self._get_img_paths(img_dir)
|
44
|
-
self.transform = transform
|
45
|
-
|
46
|
-
def __getitem__(self, index):
|
47
|
-
path = self.img_paths[index]
|
48
|
-
|
49
|
-
# 画像を読み込む。
|
50
|
-
img = Image.open(path)
|
51
|
-
|
52
|
-
if self.transform is not None:
|
53
|
-
# 前処理がある場合は行う。
|
54
|
-
img = self.transform(img)
|
55
|
-
|
56
|
-
return img
|
57
|
-
|
58
|
-
def _get_img_paths(self, img_dir):
|
59
|
-
"""指定したディレクトリ内の画像ファイルのパス一覧を取得する。
|
60
|
-
"""
|
61
|
-
img_dir = Path(img_dir)
|
62
|
-
img_paths = [
|
63
|
-
p for p in img_dir.iterdir() if p.suffix in ImageFolder.IMG_EXTENSIONS
|
64
|
-
]
|
65
|
-
|
66
|
-
return img_paths
|
67
|
-
|
68
|
-
def __len__(self):
|
69
|
-
"""ディレクトリ内の画像ファイルの数を返す。
|
70
|
-
"""
|
71
|
-
return len(self.img_paths)
|
72
|
-
|
73
|
-
|
74
41
|
class Net(nn.Module):
|
75
42
|
def __init__(self):
|
76
43
|
super().__init__()
|
@@ -140,14 +107,13 @@
|
|
140
107
|
# Transform を作成する。
|
141
108
|
transform = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])
|
142
109
|
# Dataset を作成する。
|
143
|
-
dataset = ImageFolder('/Users/○○○
|
110
|
+
dataset = ImageFolder('/Users/○○○/Desktop/△△△_data/□□□_dataset', transform)
|
144
111
|
# DataLoader を作成する。
|
145
112
|
dataloader = DataLoader(dataset, batch_size=3)
|
146
113
|
|
147
|
-
for batch in dataloader:
|
114
|
+
#for batch in dataloader:
|
148
|
-
|
115
|
+
# print(batch.shape)
|
149
116
|
|
150
|
-
|
151
117
|
# グラフのスタイルを指定
|
152
118
|
plt.style.use('seaborn-darkgrid')
|
153
119
|
|
@@ -158,7 +124,6 @@
|
|
158
124
|
|
159
125
|
train_ratio = 0.8
|
160
126
|
train_size = int(train_ratio * len(dataset))
|
161
|
-
# int()で整数に。
|
162
127
|
val_size = len(dataset) - train_size
|
163
128
|
data_size = {"train":train_size, "val":val_size}
|
164
129
|
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
@@ -171,7 +136,6 @@
|
|
171
136
|
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
172
137
|
dataloaders = {"train":train_loader, "val":test_loader}
|
173
138
|
|
174
|
-
|
175
139
|
model = Net()
|
176
140
|
|
177
141
|
device = torch.device("cpu")
|
@@ -205,5 +169,25 @@
|
|
205
169
|
total += len(images)
|
206
170
|
print(f"正解率: {(correct/total)*100:.3f} %")
|
207
171
|
|
172
|
+
```
|
208
173
|
|
174
|
+
### 該当のデータセット
|
175
|
+
```
|
176
|
+
./□□□_dataset
|
177
|
+
├── 0
|
178
|
+
│ ├── 0_0.png
|
179
|
+
│ ├── 0_1.png
|
180
|
+
│ ├── ・・・
|
181
|
+
│ ├── 0_10.png
|
182
|
+
├── 1
|
183
|
+
│ ├── 1_0.png
|
184
|
+
│ ├── 1_1.png
|
185
|
+
│ ├── ・・・
|
186
|
+
│ ├── 1_10.png
|
187
|
+
├── 2
|
188
|
+
│ ├── 2_0.png
|
189
|
+
│ ├── 2_1.png
|
190
|
+
│ ├── ・・・
|
191
|
+
│ ├── 2_10.png
|
192
|
+
|
209
193
|
```
|