https://qiita.com/takubb/items/fd972f0ac3dba909c293
をGoogleColabを利用して学習していましたが、順番に処理を実行したところエラーが発生して先に進めません。(同様の質問がshikinamiさまから出ていますが、参照しても解決しませんでした)
レベルが低いのは承知の上で、エラー解消方法をどなたかご教示ください。
発生している問題・エラーメッセージ
AttributeError Traceback (most recent call last) <ipython-input-21-71601c2017dc> in <module>() 5 6 for epoch in range(max_epoch): ----> 7 train_ = train(model) 8 test_ = train(model) 9 train_loss_.append(train_) <ipython-input-20-c19e951568be> in train(model) 12 optimizer.zero_grad() 13 loss, logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) ---> 14 loss.backward 15 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 16 optimizer.step() AttributeError: 'str' object has no attribute 'backward'
該当のソースコード
Python
1# 最適化手法の設定 2optimizer = AdamW(model.parameters(), lr=2e-5) 3 4# 訓練パートの定義 5def train(model): 6 model.train() # 訓練モードで実行 7 train_loss = 0 8 for batch in train_dataloader:# train_dataloaderはword_id, mask, labelを出力する点に注意 9 b_input_ids = batch[0].to(device) 10 b_input_mask = batch[1].to(device) 11 b_labels = batch[2].to(device) 12 optimizer.zero_grad() 13 loss, logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels) 14 loss.backward() 15 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 16 optimizer.step() 17 train_loss += loss.item() 18 return train_loss 19 20# テストパートの定義 21def validation(model): 22 model.eval()# 訓練モードをオフ 23 val_loss = 0 24 with torch.no_grad(): # 勾配を計算しない 25 for batch in validation_dataloader: 26 b_input_ids = batch[0].to(device) 27 b_input_mask = batch[1].to(device) 28 b_labels = batch[2].to(device) 29 with torch.no_grad(): 30 (loss, logits) = model(b_input_ids, 31 token_type_ids=None, 32 attention_mask=b_input_mask, 33 labels=b_labels) 34 val_loss += loss.item() 35 return val_loss 36 37# 学習の実行 38max_epoch = 4 39train_loss_ = [] 40test_loss_ = [] 41 42for epoch in range(max_epoch): 43 train_ = train(model) 44 test_ = train(model) 45 train_loss_.append(train_) 46 test_loss_.append(test_) 47
試したこと
ここに問題に対して試したことを記載してください。
補足情報(FW/ツールのバージョンなど)
ここにより詳細な情報を記載してください。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。