質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

1回答

482閲覧

訓練データとテストデータの分別に失敗しました。

cunwe

総合スコア65

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2019/08/15 16:42

前回の自分の質問の続きなのですがあの後0であることを確認した後

n_train=len(digits.data)*2//3 X_train=digits.data[:n_train] y_train=digits.data[:n_train] X_test=digits.data[n_train:] y_test=digits.data[n_train:]

と、訓練データとテストデータを用意した後

print([d.shape for d in [X_train,y_train,X_test,y_test]])

と、構造を確認した後

[(1198, 64), (1198, 64), (599, 64), (599, 64)]

という結果を得

from sklearn import svm clf=svm.SVC(gamma=0.001) clf.fit(X_train,y_train)

と、SVMで学習を行った結果以下のようなエラーが発生しました。

ValueError Traceback (most recent call last) <ipython-input-22-86937c1966f0> in <module> ----> 1 clf.fit(X_train,y_train) ~\Anaconda3\lib\site-packages\sklearn\svm\base.py in fit(self, X, y, sample_weight) 147 X, y = check_X_y(X, y, dtype=np.float64, 148 order='C', accept_sparse='csr', --> 149 accept_large_sparse=False) 150 y = self._validate_targets(y) 151 ~\Anaconda3\lib\site-packages\sklearn\utils\validation.py in check_X_y(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, warn_on_dtype, estimator) 759 dtype=None) 760 else: --> 761 y = column_or_1d(y, warn=True) 762 _assert_all_finite(y) 763 if y_numeric and y.dtype.kind == 'O': ~\Anaconda3\lib\site-packages\sklearn\utils\validation.py in column_or_1d(y, warn) 795 return np.ravel(y) 796 --> 797 raise ValueError("bad input shape {0}".format(shape)) 798 799 ValueError: bad input shape (1198, 64)

今までのエラーより長く、エラーの読み方がよくわかりませんでした。。「bad input shape」でググってもよくわからず。。Anacondaの中がどーたらこーたらとあるのでバージョンが合わないのかと思いましたが原因がわかる方、ご教示いただけると嬉しいです。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

y_train, y_test は正解ラベルなので、代入するのは、画像である digits.data ではなく、ラベルの digits.taget ではないでしょうか。

つまり、以下のように変更すればよいです。

python

1X_train = digits.data[:n_train] 2y_train = digits.target[:n_train] 3X_test = digits.data[n_train:] 4y_test = digits.target[n_train:]

投稿2019/08/15 16:49

tiitoi

総合スコア21956

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

cunwe

2019/08/16 05:25

ご指摘ありがとうございます!おっしゃるとおりに実行したらうまくいきました!ありがとうございました(^^)
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問