回答編集履歴
2
test
CHANGED
@@ -6,12 +6,12 @@
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
index = torch.tensor([0, 3, 5, 10])
|
9
|
-
data
|
9
|
+
data= torch.tensor([[1,2,3,4,5,6,7,8,9,10], [10,9,8,7,6,5,4,3,2,1]], dtype=torch.float)
|
10
|
-
out = torch.t
|
10
|
+
out = torch.cat([torch.sum(i, -1).reshape(-1, i.size(0))
|
11
|
+
for i in torch.tensor_split(data[None, :] if data.dim() == 1 else data, index, -1)
|
11
|
-
|
12
|
+
if len(torch.flatten(i))]).T
|
12
13
|
print(out)
|
13
14
|
|
14
|
-
# tensor([[ 6.,
|
15
|
+
# tensor([[ 6., 9., 40.],
|
15
|
-
# [ 9., 9., 9., 9., 9., 9., 9., 9., 9., 9.],
|
16
|
-
# [
|
16
|
+
# [27., 13., 15.]])
|
17
17
|
```
|
1
test
CHANGED
@@ -7,7 +7,7 @@
|
|
7
7
|
|
8
8
|
index = torch.tensor([0, 3, 5, 10])
|
9
9
|
data = torch.tensor([1,2,3,4,5,6,7,8,9,10], dtype=torch.float)
|
10
|
-
out = torch.tensor([[torch.sum(i)] for i in torch.tensor_split(data, index)
|
10
|
+
out = torch.tensor([[torch.sum(i)] for i in torch.tensor_split(data, index) if len(i)])\
|
11
11
|
.expand(-1, data.size(0))
|
12
12
|
print(out)
|
13
13
|
|