-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
misc: add a DisjointSet data structure (#3621)
This is useful in a few places, notably in bufferization.
- Loading branch information
1 parent
d43ac2f
commit fd6296d
Showing
2 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import pytest | ||
|
||
from xdsl.utils.disjoint_set import DisjointSet, IntDisjointSet | ||
|
||
|
||
def test_disjoint_set_init(): | ||
ds = IntDisjointSet(size=5) | ||
assert ds.value_count() == 5 | ||
# Each element should start in its own set | ||
for i in range(5): | ||
assert ds[i] == i | ||
|
||
|
||
def test_disjoint_set_add(): | ||
ds = IntDisjointSet(size=2) | ||
assert ds.value_count() == 2 | ||
|
||
new_val = ds.add() | ||
assert new_val == 2 | ||
assert ds.value_count() == 3 | ||
assert ds[new_val] == new_val | ||
|
||
|
||
def test_disjoint_set_find_invalid(): | ||
ds = IntDisjointSet(size=3) | ||
with pytest.raises(KeyError): | ||
ds[3] | ||
with pytest.raises(KeyError): | ||
ds[-1] | ||
|
||
|
||
def test_disjoint_set_union(): | ||
ds = IntDisjointSet(size=4) | ||
|
||
# Union 0 and 1 | ||
assert ds.union(0, 1) | ||
root = ds[0] | ||
assert ds[1] == root | ||
assert ds.connected(0, 1) | ||
assert not ds.connected(0, 2) | ||
|
||
# Union 2 and 3 | ||
assert ds.union(2, 3) | ||
root2 = ds[2] | ||
assert ds[3] == root2 | ||
assert ds.connected(2, 3) | ||
assert not ds.connected(1, 2) | ||
|
||
# Union already connected elements | ||
assert not ds.union(0, 1) | ||
assert ds.connected(0, 1) | ||
|
||
# Union two sets | ||
assert ds.union(1, 2) | ||
final_root = ds[0] | ||
assert ds[1] == final_root | ||
assert ds[2] == final_root | ||
assert ds[3] == final_root | ||
# After unioning all elements, they should all be connected | ||
assert ds.connected(0, 1) | ||
assert ds.connected(1, 2) | ||
assert ds.connected(2, 3) | ||
assert ds.connected(0, 3) | ||
|
||
|
||
def test_disjoint_set_path_compression(): | ||
ds = IntDisjointSet(size=4) | ||
|
||
# Create a chain: 3->2->1->0 | ||
ds._parent = [0, 0, 1, 2] # pyright: ignore[reportPrivateUsage] | ||
ds._count = [4, 3, 2, 1] # pyright: ignore[reportPrivateUsage] | ||
|
||
# Find should compress the path | ||
root = ds[3] | ||
# After compression, all nodes should point directly to root | ||
assert ds._parent[3] == root # pyright: ignore[reportPrivateUsage] | ||
assert ds._parent[2] == root # pyright: ignore[reportPrivateUsage] | ||
assert ds._parent[1] == root # pyright: ignore[reportPrivateUsage] | ||
assert ds._parent[0] == root # pyright: ignore[reportPrivateUsage] | ||
|
||
|
||
def test_generic_disjoint_set(): | ||
ds = DisjointSet(["a", "b", "c", "d"]) | ||
|
||
# Union a and b | ||
assert ds.union("a", "b") | ||
root = ds.find("a") | ||
assert ds.find("b") == root | ||
assert ds.connected("a", "b") | ||
assert not ds.connected("a", "c") | ||
|
||
# Union c and d | ||
assert ds.union("c", "d") | ||
root2 = ds.find("c") | ||
assert ds.find("d") == root2 | ||
assert ds.connected("c", "d") | ||
assert not ds.connected("b", "c") | ||
|
||
# Union already connected elements | ||
assert not ds.union("a", "b") | ||
assert ds.connected("a", "b") | ||
|
||
# Union two sets | ||
assert ds.union("b", "c") | ||
final_root = ds.find("a") | ||
assert ds.find("b") == final_root | ||
assert ds.find("c") == final_root | ||
assert ds.find("d") == final_root | ||
# After unioning all elements, they should all be connected | ||
assert ds.connected("a", "b") | ||
assert ds.connected("b", "c") | ||
assert ds.connected("c", "d") | ||
assert ds.connected("a", "d") | ||
|
||
|
||
def test_generic_disjoint_set_add(): | ||
ds = DisjointSet(["a", "b"]) | ||
ds.add("c") | ||
ds.add("d") | ||
|
||
assert ds.union("a", "c") | ||
root = ds.find("a") | ||
assert ds.find("c") == root | ||
|
||
assert ds.union("b", "d") | ||
root2 = ds.find("b") | ||
assert ds.find("d") == root2 | ||
|
||
|
||
def test_generic_disjoint_set_find_invalid(): | ||
ds = DisjointSet(["a", "b", "c"]) | ||
with pytest.raises(KeyError): | ||
ds.find("d") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
""" | ||
Generic implementation of a disjoint set data structure. | ||
https://en.wikipedia.org/wiki/Disjoint-set_data_structure | ||
""" | ||
|
||
from collections.abc import Hashable, Sequence | ||
from typing import Generic, TypeVar | ||
|
||
|
||
class IntDisjointSet: | ||
""" | ||
Represents a collection of disjoint sets of integers. | ||
The integers stored are always in the range [0,n), where n is the number of elements | ||
in this structure. | ||
This implementation uses path compression and union by size for efficiency. | ||
The amortized time complexity for operations is nearly constant. | ||
""" | ||
|
||
_parent: list[int] | ||
""" | ||
Index of the parent node. If the node is its own parent then it is a root node. | ||
""" | ||
_count: list[int] | ||
""" | ||
If the node is a root node, the corresponding value is the count of elements in the | ||
set. For non-root nodes, these counts may be stale and should not be used. | ||
""" | ||
|
||
def __init__(self, *, size: int) -> None: | ||
""" | ||
Initialize disjoint sets with elements [0,size). | ||
Each element starts in its own singleton set. | ||
""" | ||
self._parent = list(range(size)) | ||
self._count = [1] * size | ||
|
||
def value_count(self) -> int: | ||
"""Number of nodes in this structure.""" | ||
return len(self._parent) | ||
|
||
def add(self) -> int: | ||
""" | ||
Add a new element to this set as a singleton. | ||
Returns the added value, which will be equal to the previous size. | ||
""" | ||
res = len(self._parent) | ||
self._parent.append(res) | ||
self._count.append(1) | ||
return res | ||
|
||
def __getitem__(self, value: int) -> int: | ||
""" | ||
Returns the root/representative value of this set. | ||
Uses path compression - updates parent pointers to point directly to the root | ||
as we traverse up the tree, improving amortized performance. | ||
""" | ||
if value < 0 or len(self._parent) <= value: | ||
raise KeyError(f"Index {value} not found") | ||
|
||
# Find the root | ||
root = value | ||
while self._parent[root] != root: | ||
root = self._parent[root] | ||
|
||
# Path compression - point all nodes on path to root | ||
current = value | ||
while current != root: | ||
next_parent = self._parent[current] | ||
self._parent[current] = root | ||
current = next_parent | ||
|
||
return root | ||
|
||
def union(self, lhs: int, rhs: int) -> bool: | ||
""" | ||
Merges the sets containing lhs and rhs if they are different. | ||
Returns True if the sets were merged, False if they were already the same set. | ||
Uses union by size - the smaller tree is attached to the larger tree's root | ||
to maintain balance. This ensures the maximum tree height is O(log n). | ||
""" | ||
lhs_root = self[lhs] | ||
rhs_root = self[rhs] | ||
if lhs_root == rhs_root: | ||
return False | ||
|
||
lhs_count = self._count[lhs_root] | ||
rhs_count = self._count[rhs_root] | ||
# Choose the root of the larger tree as the new parent | ||
new_parent, new_child = ( | ||
(lhs_root, rhs_root) if lhs_count <= rhs_count else (rhs_root, lhs_root) | ||
) | ||
self._parent[new_child] = new_parent | ||
self._count[new_parent] = lhs_count + rhs_count | ||
# Note: We don't need to update _count[new_child] since it's no longer a root | ||
return True | ||
|
||
def connected(self, lhs: int, rhs: int) -> bool: | ||
return self[lhs] == self[rhs] | ||
|
||
|
||
_T = TypeVar("_T", bound=Hashable) | ||
|
||
|
||
class DisjointSet(Generic[_T]): | ||
""" | ||
A disjoint-set data structure that works with arbitrary hashable values. | ||
Internally uses IntDisjointSet by mapping values to integer indices. | ||
""" | ||
|
||
_base: IntDisjointSet | ||
_values: list[_T] | ||
_index_by_value: dict[_T, int] | ||
|
||
def __init__(self, values: Sequence[_T] = ()): | ||
""" | ||
Initialize a DisjointSet with the given sequence of values. | ||
Each value starts in its own singleton set. | ||
Args: | ||
values: Initial sequence of values to add to the disjoint set | ||
""" | ||
self._values = list(values) | ||
self._index_by_value = {v: i for i, v in enumerate(self._values)} | ||
self._base = IntDisjointSet(size=len(self._values)) | ||
|
||
def __len__(self): | ||
return len(self._values) | ||
|
||
def add(self, value: _T): | ||
""" | ||
Add a new value to the disjoint set in its own singleton set. | ||
Args: | ||
value: The value to add | ||
""" | ||
index = self._base.add() | ||
self._values.append(value) | ||
self._index_by_value[value] = index | ||
|
||
def find(self, value: _T) -> _T: | ||
""" | ||
Find the representative value for the set containing the given value. | ||
Returns the representative value for the set. | ||
Raises: | ||
KeyError: If the value is not in the disjoint set | ||
""" | ||
index = self._base[self._index_by_value[value]] | ||
return self._values[index] | ||
|
||
def union(self, lhs: _T, rhs: _T) -> bool: | ||
""" | ||
Merge the sets containing the two given values if they are different. | ||
Returns `True` if the sets were merged, `False` if they were already the same set. | ||
Raises: | ||
KeyError: If either value is not in the disjoint set | ||
""" | ||
return self._base.union(self._index_by_value[lhs], self._index_by_value[rhs]) | ||
|
||
def connected(self, lhs: _T, rhs: _T) -> bool: | ||
""" | ||
Returns `True` if the values are in the same set. | ||
Raises: | ||
KeyError: If either value is not in the disjoint set | ||
""" | ||
return self._base.connected( | ||
self._index_by_value[lhs], self._index_by_value[rhs] | ||
) |