teratail header banner
teratail header banner
質問するログイン新規登録

回答編集履歴

2

Numba追記

2021/12/16 03:37

投稿

bsdfan
bsdfan

スコア4921

answer CHANGED
@@ -21,4 +21,27 @@
21
21
 
22
22
 
23
23
  こういうのの高速化はnumbaのほうが良さそうなので、そちらも試してみましたが2倍程度しか速くなりませんでした。
24
+ 参考までにコードは下記の通りで、元のコードにjitをつけただけで試しています。(引数はndarrayである前提)
24
- もともと、それほど遅くなるようなforの使い方でもないように思います。
25
+ もともと、それほど遅くなるようなforの使い方でもないように思います。
26
+
27
+ ```python
28
+ from numba import jit
29
+
30
+ @jit
31
+ def _accuracy(preds, label, group):
32
+ BEST_LABEL = 2
33
+ i = 0
34
+ acc = 0
35
+ for n in group:
36
+ max = preds[i:i+n].argmax()
37
+ acc += (label[i+max] == BEST_LABEL)
38
+ i += n
39
+ return acc/len(group)
40
+
41
+
42
+ def accuracy(preds, data):
43
+ label = data.get_label()
44
+ group = data.get_group()
45
+
46
+ return "accuracy", _accuracy(preds, label, group), True
47
+ ```

1

追記

2021/12/16 03:37

投稿

bsdfan
bsdfan

スコア4921

answer CHANGED
@@ -17,6 +17,8 @@
17
17
  ```
18
18
 
19
19
  手元のテストコードで、元のより1.5倍ちょっと早くなる程度でした。
20
+ (`[0, *np.cumsum(group[:-1])]`が固定で、毎回計算する必要がないなら、この部分を外に出すともう少し速くできそうです)
20
21
 
22
+
21
23
  こういうのの高速化はnumbaのほうが良さそうなので、そちらも試してみましたが2倍程度しか速くなりませんでした。
22
24
  もともと、それほど遅くなるようなforの使い方でもないように思います。