質問編集履歴

1

情報の追加

2017/04/23 12:59

投稿

退会済みユーザー
test CHANGED
File without changes
test CHANGED
@@ -29,3 +29,151 @@
29
29
  グラフは上記の図のことで、
30
30
 
31
31
  なのでグラフの反映という意味が余計にわかりません。
32
+
33
+ コード全体は
34
+
35
+ ```ここに言語を入力
36
+
37
+ # coding: utf-8
38
+
39
+
40
+
41
+ from __future__ import absolute_import
42
+
43
+ from __future__ import division
44
+
45
+ from __future__ import print_function
46
+
47
+
48
+
49
+ import glob
50
+
51
+ import tensorflow as tf
52
+
53
+ from reader import Cifar10Reader
54
+
55
+ import numpy as np
56
+
57
+
58
+
59
+ FLAGS = tf.app.flags.FLAGS
60
+
61
+ tf.app.flags.DEFINE_string('graph_dir',None,"処理するグラフファイルのあるパス")
62
+
63
+ tf.app.flags.DEFINE_string('test_data','./data/test_batch.bin',"テストデータのパス")
64
+
65
+
66
+
67
+ def eval(graph_file):
68
+
69
+ tf.reset_default_graph()
70
+
71
+
72
+
73
+ with tf.gfile.FastGFile(graph_file,'rb') as f:
74
+
75
+ graph_def = tf.GraphDef()
76
+
77
+ graph_def.ParseFromString(f.read())
78
+
79
+ _=tf.import_graph_def(graph_def,name='')
80
+
81
+
82
+
83
+ labels = tf.placeholder(tf.int32,shape=[1],name='label')
84
+
85
+ logits = tf.get_default_graph().get_tensor_by_name('output/logits:0')
86
+
87
+ top_k_op = tf.nn.in_top_k(logits,labels,1)
88
+
89
+
90
+
91
+ with tf.Session() as sess:
92
+
93
+ image_reader = Cifar10Reader(FLAGS.test_data)
94
+
95
+
96
+
97
+ true_count = 0
98
+
99
+ for index in range(0,10000):
100
+
101
+ image = image_reader.read(index)
102
+
103
+
104
+
105
+ predictions = sess.run(
106
+
107
+ top_k_op,
108
+
109
+ feed_dict={
110
+
111
+ 'input_image:0':image.image,
112
+
113
+ labels:image.label,}
114
+
115
+ )
116
+
117
+ true_count +=np.sum(predictions)
118
+
119
+
120
+
121
+ print('%s,%.2f'%(graph_file,(true_count/10000.0)))
122
+
123
+ image_reader.close()
124
+
125
+
126
+
127
+ if __name__ == '__main__':
128
+
129
+ file_list = glob.glob(FLAGS.graph_dir+'/*.pb')
130
+
131
+ for file in file_list:
132
+
133
+ eval(file)
134
+
135
+ ```
136
+
137
+ で、
138
+
139
+ 疑問に思っている箇所は
140
+
141
+ ```ここに言語を入力
142
+
143
+ with tf.Session() as sess:
144
+
145
+ image_reader = Cifar10Reader(FLAGS.test_data)
146
+
147
+
148
+
149
+ true_count = 0
150
+
151
+ for index in range(0,10000):
152
+
153
+ image = image_reader.read(index)
154
+
155
+
156
+
157
+ predictions = sess.run(
158
+
159
+ top_k_op,
160
+
161
+ feed_dict={
162
+
163
+ 'input_image:0':image.image,
164
+
165
+ labels:image.label,}
166
+
167
+ )
168
+
169
+ true_count +=np.sum(predictions)
170
+
171
+
172
+
173
+ print('%s,%.2f'%(graph_file,(true_count/10000.0)))
174
+
175
+ image_reader.close()
176
+
177
+ ```
178
+
179
+ この部分です。