質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

Q&A

1回答

418閲覧

VGG16 cifar10 エラー

linpoti

総合スコア1

CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

0グッド

0クリップ

投稿2022/10/25 10:05

前提

VGG16でcifar10を実装したい

実現したいこと

エラーなく実装したい

発生している問題・エラーメッセージ

Traceback (most recent call last):
File "/home/lab/workspace/username/VGG2.py", line 112, in <module>
for data in trainloader:
File "/home/lab/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 681, in next
data = self._next_data()
File "/home/lab/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 721, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/lab/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/lab/anaconda3/envs/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/lab/anaconda3/envs//lib/python3.10/site-packages/torchvision/datasets/cifar.py", line 118, in getitem
img = self.transform(img)
TypeError: 'module' object is not callable

該当のソースコード

python

1ソースコード 2import os 3import numpy as np 4import matplotlib.pyplot as plt 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8import torch.optim as optim 9import torchvision 10from torchvision import models 11from torch.utils.data import DataLoader 12from torchvision import datasets, transforms 13 14 15class myVGG(nn.Module): 16 17 def __init__(self): 18 super(myVGG, self).__init__() 19 20 self.conv01 = nn.Conv2d(3, 64, 3) 21 self.conv02 = nn.Conv2d(64, 64, 3) 22 self.pool1 = nn.MaxPool2d(2, 2) 23 24 self.conv03 = nn.Conv2d(64, 128, 3) 25 self.conv04 = nn.Conv2d(128, 128, 3) 26 self.pool2 = nn.MaxPool2d(2, 2) 27 28 self.conv05 = nn.Conv2d(128, 256, 3) 29 self.conv06 = nn.Conv2d(256, 256, 3) 30 self.conv07 = nn.Conv2d(256, 256, 3) 31 self.pool3 = nn.MaxPool2d(2, 2) 32 33 self.conv08 = nn.Conv2d(256, 512, 3) 34 self.conv09 = nn.Conv2d(512, 512, 3) 35 self.conv10 = nn.Conv2d(512, 512, 3) 36 self.pool4 = nn.MaxPool2d(2, 2) 37 38 self.conv11 = nn.Conv2d(512, 512, 3) 39 self.conv12 = nn.Conv2d(512, 512, 3) 40 self.conv13 = nn.Conv2d(512, 512, 3) 41 self.pool5 = nn.MaxPool2d(2, 2) 42 43 self.avepool1 = nn.AdaptiveAvgPool2d((7, 7)) 44 45 self.fc1 = nn.Linear(512 * 7 * 7, 4096) 46 self.fc2 = nn.Linear(4096, 4096) 47 self.fc3 = nn.Linear(4096, 5) 48 49 self.dropout1 = nn.Dropout(0.5) 50 self.dropout2 = nn.Dropout(0.5) 51 52 53 54 def forward(self, x): 55 x = F.relu(self.conv01(x)) 56 x = F.relu(self.conv02(x)) 57 x = self.pool1(x) 58 59 x = F.relu(self.conv03(x)) 60 x = F.relu(self.conv04(x)) 61 x = self.pool2(x) 62 63 x = F.relu(self.conv05(x)) 64 x = F.relu(self.conv06(x)) 65 x = F.relu(self.conv07(x)) 66 x = self.pool3(x) 67 68 x = F.relu(self.conv08(x)) 69 x = F.relu(self.conv09(x)) 70 x = F.relu(self.conv10(x)) 71 x = self.pool4(x) 72 73 x = F.relu(self.conv11(x)) 74 x = F.relu(self.conv12(x)) 75 x = F.relu(self.conv13(x)) 76 x = self.pool5(x) 77 78 x = self.avepool1(x) 79 80 # 行列をベクトルに変換 81 x = x.view(-1, 512 * 7 * 7) 82 83 x = F.relu(self.fc1(x)) 84 x = self.dropout1(x) 85 x = F.relu(self.fc2(x)) 86 x = self.dropout2(x) 87 x = self.fc3(x) 88 89 return x 90 91net = myVGG() 92 93trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transforms) 94validset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transforms) 95 96trainloader = DataLoader(trainset, batch_size=32, shuffle=True) 97validloader = DataLoader(validset, batch_size=32, shuffle=False) 98 99 100device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 101net.to(device) 102net.train() 103 104criterion = nn.CrossEntropyLoss() 105optimizer = optim.Adam(net.parameters(), lr=0.00001) 106 107# 同じデータを 50 回学習します 108for epoch in range(50): 109 110 # 今回の学習効果を保存するための変数 111 running_loss = 0.0 112 113 for data in trainloader: 114 # データ整理 115 inputs, labels = data 116 inputs = inputs.to(device) 117 labels = labels.to(device) 118 119 # 前回の勾配情報をリセット 120 optimizer.zero_grad() 121 122 # 予測 123 outputs = net(inputs) 124 125 # 予測結果と教師ラベルを比べて損失を計算 126 loss = criterion(outputs, labels) 127 running_loss += loss.item() 128 129 # 損失に基づいてネットワークのパラメーターを更新 130 loss.backward() 131 optimizer.step() 132 133 # このエポックの学習効果 134 print(running_loss) 135 136 137 optimizer = optim.Adam(net.parameters(), lr=0.000001) 138 139for epoch in range(50): 140 running_loss = 0.0 141 for data in trainloader: 142 inputs, labels = data 143 inputs = inputs.to(device) 144 labels = labels.to(device) 145 optimizer.zero_grad() 146 outputs = net(inputs) 147 loss = criterion(outputs, labels) 148 running_loss += loss.item() 149 loss.backward() 150 optimizer.step() 151 print(running_loss) 152 153 154 plt.figure() 155plt.style.use('ggplot') 156plt.plot(loss, label='train loss') 157plt.plot(running_loss, label='validation loss') 158plt.legend() 159plt.savefig( 'test_1.png' ) 160 161### 試したこと 1622つのサイトを組み合わせて作ったのですが、うまくできていないと思います。 163### 補足情報(FW/ツールのバージョンなど) 164 165ここにより詳細な情報を記載してください。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

jbpb0

2022/10/26 00:43 編集

> trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transforms) validset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transforms) のところを、 https://teratail.com/questions/krf6b5esfx845l のコードの下記の部分のように変えたら、この質問のエラーは出なくなると思います (下記そのままではなく、この質問のコードに合わせて変数名とかを変えてください) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5,)) ]) train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) validation_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) 参考 https://qiita.com/kazetof/items/6a72926b9f8cd44c218e の「3. データのロード,下処理」 上記を変更したら、それよりも後(下)のところで別のエラーが出ますが、それはこの質問のエラーとは別の内容なので、別の質問にしてください
jbpb0

2022/10/26 00:29

何で「for epoch in range(50):」のループを2回やってるのですか?
linpoti

2022/10/27 14:49

すいません。これは打ち間違いです。 改善して新しく質問出したので教えていただきたいです。
guest

回答1

0

trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transforms)
validset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transforms)

のところを、
CNN cifar10 エラー
のコードの下記の部分のように変えたら、この質問のエラーは出なくなると思います
(下記そのままではなく、この質問のコードに合わせて変数名とかを変えてください)

python

1transform = transforms.Compose([ 2 transforms.ToTensor(), 3 transforms.Normalize((0.5, ), (0.5,)) 4]) 5 6train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) 7validation_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

参考
【詳細(?)】pytorch入門 〜CIFAR10をCNNする〜
の「3. データのロード,下処理」

投稿2022/11/03 02:29

jbpb0

総合スコア7651

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問