AttributeError: 'Cifar10Reader' object has no attribute 'bytestream' のエラー が出ました。
reader.py
# -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import numpy as np class Cifar10Record(object): width = 32 height = 32 depth = 3 def set_label(self,label_byte): #self.label = np.frombuffer(label_byte,dtype=np.unit8) self.label = np.frombuffer(label_byte,dtype=np.uint8) # unit8 -> uint8に修正 def set_image(self,image_bytes): byte_buffer = np.frombuffer(image_bytes,dtype=np.int8) reshaped_array = np.reshape(byte_buffer,[self.depth,self.height,self.width]) self.byte_array = np.transpose(reshaped_array,[1,2,0]) self.byte_array = self.byte_array.astype(np.float32) class Cifar10Reader(object): def __init__(self,filename): if not os.path.exists(filename): print(filename + ' is not exist') return self.bytestream = open(filename,mode="rb") def close(self): if not self.bytestream: self.bytestream.close() def read(self,index): result = Cifar10Record() label_bytes = 1 image_bytes = result.height * result.width * result.depth record_bytes = label_bytes + image_bytes self.bytestream.seek(record_bytes * index,0) result.set_label(self.bytestream.read(label_bytes)) result.set_image(self.bytestream.read(image_bytes)) return result print(self.bytestream) # 追加 reader = Cifar10Reader("lena_std.tif") ret = reader.read(0) print(ret)
というCIFAR-10形式のデータセットを読み込むプログラム を書きました。
参考url:
http://www.buildinsider.net/small/booktensorflow/0201
他に使用しているスクリプトは
png10.py
# coding: utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import numpy as np import tensorflow as tf from PIL import Image from reader import Cifar10Reader FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('file',None,"処理するファイルのパス") tf.app.flags.DEFINE_integer('offset',0,"読み飛ばすレコード数") tf.app.flags.DEFINE_integer('length',16,"読み込んで変換するレコード数") basename = os.path.basename(FLAGS.file) path = os.path.dirname(FLAGS.file) reader = Cifar10Reader(FLAGS.file) stop = FLAGS.offset + FLAGS.length for index in range(FLAGS.offset,stop): image = reader.read(index) print('label: %d' % image.label) imageshow = Image.fromarray(image.byte_array.astype(np.unit8)) file_name = '%s-%02d-%d.png' % (basename,index,image.label) file = os.path.join(path,file_name) with open(file,mode='wb') as out: imageshow.save(out,format='png') reader.close()
model.py
# coding: utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf NUM_CLASSES = 10 def _get_weights(shape,stddev=1.0): var = tf.get_variable( 'weights', shape, initializer=tf.truncated_normal_initializer(stddev=stddev) ) return var def _get_biases(shape,value=0.0): var = tf.get_variable( 'biases', shape, initializer=tf.constant_initializer(value) ) return var def inference(image_node): # conv1 with tf.variable_scope('conv1') as scope: weights = _get_weights(shape=[5,5,3,64],stddev=1e-4) conv = tf.nn.conv2d(image_node,weights,[1,1,1,1],padding='SAME') biases = _get_biases([64],value=0.1) bias = tf.nn.bias_add(conv,biases) conv1 = tf.nn.relu(bias,name=scope.name) # pool pool1 = tf.nn.max_pool(conv1,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool1') # conv2 with tf.variable_scope('conv2') as scope: weights = _get_weights(shape=[5,5,64,64],stddev=1e-4) conv = tf.nn.conv2d(pool1,weights,[1,1,1,1],padding='SAME') biases = _get_biases([64],value=0.1) bias = tf.nn.bias_add(conv,biases) conv2 = tf.nn.relu(bias,name=scope.name) # pool2 pool2 = tf.nn.max_pool(conv2,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME',name='pool2') reshape = tf.reshape(pool2,[1,-1]) dim = reshape.get_shape()[1].value # fc3 with tf.variable_scope('fc3') as scope: weights = _get_weights(shape=[dim,384],stddev=0.04) biases = _get_biases([384],value=0.1) fc3 = tf.nn.relu( tf.matmul(reshape,weights) + biases, name=scope.name ) # fc4 with tf.variable_scope('fc4') as scope: weights = _get_weights(shape=[384,192],stddev=0.04) biases = _get_biases([192],value=0.1) fc4 = tf.nn.relu(tf.matmul(fc3,weights) + biases,name=scope.name) # output with tf.variable_scope('output') as scope: weights = _get_weights(shape=[192,NUM_CLASSES],stddev=1/192.0) biases = _get_biases([NUM_CLASSES],value=0.0) logits = tf.add(tf.matmul(fc4,weights),biases,name='logits') return logits
inference.py
# coding: utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import time import tensorflow as tf import model as model from reader import Cifar10Reader FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_integer('epoch',30,"訓練するEpoch数") tf.app.flags.DEFINE_string('data_dir','./data/',"訓練データのディレクトリ") tf.app.flags.DEFINE_string('checkpoint_dir','./checkpoints/',"チェックポイントを保存するディレクトリ") filenames = [ os.path.join(FLAGS.data_dir,'data_batch_%d.bin' % i) for i in range(1,6) ] def main(argv=None): train_placeholder = tf.placeholder(tf.float32,shape=[32,32,3],name='input_image') image_node = tf.expand_dims(train_placeholder,0) logits = model.inference(image_node) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) total_duration = 0 for epoch in range(1,FLAGS.epoch+1): start_time = time.time() for file_index in range(5): print('Epoch %d: %s' % (epoch,filenames[file_index])) reader = Cifar10Reader(filenames[file_index]) for index in range(10000): image = reader.read(index) logits_value = sess.run([logits],feed_dict={ train_placeholder:image.byte_array, }) if index % 1000 ==0: print('[%d]: %r'% (image.label,logits_value)) reader.close() duration = time.time() - start_time total_duration += duration print('epoch %d duration = %d sec'%(epoch,duration)) tf.train.SummaryWriter(FLAGS.checkpoint_dir,sess.graph) print('Total duration = %d sec'% total_duration) if __name__ == '__main__': tf.app.run()
しかし、inference.py を実行すると
File "inference.py", line 44, in main image = reader.read(index) File "/Users/XXX/Desktop/cifar/reader.py", line 43, in read self.bytestream.seek(record_bytes * index,0) AttributeError: 'Cifar10Reader' object has no attribute 'bytestream'
というエラーが出てしまいました。
Cifar10Reader にbytestream を引数として持たせなければならない、という意味ですよね?でも参考urlの書き方ではCifar10Readerクラスの引数にbytestream を指定していません。
どう直せば良いのでしょうか?
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
退会済みユーザー
2017/04/09 06:53
2017/04/10 02:25
退会済みユーザー
2017/04/11 11:41
2017/04/11 11:51
退会済みユーザー
2017/04/11 12:49