Python 2.7
Chainer 3.3.0
ChainerCV 0.8.0
Ubuntu 14.04 (64bit)
現在、ChainerのTrainer機能を使用した学習を考えており、下記の図のように新しいEpochが始まるたびにTrainingデータからRandom cropして切り出して、そのEpochでのデータセットにしようと考えています。(Validationデータでは処理しない)
Epoch 1 Random cropしてデータセットを作成
・
・
・
Epoch 2 Random cropしてデータセットを作成
Chainerのサンプルコードを元に下記のプログラムを作成してみたのですが、上記のような処理が出来ず試行錯誤しています。
もし上記の処理が可能であるならば、アドバイスをよろしくお願いします。
import argparse import chainer import numpy as np from PIL import Image from chainer.datasets import TransformDataset from chainer import iterators from chainer import optimizers from chainer import training from chainer.training import extensions from chainercv import transforms from chainercv.datasets import camvid_label_names from chainercv.datasets import CamVidDataset from chainercv.extensions import SemanticSegmentationEvaluator from chainercv.links import PixelwiseSoftmaxClassifier from model import Integrating_Multiple_Deep_Networks def train_transform(in_data, crop_size=(240, 320)): img, label = in_data if (img.shape[1] < crop_size[0]) or (img.shape[2] < crop_size[1]): shorter_side = min(img.shape[1:]) _crop_size = (shorter_side, shorter_side) img, param = transforms.random_crop(img, _crop_size, True) else: img, param = transforms.random_crop(img, crop_size, True) label = label[param['y_slice'], param['x_slice']] if np.random.rand() > 0.5: img = transforms.flip(img, x_flip=True) label = transforms.flip(label[None, ...], x_flip=True)[0] if np.random.rand() > 0.5: img = img[:, :, ::-1] label = label[:, ::-1] return img, label def val_transform(in_data, crop_size=(240, 320)): img, label = in_data if (img.shape[1] < crop_size[0]) or (img.shape[2] < crop_size[1]): shorter_side = min(img.shape[1:]) _crop_size = (shorter_side, shorter_side) img, param = transforms.random_crop(img, _crop_size, True) else: img, param = transforms.random_crop(img, crop_size, True) label = label[param['y_slice'], param['x_slice']] return img, label def main(): parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--batchsize', type=int, default=2) parser.add_argument('--class_weight', type=str, default='class_weight.npy') parser.add_argument('--out', type=str, default='result') args = parser.parse_args() # Triggers log_trigger = (1, 'epoch') validation_trigger = (1, 'epoch') end_trigger = (100, 'epoch') # Dataset train = CamVidDataset(split='train') train = TransformDataset(train, train_transform) val = CamVidDataset(split='val') val = TransformDataset(val, val_transform) # Iterator train_iter = iterators.MultiprocessIterator(train, args.batchsize) val_iter = iterators.MultiprocessIterator( val, args.batchsize, shuffle=False, repeat=False) # Model class_weight = np.load(args.class_weight) model = Integrating_Multiple_Deep_Networks(n_ch=2, n_class=11, n_expt=3) model = PixelwiseSoftmaxClassifier( model, class_weight=class_weight) if args.gpu >= 0: # Make a specified GPU current chainer.cuda.get_device_from_id(args.gpu).use() model.to_gpu() # Copy the model to the GPU # Optimizer optimizer = optimizers.Adam() optimizer.setup(model) # Updater updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) # Trainer trainer = training.Trainer(updater, end_trigger, out=args.out) trainer.extend(extensions.LogReport(trigger=log_trigger)) if extensions.PlotReport.available(): trainer.extend(extensions.PlotReport( ['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png')) trainer.extend(extensions.PlotReport( ['validation/main/miou'], x_key='epoch', file_name='miou.png')) trainer.extend(extensions.snapshot_object( model.predictor, filename='model_epoch-{.updater.epoch}'), trigger=end_trigger) trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'val/main/loss', 'validation/main/miou', 'elapsed_time']), trigger=log_trigger) trainer.extend(extensions.ProgressBar(update_interval=10)) trainer.extend(extensions.Evaluator(val_iter, model, device=args.gpu), name='val') trainer.extend( SemanticSegmentationEvaluator( val_iter, model.predictor, camvid_label_names), trigger=validation_trigger) trainer.run() if __name__ == '__main__': main()
あなたの回答
tips
プレビュー