質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

87.78%

@jitで引数の型を指定した時に出るエラーについて

受付中

回答 1

投稿

  • 評価
  • クリップ 1
  • VIEW 82

score 1

pythonのjitを使用して、処理速度を高速化しようとしています。
いくつか試して、高速化に成功した例もあるのですが
arrayを引数にした時に失敗します。

失敗例

import numpy as np
@jit('f8(i4[:],f8, f4[:, :])', nopython=True)
def f0(u,theta, para_list):

    dot_list=[]
    for j in range(4):
        a,b,c=para_list[j]
        dot_list.append(P(theta,a,b,c)**u[j]*((1-P(theta,a,b,c))**(1-u[j])))

    return np.prod(dot_list)


para_list2=np.array([[ 0.42,  -1.584,  0], [ 0.326 ,-1.726  ,0],[ 0.547, -2.235 , 0],[ 0.561, -1.861 , 0.]])

u=np.array([0,0,0,0])
f0(u, 0, para_list2)


実行結果

--------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-29-0ff85004516e> in <module>
      1 import time
      2 @jit('f8(f8[:, :], f8[:, :])', nopython=True)
----> 3 def pairwise_numba2(X, D):
      4     M = X.shape[0]
      5     N = X.shape[1]

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/decorators.py in wrapper(func)
    219             with typeinfer.register_dispatcher(disp):
    220                 for sig in sigs:
--> 221                     disp.compile(sig)
    222                 disp.disable_compile()
    223         return disp

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/dispatcher.py in compile(self, sig)
    907                 with ev.trigger_event("numba:compile", data=ev_details):
    908                     try:
--> 909                         cres = self._compiler.compile(args, return_type)
    910                     except errors.ForceLiteralArg as e:
    911                         def folded(args, kws):

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/dispatcher.py in compile(self, args, return_type)
     81             return retval
     82         else:
---> 83             raise retval
     84 
     85     def _compile_cached(self, args, return_type):

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/dispatcher.py in _compile_cached(self, args, return_type)
     91 
     92         try:
---> 93             retval = self._compile_core(args, return_type)
     94         except errors.TypingError as e:
     95             self._failed_cache[key] = e

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/dispatcher.py in _compile_core(self, args, return_type)
    104 
    105         impl = self._get_implementation(args, {})
--> 106         cres = compiler.compile_extra(self.targetdescr.typing_context,
    107                                       self.targetdescr.target_context,
    108                                       impl,

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    604     pipeline = pipeline_class(typingctx, targetctx, library,
    605                               args, return_type, flags, locals)
--> 606     return pipeline.compile_extra(func)
    607 
    608 

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler.py in compile_extra(self, func)
    351         self.state.lifted = ()
    352         self.state.lifted_from = None
--> 353         return self._compile_bytecode()
    354 
    355     def compile_ir(self, func_ir, lifted=(), lifted_from=None):

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler.py in _compile_bytecode(self)
    413         """
    414         assert self.state.func_ir is None
--> 415         return self._compile_core()
    416 
    417     def _compile_ir(self):

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler.py in _compile_core(self)
    393                 self.state.status.fail_reason = e
    394                 if is_final_pipeline:
--> 395                     raise e
    396         else:
    397             raise CompilerError("All available pipelines exhausted")

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler.py in _compile_core(self)
    384             res = None
    385             try:
--> 386                 pm.run(self.state)
    387                 if self.state.cr is not None:
    388                     break

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler_machinery.py in run(self, state)
    337                     (self.pipeline_name, pass_desc)
    338                 patched_exception = self._patch_error(msg, e)
--> 339                 raise patched_exception
    340 
    341     def dependency_analysis(self):

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler_machinery.py in run(self, state)
    328                 pass_inst = _pass_registry.get(pss).pass_inst
    329                 if isinstance(pass_inst, CompilerPass):
--> 330                     self._runPass(idx, pass_inst, state)
    331                 else:
    332                     raise BaseException("Legacy pass in use")

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     33         def _acquire_compile_lock(*args, **kwargs):
     34             with self:
---> 35                 return func(*args, **kwargs)
     36         return _acquire_compile_lock
     37 

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler_machinery.py in _runPass(self, index, pss, internal_state)
    287             mutated |= check(pss.run_initialization, internal_state)
    288         with SimpleTimer() as pass_time:
--> 289             mutated |= check(pss.run_pass, internal_state)
    290         with SimpleTimer() as finalize_time:
    291             mutated |= check(pss.run_finalizer, internal_state)

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/compiler_machinery.py in check(func, compiler_state)
    260 
    261         def check(func, compiler_state):
--> 262             mangled = func(compiler_state)
    263             if mangled not in (True, False):
    264                 msg = ("CompilerPass implementations should return True/False. "

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/typed_passes.py in run_pass(self, state)
    102                               % (state.func_id.func_name,)):
    103             # Type inference
--> 104             typemap, return_type, calltypes, errs = type_inference_stage(
    105                 state.typingctx,
    106                 state.func_ir,

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/typed_passes.py in type_inference_stage(typingctx, interp, args, return_type, locals, raise_errors)
     80         infer.build_constraint()
     81         # return errors in case of partial typing
---> 82         errs = infer.propagate(raise_errors=raise_errors)
     83         typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     84 

~/opt/anaconda3/lib/python3.8/site-packages/numba/core/typeinfer.py in propagate(self, raise_errors)
   1069                                   if isinstance(e, ForceLiteralArg)]
   1070                 if not force_lit_args:
-> 1071                     raise errors[0]
   1072                 else:
   1073                     raise reduce(operator.or_, force_lit_args)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No conversion from array(float64, 2d, A) to float64 for '$124return_value.1', defined at None

File "<ipython-input-29-0ff85004516e>", line 14:
def pairwise_numba2(X, D):
    <source elided>

    return D
    ^

During: typing of assignment at <ipython-input-29-0ff85004516e> (14)

File "<ipython-input-29-0ff85004516e>", line 14:
def pairwise_numba2(X, D):
    <source elided>

    return D
    ^

因みに成功例は

@jit('f8(f8, f8, f8, f8)')
def P(theta,a,b,c):
    return c+(1-c)*(1/(1+math.exp(-1.7*a*(theta-b))))

a,b,c=[ 0.42,  -1.584,  0]
P(0,a,b,c)


結果

0.7560189714593805

どなたか助けていただけると嬉しいです。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • bsdfan

    2021/07/23 13:00

    例のコードと結果が一致してないので、正確には分かりませんが、型を指定したらエラーが出るということは、型に矛盾があるのではないでしょうか。

    キャンセル

回答 1

0

最後の要素が0.になっているのが気になります・・・

[ 0.561, -1.861 , 0.]

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 87.78%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る