Skip to content

Commit

Permalink
feat(hugr-py): automatically add state order edges for inter-graph ed…
Browse files Browse the repository at this point in the history
…ges (#1165)

uses a slightly dodgy -1 offset hack

hack currently means if you query the hugr for outgoing edges the state
order edges are ignored.
  • Loading branch information
ss2165 authored Jun 6, 2024
1 parent 6eb6d56 commit 5da06e1
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 19 deletions.
22 changes: 18 additions & 4 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._hugr import Hugr, Node, Wire, OutPort

from ._ops import Op, Command, Input, Output, DFG
from ._exceptions import NoSiblingAncestor
from hugr.serialization.tys import FunctionType, Type


Expand Down Expand Up @@ -67,12 +68,25 @@ def set_outputs(self, *args: Wire) -> None:

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
# breaks if further edges are added
self.hugr.add_link(
src.out(self.hugr.num_outgoing(src)), dst.inp(self.hugr.num_incoming(dst))
)
self.hugr.add_link(src.out(-1), dst.inp(-1))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(i))


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
src_parent = h[src].parent

while (tgt_parent := h[tgt].parent) is not None:
if tgt_parent == src_parent:
return tgt
tgt = tgt_parent

return None
11 changes: 11 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass


@dataclass
class NoSiblingAncestor(Exception):
src: int
tgt: int

@property
def msg(self):
return f"Source {self.src} has no sibling ancestor of target {self.tgt}, so cannot wire up."
36 changes: 29 additions & 7 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ def linked_ports(self, port: OutPort | InPort):

# TODO: single linked port

def outgoing_order_links(self, node: Node) -> Iterable[Node]:
return (p.node for p in self.linked_ports(node.out(-1)))

def incoming_order_links(self, node: Node) -> Iterable[Node]:
return (p.node for p in self.linked_ports(node.inp(-1)))

def _node_links(
self, node: Node, links: dict[_SubPort[P], _SubPort[K]]
) -> Iterable[tuple[P, list[K]]]:
Expand Down Expand Up @@ -320,19 +326,35 @@ def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) ->

def to_serial(self) -> SerialHugr:
node_it = (node for node in self._nodes if node is not None)

def _serialise_link(
link: tuple[_SO, _SI],
) -> tuple[tuple[int, int], tuple[int, int]]:
src, dst = link
s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port)
return (src.port.node.idx, s), (dst.port.node.idx, d)

return SerialHugr(
version="v1",
# non contiguous indices will be erased
nodes=[node.to_serial(Node(idx), self) for idx, node in enumerate(node_it)],
edges=[
(
(src.port.node.idx, src.port.offset),
(dst.port.node.idx, dst.port.offset),
)
for src, dst in self._links.items()
],
edges=[_serialise_link(link) for link in self._links.items()],
)

def _constrain_offset(self, p: P) -> int:
# negative offsets are used to refer to the last port
if p.offset < 0:
match p.direction:
case Direction.INCOMING:
current = self.num_incoming(p.node)
case Direction.OUTGOING:
current = self.num_outgoing(p.node)
offset = current + p.offset + 1
else:
offset = p.offset

return offset

@classmethod
def from_serial(cls, serial: SerialHugr) -> Hugr:
assert serial.nodes, "Empty Hugr is invalid"
Expand Down
30 changes: 22 additions & 8 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import subprocess
import os
import pathlib
from hugr._hugr import Hugr, Node, Wire
from hugr._dfg import Dfg
from hugr._hugr import Hugr, Node, Wire, _SubPort
from hugr._dfg import Dfg, _ancestral_sibling
from hugr._ops import Custom, Command
import hugr._ops as ops
from hugr.serialization import SerialHugr
Expand Down Expand Up @@ -226,15 +226,29 @@ def _nested_nop(dfg: Dfg):


def test_build_inter_graph():
h = Dfg.endo([BOOL_T, BOOL_T])
(a, b) = h.inputs()
nested = h.add_nested([], [BOOL_T])

nt = nested.add(Not(a))
nested.set_outputs(nt)

h.set_outputs(nested.root, b)

_validate(h.hugr, True)

assert _SubPort(h.input_node.out(-1)) in h.hugr._links
assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order
assert len(list(h.hugr.outgoing_order_links(h.input_node))) == 1
assert len(list(h.hugr.incoming_order_links(nested.root))) == 1
assert len(list(h.hugr.incoming_order_links(h.output_node))) == 0


def test_ancestral_sibling():
h = Dfg.endo([BOOL_T])
(a,) = h.inputs()
nested = h.add_nested([], [BOOL_T])

nt = nested.add(Not(a))
nested.set_outputs(nt)
# TODO a context manager could add this state order edge on
# exit by tracking parents of source nodes
h.add_state_order(h.input_node, nested.root)
h.set_outputs(nested.root)

_validate(h.hugr)
assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.root

0 comments on commit 5da06e1

Please sign in to comment.