コード ```Hi, I got this error. I tried to fix it but could not do this. This is my code: def forward(self, i): print(i.shape) c1_mem = c1_spike = torch.zeros(batch_size, 64, 34, 34, device=device) c2_mem = c2_spike = torch.zeros(batch_size, 64, 17, 17, device=device) c3_mem = c3_spike = torch.zeros(batch_size, 128, 8, 8, device=device) h1_mem = h1_spike = h1_sumspike = torch.zeros(batch_size, 1024, device=device) # print(h1_mem.shape) h2_mem = h2_spike = h2_sumspike = torch.zeros(batch_size, 10, device=device) for idx in range(i.shape[-1]): x = i[:, :, :, :, idx].cuda() #print(x.shape) x = x > torch.rand(x.size(), device=device) # prob. firing print(x.shape) # print(x.shape) c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike) # print(c1_mem.shape) x = F.avg_pool2d(c1_spike, 2) c2_mem, c2_spike = mem_update(self.conv2, x, c2_mem, c2_spike) x = F.avg_pool2d(c2_spike, 2) c3_mem, c3_spike = mem_update(self.conv3, x, c3_mem, c3_spike) print(c3_mem.shape) x = F.avg_pool2d(c3_spike, 4) print(x.shape) ** x = x.view(x.size(0),-1)** print(x.shape) h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike) h1_sumspike += h1_spike h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem, h2_spike) h2_sumspike += h2_spike outputs = h2_sumspike / i.shape[-1] return outputs Thanks
あなたの回答
tips
プレビュー