回答編集履歴
6
d
test
CHANGED
@@ -112,15 +112,27 @@
|
|
112
112
|
|
113
113
|
|
114
114
|
|
115
|
+
# 近似関数 a1 x^4 + a2 x^2
|
116
|
+
|
117
|
+
def f1(a1, a2):
|
118
|
+
|
119
|
+
return a1 * xs**4 + a2 * xs ** 2
|
120
|
+
|
121
|
+
|
122
|
+
|
123
|
+
# 近似関数 sin(a1) x^4 + a2 x^2
|
124
|
+
|
125
|
+
def f2(a1, a2):
|
126
|
+
|
127
|
+
return np.sin(a1) * xs**4 + a2 * xs ** 2
|
128
|
+
|
129
|
+
|
130
|
+
|
115
131
|
# 2乗誤差関数
|
116
132
|
|
117
|
-
def loss(a1, a2):
|
133
|
+
def loss(a1, a2, f):
|
118
134
|
|
119
|
-
approx = a1 * xs**4 + a2 * xs ** 2 # 近似関数
|
120
|
-
|
121
|
-
|
135
|
+
return ((ys - f(a1, a2)) ** 2).sum()
|
122
|
-
|
123
|
-
return loss
|
124
136
|
|
125
137
|
|
126
138
|
|
@@ -128,7 +140,7 @@
|
|
128
140
|
|
129
141
|
A1, A2 = np.mgrid[-10:11, -10:11]
|
130
142
|
|
131
|
-
L = np.array([loss(a1, a2) for a1, a2 in zip(A1.ravel(), A2.ravel())])
|
143
|
+
L = np.array([loss(a1, a2, f2) for a1, a2 in zip(A1.ravel(), A2.ravel())])
|
132
144
|
|
133
145
|
L = L.reshape(A1.shape)
|
134
146
|
|
@@ -157,3 +169,15 @@
|
|
157
169
|
|
158
170
|
|
159
171
|
![イメージ説明](6862cc8f3a6b345bbeafc66b319c4695.png)
|
172
|
+
|
173
|
+
f(x) = a1 x^4 + a2 x^2 の2乗誤差関数
|
174
|
+
|
175
|
+
→ 凸なので局所解 = 大域解
|
176
|
+
|
177
|
+
|
178
|
+
|
179
|
+
![イメージ説明](2a169cdc73a0b92a3e57584083901753.png)
|
180
|
+
|
181
|
+
f(x) = sin(a1) x^4 + a2 x^2 の2乗誤差関数
|
182
|
+
|
183
|
+
→ 局所解あり
|
5
d
test
CHANGED
@@ -82,7 +82,7 @@
|
|
82
82
|
|
83
83
|
|
84
84
|
|
85
|
-
ここでいう線形、線形でないというのは、
|
85
|
+
ここでいう線形、線形でないというのは、近似関数を推定するパラメータ a の関数として見たときの話です。
|
86
86
|
|
87
87
|
|
88
88
|
|
4
d
test
CHANGED
@@ -75,3 +75,85 @@
|
|
75
75
|
* なので、勾配法で解いても必ず大域解に収束する。
|
76
76
|
|
77
77
|
* 勾配法以外の方法として、正規方程式を解いても求められる。
|
78
|
+
|
79
|
+
|
80
|
+
|
81
|
+
## 追記
|
82
|
+
|
83
|
+
|
84
|
+
|
85
|
+
ここでいう線形、線形でないというのは、損失関数(2乗誤差)を推定するパラメータ a の関数として見たときの話です。
|
86
|
+
|
87
|
+
|
88
|
+
|
89
|
+
例えば、f(x) = a_1 x^4 + a_2 x^2 という関数で近似する場合、
|
90
|
+
|
91
|
+
x の関数としてみた場合非線形関数ですが、a の関数として見た場合線形関数です。
|
92
|
+
|
93
|
+
以下、損失関数を描画した例です。
|
94
|
+
|
95
|
+
|
96
|
+
|
97
|
+
```python
|
98
|
+
|
99
|
+
import matplotlib.pyplot as plt
|
100
|
+
|
101
|
+
import numpy as np
|
102
|
+
|
103
|
+
from matplotlib.patches import FancyArrowPatch
|
104
|
+
|
105
|
+
from mpl_toolkits.mplot3d import Axes3D, proj3d
|
106
|
+
|
107
|
+
|
108
|
+
|
109
|
+
xs = np.array([2.8, 2.9, 3.0, 3.1, 3.2, 3.2, 3.2, 3.3, 3.4])
|
110
|
+
|
111
|
+
ys = np.array([30, 26, 33, 31, 33, 35, 37, 36, 33])
|
112
|
+
|
113
|
+
|
114
|
+
|
115
|
+
# 2乗誤差関数
|
116
|
+
|
117
|
+
def loss(a1, a2):
|
118
|
+
|
119
|
+
approx = a1 * xs**4 + a2 * xs ** 2 # 近似関数
|
120
|
+
|
121
|
+
loss = ((ys - approx) ** 2).sum() # 2乗誤差
|
122
|
+
|
123
|
+
return loss
|
124
|
+
|
125
|
+
|
126
|
+
|
127
|
+
# 各点での関数 loss の値を計算する。
|
128
|
+
|
129
|
+
A1, A2 = np.mgrid[-10:11, -10:11]
|
130
|
+
|
131
|
+
L = np.array([loss(a1, a2) for a1, a2 in zip(A1.ravel(), A2.ravel())])
|
132
|
+
|
133
|
+
L = L.reshape(A1.shape)
|
134
|
+
|
135
|
+
|
136
|
+
|
137
|
+
fig = plt.figure(figsize=(7, 7))
|
138
|
+
|
139
|
+
ax = fig.add_subplot(111, projection='3d')
|
140
|
+
|
141
|
+
# 各軸のラベルを設定する。
|
142
|
+
|
143
|
+
ax.set_xlabel('$a_1$', fontsize=15)
|
144
|
+
|
145
|
+
ax.set_ylabel('$a_2$', fontsize=15)
|
146
|
+
|
147
|
+
ax.set_zlabel('$loss$', fontsize=15)
|
148
|
+
|
149
|
+
# グラフを作成する。
|
150
|
+
|
151
|
+
ax.plot_surface(A1, A2, L, alpha=0.4, antialiased=False)
|
152
|
+
|
153
|
+
ax.view_init(elev=50, azim=120)
|
154
|
+
|
155
|
+
```
|
156
|
+
|
157
|
+
|
158
|
+
|
159
|
+
![イメージ説明](6862cc8f3a6b345bbeafc66b319c4695.png)
|
3
d
test
CHANGED
@@ -28,7 +28,7 @@
|
|
28
28
|
|
29
29
|
|
30
30
|
|
31
|
-
J(a) が凸関数であることの証明は以下の不等式を示
|
31
|
+
J(a) が凸関数であることの証明は以下の不等式を示す必要があります。
|
32
32
|
|
33
33
|
|
34
34
|
|
2
d
test
CHANGED
@@ -56,7 +56,7 @@
|
|
56
56
|
|
57
57
|
|
58
58
|
|
59
|
-
hayataka2049 さんがコメントしてくださった正規方程式と
|
59
|
+
hayataka2049 さんがコメントしてくださった正規方程式との関係は以下のようになります。
|
60
60
|
|
61
61
|
|
62
62
|
|
1
d
test
CHANGED
@@ -6,11 +6,25 @@
|
|
6
6
|
|
7
7
|
----
|
8
8
|
|
9
|
+
|
10
|
+
|
11
|
+
## 最小二乗法の定式化
|
12
|
+
|
13
|
+
|
14
|
+
|
9
15
|
以下でいう最小二乗法の目的関数 J(a) は凸関数なので、局所解が大域解となるので、勾配法を使っても求まる解は大域解となるのではないでしょうか。
|
10
16
|
|
11
17
|
|
12
18
|
|
13
19
|
![イメージ説明](4c95b74fbfa5d5b7028ca50c91e3fadf.png)
|
20
|
+
|
21
|
+
|
22
|
+
|
23
|
+
## 最小二乗法の目的関数が凸関数であることの証明
|
24
|
+
|
25
|
+
|
26
|
+
|
27
|
+
### 示すべきこと
|
14
28
|
|
15
29
|
|
16
30
|
|
@@ -23,3 +37,41 @@
|
|
23
37
|
|
24
38
|
|
25
39
|
式展開してゴリゴリ計算すれば導出できます。(わからなければ補足します。)
|
40
|
+
|
41
|
+
|
42
|
+
|
43
|
+
### 証明
|
44
|
+
|
45
|
+
|
46
|
+
|
47
|
+
![イメージ説明](01a2bd70a954831e96b3fc69768e9b52.png)
|
48
|
+
|
49
|
+
|
50
|
+
|
51
|
+
画像はクリックすると拡大できます。
|
52
|
+
|
53
|
+
|
54
|
+
|
55
|
+
## 正規方程式
|
56
|
+
|
57
|
+
|
58
|
+
|
59
|
+
hayataka2049 さんがコメントしてくださった正規方程式というのは、以下です。
|
60
|
+
|
61
|
+
|
62
|
+
|
63
|
+
![イメージ説明](f7d11a3d533d93ed4a9525f3bce4447a.png)
|
64
|
+
|
65
|
+
|
66
|
+
|
67
|
+
## Summary
|
68
|
+
|
69
|
+
|
70
|
+
|
71
|
+
* 最小二乗法の目的関数は凸関数
|
72
|
+
|
73
|
+
* 凸関数なので、局所解は大域解になる。
|
74
|
+
|
75
|
+
* なので、勾配法で解いても必ず大域解に収束する。
|
76
|
+
|
77
|
+
* 勾配法以外の方法として、正規方程式を解いても求められる。
|