質問編集履歴

1

コード追加

2018/05/09 04:17

投稿

tanshoko
tanshoko

スコア9

test CHANGED
File without changes
test CHANGED
@@ -9,3 +9,211 @@
9
9
  プログラムの例などがあると、幸いです。
10
10
 
11
11
  よろしくお願いします。
12
+
13
+
14
+
15
+ ```python
16
+
17
+
18
+
19
+ from __future__ import print_function
20
+
21
+
22
+
23
+ import numpy as np
24
+
25
+ import matplotlib.pyplot as plt
26
+
27
+ from scipy.stats import norm
28
+
29
+ import time
30
+
31
+ from collections import Counter
32
+
33
+
34
+
35
+ from keras.layers import Input, Dense, Lambda, Concatenate
36
+
37
+ from keras.models import Model
38
+
39
+ from keras import backend as K
40
+
41
+ from keras import metrics
42
+
43
+ from keras.datasets import mnist
44
+
45
+ from keras import utils
46
+
47
+
48
+
49
+ batch_size = 100
50
+
51
+ original_dim = 784
52
+
53
+ latent_dim = 10
54
+
55
+ intermediate_dim = 256
56
+
57
+ epochs = 25
58
+
59
+ cat_dim = 10
60
+
61
+ epsilon_std = 1.0
62
+
63
+
64
+
65
+ x = Input(shape=(original_dim,))
66
+
67
+ h = Dense(intermediate_dim, activation='relu')(x)
68
+
69
+ z_mean = Dense(latent_dim)(h)
70
+
71
+ z_log_var = Dense(latent_dim)(h)
72
+
73
+
74
+
75
+ def sampling(args):
76
+
77
+ z_mean, z_log_var = args
78
+
79
+ epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0.,
80
+
81
+ stddev=epsilon_std)
82
+
83
+ return z_mean + K.exp(z_log_var / 2) * epsilon
84
+
85
+
86
+
87
+ def vae_loss(x, x_decoded_mean):
88
+
89
+ x = K.flatten(x)
90
+
91
+ x_decoded_mean = K.flatten(x_decoded_mean)
92
+
93
+ xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
94
+
95
+ kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
96
+
97
+ return K.mean(xent_loss + kl_loss)
98
+
99
+
100
+
101
+ z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
102
+
103
+
104
+
105
+ # ラベルありデータのラベルを入力
106
+
107
+ ly = Input(shape=(cat_dim,))
108
+
109
+ merge = Concatenate()([z, ly])
110
+
111
+
112
+
113
+ decoder_h = Dense(intermediate_dim, activation='relu')
114
+
115
+ decoder_mean = Dense(original_dim, activation='sigmoid')
116
+
117
+ h_decoded = decoder_h(merge)
118
+
119
+ x_decoded_mean = decoder_mean(h_decoded)
120
+
121
+
122
+
123
+ # ラベルありの時のモデル
124
+
125
+ labeled_M2 = Model([x,ly], x_decoded_mean)
126
+
127
+ labeled_M2.compile(optimizer='rmsprop', loss=vae_loss)
128
+
129
+
130
+
131
+
132
+
133
+ # データ整形
134
+
135
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
136
+
137
+
138
+
139
+ x_train = x_train.astype('float32') / 255.
140
+
141
+ x_test = x_test.astype('float32') / 255.
142
+
143
+ x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
144
+
145
+ x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
146
+
147
+
148
+
149
+
150
+
151
+ # トレーニングデータを100個選択
152
+
153
+
154
+
155
+ random_Num = np.random.randint(0,10000)
156
+
157
+ np.random.seed(random_Num)
158
+
159
+ np.random.shuffle(x_test)
160
+
161
+ np.random.seed(random_Num)
162
+
163
+ np.random.shuffle(y_test)
164
+
165
+ x_realtest = x_test[100:] # テストデータ
166
+
167
+ y_realtest = y_test[100:] # テストデータのラベル
168
+
169
+ x_test = np.delete(x_test, range(100,10000), axis=0)
170
+
171
+ y_test = np.delete(y_test, range(100,10000), axis=0)
172
+
173
+
174
+
175
+ # 教師なしデータ,教師ありデータ,テスト用データのラベルを,それぞれone-hot表現にする
176
+
177
+ y_train_cat = utils.to_categorical(y_train)
178
+
179
+ y_test_cat = utils.to_categorical(y_test)
180
+
181
+ y_realtest_cat = utils.to_categorical(y_realtest)
182
+
183
+
184
+
185
+ labeled_M2.fit([x_train, y_train_cat],
186
+
187
+ shuffle=True,
188
+
189
+ epochs=epochs,
190
+
191
+ batch_size=batch_size)
192
+
193
+
194
+
195
+
196
+
197
+ # ラベルなしの時のモデル
198
+
199
+ uy = Dense(intermediate_dim, activation='relu')(x)
200
+
201
+ merge = Concatenate()([z, uy])
202
+
203
+ h_decoded = decoder_h(merge)
204
+
205
+ x_decoded_mean = decoder_mean(h_decoded)
206
+
207
+
208
+
209
+ unlabeled_M2 = Model([x,uy],x_decoded_mean)
210
+
211
+ unlabeled_M2.compile(optimizer='rmsprop', loss=vae_loss)
212
+
213
+
214
+
215
+ labeled_M2.summary()
216
+
217
+ unlabeled_M2.summary()
218
+
219
+ ```