diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index ea393a373..c98cdb6ad 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -138,7 +138,7 @@ def __iter__(self): return iter(self._nodes) def __len__(self) -> int: - return len(self._nodes) - len(self._free_nodes) + return self.num_nodes() def add_node( self, @@ -155,15 +155,17 @@ def add_node( self._nodes.append(node_data) return node - def delete_node(self, node: Node) -> None: + def delete_node(self, node: Node) -> NodeData | None: for offset in range(self.num_in_ports(node)): self._links.delete_right(node.inp(offset)) for offset in range(self.num_out_ports(node)): self._links.delete_left(node.out(offset)) + weight = self._nodes[node.idx] self._nodes[node.idx] = None self._free_nodes.append(node) + return weight - def add_link(self, src: OutPort, dst: InPort, ty: Type | None = None) -> None: + def add_link(self, src: OutPort, dst: InPort) -> None: src = _unused_sub_offset(src, self._links) dst = _unused_sub_offset(dst, self._links) if self._links.get_left(dst) is not None: @@ -173,6 +175,9 @@ def add_link(self, src: OutPort, dst: InPort, ty: Type | None = None) -> None: def delete_link(self, src: OutPort, dst: InPort) -> None: self._links.delete_left(src) + def num_nodes(self) -> int: + return len(self._nodes) - len(self._free_nodes) + def num_in_ports(self, node: Node) -> int: return len(self.in_ports(node)) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 03092b072..fe703d405 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -3,6 +3,7 @@ from hugr.hugr import Dfg, Hugr, DummyOp, Node import hugr.serialization.tys as stys import hugr.serialization.ops as sops +import pytest BOOL_T = stys.Type(stys.SumType(stys.UnitSum(size=2))) QB_T = stys.Type(stys.Qubit()) @@ -38,7 +39,6 @@ def _validate(h: Hugr, mermaid: bool = False): - # TODO point to built hugr binary # cmd = ["cargo", "run", "--features", "cli", "--"] cmd = ["./target/debug/hugr"] @@ -47,6 +47,38 @@ def _validate(h: Hugr, mermaid: bool = False): subprocess.run(cmd + ["-"], check=True, input=h.to_serial().to_json().encode()) +def test_stable_indices(): + h = Hugr(DummyOp(sops.DFG(parent=-1))) + + nodes = [h.add_node(NOT_OP) for _ in range(3)] + assert len(h) == 4 + + h.add_link(nodes[0].out(0), nodes[1].inp(0)) + + assert h.num_out_ports(nodes[0]) == 1 + assert h.num_in_ports(nodes[1]) == 1 + + assert h.delete_node(nodes[1]) is not None + + assert len(h) == 3 + assert len(h._nodes) == 4 + assert h._free_nodes == [nodes[1]] + + assert h.num_out_ports(nodes[0]) == 0 + assert h.num_in_ports(nodes[1]) == 0 + + with pytest.raises(KeyError): + _ = h[nodes[1]] + with pytest.raises(KeyError): + _ = h[Node(46)] + + new_n = h.add_node(NOT_OP) + assert new_n == nodes[1] + + assert len(h) == 4 + assert h._free_nodes == [] + + def test_simple_id(): h = Dfg.endo([QB_T] * 2) a, b = h.inputs()