質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Q&A

解決済

1回答

2367閲覧

学習済みモデルで推定するプログラムが意図した動きをしない(sony neural network console)

meJ15

総合スコア55

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

0グッド

0クリップ

投稿2018/10/01 07:30

編集2018/10/02 05:08

前提・実現したいこと

このページを参考にして
参考URL
sony neural network console を使って学習したパラメータを使い
新しく入力したデータのラベルを判別するプログラムを作っています。

学習済みパラメータで新しい入力データのラベルを判別するプログラムを書いて実行すると毎回同じ結果しか出ません。(違う入力データで試しても)
今の場合[0.48...]しかでません。

データはこのようなかんじでラベル1には波上の信号が入っています。ラベル0には波上の信号は入っていません。
イメージ説明

CSVファイルはこのような感じで50行2列です。
イメージ説明

出力値

2018-10-01 16:12:58,584 [nnabla][INFO]: DataSource with shuffle(False) 2018-10-01 16:12:58,589 [nnabla][INFO]: Using DataSourceWithFileCache 2018-10-01 16:12:58,591 [nnabla][INFO]: DataSource with shuffle(False) 2018-10-01 16:12:58,592 [nnabla][INFO]: Cache Directory is None 2018-10-01 16:12:58,593 [nnabla][INFO]: Cache size is 100 2018-10-01 16:12:58,595 [nnabla][INFO]: Num of thread is 10 2018-10-01 16:12:58,596 [nnabla][INFO]: Cache file format is .npy 2018-10-01 16:12:58,597 [nnabla][INFO]: Tempdir for cache C:\Users\jump1268\AppData\Local\Temp\tmpphgu65ee created. 2018-10-01 16:12:58,612 [nnabla][INFO]: Creating cache file C:\Users\jump1268\AppData\Local\Temp\tmpphgu65ee\cache_00000000_00000048.npy 2018-10-01 16:12:58,620 [nnabla][INFO]: Using DataSourceWithMemoryCache 2018-10-01 16:12:58,622 [nnabla][INFO]: DataSource with shuffle(False) 2018-10-01 16:12:58,625 [nnabla][INFO]: On-memory 2018-10-01 16:12:58,627 [nnabla][INFO]: Using DataIterator 2018-10-01 16:12:58,668 [nnabla][INFO]: Parameter load (<built-in function format>): C:\Users\jump1268\sensorcapture\manycapture\test2.files\best\parameters.h5 [[0.4805599]]

該当のソースコード

python

1import nnabla as nn 2import nnabla.functions as F 3import nnabla.parametric_functions as PF 4import nnabla.solvers as S 5from nnabla.utils.data_iterator import data_iterator_csv_dataset 6 7 8 9def network(x, y, test=False): 10 # Input:x -> 50,2 11 # Dropout 12 if not test: 13 h = F.dropout(x, 0.6736167884328452) 14 else: 15 h = x 16 # AveragePooling -> 25,1 17 h = F.average_pooling(h, (2,2), (2,2)) 18 # Affine -> 66 19 h = PF.affine(h, (66,), name='Affine') 20 # BatchNormalization 21 h = PF.batch_normalization(h, (1,), 0.9, 0.0001, not test, name='BatchNormalization') 22 # Tanh 23 h = F.tanh(h) 24 # Affine_2 -> 1 25 h = PF.affine(h, (1,), name='Affine_2') 26 # Sigmoid 27 h = F.sigmoid(h) 28 # BinaryCrossEntropy 29 #h = F.binary_cross_entropy(h, y) 30 return h 31 32 33#テスト用データの読み込み("C:\Users\jump1268\sensorcapture\manycapture\label1-28.csv")でもできた 34test_data = data_iterator_csv_dataset(r"C:\Users\jump1268\sensorcapture\manycapture\label1-15.csv",1,shuffle=False) 35 36 37 38 39#ネットワークの構築 40nn.clear_parameters() 41# Prepare input variable 42x = nn.Variable((1,50,2)) 43t = nn.Variable((1,1)) 44# Build network for inference 45y = network(x, t) 46 47# load parameters 48nn.load_parameters(r'C:\Users\jump1268\sensorcapture\manycapture\test2.files\best\parameters.h5') 49 50 51# Let input data to x.d 52# x.d = ... 53#このxにセンサ入力を入れる。 54x.d, t.d = test_data.next() 55 56# Execute inference 57y.forward() 58print(y.d)

試したこと

入力するcsvファイルを
test_data = data_iterator_csv_dataset(r"C:\Users\jump1268\sensorcapture\manycapture\label1-15.csv",1,shuffle=False)
ここでいろいろ変えても出力値が全く同じでした。

x = nn.Variable((1,50,2))
t = nn.Variable((1,1))
この最初の引数を1ではなく5とかにすると
[[0.328569 ]
[0.30326033]
[0.26028016]
[0.52961046]
[0.89365166]]
このような出力になります。

イメージ説明

学習がこのようにうまくいっているので、
せめてその学習に利用したcsvファイルを入力データとして入れればラベルが1か0で判別できると考えましたが、そのCSVファイルを入力しても変わりませんでした。
なにかが根本的に間違ってる気がするのですが、それがとこかはわからないです。どこを変えれば良いのでしょうか?

補足情報(FW/ツールのバージョンなど)

評価は次のようにほぼ確実にラベル判別できています。
イメージ説明

ここにより詳細な情報を記載してください。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

こういうケースでよくあるのが、学習するときに行っていた正規化等の前処理を行っていないでそのままネットワークに流しているというケースがありますが、その点は大丈夫でしょうか?

投稿2018/10/01 09:07

tiitoi

総合スコア21956

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

meJ15

2018/10/02 05:11

正規化して学習したパラメータを読み込み、推論するcsvファイルは正規化していなかったのでうまくいっていなかったようです。 新しく正規化せずに学習したパラメータで csvファイルを読み込むとラベル0でも[[0.98550305]]のようになってしまいました。本来なら0.053のように限りなく0に近づくはず? 学習自体はうまくいっているはずなのに(補足の写真です) 推論結果がうまくいきません。 何がおかしいのでしょうか?
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問