質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
87.20%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

解決済

python skorchモジュールfitの正しい引数の与え方がわからない

GrinTea
GrinTea

総合スコア10

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

1回答

0評価

0クリップ

140閲覧

投稿2022/06/11 09:28

編集2022/06/13 17:30

前提

「braindecode」というモジュールを使用して、学習モデルの作成を行いたい。
しかし、以下のエラーによって実行ができない状況です。

実現したいこと

fit を行った際、skorch内でエラーが出てしまします。
fitに与える引数に問題があると考えていますが、その原因がわかりません。
正しい、引数の与え方を教えていただけると幸いです。

発生している問題・エラーメッセージ

Traceback (most recent call last): File "C:\◆◆◆\◆◆◆\◆◆◆\◆◆◆\◆◆◆\◆◆◆", line 126, in <module> test_eeg_classifier(TT_data, TT_ans) File "C:\◆◆◆\◆◆◆\◆◆◆\◆◆◆\◆◆◆\◆◆◆", line 76, in test_eeg_classifier clf.fit(X = X[:480], y = Y[:480], epochs=100) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\classifier.py", line 141, in fit return super(NeuralNetClassifier, self).fit(X, y, **fit_params) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 1215, in fit self.partial_fit(X, y, **fit_params) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 1174, in partial_fit self.fit_loop(X, y, **fit_params) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 1087, in fit_loop self.run_single_epoch(dataset_train, training=True, prefix="train", File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 1127, in run_single_epoch self.notify("on_batch_end", batch=batch, training=training, **step) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 357, in notify getattr(self, method_name)(self, **cb_kwargs) TypeError: on_batch_end() missing 2 required positional arguments: 'X' and 'y'

該当のソースコード

python

def test_eeg_classifier(X, Y): n_classes = 3 in_chans = X.shape[1] device = 'cuda' # Set if you want to use GPU # You can also use torch.cuda.is_available() to determine if cuda is available on your machine. cuda = True set_random_seeds(seed=20220109, cuda=cuda) # This will determine how many crops are processed in parallel input_window_samples = X.shape[2] n_classes = 3 in_chans = X.shape[1] # final_conv_length determines the size of the receptive field of the ConvNet model = Deep4Net(in_chans=in_chans, n_classes=n_classes, input_window_samples=input_window_samples, final_conv_length='auto').cuda() to_dense_prediction_model(model) # determine output size test_input = np_to_th( np.ones((2, in_chans, input_window_samples, 1), dtype=np.float32) ).cuda() out = model(test_input) n_preds_per_input = out.cpu().data.numpy().shape[2] #print(out) train_set = create_from_X_y(X[:480], Y[:480], drop_last_window=False, sfreq=250, window_size_samples=input_window_samples, window_stride_samples=n_preds_per_input) valid_set = create_from_X_y(X[480:600], Y[480:600], drop_last_window=False, sfreq=250, window_size_samples=input_window_samples, window_stride_samples=n_preds_per_input) cropped_cb_train = CroppedTrialEpochScoring( "accuracy", name="train_trial_accuracy", lower_is_better=False, on_train=True, ) cropped_cb_valid = CroppedTrialEpochScoring( "accuracy", on_train=False, name="valid_trial_accuracy", lower_is_better=False, ) # valid_set = Dataset(X[480:600], Y[480:600]) # train_set = Dataset(X[:480], Y[:480]) clf = EEGClassifier( model, cropped=True, criterion=CroppedLoss, criterion__loss_function=nll_loss, optimizer=optim.AdamW, optimizer__lr = 0.005, optimizer__weight_decay = 0.5*0.00001, train_split=predefined_split(valid_set), batch_size=10, callbacks=[ ("train_trial_accuracy", cropped_cb_train), ("valid_trial_accuracy", cropped_cb_valid), ], ) clf.fit(train_set, y=None, epochs=100) if __name__ == "__main__": import h5py import numpy as np import mne from mne.io import concatenate_raws from skorch.helper import predefined_split from torch import optim from torch.nn.functional import nll_loss from braindecode.classifier import EEGClassifier from braindecode.datasets.xy import create_from_X_y from braindecode.training.losses import CroppedLoss from braindecode.models.deep4 import Deep4Net from braindecode.models.util import to_dense_prediction_model from braindecode.training.scoring import CroppedTrialEpochScoring from braindecode.util import set_random_seeds, np_to_th f = h5py.File("../data/ML_data1.h5", "r") dataDL_R = np.empty([200, 8, 500]) dataDL_L = np.empty([200, 8, 500]) dataDL_N = np.empty([200, 8, 500]) for i in range(200): data = f["A/R/%s" % (i)][:] dataDL_R[i] = data.T for j in range(200): data = f["A/L/%s" % (i)][:] dataDL_L[i] = data.T for k in range(200): data = f["A/N/%s" % (i)][:] dataDL_N[i] = data.T train_data = np.vstack((dataDL_R[0:160], dataDL_L[0:160], dataDL_N[0:160])) test_data = np.vstack((dataDL_R[160:200], dataDL_L[160:200], dataDL_N[160:200])) TT_data = np.vstack((train_data, test_data)) ans_R = np.array([1] * 200) ans_L = np.array([2] * 200) ans_N = np.array([0] * 200) train_ans = np.hstack((ans_R[0:160], ans_L[0:160], ans_N[0:160])) test_ans = np.hstack((ans_R[160:200], ans_L[160:200], ans_N[160:200])) TT_ans = np.hstack((train_ans, test_ans)) #print(TT_ans.shape) test_eeg_classifier(TT_data, TT_ans)

良い質問の評価を上げる

以下のような質問は評価を上げましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

  • プログラミングに関係のない質問
  • やってほしいことだけを記載した丸投げの質問
  • 問題・課題が含まれていない質問
  • 意図的に内容が抹消された質問
  • 過去に投稿した質問と同じ内容の質問
  • 広告と受け取られるような投稿

評価を下げると、トップページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

jbpb0

2022/06/11 14:08

エラーメッセージは、「Traceback」と書かれてるところ以降を全部省略しないで書いてください (ユーザー名等の個人情報は伏せ字にして) ここに書くのではなく、質問を編集してください
GrinTea

2022/06/11 14:27

情報が不十分で申し訳ございません。 よろしくお願いします。
GrinTea

2022/06/11 15:03

ありがとうございます! 無事解決しました。

まだ回答がついていません

会員登録して回答してみよう

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
87.20%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。