状況
PytorchのCIFAR-10チュートリアルを参考に,画像を分類するモデルを作成しました.
CIFAR-10チュートリアルはこちらです:
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
このモデルに対して,任意の画像を入力し分類させたいのですが,RuntimeError: shape '[-1, 400]' is invalid for input of size 293904
というエラーによって失敗してしまいます.解決方法が分かる方がおりましたら,ご教授いただきたいです.
実行コード
チュートリアルのコード
jupyter形式のソースコードを繋げたものを貼っています.上記URLから直接チュートリアルを除いた方が理解しやすいかもしれないです.
python
1import torch 2import torchvision 3import torchvision.transforms as transforms 4import matplotlib.pyplot as plt 5import numpy as np 6import torch.nn as nn 7import torch.nn.functional as F 8import torch.optim as optim 9 10 11transform = transforms.Compose( 12 [transforms.ToTensor(), 13 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 14 15trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 16 download=True, transform=transform) 17trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 18 shuffle=True, num_workers=2) 19 20testset = torchvision.datasets.CIFAR10(root='./data', train=False, 21 download=True, transform=transform) 22testloader = torch.utils.data.DataLoader(testset, batch_size=4, 23 shuffle=False, num_workers=2) 24 25classes = ('plane', 'car', 'bird', 'cat', 26 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 27 28# functions to show an image 29 30def imshow(img): 31 img = img / 2 + 0.5 # unnormalize 32 npimg = img.numpy() 33 plt.imshow(np.transpose(npimg, (1, 2, 0))) 34 plt.show() 35 36 37# get some random training images 38dataiter = iter(trainloader) 39images, labels = dataiter.next() 40 41# show images 42imshow(torchvision.utils.make_grid(images)) 43# print labels 44print(' '.join('%5s' % classes[labels[j]] for j in range(4))) 45 46class Net(nn.Module): 47 def __init__(self): 48 super(Net, self).__init__() 49 self.conv1 = nn.Conv2d(3, 6, 5) 50 self.pool = nn.MaxPool2d(2, 2) 51 self.conv2 = nn.Conv2d(6, 16, 5) 52 self.fc1 = nn.Linear(16 * 5 * 5, 120) 53 self.fc2 = nn.Linear(120, 84) 54 self.fc3 = nn.Linear(84, 10) 55 56 def forward(self, x): 57 x = self.pool(F.relu(self.conv1(x))) 58 x = self.pool(F.relu(self.conv2(x))) 59 x = x.view(-1, 16 * 5 * 5) 60 # x = x.view(x.size(0), 16*37*37) 61 x = F.relu(self.fc1(x)) 62 x = F.relu(self.fc2(x)) 63 x = self.fc3(x) 64 return x 65 66net = Net() 67 68criterion = nn.CrossEntropyLoss() 69optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 70 71for epoch in range(2): # loop over the dataset multiple times 72 73 running_loss = 0.0 74 for i, data in enumerate(trainloader, 0): 75 # get the inputs; data is a list of [inputs, labels] 76 inputs, labels = data 77 78 # zero the parameter gradients 79 optimizer.zero_grad() 80 81 # forward + backward + optimize 82 outputs = net(inputs) 83 loss = criterion(outputs, labels) 84 loss.backward() 85 optimizer.step() 86 87 # print statistics 88 running_loss += loss.item() 89 if i % 2000 == 1999: # print every 2000 mini-batches 90 print('[%d, %5d] loss: %.3f' % 91 (epoch + 1, i + 1, running_loss / 2000)) 92 running_loss = 0.0 93 94print('Finished Training') 95 96PATH = './cifar_net.pth' 97torch.save(net.state_dict(), PATH) 98 99dataiter = iter(testloader) 100images, labels = dataiter.next() 101 102# print images 103imshow(torchvision.utils.make_grid(images)) 104print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) 105 106net = Net() 107net.load_state_dict(torch.load(PATH)) 108 109_, predicted = torch.max(outputs, 1) 110 111print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] 112 for j in range(4)))
実行したいコード
調べながら書きました.jpgファイルをTensorに変換して入力させようとしています.
python
1from PIL import Image 2import torchvision.transforms as transforms 3 4device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 6transform = transforms.Compose( 7 [transforms.ToTensor(), 8 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 9 10def image_loader(image_name): 11 image = Image.open(image_name) 12 image = transform(image).unsqueeze(0) 13 return image.to(device, torch.float) 14 15img = image_loader("./photo.jpg") 16 17outputs = net(img)
エラー
RuntimeError: shape '[-1, 400]' is invalid for input of size 293904
あなたの回答
tips
プレビュー