teratail header banner
teratail header banner
質問するログイン新規登録

回答編集履歴

2

d

2019/10/09 07:29

投稿

tiitoi
tiitoi

スコア21960

answer CHANGED
@@ -9,4 +9,56 @@
9
9
  ----
10
10
 
11
11
  MNIST のような小規模なデータであれば交差検証を行ってもそんなに時間がかかりませんが、ある程度の規模のデータセットを学習する場合、ディープラーニングは時間がかなりかかるため、交差検証は基本的に行われません。
12
- 学習データ、テストデータに分けるホールドアウトで評価するのが一般的です。
12
+ 学習データ、テストデータに分けるホールドアウトで評価するのが一般的です。
13
+
14
+ ## 追記
15
+
16
+ hayataka2049 さんに教えていただいたのですが、
17
+ sklearn が要求する API 仕様に合わせるためのラッパー関数が Keras に用意されているようです。
18
+
19
+ [Scikit-learn API - Keras Documentation](https://keras.io/ja/scikit-learn-api/)
20
+
21
+ 使ってみたところ、以下のように交差検証ができました。
22
+
23
+ ```python
24
+ from keras import backend as K
25
+ from keras.datasets import mnist
26
+ from keras.layers import Activation, BatchNormalization, Dense, Flatten
27
+ from keras.models import Sequential
28
+ from keras.wrappers.scikit_learn import KerasClassifier
29
+ from sklearn.model_selection import cross_val_score
30
+
31
+ def create_model():
32
+ model = Sequential(
33
+ [
34
+ Flatten(),
35
+ Dense(10),
36
+ BatchNormalization(),
37
+ Activation("relu"),
38
+ Dense(10),
39
+ BatchNormalization(),
40
+ Activation("relu"),
41
+ Dense(10),
42
+ BatchNormalization(),
43
+ Activation("softmax"),
44
+ ]
45
+ )
46
+
47
+ model.compile(
48
+ optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
49
+ )
50
+
51
+ return model
52
+
53
+ # sklearn と互換性のあるモデルを作成する。
54
+ clf = KerasClassifier(create_model, epochs=5, batch_size=256, validation_split=0.2, verbose=0)
55
+
56
+ # MNIST データを取得する。
57
+ (X_train, y_train), (X_test, y_test) = mnist.load_data()
58
+ X = np.concatenate((X_train, X_test))
59
+ y = np.concatenate((y_train, y_test))
60
+
61
+ # 交差検証
62
+ ret = cross_val_score(clf, X, y, cv=3)
63
+ print(ret) # [0.9145453 0.90875584 0.91488451]
64
+ ```

1

d

2019/10/09 07:29

投稿

tiitoi
tiitoi

スコア21960

answer CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  cross_val_score には、fit(X, y) と呼び出して学習が行えるメソッドを持つオブジェクトを渡す必要があります。基本的には sklearn の SGDClassifier のようなモデルが渡されるのを想定した作りになっているので、Keras のモデルでは利用できません。
4
4
 
5
- [sklearn.model_selection.KFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html) を使うことで交差検証を使えば、以下のように交差検証は行なえます。
5
+ [sklearn.model_selection.KFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html) を使うことで以下のように交差検証は行なえます。
6
6
 
7
7
  [[Keras/TensorFlow] KerasでCV(クロスバリデーション) - Qiita](https://qiita.com/agumon/items/0df9f008a255796b5a94)
8
8