前提・実現したいこと(+試したこと)
https://www.kaggle.com/sumantindurkhya/bert-for-regression/
のコードを実際に動かしてみたいのですが、下記のエラー(Can't pickle local object)が出力されどうすればよいか分かりません。
上記問題の解決法についてご知見を頂きたいと考えております。
コピー元のTrain Function(In[18]の6行目~)
for i, (input_ids, attention_mask, target) in enumerate(iterable=train_loader):
の部分でエラーが出ているようです。
作業環境
windows10, Anacondaを用いております。
試したこと
Kaggleのリンク先のコードをコピペして1つのpyファイルにして実行。
Windows環境のため、そのまま実行するとBroken pipeのエラーが起きたため、回避のため、
if name == 'main':
main()
と書き、コードはmain()の中にコピー&ペーストしております。
発生している問題・エラーメッセージ
Traceback (most recent call last): File "C:\Users\xxx\Desktop\bert_regression.py", line 285, in <module> main() File "C:\Users\xxx\Desktop\bert_regression.py", line 247, in main for i, (input_ids, attention_mask, target) in enumerate(iterable=train_loader): File "C:\Users\xxx\Anaconda3\envs\text\lib\site-packages\torch\utils\data\dataloader.py", line 355, in __iter__ return self._get_iterator() File "C:\Users\xxx\Anaconda3\envs\text\lib\site-packages\torch\utils\data\dataloader.py", line 301, in _get_iterator return _MultiProcessingDataLoaderIter(self) File "C:\Users\xxx\Anaconda3\envs\text\lib\site-packages\torch\utils\data\dataloader.py", line 914, in __init__ w.start() File "C:\Users\xxx\Anaconda3\envs\text\lib\multiprocessing\process.py", line 105, in start self._popen = self._Popen(self) File "C:\Users\xxx\Anaconda3\envs\text\lib\multiprocessing\context.py", line 223, in _Popen return _default_context.get_context().Process._Popen(process_obj) File "C:\Users\xxx\Anaconda3\envs\text\lib\multiprocessing\context.py", line 322, in _Popen return Popen(process_obj) File "C:\Users\xxx\Anaconda3\envs\text\lib\multiprocessing\popen_spawn_win32.py", line 65, in __init__ reduction.dump(process_obj, to_child) File "C:\Users\xxx\Anaconda3\envs\text\lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) AttributeError: Can't pickle local object 'main.<locals>.Excerpt_Dataset'
該当のソースコード
python
1#文字数の関係で一部だけ抽出いたします。全文についてはhttps://www.kaggle.com/sumantindurkhya/bert-for-regression/をご参照頂けると幸いです。 2 3 #Train function 4 def train(model, criterion, optimizer, train_loader, val_loader, epochs, device): 5 best_acc = 0 6 for epoch in trange(epochs, desc="Epoch"): 7 model.train() 8 train_loss = 0 9 for i, (input_ids, attention_mask, target) in enumerate(iterable=train_loader): 10 optimizer.zero_grad() 11 12 input_ids, attention_mask, target = input_ids.to(device), attention_mask.to(device), target.to(device) 13 14 output = model(input_ids=input_ids, attention_mask=attention_mask) 15 16 loss = criterion(output, target.type_as(output)) 17 loss.backward() 18 optimizer.step() 19 20 train_loss += loss.item() 21 22 print(f"Training loss is {train_loss/len(train_loader)}") 23 24 val_loss = evaluate(model=model, criterion=criterion, dataloader=val_loader, device=device) 25 26 print("Epoch {} complete! Validation Loss : {}".format(epoch, val_loss)) 27 28 29#エラーとなるのは下記のtrain関数を実行した際にepoch=0が表示されて処理が中断されてしまいます。 30 train(model=model, 31 criterion=criterion, 32 optimizer=optimizer, 33 train_loader=train_loader, 34 val_loader=valid_loader, 35 epochs = 10, 36 device = device) 37
補足情報(FW/ツールのバージョンなど)
上記の通りです
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2021/09/02 08:16