質問編集履歴
4
ミス訂正
test
CHANGED
File without changes
|
test
CHANGED
@@ -114,8 +114,8 @@
|
|
114
114
|
|
115
115
|
vaem2 = VAEM2()
|
116
116
|
|
117
|
-
|
117
|
+
training = vaem2.training_model()
|
118
118
|
|
119
|
-
|
119
|
+
training.compile(optimizer='adam', loss=vaem2.cost)
|
120
120
|
|
121
121
|
```
|
3
編集
test
CHANGED
File without changes
|
test
CHANGED
@@ -2,131 +2,47 @@
|
|
2
2
|
|
3
3
|
|
4
4
|
|
5
|
-
Epoch 1/10
|
6
|
-
|
7
|
-
Traceback (most recent call last):
|
8
|
-
|
9
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1361, in _do_call
|
10
|
-
|
11
|
-
return fn(*args)
|
12
|
-
|
13
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1340, in _run_fn
|
14
|
-
|
15
|
-
target_list, status, run_metadata)
|
16
|
-
|
17
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 516, in __exit__
|
18
|
-
|
19
|
-
c_api.TF_GetCode(self.status.status))
|
20
|
-
|
21
|
-
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'input_4' with dtype float and shape [?,10]
|
22
|
-
|
23
|
-
[[Node: input_4 = Placeholder[dtype=DT_FLOAT, shape=[?,10], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
|
24
5
|
|
25
6
|
|
26
7
|
|
8
|
+
|
27
|
-
|
9
|
+
------------------追記ーーーーーーーーーーーーーーーーーーーーーーーーー
|
10
|
+
|
11
|
+
すみません、確かにinput4は使わないのに宣言していたせいだった様です!
|
12
|
+
|
13
|
+
ただ次に以下の様なエラーが出たのですがこれも良くわかりません。
|
28
14
|
|
29
15
|
|
30
16
|
|
31
17
|
Traceback (most recent call last):
|
32
18
|
|
33
|
-
File "training_m2.py", line
|
19
|
+
File "training_m2.py", line 83, in <module>
|
34
20
|
|
35
|
-
|
21
|
+
training.compile(optimizer='adam', loss=vaem2.cost)
|
36
22
|
|
37
|
-
File "/usr/local/lib/python3.6/site-packages/keras/engine/training.py", line
|
23
|
+
File "/usr/local/lib/python3.6/site-packages/keras/engine/training.py", line 830, in compile
|
38
24
|
|
39
|
-
|
25
|
+
sample_weight, mask)
|
40
26
|
|
41
|
-
File "/usr/local/lib/python3.6/site-packages/keras/engine/training.py", line
|
27
|
+
File "/usr/local/lib/python3.6/site-packages/keras/engine/training.py", line 429, in weighted
|
42
28
|
|
43
|
-
o
|
29
|
+
score_array = fn(y_true, y_pred)
|
44
30
|
|
45
|
-
File "/
|
31
|
+
File "/Users/tmsmac/Google ドライブ/Python/SemiSupervised/keras-VAE-master/vae_m2.py", line 185, in cost
|
46
32
|
|
47
|
-
|
33
|
+
if np.any(y_true > 0):
|
48
34
|
|
49
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/
|
35
|
+
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 614, in __bool__
|
50
36
|
|
51
|
-
r
|
37
|
+
raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
|
52
38
|
|
53
|
-
|
54
|
-
|
55
|
-
feed_dict_tensor, options, run_metadata)
|
56
|
-
|
57
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1355, in _do_run
|
58
|
-
|
59
|
-
options, run_metadata)
|
60
|
-
|
61
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1374, in _do_call
|
62
|
-
|
63
|
-
raise type(e)(node_def, op, message)
|
64
|
-
|
65
|
-
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'input_4' with dtype float and shape [?,10]
|
66
|
-
|
67
|
-
[[Node: input_4 = Placeholder[dtype=DT_FLOAT, shape=[?,10], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
|
39
|
+
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
|
68
40
|
|
69
41
|
|
70
42
|
|
71
|
-
Caused by op 'input_4', defined at:
|
72
|
-
|
73
|
-
File "training_m2.py", line 81, in <module>
|
74
|
-
|
75
|
-
vaem2 = VAEM2()
|
76
|
-
|
77
|
-
File "/Users/tmsmac/Google ドライブ/Python/SemiSupervised/keras-VAE-master/vae_m2.py", line 19, in __init__
|
78
|
-
|
79
|
-
self.y_u = Input((self.cat_dim, ))
|
80
|
-
|
81
|
-
File "/usr/local/lib/python3.6/site-packages/keras/engine/topology.py", line 1455, in Input
|
82
|
-
|
83
|
-
input_tensor=tensor)
|
84
|
-
|
85
|
-
File "/usr/local/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
|
86
|
-
|
87
|
-
return func(*args, **kwargs)
|
88
|
-
|
89
|
-
File "/usr/local/lib/python3.6/site-packages/keras/engine/topology.py", line 1364, in __init__
|
90
|
-
|
91
|
-
name=self.name)
|
92
|
-
|
93
|
-
File "/usr/local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 504, in placeholder
|
94
|
-
|
95
|
-
x = tf.placeholder(dtype, shape=shape, name=name)
|
96
|
-
|
97
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1746, in placeholder
|
98
|
-
|
99
|
-
return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
|
100
|
-
|
101
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3051, in _placeholder
|
102
|
-
|
103
|
-
"Placeholder", dtype=dtype, shape=shape, name=name)
|
104
|
-
|
105
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
|
106
|
-
|
107
|
-
op_def=op_def)
|
108
|
-
|
109
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3271, in create_op
|
110
|
-
|
111
|
-
op_def=op_def)
|
112
|
-
|
113
|
-
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1650, in __init__
|
114
|
-
|
115
|
-
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
|
116
43
|
|
117
44
|
|
118
|
-
|
119
|
-
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'input_4' with dtype float and shape [?,10]
|
120
|
-
|
121
|
-
[[Node: input_4 = Placeholder[dtype=DT_FLOAT, shape=[?,10], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
------------------追記ーーーーーーーーーーーーーーーーーーーーーーーーー
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
以下に
|
45
|
+
以下にエラーに関連しているだろう部分のコードを載せます。基本的にhttps://github.com/rarilurelo/keras-VAE
|
130
46
|
|
131
47
|
こちらのコードを編集しています。
|
132
48
|
|
@@ -134,227 +50,7 @@
|
|
134
50
|
|
135
51
|
```Python
|
136
52
|
|
137
|
-
class VAEM2(object):
|
138
|
-
|
139
|
-
def __init__(self, in_dim=50, cat_dim=10, hid_dim=300, z_dim=50, alpha=0):
|
140
|
-
|
141
|
-
self.in_dim = in_dim
|
142
|
-
|
143
|
-
self.cat_dim = cat_dim
|
144
|
-
|
145
|
-
self.hid_dim = hid_dim
|
146
|
-
|
147
|
-
self.z_dim = z_dim
|
148
|
-
|
149
|
-
self.alpha = alpha
|
150
|
-
|
151
|
-
self.x_l = Input((self.in_dim, ))
|
152
|
-
|
153
|
-
self.x_u = Input((self.in_dim, ))
|
154
|
-
|
155
|
-
self.y_l = Input((self.cat_dim, ))
|
156
|
-
|
157
|
-
self.y_u = Input((self.cat_dim, ))
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
self.z = Input((self.z_dim, ))
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
###############
|
166
|
-
|
167
|
-
# q(z | x, y) #
|
168
|
-
|
169
|
-
###############
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
Inputx= Input(shape=(self.in_dim,))
|
174
|
-
|
175
|
-
x_branch = Dense(self.hid_dim)(Inputx)
|
176
|
-
|
177
|
-
x_branch = BatchNormalization()(x_branch)
|
178
|
-
|
179
|
-
x_branch = Activation('softplus')(x_branch)
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
Inputy = Input(shape=(self.cat_dim,))
|
184
|
-
|
185
|
-
y_branch = Dense(self.hid_dim)(Inputy)
|
186
|
-
|
187
|
-
y_branch = BatchNormalization()(y_branch)
|
188
|
-
|
189
|
-
y_branch = Activation('softplus')(y_branch)
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
merged = Concatenate()([x_branch, y_branch])
|
194
|
-
|
195
|
-
merged = Dense(self.hid_dim)(merged)
|
196
|
-
|
197
|
-
merged = BatchNormalization()(merged)
|
198
|
-
|
199
|
-
merged = Activation('softplus')(merged)
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
mean = Dense(self.hid_dim)(merged)
|
204
|
-
|
205
|
-
mean = BatchNormalization()(mean)
|
206
|
-
|
207
|
-
mean = Activation('softplus')(mean)
|
208
|
-
|
209
|
-
mean = Dense(self.z_dim)(mean)
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
mean = Model(inputs=[Inputx,Inputy], outputs=mean)
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
var = Dense(self.hid_dim)(merged)
|
218
|
-
|
219
|
-
var = BatchNormalization()(var)
|
220
|
-
|
221
|
-
var = Activation('softplus')(var)
|
222
|
-
|
223
|
-
var = Dense(self.z_dim)(var)
|
224
|
-
|
225
|
-
var = Activation('softplus')(var)
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
var = Model(inputs=[Inputx,Inputy], outputs=var)
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
self.q_z_xy = GaussianDistribution(self.z, givens=[self.x_l, self.y_l], mean_model=mean, var_model=var)
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
###############
|
238
|
-
|
239
|
-
# p(x | y, z) #
|
240
|
-
|
241
|
-
###############
|
242
|
-
|
243
|
-
Inputy = Input(shape=(self.cat_dim,))
|
244
|
-
|
245
|
-
y_branch = Dense(self.hid_dim)(Inputy)
|
246
|
-
|
247
|
-
y_branch = BatchNormalization()(y_branch)
|
248
|
-
|
249
|
-
y_branch = Activation('softplus')(y_branch)
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
Inputz= Input(shape=(self.z_dim,))
|
254
|
-
|
255
|
-
z_branch = Dense(self.hid_dim)(Inputz)
|
256
|
-
|
257
|
-
z_branch = BatchNormalization()(z_branch)
|
258
|
-
|
259
|
-
z_branch = Activation('softplus')(z_branch)
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
merged = Concatenate()([y_branch, z_branch])
|
264
|
-
|
265
|
-
merged = Dense(self.hid_dim)(merged)
|
266
|
-
|
267
|
-
merged = BatchNormalization()(merged)
|
268
|
-
|
269
|
-
merged = Activation('softplus')(merged)
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
mean = Dense(self.hid_dim)(merged)
|
274
|
-
|
275
|
-
mean = BatchNormalization()(mean)
|
276
|
-
|
277
|
-
mean = Activation('softplus')(mean)
|
278
|
-
|
279
|
-
mean = Dense(self.z_dim)(mean)
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
mean = Model(inputs=[Inputy,Inputz], outputs=mean)
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
var = Dense(self.hid_dim)(merged)
|
288
|
-
|
289
|
-
var = BatchNormalization()(var)
|
290
|
-
|
291
|
-
var = Activation('softplus')(var)
|
292
|
-
|
293
|
-
var = Dense(self.in_dim)(var)
|
294
|
-
|
295
|
-
var = Activation('softplus')(var)
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
var = Model(inputs=[Inputy,Inputz], outputs=var)
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
self.p_x_yz = GaussianDistribution(self.x_l, givens=[self.y_l, self.z], mean_model=mean, var_model=var)
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
########
|
310
|
-
|
311
|
-
# p(y) #
|
312
|
-
|
313
|
-
########
|
314
|
-
|
315
|
-
self.p_y = CategoricalDistribution(self.y_l)
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
############
|
320
|
-
|
321
|
-
# q(y | x) #
|
322
|
-
|
323
|
-
############
|
324
|
-
|
325
|
-
inference = Sequential()
|
326
|
-
|
327
|
-
inference.add(Dense(self.hid_dim, input_dim=self.in_dim))
|
328
|
-
|
329
|
-
inference.add(BatchNormalization())
|
330
|
-
|
331
|
-
inference.add(Activation('softplus'))
|
332
|
-
|
333
|
-
inference.add(Dense(self.hid_dim))
|
334
|
-
|
335
|
-
inference.add(BatchNormalization())
|
336
|
-
|
337
|
-
inference.add(Activation('softplus'))
|
338
|
-
|
339
|
-
inference.add(Dense(self.cat_dim, activation='softmax'))
|
340
|
-
|
341
|
-
self.q_y_x = CategoricalDistribution(self.y_l, givens=[self.x_l], model=inference)
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
##########################
|
346
|
-
|
347
|
-
# sample and reconstruct #
|
348
|
-
|
349
|
-
##########################
|
350
|
-
|
351
|
-
self.sampling_z = self.q_z_xy.sampling(givens=[self.x_l, self.y_l])
|
352
|
-
|
353
|
-
self.reconstruct_x_l = self.p_x_yz.sampling(givens=[self.y_l, self.sampling_z])
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
53
|
+
def cost(self, y_true, y_false):
|
358
54
|
|
359
55
|
###########
|
360
56
|
|
@@ -362,15 +58,19 @@
|
|
362
58
|
|
363
59
|
###########
|
364
60
|
|
365
|
-
|
61
|
+
L = 0
|
366
62
|
|
367
|
-
|
63
|
+
if np.any(y_true > 0):
|
368
64
|
|
369
|
-
l
|
65
|
+
self.mean, self.var = self.q_z_xy.get_params(givens=[self.x_l, self.y_l])
|
370
66
|
|
371
|
-
L = KL
|
67
|
+
KL = self._KL(self.mean, self.var)
|
372
68
|
|
69
|
+
logliklihood = -self.p_x_yz.logliklihood(self.x_l, givens=[self.y_l, self.sampling_z])-self.p_y.logliklihood(self.y_l)
|
70
|
+
|
71
|
+
L = KL+logliklihood
|
72
|
+
|
373
|
-
L = L+self.alpha*self.q_y_x.logliklihood(self.y_l, givens=[self.x_l])
|
73
|
+
L = L+self.alpha*self.q_y_x.logliklihood(self.y_l, givens=[self.x_l])
|
374
74
|
|
375
75
|
|
376
76
|
|
@@ -384,32 +84,38 @@
|
|
384
84
|
|
385
85
|
# marginalization
|
386
86
|
|
387
|
-
|
87
|
+
if not np.any(y_true > 0) :
|
388
88
|
|
389
|
-
|
89
|
+
y = y_false
|
390
90
|
|
391
|
-
|
91
|
+
mean, var = self.q_z_xy.get_params(givens=[self.x_u, y])
|
392
92
|
|
393
|
-
|
93
|
+
sampling_z = self.q_z_xy.sampling(givens=[self.x_u, y])
|
394
94
|
|
395
|
-
|
95
|
+
U += self.q_y_x.prob(y, givens=[self.x_u])*(-self.p_x_yz.logliklihood(self.x_u, givens=[y, sampling_z])
|
396
96
|
|
397
|
-
|
97
|
+
-self.p_y.logliklihood(y)
|
398
98
|
|
399
|
-
+self.
|
99
|
+
+self._KL(mean, var)
|
400
100
|
|
101
|
+
+self.q_y_x.logliklihood(y, givens=[self.x_u])
|
102
|
+
|
401
|
-
)
|
103
|
+
)
|
402
104
|
|
403
105
|
return U+L
|
404
106
|
|
405
|
-
#これがモデルです。
|
406
107
|
|
407
|
-
def training_model(self):
|
408
|
-
|
409
|
-
model = Model(input=[self.x_l, self.y_l], output=self.reconstruct_x_l)
|
410
|
-
|
411
|
-
return model
|
412
108
|
|
413
109
|
|
414
110
|
|
415
111
|
```
|
112
|
+
|
113
|
+
```Python
|
114
|
+
|
115
|
+
vaem2 = VAEM2()
|
116
|
+
|
117
|
+
training = vaem2.training_model()
|
118
|
+
|
119
|
+
training.compile(optimizer='adam', loss=vaem2.cost)
|
120
|
+
|
121
|
+
```
|
2
miss
test
CHANGED
File without changes
|
test
CHANGED
@@ -409,3 +409,7 @@
|
|
409
409
|
model = Model(input=[self.x_l, self.y_l], output=self.reconstruct_x_l)
|
410
410
|
|
411
411
|
return model
|
412
|
+
|
413
|
+
|
414
|
+
|
415
|
+
```
|
1
コードの追加
test
CHANGED
File without changes
|
test
CHANGED
@@ -119,3 +119,293 @@
|
|
119
119
|
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'input_4' with dtype float and shape [?,10]
|
120
120
|
|
121
121
|
[[Node: input_4 = Placeholder[dtype=DT_FLOAT, shape=[?,10], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
|
122
|
+
|
123
|
+
|
124
|
+
|
125
|
+
------------------追記ーーーーーーーーーーーーーーーーーーーーーーーーー
|
126
|
+
|
127
|
+
|
128
|
+
|
129
|
+
以下にそのモデル等を示します。長くなりすぎるので所々省略しましたが、基本的にhttps://github.com/rarilurelo/keras-VAE
|
130
|
+
|
131
|
+
こちらのコードを編集しています。
|
132
|
+
|
133
|
+
|
134
|
+
|
135
|
+
```Python
|
136
|
+
|
137
|
+
class VAEM2(object):
|
138
|
+
|
139
|
+
def __init__(self, in_dim=50, cat_dim=10, hid_dim=300, z_dim=50, alpha=0):
|
140
|
+
|
141
|
+
self.in_dim = in_dim
|
142
|
+
|
143
|
+
self.cat_dim = cat_dim
|
144
|
+
|
145
|
+
self.hid_dim = hid_dim
|
146
|
+
|
147
|
+
self.z_dim = z_dim
|
148
|
+
|
149
|
+
self.alpha = alpha
|
150
|
+
|
151
|
+
self.x_l = Input((self.in_dim, ))
|
152
|
+
|
153
|
+
self.x_u = Input((self.in_dim, ))
|
154
|
+
|
155
|
+
self.y_l = Input((self.cat_dim, ))
|
156
|
+
|
157
|
+
self.y_u = Input((self.cat_dim, ))
|
158
|
+
|
159
|
+
|
160
|
+
|
161
|
+
self.z = Input((self.z_dim, ))
|
162
|
+
|
163
|
+
|
164
|
+
|
165
|
+
###############
|
166
|
+
|
167
|
+
# q(z | x, y) #
|
168
|
+
|
169
|
+
###############
|
170
|
+
|
171
|
+
|
172
|
+
|
173
|
+
Inputx= Input(shape=(self.in_dim,))
|
174
|
+
|
175
|
+
x_branch = Dense(self.hid_dim)(Inputx)
|
176
|
+
|
177
|
+
x_branch = BatchNormalization()(x_branch)
|
178
|
+
|
179
|
+
x_branch = Activation('softplus')(x_branch)
|
180
|
+
|
181
|
+
|
182
|
+
|
183
|
+
Inputy = Input(shape=(self.cat_dim,))
|
184
|
+
|
185
|
+
y_branch = Dense(self.hid_dim)(Inputy)
|
186
|
+
|
187
|
+
y_branch = BatchNormalization()(y_branch)
|
188
|
+
|
189
|
+
y_branch = Activation('softplus')(y_branch)
|
190
|
+
|
191
|
+
|
192
|
+
|
193
|
+
merged = Concatenate()([x_branch, y_branch])
|
194
|
+
|
195
|
+
merged = Dense(self.hid_dim)(merged)
|
196
|
+
|
197
|
+
merged = BatchNormalization()(merged)
|
198
|
+
|
199
|
+
merged = Activation('softplus')(merged)
|
200
|
+
|
201
|
+
|
202
|
+
|
203
|
+
mean = Dense(self.hid_dim)(merged)
|
204
|
+
|
205
|
+
mean = BatchNormalization()(mean)
|
206
|
+
|
207
|
+
mean = Activation('softplus')(mean)
|
208
|
+
|
209
|
+
mean = Dense(self.z_dim)(mean)
|
210
|
+
|
211
|
+
|
212
|
+
|
213
|
+
mean = Model(inputs=[Inputx,Inputy], outputs=mean)
|
214
|
+
|
215
|
+
|
216
|
+
|
217
|
+
var = Dense(self.hid_dim)(merged)
|
218
|
+
|
219
|
+
var = BatchNormalization()(var)
|
220
|
+
|
221
|
+
var = Activation('softplus')(var)
|
222
|
+
|
223
|
+
var = Dense(self.z_dim)(var)
|
224
|
+
|
225
|
+
var = Activation('softplus')(var)
|
226
|
+
|
227
|
+
|
228
|
+
|
229
|
+
var = Model(inputs=[Inputx,Inputy], outputs=var)
|
230
|
+
|
231
|
+
|
232
|
+
|
233
|
+
self.q_z_xy = GaussianDistribution(self.z, givens=[self.x_l, self.y_l], mean_model=mean, var_model=var)
|
234
|
+
|
235
|
+
|
236
|
+
|
237
|
+
###############
|
238
|
+
|
239
|
+
# p(x | y, z) #
|
240
|
+
|
241
|
+
###############
|
242
|
+
|
243
|
+
Inputy = Input(shape=(self.cat_dim,))
|
244
|
+
|
245
|
+
y_branch = Dense(self.hid_dim)(Inputy)
|
246
|
+
|
247
|
+
y_branch = BatchNormalization()(y_branch)
|
248
|
+
|
249
|
+
y_branch = Activation('softplus')(y_branch)
|
250
|
+
|
251
|
+
|
252
|
+
|
253
|
+
Inputz= Input(shape=(self.z_dim,))
|
254
|
+
|
255
|
+
z_branch = Dense(self.hid_dim)(Inputz)
|
256
|
+
|
257
|
+
z_branch = BatchNormalization()(z_branch)
|
258
|
+
|
259
|
+
z_branch = Activation('softplus')(z_branch)
|
260
|
+
|
261
|
+
|
262
|
+
|
263
|
+
merged = Concatenate()([y_branch, z_branch])
|
264
|
+
|
265
|
+
merged = Dense(self.hid_dim)(merged)
|
266
|
+
|
267
|
+
merged = BatchNormalization()(merged)
|
268
|
+
|
269
|
+
merged = Activation('softplus')(merged)
|
270
|
+
|
271
|
+
|
272
|
+
|
273
|
+
mean = Dense(self.hid_dim)(merged)
|
274
|
+
|
275
|
+
mean = BatchNormalization()(mean)
|
276
|
+
|
277
|
+
mean = Activation('softplus')(mean)
|
278
|
+
|
279
|
+
mean = Dense(self.z_dim)(mean)
|
280
|
+
|
281
|
+
|
282
|
+
|
283
|
+
mean = Model(inputs=[Inputy,Inputz], outputs=mean)
|
284
|
+
|
285
|
+
|
286
|
+
|
287
|
+
var = Dense(self.hid_dim)(merged)
|
288
|
+
|
289
|
+
var = BatchNormalization()(var)
|
290
|
+
|
291
|
+
var = Activation('softplus')(var)
|
292
|
+
|
293
|
+
var = Dense(self.in_dim)(var)
|
294
|
+
|
295
|
+
var = Activation('softplus')(var)
|
296
|
+
|
297
|
+
|
298
|
+
|
299
|
+
var = Model(inputs=[Inputy,Inputz], outputs=var)
|
300
|
+
|
301
|
+
|
302
|
+
|
303
|
+
|
304
|
+
|
305
|
+
self.p_x_yz = GaussianDistribution(self.x_l, givens=[self.y_l, self.z], mean_model=mean, var_model=var)
|
306
|
+
|
307
|
+
|
308
|
+
|
309
|
+
########
|
310
|
+
|
311
|
+
# p(y) #
|
312
|
+
|
313
|
+
########
|
314
|
+
|
315
|
+
self.p_y = CategoricalDistribution(self.y_l)
|
316
|
+
|
317
|
+
|
318
|
+
|
319
|
+
############
|
320
|
+
|
321
|
+
# q(y | x) #
|
322
|
+
|
323
|
+
############
|
324
|
+
|
325
|
+
inference = Sequential()
|
326
|
+
|
327
|
+
inference.add(Dense(self.hid_dim, input_dim=self.in_dim))
|
328
|
+
|
329
|
+
inference.add(BatchNormalization())
|
330
|
+
|
331
|
+
inference.add(Activation('softplus'))
|
332
|
+
|
333
|
+
inference.add(Dense(self.hid_dim))
|
334
|
+
|
335
|
+
inference.add(BatchNormalization())
|
336
|
+
|
337
|
+
inference.add(Activation('softplus'))
|
338
|
+
|
339
|
+
inference.add(Dense(self.cat_dim, activation='softmax'))
|
340
|
+
|
341
|
+
self.q_y_x = CategoricalDistribution(self.y_l, givens=[self.x_l], model=inference)
|
342
|
+
|
343
|
+
|
344
|
+
|
345
|
+
##########################
|
346
|
+
|
347
|
+
# sample and reconstruct #
|
348
|
+
|
349
|
+
##########################
|
350
|
+
|
351
|
+
self.sampling_z = self.q_z_xy.sampling(givens=[self.x_l, self.y_l])
|
352
|
+
|
353
|
+
self.reconstruct_x_l = self.p_x_yz.sampling(givens=[self.y_l, self.sampling_z])
|
354
|
+
|
355
|
+
|
356
|
+
|
357
|
+
def cost(self, y_true, y_false):
|
358
|
+
|
359
|
+
###########
|
360
|
+
|
361
|
+
# Labeled #
|
362
|
+
|
363
|
+
###########
|
364
|
+
|
365
|
+
self.mean, self.var = self.q_z_xy.get_params(givens=[self.x_l, self.y_l])
|
366
|
+
|
367
|
+
KL = self._KL(self.mean, self.var)
|
368
|
+
|
369
|
+
logliklihood = -self.p_x_yz.logliklihood(self.x_l, givens=[self.y_l, self.sampling_z])-self.p_y.logliklihood(self.y_l)
|
370
|
+
|
371
|
+
L = KL+logliklihood
|
372
|
+
|
373
|
+
L = L+self.alpha*self.q_y_x.logliklihood(self.y_l, givens=[self.x_l])
|
374
|
+
|
375
|
+
|
376
|
+
|
377
|
+
#############
|
378
|
+
|
379
|
+
# UnLabeled #
|
380
|
+
|
381
|
+
#############
|
382
|
+
|
383
|
+
U = 0
|
384
|
+
|
385
|
+
# marginalization
|
386
|
+
|
387
|
+
y = self.y_u
|
388
|
+
|
389
|
+
mean, var = self.q_z_xy.get_params(givens=[self.x_u, y])
|
390
|
+
|
391
|
+
sampling_z = self.q_z_xy.sampling(givens=[self.x_u, y])
|
392
|
+
|
393
|
+
U += self.q_y_x.prob(y, givens=[self.x_u])*(-self.p_x_yz.logliklihood(self.x_u, givens=[y, sampling_z])
|
394
|
+
|
395
|
+
-self.p_y.logliklihood(y)
|
396
|
+
|
397
|
+
+self._KL(mean, var)
|
398
|
+
|
399
|
+
+self.q_y_x.logliklihood(y, givens=[self.x_u])
|
400
|
+
|
401
|
+
)
|
402
|
+
|
403
|
+
return U+L
|
404
|
+
|
405
|
+
#これがモデルです。
|
406
|
+
|
407
|
+
def training_model(self):
|
408
|
+
|
409
|
+
model = Model(input=[self.x_l, self.y_l], output=self.reconstruct_x_l)
|
410
|
+
|
411
|
+
return model
|