質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.04%

AttentionモデルのDecoder Inputについて

受付中

回答 0

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 1,028

score 13

前提・実現したいこと

現在、Tensorflowを利用して深層学習(系列変換モデル)について学んでおります。
https://www.tensorflow.org/beta/tutorials/text/image_captioning
Tensorflowチュートリアル(上記URLページ)にある、画像のキャプション生成を行なっているのですが、
Decoderのモデルのcall関数について、疑問をもったので質問させてください。

発生している問題・エラーメッセージ

上記ページ内ではshow, attend and tellという論文に基づいて、Attention機構を導入したCNNEncoder+RNNDecoderでキャプションを生成するというモデルになっています。
私の理解では、時刻tにおけるDecoderRNNへの入力となるのはx(t)と前時刻のRNNの出力(隠れ層の出力)であるh(t-1)であり、これにAttentionを導入した場合はh(t-1)はAttentionとの重み付き平均で表されるものだと思っていたのですが、以下のような実装でも同じ出力となるのでしょうか。

以下、チュートリアルページに記載のあったDecoderクラスです。

class RNN_Decoder(tf.keras.Model):
  def __init__(self, embedding_dim, units, vocab_size):
    super(RNN_Decoder, self).__init__()
    self.units = units

    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc1 = tf.keras.layers.Dense(self.units)
    self.fc2 = tf.keras.layers.Dense(vocab_size)

    self.attention = BahdanauAttention(self.units)

  def call(self, x, features, hidden):
    # defining attention as a separate model
    context_vector, attention_weights = self.attention(features, hidden)

    # x shape after passing through embedding == (batch_size, 1, embedding_dim)
    x = self.embedding(x)

    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # passing the concatenated vector to the GRU
    output, state = self.gru(x)

    # shape == (batch_size, max_length, hidden_size)
    x = self.fc1(output)

    # x shape == (batch_size * max_length, hidden_size)
    x = tf.reshape(x, (-1, x.shape[2]))

    # output shape == (batch_size * max_length, vocab)
    x = self.fc2(x)

    return x, state, attention_weights

  def reset_state(self, batch_size):
    return tf.zeros((batch_size, self.units))


上記のcall関数内で、attentionと埋め込みを得るところ(x=self.embedding(x)のところ)まではわかるのですが、その後attentionとxをconcatしてself.gruへ入力するというところに疑問をもっています。
context_vectorはencoder_output(CNNが出力した特徴量)と(t-1)のdecoder_hiddenから得たattentionを表しています。また、self.gru(x)のとき、つまり、initial_stateがNoneの時、これは零ベクトルで計算されるようになっているようです。
私自身の理解ではself.gru(x, initial_state=context_vector)となるのではないかなと思っていたのですが、上記のような記述でも同じ出力が得られるのでしょうか?
Encoder-Decoderモデルについて詳しい方がいらっしゃったら、どうかご教授のほどお願いいたします。

補足情報(FW/ツールのバージョンなど)

python 3.5.2
tensorflow 2.0.0-alpha0

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

まだ回答がついていません

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.04%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る