実現したいこと
・作成したコードが学習できるようにしたい
前提
Pytorchでcifar10の画像を4枚つなげた1枚の画像のマルチラベル分類をするコードを作っています。
発生している問題・エラーメッセージ
実行すると、はじめから損失が低く出てしまっていて、学習している様子がありません。
実際に正解率も低いです。
試したこと
損失関数や評価関数、モデルを見直しましたがうまくいきません
補足情報(FW/ツールのバージョンなど)
python3.10.10
該当のソースコード
・net.py
python
1import torch 2import torch.nn as nn 3import torchvision 4import torch.nn.functional as F 5 6 7class MyResNet50(nn.Module): 8 def __init__(self, my_pretrained_model): 9 super(MyResNet50, self).__init__() 10 self.pretrained = my_pretrained_model 11 num_classes = 10 # マルチラベル分類のクラス数 12 13 # ドロップアウト率を設定 14 dropout_rate = 0.5 15 self.my_new_layers = nn.Sequential( 16 nn.Linear(1000, 4096), # 例として追加の全結合層を挿入 17 nn.Dropout(p=dropout_rate), 18 nn.ReLU(inplace=True), 19 nn.Dropout(p=dropout_rate), 20 nn.Linear(4096, num_classes), 21 nn.Sigmoid() # 最終層にSigmoid関数を追加 22 ) 23 24 def forward(self, x): 25 x = self.pretrained(x) 26 x = self.my_new_layers(x) 27 return x
・コード本体
python
1import torch 2import torch.nn as nn 3import os 4import torch.optim as optim 5import torchvision 6import numpy as np 7import matplotlib.pyplot as plt 8import torchvision.models as models 9from torch.utils.data import Dataset 10import torchvision.transforms as transforms 11import torch.optim.lr_scheduler as lr_scheduler 12import torch.nn.functional as F 13import pdb 14import wandb 15from net import MyResNet50 16 17 18# start a new wandb run to track this script 19wandb.init( 20 # set the wandb project where this run will be logged 21 project="multi_label_classification4", 22 23 # track hyperparameters and run metadata 24 config={ 25 "learning_rate": 0.001, 26 "architecture": "CNN", 27 "dataset": "CIFAR-10", 28 "epochs": 10, 29 } 30) 31def combine_single_image(X, y): 32 out_image = np.zeros((64, 64, 3), dtype=np.uint8) 33 out_label = np.zeros(10, dtype=np.float32) 34 category_idx = np.random.permutation(np.arange(10))[:4] 35 36 for i, category in enumerate(category_idx): 37 filters = np.where(y==category)[0] 38 item_idx = np.random.permutation(filters)[i] 39 if i == 0: 40 out_image[:32, :32, :] = X[item_idx, :, :, :] 41 elif i == 1: 42 out_image[:32, 32:,:] = X[item_idx, :, :, :] 43 elif i == 2: 44 out_image[32:, :32, :] = X[item_idx, :, :, :] 45 else: 46 out_image[32:, 32:, :] = X[item_idx, :, :, :] 47 out_label[category] = 1.0 48 return out_image, out_label 49 50def create_data(X, y, num): 51 images=[] 52 labels=[] 53 for i in range(num): 54 X_item, y_item = combine_single_image(X, y) 55 images.append(X_item) 56 labels.append(y_item) 57 return (images,labels) 58 59train = torchvision.datasets.CIFAR10(root='./data/cifar10', 60 train=True, 61 download=True, 62 ) 63test = torchvision.datasets.CIFAR10(root='./data/cifar10', 64 train=False, 65 download=True, 66 ) 67 68X_train = np.zeros((len(train),32,32,3),dtype=np.uint8) 69y_train = np.zeros(len(train),dtype=np.int32) 70for i in range(len(train)): 71 img, label = train[i] 72 X_train[i] = np.array(img) 73 y_train[i] = label 74train_data = create_data(X_train, y_train, 500) 75 76 77X_test = np.zeros((len(test),32,32,3),dtype=np.uint8) 78y_test = np.zeros(len(test),dtype=np.int32) 79for i in range(len(test)): 80 img, label = test[i] 81 X_test[i] = np.array(img) 82 y_test[i] = label 83test_data = create_data(X_test, y_test, 100) 84 85transform = transforms.Compose([ 86 transforms.ToPILImage(), 87 transforms.ToTensor(), # Tensorに変換 88 transforms.Resize((224, 224)), # 画像サイズの変更 89 transforms.RandomHorizontalFlip(), # ランダムな水平反転 90 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # 正規化 91 transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False) # ランダム化2 92]) 93 94class MultiLabelTrainDataset(Dataset): 95 def __init__(self, train_data, transform=transform): 96 self.images = train_data[0] 97 self.labels = train_data[1] 98 self.transform = transform 99 100 def __getitem__(self, idx): 101 image = self.images[idx] 102 label = self.labels[idx] 103 104 if self.transform is not None: 105 image = self.transform(image) 106 107 return (image, label) 108 109 def __len__(self): 110 return len(self.images) 111 112class MultiLabelTestDataset(Dataset): 113 def __init__(self, test_data, transform=transform): 114 self.images = test_data[0] 115 self.labels = test_data[1] 116 self.transform = transform 117 118 def __getitem__(self, idx): 119 image = self.images[idx] 120 label = self.labels[idx] 121 122 if self.transform: 123 image = self.transform(image) 124 125 return (image, label) 126 127 def __len__(self): 128 return len(self.images) 129 130train_data=train_data 131test_data=test_data 132 133train_dataset = MultiLabelTrainDataset(train_data) 134test_dataset = MultiLabelTestDataset(test_data) 135 136train_dataloader =torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) 137test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32,shuffle=False) 138 139def binary_loss(y_true, y_pred): 140 bce = F.binary_cross_entropy(y_pred, y_true) 141 return torch.sum(bce, dim=-1) 142 143# 完全正解の評価 144def evaluate_exact(y_true, y_pred): 145 y_pred = torch.where(y_pred>0.5, 1, 0) 146 correct = (y_true == y_pred).all(dim=1).sum().item() 147 total = y_true.size(0) 148 accuracy = correct / total 149 return accuracy 150 151# 部分正解の評価 152def evaluate_partial(y_true, y_pred): 153 y_pred = torch.where(y_pred>0.5, 1, 0) 154 correct = (y_true == y_pred).sum().item() 155 total = y_true.size(0) * y_true.size(1) # データポイントの総数 × ラベルの数 156 accuracy = correct / total # 正解率をパーセンテージ(%)で表現 157 return accuracy 158 159 160pretrained = torchvision.models.resnet50(pretrained=True) 161device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 162model = MyResNet50(my_pretrained_model=pretrained) 163model = model.to(device) 164 165 166# オプティマイザの定義 167optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001) 168 169#スケジューラーの定義 170scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 171 172# モデルのトレーニング 173num_epochs = 5 # トレーニングのエポック数 174 175total_step = len(train_dataloader) 176for epoch in range(num_epochs): 177 model.train() 178 exact_accuracy = 0 179 partial_correct = 0 180 total = 0 181 182 for ind, (images, labels) in enumerate(train_dataloader): 183 images = images.to(device) 184 labels = labels.to(device) 185 optimizer.zero_grad() 186 187 outputs = model(images) 188 189 loss = binary_loss(labels, outputs) 190 loss.backward() 191 optimizer.step() 192 scheduler.step() 193 loss.item() 194 #import pdb;pdb.set_trace() 195 # 完全正解の評価 196 exact_accuracy += evaluate_exact(labels, outputs) 197 # 部分正解の評価 198 partial_correct += evaluate_partial(labels, outputs) 199 total += labels.size(0) 200 201 epoch_loss = loss.item() / total 202 epoch_exact_accuracy = exact_accuracy / total 203 epoch_partial_accuracy = partial_correct / total 204 # log metrics to wandb 205 wandb.log({"exact_acc": epoch_exact_accuracy, "partial_acc": epoch_partial_accuracy, "loss": epoch_loss}) 206 207 print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Exact Accuracy: {epoch_exact_accuracy*100:.2f}% - Partial Accuracy: {epoch_partial_accuracy*100:.2f}%") 208 209 # Save weights 210torch.save(model.state_dict(), './out/model.ckpt') 211 212model.eval() 213test_exact_accuracy = 0 214test_partial_correct = 0 215test_total = 0 216 217with torch.no_grad(): 218 for test_inputs, test_labels in test_dataloader: 219 test_inputs = test_inputs.to(device) 220 test_labels = test_labels.to(device) 221 test_outputs = model(test_inputs) 222 223 224 # 完全正解の評価 225 test_exact_accuracy += evaluate_exact(test_labels, test_outputs) 226 # 部分正解の評価 227 test_partial_correct += evaluate_partial(test_labels, test_outputs) 228 229 test_total = len(test_dataloader) 230 231 test_exact_accuracy = test_exact_accuracy / test_total 232 test_partial_accuracy = test_partial_correct / test_total 233 234 # log metrics to wandb 235 wandb.log({"exact_acc(test)": test_exact_accuracy, "partial_acc(test)": test_partial_accuracy}) 236 237 238 print(f"Exact Accuracy: {test_exact_accuracy*100:.2f}% - Partial Accuracy: {test_partial_accuracy*100:.2f}%") 239 240# [optional] finish the wandb run, necessary in notebooks 241wandb.finish()
あなたの回答
tips
プレビュー