中学理科1626207 views
高校日本史189857 views
いろは2986023 views
高校物理158224 views
小学社会308636 views
英語607877 views
高校生物549842 views
りんご192546 views
小学理科717236 views
中学社会667106 views
Help
Tools

English

連立方程式をガウス消去法で解く|Python

ガウス消去法は連立一次方程式を解く最も基本的なアルゴリズムだ。行列を上三角行列に変形し、後退代入で解を求める。数値計算の基礎として理解しておくべき手法である。

連立方程式と行列表現

連立方程式 を考える。例として次の 3 元連立方程式を解く。

これを拡大係数行列として表現する。

# 拡大係数行列 [A | b]
Ab = [
    [2, 1, -1, 8],
    [-3, -1, 2, -11],
    [-2, 1, 2, -3]
]

前進消去

前進消去では、ピボット(対角成分)を使って下の行の要素を消していく。

def forward_elimination(Ab):
    """前進消去:上三角行列を作る"""
    n = len(Ab)
    
    for i in range(n):
        # ピボットが 0 の場合は行を交換(後述)
        if Ab[i][i] == 0:
            for k in range(i + 1, n):
                if Ab[k][i] != 0:
                    Ab[i], Ab[k] = Ab[k], Ab[i]
                    break
        
        # i 行目より下の行を消去
        for j in range(i + 1, n):
            if Ab[i][i] == 0:
                continue
            factor = Ab[j][i] / Ab[i][i]
            for k in range(i, len(Ab[0])):
                Ab[j][k] -= factor * Ab[i][k]
    
    return Ab

Ab = [
    [2, 1, -1, 8],
    [-3, -1, 2, -11],
    [-2, 1, 2, -3]
]

Ab = forward_elimination(Ab)
for row in Ab:
    print([f'{x:.2f}' for x in row])
# ['2.00', '1.00', '-1.00', '8.00']
# ['0.00', '0.50', '0.50', '1.00']
# ['0.00', '0.00', '-1.00', '1.00']

上三角行列ができた。対角線より下がすべて 0 になっている。

後退代入

上三角行列ができたら、下から順に解を求めていく。

def back_substitution(Ab):
    """後退代入:上三角行列から解を求める"""
    n = len(Ab)
    x = [0] * n
    
    for i in range(n - 1, -1, -1):
        x[i] = Ab[i][n]  # 右辺
        for j in range(i + 1, n):
            x[i] -= Ab[i][j] * x[j]
        x[i] /= Ab[i][i]
    
    return x

x = back_substitution(Ab)
print(x)  # [2.0, 3.0, -1.0]

解は となる。

まとめた実装

前進消去と後退代入をまとめて、連立方程式を解く関数にする。

def gauss_elimination(A, b):
    """ガウス消去法で Ax = b を解く"""
    n = len(A)
    
    # 拡大係数行列を作成(元の行列を変更しないようコピー)
    Ab = [A[i][:] + [b[i]] for i in range(n)]
    
    # 前進消去
    for i in range(n):
        # ピボット選択(部分選択)
        max_row = i
        for k in range(i + 1, n):
            if abs(Ab[k][i]) > abs(Ab[max_row][i]):
                max_row = k
        Ab[i], Ab[max_row] = Ab[max_row], Ab[i]
        
        if Ab[i][i] == 0:
            raise ValueError("解が一意に定まらない")
        
        # 消去
        for j in range(i + 1, n):
            factor = Ab[j][i] / Ab[i][i]
            for k in range(i, n + 1):
                Ab[j][k] -= factor * Ab[i][k]
    
    # 後退代入
    x = [0] * n
    for i in range(n - 1, -1, -1):
        x[i] = Ab[i][n]
        for j in range(i + 1, n):
            x[i] -= Ab[i][j] * x[j]
        x[i] /= Ab[i][i]
    
    return x

A = [[2, 1, -1], [-3, -1, 2], [-2, 1, 2]]
b = [8, -11, -3]

x = gauss_elimination(A, b)
print(x)  # [2.0, 3.0, -1.0]

ピボット選択の重要性

単純なガウス消去法では、ピボットが 0 または非常に小さいと計算が不安定になる。部分ピボット選択では、各ステップで絶対値が最大の要素をピボットに選ぶ。

ピボット選択なし

ピボットが 0 だと破綻。小さいピボットで桁落ちが発生。

部分ピボット選択

絶対値最大の要素をピボットにする。数値的に安定。

上の実装では部分ピボット選択を行っている。

計算量

ガウス消去法の計算量は だ。前進消去で約 回、後退代入で約 回の演算を行う。

import time

def measure_time(n):
    """n×n の連立方程式を解く時間を測定"""
    import random
    A = [[random.random() for _ in range(n)] for _ in range(n)]
    b = [random.random() for _ in range(n)]
    
    start = time.time()
    gauss_elimination(A, b)
    return time.time() - start

for n in [100, 200, 400]:
    t = measure_time(n)
    print(f'n={n}: {t:.3f} 秒')

を 2 倍にすると時間は約 8 倍になる()。

LU 分解との関係

ガウス消去法は実質的に LU 分解を行っている。 と分解すると、 は消去に使った係数を記録した下三角行列、 は上三角行列になる。

def lu_decomposition(A):
    """LU 分解(ピボット選択なし)"""
    n = len(A)
    L = [[0] * n for _ in range(n)]
    U = [row[:] for row in A]  # A のコピー
    
    for i in range(n):
        L[i][i] = 1
        for j in range(i + 1, n):
            factor = U[j][i] / U[i][i]
            L[j][i] = factor
            for k in range(i, n):
                U[j][k] -= factor * U[i][k]
    
    return L, U

A = [[2, 1, -1], [-3, -1, 2], [-2, 1, 2]]
L, U = lu_decomposition(A)

print("L:")
for row in L:
    print([f'{x:6.2f}' for x in row])

print("U:")
for row in U:
    print([f'{x:6.2f}' for x in row])

同じ に対して複数の を解く場合、LU 分解を一度行えば、各 に対して で解ける。

実務では NumPy / SciPy

実務では numpy.linalg.solvescipy.linalg.solve を使う。高度に最適化されており、LAPACK を内部で呼び出す。

import numpy as np

A = np.array([[2, 1, -1], [-3, -1, 2], [-2, 1, 2]], dtype=float)
b = np.array([8, -11, -3], dtype=float)

x = np.linalg.solve(A, b)
print(x)  # [ 2.  3. -1.]

純粋な Python 実装はアルゴリズムの理解に役立ち、NumPy が何をしているかを知る基礎になる。