回答編集履歴
2
fix answer
test
CHANGED
@@ -14,6 +14,16 @@
|
|
14
14
|
from keras.preprocessing.image import ImageDataGenerator
|
15
15
|
from matplotlib import pyplot as plt
|
16
16
|
|
17
|
+
def mix_up(X1, y1, X2, y2, mix_up_alpha):
|
18
|
+
assert X1.shape[0] == y1.shape[0] == X2.shape[0] == y2.shape[0]
|
19
|
+
batch_size = X1.shape[0]
|
20
|
+
l = np.random.beta(mix_up_alpha, mix_up_alpha, batch_size)
|
21
|
+
X_l = l.reshape(batch_size, 1, 1, 1)
|
22
|
+
y_l = l.reshape(batch_size, 1)
|
23
|
+
X = X1 * X_l + X2 * (1 - X_l)
|
24
|
+
y = y1 * y_l + y2 * (1 - y_l)
|
25
|
+
return X, y
|
26
|
+
|
17
27
|
class MyImageDataGenerator(ImageDataGenerator):
|
18
28
|
def __init__(self, *args, **kwargs):
|
19
29
|
self.random_eraser = get_random_eraser()
|
@@ -23,21 +33,36 @@
|
|
23
33
|
x = self.random_eraser(x) # 先に処理する
|
24
34
|
return super().apply_transform(x, transform_parameters)
|
25
35
|
|
36
|
+
from keras.datasets import mnist
|
37
|
+
(X, y), (_, _) = mnist.load_data()
|
26
|
-
|
38
|
+
X = X[:, :, :, np.newaxis]
|
27
39
|
|
40
|
+
mix_up_gen = ImageDataGenerator(
|
41
|
+
rescale = 1/255,
|
42
|
+
).flow(X, y)
|
28
|
-
train_gen = MyImageDataGenerator(
|
43
|
+
train_gen = MyImageDataGenerator(
|
44
|
+
rescale = 1/255,
|
45
|
+
rotation_range = 90
|
46
|
+
).flow(X, y)
|
47
|
+
|
48
|
+
def my_flow(mix_up_gen, gen):
|
49
|
+
for (X1, y1), (X2, y2) in zip(mix_up_gen, gen):
|
50
|
+
yield mix_up(X1, y1, X2, y2, 0.5)
|
51
|
+
|
29
52
|
fig = plt.figure(figsize = (10, 7))
|
30
53
|
rows, cols = 3, 3
|
31
|
-
one_batch = next(train_gen
|
54
|
+
one_batch = next(my_flow(mix_up_gen, train_gen))
|
32
|
-
for i, img in enumerate(one_batch):
|
55
|
+
for i, (img, label) in enumerate(zip(one_batch[0], one_batch[1])):
|
33
56
|
if i >= rows * cols:
|
34
57
|
break
|
35
58
|
fig.add_subplot(rows, cols, i + 1)
|
36
59
|
img = img.reshape(28, 28)
|
37
|
-
plt.imshow(img, vmin = 0, vmax =
|
60
|
+
plt.imshow(img, vmin = 0, vmax = 1)
|
38
61
|
plt.axis("off")
|
39
62
|
```
|
40
63
|
|
41
|
-
![イメージ説明](https://ddjkaamml8q8x.cloudfront.net/questions/202
|
64
|
+
![イメージ説明](https://ddjkaamml8q8x.cloudfront.net/questions/2023-01-11/203bab8d-f5f5-4fbd-96ce-43070fce8826.png)
|
42
65
|
|
43
66
|
`apply_transform()`に要求されている引数は,[GitHubのコード](https://github.com/keras-team/keras/blob/v2.11.0/keras/preprocessing/image.py#L1983)からも分かるとおり,単一画像`x`とその変換方法`transform_parameter`なので,これが処理されてしまう前に,単一画像`x`を独自の方法で変換します.
|
67
|
+
|
68
|
+
また,MixUpをImageDataGeneratorによる変換前に挿入するのは[`Iterator`](https://github.com/keras-team/keras/blob/v2.11.0/keras/preprocessing/image.py#L59)の書き換えが必要であることがわかり現実的ではないので,従来通りジェネレータを定義するしかありません.
|
1
fix answer
test
CHANGED
@@ -20,9 +20,8 @@
|
|
20
20
|
super().__init__(*args, **kwargs)
|
21
21
|
|
22
22
|
def apply_transform(self, x, transform_parameters):
|
23
|
-
x = self.random_eraser(x)
|
23
|
+
x = self.random_eraser(x) # 先に処理する
|
24
|
-
super().apply_transform(x, transform_parameters)
|
24
|
+
return super().apply_transform(x, transform_parameters)
|
25
|
-
return x
|
26
25
|
|
27
26
|
# imgs にMNISTの画像をロードする処理
|
28
27
|
|
@@ -39,6 +38,6 @@
|
|
39
38
|
plt.axis("off")
|
40
39
|
```
|
41
40
|
|
42
|
-
![イメージ説明](https://ddjkaamml8q8x.cloudfront.net/questions/2022-12-26/
|
41
|
+
![イメージ説明](https://ddjkaamml8q8x.cloudfront.net/questions/2022-12-26/f570931d-a7c5-4802-b7b1-90371c038507.png)
|
43
42
|
|
44
43
|
`apply_transform()`に要求されている引数は,[GitHubのコード](https://github.com/keras-team/keras/blob/v2.11.0/keras/preprocessing/image.py#L1983)からも分かるとおり,単一画像`x`とその変換方法`transform_parameter`なので,これが処理されてしまう前に,単一画像`x`を独自の方法で変換します.
|