モデルの学習を行う際,epoch1のiteration2回目でout of memoryになってしまいます。
モデルの構築でおかしなところがございまいしたらご指摘いただけますと幸いです。
実行環境 ・Python:version 3.8.8 ・Pytorch:version 1.6.0 ・OS:Ubuntu 16.04 ・GPU:TITAN X PASCAL ・CUDA:10.0
下記メインのコードです。
Python
1~~~~~(省略)~~~~~ 2def main(config): 3 torch.backends.cudnn.enabled = False 4 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 model = networks.UNet( 6 in_channels=config.in_c, 7 out_channels=config.out_c, 8 depth=config.depth, 9 conv_num=config.conv_num, 10 wf=config.wf, 11 padding=True, 12 batch_norm=True, 13 up_mode="upsample", 14 with_tanh=False, 15 sync_bn=True, 16 antialiasing=True, 17 ) 18 19 ## load model 20 if not config.load_dir == " ": 21 checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt") 22 checkpoint = torch.load(checkpoint_path, map_location="cpu") 23 model.load_state_dict(checkpoint["model_state"]) 24 print("model weights loaded") 25 26 model.to(device) 27 model.train() 28 29 ## dataloader and transformation 30 ds = mask_seg_Dataset() 31 dataloader = torch.utils.data.DataLoader(ds, batch_size=config.bs, 32 shuffle=True, num_workers=0, drop_last=True) 33 34 ~~~~~(中略)~~~~~ 35 36 # mkdir_if_not(blend_output_dir) 37 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.00002) 38 focal_error = [] 39 ce_error = [] 40 results = [] 41 epochs = config.epochs 42 Sigmoid = torch.nn.Sigmoid() 43 ce = nn.BCEWithLogitsLoss() 44 45 for epoch in range(epochs): 46 for i, images in enumerate(dataloader): 47 optimizer.zero_grad() 48 print("---------------------") 49 label = images['label'].to(device) 50 image = images['image'].to(device) 51 #a=torch.reshape(label,(2,256,-1)) 52 53 fake_mask = Sigmoid(model(image)) 54 print("after model") 55 #focal_loss = loss_func(fake_mask, a) 56 BCE_loss = ce(fake_mask, label) 57 #focal_error.append(focal_loss) 58 ce_error.append(BCE_loss) 59 #loss = focal_loss + BCE_loss 60 print("after wrror") 61 BCE_loss.backward() 62 optimizer.step() 63 64 ~~~~~(中略)~~~~~ 65 66 67if __name__ == "__main__": 68 parser = argparse.ArgumentParser() 69 ~~~~~(後略)~~~~~
下記ネットワークのコードです。
Python
1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4from detection_models.sync_batchnorm import DataParallelWithCallback 5from detection_models.antialiasing import Downsample 6from torch.autograd import Variable 7 8class UNet(nn.Module): 9 def __init__( 10 self, 11 in_channels=3, 12 out_channels=3, 13 depth=5, 14 conv_num=2, 15 wf=6, 16 padding=True, 17 batch_norm=True, 18 up_mode="upsample", 19 with_tanh=False, 20 sync_bn=True, 21 antialiasing=True, 22 ): 23 24 super().__init__() 25 assert up_mode in ("upconv", "upsample") 26 self.padding = padding 27 self.depth = depth - 1 28 prev_channels = in_channels 29 30 self.first = nn.Sequential( 31 *[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)] 32 ) 33 prev_channels = 2 ** wf 34 35 self.down_path = nn.ModuleList() 36 self.down_sample = nn.ModuleList() 37 for i in range(depth): 38 if antialiasing and depth > 0: 39 self.down_sample.append( 40 nn.Sequential( 41 *[ 42 nn.ReflectionPad2d(1), 43 nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0), 44 nn.BatchNorm2d(prev_channels), 45 nn.LeakyReLU(0.2, True), 46 Downsample(channels=prev_channels, stride=2), 47 ] 48 ) 49 ) 50 else: 51 self.down_sample.append( 52 nn.Sequential( 53 *[ 54 nn.ReflectionPad2d(1), 55 nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0), 56 nn.BatchNorm2d(prev_channels), 57 nn.LeakyReLU(0.2, True), 58 ] 59 ) 60 ) 61 self.down_path.append( 62 UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm) 63 ) 64 prev_channels = 2 ** (wf + i + 1) 65 66 self.up_path = nn.ModuleList() 67 for i in reversed(range(depth)): 68 self.up_path.append( 69 UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm) 70 ) 71 prev_channels = 2 ** (wf + i) 72 73 if with_tanh: 74 self.last = nn.Sequential( 75 *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()] 76 ) 77 else: 78 self.last = nn.Sequential( 79 *[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)] 80 ) 81 82 if sync_bn: 83 self = DataParallelWithCallback(self) 84 85 def forward(self, x): 86 x = self.first(x) 87 blocks = [] 88 for i, down_block in enumerate(self.down_path): 89 blocks.append(x) 90 x = self.down_sample[i](x) 91 x = down_block(x) 92 for i, up in enumerate(self.up_path): 93 x = up(x, blocks[-i - 1]) 94 95return self.last(x) 96 97 98class UNetConvBlock(nn.Module): 99 def __init__(self, conv_num, in_size, out_size, padding, batch_norm): 100 super(UNetConvBlock, self).__init__() 101 block = [] 102 103 for _ in range(conv_num): 104 block.append(nn.ReflectionPad2d(padding=int(padding))) 105 block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0)) 106 if batch_norm: 107 block.append(nn.BatchNorm2d(out_size)) 108 block.append(nn.LeakyReLU(0.2, True)) 109 in_size = out_size 110 111 self.block = nn.Sequential(*block) 112 113 def forward(self, x): 114 out = self.block(x) 115 return out 116 117 118class UNetUpBlock(nn.Module): 119 def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm): 120 super(UNetUpBlock, self).__init__() 121 if up_mode == "upconv": 122 self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) 123 elif up_mode == "upsample": 124 self.up = nn.Sequential( 125 nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False), 126 nn.ReflectionPad2d(1), 127 nn.Conv2d(in_size, out_size, kernel_size=3, padding=0), 128 ) 129 130 self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm) 131 132 def center_crop(self, layer, target_size): 133 _, _, layer_height, layer_width = layer.size() 134 diff_y = (layer_height - target_size[0]) // 2 135 diff_x = (layer_width - target_size[1]) // 2 136 return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])] 137 138 def forward(self, x, bridge): 139 up = self.up(x) 140 crop1 = self.center_crop(bridge, up.shape[2:]) 141 out = torch.cat([up, crop1], 1) 142 out = self.conv_block(out) 143 144 return out 145 146 147 148if __name__ == "__main__": 149 from torchsummary import summary 150 151 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 152 153 model = UNet_two_decoders( 154 in_channels=3, 155 out_channels1=3, 156 out_channels2=1, 157 depth=4, 158 conv_num=1, 159 wf=6, 160 padding=True, 161 batch_norm=True, 162 up_mode="upsample", 163 with_tanh=False, 164 ) 165 model.to(device) 166 167 model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False) 168 model_pix2pix.to(device) 169 170 print("customized unet:") 171 summary(model, (3, 256, 256)) 172 173 print("cyclegan unet:") 174 summary(model_pix2pix, (3, 256, 256)) 175 176 x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda() 177 g = make_dot(model(x)) 178 g.render("models/Digraph.gv", view=False) 179
下記エラーコードです。
Traceback (most recent call last): File "./main.py", line 199, in <module> main(config) File "./main.py", line 140, in main fake_mask = Sigmoid(model(image)) File "/home/anaconda3/envs/ref_rem/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/home/UNet_segment/detection_models/networks.py", line 120, in forward x = up(x, blocks[-i-1]) ~~~~~(中略)~~~~~ RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 11.91 GiB total capacity; 11.20 GiB already allocated; 92.56 MiB free; 11.23 GiB reserved in total by PyTorch)
あなたの回答
tips
プレビュー