回答編集履歴
6
d
answer
CHANGED
@@ -55,15 +55,21 @@
|
|
55
55
|
xs = np.array([2.8, 2.9, 3.0, 3.1, 3.2, 3.2, 3.2, 3.3, 3.4])
|
56
56
|
ys = np.array([30, 26, 33, 31, 33, 35, 37, 36, 33])
|
57
57
|
|
58
|
+
# 近似関数 a1 x^4 + a2 x^2
|
59
|
+
def f1(a1, a2):
|
60
|
+
return a1 * xs**4 + a2 * xs ** 2
|
61
|
+
|
62
|
+
# 近似関数 sin(a1) x^4 + a2 x^2
|
63
|
+
def f2(a1, a2):
|
64
|
+
return np.sin(a1) * xs**4 + a2 * xs ** 2
|
65
|
+
|
58
66
|
# 2乗誤差関数
|
59
|
-
def loss(a1, a2):
|
67
|
+
def loss(a1, a2, f):
|
60
|
-
approx = a1 * xs**4 + a2 * xs ** 2 # 近似関数
|
61
|
-
|
68
|
+
return ((ys - f(a1, a2)) ** 2).sum()
|
62
|
-
return loss
|
63
69
|
|
64
70
|
# 各点での関数 loss の値を計算する。
|
65
71
|
A1, A2 = np.mgrid[-10:11, -10:11]
|
66
|
-
L = np.array([loss(a1, a2) for a1, a2 in zip(A1.ravel(), A2.ravel())])
|
72
|
+
L = np.array([loss(a1, a2, f2) for a1, a2 in zip(A1.ravel(), A2.ravel())])
|
67
73
|
L = L.reshape(A1.shape)
|
68
74
|
|
69
75
|
fig = plt.figure(figsize=(7, 7))
|
@@ -77,4 +83,10 @@
|
|
77
83
|
ax.view_init(elev=50, azim=120)
|
78
84
|
```
|
79
85
|
|
80
|
-

|
86
|
+

|
87
|
+
f(x) = a1 x^4 + a2 x^2 の2乗誤差関数
|
88
|
+
→ 凸なので局所解 = 大域解
|
89
|
+
|
90
|
+

|
91
|
+
f(x) = sin(a1) x^4 + a2 x^2 の2乗誤差関数
|
92
|
+
→ 局所解あり
|
5
d
answer
CHANGED
@@ -40,7 +40,7 @@
|
|
40
40
|
|
41
41
|
## 追記
|
42
42
|
|
43
|
-
ここでいう線形、線形でないというのは、
|
43
|
+
ここでいう線形、線形でないというのは、近似関数を推定するパラメータ a の関数として見たときの話です。
|
44
44
|
|
45
45
|
例えば、f(x) = a_1 x^4 + a_2 x^2 という関数で近似する場合、
|
46
46
|
x の関数としてみた場合非線形関数ですが、a の関数として見た場合線形関数です。
|
4
d
answer
CHANGED
@@ -36,4 +36,45 @@
|
|
36
36
|
* 最小二乗法の目的関数は凸関数
|
37
37
|
* 凸関数なので、局所解は大域解になる。
|
38
38
|
* なので、勾配法で解いても必ず大域解に収束する。
|
39
|
-
* 勾配法以外の方法として、正規方程式を解いても求められる。
|
39
|
+
* 勾配法以外の方法として、正規方程式を解いても求められる。
|
40
|
+
|
41
|
+
## 追記
|
42
|
+
|
43
|
+
ここでいう線形、線形でないというのは、損失関数(2乗誤差)を推定するパラメータ a の関数として見たときの話です。
|
44
|
+
|
45
|
+
例えば、f(x) = a_1 x^4 + a_2 x^2 という関数で近似する場合、
|
46
|
+
x の関数としてみた場合非線形関数ですが、a の関数として見た場合線形関数です。
|
47
|
+
以下、損失関数を描画した例です。
|
48
|
+
|
49
|
+
```python
|
50
|
+
import matplotlib.pyplot as plt
|
51
|
+
import numpy as np
|
52
|
+
from matplotlib.patches import FancyArrowPatch
|
53
|
+
from mpl_toolkits.mplot3d import Axes3D, proj3d
|
54
|
+
|
55
|
+
xs = np.array([2.8, 2.9, 3.0, 3.1, 3.2, 3.2, 3.2, 3.3, 3.4])
|
56
|
+
ys = np.array([30, 26, 33, 31, 33, 35, 37, 36, 33])
|
57
|
+
|
58
|
+
# 2乗誤差関数
|
59
|
+
def loss(a1, a2):
|
60
|
+
approx = a1 * xs**4 + a2 * xs ** 2 # 近似関数
|
61
|
+
loss = ((ys - approx) ** 2).sum() # 2乗誤差
|
62
|
+
return loss
|
63
|
+
|
64
|
+
# 各点での関数 loss の値を計算する。
|
65
|
+
A1, A2 = np.mgrid[-10:11, -10:11]
|
66
|
+
L = np.array([loss(a1, a2) for a1, a2 in zip(A1.ravel(), A2.ravel())])
|
67
|
+
L = L.reshape(A1.shape)
|
68
|
+
|
69
|
+
fig = plt.figure(figsize=(7, 7))
|
70
|
+
ax = fig.add_subplot(111, projection='3d')
|
71
|
+
# 各軸のラベルを設定する。
|
72
|
+
ax.set_xlabel('$a_1$', fontsize=15)
|
73
|
+
ax.set_ylabel('$a_2$', fontsize=15)
|
74
|
+
ax.set_zlabel('$loss$', fontsize=15)
|
75
|
+
# グラフを作成する。
|
76
|
+
ax.plot_surface(A1, A2, L, alpha=0.4, antialiased=False)
|
77
|
+
ax.view_init(elev=50, azim=120)
|
78
|
+
```
|
79
|
+
|
80
|
+

|
3
d
answer
CHANGED
@@ -13,7 +13,7 @@
|
|
13
13
|
|
14
14
|
### 示すべきこと
|
15
15
|
|
16
|
-
J(a) が凸関数であることの証明は以下の不等式を示
|
16
|
+
J(a) が凸関数であることの証明は以下の不等式を示す必要があります。
|
17
17
|
|
18
18
|

|
19
19
|
|
2
d
answer
CHANGED
@@ -27,7 +27,7 @@
|
|
27
27
|
|
28
28
|
## 正規方程式
|
29
29
|
|
30
|
-
hayataka2049 さんがコメントしてくださった正規方程式と
|
30
|
+
hayataka2049 さんがコメントしてくださった正規方程式との関係は以下のようになります。
|
31
31
|
|
32
32
|

|
33
33
|
|
1
d
answer
CHANGED
@@ -2,12 +2,38 @@
|
|
2
2
|
teratail は数式入力できないので画像で失礼します。
|
3
3
|
|
4
4
|
----
|
5
|
+
|
6
|
+
## 最小二乗法の定式化
|
7
|
+
|
5
8
|
以下でいう最小二乗法の目的関数 J(a) は凸関数なので、局所解が大域解となるので、勾配法を使っても求まる解は大域解となるのではないでしょうか。
|
6
9
|
|
7
10
|

|
8
11
|
|
12
|
+
## 最小二乗法の目的関数が凸関数であることの証明
|
13
|
+
|
14
|
+
### 示すべきこと
|
15
|
+
|
9
16
|
J(a) が凸関数であることの証明は以下の不等式を示せます。
|
10
17
|
|
11
18
|

|
12
19
|
|
13
|
-
式展開してゴリゴリ計算すれば導出できます。(わからなければ補足します。)
|
20
|
+
式展開してゴリゴリ計算すれば導出できます。(わからなければ補足します。)
|
21
|
+
|
22
|
+
### 証明
|
23
|
+
|
24
|
+

|
25
|
+
|
26
|
+
画像はクリックすると拡大できます。
|
27
|
+
|
28
|
+
## 正規方程式
|
29
|
+
|
30
|
+
hayataka2049 さんがコメントしてくださった正規方程式というのは、以下です。
|
31
|
+
|
32
|
+

|
33
|
+
|
34
|
+
## Summary
|
35
|
+
|
36
|
+
* 最小二乗法の目的関数は凸関数
|
37
|
+
* 凸関数なので、局所解は大域解になる。
|
38
|
+
* なので、勾配法で解いても必ず大域解に収束する。
|
39
|
+
* 勾配法以外の方法として、正規方程式を解いても求められる。
|