質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

3926閲覧

MNISTでのk交差検証ができない

退会済みユーザー

退会済みユーザー

総合スコア0

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2019/10/09 06:12

編集2019/10/09 06:13

前提・実現したいこと

Kerasを用いてMNIST10種分類問題を行っており、出てきた結果に対してK交差検証を行いたいのですが、cross_val_scoreでエラーが発生してしまいます。

発生している問題・エラーメッセージ

File "C:/Users/user/PycharmProjects/MNIST/MNISTlkfoldxval.py", line 111, in <module> scores = cross_val_score(hist, X_train, Y_train, cv=kf) File "C:\Users\user\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\model_selection\_validation.py", line 384, in cross_val_score scorer = check_scoring(estimator, scoring=scoring) File "C:\Users\user\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\metrics\scorer.py", line 270, in check_scoring 6000/6000 [==============================] - 0s 51us/step - loss: 0.4041 - accuracy: 0.8767 - val_loss: 0.3896 - val_accuracy: 0.8945 "'fit' method, %r was passed" % estimator) TypeError: estimator should be an estimator implementing 'fit' method, <keras.callbacks.callbacks.History object at 0x000001BBCF15FEB8> was passed

該当のソースコード

python

1mnist = datasets.fetch_mldata('MNIST original', data_home='.') 2 3n = len(mnist.data) 4N = 10000 # MNISTの一部を使う 5N_train = 8000 6N_validation = 2000 7indices = np.random.permutation(range(n))[:N] # ランダムにN枚を選択 8 9X = mnist.data[indices] 10X = X / 255.0 11X = X - X.mean(axis=1).reshape(len(X), 1) 12y = mnist.target[indices] 13Y = np.eye(10)[y.astype(int)] 14 15X_train, X_test, Y_train, Y_test = \ 16 train_test_split(X, Y, train_size=N_train) 17X_train, X_validation, Y_train, Y_validation = \ 18 train_test_split(X_train, Y_train, test_size=N_validation) 19 20#モデル設定 21 22n_in = len(X[0]) # 784 23n_hidden = 200 24n_out = len(Y[0]) # 10 25p_keep = 0.5 26activation = 'relu' 27 28model = Sequential() 29model.add(Dense(n_hidden, input_dim=n_in)) 30model.add(Activation(activation)) 31model.add(Dropout(p_keep)) 32 33model.add(Dense(n_hidden)) 34model.add(Activation(activation)) 35model.add(Dropout(p_keep)) 36 37model.add(Dense(n_out)) 38model.add(Activation('softmax')) 39 40model.compile(loss='categorical_crossentropy', 41 optimizer=SGD(lr=0.01, momentum=0.9), 42 metrics=['accuracy']) 43 44#モデル学習 45batch_size = 200 46hist_acc_t = hist_acc_v = hist_loss_t = hist_loss_v = [] 47max_epochs = 10 48ep_learn_interval = 50 49 50project_name = 'MNIST' + activation + str(max_epochs) 51modellog = project_name + 'models' 52os.makedirs(modellog, exist_ok=True) 53 54model_checkpoint = ModelCheckpoint( 55 filepath=os.path.join(modellog, 'model_{epoch:02d}-{loss:.2f}-{acc:.2f}-{val_loss:.2f}-{val_acc:.2f}.hdf5'), 56 monitor='val_loss', 57 verbose=1, period=50) 58 59hist = model.fit(X_train, Y_train, batch_size=batch_size, 60 epochs = max_epochs, 61 validation_data = (X_validation, Y_validation), 62 callbacks=[model_checkpoint]) 63 64hist_acc_t = hist_acc_t + hist.history['accuracy'] 65hist_acc_v = hist_acc_v + hist.history['val_accuracy'] 66hist_loss_t = hist_loss_t + hist.history['loss'] 67hist_loss_v = hist_loss_v + hist.history['val_loss'] 68 69kf = KFold(n_splits=5, shuffle=True, random_state=42) 70scores = cross_val_score(hist, X_train, Y_train, cv=kf) 71 72#学習の進み具合を可視化 73# 各分割におけるスコア 74print('Cross-Validation scores: {}'.format(scores)) 75# スコアの平均値 76import numpy as np 77print('Average score: {}'.format(np.mean(scores)))

試したこと

ウェブ上をいろいろ参考にしましたが、cross_val_scoreの使われ方がコードによって違うのでどれを適用したらいいのかわからず困っております。よろしければお力をお貸しください。

補足情報(FW/ツールのバージョンなど)

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

ウェブ上をいろいろ参考にしましたが、cross_val_scoreの使われ方がコードによって違うのでどれを適用したらいいのかわからず困っております。よろしければお力をお貸しください。

cross_val_score には、fit(X, y) と呼び出して学習が行えるメソッドを持つオブジェクトを渡す必要があります。基本的には sklearn の SGDClassifier のようなモデルが渡されるのを想定した作りになっているので、Keras のモデルでは利用できません。

sklearn.model_selection.KFold を使うことで以下のように交差検証は行なえます。

[Keras/TensorFlow] KerasでCV(クロスバリデーション) - Qiita


MNIST のような小規模なデータであれば交差検証を行ってもそんなに時間がかかりませんが、ある程度の規模のデータセットを学習する場合、ディープラーニングは時間がかなりかかるため、交差検証は基本的に行われません。
学習データ、テストデータに分けるホールドアウトで評価するのが一般的です。

追記

hayataka2049 さんに教えていただいたのですが、
sklearn が要求する API 仕様に合わせるためのラッパー関数が Keras に用意されているようです。

Scikit-learn API - Keras Documentation

使ってみたところ、以下のように交差検証ができました。

python

1from keras import backend as K 2from keras.datasets import mnist 3from keras.layers import Activation, BatchNormalization, Dense, Flatten 4from keras.models import Sequential 5from keras.wrappers.scikit_learn import KerasClassifier 6from sklearn.model_selection import cross_val_score 7 8def create_model(): 9 model = Sequential( 10 [ 11 Flatten(), 12 Dense(10), 13 BatchNormalization(), 14 Activation("relu"), 15 Dense(10), 16 BatchNormalization(), 17 Activation("relu"), 18 Dense(10), 19 BatchNormalization(), 20 Activation("softmax"), 21 ] 22 ) 23 24 model.compile( 25 optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] 26 ) 27 28 return model 29 30# sklearn と互換性のあるモデルを作成する。 31clf = KerasClassifier(create_model, epochs=5, batch_size=256, validation_split=0.2, verbose=0) 32 33# MNIST データを取得する。 34(X_train, y_train), (X_test, y_test) = mnist.load_data() 35X = np.concatenate((X_train, X_test)) 36y = np.concatenate((y_train, y_test)) 37 38# 交差検証 39ret = cross_val_score(clf, X, y, cv=3) 40print(ret) # [0.9145453 0.90875584 0.91488451]

投稿2019/10/09 06:35

編集2019/10/09 07:29
tiitoi

総合スコア21956

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

tiitoi

2019/10/09 07:30

そのような機能が用意されていたのですね。使ってみたところ、そのまま sklearn の関数に渡せました。 教えていただきありがとうございます。
hayataka2049

2019/10/09 07:33

調べてみて見つけはしたんですが、私はDNNやらないのでfitパラメータの渡し方どうするんだろうとか思って回答控えてました。tiitoiさんにやっていただいて助かりました。
退会済みユーザー

退会済みユーザー

2019/10/14 04:25

お教えいただきありがとうございます。無事に解決できました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問