以下のコードを実行したところ「KeyError:0」のエラーが出てしまいます。
どなたかおわかりになりましたら解決策についてご教示いただきたく思います。
Python
import csv,os,pickle import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import numpy as np from sklearn import datasets from sklearn.model_selection import train_test_split %matplotlib inline # 乱数シードを指定 np.random.seed(seed=0) # mnistデータセットのロード(ネットワーク接続が必要・少し時間がかかります) if os.path.exists('mnist_784'): with open('mnist_784','rb') as f: mnist = pickle.load(f) else: #mnist = datasets.fetch_openml('mnist_784') mnist = fetch_openml('mnist_784', version=1, as_frame=False) with open('mnist_784', 'wb') as f: pickle.dump(mnist, f) # 画像とラベルを取得 X, T = mnist.data, mnist.target # 訓練データとテストデータに分割 X_train, X_test, T_train, T_test = train_test_split(X, T, test_size=0.2) # ラベルデータをint型にし、one-hot-vectorに変換します T_train = np.eye(10)[T_train.astype("int")] T_test = np.eye(10)[T_test.astype("int")] # データを5つ表示 for i in range(5): plt.gray() plt.imshow(X_train[i].reshape((28,28))) plt.show() print("label: ", T_train[i])
以下のコードでエラーが出ます。
Python
plt.imshow(X_train[i].reshape((28,28)))
エラー文↓
当方のMacのPythonで、質問のコードの
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
↓ 変更
mnist = datasets.fetch_openml('mnist_784', version=1, as_frame=False)
だけ変更して、他はそのままで実行したらエラーは出ませんでしたが、上記を
mnist = datasets.fetch_openml('mnist_784')
に変更して実行した場合は、質問に書かれてるエラーが出ました
なお、一旦エラーが出る状態になったら、作成された「mnist_784」というファイルを削除しないと、コードを変更して実行しても、エラーが出続けます
「mnist_784」を削除して、変更したコードでネットからのダウンロードからやり直す必要があります
ありがとうございます。ご指摘通りに試したらうまくいきました。
最初「mnist = datasets.fetch_openml('mnist_784')」の状態でダウンロードしていたため、その後ずっとエラーが出ていたのだと思います。
ダウンロードする時に「mnist = datasets.fetch_openml('mnist_784', version=1, as_frame=False)」の状態にするべきでした。
まだ回答がついていません
会員登録して回答してみよう