前提・実現したいこと
pytorchで物体検出の学習を行おうとしたところ、学習部分で以下のようなエラーが発生しました。エラーとして表示されるテンソルのサイズが毎回異なることと、kaggleのnotebookを利用すると同じデータでもうまくいくことなどから、テンソルのサイズが異なっているというわけではなく、何か別の原因があるように感じています。
原因について何かお心当たりのある方がいらっしゃいましたらぜひご回答いただけますと幸いです。
発生している問題・エラーメッセージ
RuntimeError Traceback (most recent call last) <ipython-input-31-5bb99a4def7d> in <module> 95 manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch")) 96 ---> 97 _ = trainer.run(train_loader, max_epochs=epoch) ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in run(self, data, max_epochs, epoch_length, seed) 700 701 self.state.dataloader = data --> 702 return self._internal_run() 703 704 @staticmethod ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _internal_run(self) 773 self._dataloader_iter = None 774 self.logger.error(f"Engine run is terminating due to exception: {e}") --> 775 self._handle_exception(e) 776 777 self._dataloader_iter = None ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _handle_exception(self, e) 467 self._fire_event(Events.EXCEPTION_RAISED, e) 468 else: --> 469 raise e 470 471 @property ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _internal_run(self) 743 self._setup_engine() 744 --> 745 time_taken = self._run_once_on_dataset() 746 # time is available for handlers but must be update after fire 747 self.state.times[Events.EPOCH_COMPLETED.name] = time_taken ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _run_once_on_dataset(self) 848 except Exception as e: 849 self.logger.error(f"Current run is terminating due to exception: {e}") --> 850 self._handle_exception(e) 851 852 return time.time() - start_time ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _handle_exception(self, e) 467 self._fire_event(Events.EXCEPTION_RAISED, e) 468 else: --> 469 raise e 470 471 @property ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _run_once_on_dataset(self) 799 if self.last_event_name != Events.DATALOADER_STOP_ITERATION: 800 self._fire_event(Events.GET_BATCH_STARTED) --> 801 self.state.batch = next(self._dataloader_iter) 802 self._fire_event(Events.GET_BATCH_COMPLETED) 803 iter_counter += 1 ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py in __next__(self) 433 if self._sampler_iter is None: 434 self._reset() --> 435 data = self._next_data() 436 self._num_yielded += 1 437 if self._dataset_kind == _DatasetKind.Iterable and \ ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py in _next_data(self) 1083 else: 1084 del self._task_info[idx] -> 1085 return self._process_data(data) 1086 1087 def _try_put_index(self): ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py in _process_data(self, data) 1109 self._try_put_index() 1110 if isinstance(data, ExceptionWrapper): -> 1111 data.reraise() 1112 return data 1113 ~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/_utils.py in reraise(self) 426 # have message field 427 raise self.exc_type(message=msg) --> 428 raise self.exc_type(msg) 429 430 RuntimeError: Caught RuntimeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop data = fetcher.fetch(index) File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch return self.collate_fn(data) File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate return [default_collate(samples) for samples in transposed] File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp> return [default_collate(samples) for samples in transposed] File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate return torch.stack(batch, 0, out=out) RuntimeError: stack expects each tensor to be equal size, but got [3, 512, 499] at entry 0 and [3, 512, 458] at entry 1 #この部分の数字は実行のたびに異なるものが表示されます
該当のソースコード(コードが長いため、エラーが発生した部分のみ抜粋しております)
Python
1train_loader = DataLoader( 2 train_dataset, 3 batch_size=flags.batchsize, 4 num_workers=flags.num_workers, 5 shuffle=True, 6 pin_memory=True, 7) 8valid_loader = DataLoader( 9 valid_dataset, 10 batch_size=flags.valid_batchsize, 11 num_workers=flags.num_workers, 12 shuffle=False, 13 pin_memory=True, 14) 15 16device = torch.device(flags.device) 17 18predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode) 19classifier = Classifier(predictor) 20model = classifier 21# optimizer = optim.Adam(model.parameters(), lr=1e-3) 22optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-3) 23 24# Train setup 25trainer = create_trainer(model, optimizer, device) 26 27ema = EMA(predictor, decay=flags.ema_decay) 28 29def eval_func(*batch): 30 loss, metrics = model(*[elem.to(device) for elem in batch]) 31 # HACKING: report ema value with prefix. 32 if flags.ema_decay > 0: 33 classifier.prefix = "ema_" 34 ema.assign() 35 loss, metrics = model(*[elem.to(device) for elem in batch]) 36 ema.resume() 37 classifier.prefix = "" 38 39valid_evaluator = E.Evaluator( 40 valid_loader, model, progress_bar=False, eval_func=eval_func, device=device 41) 42 43# log_trigger = (10 if debug else 1000, "iteration") 44log_trigger = (1, "epoch") 45log_report = E.LogReport(trigger=log_trigger) 46extensions = [ 47 log_report, 48 E.ProgressBarNotebook(update_interval=10 if debug else 100), # Show progress bar during training 49 E.PrintReportNotebook(), # Show "log" on jupyter notebook 50 # E.ProgressBar(update_interval=10 if debug else 100), # Show progress bar during training 51 # E.PrintReport(), # Print "log" to terminal 52 E.FailOnNonNumber(), # Stop training when nan is detected. 53] 54epoch = flags.epoch 55models = {"main": model} 56optimizers = {"main": optimizer} 57manager = IgniteExtensionsManager( 58 trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(outdir), 59) 60# Run evaluation for valid dataset in each epoch. 61manager.extend(valid_evaluator) 62 63# Save predictor.pt every epoch 64manager.extend( 65 E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "epoch") 66) 67# Check & Save best validation predictor.pt every epoch 68# manager.extend(E.snapshot_object(predictor, "best_predictor.pt"), 69# trigger=MinValueTrigger("validation/module/nll", 70# trigger=(flags.snapshot_freq, "iteration"))) 71 72# --- lr scheduler --- 73if flags.scheduler_type != "": 74 scheduler_type = flags.scheduler_type 75 print(f"using {scheduler_type} scheduler with kwargs {flags.scheduler_kwargs}") 76 manager.extend( 77 LRScheduler(optimizer, scheduler_type, flags.scheduler_kwargs), 78 trigger=flags.scheduler_trigger, 79 ) 80 81manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger) 82 83if flags.ema_decay > 0: 84 # Exponential moving average 85 manager.extend(lambda manager: ema(), trigger=(1, "iteration")) 86 87 def save_ema_model(manager): 88 ema.assign() 89 #torch.save(predictor.state_dict(), outdir / "predictor_ema.pt") 90 torch.save(predictor.state_dict(), "/result/predictor_ema.pt") 91 ema.resume() 92 93 manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch")) 94 95_ = trainer.run(train_loader, max_epochs=epoch)
試したこと
pytorchやcudaの再インストール
補足情報(FW/ツールのバージョンなど)
pytorch: 1.7.1
cuda: 10.2
ubuntu 16.04
あなたの回答
tips
プレビュー