前提・実現したいこと
SARSA法によるユーザシュミレータの作成
ここに質問の内容を詳しく書いてください。
(対話システム作成において、参考書に基づいてユーザシュミレータを作成しました。
その際、強化学習でQ学習を行いましたが、SARSA法も試したいと思っていますが、現在のコードにどのようにして組み込めば良いかわかりません
該当のソースコード
python
import random # システムの対話行為 sys_da_lis = [ "open-prompt", "ask-drink", "ask-size", "ask-tmp", "ask-num1", "close-prompt"] # システムの状態 states = ["0000","0001","0010","0011","0100","0101","0110","0111","1000","1001","1010","1011","1100","1101","1110","1111"] # Q値(行動状態価値)の初期化 Q = {} for state in states: Q[state] = {} for sys_da in sys_da_lis: Q[state][sys_da] = 0 # フレームを更新 def update_frame(frame, da, conceptdic): if da == "request-order": for k,v in conceptdic.items(): # コンセプトの情報でスロットを埋める frame[k] = v elif da == "initialize": frame = {"drink": "", "size": "", "tmp": "", "num1": ""} elif da == "correct-info": for k,v in conceptdic.items(): if frame[k] == v: frame[k] = "" return frame # フレームから状態を表す文字列に変換 # place, date, type の順に値が埋まっていたら1,埋まってなければ0 def frame2state(frame): state = "" for k in ["drink","size","tmp","num1"]: if frame[k] == "": state += "0" else: state += "1" return state # ユーザシミュレータ # ユーザは聞かれたスロットについて的確に答える. # open-promptには聞きたいことをいくつかランダムに伝える. # tell-info によるシステム回答の内容が合っていたらgoodbyeをする. # tell-infoの内容が間違っていたらinitializeをする. def next_user_da(sys_da, sys_conceptdic, intention): if sys_da == "ask-drink": return "request-order", {"drink": intention["drink"]} elif sys_da == "ask-size": return "request-order", {"size": intention["size"]} elif sys_da == "ask-tmp": return "request-order", {"tmp": intention["tmp"]} elif sys_da == "ask-num1": return "request-order", {"num1": intention["num1"]} elif sys_da == "open-prompt": while(True): dic = {} for k,v in intention.items(): if random.choice([0,1]) == 0: dic[k] = v if len(dic) > 0: return "request-order", dic elif sys_da == "close-prompt": is_ok = True for k,v in intention.items(): if sys_conceptdic[k] != v: is_ok = False break if is_ok: return "goodbye", {} else: return "initialize", {} # ランダムに行動 def next_system_da(frame): # 値がすべて埋まってないとtell-infoは発話できない cands = list(sys_da_lis) if frame["drink"] == "" or frame["size"] == "" or frame["tmp"] == "" or frame["num1"] == "": cands.remove("close-prompt") value = random.random() sys_da = random.choice(cands) sys_conceptdic = {} if sys_da == "close-prompt": sys_conceptdic = frame return sys_da, sys_conceptdic # 対話を成功するまで一回実行 # intentionはユーザの意図,alphaは学習係数,gammaは割引率を表す def run_dialogue(intention, alpha=0.1, gamma=0.9): frame = {"drink": "", "size": "", "tmp": "", "num1": ""} while(True): s1 = frame2state(frame) sys_da, sys_conceptdic = next_system_da(frame) da, conceptdic = next_user_da(sys_da, sys_conceptdic, intention) frame = update_frame(frame, da, conceptdic) s2 = frame2state(frame) # 遷移先の状態(s2)から得られる最大の価値を取得 da_lis = sorted(Q[s2].items(),key=lambda x:x[1], reverse=True) maxval = da_lis[0][1] if da == "goodbye": # 成功した対話の後の状態は存在しないのでmaxvalは0 maxval = 0 # Q値を更新して対話を終わる Q[s1][sys_da] = Q[s1][sys_da] + alpha * ((100 + gamma * maxval) - Q[s1][sys_da]) break else: # Q値を更新 Q[s1][sys_da] = Q[s1][sys_da] + alpha * ((0 + gamma * maxval) - Q[s1][sys_da]) if __name__ == "__main__": # 十万回対話をして学習 for i in range(100000): run_dialogue({"drink":"コーヒー","size":"エル","tmp":"アイス","num1":"4つ"}) # Q値を表示 print(Q) # 各状態で最適な行動をQ値とともに表示 for k,v in Q.items(): da_lis = sorted(Q[k].items(),key=lambda x:x[1], reverse=True) print(k, "=>", da_lis[0][0], da_lis[0][1]) # end of file
試したこと
このコードはQ学習を用いて作成したプログラムです
これをSARSA法に組み替えて実行したいと考えています
あなたの回答
tips
プレビュー