回答編集履歴

4

実装コード追加

2018/01/04 07:57

投稿

退会済みユーザー
test CHANGED
@@ -39,3 +39,187 @@
39
39
  saver.restore(sess,model_path)
40
40
 
41
41
  ```
42
+
43
+
44
+
45
+ ---
46
+
47
+ **2018-01-04 16:55 追記**
48
+
49
+
50
+
51
+ 書き替えを適用するとこんな感じです。
52
+
53
+ 確認のため学習を10STEPくらいで止めたものを保存しました。
54
+
55
+ 結果、Accuracyは0.90弱でしたが、ロード後(以下のコードです)で学習を飛ばして判定させても結果は0.90弱になりました。
56
+
57
+
58
+
59
+ ロード後のイメージ写真を貼りますね。
60
+
61
+ ![イメージ説明](c8f6d9c207dbd27fd19bb82208f1ccd5.png)
62
+
63
+
64
+
65
+ ```Python
66
+
67
+ # coding: UTF-8
68
+
69
+ import tensorflow as tf
70
+
71
+ from tensorflow.examples.tutorials.mnist import input_data
72
+
73
+
74
+
75
+ # データ読み込み
76
+
77
+ mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
78
+
79
+
80
+
81
+ # placeholder用意 xは学習用画像
82
+
83
+ x = tf.placeholder(tf.float32, [None, 784])
84
+
85
+ # y_は学習用ラベル
86
+
87
+ y_ = tf.placeholder(tf.float32, [None, 10])
88
+
89
+
90
+
91
+ # weightとbias
92
+
93
+ # さっきの例ではw * xだったけど、今回はw * x + b
94
+
95
+ W = tf.Variable(tf.zeros([784, 10]),name='W')
96
+
97
+ b = tf.Variable(tf.zeros([10]),name='b')
98
+
99
+
100
+
101
+ # Softmax Regressionを使う yはモデル
102
+
103
+ y = tf.nn.softmax(tf.matmul(x, W) + b)
104
+
105
+
106
+
107
+ # 交差エントロピー
108
+
109
+ cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
110
+
111
+
112
+
113
+ # 先ほど使ったGradientDescentOptimizerで、今回はcross_entropyを利用
114
+
115
+ train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
116
+
117
+
118
+
119
+
120
+
121
+ # 初期化に関するコードの書き換え
122
+
123
+ # -------------------------------------------
124
+
125
+
126
+
127
+ # model_path = './model/model.ckpt'
128
+
129
+ # # 初期化
130
+
131
+ #init = tf.global_variables_initializer()
132
+
133
+ #sess = tf.Session()
134
+
135
+ #sess.run(init)
136
+
137
+
138
+
139
+ model_path = './model/model.ckpt'
140
+
141
+ # 初期化
142
+
143
+ init = tf.global_variables_initializer()
144
+
145
+ sess = tf.Session()
146
+
147
+ saver = tf.train.Saver()
148
+
149
+ saver.restore(sess,model_path)
150
+
151
+ sess.run(init)
152
+
153
+ # -------------------------------------------
154
+
155
+
156
+
157
+
158
+
159
+ # データ読み込み確認のため学習はコメントアウト
160
+
161
+ # -------------------------------------------
162
+
163
+ # 学習
164
+
165
+ #for i in range(1000):
166
+
167
+ # print(i)
168
+
169
+ # batch_xs, batch_ys = mnist.train.next_batch(100)
170
+
171
+ # sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
172
+
173
+ # if i == 99:
174
+
175
+ # saver = tf.train.Saver()
176
+
177
+ # saver.save(sess,model_path)
178
+
179
+ # print("model saved" + model_path)
180
+
181
+ # print(sess.run(W))
182
+
183
+ # print(sess.run(b))
184
+
185
+ # -------------------------------------------
186
+
187
+
188
+
189
+ # テストデータで予測
190
+
191
+ correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
192
+
193
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
194
+
195
+ acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
196
+
197
+ print('accuracy : '+str(acc));
198
+
199
+ sess.close()
200
+
201
+
202
+
203
+ sess = tf.InteractiveSession()
204
+
205
+ init = tf.global_variables_initializer()
206
+
207
+ sess.run(init)
208
+
209
+
210
+
211
+ saver.restore(sess,model_path)
212
+
213
+ print(sess.run(W))
214
+
215
+ print(sess.run(b))
216
+
217
+ correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
218
+
219
+ accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
220
+
221
+ acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
222
+
223
+ print('accuracy : '+str(acc))
224
+
225
+ ```

3

2018/01/04 07:57

投稿

退会済みユーザー
test CHANGED
@@ -30,6 +30,12 @@
30
30
 
31
31
 
32
32
 
33
+ ckpt = tf.train.get_checkpoint_state('./model/')
34
+
35
+ print(ckpt)
36
+
37
+
38
+
33
39
  saver.restore(sess,model_path)
34
40
 
35
41
  ```

2

2018/01/04 03:26

投稿

退会済みユーザー
test CHANGED
@@ -16,13 +16,19 @@
16
16
 
17
17
 
18
18
 
19
- 問題はセッションの復旧だと思いますので、以下のようにしてはいかがでしょうか?
19
+ 問題はセッションの復旧だと思いますので、以下のようにして、メタグラフのコードをバッサリ削ってはいかがでしょうか?
20
20
 
21
21
 
22
22
 
23
23
  ```Python
24
24
 
25
+ # saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
26
+
25
27
  saver = tf.train.Saver()
28
+
29
+ model_path = './model/model.ckpt'
30
+
31
+
26
32
 
27
33
  saver.restore(sess,model_path)
28
34
 

1

2018/01/04 03:25

投稿

退会済みユーザー
test CHANGED
@@ -1,8 +1,24 @@
1
- `saver = tf.train.import_meta_graph(...)`という書き方を初めて見ました。試していないのでいけませんが、ここが怪しい気がします。
1
+ `saver = tf.train.import_meta_graph(...)`という書き方を初めて見ました。
2
+
3
+ 試していないのでいけませんが、ここが怪しい気がします。
2
4
 
3
5
 
4
6
 
7
+ 気になって調べると、以下の関係がある事が分かります。
8
+
9
+ |操作対象|書き出し|読み出し|
10
+
11
+ |:--:|:--:|:--:|
12
+
13
+ |[**セッション**](https://www.tensorflow.org/api_docs/python/tf/train/Saver)|`saver.save(...)`|`saver.restore`|
14
+
15
+ |[メタグラフ](https://www.tensorflow.org/api_guides/python/meta_graph)|`export_meta_graph`|`import_meta_graph`|
16
+
17
+
18
+
5
- 多くサイト、以下のようにまとめられています。これで試してはいかがでしょうか?
19
+ 問題はセッション復旧だと思いますので、以下のようにしてはいかがでしょうか?
20
+
21
+
6
22
 
7
23
  ```Python
8
24