質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.09%

pytorchのデータセットでインデックスを取得する方法

解決済

回答 2

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 995

score 101

やりたいこと

pytorchでCIFAR10を用いた学習をしているのですが、データセット内で何番目の画像を使ったかのインデックス番号を取得したいです。
CIFAR10クラスを継承して自作してみましたが、dataなんてメンバ変数は知らないとエラーが出てしまいます。
ご存知の方がいらっしゃいましたらご教授頂きたいです。

試したコード

class MyCIFAR10(torchvision.datasets.CIFAR10):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(MyCIFAR10, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index ← このindexをローダに返したい!

 試した結果

上記クラスを使うと下のようなエラーが出ます。

Traceback (most recent call last):
  File "main.py", line 186, in <module>
    train_loss, train_acc = train(epoch)
  File "main.py", line 132, in train
    for batch_idx, (inputs, targets, index) in enumerate(trainloader):
  File "/home/xxx/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 560, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/xxx/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 560, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "main.py", line 46, in __getitem__
    img, target = self.data[index], self.targets[index]
AttributeError: 'MyCIFAR10' object has no attribute 'data'
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • 0kcal

    2020/02/01 15:52 編集

    お疲れ様。
    答えは、わかってないのですが、
    まず、上記の構成で、dataというのは、ないような気がします。

    ```python
    C:\_temp_work>python -m pydoc torchvision.datasets.CIFAR10
    Help on class CIFAR10 in torchvision.datasets:

    torchvision.datasets.CIFAR10 = class CIFAR10(torchvision.datasets.vision.VisionDataset)
    | torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
    |
    | `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
    |
    | Args:
    | root (string): Root directory of dataset where directory
    | ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
    | train (bool, optional): If True, creates dataset from training set, otherwise
    | creates from test set.
    | transform (callable, optional): A function/transform that takes in an PIL image
    | and returns a transformed version. E.g, ``transforms.RandomCrop``
    | target_transform (callable, optional): A function/transform that takes in the
    | target and transforms it.
    | download (bool, optional): If true, downloads the dataset from the internet and
    | puts it in root directory. If dataset is already downloaded, it is not
    | downloaded again.


    C:\_temp_work>

    ```

    キャンセル

  • s-uchi

    2020/02/01 16:19

    回答有り難うございます。

    https://pytorch.org/docs/stable/_modules/torchvision/datasets/cifar.html#CIFAR10
    を見て、CIFAR10クラスで定義されてるself.dataを引っ張ってこれたらと考えてます。
    super()ってしたら親の変数達も見れるようになると理解してるのですが、、、継承がよくわかりませんw

    キャンセル

回答 2

checkベストアンサー

+1

Cifar10の場合ですが、例えば以下の要領でデータを取得して、

trainset = torchvision.datasets.CIFAR10(root='../data/raw', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='../data/raw', train=False,
                                       download=True, transform=transform)

# DataLoaderの作成
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

次に、以下のようにすることでミニバッチ内のインデックスが取得できます。上記の場合はバッチサイズが4なので、forの中では0,1,2,3のインデックスとそのデータが取得できます。このような回答でよろしかったでしょうか。

for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        for index, result in enumerate(inputs):
          print(index, result) # これでバッチ内のインデックスとそのデータが取得できる。

<質問に対する回答の追記>
例えばですが、以下のようにクラスを定義することで、お望みのインデックスありのデータセットが定義できます。

class Subset(Dataset):
    """
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, data, label, indices):
        self.data = data
        self.label = label
        self.indices = indices

    def __getitem__(self, idx):
        #out_data = self.transform(self.data)[idx]

        return self.data, self.label, self.data[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

train_size = len(trainset) # n_samples is 60000
indices = list(range(0,train_size)) # [0,1,.....47999]
train_dataset = Subset(trainset.data, trainset.targets, indices)

# 二つ目のデータの入力とラベルとインデックスを表示
print(train_dataset.data[2])
print(train_dataset.label[2])
print(train_dataset.indices[2])

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2020/02/01 16:34

    質問が悪かったです。(本文を編集しました)

    バッチ内のインデックスではなく、データセット50000枚の何番目かを知りたかったです。
    同じ画像の特徴量が学習中にどのように変化しているか(データローダでshuffuleが入った状態で)
    トラッキングしたいのが本当のやりたいことです。

    キャンセル

  • 2020/02/01 18:30

    回答に追記させて頂きました。これでお望みのデータセットが作成できると思います。

    キャンセル

0

お疲れ様。
答えは、わかってないのですが、
まず、上記の構成で、dataというのは、存在しない気がします。

C:\_temp_work>python -m pydoc torchvision.datasets.CIFAR10
Help on class CIFAR10 in torchvision.datasets:

torchvision.datasets.CIFAR10 = class CIFAR10(torchvision.datasets.vision.VisionDataset)
| torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
|
| `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
| Args:
| root (string): Root directory of dataset where directory
| ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
| train (bool, optional): If True, creates dataset from training set, otherwise
| creates from test set.
| transform (callable, optional): A function/transform that takes in an PIL image
| and returns a transformed version. E.g, ``transforms.RandomCrop``
| target_transform (callable, optional): A function/transform that takes in the
| target and transforms it.
| download (bool, optional): If true, downloads the dataset from the internet and
| puts it in root directory. If dataset is already downloaded, it is not
| downloaded again.


C:\_temp_work>

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.09%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る