下記URLのサイトにある4つのmnistデータをk-nnで識別しエラー率を算出したいと考えています。他の方の方法を参考にコードを書いてみたのですがエラーが出てしまいました。アドバイス頂けると幸いです。
以下mnistデータの読み込み及び識別プログラム
Python
1import struct 2import numpy as np 3 4 5class MNIST: 6 def __init__(self, LT): 7 8 if LT == 'L': 9 self.fnLabel = 'train-labels.idx1-ubyte' 10 self.fnImage = 'train-images.idx3-ubyte' 11 else: 12 self.fnLabel = 't10k-labels.idx1-ubyte' 13 self.fnImage = 't10k-images.idx3-ubyte' 14 15 def getLabel(self): 16 17 return readLabel ( self.fnLabel ) 18 19 def getImage(self): 20 return readImage(self.fnImage) 21 22 23##### reading the label file 24# 25def readLabel(fnLabel): 26 f = open ( fnLabel, 'r' ) 27 28 ### header (two 4B integers, magic number(2049) & number of items) 29 # 30 header = f.read ( 8 ) 31 mn, num = struct.unpack ( '>2i', header ) # MSB first (bigendian) 32 assert mn == 2049 33 # print mn, num 34 35 ### labels (unsigned byte) 36 # 37 label = np.array ( struct.unpack ( '>%dB' % num, f.read ( ) ), dtype=int ) 38 39 f.close ( ) 40 41 return label 42 43 44##### reading the image file 45# 46def readImage(fnImage): 47 f = open ( fnImage, 'r' ) 48 49 ### header (four 4B integers, magic number(2051), #images, #rows, and #cols 50 # 51 header = f.read ( 16 ) 52 mn, num, nrow, ncol = struct.unpack ( '>4i', header ) # MSB first (bigendian) 53 assert mn == 2051 54 # print mn, num, nrow, ncol 55 56 ### pixels (unsigned byte) 57 # 58 pixel = np.empty ( (num, nrow, ncol) ) 59 npixel = nrow * ncol 60 for i in range ( num ): 61 buf=struct.unpack('>%dB'%npixel,f.read(npixel)) 62 pixel[ i, :, : ] = np.asarray ( buf ).reshape ( (nrow, ncol) ) 63 64 f.close ( ) 65 66 return pixel 67 68 69if __name__ == '__main__': 70 print '# MNIST training data' 71 mnist = MNIST ( 'L' ) 72 lab = mnist.getLabel ( ) 73 dat = mnist.getImage ( ) 74 print lab.shape, dat.shape 75 76 print '# MNIST test data' 77 mnist = MNIST ( 'T' ) 78 lab = mnist.getLabel ( ) 79 dat = mnist.getImage ( ) 80 print lab.shape, dat.shape
以下エラー率の算出
Python
1import numpy as np 2import mnist1 as mnist 3 4##### training data 5# 6mn = mnist.MNIST('L') 7labL = mn.getLabel() 8nL = labL.shape[0] 9xL = mn.getImage().reshape((nL, -1)) 10print '# xL: ', xL.shape 11 12##### test data 13# 14mn = mnist.MNIST('T') 15labT = mn.getLabel() 16nT = labT.shape[0] 17xT = mn.getImage().reshape((nT, -1)) 18print '# xT: ', xT.shape 19 20##### nearest neighbor classification 21# 22xLsq = np.sum(xL ** 2, axis=1) 23out = np.empty(nT, dtype=int) 24for i in range(nT): 25 if i % 1000 == 0: 26 print i 27 d = -2 * np.dot(xL, xT[i, :]) + xLsq 28 out[i] = labL[np.argmin(d)] 29 30er = np.sum(out != labT) / float(nT) 31 32print '# test error rate = ', er * 100, '%'
下記プログラムのエラー
Traceback (most recent call last):
File "C:/.../mnisterror.py", line 9, in <module>
xL = mn.getImage().reshape((nL, -1))
File "C:...\mnist1.py", line 20, in getImage
return readImage(self.fnImage)
File "C:...\mnist1.py", line 61, in readImage
buf=struct.unpack('>%dB'%npixel,f.read(npixel))
struct.error: unpack requires a string argument of length 784
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2016/10/31 23:08
2016/11/01 05:17
2016/11/01 09:20