回答編集履歴
2
Numba追記
answer
CHANGED
@@ -21,4 +21,27 @@
|
|
21
21
|
|
22
22
|
|
23
23
|
こういうのの高速化はnumbaのほうが良さそうなので、そちらも試してみましたが2倍程度しか速くなりませんでした。
|
24
|
+
参考までにコードは下記の通りで、元のコードにjitをつけただけで試しています。(引数はndarrayである前提)
|
24
|
-
もともと、それほど遅くなるようなforの使い方でもないように思います。
|
25
|
+
もともと、それほど遅くなるようなforの使い方でもないように思います。
|
26
|
+
|
27
|
+
```python
|
28
|
+
from numba import jit
|
29
|
+
|
30
|
+
@jit
|
31
|
+
def _accuracy(preds, label, group):
|
32
|
+
BEST_LABEL = 2
|
33
|
+
i = 0
|
34
|
+
acc = 0
|
35
|
+
for n in group:
|
36
|
+
max = preds[i:i+n].argmax()
|
37
|
+
acc += (label[i+max] == BEST_LABEL)
|
38
|
+
i += n
|
39
|
+
return acc/len(group)
|
40
|
+
|
41
|
+
|
42
|
+
def accuracy(preds, data):
|
43
|
+
label = data.get_label()
|
44
|
+
group = data.get_group()
|
45
|
+
|
46
|
+
return "accuracy", _accuracy(preds, label, group), True
|
47
|
+
```
|
1
追記
answer
CHANGED
@@ -17,6 +17,8 @@
|
|
17
17
|
```
|
18
18
|
|
19
19
|
手元のテストコードで、元のより1.5倍ちょっと早くなる程度でした。
|
20
|
+
(`[0, *np.cumsum(group[:-1])]`が固定で、毎回計算する必要がないなら、この部分を外に出すともう少し速くできそうです)
|
20
21
|
|
22
|
+
|
21
23
|
こういうのの高速化はnumbaのほうが良さそうなので、そちらも試してみましたが2倍程度しか速くなりませんでした。
|
22
24
|
もともと、それほど遅くなるようなforの使い方でもないように思います。
|