回答編集履歴

2

Numba追記

2021/12/16 03:37

投稿

bsdfan
bsdfan

スコア4794

test CHANGED
@@ -44,4 +44,50 @@
44
44
 
45
45
  こういうのの高速化はnumbaのほうが良さそうなので、そちらも試してみましたが2倍程度しか速くなりませんでした。
46
46
 
47
+ 参考までにコードは下記の通りで、元のコードにjitをつけただけで試しています。(引数はndarrayである前提)
48
+
47
49
  もともと、それほど遅くなるようなforの使い方でもないように思います。
50
+
51
+
52
+
53
+ ```python
54
+
55
+ from numba import jit
56
+
57
+
58
+
59
+ @jit
60
+
61
+ def _accuracy(preds, label, group):
62
+
63
+ BEST_LABEL = 2
64
+
65
+ i = 0
66
+
67
+ acc = 0
68
+
69
+ for n in group:
70
+
71
+ max = preds[i:i+n].argmax()
72
+
73
+ acc += (label[i+max] == BEST_LABEL)
74
+
75
+ i += n
76
+
77
+ return acc/len(group)
78
+
79
+
80
+
81
+
82
+
83
+ def accuracy(preds, data):
84
+
85
+ label = data.get_label()
86
+
87
+ group = data.get_group()
88
+
89
+
90
+
91
+ return "accuracy", _accuracy(preds, label, group), True
92
+
93
+ ```

1

追記

2021/12/16 03:37

投稿

bsdfan
bsdfan

スコア4794

test CHANGED
@@ -36,6 +36,10 @@
36
36
 
37
37
  手元のテストコードで、元のより1.5倍ちょっと早くなる程度でした。
38
38
 
39
+ (`[0, *np.cumsum(group[:-1])]`が固定で、毎回計算する必要がないなら、この部分を外に出すともう少し速くできそうです)
40
+
41
+
42
+
39
43
 
40
44
 
41
45
  こういうのの高速化はnumbaのほうが良さそうなので、そちらも試してみましたが2倍程度しか速くなりませんでした。