質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%

Q&A

0回答

1031閲覧

TensorFlowのBasicDecoderのstepの中でbatch_scatter_updateを使った際の問題について

hukuda222

総合スコア13

0グッド

0クリップ

投稿2019/05/29 13:14

編集2019/05/30 06:20

実現したいこと

BasicDecoderを継承したクラスのstep関数の内側でbatch_scatter_updateを使おうとしています。

ソースコード(簡単のためにあまり意味がないコードになっています)

python

1import tensorflow as tf 2from tensorflow.contrib.seq2seq.python.ops import basic_decoder 3from tensorflow.python.framework import ops 4from tensorflow.python.layers import core as layers_core 5 6 7class CopyDecoder(basic_decoder.BasicDecoder): 8 def __init__(self, cell, helper, initial_state, 9 batch_size, output_size, 10 output_layer=None, internal_size=100): 11 super().__init__(cell, helper, initial_state, output_layer) 12 self.internal_size = internal_size 13 self.batch_size_ = batch_size 14 self.W = layers_core.Dense( 15 self.internal_size, name="W", use_bias=False) 16 self.index = tf.tile(tf.expand_dims(tf.range(self.internal_size), axis=0), [batch_size, 1]) 17 18 19 def step(self, time, inputs, state, name=None): 20 21 with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): 22 cell_outputs, cell_state =\ 23 self._cell(inputs, state) 24 weight = self.W(cell_output) 25 26 def init_value(): 27 return tf.tile(tf.zeros([1, self.target_vocab_size], 28 dtype=tf.float32), 29 [self.batch_size, 1]) 30 ref = tf.Variable( 31 initial_value=init_value, validate_shape=False) 32 scatter = ref.batch_scatter_update( 33 tf.IndexedSlices(weight, self.index)) 34 cell_outputs = scatter 35 36 sample_ids = self._helper.sample( 37 time=time, outputs=cell_outputs, state=cell_state) 38 (finished, next_inputs, next_state) = self._helper.next_inputs( 39 time=time, 40 outputs=cell_outputs, 41 state=cell_state, 42 sample_ids=sample_ids) 43 outputs = basic_decoder.BasicDecoderOutput(cell_outputs, sample_ids) 44 return (outputs, next_state, next_inputs, finished)

エラーの内容

File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 337, in dynam$c_decode final_outputs = nest.map_structure(_transpose_batch_time, final_outputs) File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 381, in map_structure structure[0], [func(*x) for x in entries]) File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/util/nest.py", line 381, in <listcomp> structure[0], [func(*x) for x in entries]) File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py", line 67, in _transpose_batch_time x_static_shape.dims[1].value, x_static_shape.dims[0].value TypeError: 'NoneType' object is not subscriptable

問題が発生するまでの流れ

通常のBasicDecoderの代わりにそれを継承した以下のようなクラスを使おうとしました。

RNNの内側でVariableを定義する際のinitial_valueにTensorをそのまま指定するとValueError: Initializer for variable model/decoder/while/BasicDecoderStep/Variable/ is from inside a control-flow construct, such as a loop or conditional. When creating a variab le inside a loop or conditional, use a lambda as the initializer.
というエラーが発生するためTensorではなく関数をinitial_valueに指定しました。

それによってbatch_scatter_updateの出力のshapeが<unknown>になり、そのせいでその後の処理のshapeを参照する箇所で上記のエラーが出てしまっているように思われます。

環境

OSはUbuntu 16.04.5
Pythonのバージョンは3.6.8
Tensorflowのバージョンは1.13.1
です。

解決策をご存知の方は教えていただけると幸いです。

追記

python

1ref.set_shape(init_value().shape)

をref.batch_scatter_updateの前の行に追記することでshapeのエラーはとりあえず出なくなりましたが、以下のエラーが出るようになりました。

File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/numpy/core/fromnumeric.py", line 2772, in prod initial=initial) File "/home/.pyenv/versions/anaconda3-5.3.1/envs/tensorflow_gpuenv/lib/python3.6/site-packages/numpy/core/fromnumeric.py", line 86, in _wrapreduction return ufunc.reduce(obj, axis, dtype, out, **passkwargs) TypeError: unsupported operand type(s) for *: 'NoneType' and 'int' ``` おそらくこれもinitial_valueに関数を指定しているせいで、typeがうまく参照されてないように思うのですが、直し方がわからないのでわかる方は教えていただけると幸いです。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問