ここに実現したいことを箇条書きで書いてください。
タイトルの通り、学習モデルの重みを取り出したいです。
前提
複数の学習モデルの重みの平均を取り、1つのモデルにしたいと考えています。
なお、言語はpythonを使用しています。
発生している問題・エラーメッセージ
モデルを保存する際以下の形式のファイルが保存されます。
・checkpoint
・data-00000-of-oooo1
・index
・meta
しかし、この中のどこに重みが保存されているのか、またどのように呼び出せば良いのかが分からない状態です。
該当のソースコード
def train(self): total_step = self.train_inputs.shape[0] * FLAGS.num_epoch // FLAGS.batch_size print(self.train_inputs.shape[0]) print('total step is %d' % total_step) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True #メモリを必要分確保 min_validation_loss = 10 best_epoch = 0 with tf.compat.v1.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) #変数の初期化 if FLAGS.restore == True: self.saver.restore(sess, self.save_path) for step in range(total_step): print(step) sample = random.sample(range(self.train_inputs.shape[0]), FLAGS.batch_size) #入力から、バッチサイズ分の要素をランダムに取り出す train_batch = self.train_inputs[sample] train_label = self.train_labels[sample] sess.run(self.optimizer, feed_dict={self.signal_input: train_batch, self.signal_label: train_label, #学習の実行 self.rnn_keep_prob: FLAGS.rnn_keep_prob, self.dense_drop_rate: FLAGS.drop_rate}) if step % 100 == 0: # print('-----------------------------------------------------------') train_loss = sess.run(self.loss, feed_dict={self.signal_input: train_batch, self.signal_label: train_label, self.rnn_keep_prob: FLAGS.rnn_keep_prob, self.dense_drop_rate: 0}) fft_train_loss = sess.run(self.fft_loss, feed_dict={self.signal_input: train_batch, self.signal_label: train_label, self.rnn_keep_prob: FLAGS.rnn_keep_prob, self.dense_drop_rate: 0}) valid_loss = sess.run(self.loss, feed_dict={self.signal_input: self.valid_inputs, self.signal_label: self.valid_labels, self.rnn_keep_prob: 1, self.dense_drop_rate: 0}) fft_valid_loss = sess.run(self.fft_loss, feed_dict={self.signal_input: self.valid_inputs, self.signal_label: self.valid_labels, self.rnn_keep_prob: 1, self.dense_drop_rate: 0}) print('current step is %d' % step) num_epoch = step * FLAGS.batch_size // self.train_inputs.shape[0] print('current epoch is %d' % (num_epoch)) print('') print('train loss is: %f' % train_loss) print('fft train loss is: %f' % fft_train_loss) print('sum_train loss is: %f' % (train_loss + fft_train_loss)) print('') print('valid loss is: %f' % valid_loss) print('fft valid loss is: %f' % fft_valid_loss) print('sum_valid real loss is: %f' % (valid_loss + fft_valid_loss)) print('minimum valid loss is: {0:0.4f} in epoch {1}'.format(min_validation_loss, best_epoch)) if valid_loss < min_validation_loss: best_epoch = num_epoch min_validation_loss = valid_loss self.saver.save(sess, self.save_path) print(self.save_path) print('best model saved!!')
試したこと
色々調べてみたのですが、他の方のやり方では、コードを一から書き直す必要があるように見えたので、少し自分には難易度が高く、打つ手がありません…
補足情報(FW/ツールのバージョンなど)
私自身機械学習に関して初心者なのと、今回初めてteratailを使用したため、かなり分かりにくい質問内容になっているかと思いますが、ご回答いただけますと幸いです。

回答2件
あなたの回答
tips
プレビュー