From a4c93af5e5b81ef4500229d337ae5481adcfc443 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 5 Jun 2024 14:24:17 +0100 Subject: [PATCH] feat(hugr-py): store children in node weight Could instead bit the bullet and iterate through every node for calculating the children of a node, but this implementation seems like a decent code complexity/runtime balance Closes #1159 --- hugr-py/src/hugr/_exceptions.py | 6 +++++ hugr-py/src/hugr/_hugr.py | 41 +++++++++++++++++++++++++------- hugr-py/tests/test_hugr_build.py | 6 +++-- 3 files changed, 42 insertions(+), 11 deletions(-) create mode 100644 hugr-py/src/hugr/_exceptions.py diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py new file mode 100644 index 000000000..72ca735c8 --- /dev/null +++ b/hugr-py/src/hugr/_exceptions.py @@ -0,0 +1,6 @@ +from dataclasses import dataclass + + +@dataclass +class ParentBeforeChild(Exception): + msg: str = "Parent node must be added before child node." diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index ac8766aea..f26dfdd03 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -20,6 +20,7 @@ from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.tys import Type, FunctionType from hugr._ops import Op, Input, Output, DFG, Command +from ._exceptions import ParentBeforeChild from hugr.utils import BiMap @@ -108,6 +109,7 @@ class NodeData: _num_inps: int = 0 _num_outs: int = 0 # TODO children field? + children: list[Node] = field(default_factory=list) def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: o = self.op.to_serial(node, self.parent if self.parent else node, hugr) @@ -143,7 +145,7 @@ def __init__(self, root_op: Op) -> None: self._free_nodes = [] self._links = BiMap() self._nodes = [] - self.root = self.add_node(root_op) + self.root = self._add_node(root_op, None, 0) def __getitem__(self, key: Node) -> NodeData: try: @@ -160,7 +162,11 @@ def __iter__(self): def __len__(self) -> int: return self.num_nodes() - def add_node( + def children(self, node: Node | None = None) -> list[Node]: + node = node or self.root + return self[node].children + + def _add_node( self, op: Op, parent: Node | None = None, @@ -174,9 +180,24 @@ def add_node( else: node = Node(len(self._nodes)) self._nodes.append(node_data) - return replace(node, _num_out_ports=num_outs) + node = replace(node, _num_out_ports=num_outs) + if parent: + self[parent].children.append(node) + return node + + def add_node( + self, + op: Op, + parent: Node | None = None, + num_outs: int | None = None, + ) -> Node: + parent = parent or self.root + return self._add_node(op, parent, num_outs) def delete_node(self, node: Node) -> NodeData | None: + parent = self[node].parent + if parent: + self[parent].children.remove(node) for offset in range(self.num_in_ports(node)): self._links.delete_right(_SubPort(node.inp(offset))) for offset in range(self.num_out_ports(node)): @@ -289,12 +310,14 @@ def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node for idx, node_data in enumerate(hugr._nodes): if node_data is not None: - mapping[Node(idx)] = self.add_node(node_data.op, node_data.parent) - - for new_node in mapping.values(): - # update mapped parent - node_data = self[new_node] - node_data.parent = mapping[node_data.parent] if node_data.parent else parent + # relies on parents being inserted before any children + try: + node_parent = ( + mapping[node_data.parent] if node_data.parent else parent + ) + except KeyError as e: + raise ParentBeforeChild() from e + mapping[Node(idx)] = self.add_node(node_data.op, node_parent) for src, dst in hugr._links.items(): self.add_link( diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 5c78a145b..e84d58bc5 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -87,12 +87,14 @@ def test_stable_indices(): assert len(h) == 4 h.add_link(nodes[0].out(0), nodes[1].inp(0)) + assert h.children() == nodes assert h.num_outgoing(nodes[0]) == 1 assert h.num_incoming(nodes[1]) == 1 assert h.delete_node(nodes[1]) is not None assert h._nodes[nodes[1].idx] is None + assert nodes[1] not in h.children(h.root) assert len(h) == 3 assert len(h._nodes) == 4 @@ -203,7 +205,7 @@ def test_insert_nested(): (a,) = h.inputs() nested = h.insert_nested(h1, a) h.set_outputs(nested) - + assert len(h.hugr.children(nested)) == 3 _validate(h.hugr) @@ -218,7 +220,7 @@ def _nested_nop(dfg: Dfg): nested = h.add_nested([BOOL_T], [BOOL_T], a) _nested_nop(nested) - + assert len(h.hugr.children(nested.root)) == 3 h.set_outputs(nested.root) _validate(h.hugr)