前提・実現したいこと
私は機械学習のフレームであるjaxを学び始めております.
そこで,以下のkaggleにあるjaxの紹介を基に勉強しておりました.
https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-4-jax-and-devicearray/notebook
jaxの強みは高速な自動微分であると思っていたのですが,
numpyよりも遅くなってしまう事例が発生しました.
皆様には,なぜjaxの計算が遅くなってしまったのかについて,
ご意見をお聞きしたいと思っております.
(追記 質問内容が意図と異なっていたため修正いたします)
google colab上でCPUによる計算を行わせたところ,jaxの行列計算速度がnumpyよりも遅くなってしまう事例が発生しました.
jaxの強みは高速な自動微分であると思っていたため,CPU上では計算速度が低くなってしまうことを不思議に思い,
また,jaxの知識が浅くよくわかっていないため,皆様へご意見をお聞きしたいと思います.
jaxはCPUの計算環境においては,計算速度がnumpyよりも劣ってしまうのでしょうか?
発生している問題・エラーメッセージ
https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-4-jax-and-devicearray/notebook
の[12]における
# Dot product on ndarray start_time = time.time() res = np.dot(array1, array1) print(f"Time taken by dot product op on ndarrays: {time.time()-start_time:.2f} seconds") # Dot product on DeviceArray start_time = time.time() res = jnp.dot(array2, array2) print(f"Time taken by dot product op on DeviceArrays: {time.time()-start_time:.2f} seconds")
は,ページ内では
Time taken by dot product op on ndarrays: 7.94 seconds
Time taken by dot product op on DeviceArrays: 0.02 seconds
とDeviceArrayの高速性が発揮していますが,
私の環境(Google Colab/ CPU)で同等のコードで実行したところ,
Time taken by dot product op on ndarrays: 14.60 seconds
Time taken by dot product op on DeviceArrays: 16.78 seconds
と,numpyよりも計算が遅い結果となってしまいました.
試したこと
私のコードは,全てからコピペしたものであり,全く同じコードです.
それなのに,計算速度に違いが生まれてしまうことが不思議でなりません.
唯一出たエラーコードは
[1]を実行したとき:/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:8: UserWarning: Config option use_jedi
not recognized by IPCompleter
.
[2]を実行したとき:WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jaxはCPUの計算環境においては,計算速度がnumpyよりも劣ってしまうのでしょうか?
お詳しい方がいらっしゃいましたら,ご意見をいただけますと幸いです.
よろしくお願いいたします.
回答2件
あなたの回答
tips
プレビュー