Tensorflowで、一般的なバッチでテストデータの精度を算出しているので以下のようにして実行できています。
Python
1test_acc = [] 2with tf.Session(graph=graph) as sess: 3 saver.restore(sess, tf.train.latest_checkpoint('checkpoint')) 4 test_state = sess.run(cell.zero_state(batch_size, tf.float32)) 5 for ii, (x,y) in enumerate(get_batches(test_x, test_y, batch_size), 1): 6 feed = {inputs_: x, 7 labels_: y[:,None], 8 keep_prob: 1, 9 initial_state: test_state} 10 batch_acc, test_state = sess.run([accuracy, final_state], feed_dict=feed) 11 test_acc.append(batch_acc) 12 print("Test Accuracy: {:.3f}".format(np.mean(test_acc)))
test_x.shape (2500, 200) test_y.shape (2500,)
今回は2クラス分類を行っています。
しかし、test_xの中の2、3個抽出して、抽出したデータがどのクラスに属するかを知りたい場合どのようにしたら求めることができるでしょうか。
よろしくお願い致します。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。