python3.6
1#!/usr/bin/env python 2"""Example code of learning a large scale convnet from ILSVRC2012 dataset. 3Prerequisite: To run this example, crop the center of ILSVRC2012 training and 4validation images, scale them to 256x256 and convert them to RGB, and make 5two lists of space-separated CSV whose first column is full path to image and 6second column is zero-origin label (this format is same as that used by Caffe's 7ImageDataLayer). 8""" 9from __future__ import print_function 10import argparse 11import random 12 13import numpy as np 14 15import chainer 16from chainer import training 17from chainer.training import extensions 18 19import alex 20import googlenet 21import googlenetbn 22import nin 23import resnet50 24 25 26class PreprocessedDataset(chainer.dataset.DatasetMixin): 27 28 def __init__(self, path, root, mean, crop_size, random=True): 29 self.base = chainer.datasets.LabeledImageDataset(path, root) 30 self.mean = mean.astype('f') 31 self.crop_size = crop_size 32 self.random = random 33 34 def __len__(self): 35 return len(self.base) 36 37 def get_example(self, i): 38 # It reads the i-th image/label pair and return a preprocessed image. 39 # It applies following preprocesses: 40 # - Cropping (random or center rectangular) 41 # - Random flip 42 # - Scaling to [0, 1] value 43 crop_size = self.crop_size 44 45 image, label = self.base[i] 46 _, h, w = image.shape 47 48 if self.random: 49 # Randomly crop a region and flip the image 50 top = random.randint(0, h - crop_size - 1) 51 left = random.randint(0, w - crop_size - 1) 52 if random.randint(0, 1): 53 image = image[:, :, ::-1] 54 else: 55 # Crop the center 56 top = (h - crop_size) // 2 57 left = (w - crop_size) // 2 58 bottom = top + crop_size 59 right = left + crop_size 60 61 image = image[:, top:bottom, left:right] 62 image -= self.mean[:, top:bottom, left:right] 63 image *= (1.0 / 255.0) # Scale to [0, 1] 64 return image, label 65 66 67def main(): 68 archs = { 69 'alex': alex.Alex, 70 'alex_fp16': alex.AlexFp16, 71 'googlenet': googlenet.GoogLeNet, 72 'googlenetbn': googlenetbn.GoogLeNetBN, 73 'googlenetbn_fp16': googlenetbn.GoogLeNetBNFp16, 74 'nin': nin.NIN, 75 'resnet50': resnet50.ResNet50 76 } 77 78 parser = argparse.ArgumentParser( 79 description='Learning convnet from ILSVRC2012 dataset') 80 parser.add_argument('train', help='Path to training image-label list file') 81 parser.add_argument('val', help='Path to validation image-label list file') 82 parser.add_argument('--arch', '-a', choices=archs.keys(), default='nin', 83 help='Convnet architecture') 84 parser.add_argument('--batchsize', '-B', type=int, default=32, 85 help='Learning minibatch size') 86 parser.add_argument('--epoch', '-E', type=int, default=50, 87 help='Number of epochs to train') 88 parser.add_argument('--gpu', '-g', type=int, default=-1, 89 help='GPU ID (negative value indicates CPU') 90 parser.add_argument('--initmodel', 91 help='Initialize the model from given file') 92 parser.add_argument('--loaderjob', '-j', type=int, 93 help='Number of parallel data loading processes') 94 parser.add_argument('--mean', '-m', default='mean.npy', 95 help='Mean file (computed by compute_mean.py)') 96 parser.add_argument('--resume', '-r', default='', 97 help='Initialize the trainer from given file') 98 parser.add_argument('--out', '-o', default='result', 99 help='Output directory') 100 parser.add_argument('--root', '-R', default='.', 101 help='Root directory path of image files') 102 parser.add_argument('--val_batchsize', '-b', type=int, default=250, 103 help='Validation minibatch size') 104 parser.add_argument('--test', action='store_true') 105 parser.set_defaults(test=False) 106 args = parser.parse_args() 107 108 # Initialize the model to train 109 model = archs[args.arch]() 110 if args.initmodel: 111 print('Load model from', args.initmodel) 112 chainer.serializers.load_npz(args.initmodel, model) 113 if args.gpu >= 0: 114 chainer.cuda.get_device_from_id(args.gpu).use() # Make the GPU current 115 model.to_gpu() 116 117 # Load the datasets and mean file 118 mean = np.load(args.mean) 119 train = PreprocessedDataset(args.train, args.root, mean, model.insize) 120 val = PreprocessedDataset(args.val, args.root, mean, model.insize, False) 121 # These iterators load the images with subprocesses running in parallel to 122 # the training/validation. 123 train_iter = chainer.iterators.MultiprocessIterator( 124 train, args.batchsize, n_processes=args.loaderjob) 125 val_iter = chainer.iterators.MultiprocessIterator( 126 val, args.val_batchsize, repeat=False, n_processes=args.loaderjob) 127 128 # Set up an optimizer 129 optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9) 130 optimizer.setup(model) 131 132 # Set up a trainer 133 updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) 134 trainer = training.Trainer(updater, (args.epoch, 'epoch'), args.out) 135 136 val_interval = (10 if args.test else 100000), 'iteration' 137 log_interval = (10 if args.test else 1000), 'iteration' 138 139 trainer.extend(extensions.Evaluator(val_iter, model, device=args.gpu), 140 trigger=val_interval) 141 trainer.extend(extensions.dump_graph('main/loss')) 142 trainer.extend(extensions.snapshot(), trigger=val_interval) 143 trainer.extend(extensions.snapshot_object( 144 model, 'model_iter_{.updater.iteration}'), trigger=val_interval) 145 # Be careful to pass the interval directly to LogReport 146 # (it determines when to emit log rather than when to read observations) 147 trainer.extend(extensions.LogReport(trigger=log_interval)) 148 trainer.extend(extensions.observe_lr(), trigger=log_interval) 149 trainer.extend(extensions.PrintReport([ 150 'epoch', 'iteration', 'main/loss', 151 'main/accuracy', 'lr' 152 ]), trigger=log_interval) 153 trainer.extend( 154 extensions.PlotReport( 155 ['main/accuracy'], 156 'iteration', trigger = (500, 'iteration'), file_name='accuracy.png')) 157 trainer.extend(extensions.ProgressBar(update_interval=10)) 158 159 if args.resume: 160 chainer.serializers.load_npz(args.resume, trainer) 161 162 trainer.run() 163 164 chainer.serializers.save_hdf5("gpu1out.h5", model) 165 166 167if __name__ == '__main__': 168 main()
質問を修正して、コードブロックは ```コード``` と囲ってください。
失礼致しました。
>作業依頼のような投稿をして課題や仕事を無償でやってもらう場所ではありません。 https://teratail.com/help#about-teratail
コードはご自身で作られたものではなく、どっかから持ってきたものでしょうか?Chainer は使わないのでわからないですが、推論するだけなので、こちら https://qiita.com/mitmul/items/1e35fba085eb07a92560 の MNIST のチュートリアルの「保存したモデルを読み込んで推論しよう」というところでカバーされてる内容と思いますがどうでしょうか?
URL添付ありがとうございます。
あなたの回答
tips
プレビュー