質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

0回答

394閲覧

cifar10を4枚つなげてマルチラベル分類をするコードが学習しない

datitana

総合スコア0

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

1クリップ

投稿2023/06/04 16:50

編集2023/06/05 03:08

実現したいこと

・作成したコードが学習できるようにしたい

前提

Pytorchでcifar10の画像を4枚つなげた1枚の画像のマルチラベル分類をするコードを作っています。

発生している問題・エラーメッセージ

実行すると、はじめから損失が低く出てしまっていて、学習している様子がありません。
実際に正解率も低いです。

試したこと

損失関数や評価関数、モデルを見直しましたがうまくいきません

補足情報(FW/ツールのバージョンなど)

python3.10.10

該当のソースコード

・net.py

python

1import torch 2import torch.nn as nn 3import torchvision 4import torch.nn.functional as F 5 6 7class MyResNet50(nn.Module): 8 def __init__(self, my_pretrained_model): 9 super(MyResNet50, self).__init__() 10 self.pretrained = my_pretrained_model 11 num_classes = 10 # マルチラベル分類のクラス数 12 13 # ドロップアウト率を設定 14 dropout_rate = 0.5 15 self.my_new_layers = nn.Sequential( 16 nn.Linear(1000, 4096), # 例として追加の全結合層を挿入 17 nn.Dropout(p=dropout_rate), 18 nn.ReLU(inplace=True), 19 nn.Dropout(p=dropout_rate), 20 nn.Linear(4096, num_classes), 21 nn.Sigmoid() # 最終層にSigmoid関数を追加 22 ) 23 24 def forward(self, x): 25 x = self.pretrained(x) 26 x = self.my_new_layers(x) 27 return x

・コード本体

python

1import torch 2import torch.nn as nn 3import os 4import torch.optim as optim 5import torchvision 6import numpy as np 7import matplotlib.pyplot as plt 8import torchvision.models as models 9from torch.utils.data import Dataset 10import torchvision.transforms as transforms 11import torch.optim.lr_scheduler as lr_scheduler 12import torch.nn.functional as F 13import pdb 14import wandb 15from net import MyResNet50 16 17 18# start a new wandb run to track this script 19wandb.init( 20 # set the wandb project where this run will be logged 21 project="multi_label_classification4", 22 23 # track hyperparameters and run metadata 24 config={ 25 "learning_rate": 0.001, 26 "architecture": "CNN", 27 "dataset": "CIFAR-10", 28 "epochs": 10, 29 } 30) 31def combine_single_image(X, y): 32 out_image = np.zeros((64, 64, 3), dtype=np.uint8) 33 out_label = np.zeros(10, dtype=np.float32) 34 category_idx = np.random.permutation(np.arange(10))[:4] 35 36 for i, category in enumerate(category_idx): 37 filters = np.where(y==category)[0] 38 item_idx = np.random.permutation(filters)[i] 39 if i == 0: 40 out_image[:32, :32, :] = X[item_idx, :, :, :] 41 elif i == 1: 42 out_image[:32, 32:,:] = X[item_idx, :, :, :] 43 elif i == 2: 44 out_image[32:, :32, :] = X[item_idx, :, :, :] 45 else: 46 out_image[32:, 32:, :] = X[item_idx, :, :, :] 47 out_label[category] = 1.0 48 return out_image, out_label 49 50def create_data(X, y, num): 51 images=[] 52 labels=[] 53 for i in range(num): 54 X_item, y_item = combine_single_image(X, y) 55 images.append(X_item) 56 labels.append(y_item) 57 return (images,labels) 58 59train = torchvision.datasets.CIFAR10(root='./data/cifar10', 60 train=True, 61 download=True, 62 ) 63test = torchvision.datasets.CIFAR10(root='./data/cifar10', 64 train=False, 65 download=True, 66 ) 67 68X_train = np.zeros((len(train),32,32,3),dtype=np.uint8) 69y_train = np.zeros(len(train),dtype=np.int32) 70for i in range(len(train)): 71 img, label = train[i] 72 X_train[i] = np.array(img) 73 y_train[i] = label 74train_data = create_data(X_train, y_train, 500) 75 76 77X_test = np.zeros((len(test),32,32,3),dtype=np.uint8) 78y_test = np.zeros(len(test),dtype=np.int32) 79for i in range(len(test)): 80 img, label = test[i] 81 X_test[i] = np.array(img) 82 y_test[i] = label 83test_data = create_data(X_test, y_test, 100) 84 85transform = transforms.Compose([ 86 transforms.ToPILImage(), 87 transforms.ToTensor(), # Tensorに変換 88 transforms.Resize((224, 224)), # 画像サイズの変更 89 transforms.RandomHorizontalFlip(), # ランダムな水平反転 90 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # 正規化 91 transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False) # ランダム化2 92]) 93 94class MultiLabelTrainDataset(Dataset): 95 def __init__(self, train_data, transform=transform): 96 self.images = train_data[0] 97 self.labels = train_data[1] 98 self.transform = transform 99 100 def __getitem__(self, idx): 101 image = self.images[idx] 102 label = self.labels[idx] 103 104 if self.transform is not None: 105 image = self.transform(image) 106 107 return (image, label) 108 109 def __len__(self): 110 return len(self.images) 111 112class MultiLabelTestDataset(Dataset): 113 def __init__(self, test_data, transform=transform): 114 self.images = test_data[0] 115 self.labels = test_data[1] 116 self.transform = transform 117 118 def __getitem__(self, idx): 119 image = self.images[idx] 120 label = self.labels[idx] 121 122 if self.transform: 123 image = self.transform(image) 124 125 return (image, label) 126 127 def __len__(self): 128 return len(self.images) 129 130train_data=train_data 131test_data=test_data 132 133train_dataset = MultiLabelTrainDataset(train_data) 134test_dataset = MultiLabelTestDataset(test_data) 135 136train_dataloader =torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) 137test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32,shuffle=False) 138 139def binary_loss(y_true, y_pred): 140 bce = F.binary_cross_entropy(y_pred, y_true) 141 return torch.sum(bce, dim=-1) 142 143# 完全正解の評価 144def evaluate_exact(y_true, y_pred): 145 y_pred = torch.where(y_pred>0.5, 1, 0) 146 correct = (y_true == y_pred).all(dim=1).sum().item() 147 total = y_true.size(0) 148 accuracy = correct / total 149 return accuracy 150 151# 部分正解の評価 152def evaluate_partial(y_true, y_pred): 153 y_pred = torch.where(y_pred>0.5, 1, 0) 154 correct = (y_true == y_pred).sum().item() 155 total = y_true.size(0) * y_true.size(1) # データポイントの総数 × ラベルの数 156 accuracy = correct / total # 正解率をパーセンテージ(%)で表現 157 return accuracy 158 159 160pretrained = torchvision.models.resnet50(pretrained=True) 161device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 162model = MyResNet50(my_pretrained_model=pretrained) 163model = model.to(device) 164 165 166# オプティマイザの定義 167optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001) 168 169#スケジューラーの定義 170scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 171 172# モデルのトレーニング 173num_epochs = 5 # トレーニングのエポック数 174 175total_step = len(train_dataloader) 176for epoch in range(num_epochs): 177 model.train() 178 exact_accuracy = 0 179 partial_correct = 0 180 total = 0 181 182 for ind, (images, labels) in enumerate(train_dataloader): 183 images = images.to(device) 184 labels = labels.to(device) 185 optimizer.zero_grad() 186 187 outputs = model(images) 188 189 loss = binary_loss(labels, outputs) 190 loss.backward() 191 optimizer.step() 192 scheduler.step() 193 loss.item() 194 #import pdb;pdb.set_trace() 195 # 完全正解の評価 196 exact_accuracy += evaluate_exact(labels, outputs) 197 # 部分正解の評価 198 partial_correct += evaluate_partial(labels, outputs) 199 total += labels.size(0) 200 201 epoch_loss = loss.item() / total 202 epoch_exact_accuracy = exact_accuracy / total 203 epoch_partial_accuracy = partial_correct / total 204 # log metrics to wandb 205 wandb.log({"exact_acc": epoch_exact_accuracy, "partial_acc": epoch_partial_accuracy, "loss": epoch_loss}) 206 207 print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Exact Accuracy: {epoch_exact_accuracy*100:.2f}% - Partial Accuracy: {epoch_partial_accuracy*100:.2f}%") 208 209 # Save weights 210torch.save(model.state_dict(), './out/model.ckpt') 211 212model.eval() 213test_exact_accuracy = 0 214test_partial_correct = 0 215test_total = 0 216 217with torch.no_grad(): 218 for test_inputs, test_labels in test_dataloader: 219 test_inputs = test_inputs.to(device) 220 test_labels = test_labels.to(device) 221 test_outputs = model(test_inputs) 222 223 224 # 完全正解の評価 225 test_exact_accuracy += evaluate_exact(test_labels, test_outputs) 226 # 部分正解の評価 227 test_partial_correct += evaluate_partial(test_labels, test_outputs) 228 229 test_total = len(test_dataloader) 230 231 test_exact_accuracy = test_exact_accuracy / test_total 232 test_partial_accuracy = test_partial_correct / test_total 233 234 # log metrics to wandb 235 wandb.log({"exact_acc(test)": test_exact_accuracy, "partial_acc(test)": test_partial_accuracy}) 236 237 238 print(f"Exact Accuracy: {test_exact_accuracy*100:.2f}% - Partial Accuracy: {test_partial_accuracy*100:.2f}%") 239 240# [optional] finish the wandb run, necessary in notebooks 241wandb.finish()

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

PondVillege

2023/06/04 17:23

コードはコードブロック内に収めてください インデントが潰れてて再現のしようがありません
datitana

2023/06/04 18:08

修正しました。 申し訳ありません。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問