質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

1回答

4074閲覧

tf.layers denseを使ったNNの重みWの取り出し方がわからない

Dyn.Mat.Mec

総合スコア8

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2018/05/29 12:02

tf.layers denseを使用してNNを実装しました。学習済みの重みWを取り出して回帰線を引きたいのですが、重みWの取り出し方がわかりません。コードは以下の通りです。

Python

1 2from collections import namedtuple 3def build_neural_network(hidden_units1=100, hidden_units2=50, hidden_units3=10): 4 tf.reset_default_graph() 5 inputs = tf.placeholder(tf.float32, shape=[None, 2]) 6 labels = tf.placeholder(tf.float32, shape=[None, 1]) 7 learning_rate = tf.placeholder(tf.float32) 8 is_training = tf.Variable(True, dtype=tf.bool) 9 10 initializer = tf.contrib.layers.xavier_initializer() 11 fc = tf.layers.dense(inputs, hidden_units1, activation=None, kernel_initializer=initializer) 12 13 fc=tf.nn.relu(fc) 14 15 fc = tf.layers.dense(fc, hidden_units2, activation=None, kernel_initializer=initializer) 16 17 fc = tf.nn.relu(fc) 18 19 fc = tf.layers.dense(fc, hidden_units3, activation=None, kernel_initializer=initializer) 20 fc = tf.layers.batch_normalization(fc, training=is_training) 21 fc = tf.nn.relu(fc) 22 23 logits = tf.layers.dense(fc, 1, activation=None) 24 cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 25 cost = tf.reduce_mean(cross_entropy) 26 27 with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 28 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) 29 30 predicted = tf.nn.sigmoid(logits) 31 correct_pred = tf.equal(tf.round(predicted), labels) 32 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 33 34 export_nodes = ['inputs', 'labels', 'learning_rate', 'is_training', 'logits', 'cost', 'optimizer', 'predicted', 'accuracy'] 35 Graph = namedtuple('Graph', export_nodes) 36 local_dict = locals() 37 graph = Graph(*[local_dict[each] for each in export_nodes]) 38 39 return graph 40 41model = build_neural_network() 42 43 44def get_batch(data_x,data_y,batch_size=32): 45 batch_n=len(data_x)//batch_size 46 for i in range(batch_n): 47 batch_x=data_x[i*batch_size:(i+1)*batch_size] 48 batch_y=data_y[i*batch_size:(i+1)*batch_size] 49 50 yield batch_x,batch_y 51 52 53epochs = 50 54train_collect = 50 55train_print=train_collect*2 56 57learning_rate_value = 0.001 58batch_size=16 59 60x_collect = [] 61train_loss_collect = [] 62train_acc_collect = [] 63valid_loss_collect = [] 64valid_acc_collect = [] 65 66saver = tf.train.Saver() 67with tf.Session() as sess: 68 sess.run(tf.global_variables_initializer()) 69 iteration=0 70 for e in range(epochs): 71 for batch_x,batch_y in get_batch(train_x,train_y,batch_size): 72 iteration+=1 73 feed = {model.inputs: train_x, 74 model.labels: train_y, 75 model.learning_rate: learning_rate_value, 76 model.is_training:True 77 } 78 79 train_loss, _, train_acc = sess.run([model.cost, model.optimizer, model.accuracy], feed_dict=feed) 80 81 if iteration % train_collect == 0: 82 x_collect.append(e) 83 train_loss_collect.append(train_loss) 84 train_acc_collect.append(train_acc) 85 86 if iteration % train_print==0: 87 print("Epoch: {}/{}".format(e + 1, epochs), 88 "Train Loss: {:.4f}".format(train_loss), 89 "Train Acc: {:.4f}".format(train_acc)) 90 91 feed = {model.inputs: valid_x, 92 model.labels: valid_y, 93 model.is_training:False 94 } 95 val_loss, val_acc = sess.run([model.cost, model.accuracy], feed_dict=feed) 96 valid_loss_collect.append(val_loss) 97 valid_acc_collect.append(val_acc) 98 99 if iteration % train_print==0: 100 print("Epoch: {}/{}".format(e + 1, epochs), 101 "Validation Loss: {:.4f}".format(val_loss), 102 "Validation Acc: {:.4f}".format(val_acc)) 103 saver.save(sess, "./ex_nn.ckpt")

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

tf.layers.Denseで重みを取り出せるのですが、
今は上記の通り小文字の方なのでしょうか?

投稿2018/11/21 08:41

tak__tak

総合スコア78

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問