質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.47%
Chainer

Chainerは、国産の深層学習フレームワークです。あらゆるニューラルネットワークをPythonで柔軟に書くことができ、学習させることが可能。GPUをサポートしており、複数のGPUを用いた学習も直感的に記述できます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

0回答

1860閲覧

chainerのL.Linerに値を渡すときのエラー

Horin

総合スコア9

Chainer

Chainerは、国産の深層学習フレームワークです。あらゆるニューラルネットワークをPythonで柔軟に書くことができ、学習させることが可能。GPUをサポートしており、複数のGPUを用いた学習も直感的に記述できます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2018/10/18 13:16

編集2018/10/18 13:19

chainer,機械学習ともに初心者でお恥ずかしいコードになっていると思いますが、エラーの解決ご協力お願いします。
将来的にはattentionの可視化まで行いたいと思っています。
おかしいのは最後から2番目の行の
y = self.l4(h3)
でh3に問題があると思います。

python

1import chainer 2import chainer.functions as F 3import chainer.links as L 4from chainer import Variable 5import numpy as np 6 7 8class AttentionDRNN(chainer.Chain): 9 10 def __init__(self, in_s, n_units,out_s, train=True): 11 super(AttentionDRNN, self).__init__( 12 l1=L.LSTM(in_s, n_units), 13 l2=L.LSTM(n_units, n_units), 14 l3=L.LSTM(n_units, n_units), 15 l4=L.Linear(n_units, out_s), 16 ) 17 self.train = train 18 self.i = 0 19 self.gl =[] 20 self.wei = [] 21 def reset_state(self): 22 self.l1.reset_state() 23 self.l2.reset_state() 24 self.l3.reset_state() 25 26 def __call__(self, x, train=True): 27 weigh = [] #ウェイトを記録するためのリストの初期化 28 batchsize = 5 # ミニバッチのサイズを記憶 29 sum = 0 30 h1 = self.l1(x) 31 h2 = self.l2(h1) 32 h3 = self.l3(h2) 33 self.gl.append(h3) 34 self.i += 1 35 if self.i==1 : 36 y = self.l4(h3) 37 return y 38 else: 39 for num, ght in enumerate(h3): 40 we = F.matmul(self.gl[self.i-1][0], ght) 41 we = F.exp(we) # softmax関数を使って正規化する 42 weigh.append(we) 43 sum += we 44 if len(ght)==2: 45 batchsize = 2 46 weigh = np.array(weigh) 47 sum = np.array(sum) 48 weigh = weigh / sum 49 weigh = weigh.reshape(5, 1).repeat(60, axis=1) 50 att_f = np.array(h3) * np.array(weigh) 51 att_f = F.reshape(att_f, (5, 60)) 52 print(att_f.shape) 53 h3 = att_f 54 y = self.l4(h3) 55 return y

packages\chainer\utils\type_check.py", line 482, in expect
'{0} {1} {2}'.format(left, self.inv, right))
chainer.utils.type_check.InvalidType:
Invalid operation is performed in: LinearFunction (Forward)

Expect: in_types[0].dtype.kind == f
Actual: O != f

Process finished with exit code 1

このエラーが解決できないのですが調査してみた結果
h2を出力したとき以下のような値が出るのに対して
variable([[ 0.01771316 -0.01998963 0.00072347 -0.00977764 -0.02070782
0.01098927 0.01607024 -0.01494737 -0.04490452 0.04066181

h3を出力したとき以下のようになってるからではないかと考えたのですが直し方がわかりません。
variable([[variable(-0.00482196) variable(-0.00196443)
variable(0.00170961) variable(-0.00334108) variable(0.0049839)

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.47%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問