###前提・実現したいこと
Octaveで最急降下法を実装してデータを分析しようと考えている。
具体的にはデータに一番フィットする関数Θ0+Θ1xを求めようとしている。
その上でいくつかの問題が発生したため、ソースコードのアルゴリズムが間違っているのか、どこが間違っているのかを教えていただきたいです。
###発生している問題・エラーメッセージ
data = [[0,4],[1,7],[2,7],[3,8],[4,10]]の時には、a = 0.1の時だけ上手く近似できるが、aの値をいじると近似が壊れる。(本来だったらあってはならないこと)
data.txtのデータを読み込んで近似させようとすると話にならない結果(とても近似しているとは思えない結果)が変える。
###ソースコード
main.m
clf hold on data = load("./data.txt") alpha = input("Alpha : ") theta0 = -5 theta1 = -5 prev_theta0 = 0 prev_theta1 = 0 allow_cost = 1.0 * 10 ** -6 plot(data(:,1),data(:,2),"*") num = length(data) while(1) sigma_theta0 = 0 sigma_theta1 = 0 for i = 1 : num sigma_theta0 += h(theta0,theta1,data(i,1)) - data(i,2) sigma_theta1 += (h(theta0,theta1,data(i,1)) - data(i,2)) * data(i,1) end prev_theta0 = theta0 prev_theta1 = theta1 theta0 = theta0 - alpha * sigma_theta0 / num theta1 = theta1 - alpha * sigma_theta1 / num alpha = alpha / 2 if(abs(prev_theta0 - theta0) < allow_cost && abs(prev_theta1 - theta1) < allow_cost) break end end X = [min(data(:,1))-5,max(data(:,1))+5] plot(X,h(theta0,theta1,X)) theta0 theta1
h.m
function y = h(theta0,theta1,x) y = theta0 + theta1 * x end
data.txt
7 1850 26 990 10 1080 9 2230 10 1870 72 530 54 610 24 965 11 1170
###補足情報(言語/FW/ツール等のバージョンなど)
Octaveを利用しています。
データの意味と関数と数式の関係がわかりません。この式を適用しようとして、このように書いたというのが分かればお力になれると思いうのですが・・・
回答1件
あなたの回答
tips
プレビュー