Q&A
実現したいこと
ある配列data(mxd)とインデックス(n)が存在します.
indexはcumsum方式になっています. これはcsr_matrixのdataとindptrと同じです.
例えばdataはm=10, d=1とし[1,2,3,4,5,6,7,8,9,10]とします.
indexはi番目の要素がdataの何番目までかを示します.
例えばn=4のとき[0, 3, 5, 10]のようになっています. これは 出力の要素が3つあり, 1番目の要素が0から2まで, 2番目の要素が3から4まで, 3番目の要素が5から9までの要素を足し合わせることを指します.
よって出力は[6(1+2+3), 9(4+5), 40(6+7+8+9+10)]となります.
これをpytorchで計算したいのですが良い計算方法はないでしょうか. よろしくお願いします.
該当のソースコード
普通に実装すると以下のようになるのですがfor文を使わず関数で綺麗に書けないでしょうか
python
1out = torch.zeros((index.size(0)-1, data.size(1)) 2for i in range(index.size(0)-1): 3 out[i] = torch.mean(data[index[i]:index[i+1], dim=0)
python
1out = torch.zeros((index.size(0)-1, data.size(1)) 2target = torch.zeros_like(index) 3for i in range(index.size(0)-1): 4 target[index[i]:index[i+1]]=i 5out = torch_scatter.scatter(data, target, out=out, dim=0, reduce='mean')
補足情報(FW/ツールのバージョンなど)
python 3.10
pytorch 2.0.9
回答1件
あなたの回答
tips
プレビュー
下記のような回答は推奨されていません。
このような回答には修正を依頼しましょう。
2023/03/21 12:03
2023/03/21 12:57
2023/04/02 08:15