🎄teratailクリスマスプレゼントキャンペーン2024🎄』開催中!

\teratail特別グッズやAmazonギフトカード最大2,000円分が当たる!/

詳細はこちら
CSV

CSV(Comma-Separated Values)はコンマで区切られた明白なテキスト値のリストです。もしくは、そのフォーマットでひとつ以上のリストを含むファイルを指します。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

5387閲覧

Pytorch CNNで学習が進まない。

tomo193

総合スコア6

CSV

CSV(Comma-Separated Values)はコンマで区切られた明白なテキスト値のリストです。もしくは、そのフォーマットでひとつ以上のリストを含むファイルを指します。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/03/20 07:56

こちらのサイトを参考にしつつ、cnnを実装したところエラーは起きないものの学習が全くされず、困っています。
Colaboratoryで実行しました。
参照先のサイトと異なるのはcsvファイルを用いた点とネットワークの書き方です。
なにかアドバイスがあれば教えていただきたいです。よろしくお願いします。

python

1import time 2import numpy as np 3import pandas as pd 4import torch 5import torchvision 6from torch import nn,optim 7import torch.nn.functional as F 8from torch.utils.data import Dataset,DataLoader,TensorDataset 9 10from sklearn.model_selection import train_test_split 11import matplotlib.pyplot as plt 12%matplotlib inline 13 14import numpy as np 15#kaggleのfashion-mnistのcsvファイルをtrain.csv,test.csvとしています 16 17train =pd.read_csv('train.csv') 18 19y = train['label']#正解ラベルの取り出し 20x = train.drop('label', axis=1)#正解ラベルの消去 21 22x=x.values 23y=y.values 24 25#画像チェック 26x=x.reshape(-1,1,28,28) 27plt.imshow(x[3][0], cmap = 'gray', vmin = 0, vmax = 255, interpolation = 'none') 28plt.show() 29 30x=x/255 31 32train_x,val_x,train_y,val_y=train_test_split(x,y,test_size=0.2,random_state=0) 33train_x = torch.tensor(train_x, dtype=torch.float32) 34train_y = torch.tensor(train_y, dtype=torch.int64) 35 36val_x= torch.tensor(val_x, dtype=torch.float32) 37val_y = torch.tensor(val_y, dtype=torch.int64) 38 39train_set = torch.utils.data.TensorDataset(train_x,train_y) 40test_set = torch.utils.data.TensorDataset(val_x, val_y) 41 42batch_sizes=128 43train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_sizes, shuffle = False) 44test_loader = torch.utils.data.DataLoader(test_set, batch_size = batch_sizes, shuffle = False) 45 46class net(nn.Module): 47 def __init__(self): 48 super(net, self).__init__() 49 self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 5, stride=1, padding=0) 50 self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 5, stride=1, padding=0) 51 self.conv3 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, stride=1, padding=0) 52 self.conv4 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride=1, padding=0) 53 54 self.pool = nn.MaxPool2d(2, 2) 55 self.dropout1=nn.Dropout(p=0.25) 56 self.dropout2=nn.Dropout(p=0.5) 57 self.fc1 = nn.Linear(64*3*3, 256) 58 self.fc2 = nn.Linear(256, 10) 59 60 def forward(self, x): 61 x = F.relu(self.conv1(x)) 62 x = F.relu(self.conv2(x)) 63 x=self.pool(x) 64 x=self.dropout1(x) 65 x = F.relu(self.conv3(x)) 66 x = F.relu(self.conv4(x)) 67 x=self.pool(x) 68 x=self.dropout1(x) 69 70 x = x.view(x.size(0),-1) 71 x = self.fc1(x) 72 x = self.dropout2(x) 73 x = self.fc2(x) 74 return x 75 76device = 'cuda' if torch.cuda.is_available() else 'cpu' 77net = net().to(device) 78 79#損失関数 80criterion = nn.CrossEntropyLoss() 81#最適化 82optimizer = optim.RMSprop(net.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False) 83 84num_epochs =30 85 86train_loss_list = [] 87train_acc_list = [] 88val_loss_list = [] 89val_acc_list = [] 90start = time.time() 91for epoch in range(num_epochs): 92 train_loss = 0 93 train_acc = 0 94 val_loss = 0 95 val_acc = 0 96 97 #train 98 net.train() 99 for i, (images, labels) in enumerate(train_loader): 100 images, labels = images.to(device), labels.to(device) 101 optimizer.zero_grad() 102 outputs = net.forward(images) 103 loss = criterion(outputs, labels) 104 train_loss += loss.item() 105 train_acc += (outputs.max(1)[1] == labels).sum().item() 106 loss.backward() 107 optimizer.step() 108 109 avg_train_loss = train_loss / len(train_loader.dataset) 110 avg_train_acc = train_acc / len(train_loader.dataset) 111 112 #val 113 net.eval() 114 with torch.no_grad(): 115 for images, labels in test_loader: 116 images = images.to(device) 117 labels = labels.to(device) 118 outputs = net.forward(images) 119 loss = criterion(outputs, labels) 120 val_loss += loss.item() 121 val_acc += (outputs.max(1)[1] == labels).sum().item() 122 avg_val_loss = val_loss / len(test_loader.dataset) 123 avg_val_acc = val_acc / len(test_loader.dataset) 124 125 print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}' 126 .format(epoch+1, num_epochs, i+1, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc)) 127 train_loss_list.append(avg_train_loss) 128 train_acc_list.append(avg_train_acc) 129 val_loss_list.append(avg_val_loss) 130 val_acc_list.append(avg_val_acc) 131 132print(time.time() - start) 133 134plt.plot(train_acc_list, label='Training loss') 135plt.plot(val_acc_list, label='Validation loss') 136plt.legend(); 137val_acc_list[-1]

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

csvファイルの読み込みが終わる前に実行してしまい、NaNが大量発生していたことが原因でした。初歩的なミスで申し訳ないです。回答を考えてくださった皆様ありがとうございました。

投稿2021/03/20 08:52

tomo193

総合スコア6

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.36%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問