質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

387閲覧

python skorchモジュールfitの正しい引数の与え方がわからない

GrinTea

総合スコア10

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2022/06/11 09:28

編集2022/06/11 14:24

前提

「braindecode」というモジュールを使用して、学習モデルの作成を行いたい。
しかし、以下のエラーによって実行ができない状況です。

実現したいこと

fit を行った際、skorch内でエラーが出てしまします。
fitに与える引数に問題があると考えていますが、その原因がわかりません。
正しい、引数の与え方を教えていただけると幸いです。

発生している問題・エラーメッセージ

Traceback (most recent call last): File "C:\◆◆◆\◆◆◆\◆◆◆\◆◆◆\◆◆◆\◆◆◆", line 126, in <module> test_eeg_classifier(TT_data, TT_ans) File "C:\◆◆◆\◆◆◆\◆◆◆\◆◆◆\◆◆◆\◆◆◆", line 76, in test_eeg_classifier clf.fit(X = X[:480], y = Y[:480], epochs=100) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\classifier.py", line 141, in fit return super(NeuralNetClassifier, self).fit(X, y, **fit_params) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 1215, in fit self.partial_fit(X, y, **fit_params) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 1174, in partial_fit self.fit_loop(X, y, **fit_params) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 1087, in fit_loop self.run_single_epoch(dataset_train, training=True, prefix="train", File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 1127, in run_single_epoch self.notify("on_batch_end", batch=batch, training=training, **step) File "C:\ProgramData\Anaconda3\lib\site-packages\skorch\net.py", line 357, in notify getattr(self, method_name)(self, **cb_kwargs) TypeError: on_batch_end() missing 2 required positional arguments: 'X' and 'y'

該当のソースコード

python

1def test_eeg_classifier(X, Y): 2 3 n_classes = 3 4 in_chans = X.shape[1] 5 6 device = 'cuda' 7 8 9 # Set if you want to use GPU 10 # You can also use torch.cuda.is_available() to determine if cuda is available on your machine. 11 cuda = True 12 set_random_seeds(seed=20220109, cuda=cuda) 13 14 # This will determine how many crops are processed in parallel 15 input_window_samples = X.shape[2] 16 n_classes = 3 17 in_chans = X.shape[1] 18 # final_conv_length determines the size of the receptive field of the ConvNet 19 model = Deep4Net(in_chans=in_chans, n_classes=n_classes, 20 input_window_samples=input_window_samples, 21 final_conv_length='auto').cuda() 22 to_dense_prediction_model(model) 23 24 25 # determine output size 26 test_input = np_to_th( 27 np.ones((2, in_chans, input_window_samples, 1), dtype=np.float32) 28 ).cuda() 29 out = model(test_input) 30 n_preds_per_input = out.cpu().data.numpy().shape[2] 31 #print(out) 32 train_set = create_from_X_y(X[:480], Y[:480], 33 drop_last_window=False, 34 sfreq=250, 35 window_size_samples=input_window_samples, 36 window_stride_samples=n_preds_per_input) 37 38 valid_set = create_from_X_y(X[480:600], Y[480:600], 39 drop_last_window=False, 40 sfreq=250, 41 window_size_samples=input_window_samples, 42 window_stride_samples=n_preds_per_input) 43 44 cropped_cb_train = CroppedTrialEpochScoring( 45 "accuracy", 46 name="train_trial_accuracy", 47 lower_is_better=False, 48 on_train=True, 49 ) 50 51 cropped_cb_valid = CroppedTrialEpochScoring( 52 "accuracy", 53 on_train=False, 54 name="valid_trial_accuracy", 55 lower_is_better=False, 56 ) 57 58 # valid_set = Dataset(X[480:600], Y[480:600]) 59 # train_set = Dataset(X[:480], Y[:480]) 60 61 clf = EEGClassifier( 62 model, 63 cropped=True, 64 criterion=CroppedLoss, 65 criterion__loss_function=nll_loss, 66 optimizer=optim.AdamW, 67 optimizer__lr = 0.005, 68 optimizer__weight_decay = 0.5*0.00001, 69 train_split=predefined_split(valid_set), 70 batch_size=10, 71 callbacks=[ 72 ("train_trial_accuracy", cropped_cb_train), 73 ("valid_trial_accuracy", cropped_cb_valid), 74 ], 75 ) 76 clf.fit(train_set, y=None, epochs=100) 77 78 79if __name__ == "__main__": 80 import h5py 81 import numpy as np 82 import mne 83 from mne.io import concatenate_raws 84 from skorch.helper import predefined_split 85 from torch import optim 86 from torch.nn.functional import nll_loss 87 from braindecode.classifier import EEGClassifier 88 from braindecode.datasets.xy import create_from_X_y 89 from braindecode.training.losses import CroppedLoss 90 from braindecode.models.deep4 import Deep4Net 91 from braindecode.models.util import to_dense_prediction_model 92 from braindecode.training.scoring import CroppedTrialEpochScoring 93 from braindecode.util import set_random_seeds, np_to_th 94 95 96 f = h5py.File("../data/ML_data1.h5", "r") 97 98 dataDL_R = np.empty([200, 8, 500]) 99 dataDL_L = np.empty([200, 8, 500]) 100 dataDL_N = np.empty([200, 8, 500]) 101 for i in range(200): 102 data = f["A/R/%s" % (i)][:] 103 dataDL_R[i] = data.T 104 105 for j in range(200): 106 data = f["A/L/%s" % (i)][:] 107 dataDL_L[i] = data.T 108 109 for k in range(200): 110 data = f["A/N/%s" % (i)][:] 111 dataDL_N[i] = data.T 112 113 train_data = np.vstack((dataDL_R[0:160], dataDL_L[0:160], dataDL_N[0:160])) 114 test_data = np.vstack((dataDL_R[160:200], dataDL_L[160:200], dataDL_N[160:200])) 115 TT_data = np.vstack((train_data, test_data)) 116 117 ans_R = np.array([1] * 200) 118 ans_L = np.array([2] * 200) 119 ans_N = np.array([0] * 200) 120 121 train_ans = np.hstack((ans_R[0:160], ans_L[0:160], ans_N[0:160])) 122 test_ans = np.hstack((ans_R[160:200], ans_L[160:200], ans_N[160:200])) 123 TT_ans = np.hstack((train_ans, test_ans)) 124 125 #print(TT_ans.shape) 126 test_eeg_classifier(TT_data, TT_ans) 127

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

jbpb0

2022/06/11 14:08

エラーメッセージは、「Traceback」と書かれてるところ以降を全部省略しないで書いてください (ユーザー名等の個人情報は伏せ字にして) ここに書くのではなく、質問を編集してください
GrinTea

2022/06/11 14:27

情報が不十分で申し訳ございません。 よろしくお願いします。
GrinTea

2022/06/11 15:03

ありがとうございます! 無事解決しました。
guest

回答1

0

自己解決

jbpb0教えていただきました。

https://githublab.com/repository/issues/braindecode/braindecode/362
に書いてる通り「braindecode0.5.1」から「braindecode0.6」に更新したら無事に動きました。

投稿2022/06/13 08:30

GrinTea

総合スコア10

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問