質問編集履歴
2
情報不足
test
CHANGED
File without changes
|
test
CHANGED
@@ -20,31 +20,15 @@
|
|
20
20
|
|
21
21
|
Traceback (most recent call last):
|
22
22
|
|
23
|
-
File "FujiNet.py", line
|
23
|
+
File "FujiNet.py", line 100, in <module>
|
24
|
-
|
24
|
+
|
25
|
-
net = fjn.forward(
|
25
|
+
net = fjn.forward(data).to(device)
|
26
26
|
|
27
27
|
File "FujiNet.py", line 78, in forward
|
28
28
|
|
29
|
-
x = self.pool(x)
|
30
|
-
|
31
|
-
File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
|
32
|
-
|
33
|
-
result = self.forward(*input, **kwargs)
|
34
|
-
|
35
|
-
File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/nn/modules/pooling.py", line 217, in forward
|
36
|
-
|
37
|
-
self.return_indices)
|
38
|
-
|
39
|
-
File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/_jit_internal.py", line 133, in fn
|
40
|
-
|
41
|
-
return if_false(*args, **kwargs)
|
42
|
-
|
43
|
-
File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/nn/functional.py", line 528, in _max_pool3d
|
44
|
-
|
45
|
-
|
29
|
+
x = torch.from_numpy(x.astype(np.float32)).clone
|
46
|
-
|
30
|
+
|
47
|
-
TypeError:
|
31
|
+
TypeError: float() argument must be a string or a number, not 'dict'
|
48
32
|
|
49
33
|
```
|
50
34
|
|
@@ -56,6 +40,94 @@
|
|
56
40
|
|
57
41
|
```Python
|
58
42
|
|
43
|
+
class BrainData(Dataset):
|
44
|
+
|
45
|
+
def __init__(self, data, transform=None, class_map=CLASS_MAP):
|
46
|
+
|
47
|
+
self.data = data
|
48
|
+
|
49
|
+
self.class_map = class_map
|
50
|
+
|
51
|
+
self.transform = transform
|
52
|
+
|
53
|
+
|
54
|
+
|
55
|
+
def __len__(self):
|
56
|
+
|
57
|
+
return len(self.data)
|
58
|
+
|
59
|
+
|
60
|
+
|
61
|
+
def __getitem__(self, idx):
|
62
|
+
|
63
|
+
if torch.is_tensor(idx):
|
64
|
+
|
65
|
+
idx = idx.tolist()
|
66
|
+
|
67
|
+
|
68
|
+
|
69
|
+
voxel = self.data[idx]["voxel"]
|
70
|
+
|
71
|
+
voxel = voxel.reshape((1,voxel.shape[0],voxel.shape[1],voxel.shape[2]))
|
72
|
+
|
73
|
+
label = self.class_map[self.data[idx]["label"]]
|
74
|
+
|
75
|
+
|
76
|
+
|
77
|
+
#return sample
|
78
|
+
|
79
|
+
return (voxel, label)
|
80
|
+
|
81
|
+
|
82
|
+
|
83
|
+
data = dataset.load_data(["ADNI2"])
|
84
|
+
|
85
|
+
data_set = BrainData(data, CLASS_MAP)
|
86
|
+
|
87
|
+
|
88
|
+
|
89
|
+
n_train = int(len(data_set) * 0.8)
|
90
|
+
|
91
|
+
n_val = int(len(data_set) - n_train)
|
92
|
+
|
93
|
+
|
94
|
+
|
95
|
+
torch.manual_seed(0)
|
96
|
+
|
97
|
+
|
98
|
+
|
99
|
+
train_dataset, val_dataset = torch.utils.data.random_split(data_set, [n_train, n_val])
|
100
|
+
|
101
|
+
|
102
|
+
|
103
|
+
# set data loader
|
104
|
+
|
105
|
+
train_loader = torch.utils.data.DataLoader(
|
106
|
+
|
107
|
+
dataset=train_dataset,
|
108
|
+
|
109
|
+
batch_size=2,
|
110
|
+
|
111
|
+
shuffle=True,
|
112
|
+
|
113
|
+
num_workers=5)
|
114
|
+
|
115
|
+
|
116
|
+
|
117
|
+
val_loader = torch.utils.data.DataLoader(
|
118
|
+
|
119
|
+
dataset=val_dataset,
|
120
|
+
|
121
|
+
batch_size=2,
|
122
|
+
|
123
|
+
shuffle=False,
|
124
|
+
|
125
|
+
num_workers=5)
|
126
|
+
|
127
|
+
|
128
|
+
|
129
|
+
#class FujiNet
|
130
|
+
|
59
131
|
class FujiNet(nn.Module):
|
60
132
|
|
61
133
|
def __init__(self):
|
@@ -82,6 +154,8 @@
|
|
82
154
|
|
83
155
|
def forward(self, x):
|
84
156
|
|
157
|
+
x = torch.from_numpy(x.astype(np.float32)).clone
|
158
|
+
|
85
159
|
x = self.pool(x)
|
86
160
|
|
87
161
|
x = F.relu(self.conv1(x))
|
@@ -116,12 +190,14 @@
|
|
116
190
|
|
117
191
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
118
192
|
|
193
|
+
|
194
|
+
|
119
195
|
# インスタンス初期化
|
120
196
|
|
121
197
|
fjn = FujiNet()
|
122
198
|
|
123
199
|
# forwardに引数 num_classes を渡す
|
124
200
|
|
125
|
-
net = fjn.forward(
|
201
|
+
net = fjn.forward(data).to(device)
|
126
202
|
|
127
203
|
```
|
1
情報不足
test
CHANGED
@@ -1 +1 @@
|
|
1
|
-
|
1
|
+
変数x を tensor型に変換する書き方がわからない
|
test
CHANGED
@@ -8,7 +8,7 @@
|
|
8
8
|
|
9
9
|
どうやって解決したらいいでしょうか
|
10
10
|
|
11
|
-
|
11
|
+
変数x を tensor型に変換するのだと思いますが書き方がわかりません。
|
12
12
|
|
13
13
|
### 発生している問題・エラーメッセージ
|
14
14
|
|