質問編集履歴

3

コード追加

2021/07/07 08:30

投稿

mmmw
mmmw

スコア23

test CHANGED
File without changes
test CHANGED
@@ -241,3 +241,103 @@
241
241
  df_test['pred'] = pred
242
242
 
243
243
  ```
244
+
245
+
246
+
247
+
248
+
249
+ ### 追加コード
250
+
251
+ ```python
252
+
253
+ #train_Dataloader
254
+
255
+ import torch
256
+
257
+ image_dataloaders = {
258
+
259
+ 'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=4,shuffle=True, num_workers=0, drop_last=True),
260
+
261
+ 'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=4,shuffle=False, num_workers=0, drop_last=True),
262
+
263
+ }
264
+
265
+ train_dataloader = torch.utils.data.DataLoader(image_datasets['train'],
266
+
267
+ batch_size=4,
268
+
269
+ shuffle=True,
270
+
271
+ num_workers=0,
272
+
273
+ drop_last=True)
274
+
275
+ val_dataloader = torch.utils.data.DataLoader(image_datasets['val'],
276
+
277
+ batch_size=4,
278
+
279
+ shuffle=False,
280
+
281
+ num_workers=0,
282
+
283
+ drop_last=True)
284
+
285
+
286
+
287
+
288
+
289
+ #train_model呼び出し
290
+
291
+ dataset_sizes = {'train':train_dataloader, 'val':val_dataloader}
292
+
293
+ train_model(model_ft,dataset_sizes,criterion,optimizer,40,False)
294
+
295
+
296
+
297
+
298
+
299
+ #学習済みmodelのロード,loaded_modelの定義
300
+
301
+ import pickle
302
+
303
+ import torch.nn as nn
304
+
305
+ import torch
306
+
307
+ from torchvision import datasets, models, transforms
308
+
309
+
310
+
311
+ DEVICE= "cpu"
312
+
313
+ def get_model(target_num,isPretrained=False):
314
+
315
+ model_ft = models.resnet18(pretrained=isPretrained)
316
+
317
+ model_ft.fc = nn.Linear(512, target_num)
318
+
319
+ model_ft = model_ft.to(DEVICE)
320
+
321
+ return model_ft
322
+
323
+ best_model = get_model(target_num=2)
324
+
325
+ # モデルを保存する
326
+
327
+ modelname = './original_model_39.pth'
328
+
329
+ pickle.dump(best_model, open(modelname, 'wb'))
330
+
331
+ # 保存したモデルをロードする
332
+
333
+ loaded_model = pickle.load(open(modelname, 'rb'))
334
+
335
+
336
+
337
+
338
+
339
+
340
+
341
+
342
+
343
+ ```

2

質問内容の修正

2021/07/07 08:30

投稿

mmmw
mmmw

スコア23

test CHANGED
File without changes
test CHANGED
@@ -10,6 +10,10 @@
10
10
 
11
11
  testデータ推論時に、正常品1,欠陥品0と出力されるようにしたい。
12
12
 
13
+ ハイパラメータの設定が上手くいっていないため出力がおかしなことになっていると考えている。
14
+
15
+ (学習率、モーメンタムなど、)
16
+
13
17
 
14
18
 
15
19
 

1

質問内容の修正

2021/07/06 13:05

投稿

mmmw
mmmw

スコア23

test CHANGED
File without changes
test CHANGED
@@ -8,6 +8,12 @@
8
8
 
9
9
 
10
10
 
11
+ testデータ推論時に、正常品1,欠陥品0と出力されるようにしたい。
12
+
13
+
14
+
15
+
16
+
11
17
  ### 発生している問題・エラーメッセージ
12
18
 
13
19