質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.46%

Q&A

0回答

551閲覧

DNNの層を増やしたら精度が落ちてしまう

mothi5656

総合スコア27

0グッド

0クリップ

投稿2020/11/23 06:07

下のコードはmnistデータセットを使って書かれている数字を判別するネットワークを作成したものなのです。2層のNNまでは精度が97%くらいだったのですが、3層以上にすると精度が大幅に落ちるのは何故でしょうか。単純に層を増やす以外に気をつけなければならないことがあるのでしょうか。

コード from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf import numpy as np import matplotlib.pyplot as plt #####ネットワークの定義##### x=tf.placeholder(tf.float32,[None,784]) y=tf.placeholder(tf.float32,[None,10]) W1=tf.Variable(tf.random_normal([784,1024],mean=0,stddev=0.1,dtype=tf.float32)) W2=tf.Variable(tf.random_normal([1024, 1024],mean=0,stddev=0.1,dtype=tf.float32)) W3=tf.Variable(tf.random_normal([1024,1024],mean=0,stddev=0.1,dtype=tf.float32)) W4=tf.Variable(tf.random_normal([1024,1024],mean=0,stddev=0.1,dtype=tf.float32)) W5=tf.Variable(tf.random_normal([1024,10],mean=0,stddev=0.1,dtype=tf.float32)) b1=tf.Variable(tf.zeros([1024])) b2=tf.Variable(tf.zeros([1024])) b3=tf.Variable(tf.zeros([1024])) b4=tf.Variable(tf.zeros([1024])) b5=tf.Variable(tf.zeros([10])) #ネットワーク定義 h1=tf.nn.relu(tf.matmul(x,W1)+b1) h2=tf.nn.relu(tf.matmul(h1,W2)+b2) h3=tf.nn.relu(tf.matmul(h2,W3)+b3) h4=tf.nn.relu(tf.matmul(h3,W4)+b4) y_=tf.nn.softmax(tf.matmul(h4,W5)+b5) #誤差関数定義(交差エントロピー) ε=1e-7 cross_entropy= -tf.reduce_sum(y*tf.log(y_+ε)) #訓練方法の定義 train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #精度計算の定義 correct_prediction = tf.equal(tf.argmax(y, 1),tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) #セッションの開始 sess=tf.Session() sess.run(tf.global_variables_initializer()) #データ読み込み mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) for i in range(25): plt.subplot(5,5,i+1) plt.xticks([]) plt.yticks([]) plt.grid(False) plt.imshow(np.reshape( mnist.train.images[i, :].astype(np.float32) , [28,28]), cmap=plt.get_cmap("gray")) plt.show() #トレーニング部分 for i in range(1000): batch_xs, batch_ys =mnist.train.next_batch(100) sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys}) #精度の表示 test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print(test_acc) sess.close()

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.46%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問