質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

90.45%

  • Python 3.x

    10333questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

  • Keras

    513questions

seq2seqでKeyErrorがでます

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 207

yep

score 39

kerasのexample、lstm_seq2seq.py(オリジナル)を基に
一つの情報から複数の情報をseq2seqで推定したい(アレンジ)と考えています。
しかしながら、KeyErrorが出力されてしまいます。

Using TensorFlow backend.
Traceback (most recent call last):
  File "./Desktop/simple.py", line 87, in <module>
    encoder_input_data[i, t, input_token_index[char]] = 1.
KeyError: '6'

オリジナル

from __future__ import print_function

from keras.models import Model
from keras.layers import Input, LSTM, Dense
import numpy as np

batch_size = 64  # Batch size for training.
epochs = 100  # Number of epochs to train for.
latent_dim = 256  # Latent dimensionality of the encoding space.
num_samples = 10000  # Number of samples to train on.
# Path to the data txt file on disk.
data_path = 'fra-eng/fra.txt'

# Vectorize the data.
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
with open(data_path, 'r', encoding='utf-8') as f:
    lines = f.read().split('\n')
for line in lines[: min(num_samples, len(lines) - 1)]:
    input_text, target_text = line.split('\t')
    # We use "tab" as the "start sequence" character
    # for the targets, and "\n" as "end sequence" character.
    target_text = '\t' + target_text + '\n'
    input_texts.append(input_text)
    target_texts.append(target_text)
    for char in input_text:
        if char not in input_characters:
            input_characters.add(char)
    for char in target_text:
        if char not in target_characters:
            target_characters.add(char)

input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])

print('Number of samples:', len(input_texts))
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)

input_token_index = dict(
    [(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict(
    [(char, i) for i, char in enumerate(target_characters)])

encoder_input_data = np.zeros(
    (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
decoder_input_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')
decoder_target_data = np.zeros(
    (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
    dtype='float32')

for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t, input_token_index[char]] = 1.
    for t, char in enumerate(target_text):
        # decoder_target_data is ahead of decoder_input_data by one timestep
        decoder_input_data[i, t, target_token_index[char]] = 1.
        if t > 0:
            # decoder_target_data will be ahead by one timestep
            # and will not include the start character.
            decoder_target_data[i, t - 1, target_token_index[char]] = 1.

アレンジ

from keras.layers import Input, LSTM, Dense, Concatenate
from keras.models import Model
import numpy as np

batch_size = 64
epochs = 100
latent_dim = 256

input_image = [None] * 2
for i in range(2):
    with open('./Desktop/input/{}.txt'.format(i + 1), mode='r', encoding='utf-8')as f:
        input_image[i] = f.read()
target_integer = [None] * 2
for i in range(2):
    with open('./Desktop/target_integer/{}.txt'.format(i + 1), mode='r', encoding='utf-8')as f:
        target_integer[i] = f.read()
target_exponent = [None] * 2
for i in range(2):
    with open('./Desktop/target_exponent/{}.txt'.format(i + 1), mode='r', encoding='utf-8')as f:
        target_exponent[i] = f.read()

input_characters = [None] * 2
for i in range(2):
    with open('./Desktop/input/{}.txt'.format(i + 1), mode='r', encoding='utf-8')as f:
        input_characters[i] = f.read()
target_integer_characters = [None] * 2
for i in range(2):
    with open('./Desktop/target_integer/{}.txt'.format(i + 1), mode='r', encoding='utf-8')as f:
        target_integer_characters[i] = f.read()
target_exponent_characters = [None] * 2
for i in range(2):
    with open('./Desktop/target_exponent/{}.txt'.format(i + 1), mode='r', encoding='utf-8')as f:
        target_exponent_characters[i] = f.read()

input_images = map(str, input_image)
target_integers = map(str, target_integer)
target_exponents = map(str, target_exponent)
input_characters = set(input_characters)
target_integer_characters = set(target_integer_characters)
target_exponent_characters = set(target_exponent_characters)

for char in input_images:
    if char not in input_characters:
        input_characters.add(char)
for char in target_integers:
    if char not in target_integer_characters:
        target_integer_characters.add(char)
for char in target_exponents:
    if char not in target_exponent_characters:
        target_exponent_characters.add(char)

input_char = sorted(list(input_characters))
target_integer_char = sorted(list(target_integer_characters))
target_exponent_char = sorted(list(target_exponent_characters))
num_encoder_tokens = len(input_char)
num_decoder_integer_tokens = len(target_integer_char)
num_decoder_exponent_tokens = len(target_exponent_char)
max_encoder_seq_length = max([len(number) for number in input_image])
max_decoder_integer_seq_length = max([len(integer) for integer in target_integer])
max_decoder_exponent_seq_length = max([len(exponent) for exponent in target_exponent])

input_token_index = dict(
    [(char, i) for i, char in enumerate(input_char)])
target_integer_token_index = dict(
    [(char, i) for i, char in enumerate(target_integer_char)])
target_exponent_token_index = dict(
    [(char, i) for i, char in enumerate(target_exponent_char)])

encoder_input_data = np.zeros(
    (len(input_image), max_encoder_seq_length, num_encoder_tokens),
    dtype='float32')
decoder_integer_input_data = np.zeros(
    (len(input_image), max_decoder_integer_seq_length, num_decoder_integer_tokens),
    dtype='float32')
decoder_exponent_input_data = np.zeros(
    (len(input_image), max_decoder_exponent_seq_length, num_decoder_exponent_tokens),
    dtype='float32')
decoder_integer_target_data = np.zeros(
    (len(input_image), max_decoder_integer_seq_length, num_decoder_integer_tokens),
    dtype='float32')
decoder_exponent_target_data = np.zeros(
    (len(input_image), max_decoder_exponent_seq_length, num_decoder_exponent_tokens),
    dtype='float32')

for i, (input_images, target_integers, target_exponents) in enumerate(zip(input_image, target_integer, target_exponent)):
    for t, char in enumerate(input_images):
        encoder_input_data[i, t, input_token_index[char]] = 1.
    for t, char in enumerate(taregt_integers):
        decoder_input_data[i, t, target_integer_token_index[char]] = 1.
    for t, char in enumerate(taregt_exponents):
        decoder_input_data[i, t, target_exponent_token_index[char]] = 1.
        if t > 0:
            decoder_integer_target_data[i, t -1, target_integer_token_index[char]] = 1.
            decoder_exponent_target_data[i, t -1, target_integer_token_index[char]] = 1.
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • firedfly

    2019/01/30 11:35

    まずはそのエラーがどういうものだと思うのか書きましょう。
    次に、その原因を調べるため試したことを書きましょう。

    キャンセル

回答 1

check解決した方法

0

85行目の

for i, (input_images, target_integers, target_exponents) in enumerate(zip(input_image, target_integer, target_exponent)):


では、

for i, (input_image, target_integer, target_exponent) in enumerate(zip(input_images, target_integers, target_exponents)):


で反対となっていました。

お騒がせいたしました。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 90.45%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

同じタグがついた質問を見る

  • Python 3.x

    10333questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

  • Keras

    513questions