現在、tensorflowを使って靴のサイズの分類をしています。
靴のサイズは全部で9分類になりますが、入力データは6種類しかありません。
従いまして、入力ユニット数が9でと出力ユニット数が6になります。
この条件でtensorflowの分類器を作成してデータを与えて実行したところ、出力が次のようになってしまいます。
Step: 0, Accuracy: 0.222222, Loss: nan
Step: 10, Accuracy: 0.000000, Loss: nan
Step: 20, Accuracy: 0.000000, Loss: nan
Step: 30, Accuracy: 0.000000, Loss: nan
Step: 40, Accuracy: 0.000000, Loss: nan
Step: 50, Accuracy: 0.000000, Loss: nan
Step: 60, Accuracy: 0.000000, Loss: nan
Step: 70, Accuracy: 0.000000, Loss: nan
Step: 80, Accuracy: 0.000000, Loss: nan
Step: 90, Accuracy: 0.000000, Loss: nan
Step: 100, Accuracy: 0.000000, Loss: nan
他に問題らしいところが見当たりません。
よろしくお願いします。
環境は次のとおりです。
ubuntu16.04
anaconda3
python3.6
tensorflow1.2.1
コードは以下の通りです。
ただし、いろいろいじっているので他のエラーを起こすかもしれません。
データのインポート
import tensorflow as tf
import numpy as np
import pandas as pd
def train_read():
df1 = pd.read_csv("x_train.csv")
df2 = pd.read_csv("y_train.csv")
x_ = df1.iloc[:, 0:6].as_matrix()
y_ = df2.iloc[:, 0:9].as_matrix()
return x_,y_
train_x, train_y = train_read()
print(np.shape(train_x))
print(np.shape(train_y))
def test_read():
df1 = pd.read_csv("x_test.csv")
df2 = pd.read_csv("y_test.csv")
x_ = df1.iloc[:, 0:6].as_matrix()
y_ = df2.iloc[:, 0:9].as_matrix()
return x_,y_
test_x,test_y = test_read()
print(np.shape(test_x))
print(np.shape(test_y))
モデルの作成
x = tf.placeholder(tf.float32, [None, 6])
w = tf.Variable(tf.zeros([6, 9]))
b = tf.Variable(tf.zeros([9]))
y = tf.nn.softmax(tf.matmul(x, w) + b)
損失とオプティマイザーを定義
y_ = tf.placeholder(tf.float32, [None, 9])
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
精度
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
訓練
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for i in range(101):
sess.run(train_step, feed_dict={x: train_x, y_: train_y})
if i % 10 == 0:
acc, cost = sess.run([accuracy, cross_entropy], feed_dict={x: test_x, y_: test_y})
print('Step: %d, Accuracy: %f, Loss: %f' % (i, acc, cost))
回答3件
あなたの回答
tips
プレビュー