質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

1回答

1371閲覧

Autoencoderで中間層抽出

atena

総合スコア20

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

1クリップ

投稿2017/12/29 02:04

testの出力が1しか出ないのですがプログラムのどこを直せばよいかわかりません。
参考コード
http://blog.yusugomori.com/post/42244843471/python%E3%81%AB%E3%82%88%E3%82%8Bdeep-learning%E3%81%AE%E5%AE%9F%E8%A3%85denoising-autoencoders

import sys import numpy import csv numpy.seterr(all='ignore') def sigmoid(x): return 1 / (1 + numpy.exp(-x)) class dA(object): # パラメータの初期化 def __init__(self, input=None, n_visible=5, n_hidden=1, \ W=None, hbias=None, vbias=None, numpy_rng=None): self.n_visible = n_visible # 入力層のユニット数 self.n_hidden = n_hidden # 中間層のユニット数 if numpy_rng is None: numpy_rng = numpy.random.RandomState(1234) if W is None: a = 1. / n_visible initial_W = numpy.array(numpy_rng.uniform( # 均一にW(重み)を初期化する low=-a, high=a, size=(n_visible, n_hidden))) W = initial_W print(W) if hbias is None: hbias = numpy.zeros(n_hidden) # 中間層のバイアスを0で初期化する if vbias is None: vbias = numpy.zeros(n_visible) # 中間層のバイアスを0で初期化する self.numpy_rng = numpy_rng self.x = input self.W = W self.W_prime = self.W.T self.hbias = hbias self.vbias = vbias # self.params = [self.W, self.hbias, self.vbias] def get_corrupted_input(self, input, corruption_level): assert corruption_level < 1 # print (self.numpy_rng.binomial(size=input.shape,n=1,p=1-corruption_level)) return self.numpy_rng.binomial(size=input.shape, n=1, p=1-corruption_level) * input # Encode def get_hidden_values(self, input): # print (sigmoid(numpy.dot(input, self.W) + self.hbias)) return sigmoid(numpy.dot(input, self.W) + self.hbias) #シグモイド関数(入力*重みの総和) # Decode def get_reconstructed_input(self, hidden): return sigmoid(numpy.dot(hidden, self.W_prime) + self.vbias) #シグモイド関数(中間*重みの総和) def train(self, lr=0.1, corruption_level=0.3, input=None): if input is not None: self.x = input x = self.x tilde_x = self.get_corrupted_input(x, corruption_level) y = self.get_hidden_values(tilde_x) z = self.get_reconstructed_input(y) L_h2 = x - z L_h1 = numpy.dot(L_h2, self.W) * y * (1 - y) L_vbias = L_h2 L_hbias = L_h1 L_W = numpy.dot(tilde_x.T, L_h1) + numpy.dot(L_h2.T, y) self.W += lr * L_W self.hbias += lr * numpy.mean(L_hbias, axis=0) self.vbias += lr * numpy.mean(L_vbias, axis=0) print(L_W) def negative_log_likelihood(self, corruption_level=0.3): tilde_x = self.get_corrupted_input(self.x, corruption_level) y = self.get_hidden_values(tilde_x) z = self.get_reconstructed_input(y) cross_entropy = - numpy.mean( numpy.sum(self.x * numpy.log(z) + (1 - self.x) * numpy.log(1 - z), axis=1)) return cross_entropy def reconstruct(self, x): y = self.get_hidden_values(x) z = self.get_reconstructed_input(y) return y def test_dA(learning_rate=0.01, corruption_level=0.3, training_epochs=100): # data_c = numpy.array([]) # with open('train.csv','r') as f: # reader = csv.reader(f) # for row in reader: # print(row) # rowf = numpy.array(row) # rowf = rowf.astype(numpy.float32) # print(rowf) # data_c = numpy.append(data_c, rowf) data = numpy.loadtxt('train.csv', delimiter=',', dtype='int') rng = numpy.random.RandomState(123) print(data) # construct dA da = dA(input=data, n_visible=5, n_hidden=3, numpy_rng=rng) # train for epoch in range(training_epochs): da.train(lr=learning_rate, corruption_level=corruption_level) # cost = da.negative_log_likelihood(corruption_level=corruption_level) # print >> sys.stderr, 'Training epoch %d, cost is ' % epoch, cost # learning_rate *= 0.95 # test # x = numpy.array([]) # with open('test.csv','r') as f: # reader = csv.reader(f) # for row in reader: # x = numpy.append(x, row) # x = x.astype(numpy.float32) x = numpy.loadtxt('test.csv', delimiter=',', dtype='int') print (da.reconstruct(x)) if __name__ == "__main__": test_dA()
[[ 753 1439 2191 3041 185] [ 692 2191 3205 4142 286] [ 330 633 2069 3205 167] ..., [ 523 633 2450 3743 275] [ 577 692 1101 3377 292] [ 633 1349 2450 3556 220]] [[ 0.07858767 -0.08554427 -0.10925942] [ 0.02052591 0.08778759 -0.03075742] [ 0.19230568 0.0739319 -0.00762724] [-0.04315299 -0.06272879 0.09161988] [-0.0245711 -0.17612884 -0.0407823 ]] ######################## 重み割愛 ######################## [[ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.] [ 1. 1. 1.]]

テストデータ

[[ 753 1734 2733 4142 280] [ 692 1181 2884 4354 156] [ 817 1439 3556 4354 140] [ 817 1026 1533 3743 134] [ 817 1734 2588 3938 286] [ 817 1631 3377 4354 253] [ 692 1734 2588 4576 227] [ 692 1631 2733 3938 128] [ 817 1631 2450 3938 180] [ 817 1841 2588 3556 136] [ 883 1181 2588 3938 220] [ 692 1181 2884 4354 279] [ 692 883 1263 4142 213] [ 953 1263 2884 4142 191] [ 883 1263 3041 3938 175] [ 883 1263 1631 3556 184] [ 817 1181 3041 3938 273] [ 633 883 1263 3938 106] [ 817 1181 3041 3743 252] [ 817 1263 3041 3743 189] [ 817 1263 3205 3743 200] [ 692 1439 2884 4354 247] [ 753 1263 3205 4142 129] [ 692 1181 2884 3938 107] [ 883 1181 2884 3743 288] [2069 2317 3041 3556 267] [ 692 1181 2733 4142 153] [ 692 1263 2733 3938 132] [ 817 1263 2884 3938 209] [ 817 1181 3205 3743 268] [ 472 883 1349 3938 249] [ 692 1439 2884 4354 247] [ 817 1263 2588 3938 204] [ 287 1263 2588 4142 294] [ 817 1349 2884 4142 212] [1439 1841 2884 3377 286] [ 692 1263 2884 3938 209] [ 753 953 1349 4142 168] [ 883 1263 2884 3556 270] [ 692 1631 3041 3938 290] [ 523 1734 2884 3556 204] [ 692 1439 2884 4354 247] [ 633 1263 1952 4142 198] [ 692 1263 2588 4142 232] [ 692 1631 2733 3743 127] [1631 2450 3041 3743 238] [ 692 1263 2733 4142 228] [ 633 1263 2588 4354 193] [ 817 1439 2884 3938 277] [ 692 1533 2733 3938 221]]

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

うまく学習できていないだけです。
データの数字が大きすぎるので重みが変に大きくなった結果、シグモイド関数が1になります。

データの最大値などを決めてやって、それを割ってやるとうまくいくはずです。
正規化・標準化とかって言うやつです。

投稿2017/12/31 13:55

mkgrei

総合スコア8560

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問