質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
Chainer

Chainerは、国産の深層学習フレームワークです。あらゆるニューラルネットワークをPythonで柔軟に書くことができ、学習させることが可能。GPUをサポートしており、複数のGPUを用いた学習も直感的に記述できます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Q&A

解決済

1回答

2280閲覧

chainerのLSTMのpredictorの出力がよくわからない

sodiumplus3

総合スコア71

Chainer

Chainerは、国産の深層学習フレームワークです。あらゆるニューラルネットワークをPythonで柔軟に書くことができ、学習させることが可能。GPUをサポートしており、複数のGPUを用いた学習も直感的に記述できます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

0グッド

0クリップ

投稿2019/05/06 09:59

編集2019/05/06 10:00

###やったこと
chainerでLSTMを実装しsin波を学習させました。ネットの実装例がバラバラでやりづらかったのですが、最終的にchainer公式のptb実装例を参考に以下のようにコードを書きました。

python

1import numpy as np 2 3import chainer 4import chainer.functions as F 5import chainer.links as L 6from chainer import reporter, training, datasets, iterators, optimizers, serializers 7from chainer.training import extensions 8from chainer.datasets import TupleDataset 9import math 10 11 12import matplotlib.pyplot as plt 13 14 15class lstm(chainer.Chain): 16 def __init__(self,n_mid=5,n_out=1): 17 super(lstm,self).__init__() 18 with self.init_scope(): 19 self.l1 = L.Linear(None,n_mid) 20 self.l2 = L.LSTM(n_mid,n_mid) 21 self.l3 = L.Linear(n_mid,n_out) 22 23 def reset_state(self): 24 self.l2.reset_state() 25 26 def __call__(self,x): 27 h = self.l1(x) 28 h = self.l2(h) 29 h = self.l3(h) 30 return h 31 32 33net = L.Classifier(lstm(),lossfun=F.mean_squared_error) 34net.compute_accuracy = False 35optimizer = optimizers.Adam() 36optimizer.setup(net) 37 38class LSTM_Iterator(chainer.dataset.Iterator): 39 def __init__(self, dataset, batch_size=10, seq_len=5, repeat=True): 40 self.seq_length = seq_len 41 self.dataset = dataset 42 self.nsamples = len(dataset) 43 44 self.batch_size = batch_size 45 self.repeat = repeat 46 47 self.epoch = 0 48 self.iteration = 0 49 self.offsets = np.random.randint(0, len(dataset)-self.seq_length,size=batch_size) 50 51 self.is_new_epoch = False 52 53 def __next__(self): 54 if not self.repeat and self.iteration * self.batch_size >= self.nsamples: 55 raise StopIteration 56 57 x = self.get_data() 58 self.iteration += 1 59 t = self.get_data() 60 61 epoch = self.iteration * self.batch_size // self.nsamples 62 self.is_new_epoch = self.epoch < epoch 63 64 if self.is_new_epoch: 65 self.epoch = epoch 66 self.offsets = np.random.randint(0, self.nsamples-self.seq_length,size=self.batch_size) 67 68 return list(zip(x, t)) 69 70 @property 71 def epoch_detail(self): 72 return self.iteration * self.batch_size / len(self.dataset) 73 74 def get_data(self): 75 return [self.dataset[(offset + self.iteration) % len(self.dataset)] 76 for offset in self.offsets] 77 78 def serialize(self, serializer): 79 self.iteration = serializer('iteration', self.iteration) 80 self.epoch = serializer('epoch', self.epoch) 81 82 83 84class LSTM_updater(training.StandardUpdater): 85 def __init__(self, train_iter, optimizer, device): 86 super(LSTM_updater, self).__init__(train_iter, optimizer, device=device) 87 self.seq_length = train_iter.seq_length 88 89 def update_core(self): 90 loss = 0 91 92 train_iter = self.get_iterator('main') 93 optimizer = self.get_optimizer('main') 94 95 for i in range(self.seq_length): 96 batch = train_iter.__next__() 97 x, t = self.converter(batch, self.device) 98 loss += optimizer.target(chainer.Variable(x), chainer.Variable(t)) 99 100 optimizer.target.zerograds() 101 loss.backward() 102 loss.unchain_backward() 103 optimizer.update() 104 105# データ作成 106n_data = 500 107sin_data = [] 108for i in range(n_data): 109 sin_data.append(math.sin(i/50*math.pi)) 110 111# データセット 112n_train = int(n_data*0.8) 113n_test = int(n_data*0.2) 114 115sin_data = np.array(sin_data).astype(np.float32) 116 117x_train, x_test = sin_data[:n_train],sin_data[n_train:] 118 119train = TupleDataset(x_train) 120test = TupleDataset(x_test) 121 122train_iter = LSTM_Iterator(train, batch_size = 10, seq_len = 10) 123test_iter = LSTM_Iterator(test, batch_size = 10, seq_len = 10, repeat = False) 124 125updater = LSTM_updater(train_iter, optimizer, -1) 126trainer = training.Trainer(updater, (100, 'epoch'), out='results/lstm_result') 127 128eval_model = net.copy() 129eval_rnn = eval_model.predictor 130eval_rnn.train = False 131eval_rnn.reset_state() 132trainer.extend(extensions.Evaluator(test_iter, eval_model, device=-1), name='val') 133 134trainer.extend(extensions.LogReport(trigger=(1,'epoch'),log_name='log')) 135trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'val/main/loss'])) 136trainer.extend(training.extensions.PlotReport(['main/loss','val/main/loss'], x_key='epoch', file_name='loss.png')) 137# trainer.extend(extensions.ProgressBar()) 138 139trainer.run()

以上が学習コードです。さて、この学習済みモデルをhttps://qiita.com/chachay/items/052406176c55dd5b9a6a
の検証方法1を参考に

python

1presteps = 10 2net.predictor.reset_state() 3print(net.predictor(chainer.Variable(np.arange(10,dtype='float32').reshape((-1,1))))) 4for i in range(presteps): 5 y = net.predictor(chainer.Variable(np.roll(x_train,i).reshape((-1,1)))) 6 7 8plt.plot(range(n_train),np.roll(y.data,-presteps),label='lstm') 9plt.plot(range(n_train),x_train,label='train') 10plt.legend() 11plt.show()

を追記して検証したところ、画像のように、ちゃんと学習が進んでいる様子が見られました。
イメージ説明

###わからないこと
さて、ここからが質問なのですが、検証方法について疑問があります。

  1. presteps回のループの中でyに代入されるのは最後のpredictorの結果だけだと思うのですが、なぜfor文を回しているのか
  2. predictorが何をしているのか調べるためにprint(net.predictor(chainer.Variable(np.arange(10,dtype='float32').reshape((-1,1)))))を試してみました。すると、(10,1)の配列のVariableが返ってきました。自分の作ったLSTMモデルはlstmクラスで示したように、「任意の長さの配列を入力すると長さ5のLSTMを介して1つの出力が返ってくる」というモデルのつもりだったのですが、なぜこうなっているのでしょう?

確かに学習コードを書く際、x,tで学習を行うときにtにも10の長さの正解ラベルを与えたので変だなと思ってはいたのですが…

LSTMはまだ詳しい人がそう多くないせいか、回答をなかなかいただけないので、多少自信がなくても「こうなのでは」とご意見をいただけると嬉しいです。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

1についてはいまだによくわかってないですが、2についてはわかりました。
(10,1)の10はバッチ数、1が出力だったので、想定したモデルで間違ってませんでした。

投稿2019/05/06 11:51

sodiumplus3

総合スコア71

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問