diff --git a/tests/utils/test_disjoint_set.py b/tests/utils/test_disjoint_set.py new file mode 100644 index 0000000000..3422e1acb4 --- /dev/null +++ b/tests/utils/test_disjoint_set.py @@ -0,0 +1,133 @@ +import pytest + +from xdsl.utils.disjoint_set import DisjointSet, IntDisjointSet + + +def test_disjoint_set_init(): + ds = IntDisjointSet(size=5) + assert ds.value_count() == 5 + # Each element should start in its own set + for i in range(5): + assert ds[i] == i + + +def test_disjoint_set_add(): + ds = IntDisjointSet(size=2) + assert ds.value_count() == 2 + + new_val = ds.add() + assert new_val == 2 + assert ds.value_count() == 3 + assert ds[new_val] == new_val + + +def test_disjoint_set_find_invalid(): + ds = IntDisjointSet(size=3) + with pytest.raises(KeyError): + ds[3] + with pytest.raises(KeyError): + ds[-1] + + +def test_disjoint_set_union(): + ds = IntDisjointSet(size=4) + + # Union 0 and 1 + assert ds.union(0, 1) + root = ds[0] + assert ds[1] == root + assert ds.connected(0, 1) + assert not ds.connected(0, 2) + + # Union 2 and 3 + assert ds.union(2, 3) + root2 = ds[2] + assert ds[3] == root2 + assert ds.connected(2, 3) + assert not ds.connected(1, 2) + + # Union already connected elements + assert not ds.union(0, 1) + assert ds.connected(0, 1) + + # Union two sets + assert ds.union(1, 2) + final_root = ds[0] + assert ds[1] == final_root + assert ds[2] == final_root + assert ds[3] == final_root + # After unioning all elements, they should all be connected + assert ds.connected(0, 1) + assert ds.connected(1, 2) + assert ds.connected(2, 3) + assert ds.connected(0, 3) + + +def test_disjoint_set_path_compression(): + ds = IntDisjointSet(size=4) + + # Create a chain: 3->2->1->0 + ds._parent = [0, 0, 1, 2] # pyright: ignore[reportPrivateUsage] + ds._count = [4, 3, 2, 1] # pyright: ignore[reportPrivateUsage] + + # Find should compress the path + root = ds[3] + # After compression, all nodes should point directly to root + assert ds._parent[3] == root # pyright: ignore[reportPrivateUsage] + assert ds._parent[2] == root # pyright: ignore[reportPrivateUsage] + assert ds._parent[1] == root # pyright: ignore[reportPrivateUsage] + assert ds._parent[0] == root # pyright: ignore[reportPrivateUsage] + + +def test_generic_disjoint_set(): + ds = DisjointSet(["a", "b", "c", "d"]) + + # Union a and b + assert ds.union("a", "b") + root = ds.find("a") + assert ds.find("b") == root + assert ds.connected("a", "b") + assert not ds.connected("a", "c") + + # Union c and d + assert ds.union("c", "d") + root2 = ds.find("c") + assert ds.find("d") == root2 + assert ds.connected("c", "d") + assert not ds.connected("b", "c") + + # Union already connected elements + assert not ds.union("a", "b") + assert ds.connected("a", "b") + + # Union two sets + assert ds.union("b", "c") + final_root = ds.find("a") + assert ds.find("b") == final_root + assert ds.find("c") == final_root + assert ds.find("d") == final_root + # After unioning all elements, they should all be connected + assert ds.connected("a", "b") + assert ds.connected("b", "c") + assert ds.connected("c", "d") + assert ds.connected("a", "d") + + +def test_generic_disjoint_set_add(): + ds = DisjointSet(["a", "b"]) + ds.add("c") + ds.add("d") + + assert ds.union("a", "c") + root = ds.find("a") + assert ds.find("c") == root + + assert ds.union("b", "d") + root2 = ds.find("b") + assert ds.find("d") == root2 + + +def test_generic_disjoint_set_find_invalid(): + ds = DisjointSet(["a", "b", "c"]) + with pytest.raises(KeyError): + ds.find("d") diff --git a/xdsl/utils/disjoint_set.py b/xdsl/utils/disjoint_set.py new file mode 100644 index 0000000000..1cf7e0857e --- /dev/null +++ b/xdsl/utils/disjoint_set.py @@ -0,0 +1,175 @@ +""" +Generic implementation of a disjoint set data structure. + +https://en.wikipedia.org/wiki/Disjoint-set_data_structure +""" + +from collections.abc import Hashable, Sequence +from typing import Generic, TypeVar + + +class IntDisjointSet: + """ + Represents a collection of disjoint sets of integers. + The integers stored are always in the range [0,n), where n is the number of elements + in this structure. + + This implementation uses path compression and union by size for efficiency. + The amortized time complexity for operations is nearly constant. + """ + + _parent: list[int] + """ + Index of the parent node. If the node is its own parent then it is a root node. + """ + _count: list[int] + """ + If the node is a root node, the corresponding value is the count of elements in the + set. For non-root nodes, these counts may be stale and should not be used. + """ + + def __init__(self, *, size: int) -> None: + """ + Initialize disjoint sets with elements [0,size). + Each element starts in its own singleton set. + """ + self._parent = list(range(size)) + self._count = [1] * size + + def value_count(self) -> int: + """Number of nodes in this structure.""" + return len(self._parent) + + def add(self) -> int: + """ + Add a new element to this set as a singleton. + Returns the added value, which will be equal to the previous size. + """ + res = len(self._parent) + self._parent.append(res) + self._count.append(1) + return res + + def __getitem__(self, value: int) -> int: + """ + Returns the root/representative value of this set. + Uses path compression - updates parent pointers to point directly to the root + as we traverse up the tree, improving amortized performance. + """ + if value < 0 or len(self._parent) <= value: + raise KeyError(f"Index {value} not found") + + # Find the root + root = value + while self._parent[root] != root: + root = self._parent[root] + + # Path compression - point all nodes on path to root + current = value + while current != root: + next_parent = self._parent[current] + self._parent[current] = root + current = next_parent + + return root + + def union(self, lhs: int, rhs: int) -> bool: + """ + Merges the sets containing lhs and rhs if they are different. + Returns True if the sets were merged, False if they were already the same set. + + Uses union by size - the smaller tree is attached to the larger tree's root + to maintain balance. This ensures the maximum tree height is O(log n). + """ + lhs_root = self[lhs] + rhs_root = self[rhs] + if lhs_root == rhs_root: + return False + + lhs_count = self._count[lhs_root] + rhs_count = self._count[rhs_root] + # Choose the root of the larger tree as the new parent + new_parent, new_child = ( + (lhs_root, rhs_root) if lhs_count <= rhs_count else (rhs_root, lhs_root) + ) + self._parent[new_child] = new_parent + self._count[new_parent] = lhs_count + rhs_count + # Note: We don't need to update _count[new_child] since it's no longer a root + return True + + def connected(self, lhs: int, rhs: int) -> bool: + return self[lhs] == self[rhs] + + +_T = TypeVar("_T", bound=Hashable) + + +class DisjointSet(Generic[_T]): + """ + A disjoint-set data structure that works with arbitrary hashable values. + Internally uses IntDisjointSet by mapping values to integer indices. + """ + + _base: IntDisjointSet + _values: list[_T] + _index_by_value: dict[_T, int] + + def __init__(self, values: Sequence[_T] = ()): + """ + Initialize a DisjointSet with the given sequence of values. + Each value starts in its own singleton set. + + Args: + values: Initial sequence of values to add to the disjoint set + """ + self._values = list(values) + self._index_by_value = {v: i for i, v in enumerate(self._values)} + self._base = IntDisjointSet(size=len(self._values)) + + def __len__(self): + return len(self._values) + + def add(self, value: _T): + """ + Add a new value to the disjoint set in its own singleton set. + + Args: + value: The value to add + """ + index = self._base.add() + self._values.append(value) + self._index_by_value[value] = index + + def find(self, value: _T) -> _T: + """ + Find the representative value for the set containing the given value. + + Returns the representative value for the set. + + Raises: + KeyError: If the value is not in the disjoint set + """ + index = self._base[self._index_by_value[value]] + return self._values[index] + + def union(self, lhs: _T, rhs: _T) -> bool: + """ + Merge the sets containing the two given values if they are different. + + Returns `True` if the sets were merged, `False` if they were already the same set. + + Raises: + KeyError: If either value is not in the disjoint set + """ + return self._base.union(self._index_by_value[lhs], self._index_by_value[rhs]) + + def connected(self, lhs: _T, rhs: _T) -> bool: + """ + Returns `True` if the values are in the same set. + + Raises: + KeyError: If either value is not in the disjoint set + """ + return self._base.connected( + self._index_by_value[lhs], self._index_by_value[rhs] + )