質問編集履歴
1
情報の追加
test
CHANGED
File without changes
|
test
CHANGED
@@ -29,3 +29,151 @@
|
|
29
29
|
グラフは上記の図のことで、
|
30
30
|
|
31
31
|
なのでグラフの反映という意味が余計にわかりません。
|
32
|
+
|
33
|
+
コード全体は
|
34
|
+
|
35
|
+
```ここに言語を入力
|
36
|
+
|
37
|
+
# coding: utf-8
|
38
|
+
|
39
|
+
|
40
|
+
|
41
|
+
from __future__ import absolute_import
|
42
|
+
|
43
|
+
from __future__ import division
|
44
|
+
|
45
|
+
from __future__ import print_function
|
46
|
+
|
47
|
+
|
48
|
+
|
49
|
+
import glob
|
50
|
+
|
51
|
+
import tensorflow as tf
|
52
|
+
|
53
|
+
from reader import Cifar10Reader
|
54
|
+
|
55
|
+
import numpy as np
|
56
|
+
|
57
|
+
|
58
|
+
|
59
|
+
FLAGS = tf.app.flags.FLAGS
|
60
|
+
|
61
|
+
tf.app.flags.DEFINE_string('graph_dir',None,"処理するグラフファイルのあるパス")
|
62
|
+
|
63
|
+
tf.app.flags.DEFINE_string('test_data','./data/test_batch.bin',"テストデータのパス")
|
64
|
+
|
65
|
+
|
66
|
+
|
67
|
+
def eval(graph_file):
|
68
|
+
|
69
|
+
tf.reset_default_graph()
|
70
|
+
|
71
|
+
|
72
|
+
|
73
|
+
with tf.gfile.FastGFile(graph_file,'rb') as f:
|
74
|
+
|
75
|
+
graph_def = tf.GraphDef()
|
76
|
+
|
77
|
+
graph_def.ParseFromString(f.read())
|
78
|
+
|
79
|
+
_=tf.import_graph_def(graph_def,name='')
|
80
|
+
|
81
|
+
|
82
|
+
|
83
|
+
labels = tf.placeholder(tf.int32,shape=[1],name='label')
|
84
|
+
|
85
|
+
logits = tf.get_default_graph().get_tensor_by_name('output/logits:0')
|
86
|
+
|
87
|
+
top_k_op = tf.nn.in_top_k(logits,labels,1)
|
88
|
+
|
89
|
+
|
90
|
+
|
91
|
+
with tf.Session() as sess:
|
92
|
+
|
93
|
+
image_reader = Cifar10Reader(FLAGS.test_data)
|
94
|
+
|
95
|
+
|
96
|
+
|
97
|
+
true_count = 0
|
98
|
+
|
99
|
+
for index in range(0,10000):
|
100
|
+
|
101
|
+
image = image_reader.read(index)
|
102
|
+
|
103
|
+
|
104
|
+
|
105
|
+
predictions = sess.run(
|
106
|
+
|
107
|
+
top_k_op,
|
108
|
+
|
109
|
+
feed_dict={
|
110
|
+
|
111
|
+
'input_image:0':image.image,
|
112
|
+
|
113
|
+
labels:image.label,}
|
114
|
+
|
115
|
+
)
|
116
|
+
|
117
|
+
true_count +=np.sum(predictions)
|
118
|
+
|
119
|
+
|
120
|
+
|
121
|
+
print('%s,%.2f'%(graph_file,(true_count/10000.0)))
|
122
|
+
|
123
|
+
image_reader.close()
|
124
|
+
|
125
|
+
|
126
|
+
|
127
|
+
if __name__ == '__main__':
|
128
|
+
|
129
|
+
file_list = glob.glob(FLAGS.graph_dir+'/*.pb')
|
130
|
+
|
131
|
+
for file in file_list:
|
132
|
+
|
133
|
+
eval(file)
|
134
|
+
|
135
|
+
```
|
136
|
+
|
137
|
+
で、
|
138
|
+
|
139
|
+
疑問に思っている箇所は
|
140
|
+
|
141
|
+
```ここに言語を入力
|
142
|
+
|
143
|
+
with tf.Session() as sess:
|
144
|
+
|
145
|
+
image_reader = Cifar10Reader(FLAGS.test_data)
|
146
|
+
|
147
|
+
|
148
|
+
|
149
|
+
true_count = 0
|
150
|
+
|
151
|
+
for index in range(0,10000):
|
152
|
+
|
153
|
+
image = image_reader.read(index)
|
154
|
+
|
155
|
+
|
156
|
+
|
157
|
+
predictions = sess.run(
|
158
|
+
|
159
|
+
top_k_op,
|
160
|
+
|
161
|
+
feed_dict={
|
162
|
+
|
163
|
+
'input_image:0':image.image,
|
164
|
+
|
165
|
+
labels:image.label,}
|
166
|
+
|
167
|
+
)
|
168
|
+
|
169
|
+
true_count +=np.sum(predictions)
|
170
|
+
|
171
|
+
|
172
|
+
|
173
|
+
print('%s,%.2f'%(graph_file,(true_count/10000.0)))
|
174
|
+
|
175
|
+
image_reader.close()
|
176
|
+
|
177
|
+
```
|
178
|
+
|
179
|
+
この部分です。
|