質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
86.12%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

解決済

pytorch.datasets/dataloader生成のデータを調べるときにエラーの再提示

sigefuji
sigefuji

総合スコア114

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

1回答

0リアクション

0クリップ

216閲覧

投稿2022/08/31 06:32

質問の内容

参考書籍のコード(pytorch)で自前のデータの動作を調べています。
先の質問での環境を再構築して、現象が再現できましたので、再度補足して質問します。

調べている直近の課題
train_features, train_labels = next(iter(train_dataloader))

により、dataloaderで生成されたデータの内、先頭のデータ(一個の画像データ)の内容を調べること。
それで、上記命令のように、先頭の1個のデータを取り出し、featuerとlabelの内容をprintしようとした。

この位置で検出されたエラーメッセージ

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

このエラーを出力したcollate.pyのソースも補足で提示。しかしこの部分を理解するまでもなく恐らく、
next(iter(train_dataloader))の個所が何かおかしいと予想されます。

発生している問題・エラーメッセージ

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

該当のソースコード

python

# トランスフォーマーオブジェクトを生成 transform = transforms.Compose( [transforms.ToTensor(), # Tensorオブジェクトに変換 #transforms.Normalize((0.5), (0.5)) # 平均0.5、標準偏差0.5で正規化 ]) # 訓練用データの読み込み(60000セット) f_mnist_train = datasets.FashionMNIST( root=root, # データの保存先のディレクトリ download=False, # ダウンロードを許可 train=True) # 訓練データを指定 #transform=transform) # トランスフォーマーオブジェクトを指定 # Display image and label. #print("size",f_mnist_train.size()) print("train dataset",f_mnist_train.data[0,8]) #print("train dataset",f_mnist_train.label[0]) len= f_mnist_train.__len__() print("len",len) # テスト用データの読み込み(10000セット) f_mnist_test = datasets.FashionMNIST( root=root, # データの保存先のディレクトリ download=False, # ダウンロードを許可 train=False) # テストデータを指定 #transform=transform) # トランスフォーマーオブジェクトを指定 #print("test.len",f_mnist_test.size()) # 訓練用のデータローダー train_dataloader = DataLoader(f_mnist_train, # 訓練データ batch_size=batch_size, # ミニバッチのサイズ shuffle=True) # シャッフルして抽出 # テスト用のデータローダー test_dataloader = DataLoader(f_mnist_test, # テストデータ batch_size=1, # ミニバッチのサイズ shuffle=False) # シャッフル無しに抽出 #ここまではオリジナルのソース #以下が調査用に追加した部分 train_features, train_labels = next(iter(train_dataloader)) print("size", train_features.size()) sys.exit() #以下は色々調べてみようとした形跡(参考まで) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") #img = f_mnist_train[0].squeeze() #label = train_labels[0] #plt.imshow(img, cmap="mono") #plt.show() #print(f"Label: {label}") sys.exit() print("train_dataloader",train_dataloader[0]) print("train.len",train_dataloader.__len__(),train_dataloader.__size__()) print("test.len",train_dataloader.__len__()) #print("f_mnist_train.shape",f_mnist_train.shape[0],"test.shape",f_mnist_test.shape)[0]; #print("train_dataloader",type(train_dataloader)) #print("test_dataloader",test_dataloader).dim() # データローダーが返すミニバッチの先頭データの形状を出力 for (x, t) in train_dataloader: # 訓練データ print(x.shape) print(t.shape) break for (x, t) in test_dataloader: # テストデータ #print(x.shape) #print(t.shape) print(type(t)) print(type(t)) break

補足情報

最後にエラー検出した関数(collate.py)のリスト

r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to collate samples fetched from dataset into Tensor(s). These **needs** to be in global scope since Py2 doesn't support serializing static methods. `default_collate` and `default_convert` are exposed to users via 'dataloader.py'. """ import torch import re import collections from torch._six import string_classes np_str_obj_array_pattern = re.compile(r'[SaUO]') def default_convert(data): r""" Function that converts each NumPy array element into a :class:`torch.Tensor`. If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`. If the input is not an NumPy array, it is left unchanged. This is used as the default function for collation when both `batch_sampler` and `batch_size` are NOT defined in :class:`~torch.utils.data.DataLoader`. The general input type to output type mapping is similar to that of :func:`~torch.utils.data.default_collate`. See the description there for more details. Args: data: a single data point to be converted Examples: >>> # Example with `int` >>> default_convert(0) 0 >>> # Example with NumPy array >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple >>> Point = namedtuple('Point', ['x', 'y']) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) Point(x=tensor(0), y=tensor(0)) >>> # Example with List >>> default_convert([np.array([0, 1]), np.array([2, 3])]) [tensor([0, 1]), tensor([2, 3])] """ elem_type = type(data) if isinstance(data, torch.Tensor): return data elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': # array of string classes and object if elem_type.__name__ == 'ndarray' \ and np_str_obj_array_pattern.search(data.dtype.str) is not None: return data return torch.as_tensor(data) elif isinstance(data, collections.abc.Mapping): try: return elem_type({key: default_convert(data[key]) for key in data}) except TypeError: # The mapping type may not support `__init__(iterable)`. return {key: default_convert(data[key]) for key in data} elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple return elem_type(*(default_convert(d) for d in data)) elif isinstance(data, tuple): return [default_convert(d) for d in data] # Backwards compatibility. elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes): try: return elem_type([default_convert(d) for d in data]) except TypeError: # The sequence type may not support `__init__(iterable)` (e.g., `range`). return [default_convert(d) for d in data] else: return data default_collate_err_msg_format = ( "default_collate: batch must contain tensors, numpy arrays, numbers, " "dicts or lists; found {}") def default_collate(batch): 本文が長くなるので省略

以下のような質問にはリアクションをつけましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

リアクションが多い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

下記のような質問は推奨されていません。

  • 間違っている
  • 質問になっていない投稿
  • スパムや攻撃的な表現を用いた投稿

適切な質問に修正を依頼しましょう。

sigefuji

2022/08/31 11:13

仰せの事ですが、元の投稿を編集する(どうするのかわからないので単に追加するだけになると)、かえって煩雑になると考えます。もし思うとおりに編集できたとしても、それは新規投稿と何ら変わらないと考えます。

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
86.12%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。