質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.23%

tensorflowのsessionの保存について

受付中

回答 0

投稿

  • 評価
  • クリップ 0
  • VIEW 460

lstmにおけるtensorflowのsessionの保存方法が分からないので質問させていただきました。

バージョン:tensorflow 1.14

以下のstepが10000になったところでlstmの重みなどを保存したいのですが、保存されず、そこでプログラムが終了してしまいます。

with tf.device('/gpu:%d' % gpu):
    lstm = DET_LSTM(batch_size, input_size, layers, seen_step, fut_step,
                    keep_prob, logs_dir, learning_rate)

・
・
・

if step >= 10000 and step % 10000 == 0:
   lstm.save(sess, models_dir, lstm.global_step)

step = step + 1

DET_LSTMの中身全部は以下の通りです。
def init()の一番最後でtf.train.Saver()を呼び出しています。

class DET_LSTM(object):
  def __init__(self,
               batch_size,
               input_size,
               layers,
               seen_step,
               fut_step,
               keep_prob,
               logs_dir,
               learning_rate,
               mode='train'):

    self.input_size = input_size
    self.point_size = input_size / 2
    self.batch_size = batch_size
    self.seen_step = seen_step
    self.fut_step = fut_step
    if mode == 'train':
      self.seq_len = seen_step + fut_step
    else:
      self.seq_len = seen_step
    self.enc_units = layers[0]
    self.keep_prob = keep_prob
    self.learning_rate = learning_rate

    self.seq_ = tf.placeholder(
        tf.float32, shape=[batch_size, self.seq_len, input_size], name='seq')
    self.mask_ = tf.placeholder(
        tf.float32,
        shape=[batch_size, self.seq_len, self.point_size],
        name='mask')

    stacked_lstm = self.lstm_model(layers)

    mask = tf.concat([self.mask_, self.mask_], 2)
    masked_seq = mask * self.seq_

    act_emb = None
    input_list = []
    input_list_enc = []
    reuse_enc = False
    for t in range(self.seen_step):
      input_list.append(masked_seq[:, t, :])
      input_list_enc.append(relu(linear(
          masked_seq[:, t, :], 32, name='lm_enc', reuse=reuse_enc)))
      reuse_enc = True

    with tf.variable_scope('GEN'):
      with tf.variable_scope('G_LSTM'):
        enc_out, states = tf.contrib.rnn.static_rnn(
            stacked_lstm, input_list_enc, dtype=dtypes.float32)

    reuse_lstm = True
    reuse_output = False
    output_list = input_list
    empty_input = tf.zeros_like(input_list_enc[-1])

    with tf.variable_scope('GEN'):
      for t in range(fut_step):
        with tf.variable_scope('G_LSTM', reuse=reuse_lstm):
          enc_out, states = tf.contrib.rnn.static_rnn(
              stacked_lstm, [empty_input],
              initial_state=states,
              dtype=dtypes.float32)
        with tf.variable_scope('G_LSTM', reuse=reuse_output):
          output = self.decoder(enc_out[-1])
        output_list.append(output)
        reuse_output = True
      self.output = tf.stack(output_list, 1)

    if mode == 'train':
      self.recons_loss = tf.reduce_mean(
          mask[:, seen_step:, :] *
          (self.output[:, seen_step:, :] - self.seq_[:, seen_step:, :])**2)

      loss_sum = tf.summary.scalar("loss", self.recons_loss)
      self.g_sum = tf.summary.merge([loss_sum])
      self.writer = tf.summary.FileWriter(logs_dir, tf.get_default_graph())

      self.g_vars = tf.trainable_variables()

      self.global_step = tf.Variable(0, trainable=False)
      optimizer = tf.train.RMSPropOptimizer(
          self.learning_rate, name='optimizer')
      gradients, g = zip(
          *optimizer.compute_gradients(self.recons_loss, var_list=self.g_vars))

      gradients, _ = tf.clip_by_global_norm(gradients, 25)

      self.optimizer = optimizer.apply_gradients(
          zip(gradients, g), global_step=self.global_step)

      num_param = 0
      for var in self.g_vars:
        num_param += int(np.prod(var.get_shape()))
      print('NUMBER OF PARAMETERS: ' + str(num_param))
    self.saver = tf.train.Saver()

  def decoder(self, input_, reuse=False, name='decoder'):
    out = linear(input_, self.point_size * 2, name='dec_fc2')
    return tanh(out)

  def lstm_model(self, layers):
    lstm_cells = [
        tf.nn.rnn_cell.BasicLSTMCell(units, state_is_tuple=True)
        for units in layers
    ]
    lstm_cells = [
        tf.nn.rnn_cell.DropoutWrapper(cell, input_keep_prob=self.keep_prob)
        for cell in lstm_cells
    ]
    stacked_lstm = tf.nn.rnn_cell.MultiRNNCell(lstm_cells, state_is_tuple=True)
    return stacked_lstm

  def train(self, sess, batches, mask, step, save_logs=False):
    feed_dict = dict()
    feed_dict[self.seq_] = batches
    feed_dict[self.mask_] = mask
    if save_logs:
      _, summary = sess.run([self.optimizer, self.g_sum], feed_dict=feed_dict)
      self.writer.add_summary(summary, step)
    else:
      _ = sess.run([self.optimizer, self.recons_loss], feed_dict=feed_dict)

    errG = self.recons_loss.eval(feed_dict=feed_dict)

    self.global_step = self.global_step + 1
    return errG

  def predict(self, sess, seq_, mask_):
    feed_dict = dict()
    feed_dict[self.seq_] = seq_
    feed_dict[self.mask_] = mask_
    output = self.output.eval(feed_dict=feed_dict)
    return output

  def save(self, sess, checkpoint_dir, step):
    model_name = "DET_LSTM.model"

    if not os.path.exists(checkpoint_dir):
      os.makedirs(checkpoint_dir)

    self.saver.save(
sess, os.path.join(checkpoint_dir, model_name), global_step=step)

  def load(self, sess, checkpoint_dir, model_name=None):
    print("[*] Reading checkpoints...")
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
      if model_name is None: model_name = ckpt_name
      self.saver.restore(sess, os.path.join(checkpoint_dir, model_name))
      print("     Loaded model: "+str(model_name))
      return True, model_name
    else:
      return False, None
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

まだ回答がついていません

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.23%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る