質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

89.86%

VAEの誤差関数にMSEを使うと学習がうまく進まない。

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 1,762

physics303

score 79

やったこと

Chainerユーザーです。Chainerを使ってVAEを実装しました。参考にしたURLは

Variational Autoencoder徹底解説
AutoEncoder, VAE, CVAEの比較
PyTorch+Google ColabでVariational Auto Encoderをやってみた

などです。実装したコードのコアになる部分は以下の通りです。

class VAE(chainer.Chain):

    def __init__(self, n_in, n_latent, n_h, act_func=F.tanh):
        super(VAE, self).__init__()
        self.act_func = act_func
        with self.init_scope():
            # encoder
            self.le1        = L.Linear(n_in, n_h)
            self.le2        = L.Linear(n_h,  n_h)
            self.le3_mu     = L.Linear(n_h,  n_latent)
            self.le3_ln_var = L.Linear(n_h,  n_latent)

            # decoder
            self.ld1 = L.Linear(n_latent, n_h)
            self.ld2 = L.Linear(n_h,      n_h)
            self.ld3 = L.Linear(n_h,      n_in)

    def __call__(self, x, sigmoid=True):
        return self.decode(self.encode(x)[0], sigmoid)

    def encode(self, x):
        h1 = self.act_func(self.le1(x))
        h2 = self.act_func(self.le2(h1))
        mu = self.le3_mu(h2)
        ln_var = self.le3_ln_var(h2) 
        return mu, ln_var

    def decode(self, z, sigmoid=True):
        h1 = self.act_func(self.ld1(z))
        h2 = self.act_func(self.ld2(h1))
        h3 = self.ld3(h2)
        if sigmoid:
            return F.sigmoid(h3)
        else:
            return h3

    def get_loss_func(self, C=1.0, k=1):
        def lf(x):
            mu, ln_var = self.encode(x)
            batchsize = len(mu.data)
            # reconstruction error
            rec_loss = 0
            for l in six.moves.range(k):
                z = F.gaussian(mu, ln_var)
                z.name = "z"
                rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize)
            self.rec_loss = rec_loss
            self.rec_loss.name = "reconstruction error"
            self.latent_loss = C * gaussian_kl_divergence(mu, ln_var) / batchsize
            self..name = "latent loss"
            self.loss = self.rec_loss + self.latent_loss
            self.loss.name = "loss"
            return self.loss
        return lf

rec_lossは再構成誤差、すなわち入力と出力がどの程度等しいかを表していて、latent_lossの方は特徴量空間における分布が正規分布からどれくらいことなるを表す誤差だと認識しています。

MNISTで実験してみた結果、
1.lossが減少していく
2.再構成がちゃんと行われる(input画像が3なら、output画像も3になっている)
3.特徴量空間でランダムサンプリングを行った結果、ちゃんと数字が出力される。
などが確かめられました。

疑問

ところで、疑問なのですが、rec_lossは再構成誤差なので、素朴には平均二乗誤差をつかうのが自然だと思われます。そこでrec_lossの部分を

rec_loss += F.mean_squared_error(x, self.decode(z)) / k

と書き換え、ほかの条件は全部そのままで(他の部分は一切書き換えずに)、実験すると

  1. rec_lossが2epoch目以降、全く減少しない。
  2. 再構成が行われない。適当なinputを与えても、意味のないoutput画像が得られる。(ちなみにinput画像の種類によらずoutput画像は一定のようです)
    などの結果が得られて、学習がうまくいっていっていないようです。

これはなぜでしょうか。

追記1

bernoulli_nllはデフォルトではすべて合計する一方で、mean_squared_errorは二乗誤差をバッチとピクセルの両方で平均するので、再構成項が過小評価されてしまっているかもしれないと考えました。

MNSITは28×28の画像なので、MSEを用いる際には28×28×F.mean_squared_error(x, decode(z))とすれば良いと思い、試してみましたが結果は変わりませんでした。

追記2

chainerでMSEを使ったVAEの実装を行っているコードを見つけました。
(https://github.com/maguro27/VAE-CIFAR10_chainer/blob/master/VAE_CIFAR10.ipynb)
なぜ、このコードでは動いて、私の上のコードでは学習がうまくいかないのでしょうか。

追記3

optimizerの問題ではないかとの指摘を受けて、Adam,AdaDelta,SGDで試しましたが結果は変わらず…

解決策?

passerbyさんに言われた通り、

rec_loss += F.mean_squared_error(x, self.decode(z)) / k 


rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))


に変えるとうまくいきます。でもなぜだ…?

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • passerby

    2019/06/06 22:14

    rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))はどうですか?

    キャンセル

  • physics303

    2019/06/07 08:55

    コメントありがとうございます。やってみます。

    キャンセル

  • physics303

    2019/06/07 09:57

    MNISTの場合についてですが、うまくいきました!!

    rec_loss += F.mean_squared_error(x, self.decode(z)) / k

    がうまくいかずに、

    rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))

    でうまくいくのはどういう理屈なのでしょう。どちらも同じことをやっている印象なのですが…

    キャンセル

回答 1

check解決した方法

0

passerbyさんに言われた通り、

rec_loss += F.mean_squared_error(x, self.decode(z)) / k 


rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))


に変えるとうまくいきます。でもなぜだ…?

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 89.86%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

同じタグがついた質問を見る