Mnistのデータを利用していますが、途中でバッチ数が足りなくなってしまいます。当然端数は出ますが、普通は自動でそれが補われると思うのですが。
そういう仕様です。
Pytorch の DataLoader でも Keras の ImageDataGenerator でもサンプル数がバッチサイズで割り切れない場合は、最後は端数になります。
訓練データの各サンプルを1回ずつ学習に回したのを1エポックと定義するので、このような仕様になっていることは自然です。
Keras の ImageDataGenerator の場合
python
1import numpy as np
2from tensorflow.keras.preprocessing import image
3
4x = np.zeros((50, 100, 100, 3), dtype=np.float32)
5
6datagen = image.ImageDataGenerator()
7gen = datagen.flow(x, batch_size=16)
8
9for i in range(4):
10 print(next(gen).shape)
11# (16, 100, 100, 3)
12# (16, 100, 100, 3)
13# (16, 100, 100, 3)
14# (2, 100, 100, 3)
Pytorch の DataLoader の場合
python
1import numpy as np
2import torch
3
4
5class MyDataset(torch.utils.data.Dataset):
6 def __init__(self, x):
7 self.x = x
8
9 def __len__(self):
10 return len(self.x)
11
12 def __getitem__(self, i):
13 return self.x[i]
14
15
16x = np.zeros((50, 100, 100, 3), dtype=np.float32)
17
18dataset = MyDataset(x)
19dataloader = torch.utils.data.DataLoader(dataset, batch_size=16)
20
21
22for x in dataloader:
23 print(x.shape)
24# torch.Size([16, 100, 100, 3])
25# torch.Size([16, 100, 100, 3])
26# torch.Size([16, 100, 100, 3])
27# torch.Size([2, 100, 100, 3])
468イテレーションで、バッチサイズが96になるので、後続の処理でエラーとなります。このようにならない方法を教えてください。
Pytorch のモデルでは、任意のバッチサイズでも流せるようになっているので、問題にならないはずです。
質問欄に記載がないので、エラーの起こっている部分はわかりませんが、自分で書いたコードでバッチサイズが固定であることが前提の部分があるのであれば、その部分は任意のバッチサイズで対応できるように修正するとよいと思います。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2020/05/09 15:09