google colaboratory上でシカとカバとウマ画像を用いて画像分類の学習をするプログラムを作っているのですが、プログラムを実行するとクラッシュします。エラーも出ないので困っています。どうしたら直せるのでしょうか。
ファイルの配列は
net.py(下のコード)
train(学習用画像ファイル)
|----Deers_train(シカの学習用画像224枚)
|
|----Horse_train(ウマの学習用画像326枚)
|
|----Hippo_train(カバの学習用画像429枚)
としています
python
1import chainer 2import os 3import glob 4from itertools import chain 5from chainer.datasets import LabeledImageDataset 6from chainer import iterators,training,optimizers,datasets,serializers 7from chainer.training import extensions,triggers 8from chainer.dataset import concat_examples 9from chainercv.transforms import resize 10from chainer.datasets import TransformDataset 11import chainer.functions as F 12import chainer.links as L 13 14chainer.config.train = True 15 16class MyChain(chainer.Chain): 17 18 def __init__(self): 19 super(MyChain,self).__init__() 20 with self.init_scope(): 21 self.conv1 = L.Convolution2D(None,16,3,pad=2) 22 self.conv2 = L.Convolution2D(None,32,3,pad=2) 23 self.l3 = L.Linear(None,256) 24 self.l4 = L.Linear(None,3) 25 26 def __call__(self,x): 27 h = F.max_pooling_2d(F.relu(self.conv1(x)),ksize=5,stride=2,pad=2) 28 h = F.max_pooling_2d(F.relu(self.conv2(x)),ksize=5,stride=2,pad=2) 29 h = F.dropout(F.relu(self.l3(h))) 30 y = self.l4(h) 31 return y 32 33#img----------- 34 35IMG_TRA = 'train' 36 37dnames = glob.glob('{}/*'.format(IMG_TRA)) 38 39fnames = [glob.glob('{}/*.jpg'.format(d)) for d in dnames] 40fnames = list(chain.from_iterable(fnames)) 41 42labels = [os.path.basename(os.path.dirname(fn)) for fn in fnames] 43dname = [os.path.basename(d) for d in dnames] 44labels = [dname.index(l) for l in labels] 45d = LabeledImageDataset(list(zip(fnames,labels))) 46 47def transform(data): 48 img,label = data 49 img = resize(img,(500,500)) 50 return img,label 51 52train = chainer.datasets.TransformDataset(d,transform) 53 54 55#train--------- 56 57epoch = 10 58batch = 5 59 60model = L.Classifier(MyChain()) 61optimizer = optimizers.Adam() 62optimizer.setup(model) 63 64train_iter = iterators.SerialIterator(train,batch) 65updater = training.StandardUpdater(train_iter,optimizer) 66trainer = training.Trainer(updater,(epoch,'epoch'),out='result') 67 68trainer.extend(extensions.dump_graph('main/loss')) 69trainer.extend(extensions.snapshot(),trigger=(epoch,'epoch')) 70trainer.extend(extensions.LogReport()) 71trainer.extend(extensions.PrintReport(['epoch','main/loss','main/accuracy'])) 72trainer.extend(extensions.ProgressBar()) 73trainer.extend(extensions.PlotReport(['main/loss'],'epoch',file_name='loss.png')) 74trainer.extend(extensions.PlotReport(['main/accuracy'],'epoch',file_name='accuracy.png')) 75 76trainer.run() 77 78serializer.save_npz("mymodel.npz",model) 79
回答2件
あなたの回答
tips
プレビュー