質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

4回答

5107閲覧

Scipyのcurve_fitを用いた正弦波のフィッティングがうまくできない

cream_puff

総合スコア6

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

1グッド

0クリップ

投稿2022/10/13 07:56

編集2022/10/17 19:14

実現したいこと

y = a sin(bx + c) + d という関数でフィッティングをしたいのですが、scipyのcurve_fitでうまくできませんでした。改善点、アドバイスいただけると幸いです。

発生している問題・エラーメッセージ

現在の出力
[-0.24459067 0.88349125 1.9100531 5.00003073]
イメージ説明

正しい出力
[2 3 4 5 ]

該当のソースコード

python

1import numpy as np 2import matplotlib.pyplot as plt 3from scipy.optimize import curve_fit 4 5# サンプルデータ 6x = np.arange(0,10, 0.1) 7y = 2 * np.sin(3 * x + 4) + 5 8 9# フィッティング関数 10def func(x, a, b, c, d): 11 return a * np.sin(b * x + c) + d 12 13# フィッティングを実行 14param, cov = curve_fit(func, x, y) 15print(param) 16y_fit = func(x, param[0], param[1], param[2], param[3]) 17 18# グラフの描画 19plt.scatter(x, y, c="r", s=5, label="data") 20plt.plot(x, y_fit, c='b', linewidth=1, label="fitting") 21plt.legend();
can110👍を押しています

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答4

0

天下り的にはなりますが,初期値p0を角速度bに近い値で始めさせればうまくいきます.
適当にa,b,c,d全て3.5からスタートするようにするとフィットしました.

Python

1param, cov = curve_fit(func, x, y, p0 = [3.5] * 4)

初期値を乱数にして共分散の小さいものを選ぶのも良いかもしれませんね.

Python

1import numpy as np 2min_cov = 1e8 3retry, max_retry = 0, 100 4while min_cov > 1e-10 and retry < max_retry: 5 print("try init prams:", p0 := np.random.randint(1, 10, size = 4)) 6 param, cov = curve_fit(func, x, y, p0 = p0) 7 min_cov = min(np.abs(cov).sum(), min_cov) 8 retry += 1 9 10if retry == max_retry: 11 print("max retried error") 12 exit() 13 14print("retried", retry, "times") 15print("fitted params:", param) 16y_fit = func(x, *param)

最小二乗法でうまくいかない理由

元の式をy = 2 * np.sin(-3 * x + 4) + 5にして乱数[1, 10]の範囲による初期値選択を行うとその片鱗が見えてきます.このときの解は正しくは[2, -3, 4, 5]ですが,-3は乱数の範囲にないので代わりに3が選ばれ,他のパラメータに皺寄せがいく形になり[-2, 3, 2.28, 5]となりました.

今回うまくいかなかった理由は,三角関数自体の周期性などによって,「最小二乗法で得られる誤差関数が描く超平面の局所的最適解が多すぎる.」というのが考えられます.質問者はその局所的最適解に陥ったものと思われます.また,大域的最適解も複数存在しており,先述のような別解を得ることになっているものと思われます.

投稿2022/10/13 09:02

編集2022/10/14 06:40
PondVillege

総合スコア1579

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

cream_puff

2022/10/17 10:13

ご回答ありがとうございました。「最小二乗法でうまくいかない理由」納得がいきました。
guest

0

ベストアンサー

curve_fit failing on even a sine waveにも記載されているとおり、最小二乗法をもとにしたcurve_fitは(私は理論を含めた理由は理解できていませんが)正弦関数のフィッティングには向いていません。
よって代わりにFFTをベースにした方法でフィッティングするとよいでしょう。
以下はHow do I fit a sine curve to my data with pylab and numpy?の回答をもとにしたコード例です。

Python

1import numpy as np 2import matplotlib.pyplot as plt 3#from scipy.optimize import curve_fit 4import numpy, scipy.optimize 5 6 7# サンプルデータ 8x = np.arange(0,10, 0.1) 9y = 2 * np.sin(3 * x + 4) + 5 10 11 12def fit_sin(tt, yy): 13 '''Fit sin to the input time sequence, and return fitting parameters "amp", "omega", "phase", "offset", "freq", "period" and "fitfunc"''' 14 tt = numpy.array(tt) 15 yy = numpy.array(yy) 16 ff = numpy.fft.fftfreq(len(tt), (tt[1]-tt[0])) # assume uniform spacing 17 Fyy = abs(numpy.fft.fft(yy)) 18 guess_freq = abs(ff[numpy.argmax(Fyy[1:])+1]) # excluding the zero frequency "peak", which is related to offset 19 guess_amp = numpy.std(yy) * 2.**0.5 20 guess_offset = numpy.mean(yy) 21 guess = numpy.array([guess_amp, 2.*numpy.pi*guess_freq, 0., guess_offset]) 22 23 def sinfunc(t, A, w, p, c): return A * numpy.sin(w*t + p) + c 24 popt, pcov = scipy.optimize.curve_fit(sinfunc, tt, yy, p0=guess) 25 A, w, p, c = popt 26 f = w/(2.*numpy.pi) 27 fitfunc = lambda t: A * numpy.sin(w*t + p) + c 28 return {"amp": A, "omega": w, "phase": p, "offset": c, "freq": f, "period": 1./f, "fitfunc": fitfunc, "maxcov": numpy.max(pcov), "rawres": (guess,popt,pcov)} 29 30 31res = fit_sin(x, y) 32y_fit = res['fitfunc'](x) 33 34 35# グラフの描画 36plt.scatter(x, y, c="r", s=5, label="data") 37plt.plot(x, y_fit, c='b', linewidth=1, label="fitting") 38plt.legend() 39plt.show()

イメージ説明

投稿2022/10/13 08:31

can110

総合スコア38234

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

cream_puff

2022/10/17 10:14

ご回答ありがとうございました。初期値の与え方が非常に参考になりました。
guest

0

scipyのcurve_fitでは、パラメータをある程度推測して初期値として設定しておくことが重要です。
デフォルトの初期値はすべて1になっているので、そこから正解があまりにも離れすぎていると、うまくフィッティングできません。
使う関数の性質を考慮したり、そのパラメータを使って実際に描画したりして、大体合っている場合にうまくいきやすいです。

以下、初期値を正解±0.3で設定した例です。

python

1import numpy as np 2import matplotlib.pyplot as plt 3from scipy.optimize import curve_fit 4 5# サンプルデータ 6x = np.arange(0, 10, 0.1) 7y = 2 * np.sin(3 * x + 4) + 5 8 9 10# フィッティング関数 11def func(x, a, b, c, d): 12 return a * np.sin(b * x + c) + d 13 14 15# 初期値 16p0 = [2.3, 2.7, 4.3, 4.7] 17 18 19# フィッティングを実行 20param, cov = curve_fit(func, x, y, p0=p0) # 初期値を設定 21print(param) 22# y_fit = func(x, param[0], param[1], param[2], param[3]) 23y_fit = func(x, *param) # tips:アンパックが使える 24 25# グラフの描画 26plt.scatter(x, y, c="r", s=5, label="data") 27plt.plot(x, y_fit, c="b", linewidth=1, label="fitting") 28plt.legend()

出力: [2. 3. 4. 5.]

グラフ:
イメージ説明
(参考)初期値のプロットも重ねた場合:
イメージ説明
これくらいズレてしまっていても、うまくフィッティングできる場合もあるようです。

投稿2022/10/13 10:32

T_F

総合スコア74

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

0

どんな場合にでも使えるものではないですが、ある程度範囲を絞ってフィッティングして、それを初期値に全体をフィッティングするという方法もあります。

python

1p0 = [1, 1, 1, 1] 2param0, _ = curve_fit(func, x[:30], y[:30], p0=p0) 3param, cov = curve_fit(func, x, y, p0=param0)

(今回の例では param0 の時点で全体にもフィッティングできています。)

このやり方でも、ある程度は初期値(p0)が正解に近くないと、正しく収束できません。
たとえば、p0 = [1, 8, 1, 1] ではうまくいきません。

投稿2022/10/13 09:36

bsdfan

総合スコア4520

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問