import numpy as np from sklearn import datasets from sklearn.model_selection import GridSearchCV from sklearn.linear_model import LogisticRegression from sklearn.decomposition import PCA from sklearn.svm import SVC from sklearn.pipeline import Pipeline digits = datasets.load_digits() X,y=digits.data,digits.target from sklearn.model_selection import train_test_split X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=0) clf1=LogisticRegression() clf2=SVC() estimators = [('pca', PCA()), ('clf', clf1)] pipe1 = Pipeline(estimators) param1 = {'clf__C':[1e-5, 1e-3, 1e-2, 1, 1e2, 1e5, 1e10], 'pca__whiten':[True,False], } gs = GridSearchCV(pipe1, param1) gs.fit(X_train, y_train) gs.score(X_test, y_test) from sklearn.model_selection import RandomizedSearchCV estimators= [('pca', PCA()), ('clf',SVC())] pipe2 = Pipeline(estimators) gamma_range_exp = np.arange(-10.0, 0.0, 3) gamma_range = 10 ** gamma_range_exp param2 =[ {'clf__C':[1e-5, 1e-3, 1e-2, 1, 1e2, 1e5, 1e10], 'clf__kernel':['linear'], 'pca__whiten':[True,False], 'pca__n_components': [30, 20, 10]}, {'clf__C':[1e-5, 1e-3, 1e-2, 1, 1e2, 1e5, 1e10], 'clf__kernel':['rbf'], 'gamma': gamma_range, 'pca__whiten':[True,False], 'pca__n_components': [30, 20, 10]} ] gs = RandomizedSearchCV(pipe2, param2, n_jobs=-1, verbose=2) gs.fit(X_train, y_train)
エラー内容
AttributeError Traceback (most recent call last)
<ipython-input-11-21305e2006cc> in <module>()
12
13 gs = RandomizedSearchCV(pipe2, param2, n_jobs=-1, verbose=2)
---> 14 gs.fit(X_train, y_train)
~/anaconda3/lib/python3.6/site-packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
616 n_splits = cv.get_n_splits(X, y, groups)
617 # Regenerate parameter iterable for each fit
--> 618 candidate_params = list(self._get_param_iterator())
619 n_candidates = len(candidate_params)
620 if self.verbose > 0:
~/anaconda3/lib/python3.6/site-packages/sklearn/model_selection/_search.py in iter(self)
236 # in this case we want to sample without replacement
237 all_lists = np.all([not hasattr(v, "rvs")
--> 238 for v in self.param_distributions.values()])
239 rnd = check_random_state(self.random_state)
240
AttributeError: 'list' object has no attribute 'values'
1のようにグリットサーチをしたかったのですが、エラーが出てしまいました。
何がちがうのでしょうか?
回答2件
あなたの回答
tips
プレビュー