MNIST で平均値を予測する回帰問題としてサンプルコードをかきました。
画像の位置関係は必要ないと思ったので、CNN は使っていません。
一応 mean() 関数を学習できました。
今回の平均値を計算するというタスクにおいては直接計算して終わる話なので実用性はないですが。
学習する。
python
1from keras.datasets import mnist
2from keras.layers import Activation, BatchNormalization, Dense
3from keras.models import Sequential
4import numpy as np
5
6# MNIST データを取得する。
7(x_train, _), (x_test, _) = mnist.load_data()
8
9# 各画像の平均値を計算する。
10y_train = np.mean(x_train, (1, 2))
11y_test = np.mean(x_test, (1, 2))
12
13print('x_train.shape', x_train.shape) # x_train.shape (60000, 28, 28)
14print('y_train.shape', y_train.shape) # y_train.shape (60000,)
15print('x_test.shape', x_test.shape) # x_test.shape (10000, 28, 28)
16print('y_test.shape', y_test.shape) # y_test.shape (10000,)
17
18# 1次元配列にする。 (28, 28) -> (784,) にする
19x_train = x_train.reshape(len(x_train), -1)
20x_test = x_test.reshape(len(x_test), -1)
21
22# モデルを作成する。
23model = Sequential()
24model.add(Dense(10, input_dim=784))
25model.add(BatchNormalization())
26model.add(Activation('relu'))
27model.add(Dense(10))
28model.add(BatchNormalization())
29model.add(Activation('relu'))
30model.add(Dense(1))
31model.add(BatchNormalization())
32model.compile(optimizer='adam', loss='mse')
33
34# 学習する。
35history = model.fit(x_train, y_train, epochs=100, batch_size=128,
36 validation_data=(x_test, y_test))
学習過程を可視化する。
python
1import matplotlib.pyplot as plt
2fig, axes = plt.subplots(figsize=(5, 5))
3
4epochs = np.arange(1, len(history.history['loss']) + 1)
5
6# 各エポックの誤差の推移
7axes.set_title('loss')
8axes.plot(epochs, history.history['loss'], label='train')
9axes.plot(epochs, history.history['val_loss'], label='validation')
10axes.set_xticks(epochs)
11axes.legend()
12
13plt.show()
テストデータでいくつか確認
だいたい合っている。
python
1y_pred = model.predict(x_test)
2
3for i, (pred, true) in enumerate(zip(y_pred[:10], y_test[:10])):
4 print('{}: prediction: {}, mean: {}'.format(i, pred, true))
# 0: prediction: [23.70105], mean: 23.538265306122447
# 1: prediction: [36.888763], mean: 36.798469387755105
# 2: prediction: [12.814066], mean: 12.590561224489797
# 3: prediction: [47.20095], mean: 47.21173469387755
# 4: prediction: [24.676441], mean: 24.536989795918366
# 5: prediction: [17.86291], mean: 17.67219387755102
# 6: prediction: [27.157784], mean: 27.020408163265305
# 7: prediction: [26.990688], mean: 26.864795918367346
# 8: prediction: [39.258457], mean: 39.201530612244895
# 9: prediction: [40.019882], mean: 39.98724489795919
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2018/10/01 01:17
2018/10/01 16:31
2018/10/03 05:17
2018/10/03 05:22
2018/10/03 05:43