質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

OpenCV

OpenCV(オープンソースコンピュータービジョン)は、1999年にインテルが開発・公開したオープンソースのコンピュータビジョン向けのクロスプラットフォームライブラリです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

1718閲覧

PyTorchでFloat型で計算するようにエラーが出ましたが、修正箇所がわかりません。

Hiro051

総合スコア9

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

OpenCV

OpenCV(オープンソースコンピュータービジョン)は、1999年にインテルが開発・公開したオープンソースのコンピュータビジョン向けのクロスプラットフォームライブラリです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2020/05/16 01:49

前提・実現したいこと

Python初心者です。
独自のデータセット'/home/selen/downloads/'(嵐5人の顔画像)を用いて
PyTorchでファインチューニングを行いたいと思っています。
入力データはGoogleでスクレイピングし、OpenCVで顔だけ切り取った同じサイズのものです。
'Arashi/arashi.py'で実行しています。

発生している問題・エラーメッセージ

/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:23: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` in the `DataLoader` init to improve performance. warnings.warn(*args, **kwargs) Validation sanity check: 0it [00:00, ?it/s]Traceback (most recent call last): File "Arashi/arashi.py", line 107, in <module> trainer.fit(fine_net) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 765, in fit self.single_gpu_train(model) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 492, in single_gpu_train self.run_pretrain_routine(model) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 899, in run_pretrain_routine False) File "/home/selen/.pyenv/versions/3.7.3/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 322, in _evaluate eval_results = model.validation_end(outputs) File "Arashi/arashi.py", line 74, in validation_end avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean() RuntimeError: Can only calculate the mean of floating types. Got Long instead.

該当のソースコード

Python

1import torch, torchvision 2import torch.nn as nn 3import torch.nn.functional as F 4from torchvision import transforms 5import pytorch_lightning as pl 6from pytorch_lightning import Trainer 7 8from PIL import Image 9import glob 10 11fold_path = '/home/selen/downloads/' 12imgs = [] 13for imgs_path in glob.glob(fold_path + '*'): 14 imgs.append(glob.glob(imgs_path + '/*')) 15 16from torchvision.models import resnet18 17resnet = resnet18(pretrained=True) 18 19transform = transforms.Compose([ 20 transforms.Resize((224, 224)), 21 transforms.ToTensor(), 22 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 23]) 24 25labels = [] 26img_datas = torch.tensor([]) 27 28for i,imgs_arr in enumerate(imgs): 29 30 for img_path in imgs_arr: 31 labels.append(i) 32 img = Image.open(img_path) 33 tensor_img = transform(img) 34 tensor_img = tensor_img.unsqueeze(0) 35 img_datas = torch.cat([img_datas, tensor_img],dim=0) 36 37datasets = torch.utils.data.TensorDataset(img_datas, torch.tensor(labels)) 38 39 40n_train = int(len(datasets) * 0.85) 41n_val = len(datasets) - n_train 42torch.manual_seed(0) 43train,val = torch.utils.data.random_split(datasets,[n_train,n_val]) 44 45class TrainNet(pl.LightningModule): 46 @pl.data_loader 47 def train_dataloader(self): 48 return torch.utils.data.DataLoader(train, self.batch_size,shuffle=True) 49 50 def training_step(self, batch, batch_nb): 51 x, t = batch 52 y = self.forward(x) 53 loss = self.lossfun(y, t) 54 results = {'loss': loss} 55 return results 56 57class ValidationNet(pl.LightningModule): 58 59 @pl.data_loader 60 def val_dataloader(self): 61 return torch.utils.data.DataLoader(val, self.batch_size) 62 63 def validation_step(self, batch, batch_nb): 64 x, t = batch 65 y = self.forward(x) 66 loss = self.lossfun(y, t) 67 y_label = torch.argmax(y, dim=1) 68 acc = torch.sum(t == y_label) * 1.0 / len(t) 69 results = {'val_loss': loss, 'val_acc': acc} 70 return results 71 72 def validation_end(self, outputs): 73 avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 74 avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean() 75 results = {'val_loss': avg_loss, 'val_acc': avg_acc} 76 return results 77 78class FineTuningNet(TrainNet, ValidationNet): 79 80 def __init__(self, batch_size=256): 81 super().__init__() 82 self.batch_size = batch_size 83 self.conv = resnet18(pretrained=True) 84 self.fc1 = nn.Linear(1000, 100) 85 self.fc2 = nn.Linear(100, 5) 86 for param in self.conv.parameters(): 87 param.requires_grad = False 88 89 def lossfun(self, y, t): 90 return F.cross_entropy(y, t) 91 92 def configure_optimizers(self): 93 return torch.optim.SGD(self.parameters(), lr=0.01) 94 95 def forward(self, x): 96 x = self.conv(x) 97 x = self.fc1(x) 98 x = F.relu(x) 99 x = self.fc2(x) 100 return x 101 102torch.backends.cudnn.deterministic = True 103torch.backends.cudnn.benchmark = False 104 105fine_net = FineTuningNet() 106trainer = Trainer(gpus=1, max_epochs=300) 107trainer.fit(fine_net)

間違っている箇所あればご指摘していただけると幸いです。
よろしくお願いいたします。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()avg_acc = torch.stack([x['val_acc'] for x in outputs]).float().mean()にしてはどうでしょうか?

投稿2020/05/16 04:08

meg_

総合スコア10579

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

Hiro051

2020/05/16 04:41

できました。ありがとうございます! 学習ができたので学習済みモデルの精度を確認したいのですが 最後に "trainer.callback_metrics" を追記しても出力されません。 これはなぜでしょうか?
meg_

2020/05/16 06:39

何も表示されないということですか? "精度"が表示されないということですか?
Hiro051

2020/05/16 07:02

warnings.warn(*args, **kwargs) Epoch 300: 100%|███████████████████████████████████████| 2/2 [00:00<00:00, 8.00it/s, loss=0.081, v_num=16] 上記のように精度が表示されません。
meg_

2020/05/16 07:35

上記は学習中の進捗表示ですね。 下記のようなものが表示されませんか? {'epoch': 300, 'loss': 0.015820687227787348, 'val_acc': 0.9022806587490845, 'val_loss': 0.25298880732891025}
Hiro051

2020/05/16 07:39

はい。表示されていません。。
meg_

2020/05/16 08:02

学習は完了しているのですよね? スクリプトの実行環境は何ですか?
Hiro051

2020/05/16 09:07

MacターミナルからSSH接続したリモートマシンで実行しています。
meg_

2020/05/16 10:28

何故表示されないのか、ちょっと分からないですね。 その件については新たに質問を立てると有識者の方から回答いただけるかもしれません。
Hiro051

2020/05/16 11:04

そうしてみます。 親切に対応していただきありがとうございました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問