KerasとTensorBoardを使ってMNIST解析をしているのですが、TensorBoardのEmbedding Visualizationを上手く表示できない状態です。
いくつか質問があります。
1.まずKerasにはEmbedding layerというのがありますが、これは具体的にどんな役割を果たすレイヤーなのでしょうか。
これがそもそもわかっていません。オライリーの「ゼロから作るDeep Learning」を使ってNNの勉強を進めていたのですが、そこにも出てこず、「Embedding layer」で検索してもKeras関連のものしか出てこず、Kerasのオリジナルなのでしょうか、その役割がわかりません。
2.次に、KerasのEmbeddingレイヤーのパラメータの設定の仕方が、公式ドキュメントを読んでもいまいち理解できず、Embedding層を挿入するとエラーが出続けるので、上手く組み込めません。
affineレイヤーと同じように、入力数(MNISTなら784)と、その層の出力数(次の隠れ層の入力数)という感じで、設定するのとは違うのでしょうか。
それかそもそもEmbeddingレイヤー単体で使うものではないのでしょうか。
3.最後に、一番したいと思っていたことはこれなのですが、TensorBoardでMNIST画像の分類のEmbedding VisualizationをKerasを用いて実装したいのですが、上手く行きません。
調べてみるとこのサイトにKerasでEmbeddingVisualizationを実装するコードが載っているのでできないわけではないと思うのですが、上手く組み込めません。
そもそもTensorBoardとKerasのEmbedding layerは別物なのでしょうか。
現在はこのような感じです。
KerasへのTensorBoard実行のためのコードの実装と、多分間違えてますが、imageの表示もできています。
ですがこのページの下の方にあるようなEmbeddingVisualizationができない状態です。
python
1import numpy as np 2import matplotlib.pyplot as plt 3from sklearn import datasets 4from sklearn.model_selection import train_test_split 5from keras.models import Sequential 6from keras.layers import Dense,Activation 7from keras.layers.wrappers import Bidirectional 8from keras.layers.recurrent import LSTM 9from keras.optimizers import Adam 10from keras.callbacks import EarlyStopping 11 12# tensorboard 13import keras.callbacks 14import keras.backend.tensorflow_backend as KTF 15import tensorflow as tf 16 17# lossの履歴をプロット 18def plot_history(history): 19 plt.plot(history.history['loss'],label="MNIST LSTM",) 20 plt.title('LSTM') 21 plt.xlabel('epoch') 22 plt.ylabel('loss') 23 plt.legend(loc='lower right') 24 plt.show() 25 26log_filepath = './log' # tensorboard 27############################### 28# データの生成 # 29############################### 30np.random.seed(0) 31mnist = datasets.fetch_mldata('MNIST original', data_home='.') 32 33n = len(mnist.data) 34N = 10000 35indices = np.random.permutation(range(n))[:N] # ランダムにN枚を選択 36 37X = mnist.data[indices] 38y = mnist.target[indices] 39Y = np.eye(10)[y.astype(int)] # 1-of-K 表現に変換 40 41# 正規化 42X = X / 255.0 43X = X - X.mean(axis=1).reshape(len(X), 1) 44X = X.reshape(len(X), 28, 28) 45 46X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8) 47 48############################### 49# モデルの設定 # 50############################### 51n_in = 28 52n_time = 28 53n_hidden = 128 54n_out = 10 55 56def weight_variable(shape, name=None): 57 return np.random.normal(scale=.01, size=shape) 58 59early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1) 60 61old_session = KTF.get_session() # tensorboard 62 63with tf.Graph().as_default(): 64 # tensorboard 65 session = tf.Session('') 66 KTF.set_session(session) 67 KTF.set_learning_phase(1) 68 69 # image処理 70 # 読み込むべきはX_trainではない気がする 71 X_train_img = tf.reshape(X_train, [-1, 28, 28, 1]) 72 X_train_img = tf.cast(X_train_img, tf.float32) 73 tf.summary.image('train',X_train_img , 10) 74 75 model = Sequential() 76 model.add(Bidirectional(LSTM(n_hidden),input_shape=(n_time, n_in))) 77 model.add(Dense(n_out, kernel_initializer=weight_variable)) 78 model.add(Activation('softmax')) 79 model.summary() 80 81 # AdaDelta良い説!!!!!!!!!!!!1 82 optimizer = Adam(lr=0.001, beta_1=0.9, beta_2=0.999) 83 model.compile(optimizer=optimizer, 84 loss='mean_squared_error', 85 metrics=['accuracy']) 86 87 tb_cb = keras.callbacks.TensorBoard( 88 log_dir=log_filepath, 89 histogram_freq=1, 90 embeddings_freq=100) 91 # tensorboard 92 93 ############################### 94 # モデルの学習 # 95 ############################### 96 # epochs = 60 97 epochs = 3 98 batch_size = 200 99 100 his = model.fit(X_train, Y_train, 101 batch_size=batch_size, 102 epochs=epochs, 103 validation_data=(X_test,Y_test), 104 callbacks=[early_stopping,tb_cb]) # tensorboard 105 106 ############################### 107 # モデルの予測 # 108 ############################### 109 score = model.evaluate(X_test, Y_test, verbose=0) 110 print('loss:{}'.format(score[0])) 111 print('acc:{}'.format(score[1])) 112 plot_history(his) 113 114# tensorboard 115KTF.set_session(old_session)
質問が多くて恐縮です。
一部でもご存じの方がいらっしゃれば教えていただけないでしょうか。
よろしくお願いします。
あなたの回答
tips
プレビュー