質問編集履歴

1

内容の修正

2019/12/29 11:43

投稿

masaX
masaX

スコア30

test CHANGED
File without changes
test CHANGED
@@ -1,8 +1,8 @@
1
- # pytorch:次元の異なるtensorをtorch.catしたい
1
+ # pytorch:tensorをtorch.catしたい
2
2
 
3
3
  pytorchでプログラムを書いているのですが、torch.catができません。
4
4
 
5
- 以下のようなエラーが発生しているのですが、おそらくlistに含れてるテンソルの次元がべて一致していないことが原因だと思います
5
+ 以下のようなエラーが発生しています。
6
6
 
7
7
 
8
8
 
@@ -10,9 +10,7 @@
10
10
 
11
11
  ```
12
12
 
13
- RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 4 and 2 at /pytorch/aten/src/THC/generic/THCTensorMath.cu:102
13
+ TypeError: expected Variable as element 0 in argument 0, but got list
14
-
15
-
16
14
 
17
15
  ```
18
16
 
@@ -22,15 +20,25 @@
22
20
 
23
21
  ```
24
22
 
25
- p = []
23
+ p = []
26
24
 
27
25
  for i in range(0, N, self.gpu_batch):
28
26
 
29
- #p.append(self.model(stack[i:min(i + self.gpu_batch, N)]))
27
+ p.append(self.model(stack[i:min(i + self.gpu_batch, N)]))
30
-
31
- p += self.model(stack[i:min(i + self.gpu_batch, N)])
32
28
 
33
29
  p = torch.cat(p)
30
+
31
+ # Number of classes
32
+
33
+ CL = p.size(1)
34
+
35
+ sal = torch.matmul(p.data.transpose(0, 1), self.masks.view(N, H * W))
36
+
37
+ sal = sal.view((CL, H, W))
38
+
39
+ sal = sal / N / self.p1
40
+
41
+ return sal
34
42
 
35
43
  ```
36
44