Python の集合にカスタムオブジェクトを入れる(__hash__ と __eq__)
カスタムクラスのインスタンスを集合に入れるには、__hash__() と __eq__() を正しく実装する必要がある。この 2 つのメソッドがハッシュテーブルの動作を決定する。
デフォルトの挙動
カスタムクラスは、何も定義しなければデフォルトでハッシュ可能だ。id() に基づくハッシュ値が使われる。
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
p1 = Point(1, 2)
p2 = Point(1, 2)
s = {p1, p2}
print(len(s)) # 2(同じ座標でも別オブジェクト)
同じ座標を持つ 2 つの Point が、別々の要素として扱われている。これは id() が異なるためだ。
eq を定義すると hash が消える
__eq__() を定義すると、__hash__ は自動的に None になる。
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
return self.x == other.x and self.y == other.y
p = Point(1, 2)
print(hash(p)) # TypeError: unhashable type: 'Point'
これは安全のための仕様だ。__eq__() で等しいと判定されるオブジェクトは、同じハッシュ値を持たなければならない。デフォルトの id() ベースのハッシュはこの条件を満たさないため、自動的に無効化される。
hash と eq の両方を実装
集合に入れるには、両方のメソッドを整合性を持って実装する必要がある。
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
if not isinstance(other, Point):
return NotImplemented
return self.x == other.x and self.y == other.y
def __hash__(self):
return hash((self.x, self.y))
p1 = Point(1, 2)
p2 = Point(1, 2)
p3 = Point(3, 4)
s = {p1, p2, p3}
print(len(s)) # 2(p1 と p2 は同じ要素)
hash((self.x, self.y)) のようにタプルを使うのが簡単で安全だ。タプルのハッシュ関数が複数の値を適切に混合してくれる。
整合性のルール
__hash__() と __eq__() には守るべきルールがある。
a == b なら hash(a) == hash(b) でなければならない。逆は成り立たなくてよい。
オブジェクトの生存期間中、ハッシュ値が変わってはならない。
違反すると、集合が正しく動作しなくなる。
# 悪い例:ハッシュ値が変わる
class BadPoint:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
return self.x == other.x and self.y == other.y
def __hash__(self):
return hash((self.x, self.y))
s = set()
p = BadPoint(1, 2)
s.add(p)
print(p in s) # True
p.x = 10 # ハッシュ値が変わる!
print(p in s) # False(見つからない)
print(len(s)) # 1(要素は存在する)
ハッシュ値が変わると、正しいスロットを探索できなくなる。
イミュータブルなクラスの設計
安全のため、ハッシュ可能なクラスはイミュータブルにすべきだ。
class ImmutablePoint:
__slots__ = ('_x', '_y')
def __init__(self, x, y):
object.__setattr__(self, '_x', x)
object.__setattr__(self, '_y', y)
@property
def x(self):
return self._x
@property
def y(self):
return self._y
def __setattr__(self, name, value):
raise AttributeError("Cannot modify immutable object")
def __eq__(self, other):
if not isinstance(other, ImmutablePoint):
return NotImplemented
return self._x == other._x and self._y == other._y
def __hash__(self):
return hash((self._x, self._y))
__slots__ と __setattr__ のオーバーライドで変更を防いでいる。
dataclass を使う方法
Python 3.7 以降では dataclass を使うと簡潔に書ける。
from dataclasses import dataclass
@dataclass(frozen=True)
class Point:
x: int
y: int
p1 = Point(1, 2)
p2 = Point(1, 2)
s = {p1, p2}
print(len(s)) # 1
print(hash(p1) == hash(p2)) # True
frozen=True を指定すると、イミュータブルなクラスになり、__hash__() も自動生成される。
NamedTuple を使う方法
NamedTuple も選択肢だ。
from typing import NamedTuple
class Point(NamedTuple):
x: int
y: int
p1 = Point(1, 2)
p2 = Point(1, 2)
s = {p1, p2}
print(len(s)) # 1
NamedTuple はタプルのサブクラスなので、自動的にハッシュ可能でイミュータブルだ。
柔軟性が高い。メソッドを追加しやすい。
軽量でシンプル。タプルとの互換性がある。
継承時の注意
サブクラスで属性を追加する場合、ハッシュ関数も更新が必要だ。
@dataclass(frozen=True)
class Point:
x: int
y: int
@dataclass(frozen=True)
class Point3D(Point):
z: int
p1 = Point3D(1, 2, 3)
p2 = Point3D(1, 2, 4)
print(p1 == p2) # False
print(hash(p1) == hash(p2)) # False(z も含まれる)
dataclass を使えば、継承時も自動的にすべてのフィールドがハッシュに含まれる。
一部の属性だけをハッシュに使う
識別に使わない属性は、ハッシュから除外できる。
from dataclasses import dataclass, field
@dataclass(frozen=True)
class User:
id: int
name: str = field(compare=False, hash=False)
u1 = User(1, "Alice")
u2 = User(1, "Bob")
print(u1 == u2) # True(id のみで比較)
print(hash(u1) == hash(u2)) # True
s = {u1, u2}
print(len(s)) # 1
field(compare=False, hash=False) で、その属性を等価判定とハッシュ計算から除外している。キャッシュや一時的な状態を持つオブジェクトで有用だ。