挙動は numpy.where と同じです。
tf.where | TensorFlow Core 1.13 | TensorFlow
If both x and y are None, then this operation returns the coordinates of true elements of condition. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.
conditions の True の要素のインデックスを返します。
テンソルの形状は (True の数、conditions の次元数) の2次元テンソルになります。
以下の例では、conditions の True の要素のインデックスは (0, 0), (0, 2), (1, 0), (1, 2) の4つなので、(4, 2) のテンソルが返り値となります。
python
1import tensorflow as tf
2
3conditions = tf.constant(
4 [[True, False, True],
5 [True, True, True],
6 [False, False, False]]
7)
8
9b = tf.where(conditions)
10# [[0 0]
11# [0 2]
12# [1 0]
13# [1 2]]
If both non-None, x and y must have the same shape. The condition tensor must be a scalar if x and y are scalar. If x and y are vectors of higher rank, then condition must be either a vector with size matching the first dimension of x, or must have the same shape as x.
The condition tensor acts as a mask that chooses, based on the value at each element, whether the corresponding element / row in the output should be taken from x (if true) or y (if false).
x, y を指定した場合、conditions が True の要素はそれと同じインデックスの x の要素、conditions が False の要素はそれと同じインデックスの y の要素からなるテンソルを返します。
プログラミングで出てくる三項間演算子と同じ働きになります。
python
1import tensorflow as tf
2
3conditions = tf.constant(
4 [[True, False, True],
5 [True, True, True],
6 [False, False, False]]
7)
8
9x = tf.constant([[1, 2, 3],
10 [4, 5, 6],
11 [7, 8, 9]])
12y = tf.constant([[-1, -2, -3],
13 [-4, -5, -6],
14 [-7, -8, -9]])
15
16b = tf.where(conditions, x, y)
17# [[ 1 -2 3]
18# [ 4 5 6]
19# [-7 -8 -9]]
使われていた例について
Leaky ReLu 関数は微分すると、x >= 0 のときは1、そうでないときは α となるので、grad() で渡ってきたデルタδに対して、x>= 0
のときは δ1、そうでないときは δα して返す関数を tf.where() で実現している。
Deep Learning - 活性化関数
python
1def grad(dy):
2 dx = tf.where(y >= 0, dy, dy * alpha)
3 return dx, lambda ddx: tf.where(y >= 0, ddx, ddx * alpha)
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2019/06/03 09:31