pytorchのDataloaderを使ってTTAを実装しようとしているのですが、うまくいきません。
試したい流れは以下の通りです。
0. TTA用のtransoformsを3つ用意
- Datasetでデータを取得する際に3種類分の画像をリストで取得するように設定
取得するデータのsize->(3, channel, image_height, image_width)
2. Dataloaderでデータを読み込む際にbatch_size分のデータを取得
取得するデータのsize->(batch_size, 3, channel, image_height, image_width)
4. バッチごとモデルに渡し、推論
<問題点>
ステップ3の部分で、バッチサイズ分のデータを取り出すことができません。
具体的には、batch_size=32,channel=3として(32,3,3,512,512)というサイズのデータを取り出したいのですが、動かしてみると(1,3,3,512,512)というサイズになってしまいます。
(32,3,3,512,512)のサイズで取り出す方法を教えていただけないでしょうか。
# transforms test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((Conf.IM_SIZE, Conf.IM_SIZE)), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) TTA1 = transforms.Compose([ transforms.ToTensor(), transforms.CenterCrop(Conf.IM_SIZE), transforms.RandomHorizontalFlip(p=1), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) TTA2 = transforms.Compose([ transforms.ToTensor(), transforms.CenterCrop(Conf.IM_SIZE), transforms.RandomVerticalFlip(p=1), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ])
# Dataset class TestData(Dataset): def __init__(self, Dir, FNames, Transform): self.dir = Dir self.fnames = FNames self.transform = Transform self.ttas = [self.transform, TTA1, TTA2] def __len__(self): return len(self.fnames) def __getitem__(self, index): x = Image.open(os.path.join(self.dir, self.fnames[index])) imglist=[tta(x) for tta in self.ttas] image=torch.stack(imglist) return image, self.fnames[index]
# DataLoader testset = TestData(TEST_DIR, X_test, test_transform) testloader = DataLoader(testset, batch_size=Conf.BATCH, shuffle=False, num_workers=4)
# Inference for images,image_names in testloader: with torch.no_grad(): batch_size, n_crops, c, h, w = images.size() <---ここでバッチサイズが1になってしまっています。 images = images.view(-1, c, h, w) output= F.softmax(model(images),dim=1) output = output.view(batch_size, n_crops,-1).mean(1) pred = output.argmax(1).cpu().numpy()
あなたの回答
tips
プレビュー