質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

87.49%

CPUによる計算環境において,JAXの行列計算速度がnumpyに劣ってしまう問題

解決済

回答 2

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 147

score 1

前提・実現したいこと

私は機械学習のフレームであるjaxを学び始めております.
そこで,以下のkaggleにあるjaxの紹介を基に勉強しておりました.
https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-4-jax-and-devicearray/notebook

jaxの強みは高速な自動微分であると思っていたのですが, numpyよりも遅くなってしまう事例が発生しました. 皆様には,なぜjaxの計算が遅くなってしまったのかについて, ご意見をお聞きしたいと思っております.
(追記 質問内容が意図と異なっていたため修正いたします)
google colab上でCPUによる計算を行わせたところ,jaxの行列計算速度がnumpyよりも遅くなってしまう事例が発生しました.
jaxの強みは高速な自動微分であると思っていたため,CPU上では計算速度が低くなってしまうことを不思議に思い,
また,jaxの知識が浅くよくわかっていないため,皆様へご意見をお聞きしたいと思います.
jaxはCPUの計算環境においては,計算速度がnumpyよりも劣ってしまうのでしょうか?

発生している問題・エラーメッセージ

https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-4-jax-and-devicearray/notebook
の[12]における

# Dot product on ndarray
start_time = time.time()
res = np.dot(array1, array1)
print(f"Time taken by dot product op on ndarrays: {time.time()-start_time:.2f} seconds")

# Dot product on DeviceArray
start_time = time.time()
res = jnp.dot(array2, array2)
print(f"Time taken by dot product op on DeviceArrays: {time.time()-start_time:.2f} seconds")


は,ページ内では
Time taken by dot product op on ndarrays: 7.94 seconds
Time taken by dot product op on DeviceArrays: 0.02 seconds
とDeviceArrayの高速性が発揮していますが,

私の環境(Google Colab/ CPU)で同等のコードで実行したところ,
Time taken by dot product op on ndarrays: 14.60 seconds
Time taken by dot product op on DeviceArrays: 16.78 seconds
と,numpyよりも計算が遅い結果となってしまいました.

試したこと

私のコードは,全てからコピペしたものであり,全く同じコードです. それなのに,計算速度に違いが生まれてしまうことが不思議でなりません.
唯一出たエラーコードは [1]を実行したとき:/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:8: UserWarning: Config option use_jedi not recognized by IPCompleter. [2]を実行したとき:WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jaxはCPUの計算環境においては,計算速度がnumpyよりも劣ってしまうのでしょうか?
お詳しい方がいらっしゃいましたら,ご意見をいただけますと幸いです.
よろしくお願いいたします.

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • キャンセル

  • quickquip

    2021/10/14 11:42 編集

    > CPUで計算を行ったとしても,Jaxの計算速度が高速である,という考えだったため,
    質問の「jaxは計算の違いで」のあたりがそういう意図なのですね。
    その部分を(ppaulさんの回答と食い違わない範囲で)書き換えた方がいいかもと思います。今からでも。

    キャンセル

  • guratan

    2021/10/14 14:22

    >その部分を(ppaulさんの回答と食い違わない範囲で)書き換えた方がいいかもと思います。今からでも。
    おっしゃる通りです.私の質問文の説明が不足しておりました...
    そのように修正しようと思います.ご指摘ありがとうございます.

    キャンセル

回答 2

checkベストアンサー

+1

GPUもTPUも付いていないハードで実行するなら、JAXがインテルのMKLを使ったnumpyよりも遅いことはいくらでもあるでしょう。

JAX入門~高速なNumPyとして使いこなすためのチュートリアル~などには明確には書かれていませんが、結局jaxはGoogleが開発したのために開発されたライブラリで、GPUにも使えますという位置づけのように見えます。
インテルは自社のハード向けにMKLに資金を投入していますから、JAXでCPU向けにJITしたのではMKLとの勝負ではそれほど強いとは思えませんね。

JAXを使いたければればゲーミングPCに買い換えるか、Google Colabを使うべきでしょう。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2021/10/14 11:38

    ご返信ありがとうございます.
    やはり,CPUではjaxの性能を発揮できないようですね...
    大変参考になりました.

    キャンセル

  • 2021/10/14 17:34

    > CPUではjaxの性能を発揮できない

    https://qiita.com/koshian2/items/44a871386576b4f80aff
    の「パフォーマンス比較~CPU~」では、大きなサイズの配列の要素積では、jitありJAXの方がNumPyよりも高速だったと書かれてます
    CPUでも、JAXが優位になるような条件の場合は、使う価値はあると思いますよ

    キャンセル

0

google colab pro + GPU使用に変えたところ,
Time taken by dot product op on ndarrays: 14.93 seconds
Time taken by dot product op on DeviceArrays: 3.53 seconds
とjaxの計算速度が高速になりました.
jaxを使用するときは,機械学習なら当然ですがGPU/TPUを使用することが前提のようですね.

ご返信いただいた皆様,ありがとうございました.

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 87.49%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る