From 5da06e10581cbfed583bd466b27706241341ff14 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 6 Jun 2024 14:44:28 +0100 Subject: [PATCH] feat(hugr-py): automatically add state order edges for inter-graph edges (#1165) uses a slightly dodgy -1 offset hack hack currently means if you query the hugr for outgoing edges the state order edges are ignored. --- hugr-py/src/hugr/_dfg.py | 22 +++++++++++++++---- hugr-py/src/hugr/_exceptions.py | 11 ++++++++++ hugr-py/src/hugr/_hugr.py | 36 +++++++++++++++++++++++++------- hugr-py/tests/test_hugr_build.py | 30 +++++++++++++++++++------- 4 files changed, 80 insertions(+), 19 deletions(-) create mode 100644 hugr-py/src/hugr/_exceptions.py diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index f4eda9826..f083e8ae0 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -4,6 +4,7 @@ from ._hugr import Hugr, Node, Wire, OutPort from ._ops import Op, Command, Input, Output, DFG +from ._exceptions import NoSiblingAncestor from hugr.serialization.tys import FunctionType, Type @@ -67,12 +68,25 @@ def set_outputs(self, *args: Wire) -> None: def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges - # breaks if further edges are added - self.hugr.add_link( - src.out(self.hugr.num_outgoing(src)), dst.inp(self.hugr.num_incoming(dst)) - ) + self.hugr.add_link(src.out(-1), dst.inp(-1)) def _wire_up(self, node: Node, ports: Iterable[Wire]): for i, p in enumerate(ports): src = p.out_port() + node_ancestor = _ancestral_sibling(self.hugr, src.node, node) + if node_ancestor is None: + raise NoSiblingAncestor(src.node.idx, node.idx) + if node_ancestor != node: + self.add_state_order(src.node, node_ancestor) self.hugr.add_link(src, node.inp(i)) + + +def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: + src_parent = h[src].parent + + while (tgt_parent := h[tgt].parent) is not None: + if tgt_parent == src_parent: + return tgt + tgt = tgt_parent + + return None diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py new file mode 100644 index 000000000..3245af0cc --- /dev/null +++ b/hugr-py/src/hugr/_exceptions.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + + +@dataclass +class NoSiblingAncestor(Exception): + src: int + tgt: int + + @property + def msg(self): + return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up." diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 21fc1e281..d42f2edf1 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -260,6 +260,12 @@ def linked_ports(self, port: OutPort | InPort): # TODO: single linked port + def outgoing_order_links(self, node: Node) -> Iterable[Node]: + return (p.node for p in self.linked_ports(node.out(-1))) + + def incoming_order_links(self, node: Node) -> Iterable[Node]: + return (p.node for p in self.linked_ports(node.inp(-1))) + def _node_links( self, node: Node, links: dict[_SubPort[P], _SubPort[K]] ) -> Iterable[tuple[P, list[K]]]: @@ -320,19 +326,35 @@ def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> def to_serial(self) -> SerialHugr: node_it = (node for node in self._nodes if node is not None) + + def _serialise_link( + link: tuple[_SO, _SI], + ) -> tuple[tuple[int, int], tuple[int, int]]: + src, dst = link + s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port) + return (src.port.node.idx, s), (dst.port.node.idx, d) + return SerialHugr( version="v1", # non contiguous indices will be erased nodes=[node.to_serial(Node(idx), self) for idx, node in enumerate(node_it)], - edges=[ - ( - (src.port.node.idx, src.port.offset), - (dst.port.node.idx, dst.port.offset), - ) - for src, dst in self._links.items() - ], + edges=[_serialise_link(link) for link in self._links.items()], ) + def _constrain_offset(self, p: P) -> int: + # negative offsets are used to refer to the last port + if p.offset < 0: + match p.direction: + case Direction.INCOMING: + current = self.num_incoming(p.node) + case Direction.OUTGOING: + current = self.num_outgoing(p.node) + offset = current + p.offset + 1 + else: + offset = p.offset + + return offset + @classmethod def from_serial(cls, serial: SerialHugr) -> Hugr: assert serial.nodes, "Empty Hugr is invalid" diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 8691e8e08..52f7a2b07 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -3,8 +3,8 @@ import subprocess import os import pathlib -from hugr._hugr import Hugr, Node, Wire -from hugr._dfg import Dfg +from hugr._hugr import Hugr, Node, Wire, _SubPort +from hugr._dfg import Dfg, _ancestral_sibling from hugr._ops import Custom, Command import hugr._ops as ops from hugr.serialization import SerialHugr @@ -226,15 +226,29 @@ def _nested_nop(dfg: Dfg): def test_build_inter_graph(): + h = Dfg.endo([BOOL_T, BOOL_T]) + (a, b) = h.inputs() + nested = h.add_nested([], [BOOL_T]) + + nt = nested.add(Not(a)) + nested.set_outputs(nt) + + h.set_outputs(nested.root, b) + + _validate(h.hugr, True) + + assert _SubPort(h.input_node.out(-1)) in h.hugr._links + assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order + assert len(list(h.hugr.outgoing_order_links(h.input_node))) == 1 + assert len(list(h.hugr.incoming_order_links(nested.root))) == 1 + assert len(list(h.hugr.incoming_order_links(h.output_node))) == 0 + + +def test_ancestral_sibling(): h = Dfg.endo([BOOL_T]) (a,) = h.inputs() nested = h.add_nested([], [BOOL_T]) nt = nested.add(Not(a)) - nested.set_outputs(nt) - # TODO a context manager could add this state order edge on - # exit by tracking parents of source nodes - h.add_state_order(h.input_node, nested.root) - h.set_outputs(nested.root) - _validate(h.hugr) + assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.root