前提・実現したいこと
Python初心者です。
独自のデータセットを用いてAlexNetを実装したいと思っています。
以下のコードでCIFAR-10で実装をすることができました。
次はPath('/data')にある独自のデータセットで実装を行いたいのですが、
データの読み込み部分の書き方がわかりません。
教えていただけますと幸いです。
変更したいデータ読み込み部分
#load CIFA-10 data train_dataset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform) test_dataset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform)
全体のソースコード
Python
1import torch 2import torchvision 3import torch.nn as nn 4import torch.nn.init as init 5import torch.optim as optim 6import torch.nn.functional as F 7import torchvision.transforms as transforms 8import numpy as np 9from matplotlib import pyplot as plt 10 11transform = transforms.Compose([ 12 transforms.ToTensor() 13]) 14 15# load CIFA-10 data 16train_dataset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform) 17test_dataset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform) 18 19print ('train_dataset = ', len(train_dataset)) 20print ('test_dataset = ', len(test_dataset)) 21image, label = train_dataset[0] 22print (image.size()) 23 24# set data loadser 25train_loader = torch.utils.data.DataLoader( 26 dataset=train_dataset, 27 batch_size=64, 28 shuffle=True, 29 num_workers=2) 30 31test_loader = torch.utils.data.DataLoader( 32 dataset=test_dataset, 33 batch_size=64, 34 shuffle=False, 35 num_workers=2) 36 37# Alexnet 38class AlexNet(nn.Module): 39 40 def __init__(self, num_classes): 41 super(AlexNet, self).__init__() 42 self.features = nn.Sequential( 43 nn.Conv2d(3, 64, kernel_size=3, padding=1), 44 nn.ReLU(inplace=True), 45 nn.MaxPool2d(kernel_size=2, stride=2), 46 nn.Conv2d(64, 192, kernel_size=5, padding=2), 47 nn.ReLU(inplace=True), 48 nn.MaxPool2d(kernel_size=2, stride=2), 49 nn.Conv2d(192, 384, kernel_size=3, padding=1), 50 nn.ReLU(inplace=True), 51 nn.Conv2d(384, 256, kernel_size=3, padding=1), 52 nn.ReLU(inplace=True), 53 nn.Conv2d(256, 256, kernel_size=3, padding=1), 54 nn.ReLU(inplace=True), 55 nn.MaxPool2d(kernel_size=2, stride=2), 56 ) 57 self.classifier = nn.Sequential( 58 nn.Dropout(), 59 nn.Linear(256 * 4 * 4, 4096), 60 nn.ReLU(inplace=True), 61 nn.Dropout(), 62 nn.Linear(4096, 4096), 63 nn.ReLU(inplace=True), 64 nn.Linear(4096, num_classes), 65 ) 66 67 def forward(self, x): 68 x = self.features(x) 69 x = x.view(x.size(0), 256 * 4 * 4) 70 x = self.classifier(x) 71 return x 72 73# select device 74num_classes = 10 75device = 'cuda' if torch.cuda.is_available() else 'cpu' 76net = AlexNet(num_classes).to(device) 77 78# optimizing 79criterion = nn.CrossEntropyLoss() 80optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) 81 82# training 83num_epochs = 20 84train_loss_list, train_acc_list, val_loss_list, val_acc_list = [], [], [], [] 85 86### training 87for epoch in range(num_epochs): 88 train_loss, train_acc, val_loss, val_acc = 0, 0, 0, 0 89 90 # ====== train_mode ====== 91 net.train() 92 for i, (images, labels) in enumerate(train_loader): 93 images, labels = images.to(device), labels.to(device) 94 optimizer.zero_grad() 95 outputs = net(images) 96 loss = criterion(outputs, labels) 97 train_loss += loss.item() 98 train_acc += (outputs.max(1)[1] == labels).sum().item() 99 loss.backward() 100 optimizer.step() 101 102 avg_train_loss = train_loss / len(train_loader.dataset) 103 avg_train_acc = train_acc / len(train_loader.dataset) 104 105 # ====== val_mode ====== 106 net.eval() 107 with torch.no_grad(): 108 for images, labels in test_loader: 109 images = images.to(device) 110 labels = labels.to(device) 111 outputs = net(images) 112 loss = criterion(outputs, labels) 113 val_loss += loss.item() 114 val_acc += (outputs.max(1)[1] == labels).sum().item() 115 avg_val_loss = val_loss / len(test_loader.dataset) 116 avg_val_acc = val_acc / len(test_loader.dataset) 117 118 print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}' 119 .format(epoch+1, num_epochs, i+1, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc)) 120 train_loss_list.append(avg_train_loss) 121 train_acc_list.append(avg_train_acc) 122 val_loss_list.append(avg_val_loss) 123 val_acc_list.append(avg_val_acc) 124 125 126# plot graph 127plt.figure() 128plt.plot(range(num_epochs), train_loss_list, color='blue', linestyle='-', label='train_loss') 129plt.plot(range(num_epochs), val_loss_list, color='green', linestyle='--', label='val_loss') 130plt.legend() 131plt.xlabel('epoch') 132plt.ylabel('loss') 133plt.title('Training and validation loss') 134plt.grid() 135plt.show() 136 137plt.figure() 138plt.plot(range(num_epochs), train_acc_list, color='blue', linestyle='-', label='train_acc') 139plt.plot(range(num_epochs), val_acc_list, color='green', linestyle='--', label='val_acc') 140plt.legend() 141plt.xlabel('epoch') 142plt.ylabel('acc') 143plt.title('Training and validation accuracy') 144plt.grid() 145plt.show()
回答1件
あなたの回答
tips
プレビュー