質問編集履歴
1
情報不足
test
CHANGED
File without changes
|
test
CHANGED
@@ -62,7 +62,7 @@
|
|
62
62
|
|
63
63
|
### 該当のソースコード
|
64
64
|
|
65
|
-
|
65
|
+
Alexnet.py
|
66
66
|
|
67
67
|
```Python
|
68
68
|
|
@@ -136,232 +136,550 @@
|
|
136
136
|
|
137
137
|
num_workers=4)
|
138
138
|
|
139
|
-
|
140
|
-
|
141
|
-
class AlexNet(nn.Module):
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
def __init__(self, num_classes):
|
146
|
-
|
147
|
-
super(AlexNet, self).__init__()
|
148
|
-
|
149
|
-
self.features = nn.Sequential(
|
150
|
-
|
151
|
-
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
152
|
-
|
153
|
-
nn.ReLU(inplace=True),
|
154
|
-
|
155
|
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
156
|
-
|
157
|
-
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
158
|
-
|
159
|
-
nn.ReLU(inplace=True),
|
160
|
-
|
161
|
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
162
|
-
|
163
|
-
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
164
|
-
|
165
|
-
nn.ReLU(inplace=True),
|
166
|
-
|
167
|
-
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
168
|
-
|
169
|
-
nn.ReLU(inplace=True),
|
170
|
-
|
171
|
-
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
172
|
-
|
173
|
-
nn.ReLU(inplace=True),
|
174
|
-
|
175
|
-
nn.MaxPool2d(kernel_size=2, stride=2),
|
176
|
-
|
177
|
-
)
|
178
|
-
|
179
|
-
self.classifier = nn.Sequential(
|
180
|
-
|
181
|
-
nn.Dropout(),
|
182
|
-
|
183
|
-
nn.Linear(256 * 4 * 4, 4096),
|
184
|
-
|
185
|
-
nn.ReLU(inplace=True),
|
186
|
-
|
187
|
-
nn.Dropout(),
|
188
|
-
|
189
|
-
nn.Linear(4096, 4096),
|
190
|
-
|
191
|
-
nn.ReLU(inplace=True),
|
192
|
-
|
193
|
-
nn.Linear(4096, num_classes),
|
194
|
-
|
195
|
-
)
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
def forward(self, x):
|
200
|
-
|
201
|
-
x = self.features(x)
|
202
|
-
|
203
|
-
x = x.view(x.size(0), 256 * 4 * 4)
|
204
|
-
|
205
|
-
x = self.classifier(x)
|
206
|
-
|
207
|
-
return x
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
# select device
|
212
|
-
|
213
|
-
num_classes = 4
|
214
|
-
|
215
|
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
216
|
-
|
217
|
-
net = AlexNet(num_classes).to(device)
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
# optimizing
|
222
|
-
|
223
|
-
criterion = nn.CrossEntropyLoss()
|
224
|
-
|
225
|
-
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
# training
|
230
|
-
|
231
|
-
num_epochs = 20
|
232
|
-
|
233
|
-
train_loss_list, train_acc_list, val_loss_list, val_acc_list = [], [], [], []
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
### training
|
238
|
-
|
239
|
-
for epoch in range(num_epochs):
|
240
|
-
|
241
|
-
train_loss, train_acc, val_loss, val_acc = 0, 0, 0, 0
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
# ====== train_mode ======
|
246
|
-
|
247
|
-
net.train()
|
248
|
-
|
249
|
-
for i, (images, labels) in enumerate(train_loader):
|
250
|
-
|
251
|
-
images, labels = images.to(device), labels.to(device)
|
252
|
-
|
253
|
-
optimizer.zero_grad()
|
254
|
-
|
255
|
-
outputs = net(images)
|
256
|
-
|
257
|
-
loss = criterion(outputs, labels)
|
258
|
-
|
259
|
-
train_loss += loss.item()
|
260
|
-
|
261
|
-
train_acc += (outputs.max(1)[1] == labels).sum().item()
|
262
|
-
|
263
|
-
loss.backward()
|
264
|
-
|
265
|
-
optimizer.step()
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
avg_train_loss = train_loss / len(train_loader.dataset)
|
270
|
-
|
271
|
-
avg_train_acc = train_acc / len(train_loader.dataset)
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
# ====== val_mode ======
|
276
|
-
|
277
|
-
net.eval()
|
278
|
-
|
279
|
-
with torch.no_grad():
|
280
|
-
|
281
|
-
for images, labels in test_loader:
|
282
|
-
|
283
|
-
images = images.to(device)
|
284
|
-
|
285
|
-
labels = labels.to(device)
|
286
|
-
|
287
|
-
outputs = net(images)
|
288
|
-
|
289
|
-
loss = criterion(outputs, labels)
|
290
|
-
|
291
|
-
val_loss += loss.item()
|
292
|
-
|
293
|
-
val_acc += (outputs.max(1)[1] == labels).sum().item()
|
294
|
-
|
295
|
-
avg_val_loss = val_loss / len(test_loader.dataset)
|
296
|
-
|
297
|
-
avg_val_acc = val_acc / len(test_loader.dataset)
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'
|
302
|
-
|
303
|
-
.format(epoch+1, num_epochs, i+1, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))
|
304
|
-
|
305
|
-
train_loss_list.append(avg_train_loss)
|
306
|
-
|
307
|
-
train_acc_list.append(avg_train_acc)
|
308
|
-
|
309
|
-
val_loss_list.append(avg_val_loss)
|
310
|
-
|
311
|
-
val_acc_list.append(avg_val_acc)
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
# plot graph
|
318
|
-
|
319
|
-
plt.figure()
|
320
|
-
|
321
|
-
plt.plot(range(num_epochs), train_loss_list, color='blue', linestyle='-', label='train_loss')
|
322
|
-
|
323
|
-
plt.plot(range(num_epochs), val_loss_list, color='green', linestyle='--', label='val_loss')
|
324
|
-
|
325
|
-
plt.legend()
|
326
|
-
|
327
|
-
plt.xlabel('epoch')
|
328
|
-
|
329
|
-
plt.ylabel('loss')
|
330
|
-
|
331
|
-
plt.title('Training and validation loss')
|
332
|
-
|
333
|
-
plt.grid()
|
334
|
-
|
335
|
-
plt.show()
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
plt.figure()
|
340
|
-
|
341
|
-
plt.plot(range(num_epochs), train_acc_list, color='blue', linestyle='-', label='train_acc')
|
342
|
-
|
343
|
-
plt.plot(range(num_epochs), val_acc_list, color='green', linestyle='--', label='val_acc')
|
344
|
-
|
345
|
-
plt.legend()
|
346
|
-
|
347
|
-
plt.xlabel('epoch')
|
348
|
-
|
349
|
-
plt.ylabel('acc')
|
350
|
-
|
351
|
-
plt.title('Training and validation accuracy')
|
352
|
-
|
353
|
-
plt.grid()
|
354
|
-
|
355
|
-
plt.show()
|
356
|
-
|
357
|
-
|
358
|
-
|
359
139
|
```
|
360
140
|
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
141
|
+
dataset.py
|
142
|
+
|
143
|
+
```Python
|
144
|
+
|
145
|
+
import pickle
|
146
|
+
|
147
|
+
from tqdm import tqdm
|
148
|
+
|
149
|
+
|
150
|
+
|
151
|
+
import numpy as np
|
152
|
+
|
153
|
+
import pandas as pd
|
154
|
+
|
155
|
+
|
156
|
+
|
157
|
+
import pathlib
|
158
|
+
|
159
|
+
|
160
|
+
|
161
|
+
# == Base ==
|
162
|
+
|
163
|
+
DATA_DIR = pathlib.Path('/home') / 'radiology_datas'
|
164
|
+
|
165
|
+
|
166
|
+
|
167
|
+
# == Dataset ==
|
168
|
+
|
169
|
+
ADNI1 = DATA_DIR / 'ADNI1'
|
170
|
+
|
171
|
+
ADNI2 = DATA_DIR / 'JHU-radiology' / '20170509'
|
172
|
+
|
173
|
+
ADNI2_2 = DATA_DIR / 'JHU-radiology' / 'MNI_skull_stripped' / 'output'
|
174
|
+
|
175
|
+
PPMI = DATA_DIR / 'JHU-radiology' / 'PPMI'
|
176
|
+
|
177
|
+
FourRTNI = DATA_DIR / 'JHU-radiology' / '4RTNI'
|
178
|
+
|
179
|
+
|
180
|
+
|
181
|
+
BLACKLIST_DIR = DATA_DIR / 'util' / 'lists'
|
182
|
+
|
183
|
+
|
184
|
+
|
185
|
+
DATA_CSV = {
|
186
|
+
|
187
|
+
'ADNI': DATA_DIR / 'JHU-radiology' / 'ADNIMERGE.csv',
|
188
|
+
|
189
|
+
'PPMI': DATA_DIR / 'JHU-radiology' / 'PPMI.csv',
|
190
|
+
|
191
|
+
'4RTNI': FourRTNI / 'csv' / '4RTNI_DATA.csv',
|
192
|
+
|
193
|
+
}
|
194
|
+
|
195
|
+
|
196
|
+
|
197
|
+
DATA_DIRS_DICT = {
|
198
|
+
|
199
|
+
'ADNI1': ADNI1,
|
200
|
+
|
201
|
+
'ADNI2': ADNI2,
|
202
|
+
|
203
|
+
'ADNI2-2': ADNI2_2,
|
204
|
+
|
205
|
+
'PPMI': PPMI,
|
206
|
+
|
207
|
+
'4RTNI': FourRTNI / 'SkullStripped',
|
208
|
+
|
209
|
+
}
|
210
|
+
|
211
|
+
|
212
|
+
|
213
|
+
DATA_PREFIX_DICT = {
|
214
|
+
|
215
|
+
'fullsize': 'fullsize',
|
216
|
+
|
217
|
+
'half': 'half_',
|
218
|
+
|
219
|
+
}
|
220
|
+
|
221
|
+
# == Label Encoder ==
|
222
|
+
|
223
|
+
CLASS_MAP = {
|
224
|
+
|
225
|
+
'CN': 0,
|
226
|
+
|
227
|
+
'AD': 1,
|
228
|
+
|
229
|
+
'EMCI': 2,
|
230
|
+
|
231
|
+
'LMCI': 3,
|
232
|
+
|
233
|
+
'MCI': 4,
|
234
|
+
|
235
|
+
'SMC': 5,
|
236
|
+
|
237
|
+
'Control': 6,
|
238
|
+
|
239
|
+
'FControl': 6,
|
240
|
+
|
241
|
+
'PD': 7,
|
242
|
+
|
243
|
+
'SWEDD': 8,
|
244
|
+
|
245
|
+
'Prodromal': 9,
|
246
|
+
|
247
|
+
'CBD': 10,
|
248
|
+
|
249
|
+
'PSP': 11,
|
250
|
+
|
251
|
+
'Oth': 12,
|
252
|
+
|
253
|
+
|
254
|
+
|
255
|
+
}
|
256
|
+
|
257
|
+
|
258
|
+
|
259
|
+
|
260
|
+
|
261
|
+
def read_voxel(path):
|
262
|
+
|
263
|
+
'''
|
264
|
+
|
265
|
+
pathを受け取ってvoxelを返すだけ
|
266
|
+
|
267
|
+
Args
|
268
|
+
|
269
|
+
----------
|
270
|
+
|
271
|
+
path : pathlib
|
272
|
+
|
273
|
+
pklファイルへのパス
|
274
|
+
|
275
|
+
Return
|
276
|
+
|
277
|
+
----------
|
278
|
+
|
279
|
+
voxel : numpy.array
|
280
|
+
|
281
|
+
pklファイルの中身
|
282
|
+
|
283
|
+
'''
|
284
|
+
|
285
|
+
with open(path, 'rb')as rf:
|
286
|
+
|
287
|
+
voxel = pickle.load(rf)
|
288
|
+
|
289
|
+
return np.array(voxel).astype('f')
|
290
|
+
|
291
|
+
|
292
|
+
|
293
|
+
|
294
|
+
|
295
|
+
def get_uid(path):
|
296
|
+
|
297
|
+
'''
|
298
|
+
|
299
|
+
pathを受け取ってuidを返すだけ
|
300
|
+
|
301
|
+
Args
|
302
|
+
|
303
|
+
----------
|
304
|
+
|
305
|
+
path : pathlib
|
306
|
+
|
307
|
+
pklファイルへのパス
|
308
|
+
|
309
|
+
Return
|
310
|
+
|
311
|
+
----------
|
312
|
+
|
313
|
+
uid : int
|
314
|
+
|
315
|
+
uid
|
316
|
+
|
317
|
+
'''
|
318
|
+
|
319
|
+
uid = path.name
|
320
|
+
|
321
|
+
for key, value in DATA_DIRS_DICT.items():
|
322
|
+
|
323
|
+
if str(value) in str(path):
|
324
|
+
|
325
|
+
|
326
|
+
|
327
|
+
if key == 'ADNI2':
|
328
|
+
|
329
|
+
uid = path.name.split('_')[-2]
|
330
|
+
|
331
|
+
uid = int(uid[1:])
|
332
|
+
|
333
|
+
|
334
|
+
|
335
|
+
elif key == 'ADNI2-2':
|
336
|
+
|
337
|
+
uid = path.name.split('_')[-4]
|
338
|
+
|
339
|
+
uid = int(uid[1:])
|
340
|
+
|
341
|
+
|
342
|
+
|
343
|
+
elif key == 'PPMI':
|
344
|
+
|
345
|
+
uid = path.name.split('_')[-4]
|
346
|
+
|
347
|
+
uid = int(uid)
|
348
|
+
|
349
|
+
|
350
|
+
|
351
|
+
elif key == '4RTNI':
|
352
|
+
|
353
|
+
uid = path.name.split('_')[-4]
|
354
|
+
|
355
|
+
uid = int(uid)
|
356
|
+
|
357
|
+
|
358
|
+
|
359
|
+
return uid
|
360
|
+
|
361
|
+
|
362
|
+
|
363
|
+
|
364
|
+
|
365
|
+
def collect_pids(dirs):
|
366
|
+
|
367
|
+
'''
|
368
|
+
|
369
|
+
ディレクトリ内に存在するpatiantを集める
|
370
|
+
|
371
|
+
Args
|
372
|
+
|
373
|
+
----------
|
374
|
+
|
375
|
+
path : pathlib
|
376
|
+
|
377
|
+
pklファイルへのパス
|
378
|
+
|
379
|
+
Return
|
380
|
+
|
381
|
+
----------
|
382
|
+
|
383
|
+
pid : list of str
|
384
|
+
|
385
|
+
pids
|
386
|
+
|
387
|
+
'''
|
388
|
+
|
389
|
+
patiants = []
|
390
|
+
|
391
|
+
for dir_path in dirs:
|
392
|
+
|
393
|
+
[patiants.append(f.name) for f in dir_path.iterdir()]
|
394
|
+
|
395
|
+
return patiants
|
396
|
+
|
397
|
+
|
398
|
+
|
399
|
+
|
400
|
+
|
401
|
+
def get_blacklist():
|
402
|
+
|
403
|
+
'''
|
404
|
+
|
405
|
+
brain/util/listsの中にいるblacklistたちをuidのリストで返す
|
406
|
+
|
407
|
+
Args
|
408
|
+
|
409
|
+
----------
|
410
|
+
|
411
|
+
Return
|
412
|
+
|
413
|
+
----------
|
414
|
+
|
415
|
+
uid : list of int
|
416
|
+
|
417
|
+
uids
|
418
|
+
|
419
|
+
'''
|
420
|
+
|
421
|
+
key = '**/uids.txt'
|
422
|
+
|
423
|
+
excluded_uid_paths = BLACKLIST_DIR.glob(key)
|
424
|
+
|
425
|
+
excluded_uids = []
|
426
|
+
|
427
|
+
for path in excluded_uid_paths:
|
428
|
+
|
429
|
+
with open(path, 'r') as rf:
|
430
|
+
|
431
|
+
[excluded_uids.append(int(uid.rstrip('\n'))) for uid in rf]
|
432
|
+
|
433
|
+
return excluded_uids
|
434
|
+
|
435
|
+
|
436
|
+
|
437
|
+
|
438
|
+
|
439
|
+
def load_csv_data(pids):
|
440
|
+
|
441
|
+
|
442
|
+
|
443
|
+
df = pd.read_csv(DATA_CSV['ADNI'])
|
444
|
+
|
445
|
+
adni = df[['PTID', 'AGE', 'PTGENDER']]
|
446
|
+
|
447
|
+
adni.columns = ['PID', 'AGE', 'SEX']
|
448
|
+
|
449
|
+
|
450
|
+
|
451
|
+
df = pd.read_csv(DATA_CSV['PPMI'])
|
452
|
+
|
453
|
+
ppmi = df[['Subject', 'Age', 'Sex']]
|
454
|
+
|
455
|
+
ppmi.columns = ['PID', 'AGE', 'SEX']
|
456
|
+
|
457
|
+
|
458
|
+
|
459
|
+
df = pd.read_csv(DATA_CSV['4RTNI'])
|
460
|
+
|
461
|
+
fourrtni = df[['SUBID', 'AGE_AT_TP0', 'SEX']]
|
462
|
+
|
463
|
+
fourrtni.columns = ['PID', 'AGE', 'SEX']
|
464
|
+
|
465
|
+
|
466
|
+
|
467
|
+
df = adni.append(ppmi).append(fourrtni)
|
468
|
+
|
469
|
+
df.iloc[:, 2] = df['SEX'].apply(lambda x: x[0] if x in ['Male', 'Female'] else x)
|
470
|
+
|
471
|
+
df.iloc[:, 1] = df['AGE'].apply(lambda x: int(x))
|
472
|
+
|
473
|
+
df.iloc[:, 0] = df['PID'].apply(lambda x: str(x))
|
474
|
+
|
475
|
+
|
476
|
+
|
477
|
+
return df
|
478
|
+
|
479
|
+
|
480
|
+
|
481
|
+
|
482
|
+
|
483
|
+
def load_data(
|
484
|
+
|
485
|
+
kinds=['ADNI2', 'ADNI2-2', 'PPMI', '4RTNI'],
|
486
|
+
|
487
|
+
classes=['CN', 'AD', 'MCI', 'EMCI', 'LMCI', 'SMC', 'Control', 'PD', 'SWEDD', 'Prodromal', 'PSP', 'CBD', 'Oth', 'FControl'],
|
488
|
+
|
489
|
+
size='half',
|
490
|
+
|
491
|
+
csv=False,
|
492
|
+
|
493
|
+
pids=[],
|
494
|
+
|
495
|
+
uids=[],
|
496
|
+
|
497
|
+
unique=False,
|
498
|
+
|
499
|
+
blacklist=False,
|
500
|
+
|
501
|
+
dryrun=False,
|
502
|
+
|
503
|
+
):
|
504
|
+
|
505
|
+
'''
|
506
|
+
|
507
|
+
Args
|
508
|
+
|
509
|
+
----------
|
510
|
+
|
511
|
+
kind : list
|
512
|
+
|
513
|
+
ADNI2, ADNI2-2, PPMI をリストで指定
|
514
|
+
|
515
|
+
classes : list
|
516
|
+
|
517
|
+
CN, AD, MCI, EMCI, LMCI, SMC,
|
518
|
+
|
519
|
+
Control, PD, SWEDD, Prodromal,
|
520
|
+
|
521
|
+
PSP, CBD, Oth,
|
522
|
+
|
523
|
+
をリストで指定
|
524
|
+
|
525
|
+
size : str
|
526
|
+
|
527
|
+
fullsize, half
|
528
|
+
|
529
|
+
pids : list of str
|
530
|
+
|
531
|
+
取得したい患者のpidをリストで指定
|
532
|
+
|
533
|
+
uids : list of str
|
534
|
+
|
535
|
+
取得したい患者のuidをリストで指定
|
536
|
+
|
537
|
+
unique : bool
|
538
|
+
|
539
|
+
blacklist : bool
|
540
|
+
|
541
|
+
dryrun : bool
|
542
|
+
|
543
|
+
trueの場合にvoxelを読み込まないでその他の情報だけ返す
|
544
|
+
|
545
|
+
Return
|
546
|
+
|
547
|
+
----------
|
548
|
+
|
549
|
+
dataset: list
|
550
|
+
|
551
|
+
情報がいっぱい詰まったリストだよ
|
552
|
+
|
553
|
+
'''
|
554
|
+
|
555
|
+
dirs = []
|
556
|
+
|
557
|
+
for key in kinds:
|
558
|
+
|
559
|
+
for c in classes:
|
560
|
+
|
561
|
+
dirname = DATA_DIRS_DICT[key].resolve() / c
|
562
|
+
|
563
|
+
if dirname.exists():
|
564
|
+
|
565
|
+
dirs.append(DATA_DIRS_DICT[key].resolve() / c)
|
566
|
+
|
567
|
+
|
568
|
+
|
569
|
+
dataset = []
|
570
|
+
|
571
|
+
key = '**/*' + DATA_PREFIX_DICT[size] + '*.pkl'
|
572
|
+
|
573
|
+
if dryrun:
|
574
|
+
|
575
|
+
print(f'[--DRYRUN--]')
|
576
|
+
|
577
|
+
print(f'[SIZE] {size}')
|
578
|
+
|
579
|
+
print(f'[KINDS] {kinds}')
|
580
|
+
|
581
|
+
print(f'[CLASSES] {classes}')
|
582
|
+
|
583
|
+
print(f'[PATIANT] {len(pids)} of patiants')
|
584
|
+
|
585
|
+
print(f'[TARGET] {uids}')
|
586
|
+
|
587
|
+
print(f'[UNIQUE] {unique}')
|
588
|
+
|
589
|
+
print(f'[BLACKLIST] {blacklist}')
|
590
|
+
|
591
|
+
|
592
|
+
|
593
|
+
for dir_path in dirs:
|
594
|
+
|
595
|
+
for file_path in dir_path.glob(key):
|
596
|
+
|
597
|
+
data = {}
|
598
|
+
|
599
|
+
data['uid'] = get_uid(file_path)
|
600
|
+
|
601
|
+
data['pid'] = file_path.parent.name
|
602
|
+
|
603
|
+
data['label'] = dir_path.name
|
604
|
+
|
605
|
+
data['nu_label'] = CLASS_MAP[dir_path.name]
|
606
|
+
|
607
|
+
data['path'] = file_path
|
608
|
+
|
609
|
+
dataset.append(data)
|
610
|
+
|
611
|
+
|
612
|
+
|
613
|
+
if uids:
|
614
|
+
|
615
|
+
dataset = [data for data in dataset if data['uid'] in uids]
|
616
|
+
|
617
|
+
|
618
|
+
|
619
|
+
if unique:
|
620
|
+
|
621
|
+
dataset_unique = []
|
622
|
+
|
623
|
+
for pid in collect_pids(dirs):
|
624
|
+
|
625
|
+
# pidごとにdataを取り出しそれらのuidをソートして最新のものを選択
|
626
|
+
|
627
|
+
dataset_unique.append(
|
628
|
+
|
629
|
+
sorted([data for data in dataset if data['pid'] == pid], key=lambda data: data['uid'])[-1])
|
630
|
+
|
631
|
+
dataset = dataset_unique
|
632
|
+
|
633
|
+
|
634
|
+
|
635
|
+
if pids:
|
636
|
+
|
637
|
+
dataset = [data for data in dataset if data['pid'] in pids]
|
638
|
+
|
639
|
+
|
640
|
+
|
641
|
+
if blacklist:
|
642
|
+
|
643
|
+
exclude_uids = get_blacklist()
|
644
|
+
|
645
|
+
dataset = [data for data in dataset if data['uid'] not in exclude_uids]
|
646
|
+
|
647
|
+
|
648
|
+
|
649
|
+
if dryrun:
|
650
|
+
|
651
|
+
return np.array(dataset)
|
652
|
+
|
653
|
+
|
654
|
+
|
655
|
+
if csv:
|
656
|
+
|
657
|
+
df = load_csv_data([data['pid'] for data in dataset])
|
658
|
+
|
659
|
+
[data.update(
|
660
|
+
|
661
|
+
AGE=df[df.PID == data['pid']].AGE.values[0],
|
662
|
+
|
663
|
+
SEX=df[df.PID == data['pid']].SEX.values[0],
|
664
|
+
|
665
|
+
) if data['pid'] in df.PID.values else data.update(
|
666
|
+
|
667
|
+
AGE=None,
|
668
|
+
|
669
|
+
SEX=None,
|
670
|
+
|
671
|
+
) for data in dataset]
|
672
|
+
|
673
|
+
|
674
|
+
|
675
|
+
[data.update(voxel=read_voxel(data['path'])) for data in tqdm(dataset, leave=False)]
|
676
|
+
|
677
|
+
|
678
|
+
|
679
|
+
return np.array(dataset)
|
680
|
+
|
681
|
+
|
682
|
+
|
683
|
+
|
366
684
|
|
367
685
|
```
|