質問編集履歴
1
情報不足
title
CHANGED
File without changes
|
body
CHANGED
@@ -30,7 +30,7 @@
|
|
30
30
|
```
|
31
31
|
|
32
32
|
### 該当のソースコード
|
33
|
-
|
33
|
+
Alexnet.py
|
34
34
|
```Python
|
35
35
|
import torch
|
36
36
|
import torchvision
|
@@ -67,118 +67,277 @@
|
|
67
67
|
batch_size=64,
|
68
68
|
shuffle=False,
|
69
69
|
num_workers=4)
|
70
|
+
```
|
71
|
+
dataset.py
|
72
|
+
```Python
|
73
|
+
import pickle
|
74
|
+
from tqdm import tqdm
|
70
75
|
|
71
|
-
|
76
|
+
import numpy as np
|
77
|
+
import pandas as pd
|
72
78
|
|
73
|
-
def __init__(self, num_classes):
|
74
|
-
super(AlexNet, self).__init__()
|
75
|
-
self.features = nn.Sequential(
|
76
|
-
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
77
|
-
nn.ReLU(inplace=True),
|
78
|
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
79
|
-
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
80
|
-
nn.ReLU(inplace=True),
|
81
|
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
82
|
-
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
83
|
-
nn.ReLU(inplace=True),
|
84
|
-
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
85
|
-
nn.ReLU(inplace=True),
|
86
|
-
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
87
|
-
nn.ReLU(inplace=True),
|
88
|
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
89
|
-
)
|
90
|
-
self.classifier = nn.Sequential(
|
91
|
-
|
79
|
+
import pathlib
|
92
|
-
nn.Linear(256 * 4 * 4, 4096),
|
93
|
-
nn.ReLU(inplace=True),
|
94
|
-
nn.Dropout(),
|
95
|
-
nn.Linear(4096, 4096),
|
96
|
-
nn.ReLU(inplace=True),
|
97
|
-
nn.Linear(4096, num_classes),
|
98
|
-
)
|
99
80
|
|
100
|
-
|
81
|
+
# == Base ==
|
101
|
-
x = self.features(x)
|
102
|
-
|
82
|
+
DATA_DIR = pathlib.Path('/home') / 'radiology_datas'
|
103
|
-
x = self.classifier(x)
|
104
|
-
return x
|
105
83
|
|
106
|
-
#
|
84
|
+
# == Dataset ==
|
107
|
-
|
85
|
+
ADNI1 = DATA_DIR / 'ADNI1'
|
86
|
+
ADNI2 = DATA_DIR / 'JHU-radiology' / '20170509'
|
87
|
+
ADNI2_2 = DATA_DIR / 'JHU-radiology' / 'MNI_skull_stripped' / 'output'
|
108
|
-
|
88
|
+
PPMI = DATA_DIR / 'JHU-radiology' / 'PPMI'
|
109
|
-
|
89
|
+
FourRTNI = DATA_DIR / 'JHU-radiology' / '4RTNI'
|
110
90
|
|
111
|
-
# optimizing
|
112
|
-
criterion = nn.CrossEntropyLoss()
|
113
|
-
|
91
|
+
BLACKLIST_DIR = DATA_DIR / 'util' / 'lists'
|
114
92
|
|
115
|
-
# training
|
116
|
-
|
93
|
+
DATA_CSV = {
|
94
|
+
'ADNI': DATA_DIR / 'JHU-radiology' / 'ADNIMERGE.csv',
|
95
|
+
'PPMI': DATA_DIR / 'JHU-radiology' / 'PPMI.csv',
|
117
|
-
|
96
|
+
'4RTNI': FourRTNI / 'csv' / '4RTNI_DATA.csv',
|
97
|
+
}
|
118
98
|
|
119
|
-
|
99
|
+
DATA_DIRS_DICT = {
|
100
|
+
'ADNI1': ADNI1,
|
101
|
+
'ADNI2': ADNI2,
|
120
|
-
|
102
|
+
'ADNI2-2': ADNI2_2,
|
103
|
+
'PPMI': PPMI,
|
121
|
-
|
104
|
+
'4RTNI': FourRTNI / 'SkullStripped',
|
105
|
+
}
|
122
106
|
|
107
|
+
DATA_PREFIX_DICT = {
|
108
|
+
'fullsize': 'fullsize',
|
109
|
+
'half': 'half_',
|
110
|
+
}
|
123
|
-
|
111
|
+
# == Label Encoder ==
|
124
|
-
|
112
|
+
CLASS_MAP = {
|
125
|
-
for i, (images, labels) in enumerate(train_loader):
|
126
|
-
|
113
|
+
'CN': 0,
|
114
|
+
'AD': 1,
|
127
|
-
|
115
|
+
'EMCI': 2,
|
128
|
-
|
116
|
+
'LMCI': 3,
|
129
|
-
|
117
|
+
'MCI': 4,
|
118
|
+
'SMC': 5,
|
130
|
-
|
119
|
+
'Control': 6,
|
131
|
-
train_acc += (outputs.max(1)[1] == labels).sum().item()
|
132
|
-
|
120
|
+
'FControl': 6,
|
133
|
-
|
121
|
+
'PD': 7,
|
122
|
+
'SWEDD': 8,
|
123
|
+
'Prodromal': 9,
|
124
|
+
'CBD': 10,
|
125
|
+
'PSP': 11,
|
126
|
+
'Oth': 12,
|
134
127
|
|
135
|
-
avg_train_loss = train_loss / len(train_loader.dataset)
|
136
|
-
|
128
|
+
}
|
137
129
|
|
138
|
-
# ====== val_mode ======
|
139
|
-
net.eval()
|
140
|
-
with torch.no_grad():
|
141
|
-
for images, labels in test_loader:
|
142
|
-
images = images.to(device)
|
143
|
-
labels = labels.to(device)
|
144
|
-
outputs = net(images)
|
145
|
-
loss = criterion(outputs, labels)
|
146
|
-
val_loss += loss.item()
|
147
|
-
val_acc += (outputs.max(1)[1] == labels).sum().item()
|
148
|
-
avg_val_loss = val_loss / len(test_loader.dataset)
|
149
|
-
avg_val_acc = val_acc / len(test_loader.dataset)
|
150
130
|
|
131
|
+
def read_voxel(path):
|
151
|
-
|
132
|
+
'''
|
133
|
+
pathを受け取ってvoxelを返すだけ
|
134
|
+
Args
|
135
|
+
----------
|
136
|
+
path : pathlib
|
137
|
+
pklファイルへのパス
|
138
|
+
Return
|
139
|
+
----------
|
152
|
-
|
140
|
+
voxel : numpy.array
|
141
|
+
pklファイルの中身
|
142
|
+
'''
|
153
|
-
|
143
|
+
with open(path, 'rb')as rf:
|
144
|
+
voxel = pickle.load(rf)
|
154
|
-
|
145
|
+
return np.array(voxel).astype('f')
|
155
|
-
val_loss_list.append(avg_val_loss)
|
156
|
-
val_acc_list.append(avg_val_acc)
|
157
146
|
|
158
147
|
|
159
|
-
# plot graph
|
160
|
-
plt.figure()
|
161
|
-
plt.plot(range(num_epochs), train_loss_list, color='blue', linestyle='-', label='train_loss')
|
162
|
-
plt.plot(range(num_epochs), val_loss_list, color='green', linestyle='--', label='val_loss')
|
163
|
-
|
148
|
+
def get_uid(path):
|
164
|
-
|
149
|
+
'''
|
150
|
+
pathを受け取ってuidを返すだけ
|
151
|
+
Args
|
152
|
+
----------
|
153
|
+
path : pathlib
|
154
|
+
pklファイルへのパス
|
155
|
+
Return
|
156
|
+
----------
|
157
|
+
uid : int
|
158
|
+
uid
|
165
|
-
|
159
|
+
'''
|
160
|
+
uid = path.name
|
166
|
-
|
161
|
+
for key, value in DATA_DIRS_DICT.items():
|
167
|
-
|
162
|
+
if str(value) in str(path):
|
168
|
-
plt.show()
|
169
163
|
|
170
|
-
plt.figure()
|
171
|
-
plt.plot(range(num_epochs), train_acc_list, color='blue', linestyle='-', label='train_acc')
|
172
|
-
plt.plot(range(num_epochs), val_acc_list, color='green', linestyle='--', label='val_acc')
|
173
|
-
plt.legend()
|
174
|
-
|
164
|
+
if key == 'ADNI2':
|
175
|
-
|
165
|
+
uid = path.name.split('_')[-2]
|
176
|
-
plt.title('Training and validation accuracy')
|
177
|
-
|
166
|
+
uid = int(uid[1:])
|
178
|
-
plt.show()
|
179
167
|
|
168
|
+
elif key == 'ADNI2-2':
|
169
|
+
uid = path.name.split('_')[-4]
|
170
|
+
uid = int(uid[1:])
|
171
|
+
|
172
|
+
elif key == 'PPMI':
|
173
|
+
uid = path.name.split('_')[-4]
|
174
|
+
uid = int(uid)
|
175
|
+
|
176
|
+
elif key == '4RTNI':
|
177
|
+
uid = path.name.split('_')[-4]
|
178
|
+
uid = int(uid)
|
179
|
+
|
180
|
+
return uid
|
181
|
+
|
182
|
+
|
183
|
+
def collect_pids(dirs):
|
184
|
+
'''
|
185
|
+
ディレクトリ内に存在するpatiantを集める
|
180
|
-
|
186
|
+
Args
|
187
|
+
----------
|
188
|
+
path : pathlib
|
189
|
+
pklファイルへのパス
|
181
|
-
|
190
|
+
Return
|
191
|
+
----------
|
192
|
+
pid : list of str
|
193
|
+
pids
|
194
|
+
'''
|
195
|
+
patiants = []
|
196
|
+
for dir_path in dirs:
|
197
|
+
[patiants.append(f.name) for f in dir_path.iterdir()]
|
198
|
+
return patiants
|
199
|
+
|
200
|
+
|
201
|
+
def get_blacklist():
|
202
|
+
'''
|
203
|
+
brain/util/listsの中にいるblacklistたちをuidのリストで返す
|
204
|
+
Args
|
205
|
+
----------
|
206
|
+
Return
|
207
|
+
----------
|
208
|
+
uid : list of int
|
209
|
+
uids
|
210
|
+
'''
|
211
|
+
key = '**/uids.txt'
|
212
|
+
excluded_uid_paths = BLACKLIST_DIR.glob(key)
|
213
|
+
excluded_uids = []
|
214
|
+
for path in excluded_uid_paths:
|
215
|
+
with open(path, 'r') as rf:
|
216
|
+
[excluded_uids.append(int(uid.rstrip('\n'))) for uid in rf]
|
217
|
+
return excluded_uids
|
218
|
+
|
219
|
+
|
220
|
+
def load_csv_data(pids):
|
221
|
+
|
222
|
+
df = pd.read_csv(DATA_CSV['ADNI'])
|
223
|
+
adni = df[['PTID', 'AGE', 'PTGENDER']]
|
224
|
+
adni.columns = ['PID', 'AGE', 'SEX']
|
225
|
+
|
226
|
+
df = pd.read_csv(DATA_CSV['PPMI'])
|
227
|
+
ppmi = df[['Subject', 'Age', 'Sex']]
|
228
|
+
ppmi.columns = ['PID', 'AGE', 'SEX']
|
229
|
+
|
230
|
+
df = pd.read_csv(DATA_CSV['4RTNI'])
|
231
|
+
fourrtni = df[['SUBID', 'AGE_AT_TP0', 'SEX']]
|
232
|
+
fourrtni.columns = ['PID', 'AGE', 'SEX']
|
233
|
+
|
234
|
+
df = adni.append(ppmi).append(fourrtni)
|
235
|
+
df.iloc[:, 2] = df['SEX'].apply(lambda x: x[0] if x in ['Male', 'Female'] else x)
|
236
|
+
df.iloc[:, 1] = df['AGE'].apply(lambda x: int(x))
|
237
|
+
df.iloc[:, 0] = df['PID'].apply(lambda x: str(x))
|
238
|
+
|
239
|
+
return df
|
240
|
+
|
241
|
+
|
242
|
+
def load_data(
|
243
|
+
kinds=['ADNI2', 'ADNI2-2', 'PPMI', '4RTNI'],
|
244
|
+
classes=['CN', 'AD', 'MCI', 'EMCI', 'LMCI', 'SMC', 'Control', 'PD', 'SWEDD', 'Prodromal', 'PSP', 'CBD', 'Oth', 'FControl'],
|
245
|
+
size='half',
|
246
|
+
csv=False,
|
247
|
+
pids=[],
|
248
|
+
uids=[],
|
249
|
+
unique=False,
|
250
|
+
blacklist=False,
|
251
|
+
dryrun=False,
|
252
|
+
):
|
253
|
+
'''
|
254
|
+
Args
|
255
|
+
----------
|
256
|
+
kind : list
|
257
|
+
ADNI2, ADNI2-2, PPMI をリストで指定
|
258
|
+
classes : list
|
259
|
+
CN, AD, MCI, EMCI, LMCI, SMC,
|
260
|
+
Control, PD, SWEDD, Prodromal,
|
261
|
+
PSP, CBD, Oth,
|
262
|
+
をリストで指定
|
263
|
+
size : str
|
264
|
+
fullsize, half
|
265
|
+
pids : list of str
|
266
|
+
取得したい患者のpidをリストで指定
|
267
|
+
uids : list of str
|
268
|
+
取得したい患者のuidをリストで指定
|
269
|
+
unique : bool
|
270
|
+
blacklist : bool
|
271
|
+
dryrun : bool
|
272
|
+
trueの場合にvoxelを読み込まないでその他の情報だけ返す
|
273
|
+
Return
|
274
|
+
----------
|
275
|
+
dataset: list
|
276
|
+
情報がいっぱい詰まったリストだよ
|
277
|
+
'''
|
278
|
+
dirs = []
|
279
|
+
for key in kinds:
|
280
|
+
for c in classes:
|
281
|
+
dirname = DATA_DIRS_DICT[key].resolve() / c
|
282
|
+
if dirname.exists():
|
283
|
+
dirs.append(DATA_DIRS_DICT[key].resolve() / c)
|
284
|
+
|
285
|
+
dataset = []
|
286
|
+
key = '**/*' + DATA_PREFIX_DICT[size] + '*.pkl'
|
287
|
+
if dryrun:
|
288
|
+
print(f'[--DRYRUN--]')
|
289
|
+
print(f'[SIZE] {size}')
|
290
|
+
print(f'[KINDS] {kinds}')
|
291
|
+
print(f'[CLASSES] {classes}')
|
292
|
+
print(f'[PATIANT] {len(pids)} of patiants')
|
293
|
+
print(f'[TARGET] {uids}')
|
294
|
+
print(f'[UNIQUE] {unique}')
|
295
|
+
print(f'[BLACKLIST] {blacklist}')
|
296
|
+
|
297
|
+
for dir_path in dirs:
|
298
|
+
for file_path in dir_path.glob(key):
|
299
|
+
data = {}
|
300
|
+
data['uid'] = get_uid(file_path)
|
301
|
+
data['pid'] = file_path.parent.name
|
302
|
+
data['label'] = dir_path.name
|
303
|
+
data['nu_label'] = CLASS_MAP[dir_path.name]
|
304
|
+
data['path'] = file_path
|
305
|
+
dataset.append(data)
|
306
|
+
|
307
|
+
if uids:
|
308
|
+
dataset = [data for data in dataset if data['uid'] in uids]
|
309
|
+
|
310
|
+
if unique:
|
311
|
+
dataset_unique = []
|
312
|
+
for pid in collect_pids(dirs):
|
313
|
+
# pidごとにdataを取り出しそれらのuidをソートして最新のものを選択
|
314
|
+
dataset_unique.append(
|
315
|
+
sorted([data for data in dataset if data['pid'] == pid], key=lambda data: data['uid'])[-1])
|
316
|
+
dataset = dataset_unique
|
317
|
+
|
318
|
+
if pids:
|
319
|
+
dataset = [data for data in dataset if data['pid'] in pids]
|
320
|
+
|
321
|
+
if blacklist:
|
322
|
+
exclude_uids = get_blacklist()
|
323
|
+
dataset = [data for data in dataset if data['uid'] not in exclude_uids]
|
324
|
+
|
325
|
+
if dryrun:
|
182
|
-
|
326
|
+
return np.array(dataset)
|
327
|
+
|
328
|
+
if csv:
|
329
|
+
df = load_csv_data([data['pid'] for data in dataset])
|
330
|
+
[data.update(
|
331
|
+
AGE=df[df.PID == data['pid']].AGE.values[0],
|
332
|
+
SEX=df[df.PID == data['pid']].SEX.values[0],
|
333
|
+
) if data['pid'] in df.PID.values else data.update(
|
334
|
+
AGE=None,
|
335
|
+
SEX=None,
|
336
|
+
) for data in dataset]
|
337
|
+
|
338
|
+
[data.update(voxel=read_voxel(data['path'])) for data in tqdm(dataset, leave=False)]
|
339
|
+
|
183
|
-
|
340
|
+
return np.array(dataset)
|
341
|
+
|
342
|
+
|
184
343
|
```
|