回答編集履歴
2
d
test
CHANGED
@@ -52,6 +52,34 @@
|
|
52
52
|
|
53
53
|
|
54
54
|
|
55
|
+
dict を自動で作るなら
|
56
|
+
|
57
|
+
|
58
|
+
|
59
|
+
```
|
60
|
+
|
61
|
+
import torch
|
62
|
+
|
63
|
+
from pprint import pprint
|
64
|
+
|
65
|
+
|
66
|
+
|
67
|
+
def get_losses():
|
68
|
+
|
69
|
+
return {x: getattr(torch.nn, x) for x in dir(torch.nn) if x.endswith("Loss")}
|
70
|
+
|
71
|
+
|
72
|
+
|
73
|
+
losses = get_losses()
|
74
|
+
|
75
|
+
|
76
|
+
|
77
|
+
pprint(losses)
|
78
|
+
|
79
|
+
```
|
80
|
+
|
81
|
+
|
82
|
+
|
55
83
|
## 方法2
|
56
84
|
|
57
85
|
|
1
d
test
CHANGED
@@ -70,13 +70,19 @@
|
|
70
70
|
|
71
71
|
|
72
72
|
|
73
|
+
def get_losses():
|
74
|
+
|
75
|
+
return "\n".join([x for x in dir(torch.nn) if x.endswith("Loss")])
|
76
|
+
|
77
|
+
|
78
|
+
|
73
79
|
try:
|
74
80
|
|
75
81
|
loss = getattr(torch.nn, name + "Loss")
|
76
82
|
|
77
83
|
except:
|
78
84
|
|
79
|
-
raise ValueError("存在しません")
|
85
|
+
raise ValueError(f"{name}Loss は存在しません。利用可能な Loss 一覧\n{get_losses()}")
|
80
86
|
|
81
87
|
|
82
88
|
|