Tensorflowで正規分布からランダムな値のテンソルを出力する関数、truncated_normalについての質問です。
Pythonのバージョンは1.5.0です。
この関数の引数shapeに何を設定するのかをお聞きしたいです。
shapeの引数を[3,2]にすると、3行2列のテンソルが出力されます。
Python
1import tensorflow as tf 2 3ran = tf.truncated_normal(shape=[3, 2], stddev=0.1) 4 5sess = tf.Session() 6print(sess.run(ran))
出力
[[ 0.07918037 0.11460604]
[-0.07523992 0.03573691]
[ 0.17116651 0.01668654]]
問題なのはここからで、shapeの引数を[3,2,1]にすると出力されるテンソルが3行2列の形状ではなくなってしまいます。
Python
1import tensorflow as tf 2 3ran = tf.truncated_normal(shape=[3, 2, 1], stddev=0.1) 4 5sess = tf.Session() 6print(sess.run(ran))
出力
[[[ 0.05785213]
[-0.05744959]]
[[ 0.13790402]
[-0.02802232]]
[[ 0.01510559]
[ 0.1260141 ]]]
更に、shapeの引数を[3,2,1,1]にすると[3,2,1]のときと出力されるテンソルの形状が変わります。
Python
1import tensorflow as tf 2 3ran = tf.truncated_normal(shape=[3, 2, 1, 1], stddev=0.1) 4 5sess = tf.Session() 6print(sess.run(ran))
出力
[[[[ 0.02945449]]
[[-0.06374347]]]
[[[-0.06561454]]
[[-0.18820563]]]
[[[-0.00074543]]
[[ 0.13984145]]]]
shapeの配列の各要素は何を設定する値か、分かる方がおられましたら宜しくお願い致します。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2020/01/12 08:47