以下でできます。
方法1
名前 → クラス の dict を作っておく
python
1import torch
2
3name = "MSE"
4
5losses = {
6 "MSE": torch.nn.MSELoss,
7 "BCE": torch.nn.BCELoss,
8 "CrossEntropy": torch.nn.CrossEntropyLoss,
9}
10
11try:
12 loss = losses[name]
13except:
14 raise ValueError("存在しません")
15
16print(loss)
17# <class 'torch.nn.modules.loss.MSELoss'>
18# loss() で呼び出せる
dict を自動で作るなら
import torch
from pprint import pprint
def get_losses():
return {x: getattr(torch.nn, x) for x in dir(torch.nn) if x.endswith("Loss")}
losses = get_losses()
pprint(losses)
方法2
getattr() で torch.nn 以下にその名前のクラスがあれば取得する
python
1import torch
2
3name = "MSE"
4
5def get_losses():
6 return "\n".join([x for x in dir(torch.nn) if x.endswith("Loss")])
7
8try:
9 loss = getattr(torch.nn, name + "Loss")
10except:
11 raise ValueError(f"{name}Loss は存在しません。利用可能な Loss 一覧\n{get_losses()}")
12
13print(loss)
14# <class 'torch.nn.modules.loss.MSELoss'>
15# loss() で呼び出せる
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2020/08/09 09:48