質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
Rust

Rustは、MoFoが支援するプログラミング言語。高速性を維持しつつも、メモリ管理を安全に行うことが可能な言語です。同じコンパイル言語であるC言語やC++では困難だったマルチスレッドを実装しやすく、並行性という点においても優れています。

Q&A

解決済

1回答

963閲覧

rust ndarrayでf32およびf64に対して乱数の行列を生成する

BoKuToTuZenU

総合スコア51

Rust

Rustは、MoFoが支援するプログラミング言語。高速性を維持しつつも、メモリ管理を安全に行うことが可能な言語です。同じコンパイル言語であるC言語やC++では困難だったマルチスレッドを実装しやすく、並行性という点においても優れています。

0グッド

1クリップ

投稿2021/01/12 17:31

前提・実現したいこと

ndarray_randを用いて、中身が乱数の行列をf32, f64に対して同じトレイと境界を用いて生成したいと考えています。

発生している問題・エラーメッセージ

error[E0277]: the trait bound `ndarray_rand::rand_distr::Normal<{float}>: Distribution<T>` is not satisfied --> src/linear/sgd.rs:46:69 | 46 | let mut weight = Array2::<T>::random((1, input.shape()[0]), Normal::new(1., 1.).unwrap()); | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Distribution<T>` is not implemented for `ndarray_rand::rand_distr::Normal<{float}>` | = help: the following implementations were found: <ndarray_rand::rand_distr::Normal<N> as Distribution<N>> = note: required by `ndarray_rand::RandomExt::random`

このエラーは自分で実装したtraitがfloatに対応していないというのがエラーの原因だと考えられるのですが、
このトレイとはf32やf64に対応するように実装してあるのにこのようなエラーが出た理由がわかりません

該当のソースコード

rust

1use ndarray_linalg::lapack::Lapack; 2use ndarray_linalg::types::Scalar; 3use ndarray_rand::rand_distr::Float as Float; 4 5trait Type: : Lapack + Scalar + Float {} 6 7impl Type for f32 {} 8impl Type for f64 {} 9 10let weight = Array2::random((1, input.shape()[0]), Normal::new(1., 1.).unwrap());

多分、f64やf32に対して個別に実装を行えば実装することは可能だとは思いますが、
あまりいい方法とは考えられないためどのように実装すればよいかわかりません。
よろしくお願い致します。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

仰るように、いちいち Distribution<T> トレイトと関連トレイトを実装するのも大変(すでに ndarray-rand クレートで実装されているものと衝突してエラーが出ることもある)なので、例えば以下のように値をキャストする方法を用いて生成するのはいかがでしょうか。

rust

1use ndarray_linalg::lapack::Lapack; 2use ndarray_linalg::types::Scalar; 3use ndarray_rand::rand_distr::Float as Float; 4use ndarray_rand::rand_distr::{Distribution, Normal}; 5use num_traits::{Num, NumCast}; 6 7pub trait Type: Lapack + Scalar + Float {} 8 9impl Type for f32 {} 10 11impl Type for f64 {} 12 13/// cast a numeric value with type T to one with U 14fn cast_t2u<T, U>(x: T) -> U 15where 16 T: Num + NumCast + Copy, 17 U: Num + NumCast + Copy, 18{ 19 U::from(x).unwrap() 20} 21 22/// generate an array with random values following the normal distribution 23fn generate_normal_distr_array<T, D, Sh>(shape: Sh, mean: f64, std_dev: f64) -> Array<T, D> 24where 25 T: Type, 26 D: Dimension, 27 Sh: ShapeBuilder<Dim = D>, 28{ 29 let mut rng = rand::thread_rng(); 30 let gen = Normal::new(mean, std_dev).unwrap(); 31 Array::<T, D>::zeros(shape).map(|_| cast_t2u(gen.sample(&mut rng))) 32} 33 34fn main() { 35 let input: Array2<f64> = Array2::zeros((10, 10)); 36 let weight: Array2<f64> = generate_normal_distr_array(input.raw_dim(), 1.0, 1.0); 37 println!("{:?}", weight); 38}

ndarray-randndarrayrandクレートのバージョンの組によってはうまく行かないときがあるので、ひとまず私の方で動作しているバージョンの組を以下に示します。

ndarray = "0.13.0" ndarray-rand = "0.11.0" rand = "0.7.3"

投稿2021/01/13 04:45

編集2021/01/13 04:47
Surpris

総合スコア106

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問