質問編集履歴
1
コードを全文載せました
test
CHANGED
File without changes
|
test
CHANGED
@@ -30,9 +30,37 @@
|
|
30
30
|
|
31
31
|
### 該当のソースコード
|
32
32
|
|
33
|
+
network.py
|
34
|
+
|
35
|
+
```
|
33
36
|
|
34
37
|
|
38
|
+
|
39
|
+
import numpy as np
|
40
|
+
|
35
|
-
|
41
|
+
import chainer
|
42
|
+
|
43
|
+
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
|
44
|
+
|
45
|
+
from chainer import datasets, iterators, optimizers, serializers
|
46
|
+
|
47
|
+
from chainer import Link, Chain, ChainList
|
48
|
+
|
49
|
+
import chainer.functions as F
|
50
|
+
|
51
|
+
import chainer.links as L
|
52
|
+
|
53
|
+
|
54
|
+
|
55
|
+
|
56
|
+
|
57
|
+
FILTERS_NUM = 50
|
58
|
+
|
59
|
+
HIDDEN_LAYER_NUM = 10
|
60
|
+
|
61
|
+
|
62
|
+
|
63
|
+
|
36
64
|
|
37
65
|
class AgentNet(Chain):
|
38
66
|
|
@@ -68,7 +96,31 @@
|
|
68
96
|
|
69
97
|
|
70
98
|
|
99
|
+
def __call__(self, x):
|
100
|
+
|
101
|
+
size = x.data.shape[0]
|
102
|
+
|
103
|
+
for n, f in self.forward:
|
104
|
+
|
105
|
+
if not n.startswith('_'):
|
106
|
+
|
107
|
+
x = getattr(self, n)(x)
|
108
|
+
|
109
|
+
else:
|
110
|
+
|
111
|
+
x = f(x)
|
112
|
+
|
113
|
+
x = F.reshape(x, (size, 64))
|
114
|
+
|
115
|
+
if chainer.config.train:
|
116
|
+
|
117
|
+
return x
|
118
|
+
|
119
|
+
return F.softmax(x)
|
120
|
+
|
71
121
|
```
|
122
|
+
|
123
|
+
|
72
124
|
|
73
125
|
|
74
126
|
|