回答編集履歴

2

fix answer

2023/01/10 22:25

投稿

ps_aux_grep
ps_aux_grep

スコア1579

test CHANGED
@@ -14,6 +14,16 @@
14
14
  from keras.preprocessing.image import ImageDataGenerator
15
15
  from matplotlib import pyplot as plt
16
16
 
17
+ def mix_up(X1, y1, X2, y2, mix_up_alpha):
18
+ assert X1.shape[0] == y1.shape[0] == X2.shape[0] == y2.shape[0]
19
+ batch_size = X1.shape[0]
20
+ l = np.random.beta(mix_up_alpha, mix_up_alpha, batch_size)
21
+ X_l = l.reshape(batch_size, 1, 1, 1)
22
+ y_l = l.reshape(batch_size, 1)
23
+ X = X1 * X_l + X2 * (1 - X_l)
24
+ y = y1 * y_l + y2 * (1 - y_l)
25
+ return X, y
26
+
17
27
  class MyImageDataGenerator(ImageDataGenerator):
18
28
  def __init__(self, *args, **kwargs):
19
29
  self.random_eraser = get_random_eraser()
@@ -23,21 +33,36 @@
23
33
  x = self.random_eraser(x) # 先に処理する
24
34
  return super().apply_transform(x, transform_parameters)
25
35
 
36
+ from keras.datasets import mnist
37
+ (X, y), (_, _) = mnist.load_data()
26
- # imgs にMNISTの画像をロードする処理
38
+ X = X[:, :, :, np.newaxis]
27
39
 
40
+ mix_up_gen = ImageDataGenerator(
41
+ rescale = 1/255,
42
+ ).flow(X, y)
28
- train_gen = MyImageDataGenerator(rotation_range = 90)
43
+ train_gen = MyImageDataGenerator(
44
+ rescale = 1/255,
45
+ rotation_range = 90
46
+ ).flow(X, y)
47
+
48
+ def my_flow(mix_up_gen, gen):
49
+ for (X1, y1), (X2, y2) in zip(mix_up_gen, gen):
50
+ yield mix_up(X1, y1, X2, y2, 0.5)
51
+
29
52
  fig = plt.figure(figsize = (10, 7))
30
53
  rows, cols = 3, 3
31
- one_batch = next(train_gen.flow(imgs))
54
+ one_batch = next(my_flow(mix_up_gen, train_gen))
32
- for i, img in enumerate(one_batch):
55
+ for i, (img, label) in enumerate(zip(one_batch[0], one_batch[1])):
33
56
  if i >= rows * cols:
34
57
  break
35
58
  fig.add_subplot(rows, cols, i + 1)
36
59
  img = img.reshape(28, 28)
37
- plt.imshow(img, vmin = 0, vmax = 255)
60
+ plt.imshow(img, vmin = 0, vmax = 1)
38
61
  plt.axis("off")
39
62
  ```
40
63
 
41
- ![イメージ説明](https://ddjkaamml8q8x.cloudfront.net/questions/2022-12-26/f570931d-a7c5-4802-b7b1-90371c038507.png)
64
+ ![イメージ説明](https://ddjkaamml8q8x.cloudfront.net/questions/2023-01-11/203bab8d-f5f5-4fbd-96ce-43070fce8826.png)
42
65
 
43
66
  `apply_transform()`に要求されている引数は,[GitHubのコード](https://github.com/keras-team/keras/blob/v2.11.0/keras/preprocessing/image.py#L1983)からも分かるとおり,単一画像`x`とその変換方法`transform_parameter`なので,これが処理されてしまう前に,単一画像`x`を独自の方法で変換します.
67
+
68
+ また,MixUpをImageDataGeneratorによる変換前に挿入するのは[`Iterator`](https://github.com/keras-team/keras/blob/v2.11.0/keras/preprocessing/image.py#L59)の書き換えが必要であることがわかり現実的ではないので,従来通りジェネレータを定義するしかありません.

1

fix answer

2022/12/26 11:14

投稿

ps_aux_grep
ps_aux_grep

スコア1579

test CHANGED
@@ -20,9 +20,8 @@
20
20
  super().__init__(*args, **kwargs)
21
21
 
22
22
  def apply_transform(self, x, transform_parameters):
23
- x = self.random_eraser(x)
23
+ x = self.random_eraser(x) # 先に処理する
24
- super().apply_transform(x, transform_parameters)
24
+ return super().apply_transform(x, transform_parameters)
25
- return x
26
25
 
27
26
  # imgs にMNISTの画像をロードする処理
28
27
 
@@ -39,6 +38,6 @@
39
38
  plt.axis("off")
40
39
  ```
41
40
 
42
- ![イメージ説明](https://ddjkaamml8q8x.cloudfront.net/questions/2022-12-26/6cd7d49e-9c18-4890-96e3-0a89ff3f0dd0.png)
41
+ ![イメージ説明](https://ddjkaamml8q8x.cloudfront.net/questions/2022-12-26/f570931d-a7c5-4802-b7b1-90371c038507.png)
43
42
 
44
43
  `apply_transform()`に要求されている引数は,[GitHubのコード](https://github.com/keras-team/keras/blob/v2.11.0/keras/preprocessing/image.py#L1983)からも分かるとおり,単一画像`x`とその変換方法`transform_parameter`なので,これが処理されてしまう前に,単一画像`x`を独自の方法で変換します.