teratail header banner
teratail header banner
質問するログイン新規登録
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

2回答

1303閲覧

折れ線回帰をpythonで実施したいです。

kondasu

総合スコア15

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2023/06/27 08:14

0

0

実現したいこと

折れ線回帰をpythonで実施したいです。

前提

あるデータがあり、それに2直線で折れ線回帰を実施したいです。
ただし、2直線は1本目の終点と2本目の始点で繋がるという条件です。
また、最低3つの点を使って直線を描くことにしています。
例えば10個データがあるとすると、最初の3点で1本目の直線を引き、2本目は残りの7つ+1本目の終点のデータによって引くことにします。

発生している問題・エラーメッセージ

回帰がうまくいっていないのですが、有識者の方のご意見を頂ければと思います。

参考までにpythonのプログラムを添付します。

from scipy.optimize import curve_fit from numpy import arange x=np.array(data[0], dtype=float) y=np.array(data[1], dtype=float) n = x.shape[0] for i in range(2, n-2): seg1_x = x[0:i+1] seg1_y = y[0:i+1] seg2_x = x[i:n] seg2_y = y[i:n] def func1(seg1_x, a1, b1): return a1*seg1_x+b1 param1, param1_cov =curve_fit(func1, seg1_x, seg1_y) def func2(seg2_x, a2): return a2*(seg2_x-x[i])+(a1*x[i]+b1) param2, param2_cov =curve_fit(func2, seg2_x, seg2_y) ans1 = param1[0]*seg1_x+param1[1] ans2 = param2[0]*(seg2_x-x[i])+(param1[0]*x[i]+param1[1]) plt.plot(x, y, '-o', color ='red', label ="data") plt.plot(seg1_x, ans1, '--', color ='blue') plt.plot(seg2_x, ans2, '--', color ='green') plt.show() ![イメージ説明](https://ddjkaamml8q8x.cloudfront.net/questions/2023-06-27/a755552d-658b-4cd8-8c73-a82eaf60e266.png) ### 補足情報(FW/ツールのバージョンなど) テストデータも添付します。 x y 1 -1.040122 2 -1.563332 3 -1.792177 4 -0.926176 5 -0.746776 6 -1.260266 7 -1.046665 8 -0.058491 9 0.999898 10 1.642013 11 3.016936 12 2.671379 13 0.865827 14 0.802594 15 -0.738069 16 -0.666671 17 -0.159904 ここにより詳細な情報を記載してください。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答2

0

ベストアンサー

やろうとしていることとはちょっと違いますが、折れ線のフィッティングしたいということなら、シンプルに折れ線の関数を定義して curve_fit してやればいいと思います。
折れ曲がりの点の座標(xx, yy)、その前後での傾きa1a2の4つをパラメータにして定義するのが分かりやすいのではないでしょうか。

python

1import numpy as np 2 3def func(x, xx, yy, a1, a2): 4 return np.where(x < xx, a1, a2) * (x - xx) + yy

以下、上の関数をフィッティングする例です。デフォルトの初期値ではうまく収束しなかったので適当な値を指定しています。

python

1import numpy as np 2from scipy.optimize import curve_fit 3import matplotlib.pyplot as plt 4 5data = np.array([ 6 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0], 7 [-1.040122, -1.563332, -1.792177, -0.926176, -0.746776, -1.260266, -1.046665, -0.058491, 0.999898, 1.642013, 3.016936, 2.671379, 0.865827, 0.802594, -0.738069, -0.666671, -0.159904], 8]) 9 10x = data[0] 11y = data[1] 12 13def func(x, xx, yy, a1, a2): 14 return np.where(x < xx, a1, a2) * (x - xx) + yy 15 16params, _ = curve_fit(func, x, y, p0=(5, 1, 1, -1)) 17 18xx, yy, *_ = params 19x1 = [val for val in x if val < xx] + [xx] 20x2 = [xx] + [val for val in x if val > xx] 21 22plt.plot(x, y, '-o', color='red', label='data') 23plt.plot(x1, func(x1, *params), '--', color='blue') 24plt.plot(x2, func(x2, *params), '--', color='green') 25plt.show()

投稿2023/06/28 00:44

bsdfan

総合スコア4921

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

0

この回答はGPT-4による自動生成回答です

折れ線回帰の問題を解決するために、以下のようにコードを修正しました。

python

1import numpy as np 2import matplotlib.pyplot as plt 3from scipy.optimize import curve_fit 4 5data = [ 6 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], 7 [-1.040122, -1.563332, -1.792177, -0.926176, -0.746776, -1.260266, -1.046665, -0.058491, 0.999898, 1.642013, 3.016936, 2.671379, 0.865827, 0.802594, -0.738069, -0.666671, -0.159904] 8] 9 10x = np.array(data[0], dtype=float) 11y = np.array(data[1], dtype=float) 12n = x.shape[0] 13 14def func1(x, a1, b1): 15 return a1 * x + b1 16 17def func2(x, a2, x0, a1, b1): 18 return a2 * (x - x0) + (a1 * x0 + b1) 19 20best_error = float('inf') 21best_i = None 22best_params1 = None 23best_params2 = None 24 25for i in range(2, n - 2): 26 seg1_x = x[:i + 1] 27 seg1_y = y[:i + 1] 28 seg2_x = x[i:] 29 seg2_y = y[i:] 30 31 params1, _ = curve_fit(func1, seg1_x, seg1_y) 32 a1, b1 = params1 33 34 params2, _ = curve_fit(func2, seg2_x, seg2_y, p0=[1, x[i], a1, b1]) 35 a2, x0, _, _ = params2 36 37 ans1 = func1(seg1_x, a1, b1) 38 ans2 = func2(seg2_x, a2, x0, a1, b1) 39 40 error = np.sum((ans1 - seg1_y) ** 2) + np.sum((ans2 - seg2_y) ** 2) 41 42 if error < best_error: 43 best_error = error 44 best_i = i 45 best_params1 = params1 46 best_params2 = params2 47 48seg1_x = x[:best_i + 1] 49seg1_y = y[:best_i + 1] 50seg2_x = x[best_i:] 51seg2_y = y[best_i:] 52 53a1, b1 = best_params1 54a2, x0, _, _ = best_params2 55 56ans1 = func1(seg1_x, a1, b1) 57ans2 = func2(seg2_x, a2, x0, a1, b1) 58 59plt.plot(x, y, '-o', color='red', label="data") 60plt.plot(seg1_x, ans1, '--', color='blue') 61plt.plot(seg2_x, ans2, '--', color='green') 62plt.show()

この修正により、最適な分割点を見つけるために、各分割点での誤差を計算し、最小誤差を持つ分割点を選択するようになりました。また、func2関数の定義を変更し、curve_fitによるフィッティングが正しく行われるようにしました。これにより、折れ線回帰が正しく実行されるようになります。

投稿2023/06/27 08:15

teratail-ai

Powered by GPT-4

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.30%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問