質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Q&A

解決済

1回答

2319閲覧

Pix2PixHDで保存した学習モデルを使って継続して学習させたい

asuka_wataki

総合スコア6

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

0グッド

0クリップ

投稿2021/02/18 08:42

編集2021/02/18 08:43

質問内容

現在,pix2pixHDを自分のデータセットを作成し,学習させています.
その際に,学習したモデルの保存を行っているのですが,その学習されたモデルデータを再度使って学習を続きから行いたいのですが検索してもやり方がわからなくて困っています

環境

・Google Colaboratory
・公式がdominateというライブラリをインストールしろと言っていたのでinstall
・GPUはtensorflow-gpu 2.4.1を使用

問題解決のために行ったこと

GAN 学習モデル 再学習,継続学習
などで検索を行いました.

使用したコード

https://github.com/NVIDIA/pix2pixHD
をそのままcloneしてきました.改良などは行っていません.

train.pyを貼っておきます.

Python

1import time 2import os 3import numpy as np 4import torch 5from torch.autograd import Variable 6from collections import OrderedDict 7from subprocess import call 8import fractions 9def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0 10 11from options.train_options import TrainOptions 12from data.data_loader import CreateDataLoader 13from models.models import create_model 14import util.util as util 15from util.visualizer import Visualizer 16 17opt = TrainOptions().parse() 18iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 19if opt.continue_train: 20 try: 21 start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) 22 except: 23 start_epoch, epoch_iter = 1, 0 24 print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) 25else: 26 start_epoch, epoch_iter = 1, 0 27 28opt.print_freq = lcm(opt.print_freq, opt.batchSize) 29if opt.debug: 30 opt.display_freq = 1 31 opt.print_freq = 1 32 opt.niter = 1 33 opt.niter_decay = 0 34 opt.max_dataset_size = 10 35 36data_loader = CreateDataLoader(opt) 37dataset = data_loader.load_data() 38dataset_size = len(data_loader) 39print('#training images = %d' % dataset_size) 40 41model = create_model(opt) 42visualizer = Visualizer(opt) 43if opt.fp16: 44 from apex import amp 45 model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D], opt_level='O1') 46 model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 47else: 48 optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D 49 50total_steps = (start_epoch-1) * dataset_size + epoch_iter 51 52display_delta = total_steps % opt.display_freq 53print_delta = total_steps % opt.print_freq 54save_delta = total_steps % opt.save_latest_freq 55 56for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 57 epoch_start_time = time.time() 58 if epoch != start_epoch: 59 epoch_iter = epoch_iter % dataset_size 60 for i, data in enumerate(dataset, start=epoch_iter): 61 if total_steps % opt.print_freq == print_delta: 62 iter_start_time = time.time() 63 total_steps += opt.batchSize 64 epoch_iter += opt.batchSize 65 66 # whether to collect output images 67 save_fake = total_steps % opt.display_freq == display_delta 68 69 ############## Forward Pass ###################### 70 losses, generated = model(Variable(data['label']), Variable(data['inst']), 71 Variable(data['image']), Variable(data['feat']), infer=save_fake) 72 73 # sum per device losses 74 losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] 75 loss_dict = dict(zip(model.module.loss_names, losses)) 76 77 # calculate final loss scalar 78 loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 79 loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) 80 81 ############### Backward Pass #################### 82 # update generator weights 83 optimizer_G.zero_grad() 84 if opt.fp16: 85 with amp.scale_loss(loss_G, optimizer_G) as scaled_loss: scaled_loss.backward() 86 else: 87 loss_G.backward() 88 optimizer_G.step() 89 90 # update discriminator weights 91 optimizer_D.zero_grad() 92 if opt.fp16: 93 with amp.scale_loss(loss_D, optimizer_D) as scaled_loss: scaled_loss.backward() 94 else: 95 loss_D.backward() 96 optimizer_D.step() 97 98 ############## Display results and errors ########## 99 ### print out errors 100 if total_steps % opt.print_freq == print_delta: 101 errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()} 102 t = (time.time() - iter_start_time) / opt.print_freq 103 visualizer.print_current_errors(epoch, epoch_iter, errors, t) 104 visualizer.plot_current_errors(errors, total_steps) 105 #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 106 107 ### display output images 108 if save_fake: 109 visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 110 ('synthesized_image', util.tensor2im(generated.data[0])), 111 ('real_image', util.tensor2im(data['image'][0]))]) 112 visualizer.display_current_results(visuals, epoch, total_steps) 113 114 ### save latest model 115 if total_steps % opt.save_latest_freq == save_delta: 116 print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 117 model.module.save('latest') 118 np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 119 120 if epoch_iter >= dataset_size: 121 break 122 123 # end of epoch 124 iter_end_time = time.time() 125 print('End of epoch %d / %d \t Time Taken: %d sec' % 126 (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 127 128 ### save model for this epoch 129 if epoch % opt.save_epoch_freq == 0: 130 print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 131 model.module.save('latest') 132 model.module.save(epoch) 133 np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') 134 135 ### instead of only training the local enhancer, train the entire network after certain iterations 136 if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): 137 model.module.update_fixed_params() 138 139 ### linearly decay learning rate after certain iterations 140 if epoch > opt.niter: 141 model.module.update_learning_rate() 142

最後に

追記したほうがいいことがあったらコメントお願いします.
ほぼ初心者に近い者なので詳しく書いたりこのサイト勉強になるとか載せていただけると嬉しいです

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

追加モジュールがありました

投稿2021/03/29 07:00

編集2021/03/29 07:06
asuka_wataki

総合スコア6

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問