質問内容
現在,pix2pixHDを自分のデータセットを作成し,学習させています.
その際に,学習したモデルの保存を行っているのですが,その学習されたモデルデータを再度使って学習を続きから行いたいのですが検索してもやり方がわからなくて困っています
環境
・Google Colaboratory
・公式がdominateというライブラリをインストールしろと言っていたのでinstall
・GPUはtensorflow-gpu 2.4.1を使用
問題解決のために行ったこと
GAN 学習モデル 再学習,継続学習
などで検索を行いました.
使用したコード
https://github.com/NVIDIA/pix2pixHD
をそのままcloneしてきました.改良などは行っていません.
train.pyを貼っておきます.
Python
1import time 2import os 3import numpy as np 4import torch 5from torch.autograd import Variable 6from collections import OrderedDict 7from subprocess import call 8import fractions 9def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0 10 11from options.train_options import TrainOptions 12from data.data_loader import CreateDataLoader 13from models.models import create_model 14import util.util as util 15from util.visualizer import Visualizer 16 17opt = TrainOptions().parse() 18iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 19if opt.continue_train: 20 try: 21 start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) 22 except: 23 start_epoch, epoch_iter = 1, 0 24 print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) 25else: 26 start_epoch, epoch_iter = 1, 0 27 28opt.print_freq = lcm(opt.print_freq, opt.batchSize) 29if opt.debug: 30 opt.display_freq = 1 31 opt.print_freq = 1 32 opt.niter = 1 33 opt.niter_decay = 0 34 opt.max_dataset_size = 10 35 36data_loader = CreateDataLoader(opt) 37dataset = data_loader.load_data() 38dataset_size = len(data_loader) 39print('#training images = %d' % dataset_size) 40 41model = create_model(opt) 42visualizer = Visualizer(opt) 43if opt.fp16: 44 from apex import amp 45 model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D], opt_level='O1') 46 model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 47else: 48 optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D 49 50total_steps = (start_epoch-1) * dataset_size + epoch_iter 51 52display_delta = total_steps % opt.display_freq 53print_delta = total_steps % opt.print_freq 54save_delta = total_steps % opt.save_latest_freq 55 56for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 57 epoch_start_time = time.time() 58 if epoch != start_epoch: 59 epoch_iter = epoch_iter % dataset_size 60 for i, data in enumerate(dataset, start=epoch_iter): 61 if total_steps % opt.print_freq == print_delta: 62 iter_start_time = time.time() 63 total_steps += opt.batchSize 64 epoch_iter += opt.batchSize 65 66 # whether to collect output images 67 save_fake = total_steps % opt.display_freq == display_delta 68 69 ############## Forward Pass ###################### 70 losses, generated = model(Variable(data['label']), Variable(data['inst']), 71 Variable(data['image']), Variable(data['feat']), infer=save_fake) 72 73 # sum per device losses 74 losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] 75 loss_dict = dict(zip(model.module.loss_names, losses)) 76 77 # calculate final loss scalar 78 loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 79 loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) 80 81 ############### Backward Pass #################### 82 # update generator weights 83 optimizer_G.zero_grad() 84 if opt.fp16: 85 with amp.scale_loss(loss_G, optimizer_G) as scaled_loss: scaled_loss.backward() 86 else: 87 loss_G.backward() 88 optimizer_G.step() 89 90 # update discriminator weights 91 optimizer_D.zero_grad() 92 if opt.fp16: 93 with amp.scale_loss(loss_D, optimizer_D) as scaled_loss: scaled_loss.backward() 94 else: 95 loss_D.backward() 96 optimizer_D.step() 97 98 ############## Display results and errors ########## 99 ### print out errors 100 if total_steps % opt.print_freq == print_delta: 101 errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()} 102 t = (time.time() - iter_start_time) / opt.print_freq 103 visualizer.print_current_errors(epoch, epoch_iter, errors, t) 104 visualizer.plot_current_errors(errors, total_steps) 105 #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 106 107 ### display output images 108 if save_fake: 109 visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 110 ('synthesized_image', util.tensor2im(generated.data[0])), 111 ('real_image', util.tensor2im(data['image'][0]))]) 112 visualizer.display_current_results(visuals, epoch, total_steps) 113 114 ### save latest model 115 if total_steps % opt.save_latest_freq == save_delta: 116 print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 117 model.module.save('latest') 118 np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 119 120 if epoch_iter >= dataset_size: 121 break 122 123 # end of epoch 124 iter_end_time = time.time() 125 print('End of epoch %d / %d \t Time Taken: %d sec' % 126 (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 127 128 ### save model for this epoch 129 if epoch % opt.save_epoch_freq == 0: 130 print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 131 model.module.save('latest') 132 model.module.save(epoch) 133 np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') 134 135 ### instead of only training the local enhancer, train the entire network after certain iterations 136 if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): 137 model.module.update_fixed_params() 138 139 ### linearly decay learning rate after certain iterations 140 if epoch > opt.niter: 141 model.module.update_learning_rate() 142
最後に
追記したほうがいいことがあったらコメントお願いします.
ほぼ初心者に近い者なので詳しく書いたりこのサイト勉強になるとか載せていただけると嬉しいです
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。