回答編集履歴

2

d

2019/10/09 07:29

投稿

tiitoi
tiitoi

スコア21956

test CHANGED
@@ -21,3 +21,107 @@
21
21
  MNIST のような小規模なデータであれば交差検証を行ってもそんなに時間がかかりませんが、ある程度の規模のデータセットを学習する場合、ディープラーニングは時間がかなりかかるため、交差検証は基本的に行われません。
22
22
 
23
23
  学習データ、テストデータに分けるホールドアウトで評価するのが一般的です。
24
+
25
+
26
+
27
+ ## 追記
28
+
29
+
30
+
31
+ hayataka2049 さんに教えていただいたのですが、
32
+
33
+ sklearn が要求する API 仕様に合わせるためのラッパー関数が Keras に用意されているようです。
34
+
35
+
36
+
37
+ [Scikit-learn API - Keras Documentation](https://keras.io/ja/scikit-learn-api/)
38
+
39
+
40
+
41
+ 使ってみたところ、以下のように交差検証ができました。
42
+
43
+
44
+
45
+ ```python
46
+
47
+ from keras import backend as K
48
+
49
+ from keras.datasets import mnist
50
+
51
+ from keras.layers import Activation, BatchNormalization, Dense, Flatten
52
+
53
+ from keras.models import Sequential
54
+
55
+ from keras.wrappers.scikit_learn import KerasClassifier
56
+
57
+ from sklearn.model_selection import cross_val_score
58
+
59
+
60
+
61
+ def create_model():
62
+
63
+ model = Sequential(
64
+
65
+ [
66
+
67
+ Flatten(),
68
+
69
+ Dense(10),
70
+
71
+ BatchNormalization(),
72
+
73
+ Activation("relu"),
74
+
75
+ Dense(10),
76
+
77
+ BatchNormalization(),
78
+
79
+ Activation("relu"),
80
+
81
+ Dense(10),
82
+
83
+ BatchNormalization(),
84
+
85
+ Activation("softmax"),
86
+
87
+ ]
88
+
89
+ )
90
+
91
+
92
+
93
+ model.compile(
94
+
95
+ optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
96
+
97
+ )
98
+
99
+
100
+
101
+ return model
102
+
103
+
104
+
105
+ # sklearn と互換性のあるモデルを作成する。
106
+
107
+ clf = KerasClassifier(create_model, epochs=5, batch_size=256, validation_split=0.2, verbose=0)
108
+
109
+
110
+
111
+ # MNIST データを取得する。
112
+
113
+ (X_train, y_train), (X_test, y_test) = mnist.load_data()
114
+
115
+ X = np.concatenate((X_train, X_test))
116
+
117
+ y = np.concatenate((y_train, y_test))
118
+
119
+
120
+
121
+ # 交差検証
122
+
123
+ ret = cross_val_score(clf, X, y, cv=3)
124
+
125
+ print(ret) # [0.9145453 0.90875584 0.91488451]
126
+
127
+ ```

1

d

2019/10/09 07:29

投稿

tiitoi
tiitoi

スコア21956

test CHANGED
@@ -6,7 +6,7 @@
6
6
 
7
7
 
8
8
 
9
- [sklearn.model_selection.KFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html) を使うことで交差検証を使えば、以下のように交差検証は行なえます。
9
+ [sklearn.model_selection.KFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html) を使うことで以下のように交差検証は行なえます。
10
10
 
11
11
 
12
12