teratail header banner
teratail header banner
質問するログイン新規登録

回答編集履歴

3

詳しく書いた

2021/04/09 05:19

投稿

BoKuToTuZenU
BoKuToTuZenU

スコア51

answer CHANGED
@@ -1,11 +1,67 @@
1
+ エラーメッセージは入力が4次元を想定しているのに、3次元の入力が入ってきた
2
+ という意味です。
3
+
1
- pytorchの場合、入力は(Batch_Size, Channel, Height, Width)です。
4
+ pytorchの場合、モデルに入力するテンソルの次元は(Batch_Size, Channel, Height, Width)の4次元です。
2
5
  (https://discuss.pytorch.org/t/dimensions-of-an-input-image/19439)
3
- ですが、入力が(Channel, Height, Width)となっているため、エラーが起きていると考えられます。
6
+ ですが、ご提示いただいているコードでは入力が(Channel, Height, Width)の3次元となっているため、エラーが起きていると考えられます。
7
+ つまり、入力のtensorに新しい次元を足せばよいということになります。
8
+
9
+ また、画像の訓練時のサイズとデプロイ時のサイズは必ずしも一致している必要なないかと思います。
10
+ SRGANなどの学習でも,(64x64)のパッチで学習を行います。
11
+ 推論時では、元の画像の大きさをそのまま入力するやり方を取っていたはずです。
12
+
13
+
14
+ そのため、ご提示いただいていたコードであるような、
15
+ ```python
16
+ #画像の読み込みと名前,拡張子の取得
17
+ os.chdir(input_dir)
18
+ apply_img = Image.open(n).convert("RGB")
19
+ img_name, img_ext = os.path.splitext(n)
20
+ print(img_name)
21
+
22
+ #画像サイズとクロップ数の計算部分
23
+ numX = apply_img.width // sample_img.width
24
+ numY = apply_img.height // sample_img.height
25
+ crop_imgs = []
26
+ out_imgs = []
27
+
28
+ #画像を分割
29
+ for i in range(numY):
30
+ for j in range(numX):
31
+ input_img = apply_img.crop((sample_img.width * j, sample_img.height*i,
32
+ sample_img.width * j + sample_img.width, sample_img.height * i + sample_img.height))
33
+ input_img_tensor = tv.transforms.ToTensor()(input_img)
34
+ crop_imgs.append(input_img_tensor)
35
+
36
+ #分割をモデルに適用
37
+ for m in crop_imgs:
38
+ prediction = model(m)
39
+ out_imgs.append(prediction)
40
+
4
- ため、入前に、
41
+ #モデル画像を繋げ
42
+ append_imgs = out_imgs
43
+ v_img = []
44
+ for y in range(numY):
45
+ u_img = []
46
+ for x in range(numX):
47
+ num = x + y * numX
48
+ u_img.append(append_imgs[num])
49
+
50
+ imgU = cv2.hconcat(u_img)
51
+ v_img.append(imgU)
52
+
53
+ append_img = cv2.vconcat(v_img)
54
+ append_img = Image.fromarray(np.unit8(append_img))
55
+ save_name = str(n)
56
+ save_name_dir = os.path.join(save_dir, save_name)
57
+ append_img.save(save_name_dir)
5
58
  ```
6
- C, W, H = input_img_tensor.size()
59
+ のような部分は
7
- input_img = input_img_tensor.reshape(1, C, W, H)
8
60
  ```
9
- などとしてみてはいかがでしょうか?
10
- また、画像の訓練時のサイズとデプロイ時のサイズは必ずしも一致している必要なないかと思います。
61
+ apply_img = Image.open(n).convert("RGB")
62
+ input_image_tensor = tv.transforms.ToTensor()(apply_img)
63
+ input_image_tensor = torch.unsqueeze(input_image_tensor, 0)
64
+ prediction = model(m)
65
+ ```
66
+ としても良いアウトプットを得られるかと思います。
11
67
  SRGANなどの学習でも,(64x64)のパッチで学習を行い、デプロイ時はパッチを利用しないで推論していたと思います。

2

ミスを修正しました。

2021/04/09 05:19

投稿

BoKuToTuZenU
BoKuToTuZenU

スコア51

answer CHANGED
@@ -3,8 +3,8 @@
3
3
  ですが、入力が(Channel, Height, Width)となっているため、エラーが起きていると考えられます。
4
4
  そのため、入力する前に、
5
5
  ```
6
- C, W, H = input_img.size()
6
+ C, W, H = input_img_tensor.size()
7
- input_img = input_img.reshape(1, C, W, H)
7
+ input_img = input_img_tensor.reshape(1, C, W, H)
8
8
  ```
9
9
  などとしてみてはいかがでしょうか?
10
10
  また、画像の訓練時のサイズとデプロイ時のサイズは必ずしも一致している必要なないかと思います。

1

ミスを修正しました。

2021/04/09 05:08

投稿

BoKuToTuZenU
BoKuToTuZenU

スコア51

answer CHANGED
@@ -1,5 +1,6 @@
1
- pytorchの場合、入力は(Batch_Size, Channel, Width, Height)です。
1
+ pytorchの場合、入力は(Batch_Size, Channel, Height, Width)です。
2
+ (https://discuss.pytorch.org/t/dimensions-of-an-input-image/19439)
2
- ですが、入力が(Channel, Width, Height)となっているため、エラーが起きていると考えられます。
3
+ ですが、入力が(Channel, Height, Width)となっているため、エラーが起きていると考えられます。
3
4
  そのため、入力する前に、
4
5
  ```
5
6
  C, W, H = input_img.size()