質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.40%
CUDA

CUDAは並列計算プラットフォームであり、Nvidia GPU(Graphics Processing Units)向けのプログラミングモデルです。CUDAは様々なプログラミング言語、ライブラリ、APIを通してNvidiaにインターフェイスを提供します。

Julia

Juliaとは、科学技術計算に特化した、高水準・高性能な動的プログラミング言語です。オープンソースとして公表されており、書き易く動きが早いことが特徴です。

Q&A

解決済

1回答

421閲覧

CUDA.jlにおけるdot演算子の挙動に関して

NULNUL

総合スコア2

CUDA

CUDAは並列計算プラットフォームであり、Nvidia GPU(Graphics Processing Units)向けのプログラミングモデルです。CUDAは様々なプログラミング言語、ライブラリ、APIを通してNvidiaにインターフェイスを提供します。

Julia

Juliaとは、科学技術計算に特化した、高水準・高性能な動的プログラミング言語です。オープンソースとして公表されており、書き易く動きが早いことが特徴です。

0グッド

0クリップ

投稿2023/10/30 10:31

実現したいこと

GPUを用いた行列演算において、dot演算子を用いることで少しでもメモリ消費を抑えて高速化したい。
後述する通りdot演算子を使わなければ問題ないが、dot演算子を使わないとメモリ消費量が増えてメモリ不足となるため。

前提

Julia(v1.8.5), CUDA.jl(v4.3.2)を用いています。
また実際のスクリプトでは数万*数万の行列を複数使っており、nn_predict内の処理ももう少し複雑ですが、エラーが発生する最低限のスクリプトとして記載しました。

発生している問題・エラーメッセージ

後述するコードにおいて、p_buf.z2が処理ごとに変化します。
本来は、0.0 → 0.0 → 何らかの値(*) → (*)と同じ値 になってほしいですが
例えば 0.0 → 1.2 → 22 → -8.7 となります。

該当のソースコード

julia

1using CSV, DataFrames 2using CUDA 3using Random, Distributions 4 5Random.seed!(48) 6 7mutable struct TrainParams 8 w::CuArray{Float32, 2} 9 b::CuArray{Float32, 2} 10end 11mutable struct BufParams 12 z1::CuArray{Float32, 2} 13 z2::CuArray{Float32, 2} 14 z3::CuArray{Float32, 2} 15end 16 17function nn_predict(p_buf::BufParams, p1::TrainParams, p2::TrainParams, p3::TrainParams, x_train::CuArray{Float32}) 18 println(sum(p_buf.z2)) 19 p_buf.z1 .= p1.w * x_train .+ p1.b; 20 println(sum(p_buf.z2)) 21 p_buf.z2 .= p2.w* p_buf.z1 .+ p2.b; 22 println(sum(p_buf.z2)) 23 p_buf.z3 .= p3.w * p_buf.z2 .+ p3.b; 24 println(sum(p_buf.z2)) 25 26 return nothing 27end 28 29 30dist = Normal(0, sqrt(2.0f0/1000)) 31x_train = cu(rand(dist, 1000, 10)) 32 33p1 = TrainParams( 34 cu(rand(dist, 1000, 1000)), 35 cu(rand(dist, 1000, 1)) 36 ) 37p2 = TrainParams( 38 cu(rand(dist, 1000, 1000)), 39 cu(rand(dist, 1000, 1)) 40 ) 41p3 = TrainParams( 42 cu(rand(dist, 1000, 1000)), 43 cu(rand(dist, 1000, 1)) 44 ) 45 46buf = CUDA.zeros(1000, 10) 47p_buf = BufParams( 48 buf, 49 buf, 50 buf 51 ) 52 53nn_predict(p_buf, p1, p2, p3, x_train)

試したこと

dot演算子ではなく代入にすれば、きちんと
0.0 → 0.0 → 何らかの値(*) → (*)と同じ値
となります。
また、CUDAを使わず全てArray{}で行うとdot演算子でもきちんと
0.0 → 0.0 → 何らかの値(*) → (*)と同じ値
となります。

質問

前述したようにdot演算子を使わないとメモリ不足になるため出来るだけdot演算子を使いたいのですが

  1. なぜdot演算子を使うと値が処理ごとに変化するのでしょうか?
  2. どのように修正すればdot演算子を使っても想定通りの挙動になるのでしょうか?
  3. もしくはdot演算子を使わなくてもメモリ消費量を抑える方法があれば教えてください

以上です。

よろしくお願いします。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

.=で値がそれぞれ変化するのは,例示されているプログラムからだけで言うと,

julia

1buf = CUDA.zeros(1000, 10) 2p_buf = BufParams( 3 buf, 4 buf, 5 buf 6 )

の部分で,z1, z2, z3が同じバッファに割り当てられているからだと思います。
この部分をそれぞれ領域を確保する

julia

1p_buf = BufParams( 2 CUDA.zeros(1000, 10), 3 CUDA.zeros(1000, 10), 4 CUDA.zeros(1000, 10) 5 )

にすると改善するのではないでしょうか?

=なら,その度に新規でメモリが割り合てられるので最初z1,z2,z3が同じバッファでも
代入時にそれぞれ別のメモリとなるので問題が露見しないことになります。

.=だと最初に確保したメモリをそのまま使い続けるので,z1,z2,z3が同じバッファを利用している場合,
それぞれの値の更新時に,z1,z2,z3の値が変化することになります。

蛇足

この後に書く事は余計なお世話なのですが,外部コンストラクタを使って

julia

1BufParams(r::Int, c::Int) = BufParams( 2 CUDA.zeros(r, c), CUDA.zeros(r, c), CUDA.zeros(r, c) 3)

のように定義しておくと,

julia

1p_buf = BufParams(1000, 10)

のような形で利用できて便利かもしれません。

投稿2023/10/30 13:46

編集2023/10/30 14:30
ujimushi_sradjp

総合スコア2133

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

NULNUL

2023/10/30 23:54

ありがとうございます。 `BufParams()`の初期化の時に代入としてメモリが確保されるのだと勝手に思っていました。 ご教示いただいた通りに修正したところ、想定通りの挙動となりました。 また、外部コンストラクタの書き方を知らなかったのでご教示くださりありがとうございました。 こちらに変更しようと思います。 その上で恐縮なのですが、差し支えなければ以下についてご教示いただけますと幸いです。 + 原因が「z1,z2,z3が同じバッファであること」でしたが、なぜCPUの時にはdot演算子でも問題なく動いたのでしょうか?どのようなキーワードで調べればいいかも分からないため、参考になるサイトを教えていただけるだけでもありがたいです。
ujimushi_sradjp

2023/10/31 00:46

基本的には返答しないのですが,所用で休暇をとっているので一言だけ。 Base.zeros(1000, 10)は型指定なしだとFloat64型の行列になる(CUDA.zerosはデフォルトがFoat32型)ので, コンストラクタの引数に渡した時にFloat64→Float32の型変換が発生し,Float32型の行列が新規に作成されているので, 見かけ上うまくいっていたち見えたのではないかと。 CPUでもbuf =zeros(Float32, 1000, 10)とかだと同じ問題が発生すると想像できます。
NULNUL

2023/10/31 02:41

理解できました。ありがとうございました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.40%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問