0
1
https://meditech-ai.com/pytorch-efficientnet/
このサイトのコードの変形です。
python
1import glob 2import os 3中略しました。 4print(f"Loss: {loss_sum.item() / len(valid_loader)}, Accuracy: {100*correct/len(valid_data)}% ({correct}/{len(valid_data)})")
20230813現時点のコードです(確かこれ)
使用しているPythonは3.10です。
画像はリンゴとオレンジです。
python
1import glob 2import os 3import random 4 5import matplotlib.pyplot as plt 6import numpy as np 7import pandas as pd 8import torch 9import torch.nn as nn 10import torch.nn.functional as F 11import torch.optim as optim 12from PIL import Image 13from torch.optim.lr_scheduler import StepLR 14from torch.utils.data import DataLoader, Dataset 15from torchvision import datasets, transforms 16from tqdm.notebook import tqdm 17 18from pathlib import Path 19import seaborn as sns 20import timm 21from pprint import pprint 22 23import copy 24from tqdm import tqdm 25 26# Training settings,epochs50から2へ変更 27epochs = 2 28lr = 3e-5 29gamma = 0.7 30seed = 42 31 32def seed_everything(seed): 33 random.seed(seed) 34 os.environ['PYTHONHASHSEED'] = str(seed) 35 np.random.seed(seed) 36 torch.manual_seed(seed) 37 torch.cuda.manual_seed(seed) 38 torch.cuda.manual_seed_all(seed) 39 torch.backends.cudnn.deterministic = True 40 41seed_everything(seed) 42 43device = 'cpu' 44train_dataset_dir = Path('./Gender01/train') 45val_dataset_dir = Path('./Gender01/validation') 46test_dataset_dir = Path('./Gender01/test') 47 48files = glob.glob('./Gender01/*/*/*.png') 49random_idx = np.random.randint(1, len(files), size=9) 50fig, axes = plt.subplots(3, 3, figsize=(8, 6)) 51 52for idx, ax in enumerate(axes.ravel()): 53 img = Image.open(files[idx]) 54 ax.imshow(img) 55 56train_transforms = transforms.Compose( 57 [ 58 transforms.Resize((224, 224)), 59 transforms.RandomHorizontalFlip(), 60 transforms.ToTensor(), 61 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 62 ] 63) 64 65val_transforms = transforms.Compose( 66 [ 67 transforms.Resize((224, 224)), 68 transforms.ToTensor(), 69 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 70 ] 71) 72 73test_transforms = transforms.Compose( 74 [ 75 transforms.Resize((224, 224)), 76 transforms.ToTensor(), 77 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 78 ] 79) 80 81train_data = datasets.ImageFolder(train_dataset_dir,train_transforms) 82valid_data = datasets.ImageFolder(val_dataset_dir, val_transforms) 83test_data = datasets.ImageFolder(test_dataset_dir, test_transforms) 84 85train_loader = DataLoader(dataset = train_data, batch_size=16, shuffle=True ) 86valid_loader = DataLoader(dataset = valid_data, batch_size=16, shuffle=False) 87test_loader = DataLoader(dataset = test_data, batch_size=16, shuffle=False) 88 89model_names = timm.list_models(pretrained=True) 90pprint(model_names) 91 92model = timm.create_model('tf_efficientnetv2_s_in21ft1k', pretrained=True, num_classes=2) 93model = model.to(device) 94 95# loss function 96criterion = nn.CrossEntropyLoss() 97# optimizer 98optimizer = optim.Adam(model.parameters(), lr=lr) 99# scheduler 100scheduler = StepLR(optimizer, step_size=1, gamma=gamma) 101 102best_loss = None 103 104# Accuracy計算用の関数 105def calculate_accuracy(output, target): 106 output = (torch.sigmoid(output) >= 0.5) 107 target = (target == 1.0) 108 accuracy = torch.true_divide((target == output).sum(dim=0), output.size(0)).item() 109 return accuracy 110 111train_acc_list = [] 112val_acc_list = [] 113train_loss_list = [] 114val_loss_list = [] 115 116for epoch in range(epochs): 117 epoch_loss = 0 118 epoch_accuracy = 0 119 120 for data, label in tqdm(train_loader): 121 data = data.to(device) 122 label = label.to(device) 123 124 output = model(data) 125 loss = criterion(output, label) 126 127 optimizer.zero_grad() 128 loss.backward() 129 optimizer.step() 130 131 acc = (output.argmax(dim=1) == label).float().mean() 132 epoch_accuracy += acc / len(train_loader) 133 epoch_loss += loss / len(train_loader) 134 135 print("bbb") #元なかったが追記 136 with torch.no_grad(): 137 epoch_val_accuracy = 0 138 epoch_val_loss = 0 139 for data, label in valid_loader: 140 data = data.to(device) 141 label = label.to(device) 142 143 val_output = model(data) 144 val_loss = criterion(val_output, label) 145 146 acc = (val_output.argmax(dim=1) == label).float().mean() 147 epoch_val_accuracy += acc / len(valid_loader) 148 epoch_val_loss += val_loss / len(valid_loader) 149 150 print( 151 f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n" 152 ) 153 154 train_acc_list.append(epoch_accuracy) 155 val_acc_list.append(epoch_val_accuracy) 156 train_loss_list.append(epoch_loss) 157 val_loss_list.append(epoch_val_loss) 158 159 if (best_loss is None) or (best_loss > val_loss): 160 best_loss = val_loss 161 model_path = './Gender01/save/bestViTmodel.pth' 162 torch.save(model.state_dict(), model_path) 163 164 print() 165 166device2 = torch.device('cpu') 167 168train_acc = [] 169train_loss = [] 170val_acc = [] 171val_loss = [] 172 173for i in range(epochs): 174 train_acc2 = train_acc_list[i].to(device2) 175 train_acc3 = train_acc2.clone().numpy() 176 train_acc.append(train_acc3) 177 178 train_loss2 = train_loss_list[i].to(device2) 179 train_loss3 = train_loss2.clone().detach().numpy() 180 train_loss.append(train_loss3) 181 182 val_acc2 = val_acc_list[i].to(device2) 183 val_acc3 = val_acc2.clone().numpy() 184 val_acc.append(val_acc3) 185 186 val_loss2 = val_loss_list[i].to(device2) 187 val_loss3 = val_loss2.clone().numpy() 188 val_loss.append(val_loss3) 189 190#取得したデータをグラフ化する 191sns.set() 192num_epochs = epochs 193 194fig = plt.subplots(figsize=(12, 4), dpi=80) 195 196ax1 = plt.subplot(1,2,1) 197ax1.plot(range(num_epochs), train_acc, c='b', label='train acc') 198ax1.plot(range(num_epochs), val_acc, c='r', label='val acc') 199ax1.set_xlabel('epoch', fontsize='12') 200ax1.set_ylabel('accuracy', fontsize='12') 201ax1.set_title('training and val acc', fontsize='14') 202ax1.legend(fontsize='12') 203 204ax2 = plt.subplot(1,2,2) 205ax2.plot(range(num_epochs), train_loss, c='b', label='train loss') 206ax2.plot(range(num_epochs), val_loss, c='r', label='val loss') 207ax2.set_xlabel('epoch', fontsize='12') 208ax2.set_ylabel('loss', fontsize='12') 209ax2.set_title('training and val loss', fontsize='14') 210ax2.legend(fontsize='12') 211plt.show() 212 213print("aaa") # 元は無かったが追記 214model.eval() # モデルを評価モードにする 215 216loss_sum = 0 217correct = 0 218 219print("bbb") #元なかったが追記 220with torch.no_grad(): 221 for data, labels in test_loader: 222 223 # GPUが使えるならGPUにデータを送る 224 data = data.to(device) 225 labels = labels.to(device) 226 227 # ニューラルネットワークの処理を実施 228 outputs = model(data) 229 print(outputs) # 新たに追記 230 # 損失(出力とラベルとの誤差)の計算 231 loss_sum += criterion(outputs, labels) 232 233 # 正解の値を取得 234 pred = outputs.argmax(1) 235 print(pred) # 新たに追記 236 # 正解数をカウント 237 correct += pred.eq(labels.view_as(pred)).sum().item() 238 239print(f"Loss: {loss_sum.item() / len(test_loader)}, Accuracy: {100*correct/len(test_data)}% ({correct}/{len(test_data)})") 240 241print("ccc") #元なかったが追記
実行結果がこちらです。
'xcit_tiny_24_p16_384.fb_dist_in1k']
C:\Users\user\AppData\Local\Programs\Python\Python310\lib\site-packages\timm\models_factory.py:114: UserWarning: Mapping deprecated model name tf_efficientnetv2_s_in21ft1k to current tf_efficientnetv2_s.in21k_ft_in1k.
model = create_fn(
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:32<00:00, 10.81s/it]
bbb
Epoch : 1 - loss : 6.1771 - acc: 0.4754 - val_loss : 8.3888 - val_acc: 0.4688
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:29<00:00, 9.98s/it]
bbb
Epoch : 2 - loss : 1.1589 - acc: 0.8030 - val_loss : 7.3929 - val_acc: 0.4688
C:\Users\user\Desktop\0813nn.py:196: MatplotlibDeprecationWarning: Auto-removal of overlapping axes is deprecated since 3.6 and will be removed two minor releases later; explicitly call ax.remove() as needed.
ax1 = plt.subplot(1,2,1)
aaa
bbb
tensor([[ 24.5468, -7.7497],
[ 4.6458, 5.0631],
[ -0.3985, -10.5740],
[ 8.4914, -17.8499],
[ 12.2553, -2.8605],
[ 14.7144, -12.8341],
[ 26.6738, 5.5297],
[ 14.7019, -9.2587],
[ 13.5414, 0.3043],
[ 13.3540, -8.5363],
[ 7.1949, 6.6422],
[ 33.3601, -4.4459],
[ 16.8021, 5.1939],
[ -4.7226, -11.2997],
[ 4.4507, -4.2222],
[ 8.8262, -7.5904]])
tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([[ 20.1546, -3.7890],
[ 20.6143, 1.0150],
[ 4.4203, 21.4808],
[ 0.0692, 4.3055],
[ -9.2077, 15.0064],
[ -0.0940, 18.2106],
[ -7.7337, 18.9897],
[-11.1088, 2.6819],
[ -8.8662, 21.3578],
[ -3.6728, 23.6316],
[ -0.6814, 11.9310],
[ -2.1048, 25.2059],
[ -3.5031, 21.3227]])
tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Loss: 0.04366261884570122, Accuracy: 96.55172413793103% (28/29)
ccc
このコードの意味を理解したいのですが、適切なウェブページや本はありますか、教えて下さい、複数でも良いです、基本から分かっていません。
回答38件
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。