多項式回帰を行って、グラフをプロットしたいのですが、行列のサイズが異なると表示されてしまい、修正することが出来ません。
エラーメッセージでValueError: shapes (160,10) and (11,) not aligned: 10 (dim 1) != 11 (dim 0)と表示されたので、shape(10,-1)なども試したのですが、上手くいきませんでした。よろしくお願いします。
python
1from sklearn.preprocessing import PolynomialFeatures 2def gen_poly_features(data, d): 3 polynomial_features = PolynomialFeatures(degree = d) 4 X_poly = polynomial_features.fit_transform(data) 5 return X_poly 6 7degree = 10 8X_train = gen_poly_features(diabetes_X_train, degree) 9X_test = gen_poly_features(diabetes_X_test, degree) 10X_val = gen_poly_features(diabetes_X_val, degree) 11print(X_train) 12 13pregr = linear_model.LinearRegression() 14pregr.fit(X_train, diabetes_y_train) 15print(diabetes_y_train) 16 17sequence = (np.array( range(-80,80) )/1000)[:, np.newaxis] 18X_sequence = sequence 19for k in range(2,degree+1): 20 X_sequence = np.hstack( (X_sequence, sequence**k) ) 21print(X_sequence) 22predictions_sequence = pregr.predict(X_sequence)
---エラーメッセージです。
ValueError Traceback (most recent call last)
<ipython-input-254-aa5bf920823e> in <module>
4 X_sequence = np.hstack( (X_sequence, sequence**k) )
5 print(X_sequence)
----> 6 predictions_sequence = pregr.predict(X_sequence)
~\Anaconda3\lib\site-packages\sklearn\linear_model\base.py in predict(self, X)
211 Returns predicted values.
212 """
--> 213 return self._decision_function(X)
214
215 _preprocess_data = staticmethod(_preprocess_data)
~\Anaconda3\lib\site-packages\sklearn\linear_model\base.py in decision_function(self, X)
196 X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
197 return safe_sparse_dot(X, self.coef.T,
--> 198 dense_output=True) + self.intercept_
199
200 def predict(self, X):
~\Anaconda3\lib\site-packages\sklearn\utils\extmath.py in safe_sparse_dot(a, b, dense_output)
171 return ret
172 else:
--> 173 return np.dot(a, b)
174
175
ValueError: shapes (160,10) and (11,) not aligned: 10 (dim 1) != 11 (dim 0)
---X_trainのデータです。
[[ 1.00000000e+00 6.16962065e-02 3.80642190e-03 2.34841792e-04
1.44888477e-05 8.93906938e-07 5.51506671e-08 3.40258694e-09
2.09926707e-10 1.29516815e-11 7.99069614e-13]
[ 1.00000000e+00 -5.14740612e-02 2.64957898e-03 -1.36384591e-04
7.02026877e-06 -3.61361745e-07 1.86007566e-08 -9.57456483e-10
4.92841737e-11 -2.53685657e-12 1.30582311e-13]
...
[ 1.00000000e+00 -4.17737526e-02 1.74504640e-03 -7.28971367e-05
3.04518695e-06 -1.27208886e-07 5.31399254e-09 -2.21985410e-10
9.27316358e-12 -3.87374841e-13 1.61821008e-14]
[ 1.00000000e+00 1.42724753e-02 2.03703550e-04 2.90735388e-06
4.14951364e-08 5.92238308e-10 8.45270660e-12 1.20641046e-13
1.72184634e-15 2.45750094e-17 3.50746213e-19]]
---diabetes_y_trainのデータです。
[151. 75. 141. 206. 135. 97. 138. 63. 110. 310. 101. 69. 179. 185. 118. 171. 166. 144. 97. 168. 68. 49. 68. 245. 184. 202. 137. 85. 131. 283. 129. 59. 341. 87. 65. 102. 265. 276. 252. 90. 100. 55. 61. 92. 259. 53. 190. 142. 75. 142.]
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
退会済みユーザー
2019/10/18 20:03