train = tf.group(*updates) の*の役割がわかりません。
tf.reset_default_graph() x = tf.placeholder(tf.float32, name='x') t = tf.placeholder(tf.float32, name='t') W = tf.Variable(rng.uniform(low=-0.08, high=0.08, size=(2, 1)).astype('float32'), name='W') b = tf.Variable(np.zeros(1).astype('float32'), name='b') y = tf.nn.sigmoid(tf.matmul(x,W)+b) cost = -tf.reduce_mean(t*tf.log(tf.clip_by_value(y, 1e-10, 1.0)) + (1 - t)*tf.log(tf.clip_by_value(1 - y, 1e-10, 1.0))) gW, gb = tf.gradients(cost, [W, b]) updates = [ b.assign_add(-0.01*gb), W.assign_add(-0.01*gW) ] train = tf.group(*updates) train_X = np.array([[0, 1], [1, 0], [0, 0], [1, 1]]) train_y = np.array([[1], [1], [0], [1]]) with tf.Session() as sess: sess = tf.Session() sess.run(tf.global_variables_initializer()) for i in range(10000): _cost, _ = sess.run([cost, train], feed_dict={x: train_X, t: train_y}) if (i+1)%1000==0: print(_cost)
というコードが出て来て、
gW, gb = tf.gradients(cost, [W, b]) updates = [ b.assign_add(-0.01*gb), W.assign_add(-0.01*gW) ] train = tf.group(*updates)
の
train = tf.group(*updates)
の*はどういう役割を持つのでしょうか?
通用の*は掛け算を表すものですが、
このアスタリスクはリストを引き渡せるようにするもので合っていますか?
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。