質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
87.20%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

受付中

Tensorflowでのモデルの復元と予測値の表示

leon30
leon30

総合スコア0

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

3回答

0評価

2クリップ

6611閲覧

投稿2017/06/26 04:40

編集2022/01/12 10:55

大変初歩的な質問になり恐縮です。
どなたかお分かりの方教えていただければ大変助かります。
どうかよろしくお願いいたします。

###概要

<概要>
Webアプリに実装するエンジンとして、tensorflowによる予測モデルを作っています。

<前提>
ある商品を購入する際に、意思決定に影響を与える要素12個を入力データとして
商品のスペックに関わる要素1個を予測させようとしています。
これらのデータをCSVで用意し、1列目に正解のデータ、2~13列目に入力データを配列しています。

<方法>
事前に1400件のデータを用いて学習を行い、学習済みのモデルを保存。
そのモデルをサーバー上に置き、UIから送られてくるデータ(12列×1行)に対して保存済みのモデルを復元し、値を予測させるという段取りで考えています。

<問題が発生した箇所>
・保存済みのモデルの復元
・予測結果の表示

###発生している問題・エラーメッセージ

保存済みのモデルが復元できません。
また予測した値そのものを表示させたいのですが、その方法がわかりません。

load C:/Users…\(モデルを保存したPATHが表示されている\) --------------------------------------------------------------------------- TypeError Traceback \(most recent call last\) <ipython-input-16-da567374871a> in <module>\(\) 8 last_model = ckpt\.model_checkpoint_path 9 print \("load " \+ last_model\) ---> 10 tf\.train\.Saver\.restore\(tf\.Session\(\), last_model\) 11 12 TypeError: restore\(\) missing 1 required positional argument: 'save_path'

###該当のソースコード

python

import tensorflow as tf ckpt = tf\.train\.get_checkpoint_state\("セーブしているフォルダPATH"\) last_model = ckpt\.model_checkpoint_path print \("load " \+ last_model\) tf\.train\.Saver\.restore\(tf\.Session\(\), last_model\) import numpy new_input = numpy\.loadtxt\(open\("新しい入力として保存されたCSV"\), delimiter=","\) new_input = numpy\.hsplit\(raw_input, \[1\]\) prediction = tf\.argmax\(tensor,1\) print\("result: %g"%prediction\.eval\(feed_dict={score: new_input, keep_prob: 1\.0}, session=sess\)\)

###保存したコード
念のため保存したファイル(復元させたい)も下記に記します。

python

import tensorflow as tf import numpy SCORE_SIZE = 12 HIDDEN_UNIT_SIZE = 70 TRAIN_DATA_SIZE = 1394 raw_input = numpy\.loadtxt\(open\("train\.csv"\), delimiter=","\) \[tensor, score\] = numpy\.hsplit\(raw_input, \[1\]\) \[tensor_train, tensor_test\] = numpy\.vsplit\(tensor, \[TRAIN_DATA_SIZE\]\) \[score_train, score_test\] = numpy\.vsplit\(score, \[TRAIN_DATA_SIZE\]\) def inference\(score_placeholder\): with tf\.name_scope\('hidden1'\) as scope: hidden1_weight = tf\.Variable\(tf\.truncated_normal\(\[SCORE_SIZE, HIDDEN_UNIT_SIZE\], stddev=0\.1\), name="hidden1_weight"\) hidden1_bias = tf\.Variable\(tf\.constant\(0\.1, shape=\[HIDDEN_UNIT_SIZE\]\), name="hidden1_bias"\) hidden1_output = tf\.nn\.relu\(tf\.matmul\(score_placeholder, hidden1_weight\) \+ hidden1_bias\) with tf\.name_scope\('output'\) as scope: output_weight = tf\.Variable\(tf\.truncated_normal\(\[HIDDEN_UNIT_SIZE, 1\], stddev=0\.1\), name="output_weight"\) output_bias = tf\.Variable\(tf\.constant\(0\.1, shape=\[1\]\), name="output_bias"\) output = tf\.matmul\(hidden1_output, output_weight\) \+ output_bias return tf\.nn\.l2_normalize\(output, 0\) def loss\(output, tensor_placeholder, loss_label_placeholder\): with tf\.name_scope\('loss'\) as scope: loss = tf\.nn\.l2_loss\(output - tf\.nn\.l2_normalize\(tensor_placeholder, 0\)\) tf\.summary\.scalar\('loss_label_placeholder', loss\) return loss def training\(loss\): with tf\.name_scope\('training'\) as scope: train_step = tf\.train\.GradientDescentOptimizer\(0\.01\)\.minimize\(loss\) return train_step with tf\.Graph\(\)\.as_default\(\): tensor_placeholder = tf\.placeholder\("float", \[None, 1\], name="tensor_placeholder"\) score_placeholder = tf\.placeholder\("float", \[None, SCORE_SIZE\], name="score_placeholder"\) loss_label_placeholder = tf\.placeholder\("string", name="loss_label_placeholder"\) feed_dict_train={ tensor_placeholder: tensor_train, score_placeholder: score_train, loss_label_placeholder: "loss_train" } feed_dict_test={ tensor_placeholder: tensor_test, score_placeholder: score_test, loss_label_placeholder: "loss_test" } output = inference\(score_placeholder\) loss = loss\(output, tensor_placeholder, loss_label_placeholder\) training_op = training\(loss\) summary_op = tf\.summary\.merge_all\(\) init = tf\.global_variables_initializer\(\) best_loss = float\("inf"\) with tf\.Session\(\) as sess: summary_writer = tf\.summary\.FileWriter\('data', graph_def=sess\.graph_def\) sess\.run\(init\) for step in range\(10000\): sess\.run\(training_op, feed_dict=feed_dict_train\) loss_test = sess\.run\(loss, feed_dict=feed_dict_test\) if loss_test < best_loss: best_loss = loss_test best_match = sess\.run\(output, feed_dict=feed_dict_test\) if step % 100 == 0: summary_str = sess\.run\(summary_op, feed_dict=feed_dict_test\) summary_str \+= sess\.run\(summary_op, feed_dict=feed_dict_train\) summary_writer\.add_summary\(summary_str, step\) print \(sess\.run\(tf\.nn\.l2_normalize\(tensor_placeholder, 0\), feed_dict=feed_dict_test\)\) print \(best_match\) saver=tf\.train\.Saver\(\) saver\.save\(sess,"保存先のPATH"\) print\('Saved a model\.'\) sess\.close\(\)

これに対する応答

WARNING:tensorflow:Passing a `GraphDef` to the SummaryWriter is deprecated\. Pass a `Graph` object instead, such as `sess\.graph`\. \[\[ 0\.44367826\] \[ 0\.35494262\] \[ 0\.5324139 \] \[ 0\.62114954\] \[ 0\.08873565\]\] \[\[ 0\.5225144 \] \[ 0\.37568542\] \[ 0\.38035378\] \[ 0\.52725464\] \[ 0\.40394634\]\] Saved a model\.

###補足情報(言語/FW/ツール等のバージョンなど)
使用しているバージョンは
Windows10
python3.5
tensorflow1.0
anaconda3
です。

良い質問の評価を上げる

以下のような質問は評価を上げましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

  • プログラミングに関係のない質問
  • やってほしいことだけを記載した丸投げの質問
  • 問題・課題が含まれていない質問
  • 意図的に内容が抹消された質問
  • 過去に投稿した質問と同じ内容の質問
  • 広告と受け取られるような投稿

評価を下げると、トップページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

まだ回答がついていません

会員登録して回答してみよう

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
87.20%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。