質問編集履歴
1
情報の修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -1,511 +1 @@
|
|
1
|
-
初心者で、ボットを作ろうとしています。
|
2
|
-
|
3
1
|
tensorflowで学習させたモデルをボットに組み込みたいです。
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
```ここに言語を入力
|
8
|
-
|
9
|
-
import re
|
10
|
-
|
11
|
-
import tarfile
|
12
|
-
|
13
|
-
import numpy as np
|
14
|
-
|
15
|
-
import tensorflow as tf
|
16
|
-
|
17
|
-
from sklearn.utils import shuffle
|
18
|
-
|
19
|
-
from functools import reduce
|
20
|
-
|
21
|
-
from utils.data import get_file
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
np.random.seed(0)
|
26
|
-
|
27
|
-
tf.set_random_seed(1234)
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
def inference(x, q, n_batch,
|
34
|
-
|
35
|
-
vocab_size=None,
|
36
|
-
|
37
|
-
embedding_dim=None,
|
38
|
-
|
39
|
-
story_maxlen=None,
|
40
|
-
|
41
|
-
question_maxlen=None):
|
42
|
-
|
43
|
-
def weight_variable(shape, stddev=0.08):
|
44
|
-
|
45
|
-
initial = tf.truncated_normal(shape, stddev=stddev)
|
46
|
-
|
47
|
-
return tf.Variable(initial)
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
def bias_variable(shape):
|
52
|
-
|
53
|
-
initial = tf.zeros(shape, dtype=tf.float32)
|
54
|
-
|
55
|
-
return tf.Variable(initial)
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
A = weight_variable([vocab_size, embedding_dim])
|
60
|
-
|
61
|
-
B = weight_variable([vocab_size, embedding_dim])
|
62
|
-
|
63
|
-
C = weight_variable([vocab_size, question_maxlen])
|
64
|
-
|
65
|
-
m = tf.nn.embedding_lookup(A, x)
|
66
|
-
|
67
|
-
u = tf.nn.embedding_lookup(B, q)
|
68
|
-
|
69
|
-
c = tf.nn.embedding_lookup(C, x)
|
70
|
-
|
71
|
-
p = tf.nn.softmax(tf.einsum('ijk,ilk->ijl', m, u))
|
72
|
-
|
73
|
-
o = tf.add(p, c)
|
74
|
-
|
75
|
-
o = tf.transpose(o, perm=[0, 2, 1])
|
76
|
-
|
77
|
-
ou = tf.concat([o, u], axis=-1)
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
cell = tf.contrib.rnn.BasicLSTMCell(embedding_dim//2, forget_bias=1.0)
|
82
|
-
|
83
|
-
initial_state = cell.zero_state(n_batch, tf.float32)
|
84
|
-
|
85
|
-
state = initial_state
|
86
|
-
|
87
|
-
outputs = []
|
88
|
-
|
89
|
-
with tf.variable_scope('LSTM'):
|
90
|
-
|
91
|
-
for t in range(question_maxlen):
|
92
|
-
|
93
|
-
if t > 0:
|
94
|
-
|
95
|
-
tf.get_variable_scope().reuse_variables()
|
96
|
-
|
97
|
-
(cell_output, state) = cell(ou[:, t, :], state)
|
98
|
-
|
99
|
-
outputs.append(cell_output)
|
100
|
-
|
101
|
-
output = outputs[-1]
|
102
|
-
|
103
|
-
W = weight_variable([embedding_dim//2, vocab_size], stddev=0.01)
|
104
|
-
|
105
|
-
a = tf.nn.softmax(tf.matmul(output, W))
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
return a
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
def loss(y, t):
|
116
|
-
|
117
|
-
cross_entropy = \
|
118
|
-
|
119
|
-
tf.reduce_mean(-tf.reduce_sum(
|
120
|
-
|
121
|
-
t * tf.log(tf.clip_by_value(y, 1e-10, 1.0)),
|
122
|
-
|
123
|
-
reduction_indices=[1]))
|
124
|
-
|
125
|
-
return cross_entropy
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
def training(loss):
|
132
|
-
|
133
|
-
optimizer = \
|
134
|
-
|
135
|
-
tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999)
|
136
|
-
|
137
|
-
train_step = optimizer.minimize(loss)
|
138
|
-
|
139
|
-
return train_step
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
def accuracy(y, t):
|
146
|
-
|
147
|
-
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(t, 1))
|
148
|
-
|
149
|
-
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
150
|
-
|
151
|
-
return accuracy
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
def tokenize(sent):
|
158
|
-
|
159
|
-
return [x.strip() for x in re.split('(\W+)', sent) if x.strip()]
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
def parse_stories(lines):
|
166
|
-
|
167
|
-
data = []
|
168
|
-
|
169
|
-
story = []
|
170
|
-
|
171
|
-
for line in lines:
|
172
|
-
|
173
|
-
line = line.decode('utf-8').strip()
|
174
|
-
|
175
|
-
nid, line = line.split(' ', 1)
|
176
|
-
|
177
|
-
nid = int(nid)
|
178
|
-
|
179
|
-
if nid == 1:
|
180
|
-
|
181
|
-
story = []
|
182
|
-
|
183
|
-
if '\t' in line:
|
184
|
-
|
185
|
-
q, a, supporting = line.split('\t')
|
186
|
-
|
187
|
-
q = tokenize(q)
|
188
|
-
|
189
|
-
substory = [x for x in story if x]
|
190
|
-
|
191
|
-
data.append((substory, q, a))
|
192
|
-
|
193
|
-
story.append('')
|
194
|
-
|
195
|
-
else:
|
196
|
-
|
197
|
-
sent = tokenize(line)
|
198
|
-
|
199
|
-
story.append(sent)
|
200
|
-
|
201
|
-
return data
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
def get_stories(f, max_length=None):
|
208
|
-
|
209
|
-
def flatten(data):
|
210
|
-
|
211
|
-
return reduce(lambda x, y: x + y, data)
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
data = parse_stories(f.readlines())
|
216
|
-
|
217
|
-
data = [(flatten(story), q, answer)
|
218
|
-
|
219
|
-
for story, q, answer in data
|
220
|
-
|
221
|
-
if not max_length or len(flatten(story)) < max_length]
|
222
|
-
|
223
|
-
return data
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
def vectorize_stories(data, word_indices, story_maxlen, question_maxlen):
|
230
|
-
|
231
|
-
X = []
|
232
|
-
|
233
|
-
Q = []
|
234
|
-
|
235
|
-
A = []
|
236
|
-
|
237
|
-
for story, question, answer in data:
|
238
|
-
|
239
|
-
x = [word_indices[w] for w in story]
|
240
|
-
|
241
|
-
q = [word_indices[w] for w in question]
|
242
|
-
|
243
|
-
a = np.zeros(len(word_indices) + 1) # パディング用に +1
|
244
|
-
|
245
|
-
a[word_indices[answer]] = 1
|
246
|
-
|
247
|
-
X.append(x)
|
248
|
-
|
249
|
-
Q.append(q)
|
250
|
-
|
251
|
-
A.append(a)
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
return (padding(X, maxlen=story_maxlen),
|
256
|
-
|
257
|
-
padding(Q, maxlen=question_maxlen), np.array(A))
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
def padding(words, maxlen):
|
264
|
-
|
265
|
-
for i, word in enumerate(words):
|
266
|
-
|
267
|
-
words[i] = [0] * (maxlen - len(word)) + word
|
268
|
-
|
269
|
-
return np.array(words)
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
if __name__ == '__main__':
|
276
|
-
|
277
|
-
'''
|
278
|
-
|
279
|
-
データ読み込み
|
280
|
-
|
281
|
-
'''
|
282
|
-
|
283
|
-
print('Fetching data...')
|
284
|
-
|
285
|
-
try:
|
286
|
-
|
287
|
-
path = \
|
288
|
-
|
289
|
-
get_file('wakati.tar.gz')
|
290
|
-
|
291
|
-
except Exception as e:
|
292
|
-
|
293
|
-
raise
|
294
|
-
|
295
|
-
tar = tarfile.open(path)
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
challenge = 'tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_{}.txt'
|
300
|
-
|
301
|
-
train_stories = get_stories(tar.extractfile(challenge.format('train')))
|
302
|
-
|
303
|
-
test_stories = get_stories(tar.extractfile(challenge.format('test')))
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
vocab = set()
|
308
|
-
|
309
|
-
for story, q, answer in train_stories + test_stories:
|
310
|
-
|
311
|
-
vocab |= set(story + q + [answer])
|
312
|
-
|
313
|
-
vocab = sorted(vocab)
|
314
|
-
|
315
|
-
vocab_size = len(vocab) + 1 # パディング用に +1
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
story_maxlen = \
|
320
|
-
|
321
|
-
max(map(len, (x for x, _, _ in train_stories + test_stories)))
|
322
|
-
|
323
|
-
question_maxlen = \
|
324
|
-
|
325
|
-
max(map(len, (x for _, x, _ in train_stories + test_stories)))
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
print('Vectorizing data...')
|
330
|
-
|
331
|
-
word_indices = dict((c, i + 1) for i, c in enumerate(vocab))
|
332
|
-
|
333
|
-
inputs_train, questions_train, answers_train = \
|
334
|
-
|
335
|
-
vectorize_stories(train_stories, word_indices,
|
336
|
-
|
337
|
-
story_maxlen, question_maxlen)
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
inputs_test, questions_test, answers_test = \
|
342
|
-
|
343
|
-
vectorize_stories(test_stories, word_indices,
|
344
|
-
|
345
|
-
story_maxlen, question_maxlen)
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
'''
|
350
|
-
|
351
|
-
モデル設定
|
352
|
-
|
353
|
-
'''
|
354
|
-
|
355
|
-
print('Building model...')
|
356
|
-
|
357
|
-
x = tf.placeholder(tf.int32, shape=[None, story_maxlen])
|
358
|
-
|
359
|
-
q = tf.placeholder(tf.int32, shape=[None, question_maxlen])
|
360
|
-
|
361
|
-
a = tf.placeholder(tf.float32, shape=[None, vocab_size])
|
362
|
-
|
363
|
-
n_batch = tf.placeholder(tf.int32, shape=[])
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
y = inference(x, q, n_batch,
|
368
|
-
|
369
|
-
vocab_size=vocab_size,
|
370
|
-
|
371
|
-
embedding_dim=64,
|
372
|
-
|
373
|
-
story_maxlen=story_maxlen,
|
374
|
-
|
375
|
-
question_maxlen=question_maxlen)
|
376
|
-
|
377
|
-
loss = loss(y, a)
|
378
|
-
|
379
|
-
train_step = training(loss)
|
380
|
-
|
381
|
-
acc = accuracy(y, a)
|
382
|
-
|
383
|
-
history = {
|
384
|
-
|
385
|
-
'val_loss': [],
|
386
|
-
|
387
|
-
'val_acc': []
|
388
|
-
|
389
|
-
}
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
'''
|
394
|
-
|
395
|
-
モデル学習
|
396
|
-
|
397
|
-
'''
|
398
|
-
|
399
|
-
print('Training model...')
|
400
|
-
|
401
|
-
epochs = 120
|
402
|
-
|
403
|
-
batch_size = 100
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
init = tf.global_variables_initializer()
|
408
|
-
|
409
|
-
sess = tf.Session()
|
410
|
-
|
411
|
-
sess.run(init)
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
n_batches = len(inputs_train) // batch_size
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
for epoch in range(epochs):
|
420
|
-
|
421
|
-
inputs_train_, questions_train_, answers_train_ = \
|
422
|
-
|
423
|
-
shuffle(inputs_train, questions_train, answers_train)
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
for i in range(n_batches):
|
428
|
-
|
429
|
-
start = i * batch_size
|
430
|
-
|
431
|
-
end = start + batch_size
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
sess.run(train_step, feed_dict={
|
436
|
-
|
437
|
-
x: inputs_train_[start:end],
|
438
|
-
|
439
|
-
q: questions_train_[start:end],
|
440
|
-
|
441
|
-
a: answers_train_[start:end],
|
442
|
-
|
443
|
-
n_batch: batch_size
|
444
|
-
|
445
|
-
})
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
# テストデータを用いた評価
|
450
|
-
|
451
|
-
val_loss = loss.eval(session=sess, feed_dict={
|
452
|
-
|
453
|
-
x: inputs_test,
|
454
|
-
|
455
|
-
q: questions_test,
|
456
|
-
|
457
|
-
a: answers_test,
|
458
|
-
|
459
|
-
n_batch: len(inputs_test)
|
460
|
-
|
461
|
-
})
|
462
|
-
|
463
|
-
val_acc = acc.eval(session=sess, feed_dict={
|
464
|
-
|
465
|
-
x: inputs_test,
|
466
|
-
|
467
|
-
q: questions_test,
|
468
|
-
|
469
|
-
a: answers_test,
|
470
|
-
|
471
|
-
n_batch: len(inputs_test)
|
472
|
-
|
473
|
-
})
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
history['val_loss'].append(val_loss)
|
478
|
-
|
479
|
-
history['val_acc'].append(val_acc)
|
480
|
-
|
481
|
-
print('epoch:', epoch,
|
482
|
-
|
483
|
-
' validation loss:', val_loss,
|
484
|
-
|
485
|
-
' validation accuracy:', val_acc)
|
486
|
-
|
487
|
-
```
|
488
|
-
|
489
|
-
とコードを書いて実行させて得られたモデルを使って、文を打つと応答してくれるボットを作りたいです(参考サイト: https://qiita.com/Umemiya/items/027f8bac0650c28590b5 )。
|
490
|
-
|
491
|
-
質問は、
|
492
|
-
|
493
|
-
・上記のコードを実行しても”モデルファイル”が得られたわけではなく、checkpointsファイルが得られただけだがそれがモデルファイルなのか?
|
494
|
-
|
495
|
-
・日本語を学習させたかったので、wakati.tar.gzには以下のような分かち書きファイルを入れたが、
|
496
|
-
|
497
|
-
```ここに言語を入力
|
498
|
-
|
499
|
-
こんにちは 。 今日 は 寒い です ね
|
500
|
-
|
501
|
-
```
|
502
|
-
|
503
|
-
分かち書きファイルを学習させて問題はなかったか?
|
504
|
-
|
505
|
-
・作ったモデルを使ってボットを完成させるための手順が知りたい
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
の3点です。
|
510
|
-
|
511
|
-
ご存知の方がいらしたらお願いします。
|