プログラマーの徒然ブログ

プログラミングに関することをはじめ、興味がでたものを雑多に!

【Counting 1-bits】高速にビット列の1を数える!

外出自粛中に、急に気になった"ビット列の1を数える"アルゴリズムの忘却録!

ビットカウント:ビット列の1を数えるとは?

population countなんかともいわれる処理で、やりたいことは簡単!

整数を2進数としてみたときに、1の出現回数を数えること!

(本当のことを書くと、レジスタに格納された64-bitだったり、32-bitだったりするビット列の中に1が出現する回数を得る処理かな、、)

例:8までの整数のビットカウント

10進数表現 2進数表現 population count結果
0 0000 0
1 0001 1
2 0010 1
3 0011 2
4 0100 1
5 0101 2
6 0110 2
7 0111 3
8 1000 1

最初に、結論

手書きするなら、以下のコードが速い!

(下記コードはPythonで書いたもの)

def bit_count( number ):
    number = (number & 0x55555555) + (number >> 1 & 0x55555555)
    number = (number & 0x33333333) + (number >> 2 & 0x33333333)
    number = (number & 0x0f0f0f0f) + (number >> 4 & 0x0f0f0f0f)
    number = (number & 0x00ff00ff) + (number >> 8 & 0x00ff00ff)
    return (number & 0x0000ffff) + (number >>16 & 0x0000ffff)

ビット演算で、分割統治法的な処理で、固定回数で32-bitの整数のビットカウントができ、高速!

複数アルゴリズムの速度比較

簡単に思いつくアルゴリズム~結論のアルゴリズムまで、3種類を比較

Straightforward method

まず、思いつくのは以下のコードでしょう。

単純に

  1. 最下位ビットが1である場合、カウントアップ
  2. 入力値を1ビット右シフト

を入力値が0になるまでループする。

count = 0
while bit_string != 0:
    if bit_string&1 :
        count += 1
    bit_string = bit_string >> 1

Improved straightforward method

2つ目は、1つ目の処理を改善した方法である。

1つ目の方法では、1000000000011のように間に0が続く場合でも、 "1"なのか"0"なのか1つずつ確認してしまう。

2つ目の下記コードで示す方法では、連続で"0"が続く場合、一気にスキップできるため、1つ目の方法より高速になる。

しかし、1111111のように、"1"が続く入力だとあまり速度が変わらない。

count = 0
while bit_string != 0:
    count += 1
    bit_string = bit_string & (bit_string - 1)

Optimized method

ビット演算でカウントするアルゴリズム

ループを使わず、どんな入力でも固定回数の操作でカウントできる!

画期的なアルゴリズムです。

工学部の方とかで、ビット演算を知っている方は、図示してみるとアハ体験になると思います。

def bit_count_numba( number ):
    number = (number & 0x55555555) + (number >> 1 & 0x55555555)
    number = (number & 0x33333333) + (number >> 2 & 0x33333333)
    number = (number & 0x0f0f0f0f) + (number >> 4 & 0x0f0f0f0f)
    number = (number & 0x00ff00ff) + (number >> 8 & 0x00ff00ff)
    return (number & 0x0000ffff) + (number >>16 & 0x0000ffff)

速度比較

それぞれのアルゴリズムPythonで速度比較します。

ついでに、3つ目の方法を関数にしてnumba化してさらに高速化します! (最初のコンパイル時間を省略するために計測外で一回実行するチートをしています。。。)

numbaについては、以下の記事をどうぞ!

t49m1.hatenablog.com

  • 計測環境
    • ASUS ZenBook UX390U (ノートPC)
      • Intel(R) Core(TM) i5-7200U CPU @ 2.50GHz
  • 計測プログラム
import time
from numba import njit

def bit_count( number ):
    number = (number & 0x55555555) + (number >> 1 & 0x55555555)
    number = (number & 0x33333333) + (number >> 2 & 0x33333333)
    number = (number & 0x0f0f0f0f) + (number >> 4 & 0x0f0f0f0f)
    number = (number & 0x00ff00ff) + (number >> 8 & 0x00ff00ff)
    return (number & 0x0000ffff) + (number >>16 & 0x0000ffff)

@njit
def bit_count_numba( number ):
    number = (number & 0x55555555) + (number >> 1 & 0x55555555)
    number = (number & 0x33333333) + (number >> 2 & 0x33333333)
    number = (number & 0x0f0f0f0f) + (number >> 4 & 0x0f0f0f0f)
    number = (number & 0x00ff00ff) + (number >> 8 & 0x00ff00ff)
    return (number & 0x0000ffff) + (number >>16 & 0x0000ffff)


N = 1048576

st = time.time()
count = 0
for bit_string in range(N):
    while bit_string != 0:
        if bit_string&1 :
            count += 1
        bit_string = bit_string >> 1

elapsed_time = time.time() - st

print('Count: ', count)
print('Elapsed time[sec] (Straghtforward) : ', elapsed_time)


st = time.time()
count = 0
for bit_string in range(N):
    while bit_string != 0:
        count += 1
        bit_string = bit_string & (bit_string-1)

elapsed_time = time.time() - st

print('Count: ', count)
print('Elapsed time[sec] (Improved straghtforward) : ', elapsed_time)

st = time.time()
count = 0
for bit_string in range(N):
    count += bit_count( bit_string )

elapsed_time = time.time() - st

print('Count: ', count)
print('Elapsed time[sec] (Optimized) : ', elapsed_time)

st = time.time()
count = 0
for bit_string in range(N):
    while bit_string != 0:
        if bit_string&1 :
            count += 1
        bit_string = bit_string >> 1

elapsed_time = time.time() - st

print('Elapsed time[sec] (Straghtforward) : ', elapsed_time)

st = time.time()
count = 0
for bit_string in range(N):
    while bit_string != 0:
        count += 1
        bit_string = bit_string & (bit_string-1)

elapsed_time = time.time() - st

print('Elapsed time[sec] (Improved straghtforward) : ', elapsed_time)

st = time.time()
count = 0
for bit_string in range(N):
    count += bit_count( bit_string )

elapsed_time = time.time() - st

print('Elapsed time[sec] (Optimized) : ', elapsed_time)

bit_count_numba( 0 )
st = time.time()
count = 0
for bit_string in range(N):
    count += bit_count_numba( bit_string )

elapsed_time = time.time() - st

print('Elapsed time[sec] (Optimized + numba) : ', elapsed_time)
結果
アルゴリズム 処理時間[sec]
Straightforward method 4.581
Improved straightforward method 2.794
Optimized method method 1.030
Optimized method + numba 0.341

やはり、固定回数のビット演算で行えるだけあって、3つ目の方法が一番速いですね!

そして、Pythonは遅いですね。

numbaを使うことで3倍速くなりました。 つまり、ほかの方法もnumbaを使うことで1秒を切れそうですね。

今回は、Pythonのループ処理とかを速くする話ではないのでこの辺で!