前提・実現したいこと
ResNet18による画像データの回帰ラベル予測を行っています。
データセットは2562563の画像が10万枚です。
回帰ラベルは4次元,10次元,15次元の3種類があり,
それぞれマルチタスクとして学習させています。
(データの詳細は都合上省略させてください。)
発生している問題
画像のように,MSEが学習中ほとんど変わっていません。
そもそも10万程度のデータ数でresnetを使用するのが間違っているのでしょうか?(パラメータ数が多すぎる?)
当方,ほとんどCNNを扱ったことがなく,何をトライすべきかご教示いただけると幸いです。
該当のソースコード
Python3
1import os 2import csv 3import math 4import multiprocessing 5import datetime 6from pathlib import Path 7import time 8 9import numpy as np 10import pandas as pd 11from PIL import Image 12import torch 13import torch.nn as nn 14import torch.optim as optimizers 15import torch.nn.functional as F 16from torch.utils.data import Dataset, DataLoader 17from torch.utils.data.distributed import DistributedSampler 18import torchvision 19import torchvision.transforms as transforms 20 21class ResNet18(nn.Module): 22 def __init__(self, output_dim1, output_dim2, output_dim3): 23 super().__init__() 24 self.conv1 = nn.Conv2d(3, 64, 25 kernel_size=(7, 7), 26 stride=(2, 2), 27 padding=3) 28 self.bn1 = nn.BatchNorm2d(64) 29 self.relu1 = nn.ReLU() 30 self.pool1 = nn.MaxPool2d(kernel_size=(3, 3), 31 stride=(2, 2), 32 padding=1) 33 # Block 1 34 self.block0 = self._building_block(64) 35 self.block1 = nn.ModuleList([ 36 self._building_block(64) for _ in range(2) 37 ]) 38 self.conv2 = nn.Conv2d(64, 128, 39 kernel_size=(1, 1), 40 stride=(2, 2)) 41 # Block 2 42 self.block2 = nn.ModuleList([ 43 self._building_block(128) for _ in range(4) 44 ]) 45 self.conv3 = nn.Conv2d(128, 256, 46 kernel_size=(1, 1), 47 stride=(2, 2)) 48 # Block 3 49 self.block3 = nn.ModuleList([ 50 self._building_block(256) for _ in range(6) 51 ]) 52 self.conv4 = nn.Conv2d(256, 512, 53 kernel_size=(1, 1), 54 stride=(2, 2)) 55 # Block 4 56 self.block4 = nn.ModuleList([ 57 self._building_block(512) for _ in range(3) 58 ]) 59 self.avg_pool = GlobalAvgPool2d() # TODO: GlobalAvgPool2d 60 self.fc1 = nn.Linear(512, 200) 61 self.fc2 = nn.Linear(512, 500) 62 self.fc3 = nn.Linear(512, 500) 63 64 # heads 65 self.out1 = nn.Linear(200, output_dim1) 66 nn.init.xavier_normal_(self.out1.weight) 67 self.out2 = nn.Linear(500, output_dim2) 68 nn.init.xavier_normal_(self.out2.weight) 69 self.out3 = nn.Linear(500, output_dim3) 70 nn.init.xavier_normal_(self.out3.weight) 71 72 def forward(self, x): 73 h = self.conv1(x) 74 h = self.bn1(h) 75 h = self.relu1(h) 76 h = self.pool1(h) 77 h = self.block0(h) 78 for block in self.block1: 79 h = block(h) 80 h = self.conv2(h) 81 for block in self.block2: 82 h = block(h) 83 h = self.conv3(h) 84 for block in self.block3: 85 h = block(h) 86 h = self.conv4(h) 87 for block in self.block4: 88 h = block(h) 89 h = self.avg_pool(h) 90 h1 = self.fc1(h) 91 h1 = torch.relu(h1) 92 h2 = self.fc2(h) 93 h2 = torch.relu(h2) 94 h3 = self.fc3(h) 95 h3 = torch.relu(h3) 96 y1 = self.out1(h1) 97 y2 = self.out2(h2) 98 y3 = self.out3(h3) 99 return y1, y2, y3 100 101 def _building_block(self, channel_out, channel_in=None): 102 if channel_in is None: 103 channel_in = channel_out 104 return Block(channel_in, channel_out) 105 106class Block(nn.Module): 107 def __init__(self, channel_in, channel_out): 108 super().__init__() 109 self.conv1 = nn.Conv2d(channel_in, channel_out, 110 kernel_size=(3, 3), 111 padding=1) 112 self.bn1 = nn.BatchNorm2d(channel_out) 113 self.relu1 = nn.ReLU() 114 self.conv2 = nn.Conv2d(channel_out, channel_out, 115 kernel_size=(3, 3), 116 padding=1) 117 self.bn2 = nn.BatchNorm2d(channel_out) 118 self.shortcut = self._shortcut(channel_in, channel_out) 119 self.relu3 = nn.ReLU() 120 def forward(self, x): 121 h = self.conv1(x) 122 h = self.bn1(h) 123 h = self.relu1(h) 124 h = self.conv2(h) 125 h = self.bn2(h) 126 shortcut = self.shortcut(x) 127 y = self.relu3(h + shortcut) # skip connection 128 return y 129 def _shortcut(self, channel_in, channel_out): 130 if channel_in != channel_out: 131 return self._projection(channel_in, channel_out) 132 else: 133 return lambda x: x 134 def _projection(self, channel_in, channel_out): 135 return nn.Conv2d(channel_in, channel_out, 136 kernel_size=(1, 1), 137 padding=0) 138 139class GlobalAvgPool2d(nn.Module): 140 def __init__(self, 141 device='cpu'): 142 super().__init__() 143 def forward(self, x): 144 return F.avg_pool2d(x, kernel_size=x.size()[2:]).view(-1, x.size(1)) 145 146class ImageDataset(Dataset): 147 def __init__(self): # 省略 148 def __len__(self): 149 return len(self.paths) 150 def __getitem__(self, index): 151 path = self.paths[index] 152 img = Image.open(path) 153 label1 = self.labels[0].iloc[index].values 154 label2 = self.labels[1].iloc[index].values 155 label3 = self.labels[2].iloc[index].values 156 return self.transform(img), label1, label2, label3 157 158def set_data_src(NUM_CORES, batch_size, world_size, rank): # 省略 159 return train_dataloader, test_dataloader 160 161if __name__ == '__main__': 162 np.random.seed(1234) 163 torch.manual_seed(1234) 164 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 165 NUM_CORES = multiprocessing.cpu_count() 166 world_size = torch.cuda.device_count() 167 is_ddp = world_size > 1 168 rank = 0 169 batch_size = 32 170 171 base_dir = './' 172 results_dir = 'results' 173 name = 'model_name' 174 175 dt_now = datetime.datetime.now().strftime('%Y%m%d%H%M')[2:] 176 name = dt_now + '_' + name 177 base_dir = Path(base_dir) 178 (base_dir / results_dir / name).mkdir(parents=True, exist_ok=True) 179 def model_name(num): 180 return str(base_dir / results_dir / name / f'model_{num}.pt') 181 def save_model(model, num): 182 torch.save(model, model_name(num)) 183 def save_result(result): 184 with open(str(base_dir / results_dir / name / 'result.csv'), 'w', encoding='Shift_jis') as f: 185 writer = csv.writer(f, lineterminator='\n') 186 writer.writerows(result) 187 188 train_loader, test_loader = set_data_src(NUM_CORES, batch_size, world_size, rank) 189 model = ResNet18(4,10,15).to(device) 190 191 def compute_loss(label, pred): 192 return criterion(pred.float(), label.float()) 193 def train_step(x, t1, t2, t3): 194 model.train() 195 preds = model(x) 196 loss1 = compute_loss(t1, preds[0]) 197 loss2 = compute_loss(t2, preds[1]) 198 loss3 = compute_loss(t3, preds[2]) 199 optimizer.zero_grad() 200 loss = loss1 + loss2 + loss3 201 loss.backward() 202 optimizer.step() 203 return (loss1, loss2, loss3), preds 204 def test_step(x, t1, t2, t3): 205 model.eval() 206 preds = model(x) 207 loss1 = compute_loss(t1, preds[0]) 208 loss2 = compute_loss(t2, preds[1]) 209 loss3 = compute_loss(t3, preds[2]) 210 return (loss1, loss2, loss3), preds 211 criterion = nn.MSELoss() 212 optimizer = optimizers.Adam(model.parameters(), weight_decay=0.01) 213 214 epochs = 100 215 save_every = 10 216 results = [] 217 time_start = time.time() 218 for epoch in range(epochs): 219 train_loss1 = 0. 220 train_loss2 = 0. 221 train_loss3 = 0. 222 test_loss1 = 0. 223 test_loss2 = 0. 224 test_loss3 = 0. 225 for (x, t1, t2, t3) in train_loader: 226 x, t1, t2, t3 = x.to(device), t1.to(device), t2.to(device), t3.to(device) 227 loss, _ = train_step(x, t1, t2, t3) 228 train_loss1 += loss[0].item() 229 train_loss2 += loss[1].item() 230 train_loss3 += loss[2].item() 231 train_loss1 /= len(train_loader) 232 train_loss2 /= len(train_loader) 233 train_loss3 /= len(train_loader) 234 for (x, t1, t2, t3) in test_loader: 235 x, t1, t2, t3 = x.to(device), t1.to(device), t2.to(device), t3.to(device) 236 loss, _ = test_step(x, t1, t2, t3) 237 test_loss1 += loss[0].item() 238 test_loss2 += loss[1].item() 239 test_loss3 += loss[2].item() 240 test_loss1 /= len(test_loader) 241 test_loss2 /= len(test_loader) 242 test_loss3 /= len(test_loader) 243 elapsed_time = time.time()-time_start 244 print('Epoch: {}, Train rmse: {}, Test rmse: {}, Elapsed time: {:.1f}sec'.format( 245 epoch+1, 246 (train_loss1, train_loss2, train_loss3), 247 (test_loss1, test_loss2, test_loss3), 248 elapsed_time 249 )) 250 results.append([ 251 epoch+1, 252 train_loss1, 253 train_loss2, 254 train_loss3, 255 test_loss1, 256 test_loss2, 257 test_loss3, 258 elapsed_time 259 ]) 260 if (epoch+1) % save_every == 0: save_model(model.state_dict(), epoch+1) 261 save_result(results)
試したこと
batch_sizeの変更: 32,128
Adamのweight_decayの変更: 0.01, 0.0001
全結合層のノード数の変更: 1000, 1000, 1000⇒200, 500, 500
補足情報(FW/ツールのバージョンなど)
Python 3.8.1 64-bit
torch 1.6.0+cu101
torchvision 0.7.0+cu101
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。