Python の map と starmap でプロセスに処理を分散する

Python の multiprocessing.Pool には、イテラブルの各要素に関数を並列適用する mapstarmap メソッドがあります。これらを使うと、データを複数のプロセスに効率的に分散処理できます。

map の基本

map() はイテラブルの各要素に関数を適用し、結果をリストで返します。

from multiprocessing import Pool

def square(x):
    return x ** 2

if __name__ == "__main__":
    with Pool(4) as pool:
        results = pool.map(square, range(10))
    print(results)  # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

組み込みの map() と同じ感覚で使えますが、処理が並列に実行されます。

複数の引数を渡す場合の問題

map() は 1 つの引数しか受け取れません。複数の引数を渡すには工夫が必要です。

from multiprocessing import Pool

def add(x, y):
    return x + y

if __name__ == "__main__":
    with Pool() as pool:
        # これはエラーになる
        # results = pool.map(add, [1, 2, 3], [10, 20, 30])
        pass

starmap で複数引数を渡す

starmap() は引数のタプルを展開して関数に渡します。

from multiprocessing import Pool

def add(x, y):
    return x + y

if __name__ == "__main__":
    with Pool(4) as pool:
        args = [(1, 10), (2, 20), (3, 30)]
        results = pool.starmap(add, args)
    print(results)  # [11, 22, 33]

各タプルがアンパック(展開)されて、add(1, 10)add(2, 20) のように呼び出されます。

map(func, iterable)

func(item) の形で呼び出す。引数は 1 つのみ。

starmap(func, iterable)

func(*item) の形で呼び出す。タプルを展開して複数引数を渡せる。

chunksize でパフォーマンスを調整する

大量のデータを処理する場合、chunksize を指定するとオーバーヘッドを減らせます。

from multiprocessing import Pool

def process(x):
    return x ** 2

if __name__ == "__main__":
    data = range(100000)
    
    with Pool(4) as pool:
        # chunksize を指定
        results = pool.map(process, data, chunksize=1000)

chunksize は、一度にワーカーに送るタスクの数です。データ量が多い場合は、大きめの値を指定すると効率的です。

imap と imap_unordered

imap()map() の遅延評価版で、イテレータを返します。

from multiprocessing import Pool

def slow_square(x):
    return x ** 2

if __name__ == "__main__":
    with Pool(4) as pool:
        # 結果を順番に取得
        for result in pool.imap(slow_square, range(10)):
            print(result)

imap_unordered() は、完了した順に結果を返します。順序を気にしない場合は高速です。

from multiprocessing import Pool
import time

def variable_task(x):
    time.sleep(x % 3)  # 処理時間がバラバラ
    return x

if __name__ == "__main__":
    with Pool(4) as pool:
        # 完了した順に取得(順序不定)
        for result in pool.imap_unordered(variable_task, range(10)):
            print(result)

starmap の遅延評価版はない

starmap() には imap 相当の遅延評価版がありません。代わりに imap() と組み合わせて使います。

from multiprocessing import Pool
from itertools import starmap

def add(x, y):
    return x + y

if __name__ == "__main__":
    with Pool(4) as pool:
        args = [(1, 10), (2, 20), (3, 30)]
        
        # ラッパー関数を使う
        results = pool.imap(lambda a: add(*a), args)
        print(list(results))

ただし、ラムダ式は pickle できないため、この方法は動作しません。代わりにモジュールレベルの関数を定義します。

from multiprocessing import Pool

def add(x, y):
    return x + y

def add_wrapper(args):
    return add(*args)

if __name__ == "__main__":
    with Pool(4) as pool:
        args = [(1, 10), (2, 20), (3, 30)]
        results = list(pool.imap(add_wrapper, args))
    print(results)  # [11, 22, 33]

この方法なら遅延評価で複数引数の関数を並列実行できます。