for
ループ内の処理をマルチスレッドで並列化するにはRayon クレートを使うのが便利です。これはfor
ループを並列イテレーター(parallel iterator)を使った形式(例:c.par_iter().map(|item| ... ).collect()
に書き換えるだけで、map()
に書いた処理をマルチスレッドで並列処理してくれるものです。ここではRayonを使った方法を紹介します。
ndarrayはCargo.tomlでrayon
フィーチャーを指定すると、Array3
などにpar_iter()
メソッドが追加され、直接Rayonの並列イテレーターを作れるようになります。ところが、ご質問のコードではこのフィーチャーは使えません。なぜならfor
の1回のループ内でアクセスする要素(ndarrayのスライス)が1つではなく、time - 1
とtime
で指定した2つだからです。Rayonにはこういう複数の要素ずつイテレートするために、par_windows()
が用意されていますが、ndarrayのrayon
フィーチャーはpar_windows()
をサポートしていません。
そこで、今回はc
のArray3<f32>
をまずVec<ArrayView2<&f32>>
に変換し、そこからRayonのpar_windows()
を使って2要素ずつの並列イテレーターを作ることにします。
まずは変換部分のコードです。
toml
1 # Cargo.tomlにrayonを追加する
2
3 [ dependencies ]
4 ndarray = "0.14.0"
5 rayon = "1.5.0"
rust
1 use ndarray :: { arr3 , s , Array , Array2 , Array3 } ;
2 use rayon :: prelude :: * ;
3
4 fn cal_v ( c : & Array3 < f32 > , dt : f32 ) -> Array3 < f32 > {
5 // Rayonのpar_windowsを使うため、一番外側の次元(1次元目)でイテレートし、
6 // Vec<ArrayView2<&f32>>に収集する
7 let c_view2_vec : Vec < _ > = c . outer_iter ( ) . collect ( ) ;
8
9 ...
10 }
次にpar_windows()
による並列計算です。
rust
1 fn cal_v ( c : & Array3 < f32 > , dt : f32 ) -> Array3 < f32 > {
2 ...
3
4 // Rayonのpar_windowsで並列計算する。結果をVec<Array2<f32>>に収集する
5 let v_array2_vec : Vec < Array2 < _ >> = c_view2_vec
6 . par_windows ( 2 ) // 2要素ずつイテレートする
7 . map ( | w | ( & w [ 0 ] - & w [ 1 ] ) / dt )
8 . collect ( ) ;
9
10 ...
11 }
これによりmap
内に書いた処理が複数のスレッドで並列計算されます。collect()
を使って計算結果をVec<Array2<f32>>
に収集します。
最後にVec<Array2<f32>>
からArray3<f32>
に変換します。が、私はndarrayを使うのが初めてなことがあり、簡単に書く方法がわかりませんでした。しかたがないのでご質問のコードを真似てfor
とassign
で変換しています。
rust
1 fn cal_v ( c : & Array3 < f32 > , dt : f32 ) -> Array3 < f32 > {
2 ...
3
4 // Vec<Array2<f32>>をArray3<f32>に変換する。もっと簡単に書く方法があるかも
5 let mut v = Array :: zeros ( c . raw_dim ( ) ) ;
6 for ( time , now_v ) in v_array2_vec . into_iter ( ) . enumerate ( ) {
7 v . slice_mut ( s! [ time + 1 , .. , .. , ] ) . assign ( & now_v ) ;
8 }
9 v
10 }
これで完成です。mainを含めた全体のプログラムは以下のようになりました。
rust
1 // Cargo.toml
2 //
3 // [dependencies]
4 // ndarray = "0.14.0"
5 // rayon = "1.5.0"
6
7 use ndarray :: { arr3 , s , Array , Array2 , Array3 } ;
8 use rayon :: prelude :: * ;
9
10 fn cal_v ( c : & Array3 < f32 > , dt : f32 ) -> Array3 < f32 > {
11 // Rayonのpar_windowsを使うため、一番外側の次元(1次元目)でイテレートし、
12 // Vec<ArrayView2<&f32>>に収集する
13 let c_view2_vec : Vec < _ > = c . outer_iter ( ) . collect ( ) ;
14
15 // Rayonのpar_windowsで並列計算する。結果をVec<Array2<f32>>に収集する
16 let v_array2_vec : Vec < Array2 < _ >> = c_view2_vec
17 . par_windows ( 2 )
18 . map ( | w | ( & w [ 0 ] - & w [ 1 ] ) / dt )
19 . collect ( ) ;
20
21 // Vec<Array2<f32>>をArray3<f32>に変換する。もっと簡単に書く方法があるかも
22 let mut v = Array :: zeros ( c . raw_dim ( ) ) ;
23 for ( time , now_v ) in v_array2_vec . into_iter ( ) . enumerate ( ) {
24 v . slice_mut ( s! [ time + 1 , .. , .. , ] ) . assign ( & now_v ) ;
25 }
26 v
27 }
28
29 fn main ( ) {
30 let c = arr3 ( & [
31 [ [ 1.0 , - 2.0 , 3.0 ] , [ - 1.0 , 2.0 , - 3.0 ] ] ,
32 [ [ 6.0 , - 5.0 , 4.0 ] , [ - 9.0 , 8.0 , - 7.0 ] ] ,
33 [ [ - 12.0 , 11.0 , - 10.0 ] , [ 15.0 , - 14.0 , 13.0 ] ] ,
34 ] ) ;
35 let v = cal_v ( & c , 0.25 ) ;
36 println! ( "{}" , v ) ;
37
38 let expected = arr3 ( & [
39 [ [ 0.0 , 0.0 , 0.0 ] , [ 0.0 , 0.0 , 0.0 ] ] ,
40 [ [ - 20.0 , 12.0 , - 4.0 ] , [ 32.0 , - 24.0 , 16.0 ] ] ,
41 [ [ 72.0 , - 64.0 , 56.0 ] , [ - 96.0 , 88.0 , - 80.0 ] ] ,
42 ] ) ;
43
44 assert_eq! ( v , expected ) ;
45 }
追記(2021年4月10日)
上で紹介した方法は、collect()
のたびに新しいVec<_>
を作るため、ご質問の元のコードに比べて、その分が余分な処理になっています。
ndarrayのドキュメントをもう少し読んでみたところ、もっと効率の良いやり方があることに気づきました。以下の2点を組み合わせます。
par_windows()
ではなく、c
のndarrayスライスを2つ作ってZip
で対にする
Zip
のpar_map_collect()
を使用する。これは、Rayonの並列イテレーター作成 → map
→ collect
をまとめてやってくれる。
この方法ならコードが簡単になるうえ、Vec<_>
を作る必要がないので効率よく処理できるはずです。
まず、par_map_collect()
を使うためにndarrayを最新の0.15.xにします。またrayon
フィーチャーを指定します。
toml
1 # Cargo.tomlを修正する
2
3 [ dependencies ]
4 ndarray = { version = "0.15.1" , features = [ "rayon" ] }
use
文とcal_v()
関数を修正します。
rust
1 use ndarray :: { arr3 , s , Array3 , Zip } ;
2
3 fn cal_v ( c : & Array3 < f32 > , dt : f32 ) -> Array3 < f32 > {
4 let c_len0 = c . shape ( ) [ 0 ] ;
5
6 // 1つ目のスライスは1次元目(time)の最初の要素から、最後の1つ前の要素まで
7 let c_before = c . slice ( s! [ .. ( c_len0 - 1 ) , .. , .. ] ) ;
8 // 2つ目のスライスは1次元目(time)の2番目の要素から、最後の要素まで
9 let c_now = c . slice ( s! [ 1 .. , .. , .. ] ) ;
10
11 // 2つのスライスをzipして、par_map_collectで並列処理する
12 Zip :: from ( c_before )
13 . and ( c_now )
14 . par_map_collect ( | before , now | ( before - now ) / dt )
15 }
ただし、この方法を使うと、ご質問のコードとは異なり、計算結果v
の1次元目の最初の要素(値が全て0
)が作られなくなります。つまり、以下のmain()
関数のexpected
のコメントアウトした行のデータがなくなります。
rust
1 fn main ( ) {
2 let c = arr3 ( & [
3 [ [ 1.0 , - 2.0 , 3.0 ] , [ - 1.0 , 2.0 , - 3.0 ] ] ,
4 [ [ 6.0 , - 5.0 , 4.0 ] , [ - 9.0 , 8.0 , - 7.0 ] ] ,
5 [ [ - 12.0 , 11.0 , - 10.0 ] , [ 15.0 , - 14.0 , 13.0 ] ] ,
6 ] ) ;
7 let v = cal_v ( & c , 0.25 ) ;
8 println! ( "{}" , v ) ;
9
10 let expected = arr3 ( & [
11 // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], // この要素は作られなくなる
12 [ [ - 20.0 , 12.0 , - 4.0 ] , [ 32.0 , - 24.0 , 16.0 ] ] ,
13 [ [ 72.0 , - 64.0 , 56.0 ] , [ - 96.0 , 88.0 , - 80.0 ] ] ,
14 ] ) ;
15
16 assert_eq! ( v , expected ) ;
17 }