質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.21%

【Pytorch】RuntimeError: stack expects each tensor to be equal size, but got〜

受付中

回答 0

投稿

  • 評価
  • クリップ 0
  • VIEW 127

SuzuAya

score 52

前提・実現したいこと

pytorchで物体検出の学習を行おうとしたところ、学習部分で以下のようなエラーが発生しました。エラーとして表示されるテンソルのサイズが毎回異なることと、kaggleのnotebookを利用すると同じデータでもうまくいくことなどから、テンソルのサイズが異なっているというわけではなく、何か別の原因があるように感じています。
原因について何かお心当たりのある方がいらっしゃいましたらぜひご回答いただけますと幸いです。

発生している問題・エラーメッセージ

RuntimeError                              Traceback (most recent call last)
<ipython-input-31-5bb99a4def7d> in <module>
     95     manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch"))
     96 
---> 97 _ = trainer.run(train_loader, max_epochs=epoch)

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in run(self, data, max_epochs, epoch_length, seed)
    700 
    701         self.state.dataloader = data
--> 702         return self._internal_run()
    703 
    704     @staticmethod

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _internal_run(self)
    773             self._dataloader_iter = None
    774             self.logger.error(f"Engine run is terminating due to exception: {e}")
--> 775             self._handle_exception(e)
    776 
    777         self._dataloader_iter = None

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _handle_exception(self, e)
    467             self._fire_event(Events.EXCEPTION_RAISED, e)
    468         else:
--> 469             raise e
    470 
    471     @property

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _internal_run(self)
    743                     self._setup_engine()
    744 
--> 745                 time_taken = self._run_once_on_dataset()
    746                 # time is available for handlers but must be update after fire
    747                 self.state.times[Events.EPOCH_COMPLETED.name] = time_taken

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _run_once_on_dataset(self)
    848         except Exception as e:
    849             self.logger.error(f"Current run is terminating due to exception: {e}")
--> 850             self._handle_exception(e)
    851 
    852         return time.time() - start_time

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _handle_exception(self, e)
    467             self._fire_event(Events.EXCEPTION_RAISED, e)
    468         else:
--> 469             raise e
    470 
    471     @property

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/ignite/engine/engine.py in _run_once_on_dataset(self)
    799                     if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
    800                         self._fire_event(Events.GET_BATCH_STARTED)
--> 801                     self.state.batch = next(self._dataloader_iter)
    802                     self._fire_event(Events.GET_BATCH_COMPLETED)
    803                     iter_counter += 1

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py in _next_data(self)
   1083             else:
   1084                 del self._task_info[idx]
-> 1085                 return self._process_data(data)
   1086 
   1087     def _try_put_index(self):

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1109         self._try_put_index()
   1110         if isinstance(data, ExceptionWrapper):
-> 1111             data.reraise()
   1112         return data
   1113 

~/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/_utils.py in reraise(self)
    426             # have message field
    427             raise self.exc_type(message=msg)
--> 428         raise self.exc_type(msg)
    429 
    430 

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/user/home/user/libraries/Miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [3, 512, 499] at entry 0 and [3, 512, 458] at entry 1 #この部分の数字は実行のたびに異なるものが表示されます

該当のソースコード(コードが長いため、エラーが発生した部分のみ抜粋しております)

train_loader = DataLoader(
    train_dataset,
    batch_size=flags.batchsize,
    num_workers=flags.num_workers,
    shuffle=True,
    pin_memory=True,
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=flags.valid_batchsize,
    num_workers=flags.num_workers,
    shuffle=False,
    pin_memory=True,
)

device = torch.device(flags.device)

predictor = build_predictor(model_name=flags.model_name, model_mode=flags.model_mode)
classifier = Classifier(predictor)
model = classifier
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-3)

# Train setup
trainer = create_trainer(model, optimizer, device)

ema = EMA(predictor, decay=flags.ema_decay)

def eval_func(*batch):
    loss, metrics = model(*[elem.to(device) for elem in batch])
    # HACKING: report ema value with prefix.
    if flags.ema_decay > 0:
        classifier.prefix = "ema_"
        ema.assign()
        loss, metrics = model(*[elem.to(device) for elem in batch])
        ema.resume()
        classifier.prefix = ""

valid_evaluator = E.Evaluator(
    valid_loader, model, progress_bar=False, eval_func=eval_func, device=device
)

# log_trigger = (10 if debug else 1000, "iteration")
log_trigger = (1, "epoch")
log_report = E.LogReport(trigger=log_trigger)
extensions = [
    log_report,
    E.ProgressBarNotebook(update_interval=10 if debug else 100),  # Show progress bar during training
    E.PrintReportNotebook(),  # Show "log" on jupyter notebook  
    # E.ProgressBar(update_interval=10 if debug else 100),  # Show progress bar during training
    # E.PrintReport(),  # Print "log" to terminal
    E.FailOnNonNumber(),  # Stop training when nan is detected.
]
epoch = flags.epoch
models = {"main": model}
optimizers = {"main": optimizer}
manager = IgniteExtensionsManager(
    trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(outdir),
)
# Run evaluation for valid dataset in each epoch.
manager.extend(valid_evaluator)

# Save predictor.pt every epoch
manager.extend(
    E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "epoch")
)
# Check & Save best validation predictor.pt every epoch
# manager.extend(E.snapshot_object(predictor, "best_predictor.pt"),
#                trigger=MinValueTrigger("validation/module/nll",
#                trigger=(flags.snapshot_freq, "iteration")))

# --- lr scheduler ---
if flags.scheduler_type != "":
    scheduler_type = flags.scheduler_type
    print(f"using {scheduler_type} scheduler with kwargs {flags.scheduler_kwargs}")
    manager.extend(
        LRScheduler(optimizer, scheduler_type, flags.scheduler_kwargs),
        trigger=flags.scheduler_trigger,
    )

manager.extend(E.observe_lr(optimizer=optimizer), trigger=log_trigger)

if flags.ema_decay > 0:
    # Exponential moving average
    manager.extend(lambda manager: ema(), trigger=(1, "iteration"))

    def save_ema_model(manager):
        ema.assign()
        #torch.save(predictor.state_dict(), outdir / "predictor_ema.pt")
        torch.save(predictor.state_dict(), "/result/predictor_ema.pt")
        ema.resume()

    manager.extend(save_ema_model, trigger=(flags.snapshot_freq, "epoch"))

_ = trainer.run(train_loader, max_epochs=epoch)

試したこと

pytorchやcudaの再インストール

補足情報(FW/ツールのバージョンなど)

pytorch: 1.7.1
cuda: 10.2
ubuntu 16.04

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

まだ回答がついていません

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.21%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る