質問編集履歴
2
解決
test
CHANGED
File without changes
|
test
CHANGED
@@ -1,743 +1 @@
|
|
1
1
|
ChainerでRNNを使って自動文章生成したいのですが、TypeError: Can't broadcastという調べてもよくわからないエラーがでます。
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
ソースは「Chainerで作るコンテンツ自動生成AIプログラミング入門」という本のコピペです。
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
環境はgoogle colabです。
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
cudaとcupy、chainerのバージョンも合ってます
|
14
|
-
|
15
|
-
プログラムは2つです
|
16
|
-
|
17
|
-
以下が用意したテキストファイルをRNNで学習させるプログラムchapt07-2.pyです。
|
18
|
-
|
19
|
-
ちなみにテキストファイルは2万行の俳句です。
|
20
|
-
|
21
|
-
```
|
22
|
-
|
23
|
-
import chainer
|
24
|
-
|
25
|
-
import chainer.functions as F
|
26
|
-
|
27
|
-
import chainer.links as L
|
28
|
-
|
29
|
-
from chainer import training, datasets, iterators, optimizers
|
30
|
-
|
31
|
-
from chainer.training import extensions
|
32
|
-
|
33
|
-
import numpy as np
|
34
|
-
|
35
|
-
import codecs
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
batch_size = 10 # バッチサイズ10
|
40
|
-
|
41
|
-
uses_device = 0 # GPU#0を使用
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
# GPU使用時とCPU使用時でデータ形式が変わる
|
46
|
-
|
47
|
-
if uses_device >= 0:
|
48
|
-
|
49
|
-
import cupy as cp
|
50
|
-
|
51
|
-
else:
|
52
|
-
|
53
|
-
cp = np
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
# RNNの定義をするクラス
|
58
|
-
|
59
|
-
class Parses_Genarate_RNN(chainer.Chain):
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
def __init__(self, n_words, nodes):
|
64
|
-
|
65
|
-
super(Parses_Genarate_RNN, self).__init__()
|
66
|
-
|
67
|
-
with self.init_scope():
|
68
|
-
|
69
|
-
self.embed = L.EmbedID(n_words, n_words)
|
70
|
-
|
71
|
-
self.l1 = L.LSTM(n_words, nodes)
|
72
|
-
|
73
|
-
self.l2 = L.LSTM(nodes, nodes)
|
74
|
-
|
75
|
-
self.l3 = L.Linear(nodes, n_words)
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
def reset_state(self):
|
80
|
-
|
81
|
-
self.l1.reset_state()
|
82
|
-
|
83
|
-
self.l2.reset_state()
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
def __call__(self, x):
|
88
|
-
|
89
|
-
h0 = self.embed(x)
|
90
|
-
|
91
|
-
h1 = self.l1(h0)
|
92
|
-
|
93
|
-
h2 = self.l2(h1)
|
94
|
-
|
95
|
-
y = self.l3(h2)
|
96
|
-
|
97
|
-
return y
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
# カスタムUpdaterのクラス
|
102
|
-
|
103
|
-
class RNNUpdater(training.StandardUpdater):
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
def __init__(self, train_iter, optimizer, device):
|
108
|
-
|
109
|
-
super(RNNUpdater, self).__init__(
|
110
|
-
|
111
|
-
train_iter,
|
112
|
-
|
113
|
-
optimizer,
|
114
|
-
|
115
|
-
device=device
|
116
|
-
|
117
|
-
)
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
def update_core(self):
|
122
|
-
|
123
|
-
# 累積してゆく損失
|
124
|
-
|
125
|
-
loss = 0
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
# IteratorとOptimizerを取得
|
130
|
-
|
131
|
-
train_iter = self.get_iterator('main')
|
132
|
-
|
133
|
-
optimizer = self.get_optimizer('main')
|
134
|
-
|
135
|
-
# ニューラルネットワークを取得
|
136
|
-
|
137
|
-
model = optimizer.target
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
# 文を一バッチ取得
|
142
|
-
|
143
|
-
x = train_iter.__next__()
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
# RNNのステータスをリセットする
|
148
|
-
|
149
|
-
model.reset_state()
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
# 分の長さだけ繰り返しRNNに学習
|
154
|
-
|
155
|
-
for i in range(len(x[0])-1):
|
156
|
-
|
157
|
-
# バッチ処理用の配列に
|
158
|
-
|
159
|
-
batch = cp.array([s[i] for s in x], dtype=cp.int32)
|
160
|
-
|
161
|
-
# 正解データ(次の文字)の配列
|
162
|
-
|
163
|
-
t = cp.array([s[i+1] for s in x], dtype=cp.int32)
|
164
|
-
|
165
|
-
# 全部が終端文字ならそれ以上学習する必要は無い
|
166
|
-
|
167
|
-
if cp.min(batch) == 1 and cp.max(batch) == 1:
|
168
|
-
|
169
|
-
break
|
170
|
-
|
171
|
-
# 一つRNNを実行
|
172
|
-
|
173
|
-
y = model(batch)
|
174
|
-
|
175
|
-
# 結果との比較
|
176
|
-
|
177
|
-
loss += F.softmax_cross_entropy(y, t)
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
# 重みデータを一旦リセットする
|
182
|
-
|
183
|
-
optimizer.target.cleargrads()
|
184
|
-
|
185
|
-
# 誤差関数から逆伝播する
|
186
|
-
|
187
|
-
loss.backward()
|
188
|
-
|
189
|
-
# 新しい重みデータでアップデートする
|
190
|
-
|
191
|
-
optimizer.update()
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
# ファイルを読み込む
|
196
|
-
|
197
|
-
s = codecs.open('all-sentences-parses.txt', 'r', 'utf8')
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
# 全ての文
|
202
|
-
|
203
|
-
sentence = []
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
# 1行ずつ処理する
|
208
|
-
|
209
|
-
line = s.readline()
|
210
|
-
|
211
|
-
while line:
|
212
|
-
|
213
|
-
# 一つの文
|
214
|
-
|
215
|
-
one = [0] # 開始文字だけ
|
216
|
-
|
217
|
-
# 行の中の単語を数字のリストにして追加
|
218
|
-
|
219
|
-
one.extend(list(map(int,line.split(','))))
|
220
|
-
|
221
|
-
# 行が終わったところで終端文字を入れる
|
222
|
-
|
223
|
-
one.append(1)
|
224
|
-
|
225
|
-
# 新しい文を追加
|
226
|
-
|
227
|
-
sentence.append(one)
|
228
|
-
|
229
|
-
line = s.readline()
|
230
|
-
|
231
|
-
s.close()
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
# 単語の種類
|
236
|
-
|
237
|
-
n_word = max([max(l) for l in sentence]) + 1
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
# 最長の文の長さ
|
242
|
-
|
243
|
-
l_max = max([len(l) for l in sentence])
|
244
|
-
|
245
|
-
# バッチ処理の都合で全て同じ長さに揃える必要がある
|
246
|
-
|
247
|
-
for i in range(len(sentence)):
|
248
|
-
|
249
|
-
# 足りない長さは終端文字で埋める
|
250
|
-
|
251
|
-
sentence[i].extend([1]*(l_max-len(sentence[i])))
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
# ニューラルネットワークの作成
|
256
|
-
|
257
|
-
model = Parses_Genarate_RNN(n_word, 100)
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
if uses_device >= 0:
|
262
|
-
|
263
|
-
# GPUを使う
|
264
|
-
|
265
|
-
chainer.cuda.get_device_from_id(0).use()
|
266
|
-
|
267
|
-
chainer.cuda.check_cuda_available()
|
268
|
-
|
269
|
-
# GPU用データ形式に変換
|
270
|
-
|
271
|
-
model.to_gpu()
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
# 誤差逆伝播法アルゴリズムを選択
|
276
|
-
|
277
|
-
optimizer = optimizers.Adam()
|
278
|
-
|
279
|
-
optimizer.setup(model)
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
# Iteratorを作成
|
284
|
-
|
285
|
-
train_iter = iterators.SerialIterator(sentence, batch_size, shuffle=False)
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
# デバイスを選択してTrainerを作成する
|
290
|
-
|
291
|
-
updater = RNNUpdater(train_iter, optimizer, device=uses_device)
|
292
|
-
|
293
|
-
trainer = training.Trainer(updater, (100, 'epoch'), out="result")
|
294
|
-
|
295
|
-
# 学習の進展を表示するようにする
|
296
|
-
|
297
|
-
trainer.extend(extensions.ProgressBar(update_interval=1))
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
# 機械学習を実行する
|
302
|
-
|
303
|
-
trainer.run()
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
# 学習結果を保存する
|
308
|
-
|
309
|
-
chainer.serializers.save_hdf5( 'chapt07.hdf5', model )
|
310
|
-
|
311
|
-
```
|
312
|
-
|
313
|
-
以下が先ほど作成した学習結果を元に文章自動生成するプログラムchapt07-4.pyです。
|
314
|
-
|
315
|
-
word2vecのモデルを使用ますが、これが壊れているという事は考えにくいです
|
316
|
-
|
317
|
-
```
|
318
|
-
|
319
|
-
import torch
|
320
|
-
|
321
|
-
import torchvision
|
322
|
-
|
323
|
-
import torchvision.transforms as transforms
|
324
|
-
|
325
|
-
from torch import nn, optim
|
326
|
-
|
327
|
-
import torch.nn.functional as F
|
328
|
-
|
329
|
-
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
330
|
-
|
331
|
-
import numpy as np
|
332
|
-
|
333
|
-
import sys
|
334
|
-
|
335
|
-
import codecs
|
336
|
-
|
337
|
-
from gensim.models import word2vec
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
trainset = torchvision.datasets.MNIST(root='./data',
|
342
|
-
|
343
|
-
train=True,
|
344
|
-
|
345
|
-
download=True,
|
346
|
-
|
347
|
-
transform=transforms.ToTensor())
|
348
|
-
|
349
|
-
trainloader = torch.utils.data.DataLoader(trainset,
|
350
|
-
|
351
|
-
batch_size=batch_size,
|
352
|
-
|
353
|
-
shuffle=True)
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
testset = torchvision.datasets.MNIST(root='./data',
|
358
|
-
|
359
|
-
train=False,
|
360
|
-
|
361
|
-
download=True,
|
362
|
-
|
363
|
-
transform=transforms.ToTensor())
|
364
|
-
|
365
|
-
testloader = torch.utils.data.DataLoader(testset,
|
366
|
-
|
367
|
-
batch_size=batch_size,
|
368
|
-
|
369
|
-
shuffle=False)
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
# GPU使用時とCPU使用時でデータ形式が変わる
|
376
|
-
|
377
|
-
if uses_device >= 0:
|
378
|
-
|
379
|
-
import cupy as cp
|
380
|
-
|
381
|
-
import chainer.cuda
|
382
|
-
|
383
|
-
else:
|
384
|
-
|
385
|
-
cp = np
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
sys.stdout = codecs.getwriter('utf_8')(sys.stdout)
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
# RNNの定義をするクラス
|
394
|
-
|
395
|
-
class Parses_Genarate_RNN(nn.Module):
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
def __init__(self, n_words, nodes):
|
400
|
-
|
401
|
-
super(Parses_Genarate_RNN, self).__init__()
|
402
|
-
|
403
|
-
with self.init_scope():
|
404
|
-
|
405
|
-
self.embed = L.EmbedID(n_words, n_words)
|
406
|
-
|
407
|
-
self.l1 = L.LSTM(n_words, nodes)
|
408
|
-
|
409
|
-
self.l2 = L.LSTM(nodes, nodes)
|
410
|
-
|
411
|
-
self.l3 = L.Linear(nodes, n_words)
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
def reset_state(self):
|
416
|
-
|
417
|
-
self.l1.reset_state()
|
418
|
-
|
419
|
-
self.l2.reset_state()
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
def __call__(self, x):
|
424
|
-
|
425
|
-
h0 = self.embed(x)
|
426
|
-
|
427
|
-
h1 = self.l1(h0)
|
428
|
-
|
429
|
-
h2 = self.l2(h1)
|
430
|
-
|
431
|
-
y = self.l3(h2)
|
432
|
-
|
433
|
-
return y
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
# ファイルを読み込む
|
438
|
-
|
439
|
-
w = codecs.open('all-words-parses.txt', 'r', 'utf8')
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
# 単語の一覧
|
444
|
-
|
445
|
-
words_parse = {}
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
# 1行ずつ処理する
|
450
|
-
|
451
|
-
line = w.readline()
|
452
|
-
|
453
|
-
while line:
|
454
|
-
|
455
|
-
# 行の中の単語をリストする
|
456
|
-
|
457
|
-
l = line.split(',')
|
458
|
-
|
459
|
-
if len(l) == 2:
|
460
|
-
|
461
|
-
r = int(l[0].strip())
|
462
|
-
|
463
|
-
if r in words_parse:
|
464
|
-
|
465
|
-
words_parse[r].append(l[1].strip())
|
466
|
-
|
467
|
-
else:
|
468
|
-
|
469
|
-
words_parse[r] = [l[1].strip()]
|
470
|
-
|
471
|
-
line = w.readline()
|
472
|
-
|
473
|
-
w.close()
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
# ニューラルネットワークの作成
|
478
|
-
|
479
|
-
model = Parses_Genarate_RNN(max(words_parse.keys())+1, 20)
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
# 学習結果を読み込む
|
484
|
-
|
485
|
-
chainer.serializers.load_hdf5( 'chapt07.hdf5', model )
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
if uses_device >= 0:
|
490
|
-
|
491
|
-
# GPUを使う
|
492
|
-
|
493
|
-
chainer.cuda.get_device_from_id(0).use()
|
494
|
-
|
495
|
-
chainer.cuda.check_cuda_available()
|
496
|
-
|
497
|
-
# GPU用データ形式に変換
|
498
|
-
|
499
|
-
model.to_gpu()
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
# 木探索で生成する最大の深さ
|
504
|
-
|
505
|
-
words_max = 50
|
506
|
-
|
507
|
-
# RNNの実行結果から検索する単語の数
|
508
|
-
|
509
|
-
beam_w = 3
|
510
|
-
|
511
|
-
# 生成した文のリスト
|
512
|
-
|
513
|
-
parses = []
|
514
|
-
|
515
|
-
# 木探索のスタック
|
516
|
-
|
517
|
-
model_history = [model]
|
518
|
-
|
519
|
-
# 現在生成中の文
|
520
|
-
|
521
|
-
cur_parses = [0] # 開始文字
|
522
|
-
|
523
|
-
# 現在生成中の文のスコア
|
524
|
-
|
525
|
-
cur_score = []
|
526
|
-
|
527
|
-
# 最大のスコア
|
528
|
-
|
529
|
-
max_score = 0
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
# 再帰関数の木探索
|
534
|
-
|
535
|
-
def Tree_Traverse():
|
536
|
-
|
537
|
-
global max_score
|
538
|
-
|
539
|
-
# 現在の品詞を取得する
|
540
|
-
|
541
|
-
cur_parse = cur_parses[-1]
|
542
|
-
|
543
|
-
# 文のスコア
|
544
|
-
|
545
|
-
score = np.prod(cur_score)
|
546
|
-
|
547
|
-
# 現在の文の長さ
|
548
|
-
|
549
|
-
deep = len(cur_parses)
|
550
|
-
|
551
|
-
# 枝刈り - 単語数が5以上で最大スコアの6割以下なら、終わる
|
552
|
-
|
553
|
-
if max_score > 0 and deep > 5 and max_score * 0.6 > score:
|
554
|
-
|
555
|
-
return
|
556
|
-
|
557
|
-
# 終了文字か、最大の文の長さ以上なら、品詞を追加して終わる
|
558
|
-
|
559
|
-
if cur_parse == 1 or deep > words_max:
|
560
|
-
|
561
|
-
# 文のデータをコピー
|
562
|
-
|
563
|
-
data = np.array(cur_parses)
|
564
|
-
|
565
|
-
# 文を追加
|
566
|
-
|
567
|
-
parses.append((score, data))
|
568
|
-
|
569
|
-
# 最大スコアを更新
|
570
|
-
|
571
|
-
if max_score < score:
|
572
|
-
|
573
|
-
max_score = score
|
574
|
-
|
575
|
-
return
|
576
|
-
|
577
|
-
# 現在のニューラルネットワークのステータスをコピーする
|
578
|
-
|
579
|
-
cur_model = model_history[-1].copy()
|
580
|
-
|
581
|
-
# 入力値を作る
|
582
|
-
|
583
|
-
x = cp.array([cur_parse], dtype=cp.int32)
|
584
|
-
|
585
|
-
# ニューラルネットワークに入力する
|
586
|
-
|
587
|
-
y = cur_model(x)
|
588
|
-
|
589
|
-
# 実行結果を正規化する
|
590
|
-
|
591
|
-
z = F.softmax(y)
|
592
|
-
|
593
|
-
# 結果のデータを取得
|
594
|
-
|
595
|
-
result = z.data[0]
|
596
|
-
|
597
|
-
if uses_device >= 0:
|
598
|
-
|
599
|
-
result = chainer.cuda.to_cpu(result)
|
600
|
-
|
601
|
-
# 結果を確立順に並べ替える
|
602
|
-
|
603
|
-
p = np.argsort(result)[::-1]
|
604
|
-
|
605
|
-
# 現在のニューラルネットワークのステータスを保存する
|
606
|
-
|
607
|
-
model_history.append(cur_model)
|
608
|
-
|
609
|
-
# 結果から上位のものを次の枝に回す
|
610
|
-
|
611
|
-
for i in range(beam_w):
|
612
|
-
|
613
|
-
# 現在生成中の文に一文字追加する
|
614
|
-
|
615
|
-
cur_parses.append(p[i])
|
616
|
-
|
617
|
-
# 現在生成中の文のスコアに一つ追加する
|
618
|
-
|
619
|
-
cur_score.append(result[p[i]])
|
620
|
-
|
621
|
-
# 再帰呼び出し
|
622
|
-
|
623
|
-
Tree_Traverse()
|
624
|
-
|
625
|
-
# 現在生成中の文を一つ戻す
|
626
|
-
|
627
|
-
cur_parses.pop()
|
628
|
-
|
629
|
-
# 現在生成中の文のスコアを一つ戻す
|
630
|
-
|
631
|
-
cur_score.pop()
|
632
|
-
|
633
|
-
# ニューラルネットワークのステータスを一つ戻す
|
634
|
-
|
635
|
-
model_history.pop()
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
# 木検索して文章を生成する
|
640
|
-
|
641
|
-
Tree_Traverse()
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
# Word2Vecのモデルを読み込む
|
646
|
-
|
647
|
-
word_vec = word2vec.Word2Vec.load('word2vec.gensim.model')
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
# 文章のターゲット
|
652
|
-
|
653
|
-
target_str = ['元日']
|
654
|
-
|
655
|
-
#target_str = ['神']
|
656
|
-
|
657
|
-
#target_str = ['キリスト']
|
658
|
-
|
659
|
-
#target_str = ['父','子','聖霊']
|
660
|
-
|
661
|
-
#target_str = ['不思議','の','国','の','アリス']
|
662
|
-
|
663
|
-
#target_str = ['三月','うさぎ','の','お茶','会']
|
664
|
-
|
665
|
-
#target_str = ['女王']
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
# 指定した品詞の単語を文章がターゲットに近づくように返す
|
670
|
-
|
671
|
-
def similarity_word( parse, history ):
|
672
|
-
|
673
|
-
scores = []
|
674
|
-
|
675
|
-
# 品詞から候補をリスト
|
676
|
-
|
677
|
-
for i in range(len(words_parse[parse])):
|
678
|
-
|
679
|
-
w = words_parse[parse][i]
|
680
|
-
|
681
|
-
if w in word_vec:
|
682
|
-
|
683
|
-
# 候補のベクトルを履歴ベクトルに足す
|
684
|
-
|
685
|
-
t = history[:]
|
686
|
-
|
687
|
-
t.append(w)
|
688
|
-
|
689
|
-
# ターゲットとの距離を計算
|
690
|
-
|
691
|
-
sim = word_vec.n_similarity(target_str, t)
|
692
|
-
|
693
|
-
scores.append((sim, w))
|
694
|
-
|
695
|
-
# 結果をスコア順に並べ替える
|
696
|
-
|
697
|
-
result = sorted(scores, key=lambda x: x[0])[::-1]
|
698
|
-
|
699
|
-
return result[0]
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
# スコアの高いものから順に表示する
|
706
|
-
|
707
|
-
result_set = sorted(parses, key=lambda x: x[0])[::-1]
|
708
|
-
|
709
|
-
# 10個または全部の少ない方の数だけ表示
|
710
|
-
|
711
|
-
for i in range(min([10,len(result_set)])):
|
712
|
-
|
713
|
-
# 結果を取得
|
714
|
-
|
715
|
-
s, l = result_set[i]
|
716
|
-
|
717
|
-
# これまで登場した単語
|
718
|
-
|
719
|
-
history = []
|
720
|
-
|
721
|
-
# 開始文字と終端文字を除いてループ
|
722
|
-
|
723
|
-
for j in range(1,len(l)-1):
|
724
|
-
|
725
|
-
score, cur_word = similarity_word(l[j], history)
|
726
|
-
|
727
|
-
history.append(cur_word)
|
728
|
-
|
729
|
-
sys.stdout.buffer.write(cur_word.encode('utf-8'))
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
sys.stdout.buffer.write("\n".encode('utf-8'))
|
734
|
-
|
735
|
-
sys.stdout.buffer.flush()
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
```
|
1
ソース
test
CHANGED
File without changes
|
test
CHANGED
@@ -11,3 +11,733 @@
|
|
11
11
|
|
12
12
|
|
13
13
|
cudaとcupy、chainerのバージョンも合ってます
|
14
|
+
|
15
|
+
プログラムは2つです
|
16
|
+
|
17
|
+
以下が用意したテキストファイルをRNNで学習させるプログラムchapt07-2.pyです。
|
18
|
+
|
19
|
+
ちなみにテキストファイルは2万行の俳句です。
|
20
|
+
|
21
|
+
```
|
22
|
+
|
23
|
+
import chainer
|
24
|
+
|
25
|
+
import chainer.functions as F
|
26
|
+
|
27
|
+
import chainer.links as L
|
28
|
+
|
29
|
+
from chainer import training, datasets, iterators, optimizers
|
30
|
+
|
31
|
+
from chainer.training import extensions
|
32
|
+
|
33
|
+
import numpy as np
|
34
|
+
|
35
|
+
import codecs
|
36
|
+
|
37
|
+
|
38
|
+
|
39
|
+
batch_size = 10 # バッチサイズ10
|
40
|
+
|
41
|
+
uses_device = 0 # GPU#0を使用
|
42
|
+
|
43
|
+
|
44
|
+
|
45
|
+
# GPU使用時とCPU使用時でデータ形式が変わる
|
46
|
+
|
47
|
+
if uses_device >= 0:
|
48
|
+
|
49
|
+
import cupy as cp
|
50
|
+
|
51
|
+
else:
|
52
|
+
|
53
|
+
cp = np
|
54
|
+
|
55
|
+
|
56
|
+
|
57
|
+
# RNNの定義をするクラス
|
58
|
+
|
59
|
+
class Parses_Genarate_RNN(chainer.Chain):
|
60
|
+
|
61
|
+
|
62
|
+
|
63
|
+
def __init__(self, n_words, nodes):
|
64
|
+
|
65
|
+
super(Parses_Genarate_RNN, self).__init__()
|
66
|
+
|
67
|
+
with self.init_scope():
|
68
|
+
|
69
|
+
self.embed = L.EmbedID(n_words, n_words)
|
70
|
+
|
71
|
+
self.l1 = L.LSTM(n_words, nodes)
|
72
|
+
|
73
|
+
self.l2 = L.LSTM(nodes, nodes)
|
74
|
+
|
75
|
+
self.l3 = L.Linear(nodes, n_words)
|
76
|
+
|
77
|
+
|
78
|
+
|
79
|
+
def reset_state(self):
|
80
|
+
|
81
|
+
self.l1.reset_state()
|
82
|
+
|
83
|
+
self.l2.reset_state()
|
84
|
+
|
85
|
+
|
86
|
+
|
87
|
+
def __call__(self, x):
|
88
|
+
|
89
|
+
h0 = self.embed(x)
|
90
|
+
|
91
|
+
h1 = self.l1(h0)
|
92
|
+
|
93
|
+
h2 = self.l2(h1)
|
94
|
+
|
95
|
+
y = self.l3(h2)
|
96
|
+
|
97
|
+
return y
|
98
|
+
|
99
|
+
|
100
|
+
|
101
|
+
# カスタムUpdaterのクラス
|
102
|
+
|
103
|
+
class RNNUpdater(training.StandardUpdater):
|
104
|
+
|
105
|
+
|
106
|
+
|
107
|
+
def __init__(self, train_iter, optimizer, device):
|
108
|
+
|
109
|
+
super(RNNUpdater, self).__init__(
|
110
|
+
|
111
|
+
train_iter,
|
112
|
+
|
113
|
+
optimizer,
|
114
|
+
|
115
|
+
device=device
|
116
|
+
|
117
|
+
)
|
118
|
+
|
119
|
+
|
120
|
+
|
121
|
+
def update_core(self):
|
122
|
+
|
123
|
+
# 累積してゆく損失
|
124
|
+
|
125
|
+
loss = 0
|
126
|
+
|
127
|
+
|
128
|
+
|
129
|
+
# IteratorとOptimizerを取得
|
130
|
+
|
131
|
+
train_iter = self.get_iterator('main')
|
132
|
+
|
133
|
+
optimizer = self.get_optimizer('main')
|
134
|
+
|
135
|
+
# ニューラルネットワークを取得
|
136
|
+
|
137
|
+
model = optimizer.target
|
138
|
+
|
139
|
+
|
140
|
+
|
141
|
+
# 文を一バッチ取得
|
142
|
+
|
143
|
+
x = train_iter.__next__()
|
144
|
+
|
145
|
+
|
146
|
+
|
147
|
+
# RNNのステータスをリセットする
|
148
|
+
|
149
|
+
model.reset_state()
|
150
|
+
|
151
|
+
|
152
|
+
|
153
|
+
# 分の長さだけ繰り返しRNNに学習
|
154
|
+
|
155
|
+
for i in range(len(x[0])-1):
|
156
|
+
|
157
|
+
# バッチ処理用の配列に
|
158
|
+
|
159
|
+
batch = cp.array([s[i] for s in x], dtype=cp.int32)
|
160
|
+
|
161
|
+
# 正解データ(次の文字)の配列
|
162
|
+
|
163
|
+
t = cp.array([s[i+1] for s in x], dtype=cp.int32)
|
164
|
+
|
165
|
+
# 全部が終端文字ならそれ以上学習する必要は無い
|
166
|
+
|
167
|
+
if cp.min(batch) == 1 and cp.max(batch) == 1:
|
168
|
+
|
169
|
+
break
|
170
|
+
|
171
|
+
# 一つRNNを実行
|
172
|
+
|
173
|
+
y = model(batch)
|
174
|
+
|
175
|
+
# 結果との比較
|
176
|
+
|
177
|
+
loss += F.softmax_cross_entropy(y, t)
|
178
|
+
|
179
|
+
|
180
|
+
|
181
|
+
# 重みデータを一旦リセットする
|
182
|
+
|
183
|
+
optimizer.target.cleargrads()
|
184
|
+
|
185
|
+
# 誤差関数から逆伝播する
|
186
|
+
|
187
|
+
loss.backward()
|
188
|
+
|
189
|
+
# 新しい重みデータでアップデートする
|
190
|
+
|
191
|
+
optimizer.update()
|
192
|
+
|
193
|
+
|
194
|
+
|
195
|
+
# ファイルを読み込む
|
196
|
+
|
197
|
+
s = codecs.open('all-sentences-parses.txt', 'r', 'utf8')
|
198
|
+
|
199
|
+
|
200
|
+
|
201
|
+
# 全ての文
|
202
|
+
|
203
|
+
sentence = []
|
204
|
+
|
205
|
+
|
206
|
+
|
207
|
+
# 1行ずつ処理する
|
208
|
+
|
209
|
+
line = s.readline()
|
210
|
+
|
211
|
+
while line:
|
212
|
+
|
213
|
+
# 一つの文
|
214
|
+
|
215
|
+
one = [0] # 開始文字だけ
|
216
|
+
|
217
|
+
# 行の中の単語を数字のリストにして追加
|
218
|
+
|
219
|
+
one.extend(list(map(int,line.split(','))))
|
220
|
+
|
221
|
+
# 行が終わったところで終端文字を入れる
|
222
|
+
|
223
|
+
one.append(1)
|
224
|
+
|
225
|
+
# 新しい文を追加
|
226
|
+
|
227
|
+
sentence.append(one)
|
228
|
+
|
229
|
+
line = s.readline()
|
230
|
+
|
231
|
+
s.close()
|
232
|
+
|
233
|
+
|
234
|
+
|
235
|
+
# 単語の種類
|
236
|
+
|
237
|
+
n_word = max([max(l) for l in sentence]) + 1
|
238
|
+
|
239
|
+
|
240
|
+
|
241
|
+
# 最長の文の長さ
|
242
|
+
|
243
|
+
l_max = max([len(l) for l in sentence])
|
244
|
+
|
245
|
+
# バッチ処理の都合で全て同じ長さに揃える必要がある
|
246
|
+
|
247
|
+
for i in range(len(sentence)):
|
248
|
+
|
249
|
+
# 足りない長さは終端文字で埋める
|
250
|
+
|
251
|
+
sentence[i].extend([1]*(l_max-len(sentence[i])))
|
252
|
+
|
253
|
+
|
254
|
+
|
255
|
+
# ニューラルネットワークの作成
|
256
|
+
|
257
|
+
model = Parses_Genarate_RNN(n_word, 100)
|
258
|
+
|
259
|
+
|
260
|
+
|
261
|
+
if uses_device >= 0:
|
262
|
+
|
263
|
+
# GPUを使う
|
264
|
+
|
265
|
+
chainer.cuda.get_device_from_id(0).use()
|
266
|
+
|
267
|
+
chainer.cuda.check_cuda_available()
|
268
|
+
|
269
|
+
# GPU用データ形式に変換
|
270
|
+
|
271
|
+
model.to_gpu()
|
272
|
+
|
273
|
+
|
274
|
+
|
275
|
+
# 誤差逆伝播法アルゴリズムを選択
|
276
|
+
|
277
|
+
optimizer = optimizers.Adam()
|
278
|
+
|
279
|
+
optimizer.setup(model)
|
280
|
+
|
281
|
+
|
282
|
+
|
283
|
+
# Iteratorを作成
|
284
|
+
|
285
|
+
train_iter = iterators.SerialIterator(sentence, batch_size, shuffle=False)
|
286
|
+
|
287
|
+
|
288
|
+
|
289
|
+
# デバイスを選択してTrainerを作成する
|
290
|
+
|
291
|
+
updater = RNNUpdater(train_iter, optimizer, device=uses_device)
|
292
|
+
|
293
|
+
trainer = training.Trainer(updater, (100, 'epoch'), out="result")
|
294
|
+
|
295
|
+
# 学習の進展を表示するようにする
|
296
|
+
|
297
|
+
trainer.extend(extensions.ProgressBar(update_interval=1))
|
298
|
+
|
299
|
+
|
300
|
+
|
301
|
+
# 機械学習を実行する
|
302
|
+
|
303
|
+
trainer.run()
|
304
|
+
|
305
|
+
|
306
|
+
|
307
|
+
# 学習結果を保存する
|
308
|
+
|
309
|
+
chainer.serializers.save_hdf5( 'chapt07.hdf5', model )
|
310
|
+
|
311
|
+
```
|
312
|
+
|
313
|
+
以下が先ほど作成した学習結果を元に文章自動生成するプログラムchapt07-4.pyです。
|
314
|
+
|
315
|
+
word2vecのモデルを使用ますが、これが壊れているという事は考えにくいです
|
316
|
+
|
317
|
+
```
|
318
|
+
|
319
|
+
import torch
|
320
|
+
|
321
|
+
import torchvision
|
322
|
+
|
323
|
+
import torchvision.transforms as transforms
|
324
|
+
|
325
|
+
from torch import nn, optim
|
326
|
+
|
327
|
+
import torch.nn.functional as F
|
328
|
+
|
329
|
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
330
|
+
|
331
|
+
import numpy as np
|
332
|
+
|
333
|
+
import sys
|
334
|
+
|
335
|
+
import codecs
|
336
|
+
|
337
|
+
from gensim.models import word2vec
|
338
|
+
|
339
|
+
|
340
|
+
|
341
|
+
trainset = torchvision.datasets.MNIST(root='./data',
|
342
|
+
|
343
|
+
train=True,
|
344
|
+
|
345
|
+
download=True,
|
346
|
+
|
347
|
+
transform=transforms.ToTensor())
|
348
|
+
|
349
|
+
trainloader = torch.utils.data.DataLoader(trainset,
|
350
|
+
|
351
|
+
batch_size=batch_size,
|
352
|
+
|
353
|
+
shuffle=True)
|
354
|
+
|
355
|
+
|
356
|
+
|
357
|
+
testset = torchvision.datasets.MNIST(root='./data',
|
358
|
+
|
359
|
+
train=False,
|
360
|
+
|
361
|
+
download=True,
|
362
|
+
|
363
|
+
transform=transforms.ToTensor())
|
364
|
+
|
365
|
+
testloader = torch.utils.data.DataLoader(testset,
|
366
|
+
|
367
|
+
batch_size=batch_size,
|
368
|
+
|
369
|
+
shuffle=False)
|
370
|
+
|
371
|
+
|
372
|
+
|
373
|
+
|
374
|
+
|
375
|
+
# GPU使用時とCPU使用時でデータ形式が変わる
|
376
|
+
|
377
|
+
if uses_device >= 0:
|
378
|
+
|
379
|
+
import cupy as cp
|
380
|
+
|
381
|
+
import chainer.cuda
|
382
|
+
|
383
|
+
else:
|
384
|
+
|
385
|
+
cp = np
|
386
|
+
|
387
|
+
|
388
|
+
|
389
|
+
sys.stdout = codecs.getwriter('utf_8')(sys.stdout)
|
390
|
+
|
391
|
+
|
392
|
+
|
393
|
+
# RNNの定義をするクラス
|
394
|
+
|
395
|
+
class Parses_Genarate_RNN(nn.Module):
|
396
|
+
|
397
|
+
|
398
|
+
|
399
|
+
def __init__(self, n_words, nodes):
|
400
|
+
|
401
|
+
super(Parses_Genarate_RNN, self).__init__()
|
402
|
+
|
403
|
+
with self.init_scope():
|
404
|
+
|
405
|
+
self.embed = L.EmbedID(n_words, n_words)
|
406
|
+
|
407
|
+
self.l1 = L.LSTM(n_words, nodes)
|
408
|
+
|
409
|
+
self.l2 = L.LSTM(nodes, nodes)
|
410
|
+
|
411
|
+
self.l3 = L.Linear(nodes, n_words)
|
412
|
+
|
413
|
+
|
414
|
+
|
415
|
+
def reset_state(self):
|
416
|
+
|
417
|
+
self.l1.reset_state()
|
418
|
+
|
419
|
+
self.l2.reset_state()
|
420
|
+
|
421
|
+
|
422
|
+
|
423
|
+
def __call__(self, x):
|
424
|
+
|
425
|
+
h0 = self.embed(x)
|
426
|
+
|
427
|
+
h1 = self.l1(h0)
|
428
|
+
|
429
|
+
h2 = self.l2(h1)
|
430
|
+
|
431
|
+
y = self.l3(h2)
|
432
|
+
|
433
|
+
return y
|
434
|
+
|
435
|
+
|
436
|
+
|
437
|
+
# ファイルを読み込む
|
438
|
+
|
439
|
+
w = codecs.open('all-words-parses.txt', 'r', 'utf8')
|
440
|
+
|
441
|
+
|
442
|
+
|
443
|
+
# 単語の一覧
|
444
|
+
|
445
|
+
words_parse = {}
|
446
|
+
|
447
|
+
|
448
|
+
|
449
|
+
# 1行ずつ処理する
|
450
|
+
|
451
|
+
line = w.readline()
|
452
|
+
|
453
|
+
while line:
|
454
|
+
|
455
|
+
# 行の中の単語をリストする
|
456
|
+
|
457
|
+
l = line.split(',')
|
458
|
+
|
459
|
+
if len(l) == 2:
|
460
|
+
|
461
|
+
r = int(l[0].strip())
|
462
|
+
|
463
|
+
if r in words_parse:
|
464
|
+
|
465
|
+
words_parse[r].append(l[1].strip())
|
466
|
+
|
467
|
+
else:
|
468
|
+
|
469
|
+
words_parse[r] = [l[1].strip()]
|
470
|
+
|
471
|
+
line = w.readline()
|
472
|
+
|
473
|
+
w.close()
|
474
|
+
|
475
|
+
|
476
|
+
|
477
|
+
# ニューラルネットワークの作成
|
478
|
+
|
479
|
+
model = Parses_Genarate_RNN(max(words_parse.keys())+1, 20)
|
480
|
+
|
481
|
+
|
482
|
+
|
483
|
+
# 学習結果を読み込む
|
484
|
+
|
485
|
+
chainer.serializers.load_hdf5( 'chapt07.hdf5', model )
|
486
|
+
|
487
|
+
|
488
|
+
|
489
|
+
if uses_device >= 0:
|
490
|
+
|
491
|
+
# GPUを使う
|
492
|
+
|
493
|
+
chainer.cuda.get_device_from_id(0).use()
|
494
|
+
|
495
|
+
chainer.cuda.check_cuda_available()
|
496
|
+
|
497
|
+
# GPU用データ形式に変換
|
498
|
+
|
499
|
+
model.to_gpu()
|
500
|
+
|
501
|
+
|
502
|
+
|
503
|
+
# 木探索で生成する最大の深さ
|
504
|
+
|
505
|
+
words_max = 50
|
506
|
+
|
507
|
+
# RNNの実行結果から検索する単語の数
|
508
|
+
|
509
|
+
beam_w = 3
|
510
|
+
|
511
|
+
# 生成した文のリスト
|
512
|
+
|
513
|
+
parses = []
|
514
|
+
|
515
|
+
# 木探索のスタック
|
516
|
+
|
517
|
+
model_history = [model]
|
518
|
+
|
519
|
+
# 現在生成中の文
|
520
|
+
|
521
|
+
cur_parses = [0] # 開始文字
|
522
|
+
|
523
|
+
# 現在生成中の文のスコア
|
524
|
+
|
525
|
+
cur_score = []
|
526
|
+
|
527
|
+
# 最大のスコア
|
528
|
+
|
529
|
+
max_score = 0
|
530
|
+
|
531
|
+
|
532
|
+
|
533
|
+
# 再帰関数の木探索
|
534
|
+
|
535
|
+
def Tree_Traverse():
|
536
|
+
|
537
|
+
global max_score
|
538
|
+
|
539
|
+
# 現在の品詞を取得する
|
540
|
+
|
541
|
+
cur_parse = cur_parses[-1]
|
542
|
+
|
543
|
+
# 文のスコア
|
544
|
+
|
545
|
+
score = np.prod(cur_score)
|
546
|
+
|
547
|
+
# 現在の文の長さ
|
548
|
+
|
549
|
+
deep = len(cur_parses)
|
550
|
+
|
551
|
+
# 枝刈り - 単語数が5以上で最大スコアの6割以下なら、終わる
|
552
|
+
|
553
|
+
if max_score > 0 and deep > 5 and max_score * 0.6 > score:
|
554
|
+
|
555
|
+
return
|
556
|
+
|
557
|
+
# 終了文字か、最大の文の長さ以上なら、品詞を追加して終わる
|
558
|
+
|
559
|
+
if cur_parse == 1 or deep > words_max:
|
560
|
+
|
561
|
+
# 文のデータをコピー
|
562
|
+
|
563
|
+
data = np.array(cur_parses)
|
564
|
+
|
565
|
+
# 文を追加
|
566
|
+
|
567
|
+
parses.append((score, data))
|
568
|
+
|
569
|
+
# 最大スコアを更新
|
570
|
+
|
571
|
+
if max_score < score:
|
572
|
+
|
573
|
+
max_score = score
|
574
|
+
|
575
|
+
return
|
576
|
+
|
577
|
+
# 現在のニューラルネットワークのステータスをコピーする
|
578
|
+
|
579
|
+
cur_model = model_history[-1].copy()
|
580
|
+
|
581
|
+
# 入力値を作る
|
582
|
+
|
583
|
+
x = cp.array([cur_parse], dtype=cp.int32)
|
584
|
+
|
585
|
+
# ニューラルネットワークに入力する
|
586
|
+
|
587
|
+
y = cur_model(x)
|
588
|
+
|
589
|
+
# 実行結果を正規化する
|
590
|
+
|
591
|
+
z = F.softmax(y)
|
592
|
+
|
593
|
+
# 結果のデータを取得
|
594
|
+
|
595
|
+
result = z.data[0]
|
596
|
+
|
597
|
+
if uses_device >= 0:
|
598
|
+
|
599
|
+
result = chainer.cuda.to_cpu(result)
|
600
|
+
|
601
|
+
# 結果を確立順に並べ替える
|
602
|
+
|
603
|
+
p = np.argsort(result)[::-1]
|
604
|
+
|
605
|
+
# 現在のニューラルネットワークのステータスを保存する
|
606
|
+
|
607
|
+
model_history.append(cur_model)
|
608
|
+
|
609
|
+
# 結果から上位のものを次の枝に回す
|
610
|
+
|
611
|
+
for i in range(beam_w):
|
612
|
+
|
613
|
+
# 現在生成中の文に一文字追加する
|
614
|
+
|
615
|
+
cur_parses.append(p[i])
|
616
|
+
|
617
|
+
# 現在生成中の文のスコアに一つ追加する
|
618
|
+
|
619
|
+
cur_score.append(result[p[i]])
|
620
|
+
|
621
|
+
# 再帰呼び出し
|
622
|
+
|
623
|
+
Tree_Traverse()
|
624
|
+
|
625
|
+
# 現在生成中の文を一つ戻す
|
626
|
+
|
627
|
+
cur_parses.pop()
|
628
|
+
|
629
|
+
# 現在生成中の文のスコアを一つ戻す
|
630
|
+
|
631
|
+
cur_score.pop()
|
632
|
+
|
633
|
+
# ニューラルネットワークのステータスを一つ戻す
|
634
|
+
|
635
|
+
model_history.pop()
|
636
|
+
|
637
|
+
|
638
|
+
|
639
|
+
# 木検索して文章を生成する
|
640
|
+
|
641
|
+
Tree_Traverse()
|
642
|
+
|
643
|
+
|
644
|
+
|
645
|
+
# Word2Vecのモデルを読み込む
|
646
|
+
|
647
|
+
word_vec = word2vec.Word2Vec.load('word2vec.gensim.model')
|
648
|
+
|
649
|
+
|
650
|
+
|
651
|
+
# 文章のターゲット
|
652
|
+
|
653
|
+
target_str = ['元日']
|
654
|
+
|
655
|
+
#target_str = ['神']
|
656
|
+
|
657
|
+
#target_str = ['キリスト']
|
658
|
+
|
659
|
+
#target_str = ['父','子','聖霊']
|
660
|
+
|
661
|
+
#target_str = ['不思議','の','国','の','アリス']
|
662
|
+
|
663
|
+
#target_str = ['三月','うさぎ','の','お茶','会']
|
664
|
+
|
665
|
+
#target_str = ['女王']
|
666
|
+
|
667
|
+
|
668
|
+
|
669
|
+
# 指定した品詞の単語を文章がターゲットに近づくように返す
|
670
|
+
|
671
|
+
def similarity_word( parse, history ):
|
672
|
+
|
673
|
+
scores = []
|
674
|
+
|
675
|
+
# 品詞から候補をリスト
|
676
|
+
|
677
|
+
for i in range(len(words_parse[parse])):
|
678
|
+
|
679
|
+
w = words_parse[parse][i]
|
680
|
+
|
681
|
+
if w in word_vec:
|
682
|
+
|
683
|
+
# 候補のベクトルを履歴ベクトルに足す
|
684
|
+
|
685
|
+
t = history[:]
|
686
|
+
|
687
|
+
t.append(w)
|
688
|
+
|
689
|
+
# ターゲットとの距離を計算
|
690
|
+
|
691
|
+
sim = word_vec.n_similarity(target_str, t)
|
692
|
+
|
693
|
+
scores.append((sim, w))
|
694
|
+
|
695
|
+
# 結果をスコア順に並べ替える
|
696
|
+
|
697
|
+
result = sorted(scores, key=lambda x: x[0])[::-1]
|
698
|
+
|
699
|
+
return result[0]
|
700
|
+
|
701
|
+
|
702
|
+
|
703
|
+
|
704
|
+
|
705
|
+
# スコアの高いものから順に表示する
|
706
|
+
|
707
|
+
result_set = sorted(parses, key=lambda x: x[0])[::-1]
|
708
|
+
|
709
|
+
# 10個または全部の少ない方の数だけ表示
|
710
|
+
|
711
|
+
for i in range(min([10,len(result_set)])):
|
712
|
+
|
713
|
+
# 結果を取得
|
714
|
+
|
715
|
+
s, l = result_set[i]
|
716
|
+
|
717
|
+
# これまで登場した単語
|
718
|
+
|
719
|
+
history = []
|
720
|
+
|
721
|
+
# 開始文字と終端文字を除いてループ
|
722
|
+
|
723
|
+
for j in range(1,len(l)-1):
|
724
|
+
|
725
|
+
score, cur_word = similarity_word(l[j], history)
|
726
|
+
|
727
|
+
history.append(cur_word)
|
728
|
+
|
729
|
+
sys.stdout.buffer.write(cur_word.encode('utf-8'))
|
730
|
+
|
731
|
+
|
732
|
+
|
733
|
+
sys.stdout.buffer.write("\n".encode('utf-8'))
|
734
|
+
|
735
|
+
sys.stdout.buffer.flush()
|
736
|
+
|
737
|
+
|
738
|
+
|
739
|
+
|
740
|
+
|
741
|
+
|
742
|
+
|
743
|
+
```
|