Python 2.7.12, mnistデータをk-nnで識別しエラー率を算出する方法

解決済

回答 1

投稿

  • 評価
  • クリップ 0
  • VIEW 1,689

3naoki

score 10

下記URLのサイトにある4つのmnistデータをk-nnで識別しエラー率を算出したいと考えています。他の方の方法を参考にコードを書いてみたのですがエラーが出てしまいました。アドバイス頂けると幸いです。

http://yann.lecun.com/exdb/mnist/

以下mnistデータの読み込み及び識別プログラム

import struct
import numpy as np


class MNIST:
    def __init__(self, LT):

        if LT == 'L':
            self.fnLabel = 'train-labels.idx1-ubyte'
            self.fnImage = 'train-images.idx3-ubyte'
        else:
            self.fnLabel = 't10k-labels.idx1-ubyte'
            self.fnImage = 't10k-images.idx3-ubyte'

    def getLabel(self):

        return readLabel ( self.fnLabel )

    def getImage(self):
        return readImage(self.fnImage)


##### reading the label file
#
def readLabel(fnLabel):
    f = open ( fnLabel, 'r' )

    ### header (two 4B integers, magic number(2049) & number of items)
    #
    header = f.read ( 8 )
    mn, num = struct.unpack ( '>2i', header )  # MSB first (bigendian)
    assert mn == 2049
    # print mn, num

    ### labels (unsigned byte)
    #
    label = np.array ( struct.unpack ( '>%dB' % num, f.read ( ) ), dtype=int )

    f.close ( )

    return label


##### reading the image file
#
def readImage(fnImage):
    f = open ( fnImage, 'r' )

    ### header (four 4B integers, magic number(2051), #images, #rows, and #cols
    #
    header = f.read ( 16 )
    mn, num, nrow, ncol = struct.unpack ( '>4i', header )  # MSB first (bigendian)
    assert mn == 2051
    # print mn, num, nrow, ncol

    ### pixels (unsigned byte)
    #
    pixel = np.empty ( (num, nrow, ncol) )
    npixel = nrow * ncol
    for i in range ( num ):
        buf=struct.unpack('>%dB'%npixel,f.read(npixel))
        pixel[ i, :, : ] = np.asarray ( buf ).reshape ( (nrow, ncol) )

    f.close ( )

    return pixel


if __name__ == '__main__':
    print '# MNIST training data'
    mnist = MNIST ( 'L' )
    lab = mnist.getLabel ( )
    dat = mnist.getImage ( )
    print lab.shape, dat.shape

    print '# MNIST test data'
    mnist = MNIST ( 'T' )
    lab = mnist.getLabel ( )
    dat = mnist.getImage ( )
    print lab.shape, dat.shape

以下エラー率の算出

import numpy as np
import mnist1 as mnist

##### training data
#
mn = mnist.MNIST('L')
labL = mn.getLabel()
nL = labL.shape[0]
xL = mn.getImage().reshape((nL, -1))
print '# xL: ', xL.shape

##### test data
#
mn = mnist.MNIST('T')
labT = mn.getLabel()
nT = labT.shape[0]
xT = mn.getImage().reshape((nT, -1))
print '# xT: ', xT.shape

##### nearest neighbor classification
#
xLsq = np.sum(xL ** 2, axis=1)
out = np.empty(nT, dtype=int)
for i in range(nT):
    if i % 1000 == 0:
        print i
    d = -2 * np.dot(xL, xT[i, :]) + xLsq
    out[i] = labL[np.argmin(d)]

er = np.sum(out != labT) / float(nT)

print '# test error rate = ', er * 100, '%'

下記プログラムのエラー

Traceback (most recent call last):
  File "C:/.../mnisterror.py", line 9, in <module>
    xL = mn.getImage().reshape((nL, -1))
  File "C:\...\mnist1.py", line 20, in getImage
    return readImage(self.fnImage)
  File "C:\...\mnist1.py", line 61, in readImage
    buf=struct.unpack('>%dB'%npixel,f.read(npixel))
struct.error: unpack requires a string argument of length 784

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

checkベストアンサー

0

unpack requires a string argument of length 784

エラーメッセージは、unpack するデータの長さが足りていないようです。

原因は恐らくファイルの open の mode
ファイルの open の第二引数を 'rb' として、バイナリーモードで開いてみてください。
デフォルトだとテキストモードで読み込まれ、windows では改行コードが変換されてしまいます。

struct.unpack('>%dB'%npixel, f.read(npixel))


ここで unpackのサイズを動的に指定していますが、ファイルがテキストモードで開かれている場合
len(f.read(npixel)) == npixel が真となるとは限りません。
struct.unpack が要求するデータ長を f.read が返さなかった時、上記のエラーとなります。

また、該当箇所のコードでは、unpackでPythonの数値のリストを一時的に生成し
numpy の配列を再構築といったことをしていますが、numpy で直接読み込むほうが効率良いので
より規模の大きなデータを扱う場合や、速度を求められる場合等は、
np.fromfile や np.memmap の利用を検討してみてください。

参考までに, np.memmap を使う場合

pixel = np.memmap(filename, ">B", mode="r", offset=16, shape=(num, nrow, ncol))

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2016/11/01 08:08

    ご丁寧にご回答頂きありがとうございます。
    ご指摘頂いた通り、ファイルの open の mode に問題がありrbにすることで解決しました。
    ただ計算に若干時間かかるのでnp.memmap を試してみたのですがエラー率がおかしな値になってしまいました。
    恐らく自分の理解が足りておらず直し方が違ったのだと思いますが、何かアドバス頂けると幸いです
    下記修正箇所です

    #pixel = np.empty ( (num, nrow, ncol) )
    #npixel = nrow * ncol
    #for i in range ( num ):
    #buf=struct.unpack('>%dB'%npixel,f.read(npixel))
    #pixel[ i, :, : ] = np.asarray ( buf ).reshape ( (nrow, ncol) )
    #f.close ( )

    をご教授頂いたものに変更しました
    pixel = np.memmap(f, ">B", mode="r", offset=16, shape=(num, nrow, ncol))

    キャンセル

  • 2016/11/01 14:17

    すいません、勝手に持ってきた変数の説明が不足してました。ファイル名の文字列を意図してましたが、
    変数 f からヘッダを読み込んだ後に続けて読むなら offset指定(ヘッダの16バイト読み飛ばし)は不要になります

    キャンセル

  • 2016/11/01 18:20

    補足です。エラー箇所しか見ていなかったので、前後の処理まで考慮してませんでした。

    f は close されるので、memmap側にファイルのopenを任せたほうがいいかもしれません。
    np(f.name, dtype="uint8", offset=16, shape=(num, nrow, ncol))

    もう一点懸念があるとすれば、データ型の違いで
    元のコードは uint8 で読み込んだものを dtype=float の配列に代入してます。
    .astype(float) で型変換出来ます。

    ちなみに、byteorder ですが、私も ">B" と書いてしまいましたが、
    byte型の場合は読み込み順序には影響しません。

    キャンセル

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 90.22%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

同じタグがついた質問を見る