質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

89.10%

mnistでimagedatageneratorを使う

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 1
  • VIEW 600

shunx2_1003

score 17

以下のコードを改良してMNISTの画像を水増しして学習するプログラムを作ろうと考えていますが、どのように実装すればいいのかが分かりません。
imagedatageneratorの入力は四次元配列ですが、自分のNNのinput_dimは784なので、imagedatageneratorで得られた画像をどう整形すればいいのでしょうか?

# encoding: utf-8
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropout
from keras.callbacks import Callback, CSVLogger
from keras.utils import np_utils
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from PIL import Image
import datetime

class PlotLosses(Callback):
    '''
    学習中のlossについてlive plotする
    '''

    def on_train_begin(self, logs={}):
        '''
        訓練開始時に実施
        '''
        self.epoch_cnt = 0      # epochの回数を初期化
        plt.axis([0, self.epochs, 0, 0.25])
        plt.ion()               # pyplotをinteractive modeにする

    def on_train_end(self, logs={}):
        '''
        訓練修了時に実施
        '''
        plt.ioff()              # pyplotのinteractive modeをoffにする
        plt.legend(['loss', 'val_loss'], loc='best')
        plt.show()

    def on_epoch_end(self, epoch, logs={}):
        '''
        epochごとに実行する処理
        '''
        loss = logs.get('loss')
        val_loss = logs.get('val_loss')
        x = self.epoch_cnt
        # epochごとのlossとval_lossをplotする
        plt.scatter(x, loss, c='b', label='loss')
        plt.scatter(x, val_loss, c='r', label='val_loss')
        plt.pause(0.05)
        # epoch回数をcount up
        self.epoch_cnt += 1

# kerasのMNISTデータの取得
def main(epochs=6,batch_size=128, activation='relu', optimizer='rmsprop',
          NGphoto=False, realtimeplot=False, csvlog=True,
          plotmodel=True, modelsave=False):
  global model_name
  (X_train, y_train), (X_test, y_test) = mnist.load_data()
  #validationセットをtrainingセットから分ける
  X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.175)


  X_train = X_train.reshape(X_train.shape[0], 784)
  X_test = X_test.reshape(X_test.shape[0], 784)
  X_valid = X_valid.reshape(X_valid.shape[0], 784)
  X_train = X_train.astype('float32')
  X_test = X_test.astype('float32')
  X_valid = X_valid.astype('float32')
  X_train /= 255
  X_test /= 255
  X_valid /= 255
  y_train = keras.utils.to_categorical(y_train, 10)
  y_test = keras.utils.to_categorical(y_test, 10)
  y_valid = keras.utils.to_categorical(y_valid, 10)

  model = Sequential()
  model.add(Dense(512, activation=activation, input_dim=784))
  model.add(Dropout(0.25))
  model.add(Dense(512, activation=activation))
  #model.add(Dropout(0.25))
  #model.add(Dense(512, activation=activation))
  model.add(Dense(10, activation='softmax'))

  model.compile(optimizer=optimizer,
                loss='categorical_crossentropy',
                metrics=['accuracy'])

  # callback function
  plot_losses = PlotLosses()      # グラフ表示(live plot)
  plot_losses.epochs = epochs
  csv_logger = CSVLogger('trainlog.csv')

  callbacks=[]
  if realtimeplot:
    callbacks.append(plot_losses)
  if csvlog:
    callbacks.append(csv_logger)



  hist = model.fit(X_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(X_valid, y_valid),
                    callbacks=callbacks)

  score = model.evaluate(X_test, y_test, verbose=0)
  print('Test loss: {0}'.format(score[0]))
  print('Test accuracy: {0}'.format(score[1]))

  loss = hist.history['loss']
  val_loss = hist.history['val_loss']

  # lossのグラフ
  plt.plot(range(epochs), loss, marker='.', label='loss')
  plt.plot(range(epochs), val_loss, marker='.', label='val_loss')
  plt.legend(loc='best', fontsize=10)
  plt.grid()
  plt.xlabel('epoch')
  plt.ylabel('loss')
  plt.show()




  acc = hist.history['acc']
  val_acc = hist.history['val_acc']

  # accuracyのグラフ
  plt.plot(range(epochs), acc, marker='.', label='acc')
  plt.plot(range(epochs), val_acc, marker='.', label='val_acc')
  plt.legend(loc='best', fontsize=10)
  plt.grid()
  plt.xlabel('epoch')
  plt.ylabel('acc')
  plt.show()



  if NGphoto:
    pre = model.predict(X_test)
    for i,v in enumerate(pre):
        pre_ans = v.argmax()
        ans = y_test[i].argmax()
        dat = X_test[i]
        categories = [ "0","1", "2" , "3", "4" , "5", "6" , "7", "8" , "9"]
        if ans == pre_ans: continue
        fname = "NG_photo/" + str(i) + "-" + categories[pre_ans] + \
            "-ne-" + categories[ans] + ".png"
        dat *= 255
        img = Image.fromarray(dat.reshape((28,28))).convert("RGB")
        img.save(fname)

  if plotmodel:
    keras.utils.plot_model(model, to_file='keras_MNIST_powerup_model.svg', show_shapes=True)

  if modelsave:
    now = datetime.datetime.now()
    model_name = 'keras_MNIST_powerup'+now.strftime('%Y%m%d_%H%M%S')
    model.save('../models/'+model_name+'.h5', include_optimizer=False)

if __name__ == '__main__':
  epochs = 10
  activation = 'relu'
  optimizer = 'adam'
  batch_size = 100
  main(epochs, batch_size, activation, optimizer,
        NGphoto=False, realtimeplot=False,
        csvlog=True, plotmodel=True, modelsave=True)

  print('今回のモデル')
  print("model_name : " + model_name)
  print("batch_size : " + str(batch_size))
  print("epochs : " + str(epochs))
  print("activation : " + activation)
  print("optimizer : " + optimizer)

初心者で汚いコードかと思いますが、ご指導よろしくお願いします。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • quiqui

    2019/08/24 17:38 編集

    貼り付けるコードを間違っていませんか。うまくいってないコードを貼って、なにがうまくいってないのかも書きましょう。

    キャンセル

  • shunx2_1003

    2019/08/24 22:21

    このコードに、データ拡張の処理を追加しようと思っているのですが、ImageDataGeneratorクラスの使い方がわからず、何も書いていない状態です。

    キャンセル

  • quiqui

    2019/08/24 22:29

    「うまくいきません」と書いたら、そうとは伝わらないと思います。

    キャンセル

  • shunx2_1003

    2019/08/25 00:00

    そうですね、ありがとうございます。修正しておきます。

    キャンセル

回答 1

checkベストアンサー

0

ImageGenerator がわからないのですが、ここを見ながら。

サンプル x (形状が (NumSamples, Height, Width, Channels) である4次元データ) 

ということです。MNISTのデータは、例えばこのサイトなどに解説されています。

ヘッダ領域が 16 バイトで、その後に 28 * 28 バイトのピクセルデータが 60,000 画像分だけ続きます。

Channels は、入力層では色のためのバイト数と思っていいです。MNIST は256階調のグレースケールなので、1です。
NumSamples は、サンプルの数ですが、入力、出力のどっちなんだろう?

ということで、できますよね?

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 89.10%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

同じタグがついた質問を見る