回答編集履歴

2

修正

2020/03/19 08:00

投稿

tiitoi
tiitoi

スコア21956

test CHANGED
@@ -96,11 +96,19 @@
96
96
 
97
97
  import timeit
98
98
 
99
-
99
+ import numpy as np
100
100
 
101
101
  from numba import jit
102
102
 
103
103
 
104
+
105
+
106
+
107
+ np.random.seed(0)
108
+
109
+ X = np.random.randint(0, 9, (20, 15, 13, 10, 10))
110
+
111
+ Y = np.random.randint(0, 9, (9, 10, 10))
104
112
 
105
113
 
106
114
 

1

修正

2020/03/19 08:00

投稿

tiitoi
tiitoi

スコア21956

test CHANGED
@@ -75,3 +75,99 @@
75
75
 
76
76
 
77
77
  要素数は変わらないので、メモリ使用量は増えないと思います。
78
+
79
+
80
+
81
+
82
+
83
+ # 追記
84
+
85
+
86
+
87
+ 質問のループバージョンのコードを numba で最適化したところ、
88
+
89
+ 218ms が 8.45 ms と25倍程度高速化できました。
90
+
91
+ ブロードキャストで計算したバージョンが 8.01 ms なのでほぼ同じぐらいの速度が出るようになりました。
92
+
93
+
94
+
95
+ ```python
96
+
97
+ import timeit
98
+
99
+
100
+
101
+ from numba import jit
102
+
103
+
104
+
105
+
106
+
107
+ def calc1(X, Y):
108
+
109
+ # broadcast バージョン
110
+
111
+ X = np.expand_dims(X, axis=3)
112
+
113
+ Z = np.max(X * Y, axis=(-2, -1))
114
+
115
+
116
+
117
+
118
+
119
+ def calc2(X, Y):
120
+
121
+ # for-loop バージョン
122
+
123
+ Z = np.zeros((X.shape[0], X.shape[1], X.shape[2], Y.shape[0]), dtype=X.dtype)
124
+
125
+ for a in range(X.shape[0]):
126
+
127
+ for b in range(X.shape[1]):
128
+
129
+ for c in range(X.shape[2]):
130
+
131
+ for f in range(Y.shape[0]):
132
+
133
+ Z[a, b, c, f] = np.max(X[a, b, c] * Y[f])
134
+
135
+
136
+
137
+
138
+
139
+ @jit(nopython=True)
140
+
141
+ def calc3(X, Y):
142
+
143
+ # for-loop バージョンに numba でコンパイルしたバージョン
144
+
145
+ Z = np.zeros((X.shape[0], X.shape[1], X.shape[2], Y.shape[0]), dtype=X.dtype)
146
+
147
+ for a in range(X.shape[0]):
148
+
149
+ for b in range(X.shape[1]):
150
+
151
+ for c in range(X.shape[2]):
152
+
153
+ for f in range(Y.shape[0]):
154
+
155
+ Z[a, b, c, f] = np.max(X[a, b, c] * Y[f])
156
+
157
+
158
+
159
+ # Jupyter Notebook 上で計測
160
+
161
+ %timeit calc1(X, Y)
162
+
163
+ %timeit calc2(X, Y)
164
+
165
+ %timeit calc3(X, Y)
166
+
167
+ # 8.01 ms ± 239 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
168
+
169
+ # 218 ms ± 32.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
170
+
171
+ # 8.45 ms ± 25.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
172
+
173
+ ```