実現したいこと
M2チップ)pytorchでGPU(mps)処理をしたい
前提
pytorchでGPU(mps)にて計算するため下記のコードを実行しておりますが、エラーが出ます。
発生している問題・エラーメッセージ
RuntimeError Traceback (most recent call last)
Cell In[42], line 45
42 print(inputs.device)
43 device
---> 45 outputs = net(inputs)
46 print(outputs)
File ~/opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
Cell In[42], line 9, in Net.forward(self, x)
8 def forward(self, x):
----> 9 x1 = self.l1(x)
10 x2 = self.relu(x1)
11 x3 = self.l2(x2)
File ~/opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: Placeholder storage has not been allocated on MPS device!
該当のソースコード
"""python
!pip install japanize_matplotlib | tail -n 1
!pip install torchviz | tail -n 1
!pip install torchinfo | tail -n 1
import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib
from IPython.display import display
import torch
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary
from torchviz import make_dot
device = torch.device("mps")
import torchvision.datasets as datasets
data_root = "./data"
train_set0 = datasets.MNIST(
root = data_root,
train = True, download = True)
class Net(nn.Module):
def init(self, n_input, n_output, n_hidden):
super().init()
self.l1 = nn.Linear(n_input, n_hidden)
self.l2 = nn.Linear(n_hidden, n_output)
self.relu = nn.ReLU(inplace=True)
def forward(self, x): x1 = self.l1(x) x2 = self.relu(x1) x3 = self.l2(x2) return x3
device = torch.device("mps")
torch.manual_seed(123)
torch.mps.manual_seed(123)
net = Net(n_input, n_output, n_hidden)
net = net.to(device)
lr = 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
for parameter in net.named_parameters():
print(parameter)
print(net)
summary(net, (784,))
for images, labels in train_loader:
break
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
inputs = images.to(device)
labels = labels.to(device)
print(inputs.device)
device
outputs = net(inputs)
print(outputs)
"""
試したこと
GPUをcudeではなく、mpsにした
補足情報(FW/ツールのバージョンなど)
ここにより詳細な情報を記載してください。

あなたの回答
tips
プレビュー