MSE を見る限り、それなりに近似できているように見えますが、もっとパラメータを真値に近い値にしたいということでしょうか?あくまで近似なので、完全に一致することはないと思います。
python
1import matplotlib.pyplot as plt
2import numpy as np
3from scipy import stats
4from sklearn.metrics import mean_squared_error
5from scipy.optimize import curve_fit
6
7x = np.array([0, 4, 9, 14, 19, 24, 29, 34, 39, 44, 49, 54, 59,
8 64, 69, 74, 79, 84, 89, 94, 99, 104, 109, 114, 119, 124,
9 129, 134, 139, 144, 149, 154, 159, 164])
10
11y = np.array([0. , 0.17304493, 0.28618968, 0.50083195, 0.55407654,
12 0.65058236, 0.73044925, 0.83527454, 0.87687188, 0.92845258,
13 0.93510815, 0.95174709, 0.96006656, 0.9750416 , 0.97670549,
14 0.98169717, 0.98169717, 0.9843594 , 0.98868552, 0.9906822 ,
15 0.9906822 , 0.99234609, 0.99234609, 0.99567388, 0.99567388,
16 0.99567388, 0.99567388, 0.99567388, 0.99567388, 0.99567388,
17 0.99733777, 1. , 1. , 1.])
18
19
20def cdf(x, a, b, c):
21 return stats.lognorm.cdf(x, a, b, c)
22param, cov = curve_fit(cdf, x, y)
23y_pred = cdf(x, *param)
24
25# 描画する。
26plt.plot(x, y, linestyle='--', marker='o', color='b', ms=2, label='data')
27plt.plot(x, y_pred, linestyle='--', marker='o', color='g', ms=2, label='prediction')
28plt.legend()
29plt.show()
30
31mse = mean_squared_error(y, y_pred)
32print(mse) # 0.00030439892191327277
追記1
対数正規分布の累積分布関数としているのに,xが0の時にyが0とならないのでどうにかしたいのです.
対数正規分布の定義域は 0 < x < ∞ なのに、その累積分布関数で cdf(0) = 0 とならないのはおかしいということですね。
cdf(0) = 0 とならない理由は scipy.stats.lognorm 関数に loc というシフトするパラメータを含んでいるからです。リファレンス を参考にしてください。
なので、このまま fit() すると、loc も推定対象なので、近似したものでは cdf(0) = 0 とはなりません。scipy.stats.lognorm にも fit() という関数があるので、こちらを使うと、loc=0 と固定した状態で残りのパラメータを推定できます。
python
1import seaborn as sns
2
3from scipy import stats
4from sklearn.metrics import mean_squared_error
5from scipy.optimize import curve_fit
6
7# 対数正規分布に従うサンプルを生成する。
8shape, scale = 0.5, 1.
9sample = stats.lognorm(s=0.5, loc=0, scale=1.).rvs(size=2000)
10sns.distplot(sample, norm_hist=True, kde=False)
11plt.show()
12
13# loc は固定して、パラメータを推定する。
14shape_pred, loc_pred, scale_pred = stats.lognorm.fit(sample, floc=0)
15print('shape={}, loc={}, scale={}'.format(shape, loc, scale))
16# shape=0.5094946936562328, loc=0, scale=1.008179589277641
17
18x = np.linspace(0, 10, 100)
19y = stats.lognorm.cdf(x, s=shape, loc=0, scale=scale)
20y_pred = stats.lognorm.cdf(x, s=shape_pred, loc=loc_pred, scale=scale_pred)
21
22# 描画する。
23plt.plot(x, y, linestyle='--', marker='o', color='b', ms=2, label='true')
24plt.plot(x, y_pred, linestyle='--', marker='o', color='g', ms=2, label='prediction')
25plt.legend()
26plt.show()
27
28mse = mean_squared_error(y, y_pred)
29print(mse) # 7.72357454555545e-06
30print(y_pred[0]) # 0.00000000e+00 cdf(0) = 0 となっている。
追記2
curve_fit() でも次のようにすることでパラメータを固定できました。
stats.lognorm.fit()
のほうでは、fitting のアルゴリズムが異なるのか、データ数が少ない影響等により精度よく近似できませんでした。
python
1import matplotlib.pyplot as plt
2import numpy as np
3from scipy import stats
4from sklearn.metrics import mean_squared_error
5from scipy.optimize import curve_fit
6
7x = np.array([0, 4, 9, 14, 19, 24, 29, 34, 39, 44, 49, 54, 59,
8 64, 69, 74, 79, 84, 89, 94, 99, 104, 109, 114, 119, 124,
9 129, 134, 139, 144, 149, 154, 159, 164])
10
11y = np.array([0., 0.17304493, 0.28618968, 0.50083195, 0.55407654,
12 0.65058236, 0.73044925, 0.83527454, 0.87687188, 0.92845258,
13 0.93510815, 0.95174709, 0.96006656, 0.9750416, 0.97670549,
14 0.98169717, 0.98169717, 0.9843594, 0.98868552, 0.9906822,
15 0.9906822, 0.99234609, 0.99234609, 0.99567388, 0.99567388,
16 0.99567388, 0.99567388, 0.99567388, 0.99567388, 0.99567388,
17 0.99733777, 1., 1., 1.])
18
19# scipy.optimize.curve_fit を使うやり方
20######################################################
21def cdf(x, a, b):
22 return stats.lognorm.cdf(x, s=a, loc=0, scale=b)
23[s, scale], cov = curve_fit(cdf, x, y)
24print('s={}, scale={}'.format(s, scale)) # s=0.8933602211719341, scale=14.750787612138023
25# 近似した関数の結果
26y_pred = cdf(x, s, scale)
27
28# 描画する。
29plt.plot(x, y, linestyle='--', marker='o', color='b', ms=2, label='data')
30plt.plot(x, y_pred, linestyle='--', marker='o', color='g', ms=2, label='prediction')
31plt.legend()
32plt.show()
33
34print(mean_squared_error(y, y_pred)) # 0.000730706490266028
35print(y_pred[0]) # 0.0
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2018/10/04 08:22
2018/10/04 09:14
2018/10/04 10:32
2018/10/04 11:06 編集
2018/10/04 11:23