teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

1

情報不足

2020/05/25 04:06

投稿

Hiro051
Hiro051

スコア9

title CHANGED
File without changes
body CHANGED
@@ -30,7 +30,7 @@
30
30
  ```
31
31
 
32
32
  ### 該当のソースコード
33
-
33
+ Alexnet.py
34
34
  ```Python
35
35
  import torch
36
36
  import torchvision
@@ -67,118 +67,277 @@
67
67
  batch_size=64,
68
68
  shuffle=False,
69
69
  num_workers=4)
70
+ ```
71
+ dataset.py
72
+ ```Python
73
+ import pickle
74
+ from tqdm import tqdm
70
75
 
71
- class AlexNet(nn.Module):
76
+ import numpy as np
77
+ import pandas as pd
72
78
 
73
- def __init__(self, num_classes):
74
- super(AlexNet, self).__init__()
75
- self.features = nn.Sequential(
76
- nn.Conv2d(3, 64, kernel_size=3, padding=1),
77
- nn.ReLU(inplace=True),
78
- nn.MaxPool2d(kernel_size=2, stride=2),
79
- nn.Conv2d(64, 192, kernel_size=5, padding=2),
80
- nn.ReLU(inplace=True),
81
- nn.MaxPool2d(kernel_size=2, stride=2),
82
- nn.Conv2d(192, 384, kernel_size=3, padding=1),
83
- nn.ReLU(inplace=True),
84
- nn.Conv2d(384, 256, kernel_size=3, padding=1),
85
- nn.ReLU(inplace=True),
86
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
87
- nn.ReLU(inplace=True),
88
- nn.MaxPool2d(kernel_size=2, stride=2),
89
- )
90
- self.classifier = nn.Sequential(
91
- nn.Dropout(),
79
+ import pathlib
92
- nn.Linear(256 * 4 * 4, 4096),
93
- nn.ReLU(inplace=True),
94
- nn.Dropout(),
95
- nn.Linear(4096, 4096),
96
- nn.ReLU(inplace=True),
97
- nn.Linear(4096, num_classes),
98
- )
99
80
 
100
- def forward(self, x):
81
+ # == Base ==
101
- x = self.features(x)
102
- x = x.view(x.size(0), 256 * 4 * 4)
82
+ DATA_DIR = pathlib.Path('/home') / 'radiology_datas'
103
- x = self.classifier(x)
104
- return x
105
83
 
106
- # select device
84
+ # == Dataset ==
107
- num_classes = 4
85
+ ADNI1 = DATA_DIR / 'ADNI1'
86
+ ADNI2 = DATA_DIR / 'JHU-radiology' / '20170509'
87
+ ADNI2_2 = DATA_DIR / 'JHU-radiology' / 'MNI_skull_stripped' / 'output'
108
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
88
+ PPMI = DATA_DIR / 'JHU-radiology' / 'PPMI'
109
- net = AlexNet(num_classes).to(device)
89
+ FourRTNI = DATA_DIR / 'JHU-radiology' / '4RTNI'
110
90
 
111
- # optimizing
112
- criterion = nn.CrossEntropyLoss()
113
- optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
91
+ BLACKLIST_DIR = DATA_DIR / 'util' / 'lists'
114
92
 
115
- # training
116
- num_epochs = 20
93
+ DATA_CSV = {
94
+ 'ADNI': DATA_DIR / 'JHU-radiology' / 'ADNIMERGE.csv',
95
+ 'PPMI': DATA_DIR / 'JHU-radiology' / 'PPMI.csv',
117
- train_loss_list, train_acc_list, val_loss_list, val_acc_list = [], [], [], []
96
+ '4RTNI': FourRTNI / 'csv' / '4RTNI_DATA.csv',
97
+ }
118
98
 
119
- ### training
99
+ DATA_DIRS_DICT = {
100
+ 'ADNI1': ADNI1,
101
+ 'ADNI2': ADNI2,
120
- for epoch in range(num_epochs):
102
+ 'ADNI2-2': ADNI2_2,
103
+ 'PPMI': PPMI,
121
- train_loss, train_acc, val_loss, val_acc = 0, 0, 0, 0
104
+ '4RTNI': FourRTNI / 'SkullStripped',
105
+ }
122
106
 
107
+ DATA_PREFIX_DICT = {
108
+ 'fullsize': 'fullsize',
109
+ 'half': 'half_',
110
+ }
123
- # ====== train_mode ======
111
+ # == Label Encoder ==
124
- net.train()
112
+ CLASS_MAP = {
125
- for i, (images, labels) in enumerate(train_loader):
126
- images, labels = images.to(device), labels.to(device)
113
+ 'CN': 0,
114
+ 'AD': 1,
127
- optimizer.zero_grad()
115
+ 'EMCI': 2,
128
- outputs = net(images)
116
+ 'LMCI': 3,
129
- loss = criterion(outputs, labels)
117
+ 'MCI': 4,
118
+ 'SMC': 5,
130
- train_loss += loss.item()
119
+ 'Control': 6,
131
- train_acc += (outputs.max(1)[1] == labels).sum().item()
132
- loss.backward()
120
+ 'FControl': 6,
133
- optimizer.step()
121
+ 'PD': 7,
122
+ 'SWEDD': 8,
123
+ 'Prodromal': 9,
124
+ 'CBD': 10,
125
+ 'PSP': 11,
126
+ 'Oth': 12,
134
127
 
135
- avg_train_loss = train_loss / len(train_loader.dataset)
136
- avg_train_acc = train_acc / len(train_loader.dataset)
128
+ }
137
129
 
138
- # ====== val_mode ======
139
- net.eval()
140
- with torch.no_grad():
141
- for images, labels in test_loader:
142
- images = images.to(device)
143
- labels = labels.to(device)
144
- outputs = net(images)
145
- loss = criterion(outputs, labels)
146
- val_loss += loss.item()
147
- val_acc += (outputs.max(1)[1] == labels).sum().item()
148
- avg_val_loss = val_loss / len(test_loader.dataset)
149
- avg_val_acc = val_acc / len(test_loader.dataset)
150
130
 
131
+ def read_voxel(path):
151
- print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'
132
+ '''
133
+ pathを受け取ってvoxelを返すだけ
134
+ Args
135
+ ----------
136
+ path : pathlib
137
+ pklファイルへのパス
138
+ Return
139
+ ----------
152
- .format(epoch+1, num_epochs, i+1, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))
140
+ voxel : numpy.array
141
+ pklファイルの中身
142
+ '''
153
- train_loss_list.append(avg_train_loss)
143
+ with open(path, 'rb')as rf:
144
+ voxel = pickle.load(rf)
154
- train_acc_list.append(avg_train_acc)
145
+ return np.array(voxel).astype('f')
155
- val_loss_list.append(avg_val_loss)
156
- val_acc_list.append(avg_val_acc)
157
146
 
158
147
 
159
- # plot graph
160
- plt.figure()
161
- plt.plot(range(num_epochs), train_loss_list, color='blue', linestyle='-', label='train_loss')
162
- plt.plot(range(num_epochs), val_loss_list, color='green', linestyle='--', label='val_loss')
163
- plt.legend()
148
+ def get_uid(path):
164
- plt.xlabel('epoch')
149
+ '''
150
+ pathを受け取ってuidを返すだけ
151
+ Args
152
+ ----------
153
+ path : pathlib
154
+ pklファイルへのパス
155
+ Return
156
+ ----------
157
+ uid : int
158
+ uid
165
- plt.ylabel('loss')
159
+ '''
160
+ uid = path.name
166
- plt.title('Training and validation loss')
161
+ for key, value in DATA_DIRS_DICT.items():
167
- plt.grid()
162
+ if str(value) in str(path):
168
- plt.show()
169
163
 
170
- plt.figure()
171
- plt.plot(range(num_epochs), train_acc_list, color='blue', linestyle='-', label='train_acc')
172
- plt.plot(range(num_epochs), val_acc_list, color='green', linestyle='--', label='val_acc')
173
- plt.legend()
174
- plt.xlabel('epoch')
164
+ if key == 'ADNI2':
175
- plt.ylabel('acc')
165
+ uid = path.name.split('_')[-2]
176
- plt.title('Training and validation accuracy')
177
- plt.grid()
166
+ uid = int(uid[1:])
178
- plt.show()
179
167
 
168
+ elif key == 'ADNI2-2':
169
+ uid = path.name.split('_')[-4]
170
+ uid = int(uid[1:])
171
+
172
+ elif key == 'PPMI':
173
+ uid = path.name.split('_')[-4]
174
+ uid = int(uid)
175
+
176
+ elif key == '4RTNI':
177
+ uid = path.name.split('_')[-4]
178
+ uid = int(uid)
179
+
180
+ return uid
181
+
182
+
183
+ def collect_pids(dirs):
184
+ '''
185
+ ディレクトリ内に存在するpatiantを集める
180
- ```
186
+ Args
187
+ ----------
188
+ path : pathlib
189
+ pklファイルへのパス
181
- ```type
190
+ Return
191
+ ----------
192
+ pid : list of str
193
+ pids
194
+ '''
195
+ patiants = []
196
+ for dir_path in dirs:
197
+ [patiants.append(f.name) for f in dir_path.iterdir()]
198
+ return patiants
199
+
200
+
201
+ def get_blacklist():
202
+ '''
203
+ brain/util/listsの中にいるblacklistたちをuidのリストで返す
204
+ Args
205
+ ----------
206
+ Return
207
+ ----------
208
+ uid : list of int
209
+ uids
210
+ '''
211
+ key = '**/uids.txt'
212
+ excluded_uid_paths = BLACKLIST_DIR.glob(key)
213
+ excluded_uids = []
214
+ for path in excluded_uid_paths:
215
+ with open(path, 'r') as rf:
216
+ [excluded_uids.append(int(uid.rstrip('\n'))) for uid in rf]
217
+ return excluded_uids
218
+
219
+
220
+ def load_csv_data(pids):
221
+
222
+ df = pd.read_csv(DATA_CSV['ADNI'])
223
+ adni = df[['PTID', 'AGE', 'PTGENDER']]
224
+ adni.columns = ['PID', 'AGE', 'SEX']
225
+
226
+ df = pd.read_csv(DATA_CSV['PPMI'])
227
+ ppmi = df[['Subject', 'Age', 'Sex']]
228
+ ppmi.columns = ['PID', 'AGE', 'SEX']
229
+
230
+ df = pd.read_csv(DATA_CSV['4RTNI'])
231
+ fourrtni = df[['SUBID', 'AGE_AT_TP0', 'SEX']]
232
+ fourrtni.columns = ['PID', 'AGE', 'SEX']
233
+
234
+ df = adni.append(ppmi).append(fourrtni)
235
+ df.iloc[:, 2] = df['SEX'].apply(lambda x: x[0] if x in ['Male', 'Female'] else x)
236
+ df.iloc[:, 1] = df['AGE'].apply(lambda x: int(x))
237
+ df.iloc[:, 0] = df['PID'].apply(lambda x: str(x))
238
+
239
+ return df
240
+
241
+
242
+ def load_data(
243
+ kinds=['ADNI2', 'ADNI2-2', 'PPMI', '4RTNI'],
244
+ classes=['CN', 'AD', 'MCI', 'EMCI', 'LMCI', 'SMC', 'Control', 'PD', 'SWEDD', 'Prodromal', 'PSP', 'CBD', 'Oth', 'FControl'],
245
+ size='half',
246
+ csv=False,
247
+ pids=[],
248
+ uids=[],
249
+ unique=False,
250
+ blacklist=False,
251
+ dryrun=False,
252
+ ):
253
+ '''
254
+ Args
255
+ ----------
256
+ kind : list
257
+ ADNI2, ADNI2-2, PPMI をリストで指定
258
+ classes : list
259
+ CN, AD, MCI, EMCI, LMCI, SMC,
260
+ Control, PD, SWEDD, Prodromal,
261
+ PSP, CBD, Oth,
262
+ をリストで指定
263
+ size : str
264
+ fullsize, half
265
+ pids : list of str
266
+ 取得したい患者のpidをリストで指定
267
+ uids : list of str
268
+ 取得したい患者のuidをリストで指定
269
+ unique : bool
270
+ blacklist : bool
271
+ dryrun : bool
272
+ trueの場合にvoxelを読み込まないでその他の情報だけ返す
273
+ Return
274
+ ----------
275
+ dataset: list
276
+ 情報がいっぱい詰まったリストだよ
277
+ '''
278
+ dirs = []
279
+ for key in kinds:
280
+ for c in classes:
281
+ dirname = DATA_DIRS_DICT[key].resolve() / c
282
+ if dirname.exists():
283
+ dirs.append(DATA_DIRS_DICT[key].resolve() / c)
284
+
285
+ dataset = []
286
+ key = '**/*' + DATA_PREFIX_DICT[size] + '*.pkl'
287
+ if dryrun:
288
+ print(f'[--DRYRUN--]')
289
+ print(f'[SIZE] {size}')
290
+ print(f'[KINDS] {kinds}')
291
+ print(f'[CLASSES] {classes}')
292
+ print(f'[PATIANT] {len(pids)} of patiants')
293
+ print(f'[TARGET] {uids}')
294
+ print(f'[UNIQUE] {unique}')
295
+ print(f'[BLACKLIST] {blacklist}')
296
+
297
+ for dir_path in dirs:
298
+ for file_path in dir_path.glob(key):
299
+ data = {}
300
+ data['uid'] = get_uid(file_path)
301
+ data['pid'] = file_path.parent.name
302
+ data['label'] = dir_path.name
303
+ data['nu_label'] = CLASS_MAP[dir_path.name]
304
+ data['path'] = file_path
305
+ dataset.append(data)
306
+
307
+ if uids:
308
+ dataset = [data for data in dataset if data['uid'] in uids]
309
+
310
+ if unique:
311
+ dataset_unique = []
312
+ for pid in collect_pids(dirs):
313
+ # pidごとにdataを取り出しそれらのuidをソートして最新のものを選択
314
+ dataset_unique.append(
315
+ sorted([data for data in dataset if data['pid'] == pid], key=lambda data: data['uid'])[-1])
316
+ dataset = dataset_unique
317
+
318
+ if pids:
319
+ dataset = [data for data in dataset if data['pid'] in pids]
320
+
321
+ if blacklist:
322
+ exclude_uids = get_blacklist()
323
+ dataset = [data for data in dataset if data['uid'] not in exclude_uids]
324
+
325
+ if dryrun:
182
- $ print(type(dataset))
326
+ return np.array(dataset)
327
+
328
+ if csv:
329
+ df = load_csv_data([data['pid'] for data in dataset])
330
+ [data.update(
331
+ AGE=df[df.PID == data['pid']].AGE.values[0],
332
+ SEX=df[df.PID == data['pid']].SEX.values[0],
333
+ ) if data['pid'] in df.PID.values else data.update(
334
+ AGE=None,
335
+ SEX=None,
336
+ ) for data in dataset]
337
+
338
+ [data.update(voxel=read_voxel(data['path'])) for data in tqdm(dataset, leave=False)]
339
+
183
- <class 'numpy.ndarray'>
340
+ return np.array(dataset)
341
+
342
+
184
343
  ```