前提・実現したいこと
Python初心者です。
AlexNetで3次元データを画像分類したいのですが、RuntimeErrorになってしまいます。
間違っている箇所あればご指摘していただけると幸いです。
発生している問題・エラーメッセージ
$ python Alexnet.py Traceback (most recent call last): File "test.py", line 127, in <module> outputs = net(images) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__ result = self.forward(*input, **kwargs) File "test.py", line 100, in forward x = self.features(x) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__ result = self.forward(*input, **kwargs) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__ result = self.forward(*input, **kwargs) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 476, in forward self.padding, self.dilation, self.groups) RuntimeError: Expected 5-dimensional input for 5-dimensional weight 64 3 3 3, but got 4-dimensional input of size [64, 80, 96, 80] instead
該当のソースコード
Python
1import matplotlib.pyplot as plt 2from mpl_toolkits.mplot3d import Axes3D 3 4import torch 5import torchvision 6import torch.nn as nn 7import torch.nn.init as init 8import torch.optim as optim 9import torch.nn.functional as F 10import torchvision.transforms as transforms 11import numpy as np 12from matplotlib import pyplot as plt 13from skimage import io, transform 14 15from torch.utils.data import Dataset 16from dataset import CLASS_MAP 17import dataset 18 19CLASS_MAP = {"CN": 0, "AD": 1, "LMCI": 2, "SMC": 3, "EMCI": 4} 20 21class BrainData(Dataset): 22 def __init__(self, data, transform=None, class_map=CLASS_MAP): 23 self.data = data 24 self.class_map = class_map 25 self.transform = transform 26 27 def __len__(self): 28 return len(self.data) 29 30 def __getitem__(self, idx): 31 if torch.is_tensor(idx): 32 idx = idx.tolist() 33 34 voxel = self.data[idx]["voxel"] 35 label = self.class_map[self.data[idx]["label"]] 36 37 #return sample 38 return (voxel, label) 39 40data = dataset.load_data(["ADNI2"]) 41data_set = BrainData(data, CLASS_MAP) 42 43n_train = int(len(data_set) * 0.8) 44n_val = int(len(data_set) - n_train) 45 46torch.manual_seed(0) 47 48train_dataset, val_dataset = torch.utils.data.random_split(data_set, [n_train, n_val]) 49 50train_loader = torch.utils.data.DataLoader( 51 train_dataset, batch_size=64, shuffle=True, num_workers=5) 52 53val_loader = torch.utils.data.DataLoader( 54 val_dataset, batch_size=64, shuffle=False, num_workers=5) 55 56# set data loader 57train_loader = torch.utils.data.DataLoader( 58 dataset=train_dataset, 59 batch_size=64, 60 shuffle=True, 61 num_workers=5) 62 63val_loader = torch.utils.data.DataLoader( 64 dataset=val_dataset, 65 batch_size=64, 66 shuffle=False, 67 num_workers=5) 68 69#class AlexNet 70class AlexNet(nn.Module): 71 72 def __init__(self, num_classes): 73 super(AlexNet, self).__init__() 74 self.features = nn.Sequential( 75 nn.Conv3d(3, 64, kernel_size=3, padding=1), 76 nn.ReLU(inplace=True), 77 nn.MaxPool2d(kernel_size=2, stride=2), 78 nn.Conv3d(64, 192, kernel_size=5, padding=2), 79 nn.ReLU(inplace=True), 80 nn.MaxPool2d(kernel_size=2, stride=2), 81 nn.Conv3d(192, 384, kernel_size=3, padding=1), 82 nn.ReLU(inplace=True), 83 nn.Conv3d(384, 256, kernel_size=3, padding=1), 84 nn.ReLU(inplace=True), 85 nn.Conv3d(256, 256, kernel_size=3, padding=1), 86 nn.ReLU(inplace=True), 87 nn.MaxPool2d(kernel_size=2, stride=2), 88 ) 89 self.classifier = nn.Sequential( 90 nn.Dropout(), 91 nn.Linear(256 * 4 * 4, 4096), 92 nn.ReLU(inplace=True), 93 nn.Dropout(), 94 nn.Linear(4096, 4096), 95 nn.ReLU(inplace=True), 96 nn.Linear(4096, num_classes), 97 ) 98 99 def forward(self, x): 100 x = self.features(x) 101 x = x.view(x.size(0), 256 * 4 * 4) 102 x = self.classifier(x) 103 return x 104 105# select device 106num_classes = 5 107device = 'cuda' if torch.cuda.is_available() else 'cpu' 108net = AlexNet(num_classes).to(device) 109 110# optimizing 111criterion = nn.CrossEntropyLoss() 112optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 113 114# training 115num_epochs = 20 116train_loss_list, train_acc_list, val_loss_list, val_acc_list = [], [], [], [] 117 118### training 119for epoch in range(num_epochs): 120 train_loss, train_acc, val_loss, val_acc = 0, 0, 0, 0 121 122 # ====== train_mode ====== 123 net.train() 124 for i, (images, labels) in enumerate(train_loader): 125 images, labels = images.to(device), labels.to(device) 126 optimizer.zero_grad() 127 outputs = net(images) 128 loss = criterion(outputs, labels) 129 train_loss += loss.item() 130 train_acc += (outputs.max(1)[1] == labels).sum().item() 131 loss.backward() 132 optimizer.step() 133 134 avg_train_loss = train_loss / len(train_loader.dataset) 135 avg_train_acc = train_acc / len(train_loader.dataset) 136 137 # ====== val_mode ====== 138 net.eval() 139 with torch.no_grad(): 140 for images, labels in val_loader: 141 images = images.to(device) 142 labels = labels.to(device) 143 outputs = net(images) 144 loss = criterion(outputs, labels) 145 val_loss += loss.item() 146 val_acc += (outputs.max(1)[1] == labels).sum().item() 147 avg_val_loss = val_loss / len(val_loader.dataset) 148 avg_val_acc = val_acc / len(val_loader.dataset) 149 150 print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}' 151 .format(epoch+1, num_epochs, i+1, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc)) 152 train_loss_list.append(avg_train_loss) 153 train_acc_list.append(avg_train_acc) 154 val_loss_list.append(avg_val_loss) 155 val_acc_list.append(avg_val_acc) 156 157 158# plot graph 159plt.figure() 160plt.plot(range(num_epochs), train_loss_list, color='blue', linestyle='-', label='train_loss') 161plt.plot(range(num_epochs), val_loss_list, color='green', linestyle='--', label='val_loss') 162plt.legend() 163plt.xlabel('epoch') 164plt.ylabel('loss') 165plt.title('Training and validation loss') 166plt.grid() 167plt.show() 168 169plt.figure() 170plt.plot(range(num_epochs), train_acc_list, color='blue', linestyle='-', label='train_acc') 171plt.plot(range(num_epochs), val_acc_list, color='green', linestyle='--', label='val_acc') 172plt.legend() 173plt.xlabel('epoch') 174plt.ylabel('acc') 175plt.title('Training and validation accuracy') 176plt.grid() 177plt.show()
あなたの回答
tips
プレビュー