質問編集履歴
2
該当するソースコードの追加
title
CHANGED
File without changes
|
body
CHANGED
@@ -12,16 +12,14 @@
|
|
12
12
|
```
|
13
13
|
この確率分布クラスは毎回サンプリング値が異なりますが、具体的に、どのような計算を行っているのしょうか。参考ページ[PROBABILITY DISTRIBUTIONS - TORCH.DISTRIBUTIONS](https://pytorch.org/docs/stable/distributions.html)
|
14
14
|
|
15
|
-
|
15
|
+
multinomialを使っているので基本的には並べ替えを変えているようですが、それだけでしょうか。
|
16
|
-
```
|
16
|
+
```PyTorch
|
17
|
-
|
17
|
+
def sample(self, sample_shape=torch.Size()):
|
18
|
+
sample_shape = self._extended_shape(sample_shape)
|
19
|
+
param_shape = sample_shape + torch.Size((self._num_events,))
|
18
|
-
|
20
|
+
probs = self.probs.expand(param_shape)
|
19
|
-
action tensor(1)
|
20
|
-
|
21
|
+
probs_2d = probs.reshape(-1, self._num_events)
|
21
|
-
|
22
|
+
sample_2d = torch.multinomial(probs_2d, 1, True)
|
22
|
-
|
23
|
+
return sample_2d.reshape(sample_shape)
|
23
|
-
|
24
|
+
|
24
|
-
log_prob tensor(-0.5108)
|
25
|
-
action tensor(0)
|
26
|
-
log_prob tensor(-0.9163)
|
27
25
|
```
|
1
リンク追加
title
CHANGED
File without changes
|
body
CHANGED
@@ -10,7 +10,7 @@
|
|
10
10
|
action = m.sample()
|
11
11
|
log_prob = m.log_prob(action)
|
12
12
|
```
|
13
|
-
この確率分布クラスは毎回サンプリング値が異なりますが、具体的に、どのような計算を行っているのしょうか。
|
13
|
+
この確率分布クラスは毎回サンプリング値が異なりますが、具体的に、どのような計算を行っているのしょうか。参考ページ[PROBABILITY DISTRIBUTIONS - TORCH.DISTRIBUTIONS](https://pytorch.org/docs/stable/distributions.html)
|
14
14
|
|
15
15
|
⬇️5回試行した結果
|
16
16
|
```Python
|