###前提・実現したいこと
オライリー・ジャパンの”ゼロから始めるdeeplearning”を使って機械学習を勉強しています。(現在3.6.2章です)
下記のソースコードを実行したところ, エラーメッセージが出てしまったのですが,
このエラーメッセージの意味と回避方法を教えていただきたいです。
###発生している問題・エラーメッセージ
Traceback (most recent call last): File "<stdin>", line 2, in <module> File "<stdin>", line 3, in predict KeyError: 'B1'
###該当のソースコード
python
1import numpy as np 2import pickle 3 4import sys, os 5sys.path.append(os.pardir) 6from dataset.mnist import load_mnist 7 8 9def sigmoid(x): 10 return 1/(1+np.exp(-x)) 11 12def getdata(): 13 (xtrain,ttrain), (xtest,ttest) = load_mnist(normalize = True, flatten = True, one_hot_label = False) 14 return xtest, ttest 15 16def initnetwork(): 17 with open("sample_weight.pkl", 'rb') as f: 18 network = pickle.load(f) 19 return network 20 21def predict(network, x): 22 W1, W2, W3 = network['W1'], network['W2'], network['W3'] 23 B1, B2, B3 = network['B1'], network['B2'], network['B3'] 24 a = np.dot(x, W1) + B1 25 A = sigmoid(a) 26 b = np.dot(A, W2) + B2 27 B = sigmoid(b) 28 y = np.dot(B. W3) + B3 29 Y = softmax(y) 30 return Y 31 32 33 34x, t = getdata() 35network = initnetwork() 36 37accuracycnt = 0 38 39for i in range(len(x)): 40 y = predict(network, x[i]) 41 p = np.argmax(y) 42 if p == t[i]: 43 accuracycnt += 1 44 45print("Accuracy:" + str(float(accuracycnt) / len(x)))
回答3件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2017/06/02 01:38
2017/06/02 02:08
2017/06/02 02:17