重みの可視化
今回keras,mnistを使い簡単に動かしました。
"model.get_layer('dense_1').get_weights()"で重みを取得したのですが、この取得した重みをどうにか画像として出力したいです。
理想はtensorboardへの出力ですが、ありあえずmatplotでの出力を試行錯誤しています。
ここで得た重みをどのように扱えばいいのか分からず苦戦しております。
この重みを画像として出力できれば、手書き数字のどの部分が反応しているか、目で見て分かるのではと考えています。
>>> model.get_layer('dense_1').get_weights() [array([[ 0.16572747, -0.10772504, 0.1211575 , ..., -0.10732491, 0.14570805, 0.01378143], [-0.17990576, 0.17675169, -0.206588 , ..., -0.13718925, -0.04183103, -0.27726373], [-0.05705023, 0.0617761 , 0.05799635, ..., -0.36197215, 0.11944304, -0.21759842], ..., [-0.1403506 , -0.17172877, -0.44814473, ..., 0.02769719, 0.12637761, 0.01180686], [ 0.046968 , -0.17755437, -0.07650904, ..., 0.03794758, -0.08226985, 0.0308189 ], [-0.2886293 , 0.0838939 , -0.01235085, ..., 0.12330087, -0.01136581, -0.08974177]], dtype=float32), array([-0.05677079, -0.11204855, -0.02062975, -0.08041384, 0.01099737, 0.00711802, -0.0169198 , -0.08662082, 0.25113648, 0.01298474], dtype=float32)]
import tensorflow as tf import datetime import numpy as np import matplotlib.pyplot as plt mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation=tf.nn.relu), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation=tf.nn.softmax) ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M") file_writer = tf.summary.create_file_writer(log_dir) with file_writer.as_default(): images = np.reshape(x_train[0:10], (-1, 28, 28, 1)) tf.summary.image("train", images, max_outputs=10, step=1) tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) model.fit(x=x_train, y=y_train, epochs=5, validation_data=(x_test, y_test), callbacks=[tensorboard_callback]) model.evaluate(x_test, y_test) model.get_layer('dense_1').get_weights()
環境
Anaconda3
python3.7.7
tensorflow2.3.0
keras2.4.0
あなたの回答
tips
プレビュー