質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

0回答

587閲覧

Keras - Seq2Seqで文章校正に挑戦したが同じ文字しか出力されない

ryoryoohiya

総合スコア12

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2019/05/31 03:07

編集2019/05/31 03:19

前提・実現したいこと

pythonのkerasを使って文章校正のためのseq2seqを組もうとしています。
実装中に以下の問題が発生しました。

発生している問題・エラーメッセージ

以下のコードを実行しても予測結果は全て同じ文字しか出力されません。
storyとnew_storyには校正前文章と校正後文章が500(文書数)×250(1文章内の単語数)で格納されています。

該当のソースコード

python

1import pickle 2import numpy as np 3 4from keras.models import Sequential 5from keras.layers.core import Dense, Activation, RepeatVector 6from keras.layers.recurrent import LSTM 7from keras.layers.wrappers import TimeDistributed 8from keras.optimizers import Adam 9from keras.callbacks import EarlyStopping 10 11with open('word_indices.pickle', 'rb') as f : 12 word_indices = pickle.load(f) 13 14with open('indices_word.pickle', 'rb') as f : 15 indices_word = pickle.load(f) 16 17with open('story.pickle', 'rb') as ff : 18 story = pickle.load(ff) 19 20with open('new_story.pickle', 'rb') as ff : 21 new_story = pickle.load(ff) 22 23X = [] 24Y = [] 25T = [] 26maxlen = 250 27 28print("#1") 29 30enc = [] 31dec = [] 32N = 499 33for i in range(N): 34 if len(story[i]) >= 250: 35 story[i] = story[i] 36 story[i] = story[i][:250] 37 enc.append(story[i]) 38 new_story[i] = new_story[i] 39 new_story[i] = new_story[i][:250] 40 dec.append(new_story[i]) 41 42print("#2") 43 44X = np.zeros((len(enc), maxlen, len(word_indices))) 45Y = np.zeros((len(dec), maxlen, len(word_indices))) 46 47print("#3") 48 49for i in range(len(enc)): 50 for t, char in enumerate(enc[i]): 51 X[i, t, word_indices[char]] = 1 52 for t, char in enumerate(dec[i]): 53 Y[i, t, word_indices[char]] = 1 54 55del(story) 56del(new_story) 57 58X_train = X[0:int(len(enc)*0.8)][:][:] 59X_test = X[int(len(enc)*0.8):int(len(enc)*0.9)][:][:] 60X_validation = X[int(len(enc)*0.9):][:][:] 61Y_train = X[0:int(len(enc)*0.8)][:][:] 62Y_test = X[int(len(enc)*0.8):int(len(enc)*0.9)][:][:] 63Y_validation = X[int(len(enc)*0.9):][:][:] 64 65del(enc) 66del(dec) 67N_validation = len(X_validation) 68print(X_train.shape) 69print(X_test.shape) 70print(Y_train.shape) 71print(Y_test.shape) 72 73n_in = len(word_indices) 74n_hidden = 128 75n_out = len(word_indices) 76 77def weight_variable(shape, name=None): 78 return np.random.normal(scale=.01, size=shape) 79 80print("#4") 81 82early_stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1) 83 84model = Sequential() 85model.add(LSTM(n_hidden, input_shape=(maxlen, n_in))) 86model.add(RepeatVector(maxlen)) 87model.add(LSTM(n_hidden, return_sequences=True)) 88model.add(TimeDistributed(Dense(n_out, kernel_initializer=weight_variable))) 89model.add(Activation('softmax')) 90model.compile(loss='categorical_crossentropy', 91 optimizer=Adam(lr=0.01, beta_1=0.9, beta_2=0.999), 92 metrics=['accuracy']) 93 94print("#5") 95 96epochs = 20 97batch_size = 10 98 99for epoch in range(epochs): 100 print("#") 101 model.fit(X_train, Y_train, batch_size=batch_size, epochs=1, 102 validation_data=(X_validation, Y_validation), 103 callbacks=[early_stopping]) 104 for i in range(2): 105 index = np.random.randint(0, N_validation) 106 input_text = X_validation[np.array([index])][:][:] 107 output_text = Y_validation[np.array([index])][:][:] 108 prediction = model.predict_classes(input_text, verbose=0) 109 110 input_text = input_text.argmax(axis=-1) 111 output_text = output_text.argmax(axis=-1) 112 113 q = ''.join(indices_word[i] for i in input_text[0]) 114 a = ''.join(indices_word[i] for i in output_text[0]) 115 p = ''.join(indices_word[i] for i in prediction[0]) 116 print('Q: ', q) 117 print('A: ', p)

###出力

Using TensorFlow backend. #1 #2 #3 #4 WARNING:tensorflow:From /Users/PPP/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer. WARNING:tensorflow:From /Users/PPP/.pyenv/versions/3.6.8/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1190: calling reduce_sum_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version. Instructions for updating: keep_dims is deprecated, use keepdims instead WARNING:tensorflow:From /Users/PPP/.pyenv/versions/3.6.8/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1154: calling reduce_max_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version. Instructions for updating: keep_dims is deprecated, use keepdims instead #5 # WARNING:tensorflow:From /Users/PPP/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.cast instead. WARNING:tensorflow:From /Users/PPP/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:102: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Deprecated in favor of operator or tf.math.divide. Train on 147 samples, validate on 19 samples Epoch 1/1 2019-05-31 12:37:58.383911: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA 2019-05-31 12:37:58.384368: I tensorflow/core/common_runtime/process_util.cc:71] Creating new thread pool with default inter op setting: 4. Tune using inter_op_parallelism_threads for best performance. 147/147 [==============================] - 332s - loss: 8.1267 - acc: 0.0393 - val_loss: 7.7482 - val_acc: 2.1053e-04 Q: 文章(著作権の問題から掲載できません) A: …………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………

試したこと

X_trainに正しく数字が格納されているか確かめましたがそこは大丈夫でした。

補足情報(FW/ツールのバージョンなど)

keras 2.0.5

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問