質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

90.53%

  • Python

    7883questions

    Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

  • TensorFlow

    664questions

tf.layers denseを使ったNNの重みWの取り出し方がわからない

受付中

回答 0

投稿

  • 評価
  • クリップ 0
  • VIEW 292

Dyn.Mat.Mec

score 2

tf.layers denseを使用してNNを実装しました。学習済みの重みWを取り出して回帰線を引きたいのですが、重みWの取り出し方がわかりません。コードは以下の通りです。

from collections import namedtuple
def build_neural_network(hidden_units1=100, hidden_units2=50, hidden_units3=10):
    tf.reset_default_graph()
    inputs = tf.placeholder(tf.float32, shape=[None, 2])
    labels = tf.placeholder(tf.float32, shape=[None, 1])
    learning_rate = tf.placeholder(tf.float32)
    is_training = tf.Variable(True, dtype=tf.bool)

    initializer = tf.contrib.layers.xavier_initializer()
    fc = tf.layers.dense(inputs, hidden_units1, activation=None, kernel_initializer=initializer)

    fc=tf.nn.relu(fc)

    fc = tf.layers.dense(fc, hidden_units2, activation=None, kernel_initializer=initializer)

    fc = tf.nn.relu(fc)

    fc = tf.layers.dense(fc, hidden_units3, activation=None, kernel_initializer=initializer)
    fc = tf.layers.batch_normalization(fc, training=is_training)
    fc = tf.nn.relu(fc)

    logits = tf.layers.dense(fc, 1, activation=None)
    cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
    cost = tf.reduce_mean(cross_entropy)

    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

        predicted = tf.nn.sigmoid(logits)
        correct_pred = tf.equal(tf.round(predicted), labels)
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        export_nodes = ['inputs', 'labels', 'learning_rate', 'is_training', 'logits', 'cost', 'optimizer', 'predicted', 'accuracy']
        Graph = namedtuple('Graph', export_nodes)
        local_dict = locals()
        graph = Graph(*[local_dict[each] for each in export_nodes])

    return graph

model = build_neural_network()


def get_batch(data_x,data_y,batch_size=32):
    batch_n=len(data_x)//batch_size
    for i in range(batch_n):
        batch_x=data_x[i*batch_size:(i+1)*batch_size]
        batch_y=data_y[i*batch_size:(i+1)*batch_size]

        yield batch_x,batch_y


epochs = 50
train_collect = 50
train_print=train_collect*2

learning_rate_value = 0.001
batch_size=16

x_collect = []
train_loss_collect = []
train_acc_collect = []
valid_loss_collect = []
valid_acc_collect = []

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    iteration=0
    for e in range(epochs):
        for batch_x,batch_y in get_batch(train_x,train_y,batch_size):
            iteration+=1
            feed = {model.inputs: train_x,
                    model.labels: train_y,
                    model.learning_rate: learning_rate_value,
                    model.is_training:True
                   }

            train_loss, _, train_acc = sess.run([model.cost, model.optimizer, model.accuracy], feed_dict=feed)

            if iteration % train_collect == 0:
                x_collect.append(e)
                train_loss_collect.append(train_loss)
                train_acc_collect.append(train_acc)

                if iteration % train_print==0:
                     print("Epoch: {}/{}".format(e + 1, epochs),
                      "Train Loss: {:.4f}".format(train_loss),
                      "Train Acc: {:.4f}".format(train_acc))

                feed = {model.inputs: valid_x,
                        model.labels: valid_y,
                        model.is_training:False
                       }
                val_loss, val_acc = sess.run([model.cost, model.accuracy], feed_dict=feed)
                valid_loss_collect.append(val_loss)
                valid_acc_collect.append(val_acc)

                if iteration % train_print==0:
                    print("Epoch: {}/{}".format(e + 1, epochs),
                      "Validation Loss: {:.4f}".format(val_loss),
                      "Validation Acc: {:.4f}".format(val_acc))
        saver.save(sess, "./ex_nn.ckpt")
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

まだ回答がついていません

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 90.53%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る

  • Python

    7883questions

    Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

  • TensorFlow

    664questions