質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
87.20%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

配列

配列は、各データの要素(値または変数)が連続的に並べられたデータ構造です。各配列は添え字(INDEX)で識別されています。

解決済

mnistでimagedatageneratorを使う

shunx2_1003
shunx2_1003

総合スコア0

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

配列

配列は、各データの要素(値または変数)が連続的に並べられたデータ構造です。各配列は添え字(INDEX)で識別されています。

1回答

0評価

1クリップ

15閲覧

投稿2019/08/24 02:45

編集2022/01/12 10:58

以下のコードを改良してMNISTの画像を水増しして学習するプログラムを作ろうと考えていますが、うまくいきません。
imagedatageneratorの入力は四次元配列ですが、自分のNNのinput_dimは784なので、imagedatageneratorで得られた画像をどう整形すればいいのでしょうか?

python

# encoding: utf-8 import keras from keras.datasets import mnist from keras.models import Sequential from keras.layers.core import Dense, Activation, Dropout from keras.callbacks import Callback, CSVLogger from keras.utils import np_utils import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from PIL import Image import datetime class PlotLosses(Callback): ''' 学習中のlossについてlive plotする ''' def on_train_begin(self, logs={}): ''' 訓練開始時に実施 ''' self.epoch_cnt = 0 # epochの回数を初期化 plt.axis([0, self.epochs, 0, 0.25]) plt.ion() # pyplotをinteractive modeにする def on_train_end(self, logs={}): ''' 訓練修了時に実施 ''' plt.ioff() # pyplotのinteractive modeをoffにする plt.legend(['loss', 'val_loss'], loc='best') plt.show() def on_epoch_end(self, epoch, logs={}): ''' epochごとに実行する処理 ''' loss = logs.get('loss') val_loss = logs.get('val_loss') x = self.epoch_cnt # epochごとのlossとval_lossをplotする plt.scatter(x, loss, c='b', label='loss') plt.scatter(x, val_loss, c='r', label='val_loss') plt.pause(0.05) # epoch回数をcount up self.epoch_cnt += 1 # kerasのMNISTデータの取得 def main(epochs=6,batch_size=128, activation='relu', optimizer='rmsprop', NGphoto=False, realtimeplot=False, csvlog=True, plotmodel=True, modelsave=False): global model_name (X_train, y_train), (X_test, y_test) = mnist.load_data() #validationセットをtrainingセットから分ける X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.175) X_train = X_train.reshape(X_train.shape[0], 784) X_test = X_test.reshape(X_test.shape[0], 784) X_valid = X_valid.reshape(X_valid.shape[0], 784) X_train = X_train.astype('float32') X_test = X_test.astype('float32') X_valid = X_valid.astype('float32') X_train /= 255 X_test /= 255 X_valid /= 255 y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) y_valid = keras.utils.to_categorical(y_valid, 10) model = Sequential() model.add(Dense(512, activation=activation, input_dim=784)) model.add(Dropout(0.25)) model.add(Dense(512, activation=activation)) #model.add(Dropout(0.25)) #model.add(Dense(512, activation=activation)) model.add(Dense(10, activation='softmax')) model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) # callback function plot_losses = PlotLosses() # グラフ表示(live plot) plot_losses.epochs = epochs csv_logger = CSVLogger('trainlog.csv') callbacks=[] if realtimeplot: callbacks.append(plot_losses) if csvlog: callbacks.append(csv_logger) hist = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(X_valid, y_valid), callbacks=callbacks) score = model.evaluate(X_test, y_test, verbose=0) print('Test loss: {0}'.format(score[0])) print('Test accuracy: {0}'.format(score[1])) loss = hist.history['loss'] val_loss = hist.history['val_loss'] # lossのグラフ plt.plot(range(epochs), loss, marker='.', label='loss') plt.plot(range(epochs), val_loss, marker='.', label='val_loss') plt.legend(loc='best', fontsize=10) plt.grid() plt.xlabel('epoch') plt.ylabel('loss') plt.show() acc = hist.history['acc'] val_acc = hist.history['val_acc'] # accuracyのグラフ plt.plot(range(epochs), acc, marker='.', label='acc') plt.plot(range(epochs), val_acc, marker='.', label='val_acc') plt.legend(loc='best', fontsize=10) plt.grid() plt.xlabel('epoch') plt.ylabel('acc') plt.show() if NGphoto: pre = model.predict(X_test) for i,v in enumerate(pre): pre_ans = v.argmax() ans = y_test[i].argmax() dat = X_test[i] categories = [ "0","1", "2" , "3", "4" , "5", "6" , "7", "8" , "9"] if ans == pre_ans: continue fname = "NG_photo/" + str(i) + "-" + categories[pre_ans] + \ "-ne-" + categories[ans] + ".png" dat *= 255 img = Image.fromarray(dat.reshape((28,28))).convert("RGB") img.save(fname) if plotmodel: keras.utils.plot_model(model, to_file='keras_MNIST_powerup_model.svg', show_shapes=True) if modelsave: now = datetime.datetime.now() model_name = 'keras_MNIST_powerup'+now.strftime('%Y%m%d_%H%M%S') model.save('../models/'+model_name+'.h5', include_optimizer=False) if __name__ == '__main__': epochs = 10 activation = 'relu' optimizer = 'adam' batch_size = 100 main(epochs, batch_size, activation, optimizer, NGphoto=False, realtimeplot=False, csvlog=True, plotmodel=True, modelsave=True) print('今回のモデル') print("model_name : " + model_name) print("batch_size : " + str(batch_size)) print("epochs : " + str(epochs)) print("activation : " + activation) print("optimizer : " + optimizer)

初心者で汚いコードかと思いますが、ご指導よろしくお願いします。

良い質問の評価を上げる

以下のような質問は評価を上げましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

  • プログラミングに関係のない質問
  • やってほしいことだけを記載した丸投げの質問
  • 問題・課題が含まれていない質問
  • 意図的に内容が抹消された質問
  • 過去に投稿した質問と同じ内容の質問
  • 広告と受け取られるような投稿

評価を下げると、トップページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

quickquip
quickquip

2019/08/24 08:39 編集

貼り付けるコードを間違っていませんか。うまくいってないコードを貼って、なにがうまくいってないのかも書きましょう。
shunx2_1003
shunx2_1003

2019/08/24 13:21

このコードに、データ拡張の処理を追加しようと思っているのですが、ImageDataGeneratorクラスの使い方がわからず、何も書いていない状態です。
quickquip
quickquip

2019/08/24 13:29

「うまくいきません」と書いたら、そうとは伝わらないと思います。
shunx2_1003
shunx2_1003

2019/08/24 15:00

そうですね、ありがとうございます。修正しておきます。

まだ回答がついていません

会員登録して回答してみよう

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
87.20%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

配列

配列は、各データの要素(値または変数)が連続的に並べられたデータ構造です。各配列は添え字(INDEX)で識別されています。