teratail header banner
teratail header banner
質問するログイン新規登録

回答編集履歴

4

実装コード追加

2018/01/04 07:57

投稿

退会済みユーザー
answer CHANGED
@@ -18,4 +18,96 @@
18
18
  print(ckpt)
19
19
 
20
20
  saver.restore(sess,model_path)
21
+ ```
22
+
23
+ ---
24
+ **2018-01-04 16:55 追記**
25
+
26
+ 書き替えを適用するとこんな感じです。
27
+ 確認のため学習を10STEPくらいで止めたものを保存しました。
28
+ 結果、Accuracyは0.90弱でしたが、ロード後(以下のコードです)で学習を飛ばして判定させても結果は0.90弱になりました。
29
+
30
+ ロード後のイメージ写真を貼りますね。
31
+ ![イメージ説明](c8f6d9c207dbd27fd19bb82208f1ccd5.png)
32
+
33
+ ```Python
34
+ # coding: UTF-8
35
+ import tensorflow as tf
36
+ from tensorflow.examples.tutorials.mnist import input_data
37
+
38
+ # データ読み込み
39
+ mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
40
+
41
+ # placeholder用意 xは学習用画像
42
+ x = tf.placeholder(tf.float32, [None, 784])
43
+ # y_は学習用ラベル
44
+ y_ = tf.placeholder(tf.float32, [None, 10])
45
+
46
+ # weightとbias
47
+ # さっきの例ではw * xだったけど、今回はw * x + b
48
+ W = tf.Variable(tf.zeros([784, 10]),name='W')
49
+ b = tf.Variable(tf.zeros([10]),name='b')
50
+
51
+ # Softmax Regressionを使う yはモデル
52
+ y = tf.nn.softmax(tf.matmul(x, W) + b)
53
+
54
+ # 交差エントロピー
55
+ cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
56
+
57
+ # 先ほど使ったGradientDescentOptimizerで、今回はcross_entropyを利用
58
+ train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
59
+
60
+
61
+ # 初期化に関するコードの書き換え
62
+ # -------------------------------------------
63
+
64
+ # model_path = './model/model.ckpt'
65
+ # # 初期化
66
+ #init = tf.global_variables_initializer()
67
+ #sess = tf.Session()
68
+ #sess.run(init)
69
+
70
+ model_path = './model/model.ckpt'
71
+ # 初期化
72
+ init = tf.global_variables_initializer()
73
+ sess = tf.Session()
74
+ saver = tf.train.Saver()
75
+ saver.restore(sess,model_path)
76
+ sess.run(init)
77
+ # -------------------------------------------
78
+
79
+
80
+ # データ読み込み確認のため学習はコメントアウト
81
+ # -------------------------------------------
82
+ # 学習
83
+ #for i in range(1000):
84
+ # print(i)
85
+ # batch_xs, batch_ys = mnist.train.next_batch(100)
86
+ # sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
87
+ # if i == 99:
88
+ # saver = tf.train.Saver()
89
+ # saver.save(sess,model_path)
90
+ # print("model saved" + model_path)
91
+ # print(sess.run(W))
92
+ # print(sess.run(b))
93
+ # -------------------------------------------
94
+
95
+ # テストデータで予測
96
+ correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
97
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
98
+ acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
99
+ print('accuracy : '+str(acc));
100
+ sess.close()
101
+
102
+ sess = tf.InteractiveSession()
103
+ init = tf.global_variables_initializer()
104
+ sess.run(init)
105
+
106
+ saver.restore(sess,model_path)
107
+ print(sess.run(W))
108
+ print(sess.run(b))
109
+ correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
110
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
111
+ acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
112
+ print('accuracy : '+str(acc))
21
113
  ```

3

2018/01/04 07:57

投稿

退会済みユーザー
answer CHANGED
@@ -14,5 +14,8 @@
14
14
  saver = tf.train.Saver()
15
15
  model_path = './model/model.ckpt'
16
16
 
17
+ ckpt = tf.train.get_checkpoint_state('./model/')
18
+ print(ckpt)
19
+
17
20
  saver.restore(sess,model_path)
18
21
  ```

2

2018/01/04 03:26

投稿

退会済みユーザー
answer CHANGED
@@ -7,9 +7,12 @@
7
7
  |[**セッション**](https://www.tensorflow.org/api_docs/python/tf/train/Saver)|`saver.save(...)`|`saver.restore`|
8
8
  |[メタグラフ](https://www.tensorflow.org/api_guides/python/meta_graph)|`export_meta_graph`|`import_meta_graph`|
9
9
 
10
- 問題はセッションの復旧だと思いますので、以下のようにしてはいかがでしょうか?
10
+ 問題はセッションの復旧だと思いますので、以下のようにして、メタグラフのコードをバッサリ削ってはいかがでしょうか?
11
11
 
12
12
  ```Python
13
+ # saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
13
14
  saver = tf.train.Saver()
15
+ model_path = './model/model.ckpt'
16
+
14
17
  saver.restore(sess,model_path)
15
18
  ```

1

2018/01/04 03:25

投稿

退会済みユーザー
answer CHANGED
@@ -1,6 +1,14 @@
1
- `saver = tf.train.import_meta_graph(...)`という書き方を初めて見ました。試していないのでいけませんが、ここが怪しい気がします。
1
+ `saver = tf.train.import_meta_graph(...)`という書き方を初めて見ました。
2
+ 試していないのでいけませんが、ここが怪しい気がします。
2
3
 
4
+ 気になって調べると、以下の関係がある事が分かります。
5
+ |操作対象|書き出し|読み出し|
6
+ |:--:|:--:|:--:|
7
+ |[**セッション**](https://www.tensorflow.org/api_docs/python/tf/train/Saver)|`saver.save(...)`|`saver.restore`|
8
+ |[メタグラフ](https://www.tensorflow.org/api_guides/python/meta_graph)|`export_meta_graph`|`import_meta_graph`|
9
+
3
- 多くサイト、以下のようにまとめられています。これで試してはいかがでしょうか?
10
+ 問題はセッション復旧だと思いますので、以下のようにしてはいかがでしょうか?
11
+
4
12
  ```Python
5
13
  saver = tf.train.Saver()
6
14
  saver.restore(sess,model_path)