teratail header banner
teratail header banner
質問するログイン新規登録

質問編集履歴

2

目的を明確化

2017/11/14 09:26

投稿

LaLaLand
LaLaLand

スコア107

title CHANGED
File without changes
body CHANGED
@@ -1,6 +1,6 @@
1
1
  Python3、Chainerでニューラルネットワークの勉強をしています。
2
2
  適当なサンプルコードを書いて、学習結果の保存・読み込みを行おうと思っています。
3
- 学習自体は目的ではなく、あくまで保存・読み込みが目的です。
3
+ 学習自体は目的ではなく、あくまで保存・読み込みの勉強が目的です。
4
4
 
5
5
  ```python
6
6
  # モデルの読み込み(存在する場合)

1

コード全文追加

2017/11/14 09:26

投稿

LaLaLand
LaLaLand

スコア107

title CHANGED
File without changes
body CHANGED
@@ -38,4 +38,92 @@
38
38
 
39
39
  どのようにしたら読み込めますでしょうか?
40
40
 
41
- 教えてください。
41
+ 教えてください。
42
+
43
+ 以下、コードの全文です。
44
+
45
+ ```python
46
+
47
+ from chainer import Chain
48
+ from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
49
+ import chainer.links as L
50
+ import chainer.functions as F
51
+ import numpy as np
52
+ import os
53
+
54
+
55
+
56
+ class MyChain(Chain):
57
+ def __init__(self):
58
+ super(MyChain, self).__init__(
59
+ l1 = L.Linear(2, 10),
60
+ l2 = L.Linear(10, 10),
61
+ l3 = L.Linear(10, 1))
62
+
63
+ def __call__(self, x, y):
64
+ xv = Variable(x)
65
+ yv = Variable(y)
66
+ pr = self.predict(xv)
67
+ return F.mean_squared_error(pr, yv)
68
+
69
+ def predict(self, x):
70
+ h1 = self.l1(x)
71
+ h2 = self.l2(h1)
72
+ h3 = self.l3(h2)
73
+ return h3
74
+
75
+ def createData(max, N):
76
+ # 適当にデータを作って掛け算を学習してみる。
77
+ xs1 = np.linspace(-max, max, N).astype(np.float32)
78
+ xs2 = np.random.normal(0, max, N).astype(np.float32)
79
+ return np.c_[xs1, xs2], np.c_[xs1*xs2]
80
+
81
+
82
+ rangeMax = 13.5
83
+
84
+ N = 1200
85
+ batchSize = 10
86
+ xtrain, ytrain = createData(rangeMax, N)
87
+ xtest, ytest = createData(rangeMax, 1000)
88
+
89
+
90
+ # モデルの読み込み(存在する場合)
91
+ modelFileName = "kakezan-model.npz"
92
+ if os.path.exists(modelFileName):
93
+ model = MyChain()
94
+ serializers.load_npz(modelFileName, model)
95
+ else:
96
+ model = MyChain()
97
+
98
+ # optimizerの読み込み(存在する場合)
99
+ optimizerFileName = "kakezan-optimizer.npz"
100
+ if os.path.exists(optimizerFileName):
101
+ optimizer = optimizers.SMORMS3()
102
+ serializers.load_npz(optimizerFileName, optimizer)
103
+ else:
104
+ optimizer = optimizers.SMORMS3()
105
+ optimizer.setup(model)
106
+
107
+ for epoch in range(120):
108
+ perm = np.random.permutation(N)
109
+
110
+ for i in range(0, N, batchSize):
111
+ # ランダムにbatchSize個のデータを取得
112
+ x_batch = xtrain[perm[i:i + batchSize]]
113
+ y_batch = ytrain[perm[i:i + batchSize]]
114
+
115
+ model.zerograds()
116
+ loss = model(x_batch, y_batch)
117
+ loss.backward()
118
+ optimizer.update()
119
+
120
+ if 0 < epoch and epoch % 30 == 0:
121
+ loss = model(xtest, ytest)
122
+ print(loss.data)
123
+ serializers.save_npz(modelFileName, model)
124
+ serializers.save_npz(optimizerFileName, optimizer)
125
+
126
+
127
+
128
+
129
+ ```