Python でオブジェクトを比較するとき、==、<、> などの演算子は内部で「リッチ比較メソッド」を呼び出す。これらのメソッドをカスタマイズすることで、自作クラスのインスタンス同士を比較可能にできる。この記事では、リッチ比較メソッドの実装方法と、よくある落とし穴を解説する。
リッチ比較メソッド一覧
Python には 6 つのリッチ比較メソッドがある。
| 演算子 | メソッド | 意味 |
|---|---|---|
| == | __eq__ | 等しい |
| != | __ne__ | 等しくない |
| < | __lt__ | より小さい |
その他として __le__(以下)、__gt__(より大きい)、__ge__(以上)がある。
基本的な実装
バージョン番号を表すクラスを例に、リッチ比較メソッドを実装してみよう。
class Version:
def __init__(self, major, minor, patch):
self.major = major
self.minor = minor
self.patch = patch
def __repr__(self):
return f"Version({self.major}, {self.minor}, {self.patch})"
def _as_tuple(self):
return (self.major, self.minor, self.patch)
def __eq__(self, other):
if not isinstance(other, Version):
return NotImplemented
return self._as_tuple() == other._as_tuple()
def __lt__(self, other):
if not isinstance(other, Version):
return NotImplemented
return self._as_tuple() < other._as_tuple()
def __le__(self, other):
if not isinstance(other, Version):
return NotImplemented
return self._as_tuple() <= other._as_tuple()
def __gt__(self, other):
if not isinstance(other, Version):
return NotImplemented
return self._as_tuple() > other._as_tuple()
def __ge__(self, other):
if not isinstance(other, Version):
return NotImplemented
return self._as_tuple() >= other._as_tuple()
# 使用例
v1 = Version(1, 0, 0)
v2 = Version(2, 0, 0)
v3 = Version(1, 0, 0)
print(v1 < v2) # True
print(v1 == v3) # True
print(v2 >= v1) # True
NotImplemented の重要性
比較できない型が渡されたとき、False ではなく NotImplemented を返すことが重要だ。
class Money:
def __init__(self, amount):
self.amount = amount
def __eq__(self, other):
if not isinstance(other, Money):
return NotImplemented # False ではない!
return self.amount == other.amount
m = Money(100)
# NotImplemented を返すと、Python は other.__eq__(self) を試す
print(m == "100") # False(str.__eq__ が呼ばれて False)
print(m == 100) # False(int.__eq__ が呼ばれて False)
NotImplemented を返すと、Python は反対側のオブジェクトの比較メソッドを試す。これにより、異なる型同士の比較を柔軟に処理できる。
functools.total_ordering の活用
6 つのメソッドをすべて書くのは面倒だ。functools.total_ordering デコレータを使えば、__eq__ と順序比較メソッド 1 つ(__lt__ など)を定義するだけで残りが自動生成される。
from functools import total_ordering
@total_ordering
class Version:
def __init__(self, major, minor, patch):
self.major = major
self.minor = minor
self.patch = patch
def _as_tuple(self):
return (self.major, self.minor, self.patch)
def __eq__(self, other):
if not isinstance(other, Version):
return NotImplemented
return self._as_tuple() == other._as_tuple()
def __lt__(self, other):
if not isinstance(other, Version):
return NotImplemented
return self._as_tuple() < other._as_tuple()
# __le__, __gt__, __ge__ は自動生成される
v1 = Version(1, 0, 0)
v2 = Version(2, 0, 0)
print(v1 <= v2) # True(自動生成された __le__)
print(v1 > v2) # False(自動生成された __gt__)
ただし、total_ordering で生成されるメソッドは、手動実装より若干遅い。パフォーマンスが重要な場合は全メソッドを手動で実装する。
落とし穴 1:eq を定義すると hash が None になる
__eq__ を定義すると、デフォルトでは __hash__ が None に設定され、オブジェクトがハッシュ不可能になる。
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
if not isinstance(other, Point):
return NotImplemented
return self.x == other.x and self.y == other.y
p = Point(1, 2)
# hash(p) # TypeError: unhashable type: 'Point'
# set や dict のキーに使えない
# {p} # TypeError
ハッシュ可能にするには、__hash__ も定義する必要がある。
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
if not isinstance(other, Point):
return NotImplemented
return self.x == other.x and self.y == other.y
def __hash__(self):
return hash((self.x, self.y))
p1 = Point(1, 2)
p2 = Point(1, 2)
print(hash(p1) == hash(p2)) # True
print(len({p1, p2})) # 1(重複排除される)
ミュータブルなオブジェクトに __hash__ を定義するのは危険だ。属性が変更されるとハッシュ値が変わり、辞書や集合が壊れる可能性がある。
落とし穴 2:比較の非対称性
左辺と右辺で異なる型を比較する場合、呼ばれるメソッドが変わる。
class A:
def __eq__(self, other):
print("A.__eq__ called")
return True
class B:
def __eq__(self, other):
print("B.__eq__ called")
return False
a = A()
b = B()
print(a == b) # A.__eq__ called → True
print(b == a) # B.__eq__ called → False
サブクラスがある場合は、サブクラスのメソッドが優先される。
class Parent:
def __eq__(self, other):
print("Parent.__eq__")
return True
class Child(Parent):
def __eq__(self, other):
print("Child.__eq__")
return False
p = Parent()
c = Child()
print(p == c) # Child.__eq__(サブクラスが優先)
print(c == p) # Child.__eq__
落とし穴 3:循環参照での無限ループ
比較メソッド内で他のメソッドを呼ぶ場合、無限ループに注意。
# 悪い例:無限ループ
class BadCompare:
def __init__(self, value):
self.value = value
def __eq__(self, other):
return not (self != other) # __ne__ を呼ぶ
def __ne__(self, other):
return not (self == other) # __eq__ を呼ぶ → 無限ループ!
# a = BadCompare(1)
# b = BadCompare(1)
# a == b # RecursionError
Python 3 では __ne__ は __eq__ の否定として自動生成されるので、通常は __ne__ を手動定義する必要はない。
実践的な設計指針
リッチ比較メソッドを実装する際の指針をまとめる。
NotImplemented を適切に返す。__eq__ を定義したら __hash__ も検討する。ミュータブルなオブジェクトはハッシュ不可にする。
functools.total_ordering で実装を簡略化。タプル化して組み込み型の比較を活用。型チェックは isinstance で行う。
比較可能なクラスを設計する際は、これらの落とし穴を意識して実装しよう。