【Counting 1-bits】高速にビット列の1を数える!
外出自粛中に、急に気になった"ビット列の1を数える"アルゴリズムの忘却録!
ビットカウント:ビット列の1を数えるとは?
population countなんかともいわれる処理で、やりたいことは簡単!
整数を2進数としてみたときに、1の出現回数を数えること!
(本当のことを書くと、レジスタに格納された64-bitだったり、32-bitだったりするビット列の中に1が出現する回数を得る処理かな、、)
例:8までの整数のビットカウント
10進数表現 | 2進数表現 | population count結果 |
---|---|---|
0 | 0000 | 0 |
1 | 0001 | 1 |
2 | 0010 | 1 |
3 | 0011 | 2 |
4 | 0100 | 1 |
5 | 0101 | 2 |
6 | 0110 | 2 |
7 | 0111 | 3 |
8 | 1000 | 1 |
最初に、結論
手書きするなら、以下のコードが速い!
(下記コードはPythonで書いたもの)
def bit_count( number ): number = (number & 0x55555555) + (number >> 1 & 0x55555555) number = (number & 0x33333333) + (number >> 2 & 0x33333333) number = (number & 0x0f0f0f0f) + (number >> 4 & 0x0f0f0f0f) number = (number & 0x00ff00ff) + (number >> 8 & 0x00ff00ff) return (number & 0x0000ffff) + (number >>16 & 0x0000ffff)
ビット演算で、分割統治法的な処理で、固定回数で32-bitの整数のビットカウントができ、高速!
複数アルゴリズムの速度比較
簡単に思いつくアルゴリズム~結論のアルゴリズムまで、3種類を比較
Straightforward method
まず、思いつくのは以下のコードでしょう。
単純に
- 最下位ビットが1である場合、カウントアップ
- 入力値を1ビット右シフト
を入力値が0になるまでループする。
count = 0 while bit_string != 0: if bit_string&1 : count += 1 bit_string = bit_string >> 1
Improved straightforward method
2つ目は、1つ目の処理を改善した方法である。
1つ目の方法では、1000000000011
のように間に0が続く場合でも、
"1"なのか"0"なのか1つずつ確認してしまう。
2つ目の下記コードで示す方法では、連続で"0"が続く場合、一気にスキップできるため、1つ目の方法より高速になる。
しかし、1111111
のように、"1"が続く入力だとあまり速度が変わらない。
count = 0 while bit_string != 0: count += 1 bit_string = bit_string & (bit_string - 1)
Optimized method
ビット演算でカウントするアルゴリズム!
ループを使わず、どんな入力でも固定回数の操作でカウントできる!
画期的なアルゴリズムです。
工学部の方とかで、ビット演算を知っている方は、図示してみるとアハ体験になると思います。
def bit_count_numba( number ): number = (number & 0x55555555) + (number >> 1 & 0x55555555) number = (number & 0x33333333) + (number >> 2 & 0x33333333) number = (number & 0x0f0f0f0f) + (number >> 4 & 0x0f0f0f0f) number = (number & 0x00ff00ff) + (number >> 8 & 0x00ff00ff) return (number & 0x0000ffff) + (number >>16 & 0x0000ffff)
速度比較
ついでに、3つ目の方法を関数にしてnumba化してさらに高速化します! (最初のコンパイル時間を省略するために計測外で一回実行するチートをしています。。。)
numbaについては、以下の記事をどうぞ!
import time from numba import njit def bit_count( number ): number = (number & 0x55555555) + (number >> 1 & 0x55555555) number = (number & 0x33333333) + (number >> 2 & 0x33333333) number = (number & 0x0f0f0f0f) + (number >> 4 & 0x0f0f0f0f) number = (number & 0x00ff00ff) + (number >> 8 & 0x00ff00ff) return (number & 0x0000ffff) + (number >>16 & 0x0000ffff) @njit def bit_count_numba( number ): number = (number & 0x55555555) + (number >> 1 & 0x55555555) number = (number & 0x33333333) + (number >> 2 & 0x33333333) number = (number & 0x0f0f0f0f) + (number >> 4 & 0x0f0f0f0f) number = (number & 0x00ff00ff) + (number >> 8 & 0x00ff00ff) return (number & 0x0000ffff) + (number >>16 & 0x0000ffff) N = 1048576 st = time.time() count = 0 for bit_string in range(N): while bit_string != 0: if bit_string&1 : count += 1 bit_string = bit_string >> 1 elapsed_time = time.time() - st print('Count: ', count) print('Elapsed time[sec] (Straghtforward) : ', elapsed_time) st = time.time() count = 0 for bit_string in range(N): while bit_string != 0: count += 1 bit_string = bit_string & (bit_string-1) elapsed_time = time.time() - st print('Count: ', count) print('Elapsed time[sec] (Improved straghtforward) : ', elapsed_time) st = time.time() count = 0 for bit_string in range(N): count += bit_count( bit_string ) elapsed_time = time.time() - st print('Count: ', count) print('Elapsed time[sec] (Optimized) : ', elapsed_time) st = time.time() count = 0 for bit_string in range(N): while bit_string != 0: if bit_string&1 : count += 1 bit_string = bit_string >> 1 elapsed_time = time.time() - st print('Elapsed time[sec] (Straghtforward) : ', elapsed_time) st = time.time() count = 0 for bit_string in range(N): while bit_string != 0: count += 1 bit_string = bit_string & (bit_string-1) elapsed_time = time.time() - st print('Elapsed time[sec] (Improved straghtforward) : ', elapsed_time) st = time.time() count = 0 for bit_string in range(N): count += bit_count( bit_string ) elapsed_time = time.time() - st print('Elapsed time[sec] (Optimized) : ', elapsed_time) bit_count_numba( 0 ) st = time.time() count = 0 for bit_string in range(N): count += bit_count_numba( bit_string ) elapsed_time = time.time() - st print('Elapsed time[sec] (Optimized + numba) : ', elapsed_time)
結果
アルゴリズム | 処理時間[sec] |
---|---|
Straightforward method | 4.581 |
Improved straightforward method | 2.794 |
Optimized method method | 1.030 |
Optimized method + numba | 0.341 |
やはり、固定回数のビット演算で行えるだけあって、3つ目の方法が一番速いですね!
そして、Pythonは遅いですね。
numbaを使うことで3倍速くなりました。 つまり、ほかの方法もnumbaを使うことで1秒を切れそうですね。
今回は、Pythonのループ処理とかを速くする話ではないのでこの辺で!