前提・実現したいこと
numbaを用いて行列計算を高速にしたいと考えています。
numba.jitにおいて型指定を行うとより早くなるそうなので試しています。
発生している問題・エラーメッセージ
型推定に失敗?しているようです。
Compilation is falling back to object mode WITH looplifting enabled because Function "compute" failed type inference due to: Invalid use of Function(<built-in function zeros>) with argument(s) of type(s): ((Literal[int](10), array(float64, 1d, A)), dtype=Literal[str](f8)) * parameterized In definition 0: All templates rejected with literals. In definition 1: All templates rejected without literals. This error is usually caused by passing an argument of a type that is unsupported by the named function. [1] During: resolving callee type: Function(<built-in function zeros>) [2] During: typing of call at <ipython-input-45-5c375a54c6a4> (6) File "<ipython-input-45-5c375a54c6a4>", line 6: def compute(T): <source elided> dt = 0.001 X = np.zeros((N,T),dtype='f8') ^ @numba.jit('f8(f8[:])') <ipython-input-45-5c375a54c6a4>:3: NumbaWarning: Compilation is falling back to object mode WITHOUT looplifting enabled because Function "compute" failed type inference due to: cannot determine Numba type of <class 'numba.dispatcher.LiftedLoop'> File "<ipython-input-45-5c375a54c6a4>", line 12: def compute(T): <source elided> for t in range(0,T-1,1): ^ @numba.jit('f8(f8[:])') C:\Users\takuy\Anaconda3\envs\keras_env\lib\site-packages\numba\object_mode_passes.py:178: NumbaWarning: Function "compute" was compiled in object mode without forceobj=True, but has lifted loops. File "<ipython-input-45-5c375a54c6a4>", line 4: @numba.jit('f8(f8[:])') def compute(T): ^ state.func_ir.loc)) C:\Users\takuy\Anaconda3\envs\keras_env\lib\site-packages\numba\object_mode_passes.py:187: NumbaDeprecationWarning: Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour. For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit File "<ipython-input-45-5c375a54c6a4>", line 4: @numba.jit('f8(f8[:])') def compute(T): ^ warnings.warn(errors.NumbaDeprecationWarning(msg, state.func_ir.loc)) <ipython-input-45-5c375a54c6a4>:16: NumbaPerformanceWarning: '@' is faster on contiguous arrays, called on (array(float64, 2d, C), array(float64, 1d, A)) X_2[:,t+1] = -M_inv@K@X[:,t+1]
該当のソースコード
python
1 2import numpy as np 3from scipy import linalg 4 5m=1 6g=9.8 7k=100 8N=10 9 10#剛性マトリックスの作成 11K=np.eye(N,dtype=float)*2-np.eye(N,k=1)-np.eye(N,k=-1) 12K[0,0]=2;K[N-1,N-1]=1 13M=np.eye(N,dtype=float) 14K_inv=np.linalg.inv(K) 15M_inv=np.linalg.inv(M) 16 17 18import numba 19 20#問題の部分 21@numba.jit('f8(f8[:])') 22def compute(T): 23 dt = 0.001 24 X = np.zeros((N,T),dtype='f8') 25 V = np.zeros((N,T),dtype='f8') 26 V[-1,0] = 1.0 27 X_2 = np.zeros((N,T),dtype='f8') 28 print(V) 29 30 for t in range(0,T-1,1): 31 32 X[:,t+1] = X[:,t] + dt*V[:,t] 33 V[:,t+1] = V[:,t] + dt*X_2[:,t] 34 X_2[:,t+1] = -M_inv@K@X[:,t+1] 35 36 37 return (X) 38 39%time X_=compute(int(1e4)) 40 41 42
試したこと
すべてのarrayををfloat64に型指定しましたが,推定に失敗するようです
補足情報(FW/ツールのバージョンなど)
jupyter notebook 使用
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。