teratail header banner
teratail header banner
質問するログイン新規登録

回答編集履歴

2

d

2020/08/03 08:02

投稿

tiitoi
tiitoi

スコア21960

answer CHANGED
@@ -25,6 +25,20 @@
25
25
  # loss() で呼び出せる
26
26
  ```
27
27
 
28
+ dict を自動で作るなら
29
+
30
+ ```
31
+ import torch
32
+ from pprint import pprint
33
+
34
+ def get_losses():
35
+ return {x: getattr(torch.nn, x) for x in dir(torch.nn) if x.endswith("Loss")}
36
+
37
+ losses = get_losses()
38
+
39
+ pprint(losses)
40
+ ```
41
+
28
42
  ## 方法2
29
43
 
30
44
  getattr() で torch.nn 以下にその名前のクラスがあれば取得する

1

d

2020/08/03 08:02

投稿

tiitoi
tiitoi

スコア21960

answer CHANGED
@@ -34,10 +34,13 @@
34
34
 
35
35
  name = "MSE"
36
36
 
37
+ def get_losses():
38
+ return "\n".join([x for x in dir(torch.nn) if x.endswith("Loss")])
39
+
37
40
  try:
38
41
  loss = getattr(torch.nn, name + "Loss")
39
42
  except:
40
- raise ValueError("存在しません")
43
+ raise ValueError(f"{name}Loss は存在しません。利用可能な Loss 一覧\n{get_losses()}")
41
44
 
42
45
  print(loss)
43
46
  # <class 'torch.nn.modules.loss.MSELoss'>