下記コードでどんな算法を実現しているのでしょうか。
(t : Pytorch tensor )
Python
1while len(t.shape) > 1: 2 t = t.sum(-1)
👆はt.sum()と等価では❓
説明できる方是非宜しくお願いいたします。
気になる質問をクリップする
クリップした質問は、後からいつでもMYページで確認できます。
またクリップした質問に回答があった際、通知やメールを受け取ることができます。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
回答1件
0
ベストアンサー
まあ,teratailで質問するよりは自分で例を作ってやってみることですよね。
3次元の4×3×2の行列を例に実行させてみます。
比較としてjulia言語のソースです。
juliaとpythonではreshapeの仕様が違うのでpythonに合わせます。
julia
1using PyCall 2 3py""" 4import torch 5 6t = torch.reshape(torch.arange(4.0*3.0*2.0), (4, 3, 2)) 7 8print("t=\n", t) 9print("t.sum()=", t.sum()) 10 11while len(t.shape) > 1: 12 t = t.sum(-1) 13 14print("while ...t.sum(-1) = ", t,"\n") 15""" 16 17a = reshape((1:(4*3*2)) .- 1.0, (2, 3, 4)) 18t = PermutedDimsArray(a, (3,2,1)) 19 20for k=1:4 21 println("t[$k,:,:]=$(t[k,:,:])") 22end 23 24println("(julia) sum 2,3次元方向の和=$(sum(t, dims=2:ndims(t)))") 25
結果は次の感じ。
julia> include("test20221029.jl") t= tensor([[[ 0., 1.], [ 2., 3.], [ 4., 5.]], [[ 6., 7.], [ 8., 9.], [10., 11.]], [[12., 13.], [14., 15.], [16., 17.]], [[18., 19.], [20., 21.], [22., 23.]]]) t.sum()= tensor(276.) while ...t.sum(-1) = tensor([ 15., 51., 87., 123.]) t[1,:,:]=[0.0 1.0; 2.0 3.0; 4.0 5.0] t[2,:,:]=[6.0 7.0; 8.0 9.0; 10.0 11.0] t[3,:,:]=[12.0 13.0; 14.0 15.0; 16.0 17.0] t[4,:,:]=[18.0 19.0; 20.0 21.0; 22.0 23.0] (julia) sum 2,3次元方向の和=[15.0; 51.0; 87.0; 123.0;;;]
処理としては「2次元目以降の和をとる」ということになるかと思います。
次は順番にt.sum(-1)
した例
python
1>>> t = torch.reshape(torch.arange(4.0*3.0*2.0), (4, 3, 2)) 2>>> t 3tensor([[[ 0., 1.], 4 [ 2., 3.], 5 [ 4., 5.]], 6 7 [[ 6., 7.], 8 [ 8., 9.], 9 [10., 11.]], 10 11 [[12., 13.], 12 [14., 15.], 13 [16., 17.]], 14 15 [[18., 19.], 16 [20., 21.], 17 [22., 23.]]]) 18>>> t.sum(-1) 19tensor([[ 1., 5., 9.], 20 [13., 17., 21.], 21 [25., 29., 33.], 22 [37., 41., 45.]]) 23>>> t.sum(-1).sum(-1) 24tensor([ 15., 51., 87., 123.])
投稿2022/10/29 12:39
総合スコア2091
あなたの回答
tips
太字
斜体
打ち消し線
見出し
引用テキストの挿入
コードの挿入
リンクの挿入
リストの挿入
番号リストの挿入
表の挿入
水平線の挿入
プレビュー
質問の解決につながる回答をしましょう。 サンプルコードなど、より具体的な説明があると質問者の理解の助けになります。 また、読む側のことを考えた、分かりやすい文章を心がけましょう。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2022/10/30 02:49