python3
1#!/usr/bin/env python 2 3 4import six 5 6import chainer 7import chainer.functions as F 8from chainer.functions.loss.vae import gaussian_kl_divergence 9import chainer.links as L 10 11 12class VAE(chainer.Chain): 13 """ Variational AutoEncoder""" 14 15 16 def __init__(self, n_ch, n_latent, n_first): 17 super(VAE, self).__init__() 18 with self.init_scope(): 19 # encoder 入力から隠れベクトルの作成 20 self.le1 = L.Convolution2D(n_ch, n_first, 4, 2, 1) 21 self.le2 = L.Convolution2D(n_first, n_first * 2, 4, 2, 1) 22 self.le3 = L.Convolution2D(n_first * 2, n_first * 4, 4, 2, 1) 23 self.le4_mu = L.Linear(, n_latent) 24 self.le4_ln_var = L.Linear(, n_latent) 25 # decoder 26 27 self.ld1 = L.Linear(n_latent, ) 28 self.ld2 = L.Deconvolution2D(n_first * 4, n_first * 2, 4, 2, 1) 29 self.ld3 = L.Deconvolution2D(n_first * 2, n_first, 4, 2, 1) 30 self.ld4 = L.Deconvolution2D(n_first, 1, 4, 2, 1) 31 32 33 def forward(self, x, sigmoid=False): 34 """AutoEncoder""" 35 return self.decode(self.encode(x), sigmoid) 36 37 def encode(self, x): 38 h1 = F.leaky_relu(self.le1(x), slope = 0.2) 39 h2 = F.leaky_relu(self.le2(h1), slope = 0.2) 40 h3 = F.leaky_relu(self.le3(h2), slope = 0.2) 41 mu = self.le4_mu(h3) 42 ln_var = self.le4_ln_var(h3) # log(sigma**2) 43 return mu, ln_var 44 45 def decode(self, z, sigmoid=False): 46 h1 = F.relu(self.ld1(z)) 47 h2 = F.relu(self.ld2(h1)) 48 h3 = F.relu(self.ld3(h2)) 49 h4 = self.ld4(h3) 50 if sigmoid: 51 return F.F.sigmoid(h4) 52 else: 53 return h4 54 55 def get_loss_func(self, beta=1.0, k=1): 56 """ 57 VAEの損失の計算 58 Args: 59 C (int): 正則化項をどれだけ効かせるかの変数、通常1.0が使用される 60 k (int): サンプルを何回行うか 61 """ 62 63 def lf(x): 64 mu, ln_var = self.encode(x) 65 batchsize = len(mu) 66 # 復元誤差の計算 67 rec_loss = 0 68 for l in six.moves.range(k): 69 z = F.gaussian(mu, ln_var) 70 # rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize) 71 rec_loss += F.mean_squared_error( 72 x, self.decode(z, sigmoid=False)) / (k * batchsize) 73 self.rec_loss = rec_loss 74 self.loss = self.rec_loss + \ 75 beta * gaussian_kl_divergence(mu, ln_var) / batchsize 76 chainer.report( 77 {'rec_loss': rec_loss, 'loss': self.loss}, observer=self) 78 return self.loss 79 return lf 80 81 82 83 84 85
あなたの回答
tips
プレビュー