Python の Barrier で複数スレッドを同期する

Barrier(バリア)は、複数のスレッドが特定のポイントに全員到達するまで待機させるための同期機構です。全員が揃ったら一斉に次の処理に進みます。

Barrier の基本

Barrier は指定した数のスレッドが wait() を呼ぶまで、すべてのスレッドをブロックします。

import threading
import time
import random

barrier = threading.Barrier(3)  # 3つのスレッドを同期

def worker(name):
    print(f"{name}: 準備中...")
    time.sleep(random.uniform(1, 3))  # 準備時間はバラバラ
    print(f"{name}: 準備完了、他を待機")
    barrier.wait()  # 全員が揃うまで待機
    print(f"{name}: 開始!")

threads = [threading.Thread(target=worker, args=(f"Worker-{i}",)) for i in range(3)]
for t in threads:
    t.start()
for t in threads:
    t.join()

準備完了の順番はバラバラでも、「開始!」は全員揃ってから同時に表示されます。

Barrier の動作

各スレッドが barrier.wait() を呼ぶ

指定数のスレッドが wait() するまで全員ブロック

全員揃ったら一斉に解放

バリアは自動的にリセットされ再利用可能

wait() の戻り値

wait() は0からN-1までの整数を返します。そのうち1つのスレッドだけが特別な処理を行う場合に使えます。

import threading

barrier = threading.Barrier(3)

def worker(name):
    index = barrier.wait()
    if index == 0:
        print(f"{name}: リーダーとして処理を実行")
    else:
        print(f"{name}: 通常の処理")

threads = [threading.Thread(target=worker, args=(f"Worker-{i}",)) for i in range(3)]
for t in threads:
    t.start()
for t in threads:
    t.join()

action パラメータ

バリアが解放される直前に実行される関数を指定できます。

import threading

def on_barrier_release():
    print("=== 全員集合!処理開始 ===")

barrier = threading.Barrier(3, action=on_barrier_release)

def worker(name):
    print(f"{name}: 待機中")
    barrier.wait()
    print(f"{name}: 処理実行")

threads = [threading.Thread(target=worker, args=(f"Worker-{i}",)) for i in range(3)]
for t in threads:
    t.start()
for t in threads:
    t.join()

action は最後にバリアに到達したスレッドによって実行されます。

タイムアウト

wait() にタイムアウトを指定できます。タイムアウトすると BrokenBarrierError が発生します。

import threading
import time

barrier = threading.Barrier(3)

def fast_worker():
    barrier.wait(timeout=2)
    print("Fast worker: 完了")

def slow_worker():
    time.sleep(5)  # 遅い
    barrier.wait()

try:
    t1 = threading.Thread(target=fast_worker)
    t2 = threading.Thread(target=fast_worker)
    t3 = threading.Thread(target=slow_worker)
    
    for t in [t1, t2, t3]:
        t.start()
    for t in [t1, t2, t3]:
        t.join()
except threading.BrokenBarrierError:
    print("バリアがタイムアウトしました")

abort() でバリアを破壊

abort() を呼ぶと、待機中のすべてのスレッドで BrokenBarrierError が発生します。

import threading
import time

barrier = threading.Barrier(3)

def worker(name):
    try:
        print(f"{name}: 待機中")
        barrier.wait()
        print(f"{name}: 完了")
    except threading.BrokenBarrierError:
        print(f"{name}: バリアが破壊されました")

def aborter():
    time.sleep(1)
    print("バリアを破壊")
    barrier.abort()

threads = [threading.Thread(target=worker, args=(f"Worker-{i}",)) for i in range(2)]
threads.append(threading.Thread(target=aborter))

for t in threads:
    t.start()
for t in threads:
    t.join()

実用例:並列計算のフェーズ同期

複数のフェーズがある並列計算で、各フェーズの終了を同期する例です。

import threading
import time
import random

barrier = threading.Barrier(3)

def compute_phase(name):
    for phase in range(3):
        # 計算フェーズ
        work_time = random.uniform(0.5, 1.5)
        print(f"{name}: Phase {phase} 計算中 ({work_time:.1f}秒)")
        time.sleep(work_time)
        
        # 全員の計算完了を待つ
        barrier.wait()
        print(f"{name}: Phase {phase} 完了")

threads = [threading.Thread(target=compute_phase, args=(f"Node-{i}",)) for i in range(3)]
for t in threads:
    t.start()
for t in threads:
    t.join()

Barrier は、並列処理で「全員揃ってから次へ」という同期が必要な場面で役立ちます。