質問編集履歴

1

必要なプログラムを記載できていませんでした。申し訳ございません。

2020/09/05 00:56

投稿

takuya324232506
takuya324232506

スコア7

test CHANGED
File without changes
test CHANGED
@@ -115,3 +115,119 @@
115
115
  plt.show()
116
116
 
117
117
  ```
118
+
119
+ makeGaussianData.py
120
+
121
+ ```python
122
+
123
+ import numpy as np
124
+
125
+
126
+
127
+
128
+
129
+ def getData(nclass, seed = None):
130
+
131
+
132
+
133
+ assert nclass == 2 or nclass == 3
134
+
135
+
136
+
137
+ if seed != None:
138
+
139
+ np.random.seed(seed)
140
+
141
+
142
+
143
+ # 2次元の spherical な正規分布3つからデータを生成
144
+
145
+ X0 = 0.10 * np.random.randn(200, 2) + [ 0.3, 0.3 ]
146
+
147
+ X1 = 0.10 * np.random.randn(200, 2) + [ 0.7, 0.6 ]
148
+
149
+ X2 = 0.05 * np.random.randn(200, 2) + [ 0.3, 0.7 ]
150
+
151
+
152
+
153
+ # それらのラベル用のarray
154
+
155
+ lab0 = np.zeros(X0.shape[0], dtype = int)
156
+
157
+ lab1 = np.zeros(X1.shape[0], dtype = int) + 1
158
+
159
+ lab2 = np.zeros(X2.shape[0], dtype = int) + 2
160
+
161
+
162
+
163
+ # X (入力データ), label (クラスラベル), t(教師信号) をつくる
164
+
165
+ if nclass == 2:
166
+
167
+ X = np.vstack((X0, X1))
168
+
169
+ label = np.hstack((lab0, lab1))
170
+
171
+ t = np.zeros(X.shape[0])
172
+
173
+ t[label == 1] = 1.0
174
+
175
+ else:
176
+
177
+ X = np.vstack((X0, X1, X2))
178
+
179
+ label = np.hstack((lab0, lab1, lab2))
180
+
181
+ t = np.zeros((X.shape[0], nclass))
182
+
183
+ for ik in range(nclass):
184
+
185
+ t[label == ik, ik] = 1.0
186
+
187
+
188
+
189
+ return X, label, t
190
+
191
+
192
+
193
+
194
+
195
+ if __name__ == '__main__':
196
+
197
+
198
+
199
+ import matplotlib
200
+
201
+ import matplotlib.pyplot as plt
202
+
203
+
204
+
205
+ K = 3
206
+
207
+
208
+
209
+ X, lab, t = getData(K)
210
+
211
+
212
+
213
+ fig = plt.figure()
214
+
215
+ plt.xlim(-0.2, 1.2)
216
+
217
+ plt.ylim(-0.2, 1.2)
218
+
219
+ ax = fig.add_subplot(1, 1, 1)
220
+
221
+ ax.set_aspect(1)
222
+
223
+ ax.scatter(X[lab == 0, 0], X[lab == 0, 1], color = 'red')
224
+
225
+ ax.scatter(X[lab == 1, 0], X[lab == 1, 1], color = 'green')
226
+
227
+ if K == 3:
228
+
229
+ ax.scatter(X[lab == 2, 0], X[lab == 2, 1], color = 'blue')
230
+
231
+ plt.show()
232
+
233
+ ```