質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

282閲覧

Tensorflowで予測を出す

Yhaya

総合スコア439

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2019/03/04 07:26

編集2019/03/04 08:26

環境

  • Tensorflow 1.12.0

やりたいこと

MNISTを学習させたモデルを使って、予測結果を出力したい。

学習に使ったモデル

訓練データはKaggleのMNISTのデータを使っています。

python

1# 使う関数を定義 2def shuffle_batch(X, y, batch_size): 3 rnd_idx = np.random.permutation(len(X)) 4 n_batches = len(X) // batch_size 5 for batch_idx in np.array_split(rnd_idx, n_batches): 6 X_batch, y_batch = X[batch_idx], y[batch_idx] 7 yield X_batch, y_batch 8 9def leaky_relu(z, name=None): 10 return tf.maximum(0.01 * z, z, name=name) 11 12# モデル構築 13n_inputs = 28 * 28 # MNIST 14n_hidden1 = 300 15n_hidden2 = 100 16n_outputs = 10 17 18X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X") 19y = tf.placeholder(tf.int32, shape=(None), name="y") 20 21with tf.name_scope("dnn"): 22 hidden1 = tf.layers.dense(X, n_hidden1, activation=leaky_relu, name="hidden1") 23 hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=leaky_relu, name="hidden2") 24 logits = tf.layers.dense(hidden2, n_outputs, name="outputs") 25with tf.name_scope("loss"): 26 xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits) 27 loss = tf.reduce_mean(xentropy, name="loss") 28 29learning_rate = 0.01 30with tf.name_scope("train"): 31 optimizer = tf.train.GradientDescentOptimizer(learning_rate) 32 training_op = optimizer.minimize(loss) 33with tf.name_scope("eval"): 34 correct = tf.nn.in_top_k(logits, y, 1) 35 accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) 36 37init = tf.global_variables_initializer() 38saver = tf.train.Saver() 39 40# 訓練 41n_epochs = 40 42batch_size = 50 43 44with tf.Session() as sess: 45 init.run() 46 for epoch in range(n_epochs): 47 for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size): 48 sess.run(training_op, feed_dict={X: X_batch, y: y_batch}) 49 if epoch % 5 == 0: 50 acc_batch = accuracy.eval(feed_dict={X: X_batch, y: y_batch}) 51 acc_valid = accuracy.eval(feed_dict={X: X_test, y: y_test}) 52 print(epoch, "Batch accuracy:", acc_batch, "Validation accuracy:", acc_valid) 53 54 save_path = saver.save(sess, "./my_model_final.ckpt")

訓練で得られたパラメータは、チェックポイントとして保存しました。

予測

python

1tf.reset_default_graph() 2 3X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X") 4y = tf.placeholder(tf.int32, shape=(None), name="y") 5 6with tf.name_scope("dnn"): 7 hidden1 = tf.layers.dense(X, n_hidden1, activation=leaky_relu, name="hidden1") 8 hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=leaky_relu, name="hidden2") 9 logits = tf.layers.dense(hidden2, n_outputs, name="outputs") 10 11saver = tf.train.Saver() 12 13with tf.Session() as sess: 14 saver.restore(sess, './my_model_final.ckpt') 15 pred = logits.eval(feed_dict={X: X_train}) 16 17result = np.argmax(pred, axis=1)

トレーニングデータでは最終的に、精度がほとんど1になるのですが、上のコードを使って、トレーニングデータの予測を行って正解ラベルと比べるとほとんどあっていません。

どのようにすれば、モデルを予測に使えるのでしょうか?

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

tiitoi

2019/03/04 07:57

変数 pred はどこで定義されているんでしょうか?コードの一部ではなく、コピペしたら動くコード全体を貼ってください
Yhaya

2019/03/04 08:27

予測のSession部分を追加しました。データはKaggleにあるMNISTのデータを使っています。 よろしくお願いいたします。
tiitoi

2019/03/04 08:43 編集

質問欄のコードの見る限りでは問題があるようには見えません。 "コピペで完全に動くコード"を貼っていただけないと、こちらで動かせないので原因はわからないです。
Yhaya

2019/03/04 09:01

すみません。Kaggleのリンクを張るために最初から全実行したら、問題が発生しなかったので、いろいろやっているうちにカーネルが変なことになっていただけでした。 ご迷惑をおかけしました
guest

回答1

0

自己解決

Kaggleのリンクを張るために最初から全実行したら、問題が発生しなかったので、いろいろやっているうちにカーネルが変なことになっていただけでした。

投稿2019/03/04 09:01

Yhaya

総合スコア439

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問