回答編集履歴
2
d
answer
CHANGED
@@ -25,6 +25,20 @@
|
|
25
25
|
# loss() で呼び出せる
|
26
26
|
```
|
27
27
|
|
28
|
+
dict を自動で作るなら
|
29
|
+
|
30
|
+
```
|
31
|
+
import torch
|
32
|
+
from pprint import pprint
|
33
|
+
|
34
|
+
def get_losses():
|
35
|
+
return {x: getattr(torch.nn, x) for x in dir(torch.nn) if x.endswith("Loss")}
|
36
|
+
|
37
|
+
losses = get_losses()
|
38
|
+
|
39
|
+
pprint(losses)
|
40
|
+
```
|
41
|
+
|
28
42
|
## 方法2
|
29
43
|
|
30
44
|
getattr() で torch.nn 以下にその名前のクラスがあれば取得する
|
1
d
answer
CHANGED
@@ -34,10 +34,13 @@
|
|
34
34
|
|
35
35
|
name = "MSE"
|
36
36
|
|
37
|
+
def get_losses():
|
38
|
+
return "\n".join([x for x in dir(torch.nn) if x.endswith("Loss")])
|
39
|
+
|
37
40
|
try:
|
38
41
|
loss = getattr(torch.nn, name + "Loss")
|
39
42
|
except:
|
40
|
-
raise ValueError("存在しません")
|
43
|
+
raise ValueError(f"{name}Loss は存在しません。利用可能な Loss 一覧\n{get_losses()}")
|
41
44
|
|
42
45
|
print(loss)
|
43
46
|
# <class 'torch.nn.modules.loss.MSELoss'>
|