回答編集履歴
2
Numba追記
test
CHANGED
@@ -44,4 +44,50 @@
|
|
44
44
|
|
45
45
|
こういうのの高速化はnumbaのほうが良さそうなので、そちらも試してみましたが2倍程度しか速くなりませんでした。
|
46
46
|
|
47
|
+
参考までにコードは下記の通りで、元のコードにjitをつけただけで試しています。(引数はndarrayである前提)
|
48
|
+
|
47
49
|
もともと、それほど遅くなるようなforの使い方でもないように思います。
|
50
|
+
|
51
|
+
|
52
|
+
|
53
|
+
```python
|
54
|
+
|
55
|
+
from numba import jit
|
56
|
+
|
57
|
+
|
58
|
+
|
59
|
+
@jit
|
60
|
+
|
61
|
+
def _accuracy(preds, label, group):
|
62
|
+
|
63
|
+
BEST_LABEL = 2
|
64
|
+
|
65
|
+
i = 0
|
66
|
+
|
67
|
+
acc = 0
|
68
|
+
|
69
|
+
for n in group:
|
70
|
+
|
71
|
+
max = preds[i:i+n].argmax()
|
72
|
+
|
73
|
+
acc += (label[i+max] == BEST_LABEL)
|
74
|
+
|
75
|
+
i += n
|
76
|
+
|
77
|
+
return acc/len(group)
|
78
|
+
|
79
|
+
|
80
|
+
|
81
|
+
|
82
|
+
|
83
|
+
def accuracy(preds, data):
|
84
|
+
|
85
|
+
label = data.get_label()
|
86
|
+
|
87
|
+
group = data.get_group()
|
88
|
+
|
89
|
+
|
90
|
+
|
91
|
+
return "accuracy", _accuracy(preds, label, group), True
|
92
|
+
|
93
|
+
```
|
1
追記
test
CHANGED
@@ -36,6 +36,10 @@
|
|
36
36
|
|
37
37
|
手元のテストコードで、元のより1.5倍ちょっと早くなる程度でした。
|
38
38
|
|
39
|
+
(`[0, *np.cumsum(group[:-1])]`が固定で、毎回計算する必要がないなら、この部分を外に出すともう少し速くできそうです)
|
40
|
+
|
41
|
+
|
42
|
+
|
39
43
|
|
40
44
|
|
41
45
|
こういうのの高速化はnumbaのほうが良さそうなので、そちらも試してみましたが2倍程度しか速くなりませんでした。
|