プログラマーの徒然ブログ

プログラミングに関することをはじめ、興味がでたものを雑多に!

【numba】サクッとPythonを高速化

www.python.jp

Deep Learningも流行ってきて、Pythonでプログラムを作成する人も多いんじゃないでしょうか?

Pythonは簡単に書けるので、私もちょっとしたデータ整形など色んな所で使っています。

そんなPythonですが、時々気になってしまうのが処理の遅さです。 C/C++などのコンパイル言語と比べると、処理時間が長くなりがちです。

そんなPythonの処理を簡単に高速化できるかもしれないnumbaというモジュールを紹介します。

最初に結論:何をすればいいの?

1: numbaのインストール

pip install numba

2: from numba import jit@jitの追加

from numba import jit

@jit
def func(x, y):

以上、これで高速化できない場合は、numbaが扱えるようにコードを修正していく必要があるかも。。

numbaとは?

Python仮想マシンコードを取得し、LLVMコンパイラが扱えるようにLLVM-IRを生成、LLVMコンパイラで動作マシン用のネイティブコードにするようです。

そのため、numbaのデコレータ@jitを付与したPythonコードを実行すると、初回にコンパイルが行われます。 コンパイル済みの関数が実行されるようになるので、重い処理や何回も呼ばれる処理なのでは、高速化の恩恵を受けやすいです。

メリット

  • デコレータの追加だけで手軽に高速化できる
  • 事前コンパイル不要で、これまで通り実行可能

デメリット

サンプルコードは、配列の総和を計算するコードです。

import time
import numpy as np

def sum_reduction(arr):
    sum = 0.
    for i in range(len(arr)):
        sum += arr[i]
    return sum

N = 10000000
arr = np.ones(N)

t1 = time.time()
sum = sum_reduction(arr)
t2 = time.time() - t1

print('Result: ', sum)
print('Time: ', t2, ' sec')

結果

Result:  10000000.0
Time:  1.4054384231567383  sec

numbaを適用する。

import time
import numpy as np
from numba import jit

@jit
def sum_reduction(arr):
    sum = 0.
    for i in range(len(arr)):
        sum += arr[i]
    return sum

N = 10000000
arr = np.ones(N)

t1 = time.time()
sum = sum_reduction(arr)
t2 = time.time() - t1

print('Result: ', sum)
print('Time: ', t2, ' sec')

結果

Result:  10000000.0
Time:  0.08794617652893066  sec

numbaのデコレータを付与するだけで、サクッと16倍も高速になりました。

速くならなかったら?

Objectモードが適用されているかもしれない。

今のnumbaでは、No pythonモードでコンパイルし、コンパイルに失敗するとObjectモードでコンパイルされます。

Objectモードは、No pythonモードと異なり型推定に失敗した部分等はPythonで処理されるため、かえって遅くなる場合もある。

将来の仕様では、Objectモードはオプションとなるようである。

現状、No pythonモードを強制するためには@jit(nopython=True)とする必要がある。 このように設定すると、型推定が失敗するとコンパイルエラーが出力されるようになる。

エラーが出てきた場合は、関数の中で重い処理のみを別の関数として切り出し、numbaを適用するなどの工夫をする必要がある。

まとめ

numbaを利用すると、比較的簡単にPythonのプログラムを高速化できる。

ただし、高速化できない場合も往々にある。 そのときは、コンパイルエラーに従ってコードを粛々と修正する。

たまに、遅いなと感じたら、numbaを使ってみてはいかがでしょうか?

より詳細な情報は以下の公式ページをご参照ください。

numba.pydata.org