【Pytorch】RuntimeError: stack expects each tensor to be equal size, but got〜
受付中
回答 0
投稿
- 評価
- クリップ 0
- VIEW 127
前提・実現したいこと
pytorchで物体検出の学習を行おうとしたところ、学習部分で以下のようなエラーが発生しました。エラーとして表示されるテンソルのサイズが毎回異なることと、kaggleのnotebookを利用すると同じデータでもうまくいくことなどから、テンソルのサイズが異なっているというわけではなく、何か別の原因があるように感じています。
原因について何かお心当たりのある方がいらっしゃいましたらぜひご回答いただけますと幸いです。
発生している問題・エラーメッセージ
RuntimeError Traceback (most recent call last)
<ipython-input-31-5bb99a4def7d> in <module>
95 manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch"))
96
---> 97 _ = trainer.run(train_loader, max_epochs=epoch)
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in run(self, data, max_epochs, epoch_length, seed)
700
701 self.state.dataloader = data
--> 702 return self._internal_run()
703
704 @staticmethod
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _internal_run(self)
773 self._dataloader_iter = None
774 self.logger.error(f"Engine run is terminating due to exception: {e}")
--> 775 self._handle_exception(e)
776
777 self._dataloader_iter = None
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _handle_exception(self, e)
467 self._fire_event(Events.EXCEPTION_RAISED, e)
468 else:
--> 469 raise e
470
471 @property
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _internal_run(self)
743 self._setup_engine()
744
--> 745 time_taken = self._run_once_on_dataset()
746 # time is available for handlers but must be update after fire
747 self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _run_once_on_dataset(self)
848 except Exception as e:
849 self.logger.error(f"Current run is terminating due to exception: {e}")
--> 850 self._handle_exception(e)
851
852 return time.time() - start_time
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _handle_exception(self, e)
467 self._fire_event(Events.EXCEPTION_RAISED, e)
468 else:
--> 469 raise e
470
471 @property
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _run_once_on_dataset(self)
799 if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
800 self._fire_event(Events.GET_BATCH_STARTED)
--> 801 self.state.batch = next(self._dataloader_iter)
802 self._fire_event(Events.GET_BATCH_COMPLETED)
803 iter_counter += 1
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py in __next__(self)
433 if self._sampler_iter is None:
434 self._reset()
--> 435 data = self._next_data()
436 self._num_yielded += 1
437 if self._dataset_kind == _DatasetKind.Iterable and \
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py in _next_data(self)
1083 else:
1084 del self._task_info[idx]
-> 1085 return self._process_data(data)
1086
1087 def _try_put_index(self):
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1109 self._try_put_index()
1110 if isinstance(data, ExceptionWrapper):
-> 1111 data.reraise()
1112 return data
1113
~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/_utils.py in reraise(self)
426 # have message field
427 raise self.exc_type(message=msg)
--> 428 raise self.exc_type(msg)
429
430
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [3, 512, 499] at entry 0 and [3, 512, 458] at entry 1 #この部分の数字は実行のたびに異なるものが表示されます
該当のソースコード(コードが長いため、エラーが発生した部分のみ抜粋しております)
train_loader = DataLoader(
train_dataset,
batch_size=flags.batchsize,
num_workers=flags.num_workers,
shuffle=True,
pin_memory=True,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=flags.valid_batchsize,
num_workers=flags.num_workers,
shuffle=False,
pin_memory=True,
)
device = torch.device(flags.device)
predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode)
classifier = Classifier(predictor)
model = classifier
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-3)
# Train setup
trainer = create_trainer(model, optimizer, device)
ema = EMA(predictor, decay=flags.ema_decay)
def eval_func(*batch):
loss, metrics = model(*[elem.to(device) for elem in batch])
# HACKING: report ema value with prefix.
if flags.ema_decay > 0:
classifier.prefix = "ema_"
ema.assign()
loss, metrics = model(*[elem.to(device) for elem in batch])
ema.resume()
classifier.prefix = ""
valid_evaluator = E.Evaluator(
valid_loader, model, progress_bar=False, eval_func=eval_func, device=device
)
# log_trigger = (10 if debug else 1000, "iteration")
log_trigger = (1, "epoch")
log_report = E.LogReport(trigger=log_trigger)
extensions = [
log_report,
E.ProgressBarNotebook(update_interval=10 if debug else 100), # Show progress bar during training
E.PrintReportNotebook(), # Show "log" on jupyter notebook
# E.ProgressBar(update_interval=10 if debug else 100), # Show progress bar during training
# E.PrintReport(), # Print "log" to terminal
E.FailOnNonNumber(), # Stop training when nan is detected.
]
epoch = flags.epoch
models = {"main": model}
optimizers = {"main": optimizer}
manager = IgniteExtensionsManager(
trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(outdir),
)
# Run evaluation for valid dataset in each epoch.
manager.extend(valid_evaluator)
# Save predictor.pt every epoch
manager.extend(
E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "epoch")
)
# Check & Save best validation predictor.pt every epoch
# manager.extend(E.snapshot_object(predictor, "best_predictor.pt"),
# trigger=MinValueTrigger("validation/module/nll",
# trigger=(flags.snapshot_freq, "iteration")))
# --- lr scheduler ---
if flags.scheduler_type != "":
scheduler_type = flags.scheduler_type
print(f"using {scheduler_type} scheduler with kwargs {flags.scheduler_kwargs}")
manager.extend(
LRScheduler(optimizer, scheduler_type, flags.scheduler_kwargs),
trigger=flags.scheduler_trigger,
)
manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger)
if flags.ema_decay > 0:
# Exponential moving average
manager.extend(lambda manager: ema(), trigger=(1, "iteration"))
def save_ema_model(manager):
ema.assign()
#torch.save(predictor.state_dict(), outdir / "predictor_ema.pt")
torch.save(predictor.state_dict(), "/result/predictor_ema.pt")
ema.resume()
manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch"))
_ = trainer.run(train_loader, max_epochs=epoch)
試したこと
pytorchやcudaの再インストール
補足情報(FW/ツールのバージョンなど)
pytorch: 1.7.1
cuda: 10.2
ubuntu 16.04
-
気になる質問をクリップする
クリップした質問は、後からいつでもマイページで確認できます。
またクリップした質問に回答があった際、通知やメールを受け取ることができます。
クリップを取り消します
-
良い質問の評価を上げる
以下のような質問は評価を上げましょう
- 質問内容が明確
- 自分も答えを知りたい
- 質問者以外のユーザにも役立つ
評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。
質問の評価を上げたことを取り消します
-
評価を下げられる数の上限に達しました
評価を下げることができません
- 1日5回まで評価を下げられます
- 1日に1ユーザに対して2回まで評価を下げられます
質問の評価を下げる
teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。
- プログラミングに関係のない質問
- やってほしいことだけを記載した丸投げの質問
- 問題・課題が含まれていない質問
- 意図的に内容が抹消された質問
- 過去に投稿した質問と同じ内容の質問
- 広告と受け取られるような投稿
評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。
質問の評価を下げたことを取り消します
この機能は開放されていません
評価を下げる条件を満たしてません
質問の評価を下げる機能の利用条件
この機能を利用するためには、以下の事項を行う必要があります。
- 質問回答など一定の行動
-
メールアドレスの認証
メールアドレスの認証
-
質問評価に関するヘルプページの閲覧
質問評価に関するヘルプページの閲覧
まだ回答がついていません
15分調べてもわからないことは、teratailで質問しよう!
- ただいまの回答率 88.21%
- 質問をまとめることで、思考を整理して素早く解決
- テンプレート機能で、簡単に質問をまとめられる