質問編集履歴
1
該当ソースコードのところに全体のコードを追加しました。
test
CHANGED
File without changes
|
test
CHANGED
@@ -2,6 +2,20 @@
|
|
2
2
|
|
3
3
|
convLSTMを用いた画像予測を行っています。
|
4
4
|
|
5
|
+
|
6
|
+
|
7
|
+
%matplotlib inline
|
8
|
+
|
9
|
+
i=0
|
10
|
+
|
11
|
+
fig, axes = plt.subplots(1, 2, figsize=(12,6))
|
12
|
+
|
13
|
+
axes[0].imshow((y_test[i]+1)/2)
|
14
|
+
|
15
|
+
axes[1].imshow((model.predict(x_test[[i]]).reshape(100,180,3)+1)/2)
|
16
|
+
|
17
|
+
|
18
|
+
|
5
19
|
i=0のところを0以外の数値にするとエラーが出ます。
|
6
20
|
|
7
21
|
|
@@ -26,11 +40,145 @@
|
|
26
40
|
|
27
41
|
```ここに言語名を入力
|
28
42
|
|
43
|
+
|
44
|
+
|
45
|
+
import numpy as np
|
46
|
+
|
47
|
+
import matplotlib.pyplot as plt
|
48
|
+
|
49
|
+
import pandas as pd
|
50
|
+
|
51
|
+
import seaborn as sns
|
52
|
+
|
53
|
+
from sklearn.model_selection import train_test_split
|
54
|
+
|
55
|
+
import glob
|
56
|
+
|
57
|
+
from PIL import Image
|
58
|
+
|
59
|
+
from tqdm import tqdm
|
60
|
+
|
61
|
+
import zipfile
|
62
|
+
|
63
|
+
import io
|
64
|
+
|
65
|
+
|
66
|
+
|
67
|
+
|
68
|
+
|
69
|
+
|
70
|
+
|
71
|
+
# 縮小後の画像サイズ
|
72
|
+
|
73
|
+
height = 100
|
74
|
+
|
75
|
+
width = 180
|
76
|
+
|
77
|
+
|
78
|
+
|
79
|
+
# 読み込んだ画像を入れる配列
|
80
|
+
|
81
|
+
imgs=np.empty((0, height, width, 3))
|
82
|
+
|
83
|
+
|
84
|
+
|
85
|
+
# zipファイルを読み込んでnumpy配列にする
|
86
|
+
|
87
|
+
zip_f = zipfile.ZipFile('drive/My Drive/Colab Notebooks/convLSTM/wide.zip')
|
88
|
+
|
89
|
+
for name in tqdm(zip_f.namelist()):
|
90
|
+
|
91
|
+
with zip_f.open(name) as file:
|
92
|
+
|
93
|
+
path = io.BytesIO(file.read()) #解凍
|
94
|
+
|
95
|
+
img = Image.open(path)
|
96
|
+
|
97
|
+
img = img.resize((width, height))
|
98
|
+
|
99
|
+
img_np = np.array(img).reshape(1, height, width, 3)
|
100
|
+
|
101
|
+
imgs = np.append(imgs, img_np, axis=0)
|
102
|
+
|
103
|
+
|
104
|
+
|
105
|
+
|
106
|
+
|
107
|
+
# 時系列で学習できる形式に整える
|
108
|
+
|
109
|
+
n_seq = 5
|
110
|
+
|
111
|
+
n_sample = imgs.shape[0] - n_seq
|
112
|
+
|
113
|
+
|
114
|
+
|
115
|
+
x = np.zeros((n_sample, n_seq, height, width, 3))
|
116
|
+
|
117
|
+
y = np.zeros((n_sample, height, width, 3))
|
118
|
+
|
119
|
+
for i in range(n_sample):
|
120
|
+
|
121
|
+
x[i] = imgs[i:i+n_seq]
|
122
|
+
|
123
|
+
y[i] = imgs[i+n_seq]
|
124
|
+
|
125
|
+
x, y = (x-128)/128, (y-128)/128
|
126
|
+
|
127
|
+
|
128
|
+
|
129
|
+
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.1, shuffle = False)
|
130
|
+
|
131
|
+
|
132
|
+
|
133
|
+
|
134
|
+
|
135
|
+
from keras import layers
|
136
|
+
|
137
|
+
from keras.layers.core import Activation
|
138
|
+
|
139
|
+
from tensorflow.keras.models import Model
|
140
|
+
|
141
|
+
|
142
|
+
|
143
|
+
inputs = layers.Input(shape=(5, height, width, 3))
|
144
|
+
|
145
|
+
x0 = layers.ConvLSTM2D(filters=16, kernel_size=(3,3), padding="same", return_sequences=True, data_format="channels_last")(inputs)
|
146
|
+
|
147
|
+
x0 = layers.BatchNormalization(momentum=0.6)(x0)
|
148
|
+
|
149
|
+
x0 = layers.ConvLSTM2D(filters=16, kernel_size=(3,3), padding="same", return_sequences=True, data_format="channels_last")(x0)
|
150
|
+
|
151
|
+
x0 = layers.BatchNormalization(momentum=0.8)(x0)
|
152
|
+
|
153
|
+
|
154
|
+
|
155
|
+
x0 = layers.ConvLSTM2D(filters=3, kernel_size=(3,3), padding="same", return_sequences=False, data_format="channels_last")(x0)
|
156
|
+
|
157
|
+
out = Activation('tanh')(x0)
|
158
|
+
|
159
|
+
model = Model(inputs=inputs, outputs=out)
|
160
|
+
|
161
|
+
model.summary()
|
162
|
+
|
163
|
+
|
164
|
+
|
165
|
+
|
166
|
+
|
167
|
+
model.compile(optimizer='rmsprop',
|
168
|
+
|
169
|
+
loss='mae', metrics=['mse'])
|
170
|
+
|
171
|
+
call_backs=[EarlyStopping(monitor="val_loss",patience=5)]
|
172
|
+
|
173
|
+
model.fit(x_train, y_train, batch_size=16, epochs=100, verbose=2, validation_split=0.2, shuffle=True, callbacks=call_backs)
|
174
|
+
|
175
|
+
|
176
|
+
|
29
177
|
# 描画
|
30
178
|
|
31
179
|
%matplotlib inline
|
32
180
|
|
33
|
-
i=0
|
181
|
+
i=10
|
34
182
|
|
35
183
|
fig, axes = plt.subplots(1, 2, figsize=(12,6))
|
36
184
|
|
@@ -38,6 +186,10 @@
|
|
38
186
|
|
39
187
|
axes[1].imshow((model.predict(x_test[[i]]).reshape(100,180,3)+1)/2)
|
40
188
|
|
189
|
+
|
190
|
+
|
191
|
+
|
192
|
+
|
41
193
|
```
|
42
194
|
|
43
195
|
|