###前提・実現したいこと
現在ChainerのLSTMで予測モデルを作り、それを使って語義曖昧性解消の予測を行うということをやっております。
その中でpickle.loadを行う際に引数に関するエラーが発生し、どうしても解決できなかったため、どなたか助けていただけると幸いです。
###発生している問題・エラーメッセージ
$ python lstm_prediction.py Traceback (most recent call last): File "lstm_prediction.py", line 139, in <module> main() File "lstm_prediction.py", line 112, in main lstm = pickle.load(fl) TypeError: __init__() takes from 2 to 4 positional arguments but 5 were given
###予測モデルを作成するソースコード
lstm.py
1import argparse 2import os 3import sys 4import numpy as np 5import chainer 6from chainer import optimizers 7import chainer.functions as F 8import chainer.links as L 9import pickle 10# import cupy 11 12global vocab 13global n_vocab 14global inv_vocab 15 16vocab = {'<$>':0, '<s>':1, '<eos>':2} 17n_vocab = len(vocab) 18inv_vocab = {0:'<$>', 1:'<s>', 2:'<eos>'} #逆引き辞書 19 20n_units = 512 # 隠れ層のユニット数 21 22 23# LSTMのネットワーク定義 24class LSTM(chainer.Chain): 25 state = {} 26 27 def __init__(self, n_vocab, n_units): 28 print(n_vocab, n_units) 29 super(LSTM, self).__init__( 30 l1_embed = L.EmbedID(n_vocab, n_units), 31 l1_x = L.Linear(n_units, 4 * n_units), 32 l1_h = L.Linear(n_units, 4 * n_units), 33 l2_embed = L.EmbedID(n_vocab, n_units), 34 l2_x = L.Linear(n_units, 4 * n_units), 35 l2_h = L.Linear(n_units, 4 * n_units), 36 l3_embed = L.EmbedID(n_vocab, n_units), 37 l3_x = L.Linear(n_units, 4 * n_units), 38 l3_h = L.Linear(n_units, 4 * n_units), 39 l4_embed = L.EmbedID(n_vocab, n_units), 40 l4_x = L.Linear(n_units, 4 * n_units), 41 l4_h = L.Linear(n_units, 4 * n_units), 42 l5_embed = L.EmbedID(n_vocab, n_units), 43 l5_x = L.Linear(n_units, 4 * n_units), 44 l5_h = L.Linear(n_units, 4 * n_units), 45 l_umembed = L.Linear(n_units, n_vocab) 46 ) 47 48 def forward(self, x1, x2, x3, x4, x5, t, train=True, dropout_ratio=0.5): 49 h1 = self.l1_embed(chainer.Variable(np.asarray([x1], dtype=np.int32))) 50 c1, y1 = F.lstm(chainer.Variable(np.zeros((1, n_units), dtype=np.float32)), F.dropout(self.l1_x(h1), ratio=dropout_ratio, train=train ) + self.l1_h(self.state['y1'])) 51 h2 = self.l2_embed(chainer.Variable(np.asarray([x2]))) 52 c2, y2 = F.lstm(self.state['c1'], F.dropout(self.l2_x(h2), ratio=dropout_ratio, train=train) + self.l2_h(self.state['y2'])) 53 h3 = self.l3_embed(chainer.Variable(np.asarray([x3]))) 54 c3, y3 = F.lstm(self.state['c2'], F.dropout(self.l3_x(h3), ratio=dropout_ratio, train=train) + self.l3_h(self.state['y3'])) 55 h4 = self.l4_embed(chainer.Variable(np.asarray([x4]))) 56 c4, y4 = F.lstm(self.state['c3'], F.dropout(self.l4_x(h4), ratio=dropout_ratio, train=train) + self.l4_h(self.state['y4'])) 57 h5 = self.l5_embed(chainer.Variable(np.asarray([x5]))) 58 c5, y5 = F.lstm(self.state['c4'], F.dropout(self.l5_x(h5), ratio=dropout_ratio, train=train) + self.l5_h(self.state['y5'])) 59 self.state = {'c1': c1, 'y1': y1, 'h1': h1, 'c2': c2, 'y2': y2, 'h2': h2, 'c3': c3, 'y3': y3, 'h3': h3, 'c4': c4, 'y4': y4, 'h4': h4, 'c5': c5, 'y5': y5, 'h5': h5} 60 y = self.l_umembed(y5) 61 #print('y:',vars(y)) 62 #print('ans:', np.asarray([t]), 'pred:', np.argmax(y.data)) 63 if train: 64 return F.softmax_cross_entropy(y, np.asarray([t])) 65 else: 66 return F.softmax(y5), y5.data 67 68 def initialize_state(self, n_units, batchsize=1, train=True): 69 for name in ('c1', 'y1', 'h1', 'c2', 'y2', 'h2', 'c3', 'y3', 'h3', 'c4', 'y4', 'h4', 'c5', 'y5', 'h5'): 70 self.state[name] = chainer.Variable(np.zeros((batchsize, n_units), dtype=np.float32), volatile=not train) 71 72~~~~~~ 73def main(): 74 p = 5 # 文字列長 75 w = 2 # 前後の単語の数 76 epoch = 2 #繰り返し回数 77 total_loss = 0 # 誤差関数の値を入れる変数 78 79 # 引数の処理 80 parser = argparse.ArgumentParser() 81 parser.add_argument('--gpu', '-g', default=-1, type=int, 82 help='GPU ID (negative value indicates CPU)') 83 args = parser.parse_args() 84 85 86 # 訓練データ、評価データ、テストデータの読み込み 87 os.chdir('/Users/suguruoki/practice/chainer/examples/ptb') 88 train_data = load_data('last4.dat') 89 pickle.dump(vocab, open('vocab.bin', 'wb')) 90 pickle.dump(inv_vocab, open('inv_vocab.bin', 'wb')) 91 92 n_vocab = len(vocab) 93 #print(train_data[0:1000]) 94 95 # モデルの準備 96 # 入力は単語数、中間層はmain関数冒頭で定義 97 lstm = LSTM(n_vocab, n_units) 98 lstm.initialize_state(n_units) 99 model = L.Classifier(lstm) 100 model.compute_accuracy = False 101 102 if args.gpu >= 0: 103 xp = cupy 104 else: np 105 106 if args.gpu >= 0: 107 cupy.get_device(args.gpu).use() 108 model.to_gpu() 109 110 # optimizerの設定 111 optimizer = optimizers.Adam() 112 optimizer.setup(model) 113 seq = [] 114 seq_append = seq.append 115 116 # 5単語毎に分ける 117 for i in range(epoch): # 同じデータをepoch数繰り返す 118 print('Epoch =', i+1, file=sys.stderr) 119 length = len(train_data) 120 121 # 単語を順番に走査 122 for t in range(length): 123 124 print("{}/{}".format(t, length)) 125 # 文頭、文末を考慮する 126 127 for k in range(t-w, t+w+1): 128 if k >= 0: 129 if k == t: 130 seq.append(vocab['<$>']) 131 elif k > len(train_data)-1: 132 seq.append(vocab['<s>']) 133 else: 134 seq.append(train_data[k]) 135 else: 136 seq.append(vocab['<s>']) 137 seq.append(train_data[t]) 138 # print('t =', t,', seq :', seq) 139 tmp = np.array(seq, dtype='i') 140 seq = [] # seq: 周辺単語のリスト 141 142 loss = lstm.forward(tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5]) 143 tmp = [] 144 145 # 出力する時はlossを記憶 146 if i%epoch==0: 147 total_loss += loss.data 148 # 最適化の実行 149 model.cleargrads() 150 loss.backward() 151 optimizer.update() 152 # 学習結果を1epochごとにファイルに保存する 153 # model.to_cpu() 154 pickle.dump(model, open('LSTMmodel.pkl', 'wb')) 155 pickle.dump(lstm, open('LSTMlstm.pkl', 'wb')) 156 157 158~~~~省略 159
↑の全ソースはこちら
予測を行うソースコード(エラーが出ているのはこちらのファイルです)
lstm_prediction.py
1 2~~~~省略 3 4n_units = 512 # 隠れ層のユニット数 5 6# LSTMのネットワーク定義 7class LSTM(chainer.Chain): 8 state = {} 9 10 def __init__(self, n_vocab, n_units): 11 #print(n_vocab, n_units) 12 super(LSTM, self).__init__( 13 l1_embed = L.EmbedID(n_vocab, n_units), 14 l1_x = L.Linear(n_units, 4 * n_units), 15 l1_h = L.Linear(n_units, 4 * n_units), 16 l2_embed = L.EmbedID(n_vocab, n_units), 17 l2_x = L.Linear(n_units, 4 * n_units), 18 l2_h = L.Linear(n_units, 4 * n_units), 19 l3_embed = L.EmbedID(n_vocab, n_units), 20 l3_x = L.Linear(n_units, 4 * n_units), 21 l3_h = L.Linear(n_units, 4 * n_units), 22 l4_embed = L.EmbedID(n_vocab, n_units), 23 l4_x = L.Linear(n_units, 4 * n_units), 24 l4_h = L.Linear(n_units, 4 * n_units), 25 l5_embed = L.EmbedID(n_vocab, n_units), 26 l5_x = L.Linear(n_units, 4 * n_units), 27 l5_h = L.Linear(n_units, 4 * n_units), 28 l_umembed = L.Linear(n_units, n_vocab) 29 ) 30 31 32~~~~省略 33 34 def initialize_state(self, n_units, batchsize=1, train=True): 35 for name in ('c1', 'y1', 'h1', 'c2', 'y2', 'h2', 'c3', 'y3', 'h3', 'c4', 'y4', 'h4', 'c5', 'y5', 'h5'): 36 self.state[name] = chainer.Variable(np.zeros((batchsize, n_units), dtype=np.float32), volatile=not train) 37 38 39~~~~省略 40 41def main(): 42 ''' main関数 ''' 43 p = 5 # 文字列長 44 w = 2 # 前後の単語の数 45 total_loss = 0 # 誤差関数の値を入れる変数 46 vocab = {} 47 n_vocab = len(vocab) 48 inv_vocab={} #逆引き辞書 49 50 # 引数の処理 51 parser = argparse.ArgumentParser() 52 parser.add_argument('--gpu', '-g', default=-1, type=int, 53 help='GPU ID (negative value indicates CPU)') 54 args = parser.parse_args() 55 # cuda環境では以下のようにすればよい 56 xp = cuda.cupy if args.gpu >= 0 else np 57 if args.gpu >= 0: 58 cuda.get_device(args.gpu).use() 59 model.to_gpu() 60 61 with open('vocab.bin', 'rb') as fv: 62 vocab = pickle.load(fv) 63 with open('inv_vocab.bin', 'rb') as fi: 64 inv_vocab = pickle.load(fi) 65 # 訓練データ、評価データ、テストデータの読み込み 66 os.chdir('/Users/suguruoki/practice/chainer/examples/ptb/') 67 test_data = load_data('117-test-data.dat', vocab, inv_vocab) 68 69 n_vocab = len(vocab) 70 71 # モデルの準備 72 # 入力は単語数、中間層はmain関数冒頭で定義 73 lstm = LSTM(n_vocab , n_units) 74 lstm.initialize_state(n_units) 75 model = L.Classifier(lstm) 76 model.compute_accuracy = False 77 with open('LSTMlstm.pkl', 'rb') as fl: 78 lstm = pickle.load(fl) # =>エラーが出ているのはこの行です。 79 with open('LSTMmodel.pkl', 'rb') as fm: 80 model = pickle.load(fm) 81 82~~~~省略 83 84
↑の全ソースコードはこちら
###補足情報(言語/FW/ツール等のバージョンなど)
- Python => Python 3.5.2 :: Anaconda custom (x86_64)
- chainer => 1.9.0
また何か必要な情報があれば、教えていただけると幸いです。
よろしくお願いいたします。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2017/07/18 09:21
2017/07/18 09:40
2017/07/18 09:51
2017/07/18 10:10 編集
2017/07/20 10:27
2017/07/21 07:07
2017/07/24 01:21