質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

配列

配列は、各データの要素(値または変数)が連続的に並べられたデータ構造です。各配列は添え字(INDEX)で識別されています。

Q&A

解決済

1回答

1005閲覧

mnistでimagedatageneratorを使う

shunx2_1003

総合スコア18

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

配列

配列は、各データの要素(値または変数)が連続的に並べられたデータ構造です。各配列は添え字(INDEX)で識別されています。

0グッド

1クリップ

投稿2019/08/24 02:45

編集2019/08/24 15:02

以下のコードを改良してMNISTの画像を水増しして学習するプログラムを作ろうと考えていますが、どのように実装すればいいのかが分かりません。
imagedatageneratorの入力は四次元配列ですが、自分のNNのinput_dimは784なので、imagedatageneratorで得られた画像をどう整形すればいいのでしょうか?

python

1# encoding: utf-8 2import keras 3from keras.datasets import mnist 4from keras.models import Sequential 5from keras.layers.core import Dense, Activation, Dropout 6from keras.callbacks import Callback, CSVLogger 7from keras.utils import np_utils 8import matplotlib.pyplot as plt 9from sklearn.model_selection import train_test_split 10from PIL import Image 11import datetime 12 13class PlotLosses(Callback): 14 ''' 15 学習中のlossについてlive plotする 16 ''' 17 18 def on_train_begin(self, logs={}): 19 ''' 20 訓練開始時に実施 21 ''' 22 self.epoch_cnt = 0 # epochの回数を初期化 23 plt.axis([0, self.epochs, 0, 0.25]) 24 plt.ion() # pyplotをinteractive modeにする 25 26 def on_train_end(self, logs={}): 27 ''' 28 訓練修了時に実施 29 ''' 30 plt.ioff() # pyplotのinteractive modeをoffにする 31 plt.legend(['loss', 'val_loss'], loc='best') 32 plt.show() 33 34 def on_epoch_end(self, epoch, logs={}): 35 ''' 36 epochごとに実行する処理 37 ''' 38 loss = logs.get('loss') 39 val_loss = logs.get('val_loss') 40 x = self.epoch_cnt 41 # epochごとのlossとval_lossをplotする 42 plt.scatter(x, loss, c='b', label='loss') 43 plt.scatter(x, val_loss, c='r', label='val_loss') 44 plt.pause(0.05) 45 # epoch回数をcount up 46 self.epoch_cnt += 1 47 48# kerasのMNISTデータの取得 49def main(epochs=6,batch_size=128, activation='relu', optimizer='rmsprop', 50 NGphoto=False, realtimeplot=False, csvlog=True, 51 plotmodel=True, modelsave=False): 52 global model_name 53 (X_train, y_train), (X_test, y_test) = mnist.load_data() 54 #validationセットをtrainingセットから分ける 55 X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.175) 56 57 58 X_train = X_train.reshape(X_train.shape[0], 784) 59 X_test = X_test.reshape(X_test.shape[0], 784) 60 X_valid = X_valid.reshape(X_valid.shape[0], 784) 61 X_train = X_train.astype('float32') 62 X_test = X_test.astype('float32') 63 X_valid = X_valid.astype('float32') 64 X_train /= 255 65 X_test /= 255 66 X_valid /= 255 67 y_train = keras.utils.to_categorical(y_train, 10) 68 y_test = keras.utils.to_categorical(y_test, 10) 69 y_valid = keras.utils.to_categorical(y_valid, 10) 70 71 model = Sequential() 72 model.add(Dense(512, activation=activation, input_dim=784)) 73 model.add(Dropout(0.25)) 74 model.add(Dense(512, activation=activation)) 75 #model.add(Dropout(0.25)) 76 #model.add(Dense(512, activation=activation)) 77 model.add(Dense(10, activation='softmax')) 78 79 model.compile(optimizer=optimizer, 80 loss='categorical_crossentropy', 81 metrics=['accuracy']) 82 83 # callback function 84 plot_losses = PlotLosses() # グラフ表示(live plot) 85 plot_losses.epochs = epochs 86 csv_logger = CSVLogger('trainlog.csv') 87 88 callbacks=[] 89 if realtimeplot: 90 callbacks.append(plot_losses) 91 if csvlog: 92 callbacks.append(csv_logger) 93 94 95 96 hist = model.fit(X_train, y_train, 97 batch_size=batch_size, 98 epochs=epochs, 99 verbose=1, 100 validation_data=(X_valid, y_valid), 101 callbacks=callbacks) 102 103 score = model.evaluate(X_test, y_test, verbose=0) 104 print('Test loss: {0}'.format(score[0])) 105 print('Test accuracy: {0}'.format(score[1])) 106 107 loss = hist.history['loss'] 108 val_loss = hist.history['val_loss'] 109 110 # lossのグラフ 111 plt.plot(range(epochs), loss, marker='.', label='loss') 112 plt.plot(range(epochs), val_loss, marker='.', label='val_loss') 113 plt.legend(loc='best', fontsize=10) 114 plt.grid() 115 plt.xlabel('epoch') 116 plt.ylabel('loss') 117 plt.show() 118 119 120 121 122 acc = hist.history['acc'] 123 val_acc = hist.history['val_acc'] 124 125 # accuracyのグラフ 126 plt.plot(range(epochs), acc, marker='.', label='acc') 127 plt.plot(range(epochs), val_acc, marker='.', label='val_acc') 128 plt.legend(loc='best', fontsize=10) 129 plt.grid() 130 plt.xlabel('epoch') 131 plt.ylabel('acc') 132 plt.show() 133 134 135 136 if NGphoto: 137 pre = model.predict(X_test) 138 for i,v in enumerate(pre): 139 pre_ans = v.argmax() 140 ans = y_test[i].argmax() 141 dat = X_test[i] 142 categories = [ "0","1", "2" , "3", "4" , "5", "6" , "7", "8" , "9"] 143 if ans == pre_ans: continue 144 fname = "NG_photo/" + str(i) + "-" + categories[pre_ans] + \ 145 "-ne-" + categories[ans] + ".png" 146 dat *= 255 147 img = Image.fromarray(dat.reshape((28,28))).convert("RGB") 148 img.save(fname) 149 150 if plotmodel: 151 keras.utils.plot_model(model, to_file='keras_MNIST_powerup_model.svg', show_shapes=True) 152 153 if modelsave: 154 now = datetime.datetime.now() 155 model_name = 'keras_MNIST_powerup'+now.strftime('%Y%m%d_%H%M%S') 156 model.save('../models/'+model_name+'.h5', include_optimizer=False) 157 158if __name__ == '__main__': 159 epochs = 10 160 activation = 'relu' 161 optimizer = 'adam' 162 batch_size = 100 163 main(epochs, batch_size, activation, optimizer, 164 NGphoto=False, realtimeplot=False, 165 csvlog=True, plotmodel=True, modelsave=True) 166 167 print('今回のモデル') 168 print("model_name : " + model_name) 169 print("batch_size : " + str(batch_size)) 170 print("epochs : " + str(epochs)) 171 print("activation : " + activation) 172 print("optimizer : " + optimizer)

初心者で汚いコードかと思いますが、ご指導よろしくお願いします。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

quickquip

2019/08/24 08:39 編集

貼り付けるコードを間違っていませんか。うまくいってないコードを貼って、なにがうまくいってないのかも書きましょう。
shunx2_1003

2019/08/24 13:21

このコードに、データ拡張の処理を追加しようと思っているのですが、ImageDataGeneratorクラスの使い方がわからず、何も書いていない状態です。
quickquip

2019/08/24 13:29

「うまくいきません」と書いたら、そうとは伝わらないと思います。
shunx2_1003

2019/08/24 15:00

そうですね、ありがとうございます。修正しておきます。
guest

回答1

0

ベストアンサー

ImageGenerator がわからないのですが、ここを見ながら。

サンプル x (形状が (NumSamples, Height, Width, Channels) である4次元データ)

ということです。MNISTのデータは、例えばこのサイトなどに解説されています。

ヘッダ領域が 16 バイトで、その後に 28 * 28 バイトのピクセルデータが 60,000 画像分だけ続きます。

Channels は、入力層では色のためのバイト数と思っていいです。MNIST は256階調のグレースケールなので、1です。
NumSamples は、サンプルの数ですが、入力、出力のどっちなんだろう?

ということで、できますよね?

投稿2019/09/01 12:03

Q71

総合スコア995

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問