実現したいこと
エラーの解消
発生している問題・分からないこと
Pytorch Lightningを使用中にforwardに意図したように因数が渡されないエラーが発生しています.
TypeError: forward() missing 1 required positional argument: 'masks'
エラーメッセージ
error
1TypeError Traceback (most recent call last) 2Cell In[32], line 16 3 8 logger = TensorBoardLogger( 4 9 save_dir="lightning_logs", 5 10 name=name, 6 11 version=f"Fold_{fold+1}" 7 12 ) 8 14 trainer = make_trainer(max_epochs, logger, name, patience) 9---> 16 trainer.fit(model, train_loader, val_loader) 10 18 val_results = trainer.validate(model, val_loader) 11 20 val_losses.append(val_results[0]['val_loss']) 12 13File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:532, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 14 530 self.strategy._lightning_module = model 15 531 _verify_strategy_supports_compile(model, self.strategy) 16--> 532 call._call_and_handle_interrupt( 17 533 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 18 534 ) 19 20File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:43, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs) 21 41 if trainer.strategy.launcher is not None: 22 42 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) 23---> 43 return trainer_fn(*args, **kwargs) 24 45 except _TunerExitException: 25... 26 return self._call_impl(*args, **kwargs) 27 File "/home/foo/anaconda3/envs/choi_venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl 28 return forward_call(*args, **kwargs) 29TypeError: forward() missing 1 required positional argument: 'masks'
該当のソースコード
関係しそうな部分を抜粋しています.以下のリンクにほとんどの部分を掲載しました:
https://note.com/rafo/n/ne0b558d5818f
Python
1############### 2### Dataset ### 3############### 4class VideoDataset(Dataset): 5 def __init__(self, df_list, df, tokenizer, bert_model, max_length, comment_batch_size, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size, j, frame_batch_size, num_heads, video_batch_size, d=768): 6 self.comment_processor = CommentProcessor(d, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size) 7 self.title_desc_processor = TitleDescProcessor(d, df, tokenizer, bert_model, max_length, batch_size=32) 8 self.get_j_frames = GetJFrames(j, frame_batch_size, video_batch_size=1, d=768) 9 self.video_processor = VideoProcessor(d, num_heads, video_batch_size) 10 11 self.df_list = df_list 12 self.tokenizer = tokenizer 13 self.bert_model = bert_model 14 self.max_length = max_length 15 self.comment_batch_size = comment_batch_size 16 17 def __len__(self): 18 return len(self.df_list) 19 20 def __getitem__(self, idx): 21 # 各データをstackして返す 22 df = self.df_list[idx] 23 # 動画によってコメント数が違う→バッチサイズが異なる→スタックできない→padding 24 comment_embeddings = get_comment_embedding(df, self.tokenizer, self.bert_model, self.max_length, self.comment_batch_size) 25 comment_output_avg = self.comment_processor(comment_embeddings) 26 hit_likes = torch.tensor(df['like_count'].values, dtype=torch.float16) 27 title_desc_output_avg = self.title_desc_processor() 28 top_j_sim_video_embeddings_list = self.get_j_frames() 29 video_output_avg = self.video_processor(top_j_sim_video_embeddings_list) 30 31 label = df['label'].values 32 label = torch.tensor(label, dtype=torch.float16) 33 # 自動的にlabelは(batch_size,)の形状にして渡される 34 35 return comment_output_avg, hit_likes, title_desc_output_avg, video_output_avg, label 36 37 38def collate_fn(batch): 39 # バッチ内の全ての要素からコメントテンソルを取得し、最大のコメント数を計算 40 max_comments = max([comments.size(0) for comments, _, _, _, _ in batch]) 41 padded_comments = [] 42 masks = [] 43 44 for comments, hit_likes, title_desc_embeddings, video_output, label in batch: 45 pad_size = max_comments - comments.size(0) 46 mask = torch.ones(comments.size(0), dtype=torch.bool) 47 if pad_size > 0: 48 pad_tensor = torch.zeros(pad_size, comments.size(1), comments.size(2), dtype=comments.dtype) 49 comments = torch.cat([comments, pad_tensor], dim=0) 50 pad_mask = torch.zeros(pad_size, dtype=torch.bool) 51 mask = torch.cat([mask, pad_mask], dim=0) 52 padded_comments.append(comments) 53 masks.append(mask) 54 55 # リストをTensorに変換 56 padded_comments_stack = torch.stack(padded_comments, dim=0) 57 masks_stack = torch.stack(masks, dim=0) 58 hit_likes = torch.stack([hit_likes for _, hit_likes, _, _, _ in batch], dim=0) 59 title_desc_embeddings = torch.stack([title_desc_embeddings for _, _, title_desc_embeddings, _, _ in batch], dim=0) 60 video_output_stack = torch.stack([video_output for _, _, _, video_output, _ in batch], dim=0) 61 labels = torch.stack([label for _, _, _, _, label in batch], dim=0) 62 63 return padded_comments_stack, masks_stack, hit_likes, title_desc_embeddings, video_output_stack, labels
Python
1############# 2### Class ### 3############# 4class FakeNewsDetector(pl.LightningModule): 5 def __init__(self, tokenizer, bert_model, random_state, max_length, batch_size, num_workers, lr, n_split, dropout_rate, lstm_dropout, input_size, lstm_hidden_size, hidden_dim, num_layers, bidirectional, num_heads, max_epochs, patience, fig_save_name, name, weight_decay, d=768): 6 super().__init__() 7 self.save_hyperparameters(ignore=['tokenizer', 'bert_model']) 8 9 self.validation_step_outputs = [] 10 self.d = d 11 12 self.video_fc = nn.Linear(2*d, 2*d) 13 14 self.flatten = nn.Flatten() 15 self.fc1 = nn.Linear(max_length * 2*d, 1024) 16 self.bn1 = nn.BatchNorm1d(1024) 17 self.dropout = nn.Dropout(dropout_rate) 18 self.fc2 = nn.Linear(1024, 512) 19 self.bn2 = nn.BatchNorm1d(512) 20 self.fc3 = nn.Linear(512, 128) 21 self.bn3 = nn.BatchNorm1d(128) 22 self.fc4 = nn.Linear(128, 1) 23 24 self.comment_weight = nn.Parameter(torch.randn(1)) 25 self.title_desc_weight = nn.Parameter(torch.randn(1)) 26 self.video_weight = nn.Parameter(torch.randn(1)) 27 28 self.bilstm_model = BiLSTM(input_size=int(input_size), hidden_size=int(lstm_hidden_size), 29 num_layers=int(num_layers), dropout=float(lstm_dropout)) 30 self.bilstm_model = self.bilstm_model.to('cuda') 31 32 self.comment_processor = CommentProcessor(d, num_layers, lstm_dropout, lstm_batch_size, lstm_hidden_size=768//2) 33 self.title_desc_processor = TitleDescProcessor(d) 34 self.get_j_frames = GetJFrames() 35 self.video_processor = VideoProcessor() 36 37 def forward(self, comment_embeddings, masks_stack, hit_likes, title_desc_embedding, video_output_stack): 38 39 comment_output_avg = self.comment_processor(comment_embeddings) 40 # shape: (batch_size, 2*d) 41 42 title_desc_output_avg = self.title_desc_processor(title_desc_embedding) 43 # shape: (batch_size, 2*d) 44 45 top_j_sim_video_embeddings_list = self.get_j_frames(common_ids_list) 46 video_output_avg = self.video_processor(top_j_sim_video_embeddings_list) 47 # shape: (1, 2*d) 48 49 weights = F.softmax(torch.stack([self.comment_weight, self.title_desc_weight, self.video_weight]), dim=0) 50 51 52 combined_output = weights[0] * comment_output_avg + weights[1] * title_desc_output_avg + weights[2] * video_output_avg 53 54 55 x = self.flatten(combined_output) 56 x = self.fc1(x) 57 x = F.relu(x) 58 x = self.bn1(x) 59 x = self.dropout(x) 60 61 x = self.fc2(x) 62 x = F.relu(x) 63 x = self.bn2(x) 64 x = self.dropout(x) 65 66 x = self.fc3(x) 67 x = F.relu(x) 68 x = self.bn3(x) 69 x = self.dropout(x) 70 71 x = self.fc4(x) 72 x = torch.sigmoid(x) 73 x = x.squeeze() # 不要な次元を削除して形状を(batch_size,)にする 74 return x 75 76 77 def training_step(self, batch, batch_idx): 78 comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding, label = batch 79 output = self(comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding) 80 loss = F.binary_cross_entropy(output, label) 81 self.log('train_loss', loss) 82 83 return loss 84 85 def validation_step(self, batch, batch_idx): 86 comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding, label = batch 87 output = self(comment_embeddings, masks, hit_likes, title_desc_embedding, weighted_avg_video_embedding) 88 89 loss = F.binary_cross_entropy(output, label) 90 self.log('val_loss', loss) 91 92 return loss
補足
環境は以下の通りです:
PyTorch==2.1.2
pytorch-lightning==2.0.8
エラー発生行?の「trainer.fit(model, train_loader, val_loader)」が該当のソースコードに見当たらないのですが何故でしょうか?
返信が遅れてすみません。
@meg_様
最後のトレーニング実行の部分だったので関係ないと判断し省略いたしました。後ほど修正いたします。
@can110様
リンクをありがとうございます。なぜエラーなのか全く見当がつかず丸投げ質問になってしまっている自覚はあるのですが、どう改善すれば良いかわからないというのが正直なところです。
ご指摘ありがとうございます。
> 最後のトレーニング実行の部分だったので関係ないと判断し省略いたしました。後ほど修正いたします。
該当のソースコードにあるのはクラス定義と関数定義のみに見えます。実際に実行したコードがないと回答は付かないのではないでしょうか?
文字数的にこれ以上多くの情報を載せることができないこと、pytorch-lightningの使用上実行部分は誰が書いても同じになることから、省略した次第です。文字数が現状でいっぱいでして、追加情報の掲載は厳しい状況です。
本文にコードの大半を乗せたリンクを乗せました.ぜひご確認ください.
あなたの回答
tips
プレビュー