質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

470閲覧

【Pythorch】AlexNetでクラス分類したいが、データを指定し読み込めない。

Hiro051

総合スコア9

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2020/05/22 07:32

データセットを読み込みたい

Python初心者です。
"dataset.py"のload_data()でADNI2というフォルダを指定しデータを読み込みたいのですが、
エラーの意味がよく理解できません。

データを読み込んだのち、テストと訓練を8:2に分けようと考えています
"AlexNet.py"で実行しています。(データ読み込み部分のコードしか書いていません)
クラスは4つに分けたいのでnum_workers=4としています。

欠けている情報ございましたら追加いたします。
どうぞよろしくお願いいたします。

発生している問題・エラーメッセージ

$ python AlexNet.py Traceback (most recent call last): File "AlexNet.py", line 4, in <module> dataset = dataset.load_data(['ADNI2']) NameError: name 'dataset' is not defined

該当のソースコード

dataset.py

Python

1import pickle 2from tqdm import tqdm 3 4import numpy as np 5import pandas as pd 6 7import pathlib 8 9# == Base == 10DATA_DIR = pathlib.Path('/data') / 'radiology_datas' 11 12# == Dataset == 13ADNI1 = DATA_DIR / 'ADNI1' 14ADNI2 = DATA_DIR / 'JHU-radiology' / '20170509' 15ADNI2_2 = DATA_DIR / 'JHU-radiology' / 'MNI_skull_stripped' / 'output' 16PPMI = DATA_DIR / 'JHU-radiology' / 'PPMI' 17FourRTNI = DATA_DIR / 'JHU-radiology' / '4RTNI' 18 19BLACKLIST_DIR = DATA_DIR / 'util' / 'lists' 20 21DATA_CSV = { 22 'ADNI': DATA_DIR / 'JHU-radiology' / 'ADNIMERGE.csv', 23 'PPMI': DATA_DIR / 'JHU-radiology' / 'PPMI.csv', 24 '4RTNI': FourRTNI / 'csv' / '4RTNI_DATA.csv', 25} 26 27DATA_DIRS_DICT = { 28 'ADNI1': ADNI1, 29 'ADNI2': ADNI2, 30 'ADNI2-2': ADNI2_2, 31 'PPMI': PPMI, 32 '4RTNI': FourRTNI / 'SkullStripped', 33} 34 35DATA_PREFIX_DICT = { 36 'fullsize': 'fullsize', 37 'half': 'half_', 38} 39# == Label Encoder == 40CLASS_MAP = { 41 'CN': 0, 42 'AD': 1, 43 'EMCI': 2, 44 'LMCI': 3, 45 'MCI': 4, 46 'SMC': 5, 47 'Control': 6, 48 'FControl': 6, 49 'PD': 7, 50 'SWEDD': 8, 51 'Prodromal': 9, 52 'CBD': 10, 53 'PSP': 11, 54 'Oth': 12, 55 56} 57 58 59def read_voxel(path): 60 ''' 61 pathを受け取ってvoxelを返すだけ 62 Args 63 ---------- 64 path : pathlib 65 pklファイルへのパス 66 Return 67 ---------- 68 voxel : numpy.array 69 pklファイルの中身 70 ''' 71 with open(path, 'rb')as rf: 72 voxel = pickle.load(rf) 73 return np.array(voxel).astype('f') 74 75 76def get_uid(path): 77 ''' 78 pathを受け取ってuidを返すだけ 79 Args 80 ---------- 81 path : pathlib 82 pklファイルへのパス 83 Return 84 ---------- 85 uid : int 86 uid 87 ''' 88 uid = path.name 89 for key, value in DATA_DIRS_DICT.items(): 90 if str(value) in str(path): 91 92 if key == 'ADNI2': 93 uid = path.name.split('_')[-2] 94 uid = int(uid[1:]) 95 96 elif key == 'ADNI2-2': 97 uid = path.name.split('_')[-4] 98 uid = int(uid[1:]) 99 100 elif key == 'PPMI': 101 uid = path.name.split('_')[-4] 102 uid = int(uid) 103 104 elif key == '4RTNI': 105 uid = path.name.split('_')[-4] 106 uid = int(uid) 107 108 return uid 109 110 111def collect_pids(dirs): 112 ''' 113 ディレクトリ内に存在するpatiantを集める 114 Args 115 ---------- 116 path : pathlib 117 pklファイルへのパス 118 Return 119 ---------- 120 pid : list of str 121 pids 122 ''' 123 patiants = [] 124 for dir_path in dirs: 125 [patiants.append(f.name) for f in dir_path.iterdir()] 126 return patiants 127 128 129def get_blacklist(): 130 ''' 131 brain/util/listsの中にいるblacklistたちをuidのリストで返す 132 Args 133 ---------- 134 Return 135 ---------- 136 uid : list of int 137 uids 138 ''' 139 key = '**/uids.txt' 140 excluded_uid_paths = BLACKLIST_DIR.glob(key) 141 excluded_uids = [] 142 for path in excluded_uid_paths: 143 with open(path, 'r') as rf: 144 [excluded_uids.append(int(uid.rstrip('\n'))) for uid in rf] 145 return excluded_uids 146 147 148def load_csv_data(pids): 149 150 df = pd.read_csv(DATA_CSV['ADNI']) 151 adni = df[['PTID', 'AGE', 'PTGENDER']] 152 adni.columns = ['PID', 'AGE', 'SEX'] 153 154 df = pd.read_csv(DATA_CSV['PPMI']) 155 ppmi = df[['Subject', 'Age', 'Sex']] 156 ppmi.columns = ['PID', 'AGE', 'SEX'] 157 158 df = pd.read_csv(DATA_CSV['4RTNI']) 159 fourrtni = df[['SUBID', 'AGE_AT_TP0', 'SEX']] 160 fourrtni.columns = ['PID', 'AGE', 'SEX'] 161 162 df = adni.append(ppmi).append(fourrtni) 163 df.iloc[:, 2] = df['SEX'].apply(lambda x: x[0] if x in ['Male', 'Female'] else x) 164 df.iloc[:, 1] = df['AGE'].apply(lambda x: int(x)) 165 df.iloc[:, 0] = df['PID'].apply(lambda x: str(x)) 166 167 return df 168 169 170def load_data( 171 kinds=['ADNI2'], 172 classes=['CN', 'AD', 'LMCI', 'MCI'], 173 size='half', 174 csv=False, 175 pids=[], 176 uids=[], 177 unique=False, 178 blacklist=False, 179 dryrun=False, 180): 181 ''' 182 Args 183 ---------- 184 kind : list 185 ADNI2, ADNI2-2, PPMI をリストで指定 186 classes : list 187 CN, AD, MCI, EMCI, LMCI, SMC, 188 Control, PD, SWEDD, Prodromal, 189 PSP, CBD, Oth, 190 をリストで指定 191 size : str 192 fullsize, half 193 pids : list of str 194 取得したい患者のpidをリストで指定 195 uids : list of str 196 取得したい患者のuidをリストで指定 197 unique : bool 198 blacklist : bool 199 dryrun : bool 200 trueの場合にvoxelを読み込まないでその他の情報だけ返す 201 Return 202 ---------- 203 dataset: list 204 情報がいっぱい詰まったリストだよ 205 ''' 206 dirs = [] 207 for key in kinds: 208 for c in classes: 209 dirname = DATA_DIRS_DICT[key].resolve() / c 210 if dirname.exists(): 211 dirs.append(DATA_DIRS_DICT[key].resolve() / c) 212 213 dataset = [] 214 key = '**/*' + DATA_PREFIX_DICT[size] + '*.pkl' 215 if dryrun: 216 print(f'[--DRYRUN--]') 217 print(f'[SIZE] {size}') 218 print(f'[KINDS] {kinds}') 219 print(f'[CLASSES] {classes}') 220 print(f'[PATIANT] {len(pids)} of patiants') 221 print(f'[TARGET] {uids}') 222 print(f'[UNIQUE] {unique}') 223 print(f'[BLACKLIST] {blacklist}') 224 225 for dir_path in dirs: 226 for file_path in dir_path.glob(key): 227 data = {} 228 data['uid'] = get_uid(file_path) 229 data['pid'] = file_path.parent.name 230 data['label'] = dir_path.name 231 data['nu_label'] = CLASS_MAP[dir_path.name] 232 data['path'] = file_path 233 dataset.append(data) 234 235 if uids: 236 dataset = [data for data in dataset if data['uid'] in uids] 237 238 if unique: 239 dataset_unique = [] 240 for pid in collect_pids(dirs): 241 # pidごとにdataを取り出しそれらのuidをソートして最新のものを選択 242 dataset_unique.append( 243 sorted([data for data in dataset if data['pid'] == pid], key=lambda data: data['uid'])[-1]) 244 dataset = dataset_unique 245 246 if pids: 247 dataset = [data for data in dataset if data['pid'] in pids] 248 249 if blacklist: 250 exclude_uids = get_blacklist() 251 dataset = [data for data in dataset if data['uid'] not in exclude_uids] 252 253 if dryrun: 254 return np.array(dataset) 255 256 if csv: 257 df = load_csv_data([data['pid'] for data in dataset]) 258 [data.update( 259 AGE=df[df.PID == data['pid']].AGE.values[0], 260 SEX=df[df.PID == data['pid']].SEX.values[0], 261 ) if data['pid'] in df.PID.values else data.update( 262 AGE=None, 263 SEX=None, 264 ) for data in dataset] 265 266 [data.update(voxel=read_voxel(data['path'])) for data in tqdm(dataset, leave=False)] 267 268 return np.array(dataset)

AlexNet.py

Python

1import torch 2from torchvision import transforms 3 4dataset = dataset.load_data(['ADNI2']) 5transform = transforms.Compose([transforms.ToTensor()]) 6 7n_train = int(len(dataset) * 0.8) 8n_test = len(dataset) - n_train 9 10train_set, test_set = torch.utils.data.random_split( 11 dataset, [n_train, n_test] 12) 13 14train_loader = torch.utils.data.DataLoader(train_set, self.batch_size, shuffle=True, num_workers=4) 15test_loader = torch.utils.data.DataLoader(test_set, self.batch_size, shuffle=False, num_workers=4)

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

AlexNet.pyの方で、datasetをimportして上げていないからではないでしょうか。

投稿2020/05/22 07:36

jeanbiego

総合スコア3966

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

Hiro051

2020/05/22 07:43

昨日今日勉強を始めたばかりなので恐縮なのですが、 具体的にどうコードを欠けばいいのでしょうか?
Hiro051

2020/05/22 07:47

失礼いたしました。 解決できました。 ありがとうございます!
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問