質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
87.20%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

受付中

tf.layers denseを使ったNNの重みWの取り出し方がわからない

Dyn.Mat.Mec
Dyn.Mat.Mec

総合スコア8

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

1回答

0評価

0クリップ

3541閲覧

投稿2018/05/29 12:02

tf.layers denseを使用してNNを実装しました。学習済みの重みWを取り出して回帰線を引きたいのですが、重みWの取り出し方がわかりません。コードは以下の通りです。

Python

from collections import namedtuple def build_neural_network(hidden_units1=100, hidden_units2=50, hidden_units3=10): tf.reset_default_graph() inputs = tf.placeholder(tf.float32, shape=[None, 2]) labels = tf.placeholder(tf.float32, shape=[None, 1]) learning_rate = tf.placeholder(tf.float32) is_training = tf.Variable(True, dtype=tf.bool) initializer = tf.contrib.layers.xavier_initializer() fc = tf.layers.dense(inputs, hidden_units1, activation=None, kernel_initializer=initializer) fc=tf.nn.relu(fc) fc = tf.layers.dense(fc, hidden_units2, activation=None, kernel_initializer=initializer) fc = tf.nn.relu(fc) fc = tf.layers.dense(fc, hidden_units3, activation=None, kernel_initializer=initializer) fc = tf.layers.batch_normalization(fc, training=is_training) fc = tf.nn.relu(fc) logits = tf.layers.dense(fc, 1, activation=None) cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) cost = tf.reduce_mean(cross_entropy) with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) predicted = tf.nn.sigmoid(logits) correct_pred = tf.equal(tf.round(predicted), labels) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) export_nodes = ['inputs', 'labels', 'learning_rate', 'is_training', 'logits', 'cost', 'optimizer', 'predicted', 'accuracy'] Graph = namedtuple('Graph', export_nodes) local_dict = locals() graph = Graph(*[local_dict[each] for each in export_nodes]) return graph model = build_neural_network() def get_batch(data_x,data_y,batch_size=32): batch_n=len(data_x)//batch_size for i in range(batch_n): batch_x=data_x[i*batch_size:(i+1)*batch_size] batch_y=data_y[i*batch_size:(i+1)*batch_size] yield batch_x,batch_y epochs = 50 train_collect = 50 train_print=train_collect*2 learning_rate_value = 0.001 batch_size=16 x_collect = [] train_loss_collect = [] train_acc_collect = [] valid_loss_collect = [] valid_acc_collect = [] saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) iteration=0 for e in range(epochs): for batch_x,batch_y in get_batch(train_x,train_y,batch_size): iteration+=1 feed = {model.inputs: train_x, model.labels: train_y, model.learning_rate: learning_rate_value, model.is_training:True } train_loss, _, train_acc = sess.run([model.cost, model.optimizer, model.accuracy], feed_dict=feed) if iteration % train_collect == 0: x_collect.append(e) train_loss_collect.append(train_loss) train_acc_collect.append(train_acc) if iteration % train_print==0: print("Epoch: {}/{}".format(e + 1, epochs), "Train Loss: {:.4f}".format(train_loss), "Train Acc: {:.4f}".format(train_acc)) feed = {model.inputs: valid_x, model.labels: valid_y, model.is_training:False } val_loss, val_acc = sess.run([model.cost, model.accuracy], feed_dict=feed) valid_loss_collect.append(val_loss) valid_acc_collect.append(val_acc) if iteration % train_print==0: print("Epoch: {}/{}".format(e + 1, epochs), "Validation Loss: {:.4f}".format(val_loss), "Validation Acc: {:.4f}".format(val_acc)) saver.save(sess, "./ex_nn.ckpt")

良い質問の評価を上げる

以下のような質問は評価を上げましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

  • プログラミングに関係のない質問
  • やってほしいことだけを記載した丸投げの質問
  • 問題・課題が含まれていない質問
  • 意図的に内容が抹消された質問
  • 過去に投稿した質問と同じ内容の質問
  • 広告と受け取られるような投稿

評価を下げると、トップページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

まだ回答がついていません

会員登録して回答してみよう

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
87.20%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。