ビット列のカウント、比較について、以下の問題の演算回数を削減する方法がもしあれば教えていただきたいです。
ビットカウントの分割統治法 のようにビット演算を活用してなにか上手くできないかと考えているのですが、なかなか思いつきません。
プログラムは C++ で書いており、下記問題を計算する部分の呼び出し回数が多く、ボトルネックとなっているため、少しでも演算回数を減らしたいというのが質問の背景です。
追記 9/17
沢山のアイデア、サンプルコードありがとうございます。
実行時間を検証した結果、以下のようになりました。
-
C++17
-
Visual Studio 2019
-
テストケース: 今回解きたい全パターン 40万件弱
-
CPU: Core™ i9-9900K 3.6 GHz
-
40万件*100回=4000万回の計算時間を算出したら以下のようになりました。
合計のカウント | 1以上の要素数 | 2以上の要素数 | 2の要素数 | |
---|---|---|---|---|
質問記載コード | 1358ms | 1294ms | 1452ms | 1458ms |
SHOMI さんのコード | 1074ms (-284) | 1215ms (-79) | 1280ms (-172) | 1256ms (-202) |
kazuma-sさんのコード | 1068ms (-290) | 1140ms (-154) | 1137ms (-315) | 1144ms (-314) |
kazuma-sさんのコード2 | 978ms (-380) | 949ms (-345) | 1026ms (-426) | 996ms (-462) |
kichirb3 さんのコード | 957ms (-401) | 854ms (-440) | 786ms (-666) | 986ms (-472) |
皆さんにご回答いただいた内容を理解していこうと思いますが、時間がかかると思うので一旦質問はクローズします。
全員 BA としたいのですが仕様上できないので、最初にビット演算のコードを使った回答をしていただいた SHOMI さんを BA に選択します。
ありがとうございました。d
問題設定
各値が 0 ~ 4 である長さが 9 の配列があるとします。
各値の最大値は4 (0b100) で 3bit あれば表せるため、配列全体を 3 * 9 = 27bit 表現します。
- 1 ~ 3bit: a[0]
- 3 ~ 6bit: a[1]
- ...
- 21 ~ 27bit: a[9]
例えば、[0, 2, 0, 2, 2, 1, 1, 1, 4]
の場合、ビット列での表現は 100|001|001|001|010|010|000|010|000
になります。(|
は区切りがわかりやすいように入れています)
言語・環境について
- 動作環境は C++17 です。
- コンパイラは Visual Studio 2019 Update6
- 今は int (32bit) にビット値を格納しています。28~32bitは未使用で0になっています。
- 計算量削減以外に C++ に特化した最適化方法のアイデア等でも大丈夫です。
サンプルコードが C++/Python ですが、ビット演算はどの言語にもあると思うので、ビット演算を使用した計算法であれば、特に言語を限定した質問ではないです。
質問内容
このときにこのビット列から以下を求めたいのですが、サンプルコードのように各ブロックごとに取り出して加算していくやり方しか思いつきません。ビット演算などを活用し、より少ない演算回数で計算できる方法がもしあれば教えていただきたいです。
- 元の配列の各値の合計
- 元の配列の各値が1以上の要素数
- 元の配列の各値が2以上の要素数
- 元の配列の各値が2の要素数
- どれか1つだけでもよいです。
- コードでなくとも、アイデアだけでも助かります。
例: [0, 2, 0, 2, 2, 1, 1, 1, 4]
(69510160
)の場合
- 元の配列の各値の合計: 13
- 元の配列の各値が1以上の要素数: 7
- 元の配列の各値が2以上の要素数: 4
- 元の配列の各値が2の要素数: 3
C++
cpp
1#include <iostream> 2 3int main() 4{ 5 int x = 69510160; 6 7 // 各値の合計 8 { 9 int cnt = 0; 10 for (int i = 0; i < 9; ++i) 11 cnt += (x >> i * 3) & 0b111; 12 13 std::cout << cnt << std::endl; 14 } 15 16 // 各値が1以上の要素数 17 { 18 int cnt = 0; 19 for (int i = 0; i < 9; ++i) { 20 if ((x >> i * 3) & 0b111) 21 cnt++; 22 } 23 std::cout << cnt << std::endl; 24 } 25 26 // 各値が2以上の要素数 27 { 28 int cnt = 0; 29 for (int i = 0; i < 9; ++i) { 30 if (((x >> i * 3) & 0b111) >= 2) 31 cnt++; 32 } 33 std::cout << cnt << std::endl; 34 } 35 36 // 各値が2の要素数 37 { 38 int cnt = 0; 39 for (int i = 0; i < 9; ++i) { 40 if (((x >> i * 3) & 0b111) == 2) 41 cnt++; 42 } 43 std::cout << cnt << std::endl; 44 } 45}
少し演算回数を減らしたコード
cpp
1#include <iostream> 2#include <vector> 3 4/** 5 * @brief マスク 6 */ 7const std::vector<int> mask = { 8 7, 7 << 3, 7 << 6, 7 << 9, 7 << 12, 7 << 15, 7 << 18, 7 << 21, 7 << 24, 9}; 10 11/** 12 * @brief 2以上かどうか調べるときのマスク 13 */ 14const std::vector<int> ge2 = { 15 6, 6 << 3, 6 << 6, 6 << 9, 6 << 12, 6 << 15, 6 << 18, 6 << 21, 6 << 24, 16}; 17 18int main() 19{ 20 int x = 69510160; 21 22 // 各値の合計 23 { 24 int cnt = 0; 25 for (int i = 0; i < 9; ++i) 26 cnt += (x >> i * 3) & 0b111; 27 28 std::cout << cnt << std::endl; 29 } 30 31 // 各値が1以上の要素数 32 { 33 int cnt = 0; 34 for (int i = 0; i < 9; ++i) { 35 if (x & mask[i]) 36 cnt++; 37 } 38 std::cout << cnt << std::endl; 39 } 40 41 // 各値が2以上の要素数 42 { 43 int cnt = 0; 44 for (int i = 0; i < 9; ++i) { 45 if (x & ge2[i]) 46 cnt++; 47 } 48 std::cout << cnt << std::endl; 49 } 50 51 // 各値が2の要素数 52 { 53 int cnt = 0; 54 for (int i = 0; i < 9; ++i) { 55 if (((x >> i * 3) & 0b111) == 2) 56 cnt++; 57 } 58 std::cout << cnt << std::endl; 59 } 60}
Python
python
1from functools import reduce 2 3 4def to_bit(key): 5 """ビット表現に変換する。 6 """ 7 return reduce(lambda x, y: x * 8 + y, key[::-1]) 8 9 10def to_str(x): 11 """ビット表現を表す文字列に変換する。(デバッグ用) 12 """ 13 s = f"{x:027b}" 14 s = "|".join([s[i : i + 3] for i in range(0, len(s), 3)]) 15 return s 16 17########### 以下が本題 18 19key = to_bit([0, 2, 0, 2, 2, 1, 1, 1, 4]) # 実際は元の配列はなく、ビット列だけ与えられます 20print(key, to_str(key)) 21# 69510160 100|001|001|001|010|010|000|010|000 22 23 24# 各値の合計 25cnt = 0 26for i in range(9): 27 cnt += (key >> i * 3) & 0b111 28print(cnt) # 13 29 30 31# 各値が1以上の要素数 32cnt = 0 33for i in range(9): 34 val = (key >> i * 3) & 0b111 35 if val: 36 cnt += 1 37print(cnt) # 7 38 39 40# 各値が2以上の要素数 41cnt = 0 42for i in range(9): 43 val = (key >> i * 3) & 0b111 44 if val >= 2: 45 cnt += 1 46print(cnt) # 4 47 48 49# 各値が2の要素数 50cnt = 0 51for i in range(9): 52 val = (key >> i * 3) & 0b111 53 if val == 2: 54 cnt += 1 55print(cnt) # 3
回答8件
あなたの回答
tips
プレビュー