pythonでnumpyのeinsumを使用するときに9個の配列を1つのデータとしてを一つずつ処理しているのですが、
これを複数入力データに対応できるようにバッチ処理を行いたいのですが、どのようにすればよいでしょうか。
バッチ処理していないeinsum計算が下記です。
python
1import numpy as np 2 3def einsum_test(x): 4 A1 = np.arange(4).reshape(2,2) 5 A2 = np.arange(8).reshape(2,2,2) 6 A3 = np.arange(4).reshape(2,2) 7 8 I0 = torch.tensor([1, 0]) 9 I1 = torch.tensor([0, 1]) 10 11 I = I0 if x[0] == 0 else I1 12 y = np.einsum("ia,i->a",A1,I) 13 I = I0 if x[1] == 0 else I1 14 y = np.einsum("a,aib,i->b",y,A2,I) 15 I = I0 if x[2] == 0 else I1 16 y = np.einsum("a,ai,i->...",y,A3,I) 17 18 return y 19 20x = np.array([1,1,1]) 21y = einsum_test(x) 22print(y)
output
1103
この処理を複数の入力で処理したいです。
xを多次元配列にして、これを入力としたいのですが、
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
とエラーが出てしまいます。
x
1x = np.array([[1,1,1],[0,1,0],[0,0,1]])
python
1import numpy as np 2 3def einsum_test_batch(x): 4 A1 = np.arange(4).reshape(2,2) 5 A2 = np.arange(8).reshape(2,2,2) 6 A3 = np.arange(4).reshape(2,2) 7 8 I0 = torch.tensor([1, 0]) 9 I1 = torch.tensor([0, 1]) 10 11 I = I0 if x[:,0] == 0 else I1 12 y = np.einsum("ia,i->a",A1,I) 13 I = I0 if x[:,1] == 0 else I1 14 y = np.einsum("a,aib,i->b",y,A2,I) 15 I = I0 if x[:,2] == 0 else I1 16 y = np.einsum("a,ai,i->...",y,A3,I) 17 18 return y 19 20x = np.array([[1,1,1],[0,1,0],[0,0,1]]) 21y = einsum_test_batch(x) 22print(y)
このプログラムを修正して下記の出力となるようにできますか?
expect
1103 214 319
回答1件
あなたの回答
tips
プレビュー