pytorchで自動微分を行いたいときに、A1からA5の微分可能なテンソルを用意して、func内で縮約した結果に対してA1からA5で勾配を取得したいのですが、
辞書形式で、下のように関数名を取得すると、requires_gradが無効になってしまいます。
python
1Ai = locals().get(f'A{i+1}', 0)
下記のプログラムのように、縮約計算した後に勾配を得たいときに、動的に配列名を得る方法はありますか?
またA1~A5はグローバル変数として扱いたいです。
python
1import torch 2import numpy as np 3A1 = torch.arange(4).reshape(2,2).to(torch.float32) 4A2 = torch.arange(8).reshape(2,2,2).to(torch.float32) 5A3 = torch.arange(8).reshape(2,2,2).to(torch.float32) 6A4 = torch.arange(8).reshape(2,2,2).to(torch.float32) 7A5 = torch.arange(4).reshape(2,2).to(torch.float32) 8 9A1.requires_grad = True 10A2.requires_grad = True 11A3.requires_grad = True 12A4.requires_grad = True 13A5.requires_grad = True 14 15def func(A1,A2,A3,A4,A5): 16 I = torch.tensor([0, 1], dtype=torch.float32) 17 A = torch.einsum("ia,i->a",A1,I) 18 for i in range(1,4): 19 Ai = locals().get(f'A{i+1}', 0) 20 A = torch.einsum("a,aib,i->b",A, Ai,I) 21 y = torch.einsum("a,ai,i->...",A,A5,I) 22 return y 23 24y=func(A1,A2,A3,A4,A5) 25y.backward() 26print(A2.grad)

回答1件
あなたの回答
tips
プレビュー