Skip to content

Commit

Permalink
misc: add a DisjointSet data structure (#3621)
Browse files Browse the repository at this point in the history
This is useful in a few places, notably in bufferization.
  • Loading branch information
superlopuh authored Dec 16, 2024
1 parent d43ac2f commit fd6296d
Show file tree
Hide file tree
Showing 2 changed files with 308 additions and 0 deletions.
133 changes: 133 additions & 0 deletions tests/utils/test_disjoint_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import pytest

from xdsl.utils.disjoint_set import DisjointSet, IntDisjointSet


def test_disjoint_set_init():
ds = IntDisjointSet(size=5)
assert ds.value_count() == 5
# Each element should start in its own set
for i in range(5):
assert ds[i] == i


def test_disjoint_set_add():
ds = IntDisjointSet(size=2)
assert ds.value_count() == 2

new_val = ds.add()
assert new_val == 2
assert ds.value_count() == 3
assert ds[new_val] == new_val


def test_disjoint_set_find_invalid():
ds = IntDisjointSet(size=3)
with pytest.raises(KeyError):
ds[3]
with pytest.raises(KeyError):
ds[-1]


def test_disjoint_set_union():
ds = IntDisjointSet(size=4)

# Union 0 and 1
assert ds.union(0, 1)
root = ds[0]
assert ds[1] == root
assert ds.connected(0, 1)
assert not ds.connected(0, 2)

# Union 2 and 3
assert ds.union(2, 3)
root2 = ds[2]
assert ds[3] == root2
assert ds.connected(2, 3)
assert not ds.connected(1, 2)

# Union already connected elements
assert not ds.union(0, 1)
assert ds.connected(0, 1)

# Union two sets
assert ds.union(1, 2)
final_root = ds[0]
assert ds[1] == final_root
assert ds[2] == final_root
assert ds[3] == final_root
# After unioning all elements, they should all be connected
assert ds.connected(0, 1)
assert ds.connected(1, 2)
assert ds.connected(2, 3)
assert ds.connected(0, 3)


def test_disjoint_set_path_compression():
ds = IntDisjointSet(size=4)

# Create a chain: 3->2->1->0
ds._parent = [0, 0, 1, 2] # pyright: ignore[reportPrivateUsage]
ds._count = [4, 3, 2, 1] # pyright: ignore[reportPrivateUsage]

# Find should compress the path
root = ds[3]
# After compression, all nodes should point directly to root
assert ds._parent[3] == root # pyright: ignore[reportPrivateUsage]
assert ds._parent[2] == root # pyright: ignore[reportPrivateUsage]
assert ds._parent[1] == root # pyright: ignore[reportPrivateUsage]
assert ds._parent[0] == root # pyright: ignore[reportPrivateUsage]


def test_generic_disjoint_set():
ds = DisjointSet(["a", "b", "c", "d"])

# Union a and b
assert ds.union("a", "b")
root = ds.find("a")
assert ds.find("b") == root
assert ds.connected("a", "b")
assert not ds.connected("a", "c")

# Union c and d
assert ds.union("c", "d")
root2 = ds.find("c")
assert ds.find("d") == root2
assert ds.connected("c", "d")
assert not ds.connected("b", "c")

# Union already connected elements
assert not ds.union("a", "b")
assert ds.connected("a", "b")

# Union two sets
assert ds.union("b", "c")
final_root = ds.find("a")
assert ds.find("b") == final_root
assert ds.find("c") == final_root
assert ds.find("d") == final_root
# After unioning all elements, they should all be connected
assert ds.connected("a", "b")
assert ds.connected("b", "c")
assert ds.connected("c", "d")
assert ds.connected("a", "d")


def test_generic_disjoint_set_add():
ds = DisjointSet(["a", "b"])
ds.add("c")
ds.add("d")

assert ds.union("a", "c")
root = ds.find("a")
assert ds.find("c") == root

assert ds.union("b", "d")
root2 = ds.find("b")
assert ds.find("d") == root2


def test_generic_disjoint_set_find_invalid():
ds = DisjointSet(["a", "b", "c"])
with pytest.raises(KeyError):
ds.find("d")
175 changes: 175 additions & 0 deletions xdsl/utils/disjoint_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
Generic implementation of a disjoint set data structure.
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
"""

from collections.abc import Hashable, Sequence
from typing import Generic, TypeVar


class IntDisjointSet:
"""
Represents a collection of disjoint sets of integers.
The integers stored are always in the range [0,n), where n is the number of elements
in this structure.
This implementation uses path compression and union by size for efficiency.
The amortized time complexity for operations is nearly constant.
"""

_parent: list[int]
"""
Index of the parent node. If the node is its own parent then it is a root node.
"""
_count: list[int]
"""
If the node is a root node, the corresponding value is the count of elements in the
set. For non-root nodes, these counts may be stale and should not be used.
"""

def __init__(self, *, size: int) -> None:
"""
Initialize disjoint sets with elements [0,size).
Each element starts in its own singleton set.
"""
self._parent = list(range(size))
self._count = [1] * size

def value_count(self) -> int:
"""Number of nodes in this structure."""
return len(self._parent)

def add(self) -> int:
"""
Add a new element to this set as a singleton.
Returns the added value, which will be equal to the previous size.
"""
res = len(self._parent)
self._parent.append(res)
self._count.append(1)
return res

def __getitem__(self, value: int) -> int:
"""
Returns the root/representative value of this set.
Uses path compression - updates parent pointers to point directly to the root
as we traverse up the tree, improving amortized performance.
"""
if value < 0 or len(self._parent) <= value:
raise KeyError(f"Index {value} not found")

# Find the root
root = value
while self._parent[root] != root:
root = self._parent[root]

# Path compression - point all nodes on path to root
current = value
while current != root:
next_parent = self._parent[current]
self._parent[current] = root
current = next_parent

return root

def union(self, lhs: int, rhs: int) -> bool:
"""
Merges the sets containing lhs and rhs if they are different.
Returns True if the sets were merged, False if they were already the same set.
Uses union by size - the smaller tree is attached to the larger tree's root
to maintain balance. This ensures the maximum tree height is O(log n).
"""
lhs_root = self[lhs]
rhs_root = self[rhs]
if lhs_root == rhs_root:
return False

lhs_count = self._count[lhs_root]
rhs_count = self._count[rhs_root]
# Choose the root of the larger tree as the new parent
new_parent, new_child = (
(lhs_root, rhs_root) if lhs_count <= rhs_count else (rhs_root, lhs_root)
)
self._parent[new_child] = new_parent
self._count[new_parent] = lhs_count + rhs_count
# Note: We don't need to update _count[new_child] since it's no longer a root
return True

def connected(self, lhs: int, rhs: int) -> bool:
return self[lhs] == self[rhs]


_T = TypeVar("_T", bound=Hashable)


class DisjointSet(Generic[_T]):
"""
A disjoint-set data structure that works with arbitrary hashable values.
Internally uses IntDisjointSet by mapping values to integer indices.
"""

_base: IntDisjointSet
_values: list[_T]
_index_by_value: dict[_T, int]

def __init__(self, values: Sequence[_T] = ()):
"""
Initialize a DisjointSet with the given sequence of values.
Each value starts in its own singleton set.
Args:
values: Initial sequence of values to add to the disjoint set
"""
self._values = list(values)
self._index_by_value = {v: i for i, v in enumerate(self._values)}
self._base = IntDisjointSet(size=len(self._values))

def __len__(self):
return len(self._values)

def add(self, value: _T):
"""
Add a new value to the disjoint set in its own singleton set.
Args:
value: The value to add
"""
index = self._base.add()
self._values.append(value)
self._index_by_value[value] = index

def find(self, value: _T) -> _T:
"""
Find the representative value for the set containing the given value.
Returns the representative value for the set.
Raises:
KeyError: If the value is not in the disjoint set
"""
index = self._base[self._index_by_value[value]]
return self._values[index]

def union(self, lhs: _T, rhs: _T) -> bool:
"""
Merge the sets containing the two given values if they are different.
Returns `True` if the sets were merged, `False` if they were already the same set.
Raises:
KeyError: If either value is not in the disjoint set
"""
return self._base.union(self._index_by_value[lhs], self._index_by_value[rhs])

def connected(self, lhs: _T, rhs: _T) -> bool:
"""
Returns `True` if the values are in the same set.
Raises:
KeyError: If either value is not in the disjoint set
"""
return self._base.connected(
self._index_by_value[lhs], self._index_by_value[rhs]
)

0 comments on commit fd6296d

Please sign in to comment.