質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
86.12%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

解決済

MNISTの0〰9数字を分類するコードで学習が進みません。

kobahot
kobahot

総合スコア10

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

1回答

0リアクション

0クリップ

271閲覧

投稿2022/09/17 08:58

前提

初心者です。
MNISTの0〰9数字を分類するコードをスクラッチで書いています。

実現したいこと

softmax関数の計算がおかしいようで学習が進ません。
class SoftmaxCrossEntropyLoss()の中で使っています。
softmaxの書き方が悪いのか、層の計算が間違っているのかも分かりません。
何んとなく怪しい箇所は、#怪しい と記載しました。

発生している問題・エラーメッセージ

RuntimeWarning: invalid value encountered in subtract x = x-np.max(x,axis=0)

該当のソースコード

Python3.9

import csv import os import pickle import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import cv2 from sklearn import datasets from sklearn.model_selection import train_test_split %matplotlib inline np.random.seed(seed=0) #怪しい def softmax(x): x = x.T x = x-np.max(x,axis=0) x = np.exp(x)/np.sum(np.exp(x)) return x.T # mnistデータセット mnist = datasets.fetch_openml('mnist_784', as_frame=False) # 画像とラベルを取得 X, T = mnist.data, mnist.target # 訓練データとテストデータに分割 X_train, X_test, T_train, T_test = train_test_split(X, T, test_size=0.2) T_train = np.eye(10)[T_train.astype("int")] T_test = np.eye(10)[T_test.astype("int")] def cross_entropy_error(t, y): delta = 1e-8 error = -np.mean(t * np.log(y + delta)) return error class SoftmaxCrossEntropyLoss(): def __init__(self): self.y = None self.t = None self.loss = None def __call__(self, t, y): self.y = softmax(y) self.t = t.copy() self.loss = cross_entropy_error(self.t, self.y) return self.loss def backward(self): batch_size = self.t.shape[0] dy = self.y - self.t dy /= batch_size return dy class FullyConnectedLayer(): def __init__(self, input_shape, output_shape): self.w = np.random.randn(input_shape, output_shape) * 0.01 self.b = np.zeros(output_shape, dtype=np.float) self.x = None self.dw = None self.db = None #怪しい def __call__(self, x): self.x = x out = np.dot(x,self.w)+self.b return out #怪しい def backward(self, dout): dx = np.dot(dout,np.transpose(self.w)) batch_size = dx.shape[0] self.dw = np.dot(np.transpose(self.x),dout) self.db = np.sum(dout,axis=0) return dx class ReLU(): def __init__(self): self.mask = None #怪しい def __call__(self, x): self.mask = (x <= 0) out = x.copy() out[self.mask]=0 return out #怪しい def backward(self, dout): dout[self.mask]=0 dx = dout return dx class MLP_classifier(): def __init__(self): ''' x -> fc(784, 256) -> relu -> fc(256, 256) -> relu -> fc(256, 10) -> out ''' # 層 self.fc1 = FullyConnectedLayer(784, 256) self.relu1 = ReLU() self.fc2 = FullyConnectedLayer(256, 256) self.relu2 = ReLU() self.fc3 = FullyConnectedLayer(256, 10) self.out = None # 損失関数の定義 self.criterion = SoftmaxCrossEntropyLoss() def forward(self, x): ''' 順伝播 ''' x = self.relu1(self.fc1(x)) x = self.relu2(self.fc2(x)) self.out = self.fc3(x) return self.out def backward(self, t): ''' 逆伝播 ''' # 誤差を計算 loss = self.criterion(t, self.out) # 勾配を逆伝播 d = self.criterion.backward() d = self.fc3.backward(d) d = self.relu2.backward(d) d = self.fc2.backward(d) d = self.relu1.backward(d) d = self.fc1.backward(d) return loss def optimize_GradientDecent(self, lr): ''' 勾配降下法による全層のパラメータの更新 ''' for fc in [self.fc1, self.fc2, self.fc3]: fc.w -= lr * fc.dw fc.b -= lr * fc.db # モデルの宣言 model = MLP_classifier() # 学習率 lr = 0.005 # 学習エポック数 n_epoch = 20 for n in range(n_epoch): # 訓練 y = model.forward(X_train) loss = model.backward(T_train) model.optimize_GradientDecent(lr) # テスト y = model.forward(X_test) test_loss = model.backward(T_test) pred = softmax(y) accuracy = np.mean(np.equal(np.argmax(y, axis=1), np.argmax(T_test, axis=1))) print(f'EPOCH {n + 1} | TRAIN LOSS {loss:.5f} | TEST LOSS {test_loss:.5f} | ACCURACY {accuracy:.2%}') classification_accuracy = accuracy

試したこと

エラーがdef softmax()の中なので、その中はいくつか試しました。

補足情報(FW/ツールのバージョンなど)

初心者ですので、質問の仕方が不適切かと思いますが、
助けてください。

以下のような質問にはリアクションをつけましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

リアクションが多い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

下記のような質問は推奨されていません。

  • 間違っている
  • 質問になっていない投稿
  • スパムや攻撃的な表現を用いた投稿

適切な質問に修正を依頼しましょう。

meg_

2022/09/17 09:38 編集

まずは順伝播の計算が合っていることを単純なデータで確認するのが先かと思います。あと、エラーではなくウォーニングかと思いますが。
kobahot

2022/09/22 14:32

ご指導ありがとうございます。ルールを把握しておらず申し訳ありません。 正直、解決したか自信ありませんが、学習は完了できたので、その内容を投稿しておきます。

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
86.12%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。