質問編集履歴
3
コード追加
test
CHANGED
File without changes
|
test
CHANGED
@@ -241,3 +241,103 @@
|
|
241
241
|
df_test['pred'] = pred
|
242
242
|
|
243
243
|
```
|
244
|
+
|
245
|
+
|
246
|
+
|
247
|
+
|
248
|
+
|
249
|
+
### 追加コード
|
250
|
+
|
251
|
+
```python
|
252
|
+
|
253
|
+
#train_Dataloader
|
254
|
+
|
255
|
+
import torch
|
256
|
+
|
257
|
+
image_dataloaders = {
|
258
|
+
|
259
|
+
'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=4,shuffle=True, num_workers=0, drop_last=True),
|
260
|
+
|
261
|
+
'val': torch.utils.data.DataLoader(image_datasets['val'], batch_size=4,shuffle=False, num_workers=0, drop_last=True),
|
262
|
+
|
263
|
+
}
|
264
|
+
|
265
|
+
train_dataloader = torch.utils.data.DataLoader(image_datasets['train'],
|
266
|
+
|
267
|
+
batch_size=4,
|
268
|
+
|
269
|
+
shuffle=True,
|
270
|
+
|
271
|
+
num_workers=0,
|
272
|
+
|
273
|
+
drop_last=True)
|
274
|
+
|
275
|
+
val_dataloader = torch.utils.data.DataLoader(image_datasets['val'],
|
276
|
+
|
277
|
+
batch_size=4,
|
278
|
+
|
279
|
+
shuffle=False,
|
280
|
+
|
281
|
+
num_workers=0,
|
282
|
+
|
283
|
+
drop_last=True)
|
284
|
+
|
285
|
+
|
286
|
+
|
287
|
+
|
288
|
+
|
289
|
+
#train_model呼び出し
|
290
|
+
|
291
|
+
dataset_sizes = {'train':train_dataloader, 'val':val_dataloader}
|
292
|
+
|
293
|
+
train_model(model_ft,dataset_sizes,criterion,optimizer,40,False)
|
294
|
+
|
295
|
+
|
296
|
+
|
297
|
+
|
298
|
+
|
299
|
+
#学習済みmodelのロード,loaded_modelの定義
|
300
|
+
|
301
|
+
import pickle
|
302
|
+
|
303
|
+
import torch.nn as nn
|
304
|
+
|
305
|
+
import torch
|
306
|
+
|
307
|
+
from torchvision import datasets, models, transforms
|
308
|
+
|
309
|
+
|
310
|
+
|
311
|
+
DEVICE= "cpu"
|
312
|
+
|
313
|
+
def get_model(target_num,isPretrained=False):
|
314
|
+
|
315
|
+
model_ft = models.resnet18(pretrained=isPretrained)
|
316
|
+
|
317
|
+
model_ft.fc = nn.Linear(512, target_num)
|
318
|
+
|
319
|
+
model_ft = model_ft.to(DEVICE)
|
320
|
+
|
321
|
+
return model_ft
|
322
|
+
|
323
|
+
best_model = get_model(target_num=2)
|
324
|
+
|
325
|
+
# モデルを保存する
|
326
|
+
|
327
|
+
modelname = './original_model_39.pth'
|
328
|
+
|
329
|
+
pickle.dump(best_model, open(modelname, 'wb'))
|
330
|
+
|
331
|
+
# 保存したモデルをロードする
|
332
|
+
|
333
|
+
loaded_model = pickle.load(open(modelname, 'rb'))
|
334
|
+
|
335
|
+
|
336
|
+
|
337
|
+
|
338
|
+
|
339
|
+
|
340
|
+
|
341
|
+
|
342
|
+
|
343
|
+
```
|
2
質問内容の修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -10,6 +10,10 @@
|
|
10
10
|
|
11
11
|
testデータ推論時に、正常品1,欠陥品0と出力されるようにしたい。
|
12
12
|
|
13
|
+
ハイパラメータの設定が上手くいっていないため出力がおかしなことになっていると考えている。
|
14
|
+
|
15
|
+
(学習率、モーメンタムなど、)
|
16
|
+
|
13
17
|
|
14
18
|
|
15
19
|
|
1
質問内容の修正
test
CHANGED
File without changes
|
test
CHANGED
@@ -8,6 +8,12 @@
|
|
8
8
|
|
9
9
|
|
10
10
|
|
11
|
+
testデータ推論時に、正常品1,欠陥品0と出力されるようにしたい。
|
12
|
+
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
|
11
17
|
### 発生している問題・エラーメッセージ
|
12
18
|
|
13
19
|
|