質問編集履歴

2

コードが見にくかったため

2021/05/10 12:46

投稿

shujiu
shujiu

スコア3

test CHANGED
File without changes
test CHANGED
File without changes

1

コードが見にくかったため

2021/05/10 12:46

投稿

shujiu
shujiu

スコア3

test CHANGED
File without changes
test CHANGED
@@ -18,8 +18,6 @@
18
18
 
19
19
  ```
20
20
 
21
- ### 該当のソースコード
22
-
23
21
  ```python
24
22
 
25
23
  import torch
@@ -158,8 +156,6 @@
158
156
 
159
157
 
160
158
 
161
-
162
-
163
159
  model.class_to_idx = all_data.class_to_idx
164
160
 
165
161
  model.idx_to_class = {
@@ -192,6 +188,8 @@
192
188
 
193
189
 
194
190
 
191
+ #Initializing some variables
192
+
195
193
  valid_loss_min = np.Inf
196
194
 
197
195
  stop_count = 0
@@ -242,7 +240,7 @@
242
240
 
243
241
 
244
242
 
245
-
243
+ # Track train loss by multiplying average loss by number of examples in batch
246
244
 
247
245
  train_loss += loss.item() * data.size(0)
248
246
 
@@ -334,50 +332,220 @@
334
332
 
335
333
  stop_count += 1
336
334
 
337
-
338
-
339
- if stop_count >= early_stop:
340
-
341
- print(f'\nEarly Stopping Total epochs: {epoch}. Best epoch: {best_epoch} with loss: {valid_loss_min:.2f} and acc: {100 * valid_acc:.2f}%')
342
-
343
- model.load_state_dict(torch.load(save_location))
344
-
345
- model.optimizer = optimizer
346
-
347
- history = pd.DataFrame(history, columns=['train_loss', 'valid_loss', 'train_acc','valid_acc'])
348
-
349
- return model, history
350
-
351
-
352
-
353
- model.optimizer = optimizer
354
-
355
- print(f'\nBest epoch: {best_epoch} with loss: {valid_loss_min:.2f} and acc: {100 * valid_acc:.2f}%')
356
-
357
-
358
-
359
- history = pd.DataFrame(history, columns=['train_loss', 'valid_loss', 'train_acc', 'valid_acc'])
360
-
361
- return model, history
362
-
363
-
364
-
365
- model, history = train(
366
-
367
- model,
368
-
369
- criterion,
370
-
371
- optimizer,
372
-
373
- train_loader,
374
-
375
- val_loader,
376
-
377
- save_location='./natural_images_resnet.pt',
378
-
379
- early_stop=5,
380
-
381
- n_epochs=10,
382
-
383
- print_every=2)
335
+
336
+
337
+ class GradCAM:
338
+
339
+ def __init__(self, model, feature_layer):
340
+
341
+ self.model = model
342
+
343
+ self.feature_layer = feature_layer
344
+
345
+ self.model.eval()
346
+
347
+ self.feature_grad = None
348
+
349
+ self.feature_map = None
350
+
351
+ self.hooks = []
352
+
353
+
354
+
355
+ # 最終層逆伝播時の勾配を記録する
356
+
357
+ def save_feature_grad(module, in_grad, out_grad):
358
+
359
+ self.feature_grad = out_grad[0]
360
+
361
+ self.hooks.append(self.feature_layer.register_backward_hook(save_feature_grad))
362
+
363
+ # 最終層の出力 Feature Map を記録する
364
+
365
+ def save_feature_map(module, inp, outp):
366
+
367
+ self.feature_map = outp[0]
368
+
369
+ self.hooks.append(self.feature_layer.register_forward_hook(save_feature_map))
370
+
371
+
372
+
373
+ def forward(self, x):
374
+
375
+ return self.model(x)
376
+
377
+
378
+
379
+ def backward_on_target(self, output, target):
380
+
381
+ self.model.zero_grad()
382
+
383
+ one_hot_output = torch.zeros([1, output.size()[-1]])
384
+
385
+ one_hot_output[0][target] = 1
386
+
387
+ output.backward(gradient=one_hot_output, retain_graph=True)
388
+
389
+
390
+
391
+ def clear_hook(self):
392
+
393
+ for hook in self.hooks:
394
+
395
+ hook.remove()
396
+
397
+
398
+
399
+ path = "/Users/ora/Desktop/2019_33 ①.jpg"
400
+
401
+ VISUALIZE_SIZE = (300, 200)
402
+
403
+ image = Image.open(path)
404
+
405
+ image.thumbnail(VISUALIZE_SIZE, Image.ANTIALIAS)
406
+
407
+
408
+
409
+ plt.imshow(image)
410
+
411
+
412
+
413
+ image_orig_size = image.size # (W, H)
414
+
415
+
416
+
417
+ test_image_tensor = image_transforms['test'](image)
418
+
419
+ test_image_tensor = test_image_tensor.unsqueeze(0)
420
+
421
+ device = torch.device("cuda")
422
+
423
+ model = models.resnet50(pretrained=False)
424
+
425
+ n_classes = 2
426
+
427
+ num_ftrs = model.fc.in_features
428
+
429
+ model.fc = nn.Linear(num_ftrs, n_classes)
430
+
431
+ model.to(device)
432
+
433
+ model.eval()
434
+
435
+
436
+
437
+ from collections import OrderedDict
438
+
439
+ import torch
440
+
441
+ checkpoint=torch.load('./natural_images_resnet.pt')
442
+
443
+
444
+
445
+ state_dict=checkpoint
446
+
447
+ new_state_dict=OrderedDict()
448
+
449
+
450
+
451
+ model.load_state_dict(state_dict)
452
+
453
+
454
+
455
+ #############GradCamによる可視化
456
+
457
+ grad_cam = GradCAM(model, feature_layer=list(model.layer4.modules())[26])
458
+
459
+ #画像をGradcamに入力
460
+
461
+ model_output = grad_cam.forward(test_image_tensor)
462
+
463
+
464
+
465
+ if len(model_output) == 1:
466
+
467
+ target = model_output.argmax(1).item()
468
+
469
+ grad_cam.backward_on_target(model_output, target)
470
+
471
+
472
+
473
+ # Get feature gradient
474
+
475
+ feature_grad = grad_cam.feature_grad.data.numpy()[0]
476
+
477
+
478
+
479
+ # Get weights from gradient
480
+
481
+ weights = np.mean(feature_grad, axis=(1, 2)) # Take averages for each gradient
482
+
483
+ # Get features outputs
484
+
485
+ feature_map = grad_cam.feature_map.data.numpy()
486
+
487
+ grad_cam.clear_hook()
488
+
489
+
490
+
491
+ cam = np.sum((weights * feature_map.T), axis=2).T
492
+
493
+ cam = np.maximum(cam, 0) # apply ReLU to cam
494
+
495
+
496
+
497
+ cam = cv2.resize(cam, VISUALIZE_SIZE)
498
+
499
+ cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) # Normalize between 0-1
500
+
501
+ cam = np.uint8(cam * 255) # Scale between 0-255 to visualize
502
+
503
+ activation_heatmap = np.uint8(cv2.applyColorMap(cam, cv2.COLORMAP_JET))
504
+
505
+ activation_heatmap = cv2.cvtColor(activation_heatmap, cv2.COLOR_BGR2RGB) #色反転
506
+
507
+
508
+
509
+ plt.imshow(activation_heatmap)
510
+
511
+
512
+
513
+ org_img = np.asarray(image.resize(VISUALIZE_SIZE))
514
+
515
+ intensity = 0.4
516
+
517
+ img_with_heatmap = cv2.addWeighted(activation_heatmap, intensity, org_img, 1, 0)
518
+
519
+ org_img = cv2.resize(org_img, image_orig_size)
520
+
521
+ img_with_heatmap = cv2.resize(img_with_heatmap, image_orig_size)
522
+
523
+
524
+
525
+ plt.figure(figsize=(10,5))
526
+
527
+
528
+
529
+ plt.subplot(1,2,1)
530
+
531
+ plt.imshow(org_img)
532
+
533
+ plt.xticks(color="None")
534
+
535
+ plt.yticks(color="None")
536
+
537
+ plt.tick_params(length=0)
538
+
539
+ plt.subplot(1,2,2)
540
+
541
+ plt.imshow(img_with_heatmap)
542
+
543
+ plt.xticks(color="None")
544
+
545
+ plt.yticks(color="None")
546
+
547
+ plt.tick_params(length=0)
548
+
549
+ plt.show()
550
+
551
+ ```