質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.46%
CSV

CSV(Comma-Separated Values)はコンマで区切られた明白なテキスト値のリストです。もしくは、そのフォーマットでひとつ以上のリストを含むファイルを指します。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

0回答

947閲覧

ニューラルネットの精度を上げたい

hate_tometo

総合スコア3

CSV

CSV(Comma-Separated Values)はコンマで区切られた明白なテキスト値のリストです。もしくは、そのフォーマットでひとつ以上のリストを含むファイルを指します。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/12/08 05:30

編集2021/12/09 07:31

前提・実現したいこと

python、ニューラルネット初心者です。

python3を用いてニューラルネットを構築しています。
CSVデータに保存したデータを使って、学習を行い未知のデータに対して10段階の分類を行うことが最終目標です。
最終的な精度が30~40%と低いままで向上しないため、精度が上がるようなアドバイスをいただきたいです。

回答にあたって、必要な情報が欠けていましたら補足いたしますので、教えていただけますと幸いです。

発生している問題・エラーメッセージ

Epoch 1/200 237/237 [==============================] - 1s 2ms/step - loss: 0.1512 - accuracy: 0.3215 - val_loss: 0.1505 - val_accuracy: 0.3241 ... Epoch 200/200 237/237 [==============================] - 0s 1ms/step - loss: 0.1440 - accuracy: 0.3781 - val_loss: 0.1507 - val_accuracy: 0.3141 13/13 [==============================] - 0s 795us/step - loss: 0.1507 - accuracy: 0.3141 Test loss: 0.15066926181316376 Test accuracy: 0.31407034397125244

該当のソースコード

python3

1 2from __future__ import print_function 3import csv 4 5import pandas as pd 6from pandas import Series,DataFrame 7 8from sklearn import svm 9from sklearn.model_selection import train_test_split 10from sklearn.metrics import accuracy_score 11 12import numpy as np 13import matplotlib.pyplot as plt 14 15import tensorflow as tf 16 17from tensorflow import keras 18from tensorflow.keras import models 19from tensorflow.keras import optimizers 20from tensorflow.keras.datasets import fashion_mnist 21from tensorflow.keras import Sequential 22from tensorflow.keras.layers import Dense, Dropout 23from tensorflow.keras.models import Model 24from tensorflow.keras.layers import Input 25 26 27#CSVファイルの読み込み 28data_set = pd.read_csv("dataset_new2.csv",sep=",",header=0) 29#説明変数(データセットの入力) 30x = DataFrame(data_set.drop("result",axis=1)) 31 32#目的変数(得点の10段階評価) 33y = DataFrame(data_set["result"]) 34 35 36#説明変数・目的変数をそれぞれ訓練データ・テストデータに分割 37x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.05) 38 39#データの整形 40x_train = x_train.astype(np.float) 41x_test = x_test.astype(np.float) 42 43y_train = tf.keras.utils.to_categorical(y_train,5) 44y_test = tf.keras.utils.to_categorical(y_test,5) 45 46#ニューラルネットワークの実装① 47model = tf.keras.Sequential() 48 49model.add(Dense(50, activation='relu', input_shape=(14,))) 50model.add(Dropout(0.2)) 51 52model.add(Dense(50, activation='relu', input_shape=(14,))) 53model.add(Dropout(0.2)) 54 55model.add(Dense(50, activation='relu', input_shape=(14,))) 56model.add(Dropout(0.2)) 57 58model.add(Dense(5, activation='softmax')) 59 60model.summary() 61print("\n") 62 63#ニューラルネットワークの実装② 64model.compile(loss='mean_squared_error',optimizer=tf.keras.optimizers.Adam( 65 learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False, 66 name='Adam'),metrics=['accuracy']) 67 68# optimizer=tf.keras.optimizers.RMSprop( 69# learning_rate=0.001, rho=0.9, momentum=0.0, epsilon=1e-07, centered=False, 70# name='RMSprop'),metrics=['accuracy']) 71 72 73 74#ニューラルネットワークの学習 75history = model.fit(x_train, y_train,batch_size=32,epochs=200,verbose=1,validation_data=(x_test, y_test)) 76 77#ニューラルネットワークの推論 78score = model.evaluate(x_test,y_test,verbose=1) 79print("\n") 80print("Test loss:",score[0]) 81print("Test accuracy:",score[1]) 82

該当のデータセット

csv

1year,day,time,num,sup_x,sup_y,sup_s,sup_t,em_x,em_y,em_s,em_t,hm_aw,capacty,result 20.994,0.603,0.924,0.506,0.196,0.248,0.194,0.36,0.192,0.126,0.164,0.316,0,0.691,0 30.994,0.603,0.924,0.271,0.092,0.128,0.171,0.124,0.368,0.841,0.486,0.558,0,0.235,1 40.994,0.603,0.949,0.233,0.279,0.168,0.343,0.363,0.124,0.129,0.166,0.165,0,0.29,2 50.994,0.603,0.949,0.333,0.155,0.161,0.16,0.168,0.263,0.13,0.26,0.343,0,0.553,1 60.994,0.603,0.95,0.228,0.203,0.198,0.244,0.349,0.187,0.109,0.171,0.257,0,0.563,2 70.994,0.603,0.95,0.398,0.195,0.193,0.225,0.254,0.313,0.187,0.239,0.417,0,0.704,2 80.994,0.604,0.949,0.271,0.212,0.188,0.258,0.321,0.363,0.201,0.182,0.517,0,0.273,2 90.994,0.604,0.949,0.384,0.406,0.275,0.215,0.494,0.155,0.33,0.17,0.279,0,1,2 100.994,0.672,0.896,0.169,0.192,0.126,0.164,0.316,0.195,0.193,0.225,0.254,0,0.51,1 110.994,0.672,0.946,0.662,0.155,0.33,0.17,0.279,0.212,0.188,0.258,0.321,0,0.585,1 120.994,0.672,0.946,0.201,0.187,0.109,0.171,0.257,0.237,0.084,0.163,0.278,0,0.661,0 130.994,0.672,0.947,0.199,0.363,0.201,0.182,0.517,0.092,0.128,0.171,0.124,0,0.373,4 140.994,0.672,0.949,0.801,0.368,0.841,0.486,0.558,0.196,0.248,0.194,0.36,0,0.881,4 150.994,0.672,0.949,0.322,0.263,0.13,0.26,0.343,0.406,0.275,0.215,0.494,0,0.363,1 160.994,0.672,0.949,0.485,0.313,0.187,0.239,0.417,0.203,0.198,0.244,0.349,0,0.704,3 170.994,0.672,0.949,0.257,0.233,0.111,0.177,0.239,0.279,0.168,0.343,0.363,0,0.285,1 180.994,0.672,0.949,0.291,0.124,0.129,0.166,0.165,0.155,0.161,0.16,0.168,0,0.298,0 190.994,0.678,0.946,0.23,0.363,0.201,0.182,0.517,0.313,0.187,0.239,0.417,0,0.373,3 200.994,0.678,0.946,0.166,0.124,0.129,0.166,0.165,0.233,0.111,0.177,0.239,0,0.298,4 210.994,0.678,0.948,0.297,0.187,0.109,0.171,0.257,0.406,0.275,0.215,0.494,0,0.661,2 220.994,0.678,0.949,0.547,0.203,0.198,0.244,0.349,0.368,0.841,0.486,0.558,0,0.563,2 230.994,0.678,0.949,0.181,0.237,0.084,0.163,0.278,0.279,0.168,0.343,0.363,0,0.297,0 240.994,0.678,0.949,0.616,0.155,0.33,0.17,0.279,0.263,0.13,0.26,0.343,0,0.585,2 250.994,0.678,0.949,0.246,0.195,0.193,0.225,0.254,0.092,0.128,0.171,0.124,0,0.28,4 260.994,0.679,0.9,0.162,0.192,0.126,0.164,0.316,0.155,0.161,0.16,0.168,0,0.51,0 270.994,0.679,0.949,0.276,0.212,0.188,0.258,0.321,0.196,0.248,0.194,0.36,0,0.273,3 280.994,0.681,0.947,0.181,0.263,0.13,0.26,0.343,0.363,0.201,0.182,0.517,0,0.363,4 290.994,0.681,0.948,0.191,0.313,0.187,0.239,0.417,0.212,0.188,0.258,0.321,0,0.21,3 300.994,0.681,0.949,0.55,0.368,0.841,0.486,0.558,0.155,0.33,0.17,0.279,0,0.881,3 310.994,0.681,0.949,0.296,0.196,0.248,0.194,0.36,0.124,0.129,0.166,0.165,0,0.664,4 320.994,0.681,0.949,0.193,0.092,0.128,0.171,0.124,0.203,0.198,0.244,0.349,0,0.235,2 330.994,0.681,0.949,0.18,0.233,0.111,0.177,0.239,0.187,0.109,0.171,0.257,0,0.285,4 340.994,0.681,0.949,0.285,0.279,0.168,0.343,0.363,0.192,0.126,0.164,0.316,0,0.29,3 350.994,0.681,0.949,0.301,0.155,0.161,0.16,0.168,0.195,0.193,0.225,0.254,0,0.553,3 360.994,0.681,0.95,0.186,0.406,0.275,0.215,0.494,0.237,0.084,0.163,0.278,0,0.214,1 370.994,0.684,0.899,0.145,0.237,0.084,0.163,0.278,0.263,0.13,0.26,0.343,0,0.213,1 380.994,0.684,0.921,0.224,0.203,0.198,0.244,0.349,0.192,0.126,0.164,0.316,0,0.563,0 390.994,0.684,0.924,0.227,0.092,0.128,0.171,0.124,0.313,0.187,0.239,0.417,0,0.235,1 400.994,0.684,0.946,0.283,0.124,0.129,0.166,0.165,0.212,0.188,0.258,0.321,0,0.298,1 410.994,0.684,0.949,0.646,0.155,0.33,0.17,0.279,0.155,0.161,0.16,0.168,0,0.585,3 420.994,0.684,0.949,0.258,0.195,0.193,0.225,0.254,0.196,0.248,0.194,0.36,0,0.28,2 430.994,0.684,0.95,0.373,0.187,0.109,0.171,0.257,0.368,0.841,0.486,0.558,0,0.661,1 440.994,0.685,0.949,0.142,0.233,0.111,0.177,0.239,0.406,0.275,0.215,0.494,0,0.285,0 450.994,0.685,0.95,0.28,0.363,0.201,0.182,0.517,0.279,0.168,0.343,0.363,0,0.373,3 460.994,0.687,0.921,0.183,0.212,0.188,0.258,0.321,0.263,0.13,0.26,0.343,0,0.273,1 470.994,0.687,0.945,0.151,0.203,0.198,0.244,0.349,0.363,0.201,0.182,0.517,0,0.563,2 480.994,0.687,0.946,0.148,0.195,0.193,0.225,0.254,0.233,0.111,0.177,0.239,0,0.28,1 490.994,0.687,0.949,0.223,0.279,0.168,0.343,0.363,0.155,0.33,0.17,0.279,0,0.29,3 500.994,0.687,0.949,0.078,0.237,0.084,0.163,0.278,0.124,0.129,0.166,0.165,0,0.297,2 510.994,0.687,0.949,0.336,0.196,0.248,0.194,0.36,0.187,0.109,0.171,0.257,0,0.664,2 520.994,0.687,0.949,0.152,0.192,0.126,0.164,0.316,0.313,0.187,0.239,0.417,0,0.51,2 530.994,0.687,0.949,0.445,0.155,0.161,0.16,0.168,0.368,0.841,0.486,0.558,0,0.553,2 540.994,0.687,0.95,0.198,0.406,0.275,0.215,0.494,0.092,0.128,0.171,0.124,0,0.214,3 550.994,0.752,0.8,0.364,0.406,0.275,0.215,0.494,0.263,0.13,0.26,0.343,0,1,1 56 57 58 59

試したこと

バッチサイズの変更、学習率の変更、最適化アルゴリズムの変更、学習回数の変更、ユニット数の変更を行いましたが、精度が大きく向上することはありませんでした。

データセットが不均衡なデータになっていたため、サンプリングを行い、少数派のデータが占める割合を増やしましたが、精度が向上することはありませんでした。
サンプリングを行った場合、10段階の分類→5段階の分類に変更しました。

データセットの正規化を行いましたが、精度が向上することはありませんでした。

公開されているライブラリのデータセットを用いて学習を行うと90%以上の成果が得られました。

補足情報(FW/ツールのバージョンなど)

python3
tensorflow 2.6.2
keras 2.6.0

使用しているデータセットはこちらです。
戯画ファイル便を使用しての共有になります。
https://13.gigafile.nu/1216-c23c9e68e2185c05ada3cb58098943c01

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

jbpb0

2021/12/08 11:05

> 公開されているライブラリのデータセットを用いて学習を行うと90%以上の成果が得られました。 それならデータ依存の現象だから、データをもっとたくさん開示してくれないと、他人には何が起きてるのか分かりません データ全部とは言いませんが、他人が学習させて現象を再現できる程度のデータ量は要ります
jbpb0

2021/12/08 11:11 編集

「input_shape=(14,)」が3ヶ所にありますけど、一番最初の以外は要らないのでは? (未確認ですが)
hate_tometo

2021/12/08 16:48

ご回答ありがとうございます。 データの開示についてですが、字数制限以内(10000字)でできる限りのデータを載せました。 必要な情報の不足申し訳ありません。 ご指摘いただいた、input_shapeについてですが、ご指摘通りの修正を行ったところ動作が確認できましたので、不必要な入力でした。 ありがとうございます。
hate_tometo

2021/12/09 12:38

ギガファイル便での共有でしたら、可能でしたのでURLを添付いたしました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.46%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問