質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

5376閲覧

【Pytorch】numpy.ndarray型データの使おうとするとTypeErrorになってしまう。

Hiro051

総合スコア9

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

1クリップ

投稿2020/05/25 03:17

編集2020/05/25 04:06

前提・実現したいこと

Python初心者です。
独自のnumpy.ndarray型のデータをAlexNetで学習させたいのですが、
以下のエラーが生じてしまい、修正箇所がわかりません。
ご指摘していただけますと幸いです。

AlexNet.pyで実行しています。

発生している問題・エラーメッセージ

$ python AlexNet.py Traceback (most recent call last): File "AlexNet.py", line 91, in <module> for i, (images, labels) in enumerate(train_loader): File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 582, in __next__ return self._process_next_batch(batch) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch raise batch.exc_type(batch.exc_msg) TypeError: Traceback (most recent call last): File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop samples = collate_fn([dataset[i] for i in batch_indices]) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 63, in default_collate return {key: default_collate([d[key] for d in batch]) for key in batch[0]} File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 63, in <dictcomp> return {key: default_collate([d[key] for d in batch]) for key in batch[0]} File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 70, in default_collate raise TypeError((error_msg_fmt.format(type(batch[0])))) TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'pathlib.PosixPath'>

該当のソースコード

Alexnet.py

Python

1import torch 2import torchvision 3import torch.nn as nn 4import torch.nn.init as init 5import torch.optim as optim 6import torch.nn.functional as F 7import torchvision.transforms as transforms 8import numpy as np 9from matplotlib import pyplot as plt 10 11import dataset 12from torchvision import transforms 13 14dataset = dataset.load_data(['ADNI2']) 15transform = transforms.Compose([transforms.ToTensor()]) 16 17n_train = int(len(dataset) * 0.8) 18n_test = int(len(dataset) - n_train) 19 20train_dataset, test_dataset = torch.utils.data.random_split( 21 dataset, [n_train, n_test] 22) 23 24# set data loader 25train_loader = torch.utils.data.DataLoader( 26 dataset=train_dataset, 27 batch_size=64, 28 shuffle=True, 29 num_workers=4) 30 31test_loader = torch.utils.data.DataLoader( 32 dataset=test_dataset, 33 batch_size=64, 34 shuffle=False, 35 num_workers=4)

dataset.py

Python

1import pickle 2from tqdm import tqdm 3 4import numpy as np 5import pandas as pd 6 7import pathlib 8 9# == Base == 10DATA_DIR = pathlib.Path('/home') / 'radiology_datas' 11 12# == Dataset == 13ADNI1 = DATA_DIR / 'ADNI1' 14ADNI2 = DATA_DIR / 'JHU-radiology' / '20170509' 15ADNI2_2 = DATA_DIR / 'JHU-radiology' / 'MNI_skull_stripped' / 'output' 16PPMI = DATA_DIR / 'JHU-radiology' / 'PPMI' 17FourRTNI = DATA_DIR / 'JHU-radiology' / '4RTNI' 18 19BLACKLIST_DIR = DATA_DIR / 'util' / 'lists' 20 21DATA_CSV = { 22 'ADNI': DATA_DIR / 'JHU-radiology' / 'ADNIMERGE.csv', 23 'PPMI': DATA_DIR / 'JHU-radiology' / 'PPMI.csv', 24 '4RTNI': FourRTNI / 'csv' / '4RTNI_DATA.csv', 25} 26 27DATA_DIRS_DICT = { 28 'ADNI1': ADNI1, 29 'ADNI2': ADNI2, 30 'ADNI2-2': ADNI2_2, 31 'PPMI': PPMI, 32 '4RTNI': FourRTNI / 'SkullStripped', 33} 34 35DATA_PREFIX_DICT = { 36 'fullsize': 'fullsize', 37 'half': 'half_', 38} 39# == Label Encoder == 40CLASS_MAP = { 41 'CN': 0, 42 'AD': 1, 43 'EMCI': 2, 44 'LMCI': 3, 45 'MCI': 4, 46 'SMC': 5, 47 'Control': 6, 48 'FControl': 6, 49 'PD': 7, 50 'SWEDD': 8, 51 'Prodromal': 9, 52 'CBD': 10, 53 'PSP': 11, 54 'Oth': 12, 55 56} 57 58 59def read_voxel(path): 60 ''' 61 pathを受け取ってvoxelを返すだけ 62 Args 63 ---------- 64 path : pathlib 65 pklファイルへのパス 66 Return 67 ---------- 68 voxel : numpy.array 69 pklファイルの中身 70 ''' 71 with open(path, 'rb')as rf: 72 voxel = pickle.load(rf) 73 return np.array(voxel).astype('f') 74 75 76def get_uid(path): 77 ''' 78 pathを受け取ってuidを返すだけ 79 Args 80 ---------- 81 path : pathlib 82 pklファイルへのパス 83 Return 84 ---------- 85 uid : int 86 uid 87 ''' 88 uid = path.name 89 for key, value in DATA_DIRS_DICT.items(): 90 if str(value) in str(path): 91 92 if key == 'ADNI2': 93 uid = path.name.split('_')[-2] 94 uid = int(uid[1:]) 95 96 elif key == 'ADNI2-2': 97 uid = path.name.split('_')[-4] 98 uid = int(uid[1:]) 99 100 elif key == 'PPMI': 101 uid = path.name.split('_')[-4] 102 uid = int(uid) 103 104 elif key == '4RTNI': 105 uid = path.name.split('_')[-4] 106 uid = int(uid) 107 108 return uid 109 110 111def collect_pids(dirs): 112 ''' 113 ディレクトリ内に存在するpatiantを集める 114 Args 115 ---------- 116 path : pathlib 117 pklファイルへのパス 118 Return 119 ---------- 120 pid : list of str 121 pids 122 ''' 123 patiants = [] 124 for dir_path in dirs: 125 [patiants.append(f.name) for f in dir_path.iterdir()] 126 return patiants 127 128 129def get_blacklist(): 130 ''' 131 brain/util/listsの中にいるblacklistたちをuidのリストで返す 132 Args 133 ---------- 134 Return 135 ---------- 136 uid : list of int 137 uids 138 ''' 139 key = '**/uids.txt' 140 excluded_uid_paths = BLACKLIST_DIR.glob(key) 141 excluded_uids = [] 142 for path in excluded_uid_paths: 143 with open(path, 'r') as rf: 144 [excluded_uids.append(int(uid.rstrip('\n'))) for uid in rf] 145 return excluded_uids 146 147 148def load_csv_data(pids): 149 150 df = pd.read_csv(DATA_CSV['ADNI']) 151 adni = df[['PTID', 'AGE', 'PTGENDER']] 152 adni.columns = ['PID', 'AGE', 'SEX'] 153 154 df = pd.read_csv(DATA_CSV['PPMI']) 155 ppmi = df[['Subject', 'Age', 'Sex']] 156 ppmi.columns = ['PID', 'AGE', 'SEX'] 157 158 df = pd.read_csv(DATA_CSV['4RTNI']) 159 fourrtni = df[['SUBID', 'AGE_AT_TP0', 'SEX']] 160 fourrtni.columns = ['PID', 'AGE', 'SEX'] 161 162 df = adni.append(ppmi).append(fourrtni) 163 df.iloc[:, 2] = df['SEX'].apply(lambda x: x[0] if x in ['Male', 'Female'] else x) 164 df.iloc[:, 1] = df['AGE'].apply(lambda x: int(x)) 165 df.iloc[:, 0] = df['PID'].apply(lambda x: str(x)) 166 167 return df 168 169 170def load_data( 171 kinds=['ADNI2', 'ADNI2-2', 'PPMI', '4RTNI'], 172 classes=['CN', 'AD', 'MCI', 'EMCI', 'LMCI', 'SMC', 'Control', 'PD', 'SWEDD', 'Prodromal', 'PSP', 'CBD', 'Oth', 'FControl'], 173 size='half', 174 csv=False, 175 pids=[], 176 uids=[], 177 unique=False, 178 blacklist=False, 179 dryrun=False, 180): 181 ''' 182 Args 183 ---------- 184 kind : list 185 ADNI2, ADNI2-2, PPMI をリストで指定 186 classes : list 187 CN, AD, MCI, EMCI, LMCI, SMC, 188 Control, PD, SWEDD, Prodromal, 189 PSP, CBD, Oth, 190 をリストで指定 191 size : str 192 fullsize, half 193 pids : list of str 194 取得したい患者のpidをリストで指定 195 uids : list of str 196 取得したい患者のuidをリストで指定 197 unique : bool 198 blacklist : bool 199 dryrun : bool 200 trueの場合にvoxelを読み込まないでその他の情報だけ返す 201 Return 202 ---------- 203 dataset: list 204 情報がいっぱい詰まったリストだよ 205 ''' 206 dirs = [] 207 for key in kinds: 208 for c in classes: 209 dirname = DATA_DIRS_DICT[key].resolve() / c 210 if dirname.exists(): 211 dirs.append(DATA_DIRS_DICT[key].resolve() / c) 212 213 dataset = [] 214 key = '**/*' + DATA_PREFIX_DICT[size] + '*.pkl' 215 if dryrun: 216 print(f'[--DRYRUN--]') 217 print(f'[SIZE] {size}') 218 print(f'[KINDS] {kinds}') 219 print(f'[CLASSES] {classes}') 220 print(f'[PATIANT] {len(pids)} of patiants') 221 print(f'[TARGET] {uids}') 222 print(f'[UNIQUE] {unique}') 223 print(f'[BLACKLIST] {blacklist}') 224 225 for dir_path in dirs: 226 for file_path in dir_path.glob(key): 227 data = {} 228 data['uid'] = get_uid(file_path) 229 data['pid'] = file_path.parent.name 230 data['label'] = dir_path.name 231 data['nu_label'] = CLASS_MAP[dir_path.name] 232 data['path'] = file_path 233 dataset.append(data) 234 235 if uids: 236 dataset = [data for data in dataset if data['uid'] in uids] 237 238 if unique: 239 dataset_unique = [] 240 for pid in collect_pids(dirs): 241 # pidごとにdataを取り出しそれらのuidをソートして最新のものを選択 242 dataset_unique.append( 243 sorted([data for data in dataset if data['pid'] == pid], key=lambda data: data['uid'])[-1]) 244 dataset = dataset_unique 245 246 if pids: 247 dataset = [data for data in dataset if data['pid'] in pids] 248 249 if blacklist: 250 exclude_uids = get_blacklist() 251 dataset = [data for data in dataset if data['uid'] not in exclude_uids] 252 253 if dryrun: 254 return np.array(dataset) 255 256 if csv: 257 df = load_csv_data([data['pid'] for data in dataset]) 258 [data.update( 259 AGE=df[df.PID == data['pid']].AGE.values[0], 260 SEX=df[df.PID == data['pid']].SEX.values[0], 261 ) if data['pid'] in df.PID.values else data.update( 262 AGE=None, 263 SEX=None, 264 ) for data in dataset] 265 266 [data.update(voxel=read_voxel(data['path'])) for data in tqdm(dataset, leave=False)] 267 268 return np.array(dataset) 269 270

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

import dataset は自作モジュールでしょうか。

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'pathlib.PosixPath'>

推測ですが、自作した Dataset の __getitem__() 内で pathlib.Path を返すような処理が入っていないでしょうか。その場合は、str() で Pathlib を str に変換してから返すようにしてください。

追記

Alexnet.pyを修正しカスタムのデータセットで読み込ませなきゃいけないと

言われたのですが、これはどういうことでしょうか?

まず、Pytorch で自作のデータセットを使う場合、自作のデータセットを表すクラスを定義する必要があります。
torch.utils.data.Dataset を継承したクラスを作成し、サンプルが要求されたときに返す __getitem__(index) という関数を実装する必要があります。

以下参照
Pytorch - Transforms、Dataset、DataLoader について解説

質問のコード dataset.py を見ると、そのようになっておらず、そのまま ndarray を返しているように見えます。

サンプルコードなどを参考にして、自作のデータセットクラスを作るところから始めましょう

GitHub - utkuozbulak/pytorch-custom-dataset-examples: Some custom dataset examples for PyTorch

投稿2020/05/25 03:46

編集2020/05/25 05:54
tiitoi

総合スコア21956

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

Hiro051

2020/05/25 04:09

ご回答ありがとうございます。 import dataset は自作モジュールです。 該当箇所がわからず、dataset.pyのモジュールを記載せさせていただきました。 def get_uid(path):の箇所でしょうか?
tiitoi

2020/05/25 04:27

data['path'] = file_path を data['path'] = str(file_path) としてみてはどうでしょうか
Hiro051

2020/05/25 04:52

エラーが出てしまいました。 Alexnet.pyを修正しカスタムのデータセットで読み込ませなきゃいけないと 言われたのですが、これはどういうことでしょうか?
tiitoi

2020/05/25 05:55

コードを見ると、Pytorch で自作のデータセットを使う際のやり方に則ってないので、追記の内容のように自作のデータセットクラスを定義するところから始めるといいと思います
Hiro051

2020/05/25 08:53

ご丁寧にありがとうございました。 ご指摘の通りやってみます!
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問