前提
参考書籍のpytorchのサンプルコードでDataloaderが生成したtensorデータの先頭のイメージを調べています。
pytorchの公式サイトから、参考になる個所を抜き出し下記のコードを作りました。
エラーの行で
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>
となります。
調査用のこの数行が無い場合は、それなりに全epoch分実行されます(エラーとはならない)
#先頭イメージのデータを見るとした行は正しいイメージデータがプリントされます(8行目の狙った1行のみ)
なおFashionMNISTとありますが、自前のデータを、本来の場所に安直に上書きしたものです。(自前のdetaloader作成手引きもあるようですが今後の課題です)
エラー型の不一致のようですが、原因がわかりません。
別な質問ですが、datasetsやDataLoaderが生成するデータはpytorch用のtensorなる形式だそうですが、
numpyのようなshapeが使えませんが、それらに相当するメッソドなりの公式サイトがみつけらえませんでした。
以下の__len__などは断片的なサイトを参考にしましたので、datasetsクラス(?)のメソッド(?)の説明ページもわかりません。(どんなのがあるか?)
補足
エラーのあるnext(iter(train_dataloader))
の意味は、ここでは生成したデータの先頭のイメージを1個取り出すことと理解しています。
発生している問題・エラーメッセージ
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>
f_mnist_train = datasets.FashionMNIST(
root=root, # データの保存先のディレクトリ
download=False, # ダウンロードを許可
train=True) # 訓練データを指定
#transform=transform) # トランスフォーマーオブジェクトを指定
print("train dataset",f_mnist_train.data[0,8]) #先頭イメージのデータを見る
len= f_mnist_train.len()
print("len",len)
train_dataloader = DataLoader(f_mnist_train, # 訓練データ
batch_size=batch_size, # ミニバッチのサイズ
shuffle=True) # シャッフルして抽出
train_features, train_labels = next(iter(train_dataloader)) #エラー
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = f_mnist_train[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="mono")
plt.show()
sys.exit()
以下略
python38 pytorch
試したこと
ここに問題に対して試したことを記載してください。
補足情報(FW/ツールのバージョンなど)
ここにより詳細な情報を記載してください。
すべてのトレースバック
Traceback (most recent call last):
File "C:\Users\qhtsi\anaconda3\envs\py38pytorch2\lib\site-packages\spyder_kernels\py3compat.py", line 356, in compat_exec
exec(code, globals, locals)
File "c:\book\shuwa\chap06\sec03\fashion-mnist_cnn_pytorch.py", line 62, in <module>
train_features, train_labels = next(iter(train_dataloader)) エラー行
File "C:\Users\xxx\anaconda3\envs\py38pytorch2\lib\site-packages\torch\utils\data\dataloader.py", line 681, in next
data = self._next_data()
File "C:\Users\xxx\anaconda3\envs\py38pytorch2\lib\site-packages\torch\utils\data\dataloader.py", line 721, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "C:\Users\xxx\anaconda3\envs\py38pytorch2\lib\site-packages\torch\utils\data_utils\fetch.py", line 52, in fetch
return self.collate_fn(data)
File "C:\Users\xxx\anaconda3\envs\py38pytorch2\lib\site-packages\torch\utils\data_utils\collate.py", line 175, in default_collate
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File "C:\Users\xxx\anaconda3\envs\py38pytorch2\lib\site-packages\torch\utils\data_utils\collate.py", line 175, in <listcomp>
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
File "C:\Users\xxx\anaconda3\envs\py38pytorch2\lib\site-packages\torch\utils\data_utils\collate.py", line 183, in default_collate
raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>
回答2件
あなたの回答
tips
プレビュー