質問編集履歴

1

作成途中

2018/10/25 05:55

投稿

FALLOT
FALLOT

スコア16

test CHANGED
File without changes
test CHANGED
@@ -24,6 +24,10 @@
24
24
 
25
25
  from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
26
26
 
27
+ from sklearn.model_selection import train_test_split
28
+
29
+ from sklearn import datasets
30
+
27
31
  import numpy as np
28
32
 
29
33
  from PIL import Image
@@ -48,9 +52,9 @@
48
52
 
49
53
  #それぞれの画像の枚数を入力
50
54
 
51
- A = 3000
55
+ A = 50
52
-
56
+
53
- B = 3000
57
+ B = 50
54
58
 
55
59
  sum =A+B
56
60
 
@@ -72,7 +76,7 @@
72
76
 
73
77
  #エポック数
74
78
 
75
- E = 500
79
+ E = 5
76
80
 
77
81
  #バッチサイズ
78
82
 
@@ -142,10 +146,6 @@
142
146
 
143
147
 
144
148
 
145
-
146
-
147
-
148
-
149
149
  print("画像の読み込み 終了")
150
150
 
151
151
 
@@ -156,8 +156,6 @@
156
156
 
157
157
 
158
158
 
159
-
160
-
161
159
  #最大応力の位置_読み込み_表示
162
160
 
163
161
  location = np.loadtxt("data/value/max_stress_value_a.csv",delimiter=",",skiprows=0)
@@ -188,7 +186,7 @@
188
186
 
189
187
 
190
188
 
191
- model.add(Dense(8000, input_dim=Z,kernel_initializer='random_uniform',bias_initializer='zeros'))
189
+ model.add(Dense(5000, input_dim=Z,kernel_initializer='random_uniform',bias_initializer='zeros'))
192
190
 
193
191
  #model.add(Activation("LeakyReLU"))
194
192
 
@@ -198,6 +196,8 @@
198
196
 
199
197
 
200
198
 
199
+
200
+
201
201
  model.add(Dense(100,kernel_initializer='random_uniform',bias_initializer='zeros'))
202
202
 
203
203
  model.add(LeakyReLU())
@@ -206,15 +206,17 @@
206
206
 
207
207
 
208
208
 
209
+
210
+
209
- model.add(Dense(50,kernel_initializer='random_uniform',bias_initializer='zeros'))
211
+ model.add(Dense(10,kernel_initializer='random_uniform',bias_initializer='zeros'))
210
212
 
211
213
  model.add(LeakyReLU())
212
214
 
213
- model.add(Dropout(0.075))
215
+ model.add(Dropout(0.0))
214
-
215
-
216
-
216
+
217
+
218
+
217
- model.add(Dense(10,kernel_initializer='random_uniform',bias_initializer='zeros'))
219
+ model.add(Dense(5,kernel_initializer='random_uniform',bias_initializer='zeros'))
218
220
 
219
221
  model.add(LeakyReLU())
220
222
 
@@ -290,6 +292,102 @@
290
292
 
291
293
 
292
294
 
295
+ image_list = np.array(image_list)
296
+
297
+ location_list = np.array(location_list)
298
+
299
+ print(image_list.shape, image_list.dtype) # (300, 11250) float64
300
+
301
+ print(location_list.shape, location_list.dtype) # (300,) float64
302
+
303
+
304
+
305
+
306
+
307
+ def get_batch(image_list, location_list, batch_size, shuffle=False):
308
+
309
+ '''ミニバッチを生成するジェネレーター関数
310
+
311
+ '''
312
+
313
+ num_samples = location_list # サンプル数
314
+
315
+ if shuffle:# シャッフルする場合
316
+
317
+ indices = np.random.permutation(num_samples)
318
+
319
+ else: # シャッフルしない場合
320
+
321
+ indices = np.random.arange(num_samples)
322
+
323
+ num_steps = np.ceil(num_samples / batch_size).astype(int)
324
+
325
+ print(num_steps)
326
+
327
+ print(type(num_steps))
328
+
329
+
330
+
331
+ for itr in range(num_steps):
332
+
333
+ start = batch_size * itr
334
+
335
+ excerpt = indices[start:start + batch_size]
336
+
337
+ yield x[excerpt], y[excerpt]
338
+
339
+
340
+
341
+ # 保存用ディレクトリ
342
+
343
+ out_dirpath = 'prediction'
344
+
345
+ os.makedirs(out_dirpath, exist_ok=True)
346
+
347
+
348
+
349
+ x_train, x_test, y_train, y_test = train_test_split(image_list, location_list, test_size=0.2)
350
+
351
+
352
+
353
+ # 学習する。
354
+
355
+ epochs = E
356
+
357
+ for i in range(epochs):
358
+
359
+ for x_batch, y_batch in get_batch(x_train, y_train, batch_size=BATCH_SIZE, shuffle=True):
360
+
361
+ # x_batch, y_batch が生成されたミニバッチ
362
+
363
+
364
+
365
+ # 1バッチ分学習する
366
+
367
+ model.train_on_batch(x_batch, y_batch)
368
+
369
+
370
+
371
+ # エポックごとにテストデータで推論する。
372
+
373
+ y_pred = model.predict_classes(x_train)
374
+
375
+ result = np.c_[y_pred, y_train]
376
+
377
+
378
+
379
+ # 推論結果を保存する。
380
+
381
+ filepath = os.path.join(out_dirpath, 'prediction_{}.csv'.format(i))
382
+
383
+ np.savetxt(filepath, result, fmt='%.0f')
384
+
385
+
386
+
387
+
388
+
389
+
390
+
293
391
  end_time = time.time()
294
392
 
295
393
  print("\n終了時刻: ",end_time)
@@ -308,4 +406,8 @@
308
406
 
309
407
 
310
408
 
409
+
410
+
411
+
412
+
311
413
  ```