質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

1229閲覧

pythonの回帰の勾配法を使って二次元面モデルの作成がわかりません

wa-ya

総合スコア1

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/01/28 06:58

前提・実現したいこと

大学の授業でpythonで回帰を勉強しております
現在解析解を使わずに勾配法を使って二次元入力の面モデルを作成しています
(年齢と体重の2つのパラメータから身長を推定する)
エラーが出てわからず止まっています
どなたかご教授下さい

発生している問題・エラーメッセージ

cannot copy sequence with size 2 to array axis with dimension 3

該当のソースコード

import numpy as np import matplotlib.pyplot as plt %matplotlib inline #データ生成 X0_n = 16 X0_min = 5 X0_max = 30 T_min = 40 T_max = 75 X0 = np.array([15.43, 23.01, 5.00, 12.56, 8.67, 7.31, 9.66, 13.64, 14.92, 18.47, 15.48, 22.13, 10.11, 26.95, 5.68, 21.76]) #x(年齢) T = np.array([70.43, 58.15, 37.22, 56.51, 57.32, 40.84, 57.79, 56.94, 63.03, 65.69, 62.33, 64.95, 57.73, 66.89, 46.68, 61.08]) #t(体重) #身長を抽出 import math np.random.seed(seed=1) #乱数を固定 h = 10000 * (T - 2 * np.random.randn(X0_n)) / 23 #hは(身長)**2のこと X1 = np.sqrt(h) #ルートで身長を出す X1は身長 X1_min = 129 X1_max = 175 print(np.round(X1,2)) #小数点第2位にする #二次元データを表示 def show_data2(ax, x0, x1, t): for i in range(len(x0)): ax.plot([x0[i], x0[i]], [x1[i], x1[i]], [120, t[i]], color='gray') ax.plot(x0, x1, t, 'o', color='cornflowerblue', markeredgecolor='black', markersize=6, markeredgewidth=0.5) ax.view_init(elev=35, azim=-75) #メイン plt.figure(figsize=(6,5)) ax = plt.subplot(1,1,1,projection='3d') show_data2(ax, X0, T, X1) plt.show() #面の表示 def show_plane(ax,w): px0 = np.linspace(X0_min, X0_max, 5) px1 = np.linspace(X1_min, X1_max, 5) px0, px1 = np.meshgrid(px0, px1) y = w[0]*px0 + w[1] * px1 + w[2] ax.plot_surface(px0, px1, y, rstride=1, cstride=1, alpha=0.3, color='blue', edgecolor='black') #面のMSE def mse_plane(x0, x1, t, w): y = w[0] * x0 + w[1] *x1 + w[2] mse = np.mean((y - t)**2) return mse #メイン plt.figure(figsize=(6,5)) ax = plt.subplot(1,1,1,projection='3d') W = [1.5, 1, 90] show_plane(ax, W) show_data2(ax, X0, X1, T) mse = mse_plane(X0 ,X1, T, W) print("SD={0:.2f} cm".format(np.sqrt(mse))) plt.show() #平均誤差関数 def mse_line(x0, x1, t, w): y = w[0] * x0 + w[1] * x1 + w[2] mse = np.mean((y-t)**2) return mse #平均二乗誤差の勾配法 def dmse_line(x0,x1,t,w): y = w[0] * x0 + w[1] * x1 + w[2] d_w0 = 2 * np.mean((y - t)* x0) d_w1 = 2 * np.mean((y - t)* x1) d_w2 = 2 * np.mean(y - t) return d_w0, d_w1, d_w2 d_w = dmse_line(X0, X1, T, W) print(np.round(d_w, 1)) #この下から止まっています #勾配法 def fit_line_num(x0,t,x1): w_init = [10.0, 165.0] #初期パラメータ alpha = 0.001 #学習率 tau_max = 100000 #繰り返しの最大級 eps = 0.1 #繰り返しをやめる勾配の絶対値のしきい値 w_hist = np.zeros([tau_max,3]) w_hist[0,:] = w_init for tau in range(1, tau_max): dmse = dmse_line(x0,t,x1,w_hist[tau - 1]) w_hist[tau,0] = w_hist[tau - 1, 0] - alpha * dmse[0] w_hist[tau,1] = w_hist[tau - 1, 1] - alpha * dmse[1] w_hist[tau,2] = w_hist[tau - 1, 2] - alpha * dmse[2] if max(np.absolute(dmse)) < eps: #終了 break w0 = w_hist[tau,0] w1 = w_hist[tau,1] w2 = w_hist[tau,2] w_hist = w_hist[:tau,:] return w0, w1,w2, dmse, w_hist #メイン plt.figure(figsize=(4,4)) #勾配法呼び出し w0, w1,w2, dMSE, w_history = fit_line_num(X0,T,X1) #結果表示 print('繰り返し回数{0}'.format(w_history.shape[0])) print('W = [{0:.6f}, {1:.6f}]'.format(W0,W1,W2)) print('dMSE=[{0:.6f}, {1:.6f}]'.format(dMSE[0], dMSE[1], dMSE[2])) print('MSE={0:6f}'.format(mse_line(X0,T,X1,[w0,w1,w2]))) plt.plot(w_history[:,0], w_history[:,1], w_history[:,2], '.-', color='gray', markersize=10, markeredgecolor='cornflowerblue') plt.show()

試したこと

同じようなエラーを調べたがよくわからなかったです

補足情報(FW/ツールのバージョンなど)

もしかしてそもそも最初の方から間違っているのでしょうか?

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

w_hist[0,:] はarray([0., 0., 0.])で、そこにw_init = [10.0, 165.0]を入れようとして形が違うというエラーメッセージです。

中身は分かりませんが、そこを修正してください。

投稿2021/01/28 07:32

ppaul

総合スコア24670

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

wa-ya

2021/01/29 04:36

ありがとうございます! 次に進めました!
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問