質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

3971閲覧

pytroch_lightningでValidation sanity checkが止まってしまう

cc_hk

総合スコア2

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/08/02 08:33

発生している問題・エラーメッセージ

下記のURLを参考に、データセットの自作⇒pytorch_lightningによる分類 を実施しています。
エラーは出ていないのですが、最後のモデルを学習させるところでValidation sanity checkの部分が止まってしまいます。

参考URL:https://free.kikagaku.ai/tutorial/basic_of_deep_learning/learn/pytorch_classification

どの部分が問題で動かないのでしょうか。もし知っている方がいらっしゃればお願いします。

該当のソースコード

python

1import pandas as pd 2import numpy 3import torch 4from torch import nn 5from torch.nn import functional as F 6from torch.utils.data import DataLoader 7from torch.utils.data import random_split 8from torchvision import transforms 9!pip install pytorch_lightning 10import pytorch_lightning as pl 11from pytorch_lightning import Trainer 12 13# colab上でデータアップロード 14from google.colab import files 15uploaded = files.upload() 16 17#データ読み込み 18df = pd.read_csv('/content/Dataset.csv',index_col=0) 19 20 21#データを説明変数と目的変数にわける 22data = df.drop('Tag',axis=1) 23target = df.iloc[:,-1] 24 25# PyTorch で学習に使用できる形式へ変換 26data = torch.tensor(data.values, dtype=torch.float32) 27target = torch.tensor(target.values, dtype=torch.int64) #今回は分類なのでint64 28 29# 目的変数と入力変数をまとめてdatasetに変換 30dataset = torch.utils.data.TensorDataset(data,target) 31 32# 各データセットのサンプル数を決定 33# train : val : test = 60% : 20% : 20% 34n_train = int(len(dataset) * 0.6) 35n_val = int((len(dataset) - n_train) * 0.5) 36n_test = len(dataset) - n_train - n_val 37 38# データセットの分割 39torch.manual_seed(0) #乱数を与えて固定 40train, val, test = torch.utils.data.random_split(dataset, [n_train, n_val,n_test]) 41 42 43# 学習データに対する処理 44class TrainNet(pl.LightningModule): 45 46 def train_dataloader(self): 47 return torch.utils.data.DataLoader(train, self.batch_size, shuffle=True, num_workers=self.num_workers) 48 49 def training_step(self, batch, batch_nb): 50 x, t = batch 51 y = self.forward(x) 52 loss = self.lossfun(y, t) 53 results = {'loss': loss} 54 return results 55 56# 検証データに対する処理 57class ValidationNet(pl.LightningModule): 58 59 def val_dataloader(self): 60 return torch.utils.data.DataLoader(val, self.batch_size) 61 62 def validation_step(self, batch, batch_nb): 63 x, t = batch 64 y = self.forward(x) 65 loss = self.lossfun(y, t) 66 y_label = torch.argmax(y, dim=1) 67 acc = torch.sum(t == y_label) * 1.0 / len(t) 68 results = {'val_loss': loss, 'val_acc': acc} 69 return results 70 71 def validation_end(self, outputs): 72 avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 73 avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean() 74 results = {'val_loss': avg_loss, 'val_acc': avg_acc} 75 return results 76 77# テストデータに対する処理 78class TestNet(pl.LightningModule): 79 80 def test_dataloader(self): 81 return torch.utils.data.DataLoader(test, self.batch_size) 82 83 def test_step(self, batch, batch_nb): 84 x, t = batch 85 y = self.forward(x) 86 loss = self.lossfun(y, t) 87 y_label = torch.argmax(y, dim=1) 88 acc = torch.sum(t == y_label) * 1.0 / len(t) 89 results = {'test_loss': loss, 'test_acc': acc} 90 return results 91 92 def test_end(self, outputs): 93 avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 94 avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean() 95 results = {'test_loss': avg_loss, 'test_acc': avg_acc} 96 return results 97 98# 学習データ、検証データ、テストデータへの処理を継承したクラス 99class Net(TrainNet, ValidationNet, TestNet): 100 101 def __init__(self, batch_size=32, num_workers=0): 102 super(Net, self).__init__() 103 self.fc1 = nn.Linear(77, 5) 104 self.fc2 = nn.Linear(5, 2) 105 self.batch_size = batch_size 106 self.num_workers = num_workers 107 108 def forward(self, x): 109 x = self.fc1(x) 110 x = F.relu(x) 111 x = self.fc2(x) 112 return x 113 114 def lossfun(self, y, t): 115 return F.cross_entropy(y, t) 116 117 def configure_optimizers(self): 118 return torch.optim.SGD(self.parameters(), lr=0.1) 119 120net = Net() 121trainer = Trainer(max_epochs=10) 122 123trainer.fit(net) 124

最後の問題画面

Validation sanity check: 0%で止まったままになる。
イメージ説明

補足情報(バージョンなど)

colablatory下
pytorch = 1.9.0+cu102
pytorch_lightning = 1.4.0

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

下記によるとtrainer = Trainer(num_sanity_val_steps=0)で克服します、とありました。
https://github.com/PyTorchLightning/pytorch-lightning/issues/2295

投稿2021/08/03 13:29

odataiki

総合スコア938

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

cc_hk

2021/08/06 03:10

実施してみたのですが、うまくいかず…です。 もしかしたらPC環境問題かもしれません。 回答ありがとうございます。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問