質問編集履歴

1

情報不足

2020/05/25 04:06

投稿

Hiro051
Hiro051

スコア9

test CHANGED
File without changes
test CHANGED
@@ -62,7 +62,7 @@
62
62
 
63
63
  ### 該当のソースコード
64
64
 
65
-
65
+ Alexnet.py
66
66
 
67
67
  ```Python
68
68
 
@@ -136,232 +136,550 @@
136
136
 
137
137
  num_workers=4)
138
138
 
139
-
140
-
141
- class AlexNet(nn.Module):
142
-
143
-
144
-
145
- def __init__(self, num_classes):
146
-
147
- super(AlexNet, self).__init__()
148
-
149
- self.features = nn.Sequential(
150
-
151
- nn.Conv2d(3, 64, kernel_size=3, padding=1),
152
-
153
- nn.ReLU(inplace=True),
154
-
155
- nn.MaxPool2d(kernel_size=2, stride=2),
156
-
157
- nn.Conv2d(64, 192, kernel_size=5, padding=2),
158
-
159
- nn.ReLU(inplace=True),
160
-
161
- nn.MaxPool2d(kernel_size=2, stride=2),
162
-
163
- nn.Conv2d(192, 384, kernel_size=3, padding=1),
164
-
165
- nn.ReLU(inplace=True),
166
-
167
- nn.Conv2d(384, 256, kernel_size=3, padding=1),
168
-
169
- nn.ReLU(inplace=True),
170
-
171
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
172
-
173
- nn.ReLU(inplace=True),
174
-
175
- nn.MaxPool2d(kernel_size=2, stride=2),
176
-
177
- )
178
-
179
- self.classifier = nn.Sequential(
180
-
181
- nn.Dropout(),
182
-
183
- nn.Linear(256 * 4 * 4, 4096),
184
-
185
- nn.ReLU(inplace=True),
186
-
187
- nn.Dropout(),
188
-
189
- nn.Linear(4096, 4096),
190
-
191
- nn.ReLU(inplace=True),
192
-
193
- nn.Linear(4096, num_classes),
194
-
195
- )
196
-
197
-
198
-
199
- def forward(self, x):
200
-
201
- x = self.features(x)
202
-
203
- x = x.view(x.size(0), 256 * 4 * 4)
204
-
205
- x = self.classifier(x)
206
-
207
- return x
208
-
209
-
210
-
211
- # select device
212
-
213
- num_classes = 4
214
-
215
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
216
-
217
- net = AlexNet(num_classes).to(device)
218
-
219
-
220
-
221
- # optimizing
222
-
223
- criterion = nn.CrossEntropyLoss()
224
-
225
- optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
226
-
227
-
228
-
229
- # training
230
-
231
- num_epochs = 20
232
-
233
- train_loss_list, train_acc_list, val_loss_list, val_acc_list = [], [], [], []
234
-
235
-
236
-
237
- ### training
238
-
239
- for epoch in range(num_epochs):
240
-
241
- train_loss, train_acc, val_loss, val_acc = 0, 0, 0, 0
242
-
243
-
244
-
245
- # ====== train_mode ======
246
-
247
- net.train()
248
-
249
- for i, (images, labels) in enumerate(train_loader):
250
-
251
- images, labels = images.to(device), labels.to(device)
252
-
253
- optimizer.zero_grad()
254
-
255
- outputs = net(images)
256
-
257
- loss = criterion(outputs, labels)
258
-
259
- train_loss += loss.item()
260
-
261
- train_acc += (outputs.max(1)[1] == labels).sum().item()
262
-
263
- loss.backward()
264
-
265
- optimizer.step()
266
-
267
-
268
-
269
- avg_train_loss = train_loss / len(train_loader.dataset)
270
-
271
- avg_train_acc = train_acc / len(train_loader.dataset)
272
-
273
-
274
-
275
- # ====== val_mode ======
276
-
277
- net.eval()
278
-
279
- with torch.no_grad():
280
-
281
- for images, labels in test_loader:
282
-
283
- images = images.to(device)
284
-
285
- labels = labels.to(device)
286
-
287
- outputs = net(images)
288
-
289
- loss = criterion(outputs, labels)
290
-
291
- val_loss += loss.item()
292
-
293
- val_acc += (outputs.max(1)[1] == labels).sum().item()
294
-
295
- avg_val_loss = val_loss / len(test_loader.dataset)
296
-
297
- avg_val_acc = val_acc / len(test_loader.dataset)
298
-
299
-
300
-
301
- print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'
302
-
303
- .format(epoch+1, num_epochs, i+1, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))
304
-
305
- train_loss_list.append(avg_train_loss)
306
-
307
- train_acc_list.append(avg_train_acc)
308
-
309
- val_loss_list.append(avg_val_loss)
310
-
311
- val_acc_list.append(avg_val_acc)
312
-
313
-
314
-
315
-
316
-
317
- # plot graph
318
-
319
- plt.figure()
320
-
321
- plt.plot(range(num_epochs), train_loss_list, color='blue', linestyle='-', label='train_loss')
322
-
323
- plt.plot(range(num_epochs), val_loss_list, color='green', linestyle='--', label='val_loss')
324
-
325
- plt.legend()
326
-
327
- plt.xlabel('epoch')
328
-
329
- plt.ylabel('loss')
330
-
331
- plt.title('Training and validation loss')
332
-
333
- plt.grid()
334
-
335
- plt.show()
336
-
337
-
338
-
339
- plt.figure()
340
-
341
- plt.plot(range(num_epochs), train_acc_list, color='blue', linestyle='-', label='train_acc')
342
-
343
- plt.plot(range(num_epochs), val_acc_list, color='green', linestyle='--', label='val_acc')
344
-
345
- plt.legend()
346
-
347
- plt.xlabel('epoch')
348
-
349
- plt.ylabel('acc')
350
-
351
- plt.title('Training and validation accuracy')
352
-
353
- plt.grid()
354
-
355
- plt.show()
356
-
357
-
358
-
359
139
  ```
360
140
 
361
- ```type
362
-
363
- $ print(type(dataset))
364
-
365
- <class 'numpy.ndarray'>
141
+ dataset.py
142
+
143
+ ```Python
144
+
145
+ import pickle
146
+
147
+ from tqdm import tqdm
148
+
149
+
150
+
151
+ import numpy as np
152
+
153
+ import pandas as pd
154
+
155
+
156
+
157
+ import pathlib
158
+
159
+
160
+
161
+ # == Base ==
162
+
163
+ DATA_DIR = pathlib.Path('/home') / 'radiology_datas'
164
+
165
+
166
+
167
+ # == Dataset ==
168
+
169
+ ADNI1 = DATA_DIR / 'ADNI1'
170
+
171
+ ADNI2 = DATA_DIR / 'JHU-radiology' / '20170509'
172
+
173
+ ADNI2_2 = DATA_DIR / 'JHU-radiology' / 'MNI_skull_stripped' / 'output'
174
+
175
+ PPMI = DATA_DIR / 'JHU-radiology' / 'PPMI'
176
+
177
+ FourRTNI = DATA_DIR / 'JHU-radiology' / '4RTNI'
178
+
179
+
180
+
181
+ BLACKLIST_DIR = DATA_DIR / 'util' / 'lists'
182
+
183
+
184
+
185
+ DATA_CSV = {
186
+
187
+ 'ADNI': DATA_DIR / 'JHU-radiology' / 'ADNIMERGE.csv',
188
+
189
+ 'PPMI': DATA_DIR / 'JHU-radiology' / 'PPMI.csv',
190
+
191
+ '4RTNI': FourRTNI / 'csv' / '4RTNI_DATA.csv',
192
+
193
+ }
194
+
195
+
196
+
197
+ DATA_DIRS_DICT = {
198
+
199
+ 'ADNI1': ADNI1,
200
+
201
+ 'ADNI2': ADNI2,
202
+
203
+ 'ADNI2-2': ADNI2_2,
204
+
205
+ 'PPMI': PPMI,
206
+
207
+ '4RTNI': FourRTNI / 'SkullStripped',
208
+
209
+ }
210
+
211
+
212
+
213
+ DATA_PREFIX_DICT = {
214
+
215
+ 'fullsize': 'fullsize',
216
+
217
+ 'half': 'half_',
218
+
219
+ }
220
+
221
+ # == Label Encoder ==
222
+
223
+ CLASS_MAP = {
224
+
225
+ 'CN': 0,
226
+
227
+ 'AD': 1,
228
+
229
+ 'EMCI': 2,
230
+
231
+ 'LMCI': 3,
232
+
233
+ 'MCI': 4,
234
+
235
+ 'SMC': 5,
236
+
237
+ 'Control': 6,
238
+
239
+ 'FControl': 6,
240
+
241
+ 'PD': 7,
242
+
243
+ 'SWEDD': 8,
244
+
245
+ 'Prodromal': 9,
246
+
247
+ 'CBD': 10,
248
+
249
+ 'PSP': 11,
250
+
251
+ 'Oth': 12,
252
+
253
+
254
+
255
+ }
256
+
257
+
258
+
259
+
260
+
261
+ def read_voxel(path):
262
+
263
+ '''
264
+
265
+ pathを受け取ってvoxelを返すだけ
266
+
267
+ Args
268
+
269
+ ----------
270
+
271
+ path : pathlib
272
+
273
+ pklファイルへのパス
274
+
275
+ Return
276
+
277
+ ----------
278
+
279
+ voxel : numpy.array
280
+
281
+ pklファイルの中身
282
+
283
+ '''
284
+
285
+ with open(path, 'rb')as rf:
286
+
287
+ voxel = pickle.load(rf)
288
+
289
+ return np.array(voxel).astype('f')
290
+
291
+
292
+
293
+
294
+
295
+ def get_uid(path):
296
+
297
+ '''
298
+
299
+ pathを受け取ってuidを返すだけ
300
+
301
+ Args
302
+
303
+ ----------
304
+
305
+ path : pathlib
306
+
307
+ pklファイルへのパス
308
+
309
+ Return
310
+
311
+ ----------
312
+
313
+ uid : int
314
+
315
+ uid
316
+
317
+ '''
318
+
319
+ uid = path.name
320
+
321
+ for key, value in DATA_DIRS_DICT.items():
322
+
323
+ if str(value) in str(path):
324
+
325
+
326
+
327
+ if key == 'ADNI2':
328
+
329
+ uid = path.name.split('_')[-2]
330
+
331
+ uid = int(uid[1:])
332
+
333
+
334
+
335
+ elif key == 'ADNI2-2':
336
+
337
+ uid = path.name.split('_')[-4]
338
+
339
+ uid = int(uid[1:])
340
+
341
+
342
+
343
+ elif key == 'PPMI':
344
+
345
+ uid = path.name.split('_')[-4]
346
+
347
+ uid = int(uid)
348
+
349
+
350
+
351
+ elif key == '4RTNI':
352
+
353
+ uid = path.name.split('_')[-4]
354
+
355
+ uid = int(uid)
356
+
357
+
358
+
359
+ return uid
360
+
361
+
362
+
363
+
364
+
365
+ def collect_pids(dirs):
366
+
367
+ '''
368
+
369
+ ディレクトリ内に存在するpatiantを集める
370
+
371
+ Args
372
+
373
+ ----------
374
+
375
+ path : pathlib
376
+
377
+ pklファイルへのパス
378
+
379
+ Return
380
+
381
+ ----------
382
+
383
+ pid : list of str
384
+
385
+ pids
386
+
387
+ '''
388
+
389
+ patiants = []
390
+
391
+ for dir_path in dirs:
392
+
393
+ [patiants.append(f.name) for f in dir_path.iterdir()]
394
+
395
+ return patiants
396
+
397
+
398
+
399
+
400
+
401
+ def get_blacklist():
402
+
403
+ '''
404
+
405
+ brain/util/listsの中にいるblacklistたちをuidのリストで返す
406
+
407
+ Args
408
+
409
+ ----------
410
+
411
+ Return
412
+
413
+ ----------
414
+
415
+ uid : list of int
416
+
417
+ uids
418
+
419
+ '''
420
+
421
+ key = '**/uids.txt'
422
+
423
+ excluded_uid_paths = BLACKLIST_DIR.glob(key)
424
+
425
+ excluded_uids = []
426
+
427
+ for path in excluded_uid_paths:
428
+
429
+ with open(path, 'r') as rf:
430
+
431
+ [excluded_uids.append(int(uid.rstrip('\n'))) for uid in rf]
432
+
433
+ return excluded_uids
434
+
435
+
436
+
437
+
438
+
439
+ def load_csv_data(pids):
440
+
441
+
442
+
443
+ df = pd.read_csv(DATA_CSV['ADNI'])
444
+
445
+ adni = df[['PTID', 'AGE', 'PTGENDER']]
446
+
447
+ adni.columns = ['PID', 'AGE', 'SEX']
448
+
449
+
450
+
451
+ df = pd.read_csv(DATA_CSV['PPMI'])
452
+
453
+ ppmi = df[['Subject', 'Age', 'Sex']]
454
+
455
+ ppmi.columns = ['PID', 'AGE', 'SEX']
456
+
457
+
458
+
459
+ df = pd.read_csv(DATA_CSV['4RTNI'])
460
+
461
+ fourrtni = df[['SUBID', 'AGE_AT_TP0', 'SEX']]
462
+
463
+ fourrtni.columns = ['PID', 'AGE', 'SEX']
464
+
465
+
466
+
467
+ df = adni.append(ppmi).append(fourrtni)
468
+
469
+ df.iloc[:, 2] = df['SEX'].apply(lambda x: x[0] if x in ['Male', 'Female'] else x)
470
+
471
+ df.iloc[:, 1] = df['AGE'].apply(lambda x: int(x))
472
+
473
+ df.iloc[:, 0] = df['PID'].apply(lambda x: str(x))
474
+
475
+
476
+
477
+ return df
478
+
479
+
480
+
481
+
482
+
483
+ def load_data(
484
+
485
+ kinds=['ADNI2', 'ADNI2-2', 'PPMI', '4RTNI'],
486
+
487
+ classes=['CN', 'AD', 'MCI', 'EMCI', 'LMCI', 'SMC', 'Control', 'PD', 'SWEDD', 'Prodromal', 'PSP', 'CBD', 'Oth', 'FControl'],
488
+
489
+ size='half',
490
+
491
+ csv=False,
492
+
493
+ pids=[],
494
+
495
+ uids=[],
496
+
497
+ unique=False,
498
+
499
+ blacklist=False,
500
+
501
+ dryrun=False,
502
+
503
+ ):
504
+
505
+ '''
506
+
507
+ Args
508
+
509
+ ----------
510
+
511
+ kind : list
512
+
513
+ ADNI2, ADNI2-2, PPMI をリストで指定
514
+
515
+ classes : list
516
+
517
+ CN, AD, MCI, EMCI, LMCI, SMC,
518
+
519
+ Control, PD, SWEDD, Prodromal,
520
+
521
+ PSP, CBD, Oth,
522
+
523
+ をリストで指定
524
+
525
+ size : str
526
+
527
+ fullsize, half
528
+
529
+ pids : list of str
530
+
531
+ 取得したい患者のpidをリストで指定
532
+
533
+ uids : list of str
534
+
535
+ 取得したい患者のuidをリストで指定
536
+
537
+ unique : bool
538
+
539
+ blacklist : bool
540
+
541
+ dryrun : bool
542
+
543
+ trueの場合にvoxelを読み込まないでその他の情報だけ返す
544
+
545
+ Return
546
+
547
+ ----------
548
+
549
+ dataset: list
550
+
551
+ 情報がいっぱい詰まったリストだよ
552
+
553
+ '''
554
+
555
+ dirs = []
556
+
557
+ for key in kinds:
558
+
559
+ for c in classes:
560
+
561
+ dirname = DATA_DIRS_DICT[key].resolve() / c
562
+
563
+ if dirname.exists():
564
+
565
+ dirs.append(DATA_DIRS_DICT[key].resolve() / c)
566
+
567
+
568
+
569
+ dataset = []
570
+
571
+ key = '**/*' + DATA_PREFIX_DICT[size] + '*.pkl'
572
+
573
+ if dryrun:
574
+
575
+ print(f'[--DRYRUN--]')
576
+
577
+ print(f'[SIZE] {size}')
578
+
579
+ print(f'[KINDS] {kinds}')
580
+
581
+ print(f'[CLASSES] {classes}')
582
+
583
+ print(f'[PATIANT] {len(pids)} of patiants')
584
+
585
+ print(f'[TARGET] {uids}')
586
+
587
+ print(f'[UNIQUE] {unique}')
588
+
589
+ print(f'[BLACKLIST] {blacklist}')
590
+
591
+
592
+
593
+ for dir_path in dirs:
594
+
595
+ for file_path in dir_path.glob(key):
596
+
597
+ data = {}
598
+
599
+ data['uid'] = get_uid(file_path)
600
+
601
+ data['pid'] = file_path.parent.name
602
+
603
+ data['label'] = dir_path.name
604
+
605
+ data['nu_label'] = CLASS_MAP[dir_path.name]
606
+
607
+ data['path'] = file_path
608
+
609
+ dataset.append(data)
610
+
611
+
612
+
613
+ if uids:
614
+
615
+ dataset = [data for data in dataset if data['uid'] in uids]
616
+
617
+
618
+
619
+ if unique:
620
+
621
+ dataset_unique = []
622
+
623
+ for pid in collect_pids(dirs):
624
+
625
+ # pidごとにdataを取り出しそれらのuidをソートして最新のものを選択
626
+
627
+ dataset_unique.append(
628
+
629
+ sorted([data for data in dataset if data['pid'] == pid], key=lambda data: data['uid'])[-1])
630
+
631
+ dataset = dataset_unique
632
+
633
+
634
+
635
+ if pids:
636
+
637
+ dataset = [data for data in dataset if data['pid'] in pids]
638
+
639
+
640
+
641
+ if blacklist:
642
+
643
+ exclude_uids = get_blacklist()
644
+
645
+ dataset = [data for data in dataset if data['uid'] not in exclude_uids]
646
+
647
+
648
+
649
+ if dryrun:
650
+
651
+ return np.array(dataset)
652
+
653
+
654
+
655
+ if csv:
656
+
657
+ df = load_csv_data([data['pid'] for data in dataset])
658
+
659
+ [data.update(
660
+
661
+ AGE=df[df.PID == data['pid']].AGE.values[0],
662
+
663
+ SEX=df[df.PID == data['pid']].SEX.values[0],
664
+
665
+ ) if data['pid'] in df.PID.values else data.update(
666
+
667
+ AGE=None,
668
+
669
+ SEX=None,
670
+
671
+ ) for data in dataset]
672
+
673
+
674
+
675
+ [data.update(voxel=read_voxel(data['path'])) for data in tqdm(dataset, leave=False)]
676
+
677
+
678
+
679
+ return np.array(dataset)
680
+
681
+
682
+
683
+
366
684
 
367
685
  ```