質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.40%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

VSCodeDevContainer

VSCode Dev Containerは、VSCodeの拡張機能の一つ。Dockerコンテナ上でVSCodeの機能が使える開発環境を構築できます。開発環境の可搬性や再現性が高く、ローカル環境への影響が低い点などが特徴です。

Q&A

解決済

1回答

266閲覧

python(tensorflow)のニューラルネットワークにてテストデータの出力が0になってします

tottatoato

総合スコア1

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

VSCodeDevContainer

VSCode Dev Containerは、VSCodeの拡張機能の一つ。Dockerコンテナ上でVSCodeの機能が使える開発環境を構築できます。開発環境の可搬性や再現性が高く、ローカル環境への影響が低い点などが特徴です。

0グッド

0クリップ

投稿2024/06/04 04:50

実現したいこと

テストデータの値を正確に出力したい

発生している問題・分からないこと

tensorflowにてニューラルネットワークに取り組んでいるのですが、テストデータの値が0と出力されてしまいます。
この理由が、例えば0.3531など少数以下に値はあるが、出力方法が間違っているために整数で0と出力されてしまうのか、
また、そもそもコードが間違っており、学習ができていないために0と出力されるのかが分かりません

該当のソースコード

python

1import keras 2from keras.layers import Activation, Dense, Dropout 3from keras.models import Sequential 4from sklearn.preprocessing import MinMaxScaler 5from tensorflow.keras.callbacks import EarlyStopping 6from tensorflow.keras.optimizers import Adam 7import pandas as pd 8import numpy as np 9import matplotlib.pyplot as plt 10 11#学習データ 12data = pd.DataFrame( 13 [ 14 #No.1 15 [9.0, 2701, 1477, 2000, 1475, 0.4018], 16 17 #No.2 18 [8.4,1675,773,2000,1364,0.2888], 19 #No.3 20 [6.6,7777,54,2000,1176,0.1017], 21 22 #No.4 23 [7.3,9686,178,1521,2000,0.1556], 24 25 #No.5 26 [7.1,3089,1840,2000,1188,0.0491], 27 28 #No.6 29 [7.8,9441,205,1322,2000,0.2298], 30 31 #No.7 32 [6.5,4763,1242,1588,2000,0.0790], 33 34 #No.8 35 [8.9,6498,1865,1943,2000,0.3776], 36 37 #No.9 38 [8.0,2489,568,1658,2000,0.2597], 39 40 #No.10 41 [6.0,7839,823,1310,2000,0.0766], 42 43 #No.11 44 [9.0,7941,1004,1720,2000,0.4532], 45 46 #No.12 47 [7.5,1452,698,2000,219,0.1003], 48 49 #No.13 50 [6.9,6007,1246,2000,257,0.0373], 51 52 #No.14 53 [8.7,8671,1744,1999,2000,0.3207], 54 55 #No.15 56 [8.1,86554,951,2000,26,0.1521], 57 58 #No.16 59 [7.0,81375,985,1400,2000,0.0912], 60 61 #No.17 62 [8.6,32889,1274,2000,623,0.2294], 63 64 #No.18 65 [9.0,48790,1446,2000,1987,0.4439], 66 67 #No.19 68 [7.3,59028,817,832,2000,0.1194], 69 70 #No.20 71 [7.9,68181,1512,1608,2000,0.1954], 72 73 #No.21 74 [6.1,93928,401,1395,2000,0.0890], 75 76 #No.22 77 [6.2,85892,6,81,2000,0.0626], 78 79 #No.23 80 [7.9,72570,1095,2000,81,0.1240], 81 82 #No.24 83 [9.0,86521,1392,1488,2000,0.4348], 84 85 #No.25 86 [6.9,34140,1716,2000,779,0.0370], 87 88 #No.26 89 [8.7,42481,875,2000,1923,0.4056], 90 91 #No.27 92 [6.5,12369,1715,1897,2000,0.0730], 93 94 #No.28 95 [8.7,22112,264,1876,2000,0.4341], 96 97 #No.29 98 [8.4,15363,58,2000,39,0.2479], 99 100 #No.30 101 [6.3,15165,195,955,2000,0.0884], 102 103 #No.31 104 [7.3,68657,1864,1978,2000,0.0797], 105 106 #No.32 107 [9.0,96352,15,1835,2000,0.4341], 108 109 #No.33 110 [6.7,735959,331,387,2000,0.0619], 111 112 #No.34 113 [9.0,531601,555,2000,1733,0.2679], 114 115 #No.35 116 [7.9,852996,612,2000,589,0.1109], 117 118 #No.36 119 [8.9,415339,1507,1586,2000,0.2668], 120 121 #No.37 122 [8.4,499514,1756,1820,2000,0.1504], 123 124 #No.38 125 [6.4,973579,122,2000,1941,0.0892], 126 127 #No.39 128 [6.1,436622,1704,2000,1388,0.0436], 129 130 #No.40 131 [9.0,360412,864,2000,594,0.2271], 132 133 #No.41 134 [8.5,646469,73,828,2000,0.1500], 135 136 #No.42 137 [7.2,908067,1790,1820,2000,0.0688], 138 139 #No.43 140 [7.3,977526,1662,2000,1013,0.0427], 141 142 #No.44 143 [6.5,130665,904,1112,2000,0.0719], 144 145 #No.45 146 [9.0,144155,1088,2000,1069,0.3495], 147 148 #No.46 149 [7.6,229788,150,2000,1730,0.1563], 150 151 #No.47 152 [7.8,645105,84,2000,870,0.1276], 153 154 #No.48 155 [6.9,500021,995,1109,2000,0.0735], 156 157 #No.49 158 [8.5,294460,690,2000,1748,0.2675], 159 160 ], 161 columns = ['M','t','X1','X2','Y2','uSv/y'] 162) 163df = pd.DataFrame(data) 164 165#説明変数 166in_data = df.iloc[:, 0:5] 167 168#目的変数 169out_data = df['uSv/y'] 170 171#正規化 172scaler = MinMaxScaler() 173in_data = pd.DataFrame(scaler.fit_transform(in_data)) 174print(in_data) 175 176#ndarrayに変換 177in_data=np.array(in_data) 178out_data=np.array(out_data) 179 180#モデルの生成 181model = keras.models.Sequential() 182model.add(Dense(units = 64,input_dim=(5))) 183model.add(Activation('relu')) 184model.add(Dense(units = 64 )) 185model.add(Activation('relu')) 186model.add(Dense(1)) 187 188 189#モデルのコンパイル 190model.compile( 191 loss = 'mse', 192 optimizer = 'adam', 193 metrics = ['mae']) 194 195#Earlystopping 196ealy_stop = EarlyStopping(monitor='val_loss',patience=30) 197 198#学習 199history = model.fit(in_data, out_data, epochs=1000,validation_split=0.2, 200 callbacks=[ealy_stop]) 201 202#テストデータ 203test_data = np.array( 204 [ 205 #No.50 206 [7.0,5647,1000,1266,2000], 207 [6.3,75894,539,1111,2000], 208 [8.5,142747,1788,2000,242], 209 [6.8,16474,539,1266,1600], 210 [9.0,174947,539,1266,2000], 211 [7.7,75947,539,1266,2000] 212 ] 213) 214 215#正規化 216scaler = MinMaxScaler() 217test_data = pd.DataFrame(scaler.fit_transform(test_data)) 218print(test_data) 219 220#予測 221result = np.argmax(model.predict(test_data) ,axis=1) 222 223#予測結果の出力 224print(result) 225 226#正解の出力 227print(out_data) 228

試したこと・調べたこと

  • teratailやGoogle等で検索した
  • ソースコードを自分なりに変更した
  • 知人に聞いた
  • その他
上記の詳細・結果

np.set_printoptions(precision=3)
print(out_data)
を試してみましたが変わらず0でした

補足

特になし

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

tottatoato

2024/06/04 04:50

先週はじめたレベルなのでまだ、右も左も分からない状態ですが、よろしくお願いいたします。
guest

回答1

0

ベストアンサー

np.argmaxは「何番目の要素が最大か?」を返す関数です。

>>> import numpy as np >>> np.argmax([1,3,5,3,2]) 2 >>> np.argmax([1,2,3,5,4]) 3

2次元構造のデータにaxis=1にすると各要素の中で何番目の要素が最大か? が返ってきます。

>>> np.argmax([[0,1,2],[3,4,1],[5,1,2],[4,2,1],[1,2,5]], axis=1) array([2, 1, 0, 0, 2])

今のコードで、欲しい結果がスカラーで最終層がDense(1)なのですべての要素数は1個ですね。
argmaxしたらすべて0になります。(最大の要素は0番目に決まってますから)

>>> np.argmax([[1],[2],[3],[5],[4]], axis=1) array([0, 0, 0, 0, 0])

まずは

#予測 result = model.predict(test_data) #予測結果の出力 print(result)

を眺めてみるといいのではないでしょうか。

投稿2024/06/04 06:27

編集2024/06/04 09:23
quickquip

総合スコア11165

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

tottatoato

2024/06/05 03:54

回答ありがとうございます。 試したところ問題が解決しました! ベストアンサーに選ばせていただきました。 助かりました、ありがとうございます!!
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.40%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問