質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.63%

chainerのLSTMのpredictorの出力がよくわからない

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 1,105

sodiumplus3

score 31

やったこと

chainerでLSTMを実装しsin波を学習させました。ネットの実装例がバラバラでやりづらかったのですが、最終的にchainer公式のptb実装例を参考に以下のようにコードを書きました。

import numpy as np

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import reporter, training, datasets, iterators, optimizers, serializers
from chainer.training import extensions
from chainer.datasets import TupleDataset
import math


import matplotlib.pyplot as plt


class lstm(chainer.Chain):
    def __init__(self,n_mid=5,n_out=1):
        super(lstm,self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None,n_mid)
            self.l2 = L.LSTM(n_mid,n_mid)
            self.l3 = L.Linear(n_mid,n_out)

    def reset_state(self):
        self.l2.reset_state()

    def __call__(self,x):
        h = self.l1(x)
        h = self.l2(h)
        h = self.l3(h)
        return h


net = L.Classifier(lstm(),lossfun=F.mean_squared_error)
net.compute_accuracy = False
optimizer = optimizers.Adam()
optimizer.setup(net)

class LSTM_Iterator(chainer.dataset.Iterator):
    def __init__(self, dataset, batch_size=10, seq_len=5, repeat=True):
        self.seq_length = seq_len
        self.dataset = dataset
        self.nsamples = len(dataset)

        self.batch_size = batch_size
        self.repeat = repeat

        self.epoch = 0
        self.iteration = 0
        self.offsets = np.random.randint(0, len(dataset)-self.seq_length,size=batch_size)

        self.is_new_epoch = False

    def __next__(self):
        if not self.repeat and self.iteration * self.batch_size >= self.nsamples:
            raise StopIteration

        x = self.get_data()
        self.iteration += 1
        t = self.get_data()

        epoch = self.iteration * self.batch_size // self.nsamples
        self.is_new_epoch = self.epoch < epoch

        if self.is_new_epoch:
            self.epoch = epoch
            self.offsets = np.random.randint(0, self.nsamples-self.seq_length,size=self.batch_size)

        return list(zip(x, t))

    @property
    def epoch_detail(self):
        return self.iteration * self.batch_size / len(self.dataset)

    def get_data(self):
        return [self.dataset[(offset + self.iteration) % len(self.dataset)]
                for offset in self.offsets]

    def serialize(self, serializer):
        self.iteration = serializer('iteration', self.iteration)
        self.epoch     = serializer('epoch', self.epoch)



class LSTM_updater(training.StandardUpdater):
    def __init__(self, train_iter, optimizer, device):
        super(LSTM_updater, self).__init__(train_iter, optimizer, device=device)
        self.seq_length = train_iter.seq_length

    def update_core(self):
        loss = 0

        train_iter = self.get_iterator('main')
        optimizer = self.get_optimizer('main')

        for i in range(self.seq_length):
            batch = train_iter.__next__()
            x, t  = self.converter(batch, self.device)
            loss += optimizer.target(chainer.Variable(x), chainer.Variable(t))

        optimizer.target.zerograds()
        loss.backward()
        loss.unchain_backward()
        optimizer.update()

# データ作成
n_data = 500
sin_data = []
for i in range(n_data):
    sin_data.append(math.sin(i/50*math.pi))

# データセット
n_train = int(n_data*0.8)
n_test  = int(n_data*0.2)

sin_data = np.array(sin_data).astype(np.float32)

x_train, x_test = sin_data[:n_train],sin_data[n_train:]

train = TupleDataset(x_train)
test  = TupleDataset(x_test)

train_iter = LSTM_Iterator(train, batch_size = 10, seq_len = 10)
test_iter  = LSTM_Iterator(test,  batch_size = 10, seq_len = 10, repeat = False)

updater = LSTM_updater(train_iter, optimizer, -1)
trainer = training.Trainer(updater, (100, 'epoch'), out='results/lstm_result')

eval_model = net.copy()
eval_rnn = eval_model.predictor
eval_rnn.train = False
eval_rnn.reset_state()
trainer.extend(extensions.Evaluator(test_iter, eval_model, device=-1), name='val')

trainer.extend(extensions.LogReport(trigger=(1,'epoch'),log_name='log'))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'val/main/loss']))
trainer.extend(training.extensions.PlotReport(['main/loss','val/main/loss'], x_key='epoch', file_name='loss.png'))
# trainer.extend(extensions.ProgressBar())

trainer.run()


以上が学習コードです。さて、この学習済みモデルをhttps://qiita.com/chachay/items/052406176c55dd5b9a6a
の検証方法1を参考に

presteps = 10
net.predictor.reset_state()
print(net.predictor(chainer.Variable(np.arange(10,dtype='float32').reshape((-1,1)))))
for i in range(presteps):
    y = net.predictor(chainer.Variable(np.roll(x_train,i).reshape((-1,1))))


plt.plot(range(n_train),np.roll(y.data,-presteps),label='lstm')
plt.plot(range(n_train),x_train,label='train')
plt.legend()
plt.show()


を追記して検証したところ、画像のように、ちゃんと学習が進んでいる様子が見られました。
イメージ説明

わからないこと

さて、ここからが質問なのですが、検証方法について疑問があります。

  1. presteps回のループの中でyに代入されるのは最後のpredictorの結果だけだと思うのですが、なぜfor文を回しているのか
  2. predictorが何をしているのか調べるためにprint(net.predictor(chainer.Variable(np.arange(10,dtype='float32').reshape((-1,1)))))を試してみました。すると、(10,1)の配列のVariableが返ってきました。自分の作ったLSTMモデルはlstmクラスで示したように、「任意の長さの配列を入力すると長さ5のLSTMを介して1つの出力が返ってくる」というモデルのつもりだったのですが、なぜこうなっているのでしょう?
    確かに学習コードを書く際、x,tで学習を行うときにtにも10の長さの正解ラベルを与えたので変だなと思ってはいたのですが…

LSTMはまだ詳しい人がそう多くないせいか、回答をなかなかいただけないので、多少自信がなくても「こうなのでは」とご意見をいただけると嬉しいです。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

check解決した方法

0

1についてはいまだによくわかってないですが、2についてはわかりました。
(10,1)の10はバッチ数、1が出力だったので、想定したモデルで間違ってませんでした。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.63%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る