質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.64%

chainerで実装したLSTMの学習がうまくいかない

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 1,669

sodiumplus3

score 31

chainerでLSTMを実装してsin波を学習したがうまくいかない

コードは以下です。分類でないのにclassifierでラップして実装しようとしているのでそこらへんで実装にミスがあると思うのですが…

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training,optimizers,Variable
import math
import numpy as np
from chainer.datasets import TupleDataset,split_dataset_random
from chainer import iterators,serializers

# データの生成
data = []
for i in range(10000):
    data.append(math.sin(i/100))

#データの整形・分割
x = []
t = []
for i in range(len(data)-5):
    x.append(data[i:i+1])
    t.append(data[i+1])
x = np.array(x).astype('float32')
t = np.array(t).astype('float32')
t = np.array(t).reshape(len(t), 1)
print(x.shape,t.shape)

dataset = TupleDataset(x,t)
n_train = int(len(dataset)*0.7)
train,test = split_dataset_random(dataset,n_train,seed=1)

# for i in range(len(test)):
#     print(test[i])


batch_size = 10
train_iter = iterators.SerialIterator(train,batch_size)
test_iter = iterators.SerialIterator(test,batch_size,shuffle=False,repeat=False)



class lstm(chainer.Chain):
    def __init__(self,n_mid=4,n_out=1):
        super(lstm,self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None,n_mid)
            self.l2 = L.LSTM(n_mid,n_mid)
            self.l3 = L.Linear(n_mid,n_out)

    def reset_state(self):
        self.l2.reset_state()

    def __call__(self,x):
        h = self.l1(x)
        h = self.l2(h)
        h = self.l3(h)
        return h

class LSTMUpdater(training.StandardUpdater):
    def __init__(self, train_iter, optimizer, device=-1):
        super(LSTMUpdater, self).__init__(train_iter, optimizer, device=device)

    # overrided
    def update_core(self):
        train_iter = self.get_iterator('main')
        optimizer = self.get_optimizer('main')

        batch = train_iter.__next__()
        X_STF, y_STF = chainer.dataset.concat_examples(batch, self.device)

        optimizer.target.zerograds()
        optimizer.target.predictor.reset_state()
        loss = optimizer.target(Variable(X_STF), Variable(y_STF))

        loss.backward()
        optimizer.update()

class LossSumMSEOverTime(L.Classifier):
    def __init__(self, predictor):
        super(LossSumMSEOverTime, self).__init__(predictor, lossfun=F.mean_squared_error)

    def __call__(self, X_STF, y_STF):
        """
        # Param
        - X_STF (Variable: (S, T, F))
        - y_STF (Variable: (S, T, F))
        S: samples
        T: time_steps
        F: features

        # Return
        - loss (Variable: (1, ))
        """
        # 時間 T で loop させるため、Tを先頭の軸にする
        X_TSF = X_STF.transpose(1,0,2)
        y_TSF = y_STF.transpose(1,0,2)
        seq_len  = X_TSF.shape[0]

        # 各時刻についてlossをとり、最終的なlossに足していく
        loss = 0
        for t in range(seq_len):
            pred = self.predictor(X_TSF[t])
            obs  = y_TSF[t]
            loss += self.lossfun(pred, obs)
        # loss の大きさが時系列長に依存してしまうので、時系列長で割る
        loss /= seq_len

        # reporter に loss の値を渡す
        reporter.report({'loss': loss}, self)

        return loss

predictor = lstm()
net = L.Classifier(predictor,lossfun=F.mean_squared_error)
net.compute_accuracy = False
optimizer = optimizers.Adam().setup(net)
updater = LSTMUpdater(train_iter,optimizer,device=-1)

from chainer.training.triggers import EarlyStoppingTrigger
trigger = EarlyStoppingTrigger(monitor='test/main/loss',patients=3)
trainer = training.Trainer(updater,trigger,out='results/lstm_result')

trainer.extend(training.extensions.LogReport(trigger=(1,'epoch'),log_name='log'))
trainer.extend(training.extensions.Evaluator(test_iter, net, device=-1), name='test')
trainer.extend(training.extensions.PrintReport(['epoch','iteration','main/loss','test/main/loss']))
trainer.extend(training.extensions.PlotReport(['main/loss','test/main/loss'], x_key='epoch', file_name='loss.png'))

trainer.run()


serializers.save_npz('lstm.npz',net)

chainerに詳しい方、回答いただけると嬉しいです。よろしくお願いします。

追記

うまくいかない、というのは良い予測ができないという意味です。明らかに簡単なデータなので、コードのモデルに何か問題があるのだろうと思っています。上記コードで学習したのち、以下のコードでテストしています。

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training,optimizers
import math
import numpy as np
from chainer.datasets import TupleDataset,split_dataset_random
from chainer import iterators,serializers
import matplotlib.pyplot as plt


class lstm(chainer.Chain):
    def __init__(self,n_mid=4,n_out=1):
        super(lstm,self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None,n_mid)
            self.l2 = L.LSTM(n_mid,n_mid)
            self.l3 = L.Linear(n_mid,n_out)

    def reset_state(self):
        self.l2.reset_state()

    def __call__(self,x):
        h = self.l1(x)
        h = self.l2(h)
        h = self.l3(h)
        return h

loaded_net = L.Classifier(lstm())
chainer.serializers.load_npz('lstm.npz',loaded_net)

data = []
x = 10
for i in range(x):
    data.append(math.sin(i/10))
data = np.array(data).astype('float32')

true_data = []
for i in range(100):
    true_data.append(math.sin(i/10))
true_data = np.array(true_data).astype('float32')

def make_data(data):
    l = np.zeros([1,1]).astype('float32')
    data = data[-10:]
    for i in range(len(l)):
        l[i] = data[i:i+1]
    return l

# print(make_data(data))

for i in range(100-x):
    with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
        y = loaded_net.predictor.__call__(make_data(data))
        # print(np.array(y.data).reshape(-1))
        data = np.append(data,y.data.reshape(-1)[-1])



plt.plot(range(len(data)),data,label='lstm')
plt.plot(range(len(true_data)),true_data,label='train_data')
plt.legend()

plt.show()

# print(true_data)


結果は例えばこうなります。
イメージ説明

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • quickquip

    2019/05/02 08:55

    "うまくいかない"の内容を追記していただけるといいかと。(エラー? 性能が出ない?)

    キャンセル

  • sodiumplus3

    2019/05/03 01:16

    コメントありがとうございます。追記します。

    キャンセル

回答 1

check解決した方法

0

自己解決しました。
これ、我ながら間違いが多すぎて指摘する気にもならないですね…
もし同じように困っている方がいれば
https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb.py
を参考にしてみてください。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.64%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る