回答編集履歴

2

2023/03/21 12:55

投稿

melian
melian

スコア19749

test CHANGED
@@ -6,12 +6,12 @@
6
6
  import torch
7
7
 
8
8
  index = torch.tensor([0, 3, 5, 10])
9
- data = torch.tensor([1,2,3,4,5,6,7,8,9,10], dtype=torch.float)
9
+ data= torch.tensor([[1,2,3,4,5,6,7,8,9,10], [10,9,8,7,6,5,4,3,2,1]], dtype=torch.float)
10
- out = torch.tensor([[torch.sum(i)] for i in torch.tensor_split(data, index) if len(i)])\
10
+ out = torch.cat([torch.sum(i, -1).reshape(-1, i.size(0))
11
+ for i in torch.tensor_split(data[None, :] if data.dim() == 1 else data, index, -1)
11
- .expand(-1, data.size(0))
12
+ if len(torch.flatten(i))]).T
12
13
  print(out)
13
14
 
14
- # tensor([[ 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
15
+ # tensor([[ 6., 9., 40.],
15
- # [ 9., 9., 9., 9., 9., 9., 9., 9., 9., 9.],
16
- # [40., 40., 40., 40., 40., 40., 40., 40., 40., 40.]])
16
+ # [27., 13., 15.]])
17
17
  ```

1

2023/03/21 11:11

投稿

melian
melian

スコア19749

test CHANGED
@@ -7,7 +7,7 @@
7
7
 
8
8
  index = torch.tensor([0, 3, 5, 10])
9
9
  data = torch.tensor([1,2,3,4,5,6,7,8,9,10], dtype=torch.float)
10
- out = torch.tensor([[torch.sum(i)] for i in torch.tensor_split(data, index)[1:-1]])\
10
+ out = torch.tensor([[torch.sum(i)] for i in torch.tensor_split(data, index) if len(i)])\
11
11
  .expand(-1, data.size(0))
12
12
  print(out)
13
13