質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.54%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

1回答

4537閲覧

TensorFlowでモデルを保存して復元できない

donafudo

総合スコア46

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2018/08/05 08:01

###問題
本を参考にしながらRNNで言語モデルを作って、いざ文章を生成しようとすると、
トレーニングと生成を同時に実行するとそれらしい文章が出来るが、
トレーニングと生成を別々に実行すると、ちゃんとrestoreされていないのか、でたらめな文章が生成される。

saver.save()で保存されたファイルは以下の4つ
checkpoint
language-model.ckpt.data-00000-of-00001
language-model.ckpt.index
language-model.ckpt.meta

tensorflowのあるバージョンからはmetaファイルが生成されるので
self.saver = tf.train.import_meta_graph(os.path.join(ckpt_dir, 'language-model.ckpt.meta'))
が必要という情報があったので試してみても結果は変わらず、
エラーも特に出ていないので手詰まり状態です。
下記のコードに問題があるところはありますでしょうか?

#####環境
OS:macOS High Sierra 10.13.6
tensorflow 1.9.0
Pyton3.6

python

1import numpy as np 2 3# テキストを読み込んで処理 4with open('wordDataSet', 'r', encoding='utf-8')as f: 5 text = f.read() 6 7text = text[:] 8 9chars = set(text) 10char2int = {ch: i for i, ch in enumerate(chars)} 11int2char = dict(enumerate(chars)) 12text_ints = np.array([char2int[ch] for ch in text], dtype=np.int32) 13 14 15def reshape_data(sequence, batch_size, num_steps): 16 tot_batch_length = batch_size * num_steps 17 num_batches = int(len(sequence) / tot_batch_length) 18 print(len(sequence)) 19 if num_batches * tot_batch_length + 1 > len(sequence): 20 num_batches = num_batches - 1 21 22 # シーケンスの最後の部分から完全なバッチにならない半端な文字を削除 23 x = sequence[0:num_batches * tot_batch_length] 24 y = sequence[1:num_batches * tot_batch_length + 1] 25 26 27 # xとyをシーケンスのバッチのリストに分割 28 x_batch_splits = np.split(x, batch_size) 29 y_batch_splits = np.split(y, batch_size) 30 31 # それらのバッチを結合 32 # batch_size×tot_batch_length 33 x = np.stack(x_batch_splits) 34 y = np.stack(y_batch_splits) 35 36 return x, y 37 38 39def create_batch_generator(data_x, data_y, num_steps): 40 batch_size, tot_batch_length = data_x.shape 41 num_batches = int(tot_batch_length / num_steps) 42 for b in range(num_batches): 43 yield (data_x[:, b * num_steps:(b + 1) * num_steps], 44 data_y[:, b * num_steps:(b + 1) * num_steps]) 45 46 47def get_top_char(probas, char_size, top_n=2): 48 p = np.squeeze(probas) 49 p[np.argsort(p)[:-top_n]] = 0.0 50 p = p / np.sum(p) 51 ch_id = np.random.choice(char_size, 1, p=p)[0] 52 return ch_id 53 54 55import tensorflow as tf 56import os 57 58class CharRNN(object): 59 def __init__(self, num_classes, batch_size=64, num_steps=100, 60 lstm_size=128, num_layers=1, learning_rate=0.001, 61 keep_prob=0.5, grad_clip=5, sampling=False): 62 self.num_classes = num_classes 63 self.batch_size = batch_size 64 self.num_steps = num_steps 65 self.lstm_size = lstm_size 66 self.num_layers = num_layers 67 self.learning_rate = learning_rate 68 self.keep_prob = keep_prob 69 self.grad_clip = grad_clip 70 71 self.g = tf.Graph() 72 with self.g.as_default(): 73 tf.set_random_seed(123) 74 75 self.build(sampling=sampling) 76 self.saver = tf.train.Saver(max_to_keep=2) 77 self.init_op = tf.global_variables_initializer() 78 79 def build(self, sampling): 80 if sampling == True: 81 batch_size, num_steps = 1, 1 82 else: 83 batch_size = self.batch_size 84 num_steps = self.num_steps 85 86 tf_x = tf.placeholder(tf.int32, 87 shape=[batch_size, num_steps], 88 name='tf_x') 89 tf_y = tf.placeholder(tf.int32, 90 shape=[batch_size, num_steps], 91 name='tf_y') 92 tf_keepprob = tf.placeholder(tf.float32, 93 name='tf_keepprob') 94 95 # one-hotエンコーディングを適用 96 x_onehot = tf.one_hot(tf_x, depth=self.num_classes) 97 y_onehot = tf.one_hot(tf_y, depth=self.num_classes) 98 99 # 多層RNNのセルを構築 100 cells = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.DropoutWrapper( 101 tf.contrib.rnn.BasicLSTMCell(self.lstm_size), 102 output_keep_prob=tf_keepprob) for _ in range(self.num_layers)]) 103 104 # 初期状態を定義 105 self.initial_state = cells.zero_state(batch_size, tf.float32) 106 107 # RNNで各シーケンスステップを実行 108 lstm_outputs, self.final_state = tf.nn.dynamic_rnn(cells, x_onehot, initial_state=self.initial_state) 109 110 print(' << lstm_outputs >>', lstm_outputs) 111 112 # 二次元テンソルに変形 113 seq_output_reshaped = tf.reshape(lstm_outputs, 114 shape=[-1, self.lstm_size], 115 name='seq_output_reshaped') 116 117 # そう入力を取得 118 logits = tf.layers.dense(inputs=seq_output_reshaped, 119 units=self.num_classes, 120 activation=None, 121 name='logits') 122 123 # 次の文字バッチの確率を計算 124 proba = tf.nn.softmax(logits, name='probabilities') 125 126 # コスト関数を定義 127 y_reshaped = tf.reshape(y_onehot, 128 shape=[-1, self.num_classes], 129 name='y_reshaped') 130 cost = tf.reduce_mean( 131 tf.nn.softmax_cross_entropy_with_logits(logits=logits, 132 labels=y_reshaped), 133 name='cost') 134 135 # 勾配発散問題を回避するための勾配刈り込み 136 tvars = tf.trainable_variables() 137 grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), 138 self.grad_clip) 139 140 # オプティマイザを定義 141 optimizer = tf.train.AdamOptimizer(self.learning_rate) 142 train_op = optimizer.apply_gradients(zip(grads, tvars), 143 name='train_op') 144 145 def train(self, train_x, train_y, num_epochs, ckpt_dir='./model/'): 146 147 # チェックポイントディレクトリがまだ存在しない場合は作成 148 if not os.path.exists(ckpt_dir): 149 os.mkdir(ckpt_dir) 150 151 with tf.Session(graph=self.g) as sess: 152 sess.run(self.init_op) 153 154 n_batches = int(train_x.shape[1] / self.num_steps) 155 iterations = n_batches * num_epochs 156 for epoch in range(num_epochs): 157 158 # ネットワークをトレーニング 159 new_state = sess.run(self.initial_state) 160 loss = 0 161 162 # ミニバッチジェネレータ 163 bgen = create_batch_generator(train_x, train_y, self.num_steps) 164 for b, (batch_x, batch_y) in enumerate(bgen, 1): 165 iteration = epoch * n_batches + b 166 167 feed = {'tf_x:0': batch_x, 168 'tf_y:0': batch_y, 169 'tf_keepprob:0': self.keep_prob, 170 self.initial_state: new_state} 171 batch_cost, _, new_state = sess.run( 172 ['cost:0', 'train_op', 173 self.final_state], 174 feed_dict=feed) 175 176 if iteration % 10 == 0: 177 print('Epoch %d/%d iteration %d| Training loss: %.4f' % 178 (epoch + 1, num_epochs, iteration, batch_cost)) 179 180 # トレーニング済みのモデルを保存 181 self.saver.save(sess, 182 os.path.join(ckpt_dir, 'language-model.ckpt')) 183 184 def sample(self, output_length, ckpt_dir, starter_seq="電子書籍"): 185 observed_seq = [ch for ch in starter_seq] 186 with tf.Session(graph=self.g)as sess: 187 self.saver = tf.train.import_meta_graph(os.path.join(ckpt_dir, 'language-model.ckpt.meta')) 188 self.saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir)) 189 190 # 1: starter_seqを使ってモデルを実行 191 new_state = sess.run(self.initial_state) 192 for ch in starter_seq: 193 x = np.zeros((1, 1)) 194 x[0, 0] = char2int[ch] 195 feed = {'tf_x:0': x, 'tf_keepprob:0': 1.0, 196 self.initial_state: new_state} 197 proba, new_state = sess.run( 198 ['probabilities:0', self.final_state], 199 feed_dict=feed) 200 201 ch_id = get_top_char(proba, len(chars)) 202 observed_seq.append(int2char[ch_id]) 203 204 # 2: 更新されたobserved_seqを使ってモデルを実行 205 for i in range(output_length): 206 x[0, 0] = ch_id 207 feed = {'tf_x:0': x, 'tf_keepprob:0': 1.0, 208 self.initial_state: new_state} 209 proba, new_state = sess.run( 210 ['probabilities:0', self.final_state], 211 feed_dict=feed) 212 213 ch_id = get_top_char(proba, len(chars)) 214 observed_seq.append(int2char[ch_id]) 215 return ''.join(observed_seq) 216 217 218batch_size = 32 219num_steps = 30 220train_x, train_y = reshape_data(text_ints, batch_size, num_steps) 221 222# モデルのトレーニング 223rnn=CharRNN(num_classes=len(chars),batch_size=batch_size) 224rnn.train(train_x,train_y,num_epochs=1000,ckpt_dir='./model-100/') 225del rnn 226 227# 文章生成 228np.random.seed(123) 229rnn = CharRNN(len(chars), sampling=True) 230print(rnn.sample(ckpt_dir='./model-100/', output_length=500,starter_seq='明日は')) 231

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

参考になるかわかりませんけど自分の環境ではsave,restoreは以下のように記述したらrestoreできてるような気がします。

python

1#self.saver 2self.saver = tf.train.Saver() 3 4#モデルの保存 5self.saver.save(self.sess,"./checkpoint/m.ckpt") 6 7#restore 8ckpt = tf.train.get_checkpoint_state('./checkpoint') 9if ckpt:#モデルファイルがあるとき 10 last_model = ckpt.model_checkpoint_path 11 self.saver.restore(self.sess, last_model)

投稿2018/08/21 09:33

bsk

総合スコア174

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

donafudo

2018/08/23 05:44

回答ありがとうございます。 同じ記述でやってみたのですが、結果は変わりませんでした。 一度、別の環境で試してみます。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.54%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

同じタグがついた質問を見る

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。