質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

90.86%

  • Python 3.x

    4806questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Autoencoderで中間層抽出

解決済

回答 1

投稿

  • 評価
  • クリップ 1
  • VIEW 133

atena

score 10

testの出力が1しか出ないのですがプログラムのどこを直せばよいかわかりません。
参考コード
http://blog.yusugomori.com/post/42244843471/python%E3%81%AB%E3%82%88%E3%82%8Bdeep-learning%E3%81%AE%E5%AE%9F%E8%A3%85denoising-autoencoders

import sys
import numpy
import csv


numpy.seterr(all='ignore')

def sigmoid(x):
    return 1 / (1 + numpy.exp(-x))


class dA(object):
    # パラメータの初期化
    def __init__(self, input=None, n_visible=5, n_hidden=1, \
        W=None, hbias=None, vbias=None, numpy_rng=None):

        self.n_visible = n_visible  # 入力層のユニット数
        self.n_hidden = n_hidden    # 中間層のユニット数

        if numpy_rng is None:
            numpy_rng = numpy.random.RandomState(1234)

        if W is None:
            a = 1. / n_visible
            initial_W = numpy.array(numpy_rng.uniform(  # 均一にW(重み)を初期化する
                low=-a,
                high=a,
                size=(n_visible, n_hidden)))

            W = initial_W
            print(W)

        if hbias is None:
            hbias = numpy.zeros(n_hidden)  # 中間層のバイアスを0で初期化する

        if vbias is None:
            vbias = numpy.zeros(n_visible)  # 中間層のバイアスを0で初期化する

        self.numpy_rng = numpy_rng
        self.x = input
        self.W = W
        self.W_prime = self.W.T
        self.hbias = hbias
        self.vbias = vbias

        # self.params = [self.W, self.hbias, self.vbias]


    def get_corrupted_input(self, input, corruption_level):
        assert corruption_level < 1

#        print (self.numpy_rng.binomial(size=input.shape,n=1,p=1-corruption_level))
        return self.numpy_rng.binomial(size=input.shape,
                                       n=1,
                                       p=1-corruption_level) * input

    # Encode
    def get_hidden_values(self, input):
#        print (sigmoid(numpy.dot(input, self.W) + self.hbias))
        return sigmoid(numpy.dot(input, self.W) + self.hbias)  #シグモイド関数(入力*重みの総和)

    # Decode
    def get_reconstructed_input(self, hidden):
        return sigmoid(numpy.dot(hidden, self.W_prime) + self.vbias)  #シグモイド関数(中間*重みの総和)


    def train(self, lr=0.1, corruption_level=0.3, input=None):
        if input is not None:
            self.x = input

        x = self.x
        tilde_x = self.get_corrupted_input(x, corruption_level)
        y = self.get_hidden_values(tilde_x)
        z = self.get_reconstructed_input(y)

        L_h2 = x - z
        L_h1 = numpy.dot(L_h2, self.W) * y * (1 - y)

        L_vbias = L_h2
        L_hbias = L_h1
        L_W =  numpy.dot(tilde_x.T, L_h1) + numpy.dot(L_h2.T, y)


        self.W += lr * L_W
        self.hbias += lr * numpy.mean(L_hbias, axis=0)
        self.vbias += lr * numpy.mean(L_vbias, axis=0)

        print(L_W)



    def negative_log_likelihood(self, corruption_level=0.3):
        tilde_x = self.get_corrupted_input(self.x, corruption_level)
        y = self.get_hidden_values(tilde_x)
        z = self.get_reconstructed_input(y)

        cross_entropy = - numpy.mean(
            numpy.sum(self.x * numpy.log(z) +
            (1 - self.x) * numpy.log(1 - z),
                      axis=1))

        return cross_entropy


    def reconstruct(self, x):
        y = self.get_hidden_values(x)
        z = self.get_reconstructed_input(y)
        return y



def test_dA(learning_rate=0.01, corruption_level=0.3, training_epochs=100):
#    data_c = numpy.array([])
#    with open('train.csv','r') as f:
#        reader = csv.reader(f)
#        for row in reader:
#            print(row)
#            rowf = numpy.array(row)
#            rowf = rowf.astype(numpy.float32)
#            print(rowf)
#            data_c = numpy.append(data_c, rowf)
    data = numpy.loadtxt('train.csv', delimiter=',', dtype='int')
    rng = numpy.random.RandomState(123)

    print(data)

    # construct dA
    da = dA(input=data, n_visible=5, n_hidden=3, numpy_rng=rng)

    # train
    for epoch in range(training_epochs):
        da.train(lr=learning_rate, corruption_level=corruption_level)
        # cost = da.negative_log_likelihood(corruption_level=corruption_level)
        # print >> sys.stderr, 'Training epoch %d, cost is ' % epoch, cost
        # learning_rate *= 0.95


    # test
#    x = numpy.array([])
#    with open('test.csv','r') as f:
#        reader = csv.reader(f)
#        for row in reader:
#            x = numpy.append(x, row)
#    x = x.astype(numpy.float32)
    x = numpy.loadtxt('test.csv', delimiter=',', dtype='int')

    print (da.reconstruct(x))



if __name__ == "__main__":
    test_dA()
[[ 753 1439 2191 3041  185]
 [ 692 2191 3205 4142  286]
 [ 330  633 2069 3205  167]
 ...,
 [ 523  633 2450 3743  275]
 [ 577  692 1101 3377  292]
 [ 633 1349 2450 3556  220]]
[[ 0.07858767 -0.08554427 -0.10925942]
 [ 0.02052591  0.08778759 -0.03075742]
 [ 0.19230568  0.0739319  -0.00762724]
 [-0.04315299 -0.06272879  0.09161988]
 [-0.0245711  -0.17612884 -0.0407823 ]]
########################
重み割愛
########################
[[ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]]


テストデータ

[[ 753 1734 2733 4142  280]
 [ 692 1181 2884 4354  156]
 [ 817 1439 3556 4354  140]
 [ 817 1026 1533 3743  134]
 [ 817 1734 2588 3938  286]
 [ 817 1631 3377 4354  253]
 [ 692 1734 2588 4576  227]
 [ 692 1631 2733 3938  128]
 [ 817 1631 2450 3938  180]
 [ 817 1841 2588 3556  136]
 [ 883 1181 2588 3938  220]
 [ 692 1181 2884 4354  279]
 [ 692  883 1263 4142  213]
 [ 953 1263 2884 4142  191]
 [ 883 1263 3041 3938  175]
 [ 883 1263 1631 3556  184]
 [ 817 1181 3041 3938  273]
 [ 633  883 1263 3938  106]
 [ 817 1181 3041 3743  252]
 [ 817 1263 3041 3743  189]
 [ 817 1263 3205 3743  200]
 [ 692 1439 2884 4354  247]
 [ 753 1263 3205 4142  129]
 [ 692 1181 2884 3938  107]
 [ 883 1181 2884 3743  288]
 [2069 2317 3041 3556  267]
 [ 692 1181 2733 4142  153]
 [ 692 1263 2733 3938  132]
 [ 817 1263 2884 3938  209]
 [ 817 1181 3205 3743  268]
 [ 472  883 1349 3938  249]
 [ 692 1439 2884 4354  247]
 [ 817 1263 2588 3938  204]
 [ 287 1263 2588 4142  294]
 [ 817 1349 2884 4142  212]
 [1439 1841 2884 3377  286]
 [ 692 1263 2884 3938  209]
 [ 753  953 1349 4142  168]
 [ 883 1263 2884 3556  270]
 [ 692 1631 3041 3938  290]
 [ 523 1734 2884 3556  204]
 [ 692 1439 2884 4354  247]
 [ 633 1263 1952 4142  198]
 [ 692 1263 2588 4142  232]
 [ 692 1631 2733 3743  127]
 [1631 2450 3041 3743  238]
 [ 692 1263 2733 4142  228]
 [ 633 1263 2588 4354  193]
 [ 817 1439 2884 3938  277]
 [ 692 1533 2733 3938  221]]
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

checkベストアンサー

+1

うまく学習できていないだけです。
データの数字が大きすぎるので重みが変に大きくなった結果、シグモイド関数が1になります。

データの最大値などを決めてやって、それを割ってやるとうまくいくはずです。
正規化・標準化とかって言うやつです。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 90.86%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る

  • Python 3.x

    4806questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。