前提として
インプットがtorch.Size([32, 3, 400, 400]) なら
インプットが3,400,400のモデル バッチサイズ32で動かす
(動かす+lossを出す+バックエンドの流れを for で32回)と大体同じ であってますか?
torch.Size([1, 3, 400, 400]) なら動きますが(ランダムで同じエラー出る)
torch.Size([32, 3, 400, 400])だと確定で動きません
理由が全く分からず詰んでます わかる人教えてください
メモリーを確認しましたがすべて
torch.Size([1, 3, 400, 400])でした
mainQN.train() optimizer.zero_grad() print(inputs[0].shape) output = mainQN.forward(inputs,"net_q") if self.IQN==True: self.loss_IQN(target,output,weights) else: loss = criterion(output,targets)
mainQN = QNetwork(state.shape,action_size).to('cuda:0') optimizer = optim.Adam(mainQN.parameters(), lr=learning_rate) mainQN, optimizer = amp.initialize(mainQN, optimizer, opt_level="O1")#-------------- if IQN==False: criterion = nn.MSELoss() targetQN = mainQN targetQN.eval() mainQN.eval() モデル定義 swishは自作活性化関数 class QNetwork(nn.Module): def __init__(self, num_inputs, num_outputs): super(QNetwork, self).__init__() self.LSTMs=[] self.net_type="noisy" self.hidden_size=25*25#NAF用 最後の層 nn.ZeroPad2d(1) self.hidden_size1=25*25 self.hidden_size2=None self.num_inputs=num_inputs self.cnn1 = nn.Sequential( nn.Conv2d(3, 16, kernel_size=(3,3), padding=1), swish(0.7), .... nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))) self.cnn1.apply(init_weights) self.free_net= nn.Sequential( nn.Linear(self.free_input(), self.hidden_size1), swish(0.7), nn.Linear(self.hidden_size1, self.hidden_size1), swish(0.7), nn.Linear(self.hidden_size1, self.hidden_size1), swish(0.7), ) self.free_net.apply(init_weights) def forward(self,inputs,net): if net=="net_q": x, u = inputs x=x.to('cuda:0') u=u.to('cuda:0') else: x = inputs.to('cuda:0') #------------------------------------ x=self.cnn1(x) x=x.contiguous().view(-1, 1).T x=self.free_net(x) ....
torch.Size([32, 3, 400, 400]) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-1-be2832c14163> in <module> 436 trin.pioritized_experience_replay(batch_size, gamma,step=episode, 437 state_size=state_,action_size=acthon, --> 438 multireward_steps=multireward_steps) 439 trin.Done(episode) 440 mainQN.Done() <ipython-input-1-be2832c14163> in pioritized_experience_replay(self, batch_size, gamma, step, state_size, action_size, multireward_steps) 298 optimizer.zero_grad() 299 print(inputs[0].shape) --> 300 output = mainQN.forward(inputs,"net_q") 301 if self.IQN==True: 302 self.loss_IQN(target,output,weights) <ipython-input-1-be2832c14163> in forward(self, inputs, net) 204 x=self.cnn1(x) 205 x=x.contiguous().view(-1, 1).T --> 206 x=self.free_net(x) 207 208 ~\Anaconda3\envs\pyflan\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs) 548 result = self._slow_forward(*input, **kwargs) 549 else: --> 550 result = self.forward(*input, **kwargs) 551 for hook in self._forward_hooks.values(): 552 hook_result = hook(self, input, result) ~\Anaconda3\envs\pyflan\lib\site-packages\torch\nn\modules\container.py in forward(self, input) 98 def forward(self, input): 99 for module in self: --> 100 input = module(input) 101 return input 102 ~\Anaconda3\envs\pyflan\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs) 548 result = self._slow_forward(*input, **kwargs) 549 else: --> 550 result = self.forward(*input, **kwargs) 551 for hook in self._forward_hooks.values(): 552 hook_result = hook(self, input, result) ~\Anaconda3\envs\pyflan\lib\site-packages\torch\nn\modules\linear.py in forward(self, input) 85 86 def forward(self, input): ---> 87 return F.linear(input, self.weight, self.bias) 88 89 def extra_repr(self): ~\Anaconda3\envs\pyflan\lib\site-packages\apex\amp\wrap.py in wrapper(*args, **kwargs) 26 args, 27 kwargs) ---> 28 return orig_fn(*new_args, **kwargs) 29 return wrapper 30 ~\Anaconda3\envs\pyflan\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias) 1608 if input.dim() == 2 and bias is not None: 1609 # fused op is marginally faster -> 1610 ret = torch.addmm(bias, input, weight.t()) 1611 else: 1612 output = input.matmul(weight.t()) ~\Anaconda3\envs\pyflan\lib\site-packages\apex\amp\wrap.py in wrapper(*args, **kwargs) 26 args, 27 kwargs) ---> 28 return orig_fn(*new_args, **kwargs) 29 return wrapper 30 RuntimeError: size mismatch, m1: [1 x 5120000], m2: [160000 x 625] at C:/cb/pytorch_1000000000000/work/aten/src\THC/generic/THCTensorMathBlas.cu:283
あなたの回答
tips
プレビュー