やりたいこと
pytorchでCIFAR10を用いた学習をしているのですが、データセット内で何番目の画像を使ったかのインデックス番号を取得したいです。
CIFAR10クラスを継承して自作してみましたが、dataなんてメンバ変数は知らないとエラーが出てしまいます。
ご存知の方がいらっしゃいましたらご教授頂きたいです。
試したコード
class MyCIFAR10(torchvision.datasets.CIFAR10): def __init__(self, root, train=True, transform=None, target_transform=None, download=False): super(MyCIFAR10, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) def __getitem__(self, index): img, target = self.data[index], self.targets[index] img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target, index ← このindexをローダに返したい!
## 試した結果
上記クラスを使うと下のようなエラーが出ます。
Traceback (most recent call last): File "main.py", line 186, in <module> train_loss, train_acc = train(epoch) File "main.py", line 132, in train for batch_idx, (inputs, targets, index) in enumerate(trainloader): File "/home/xxx/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 560, in __next__ batch = self.collate_fn([self.dataset[i] for i in indices]) File "/home/xxx/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 560, in <listcomp> batch = self.collate_fn([self.dataset[i] for i in indices]) File "main.py", line 46, in __getitem__ img, target = self.data[index], self.targets[index] AttributeError: 'MyCIFAR10' object has no attribute 'data'
回答2件
あなたの回答
tips
プレビュー