Pytorchのチュートリアルのコードを実行したところエラーが出てきたのですが, 解決方法がわからず困っています。
ソースコード
python
1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4import torchvision 5import torchvision.transforms as transforms 6 7transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 8 9trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform) 10trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2) 11 12testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) 13testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2) 14 15 16classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 17 18import matplotlib.pyplot as plt 19import numpy as np 20 21def imshow(img): 22 img = img / 2 + 0.5 # 標準化を戻す 23 npimg = img.numpy() # NumPy配列に変換 24 plt.imshow(np.transpose(npimg, (1, 2, 0))) # (高さ, 横幅, チャネル数)となるよう整形 25 plt.show() #画像の表示 26 27# get some random training images 28dataiter = iter(trainloader) 29images, labels = dataiter.next() 30 31imshow(torchvision.utils.make_grid(images)) 32# print labels 33print(' '.join('%5s' % classes[labels[j]] for j in range(4))) 34 35class Net(nn.Module): 36 def __init__(self): 37 super(Net, self).__init__() 38 self.conv1 = nn.Conv2d(3, 6, 5) 39 self.pool = nn.MaxPool2d(2, 2) 40 self.conv2 = nn.Conv2d(6, 16, 5) 41 self.fc1 = nn.Linear(16 * 5 * 5, 120) 42 self.fc2 = nn.Linear(120, 84) 43 self.fc3 = nn.Linear(84, 10) 44 45 def forward(self, x): 46 x = self.pool(F.relu(self.conv1(x))) 47 x = self.pool(F.relu(self.conv2(x))) 48 x = x.view(-1, 16 * 5 * 5) 49 x = F.relu(self.fc1(x)) 50 x = F.relu(self.fc2(x)) 51 x = self.fc3(x) 52 return x 53 54 55net = Net() 56 57import torch.optim as optim 58 59criterion = nn.CrossEntropyLoss() # 損失関数を交差エントロピーに設定 60optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 最適化関数をSGDにしてmomentumを指定 61 62 63for epoch in range(2): # loop over the dataset multiple times 64 65 running_loss = 0.0 66 for i, data in enumerate(trainloader, 0): 67 # get the inputs; data is a list of [inputs, labels] 68 inputs, labels = data 69 70 # zero the parameter gradients 71 optimizer.zero_grad() 72 73 # forward + backward + optimize 74 outputs = net(inputs) 75 loss = criterion(outputs, labels) 76 loss.backward() 77 optimizer.step() 78 79 # print statistics 80 running_loss += loss.item() 81 82 #plot_history(running_loss,i,epoch) 83 84 if i % 2000 == 1999: # print every 2000 mini-batches 85 print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) 86 running_loss = 0.0 87 88print('Finished Training') 89 90PATH = './cifar_net.pth' 91torch.save(net.state_dict(), PATH) 92 93dataiter = iter(testloader) 94images, labels = dataiter.next() 95 96# print images 97imshow(torchvision.utils.make_grid(images)) 98print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) 99 100net = Net() 101net.load_state_dict(torch.load(PATH)) 102 103outputs = net(images) 104 105print("outputs:\n{}".format(outputs)) 106 107_, predicted = torch.max(outputs, 1) 108print("max value:{}".format(_)) 109print("Predicted:{}".format(predicted)) 110 111print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4))) 112 113correct = 0 114total = 0 115with torch.no_grad(): 116 for data in testloader: 117 images, labels = data 118 outputs = net(images) 119 _, predicted = torch.max(outputs.data, 1) 120 total += labels.size(0) 121 correct += (predicted == labels).sum().item() 122 123print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total)) 124 125class_correct = list(0. for i in range(10)) 126class_total = list(0. for i in range(10)) 127with torch.no_grad(): 128 for data in testloader: 129 images, labels = data 130 outputs = net(images) 131 _, predicted = torch.max(outputs, 1) 132 c = (predicted == labels).squeeze() 133 for i in range(4): 134 label = labels[i] 135 class_correct[label] += c[i].item() 136 class_total[label] += 1 137 138 139for i in range(10): 140 print('Accuracy of %5s : %2d %%' % ( 141 classes[i], 100 * class_correct[i] / class_total[i]))
エラー内容
RuntimeError Traceback (most recent call last)
<ipython-input-9-99ffcbefe663> in <module>
26
27 # get some random training images
---> 28 dataiter = iter(trainloader)
29 images, labels = dataiter.next()
30
~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in iter(self)
277 return _SingleProcessDataLoaderIter(self)
278 else:
--> 279 return _MultiProcessingDataLoaderIter(self)
280
281 @property
~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in init(self, loader)
744 # prime the prefetch loop
745 for _ in range(2 * self._num_workers):
--> 746 self._try_put_index()
747
748 def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _try_put_index(self)
859 assert self._tasks_outstanding < 2 * self._num_workers
860 try:
--> 861 index = self._next_index()
862 except StopIteration:
863 return
~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_index(self)
337
338 def _next_index(self):
--> 339 return next(self._sampler_iter) # may raise StopIteration
340
341 def _next_data(self):
~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/sampler.py in iter(self)
198 def iter(self):
199 batch = []
--> 200 for idx in self.sampler:
201 batch.append(idx)
202 if len(batch) == self.batch_size:
~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/sampler.py in iter(self)
105 if self.replacement:
106 return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
--> 107 return iter(torch.randperm(n).tolist())
108
109 def len(self):
~/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/signal_handling.py in handler(signum, frame)
64 # This following call uses waitid
with WNOHANG from C side. Therefore,
65 # Python can still get and update the process status successfully.
---> 66 _error_if_any_worker_fails()
67 if previous_handler is not None:
68 previous_handler(signum, frame)
**RuntimeError: DataLoader worker (pid 80830) is killed by signal: Unknown signal: 0. **
同じコードで昨日動かしていたときには, 問題なく動作していたので, コードの中身の問題ではないと思うのですが, 何かわかる方いましたらお願いします。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。