From d6312596323b7969deb339b72cde89efcca05174 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 13 Jun 2024 16:46:09 +0100 Subject: [PATCH 01/17] feat: require ops to require port kinds --- hugr-py/src/hugr/_dfg.py | 12 +++--- hugr-py/src/hugr/_hugr.py | 12 +++++- hugr-py/src/hugr/_ops.py | 86 ++++++++++++++++++++++++++++++--------- hugr-py/src/hugr/_tys.py | 29 +++++++++++++ 4 files changed, 113 insertions(+), 26 deletions(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 80b6ad550..5d31b8776 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -29,12 +29,12 @@ def __init__(self, root_op: DP) -> None: self._init_io_nodes(root_op) def _init_io_nodes(self, root_op: DP): - input_types = root_op.input_types() - output_types = root_op.output_types() + inner_sig = root_op.inner_signature() + self.input_node = self.hugr.add_node( - ops.Input(input_types), self.root, len(input_types) + ops.Input(inner_sig.input), self.root, len(inner_sig.input) ) - self.output_node = self.hugr.add_node(ops.Output(output_types), self.root) + self.output_node = self.hugr.add_node(ops.Output(inner_sig.output), self.root) @classmethod def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self: @@ -61,7 +61,9 @@ def root_op(self) -> DP: def inputs(self) -> list[OutPort]: return [self.input_node.out(i) for i in range(len(self._input_op().types))] - def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node: + def add_op( + self, op: ops.DataflowOp, /, *args: Wire, num_outs: int | None = None + ) -> Node: new_n = self.hugr.add_node(op, self.root, num_outs=num_outs) self._wire_up(new_n, args) return new_n diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 416daac4f..3b072953c 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -16,7 +16,8 @@ from typing_extensions import Self -from hugr._ops import Op +from hugr._ops import Op, DataflowOp +from hugr._tys import Type, Kind from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -333,6 +334,15 @@ def num_outgoing(self, node: ToNode) -> int: # TODO: num_links and _linked_ports + def port_kind(self, port: InPort | OutPort) -> Kind: + return self[port.node].op.port_kind(port) + + def port_type(self, port: InPort | OutPort) -> Type | None: + op = self[port.node].op + if isinstance(op, DataflowOp): + return op.port_type(port) + return None + def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]: mapping: dict[Node, Node] = {} diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 6b6986081..1fdff4feb 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -1,16 +1,17 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Generic, Protocol, TypeVar, TYPE_CHECKING +from typing import Generic, Protocol, TypeVar, TYPE_CHECKING, runtime_checkable from hugr.serialization.ops import BaseOp import hugr.serialization.ops as sops from hugr.utils import ser_it import hugr._tys as tys if TYPE_CHECKING: - from hugr._hugr import Hugr, Node, Wire + from hugr._hugr import Hugr, Node, Wire, _Port +@runtime_checkable class Op(Protocol): @property def num_out(self) -> int | None: @@ -18,13 +19,33 @@ def num_out(self) -> int | None: def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... + def port_kind(self, port: _Port) -> tys.Kind: ... + + +@runtime_checkable +class DataflowOp(Op, Protocol): + def outer_signature(self) -> tys.FunctionType: ... + + def port_kind(self, port: _Port) -> tys.Kind: + if port.offset == -1: + return tys.OrderKind() + return tys.ValueKind(self.port_type(port)) + + def port_type(self, port: _Port) -> tys.Type: + from hugr._hugr import Direction + + sig = self.outer_signature() + if port.direction == Direction.INCOMING: + return sig.input[port.offset] + return sig.output[port.offset] + def __call__(self, *args) -> Command: return Command(self, list(args)) @dataclass(frozen=True) class Command: - op: Op + op: DataflowOp incoming: list[Wire] @@ -41,9 +62,12 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T: root.parent = parent.idx return root + def port_kind(self, port: _Port) -> tys.Kind: + raise NotImplementedError + @dataclass() -class Input(Op): +class Input(DataflowOp): types: tys.TypeRow @property @@ -53,20 +77,26 @@ def num_out(self) -> int | None: def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: return sops.Input(parent=parent.idx, types=ser_it(self.types)) + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(input=[], output=self.types) + def __call__(self) -> Command: return super().__call__() @dataclass() -class Output(Op): +class Output(DataflowOp): types: tys.TypeRow def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: return sops.Output(parent=parent.idx, types=ser_it(self.types)) + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(input=self.types, output=[]) + @dataclass() -class Custom(Op): +class Custom(DataflowOp): op_name: str signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) description: str = "" @@ -87,9 +117,12 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: args=ser_it(self.args), ) + def outer_signature(self) -> tys.FunctionType: + return self.signature + @dataclass() -class MakeTuple(Op): +class MakeTuple(DataflowOp): types: tys.TypeRow num_out: int | None = 1 @@ -102,9 +135,12 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: def __call__(self, *elements: Wire) -> Command: return super().__call__(*elements) + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)]) + @dataclass() -class UnpackTuple(Op): +class UnpackTuple(DataflowOp): types: tys.TypeRow @property @@ -120,14 +156,16 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple: def __call__(self, tuple_: Wire) -> Command: return super().__call__(tuple_) + def outer_signature(self) -> tys.FunctionType: + return MakeTuple(self.types).outer_signature().flip() + class DfParentOp(Op, Protocol): - def input_types(self) -> tys.TypeRow: ... - def output_types(self) -> tys.TypeRow: ... + def inner_signature(self) -> tys.FunctionType: ... @dataclass() -class DFG(DfParentOp): +class DFG(DfParentOp, DataflowOp): signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) @property @@ -140,15 +178,15 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: signature=self.signature.to_serial(), ) - def input_types(self) -> tys.TypeRow: - return self.signature.input + def inner_signature(self) -> tys.FunctionType: + return self.signature - def output_types(self) -> tys.TypeRow: - return self.signature.output + def outer_signature(self) -> tys.FunctionType: + return self.signature @dataclass() -class CFG(Op): +class CFG(DataflowOp): signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) @property @@ -161,6 +199,9 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CFG: signature=self.signature.to_serial(), ) + def outer_signature(self) -> tys.FunctionType: + return self.signature + @dataclass class DataflowBlock(DfParentOp): @@ -182,11 +223,13 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DataflowBlock: extension_delta=self.extension_delta, ) - def input_types(self) -> tys.TypeRow: - return self.inputs + def inner_signature(self) -> tys.FunctionType: + return tys.FunctionType( + input=self.inputs, output=[tys.Sum(self.sum_rows), *self.other_outputs] + ) - def output_types(self) -> tys.TypeRow: - return [tys.Sum(self.sum_rows), *self.other_outputs] + def port_kind(self, port: _Port) -> tys.Kind: + return tys.CFKind() @dataclass @@ -199,3 +242,6 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock: parent=parent.idx, cfg_outputs=ser_it(self.cfg_outputs), ) + + def port_kind(self, port: _Port) -> tys.Kind: + return tys.CFKind() diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 6f199959c..707af9d43 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -233,6 +233,9 @@ def to_serial(self) -> stys.FunctionType: def empty(cls) -> FunctionType: return cls(input=[], output=[]) + def flip(self) -> FunctionType: + return FunctionType(input=self.output, output=self.input) + @dataclass(frozen=True) class PolyFuncType(Type): @@ -270,3 +273,29 @@ def to_serial(self) -> stys.Qubit: Qubit = QubitDef() Bool = UnitSum(size=2) Unit = UnitSum(size=1) + + +@dataclass(frozen=True) +class ValueKind: + ty: Type + + +@dataclass(frozen=True) +class ConstKind: + ty: Type + + +@dataclass(frozen=True) +class FunctionKind: + ty: PolyFuncType + + +@dataclass(frozen=True) +class CFKind: ... + + +@dataclass(frozen=True) +class OrderKind: ... + + +Kind = ValueKind | ConstKind | FunctionKind | CFKind | OrderKind From e7b65296bd6cabc33f78f4868a9cad5065804df6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 13 Jun 2024 23:18:25 +0100 Subject: [PATCH 02/17] feat(hugr-py): flow type from input to output while building --- hugr-py/src/hugr/_dfg.py | 55 ++++++++++++++++++--------- hugr-py/src/hugr/_ops.py | 42 ++++++++++++++++---- hugr-py/src/hugr/serialization/ops.py | 8 ++-- hugr-py/tests/test_cfg.py | 2 +- hugr-py/tests/test_hugr_build.py | 38 +++++++++--------- 5 files changed, 95 insertions(+), 50 deletions(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 5d31b8776..bd25c5a9f 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -1,13 +1,22 @@ from __future__ import annotations +from dataclasses import dataclass, replace +from typing import ( + Iterator, + Iterable, + TYPE_CHECKING, + Generic, + TypeVar, + cast, +) +from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder -from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Iterable, TypeVar, cast from typing_extensions import Self import hugr._ops as ops from hugr._tys import FunctionType, TypeRow from ._exceptions import NoSiblingAncestor -from ._hugr import Hugr, Node, OutPort, ParentBuilder, Wire, ToNode +from ._hugr import ToNode +from hugr._tys import Type if TYPE_CHECKING: from ._cfg import Cfg @@ -61,15 +70,14 @@ def root_op(self) -> DP: def inputs(self) -> list[OutPort]: return [self.input_node.out(i) for i in range(len(self._input_op().types))] - def add_op( - self, op: ops.DataflowOp, /, *args: Wire, num_outs: int | None = None - ) -> Node: - new_n = self.hugr.add_node(op, self.root, num_outs=num_outs) + def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node: + new_n = self.hugr.add_node(op, self.root) self._wire_up(new_n, args) - return new_n + + return replace(new_n, _num_out_ports=op.num_out) def add(self, com: ops.Command) -> Node: - return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out) + return self.add_op(com.op, *com.incoming) def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: mapping = self.hugr.insert_hugr(dfg.hugr, self.root) @@ -78,13 +86,13 @@ def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: def add_nested( self, - input_types: TypeRow, - output_types: TypeRow, *args: Wire, ) -> Dfg: from ._dfg import Dfg - root_op = ops.DFG(FunctionType(input=input_types, output=output_types)) + _, input_types = zip(*self._get_dataflow_types(args)) if args else ([], []) + + root_op = ops.DFG(FunctionType(input=list(input_types), output=[])) dfg = Dfg.new_nested(root_op, self.hugr, self.root) self._wire_up(dfg.root, args) return dfg @@ -108,14 +116,27 @@ def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) + self.root_op()._set_out_types(self._output_op().types) def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges self.hugr.add_link(src.out(-1), dst.inp(-1)) def _wire_up(self, node: Node, ports: Iterable[Wire]): - for i, p in enumerate(ports): + tys = [] + for i, (p, ty) in enumerate(self._get_dataflow_types(ports)): + tys.append(ty) self._wire_up_port(node, i, p) + if isinstance(op := self.hugr[node].op, ops.DataflowOp): + op._set_in_types(tys) + + def _get_dataflow_types(self, wires: Iterable[Wire]) -> Iterator[tuple[Wire, Type]]: + for w in wires: + port = w.out_port() + ty = self.hugr.port_type(port) + if ty is None: + raise ValueError(f"Port {port} is not a dataflow port.") + yield w, ty def _wire_up_port(self, node: Node, offset: int, p: Wire): src = p.out_port() @@ -128,14 +149,10 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire): class Dfg(_DfBase[ops.DFG]): - def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None: - root_op = ops.DFG(FunctionType(input=input_types, output=output_types)) + def __init__(self, *input_types: Type) -> None: + root_op = ops.DFG(FunctionType(input=list(input_types), output=[])) super().__init__(root_op) - @classmethod - def endo(cls, types: TypeRow) -> Dfg: - return cls(types, types) - def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: src_parent = h[src].parent diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 1fdff4feb..f07c8154c 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import Generic, Protocol, TypeVar, TYPE_CHECKING, runtime_checkable from hugr.serialization.ops import BaseOp import hugr.serialization.ops as sops @@ -39,6 +39,9 @@ def port_type(self, port: _Port) -> tys.Type: return sig.input[port.offset] return sig.output[port.offset] + def _set_in_types(self, types: tys.TypeRow) -> None: + return + def __call__(self, *args) -> Command: return Command(self, list(args)) @@ -86,7 +89,7 @@ def __call__(self) -> Command: @dataclass() class Output(DataflowOp): - types: tys.TypeRow + types: tys.TypeRow = field(default_factory=list) def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: return sops.Output(parent=parent.idx, types=ser_it(self.types)) @@ -94,6 +97,9 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=self.types, output=[]) + def _set_in_types(self, types: tys.TypeRow) -> None: + self.types = types + @dataclass() class Custom(DataflowOp): @@ -122,8 +128,8 @@ def outer_signature(self) -> tys.FunctionType: @dataclass() -class MakeTuple(DataflowOp): - types: tys.TypeRow +class MakeTupleDef(DataflowOp): + types: tys.TypeRow = field(default_factory=list) num_out: int | None = 1 def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: @@ -138,10 +144,16 @@ def __call__(self, *elements: Wire) -> Command: def outer_signature(self) -> tys.FunctionType: return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)]) + def _set_in_types(self, types: tys.TypeRow) -> None: + self.types = types + + +MakeTuple = MakeTupleDef() + @dataclass() -class UnpackTuple(DataflowOp): - types: tys.TypeRow +class UnpackTupleDef(DataflowOp): + types: tys.TypeRow = field(default_factory=list) @property def num_out(self) -> int | None: @@ -157,12 +169,25 @@ def __call__(self, tuple_: Wire) -> Command: return super().__call__(tuple_) def outer_signature(self) -> tys.FunctionType: - return MakeTuple(self.types).outer_signature().flip() + return MakeTupleDef(self.types).outer_signature().flip() + + def _set_in_types(self, types: tys.TypeRow) -> None: + (t,) = types + assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}" + (row,) = t.variant_rows + self.types = row + print(row) + + +UnpackTuple = UnpackTupleDef() class DfParentOp(Op, Protocol): def inner_signature(self) -> tys.FunctionType: ... + def _set_out_types(self, types: tys.TypeRow) -> None: + return + @dataclass() class DFG(DfParentOp, DataflowOp): @@ -184,6 +209,9 @@ def inner_signature(self) -> tys.FunctionType: def outer_signature(self) -> tys.FunctionType: return self.signature + def _set_out_types(self, types: tys.TypeRow) -> None: + self.signature = replace(self.signature, output=types) + @dataclass() class CFG(DataflowOp): diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index fbeb3e8c2..389a7c2de 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -447,8 +447,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: in_types = [] self.tys = list(in_types) - def deserialize(self) -> _ops.MakeTuple: - return _ops.MakeTuple(deser_it(self.tys)) + def deserialize(self) -> _ops.MakeTupleDef: + return _ops.MakeTupleDef(deser_it(self.tys)) class UnpackTuple(DataflowOp): @@ -460,8 +460,8 @@ class UnpackTuple(DataflowOp): def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.tys = list(out_types) - def deserialize(self) -> _ops.UnpackTuple: - return _ops.UnpackTuple(deser_it(self.tys)) + def deserialize(self) -> _ops.UnpackTupleDef: + return _ops.UnpackTupleDef(deser_it(self.tys)) class Tag(DataflowOp): diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 8d41ef8c0..dd6d3f849 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -39,7 +39,7 @@ def test_branch() -> None: def test_nested_cfg() -> None: - dfg = Dfg([tys.Unit, tys.Bool], [tys.Bool]) + dfg = Dfg(tys.Unit, tys.Bool) cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs()) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 167f28d2f..54d92ea99 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -120,7 +120,7 @@ def test_stable_indices(): def test_simple_id(): - h = Dfg.endo([tys.Qubit] * 2) + h = Dfg(tys.Qubit, tys.Qubit) a, b = h.inputs() h.set_outputs(a, b) @@ -128,7 +128,7 @@ def test_simple_id(): def test_multiport(): - h = Dfg([tys.Bool], [tys.Bool] * 2) + h = Dfg(tys.Bool) (a,) = h.inputs() h.set_outputs(a, a) in_n, ou_n = h.input_node, h.output_node @@ -151,7 +151,7 @@ def test_multiport(): def test_add_op(): - h = Dfg.endo([tys.Bool]) + h = Dfg(tys.Bool) (a,) = h.inputs() nt = h.add_op(Not, a) h.set_outputs(nt) @@ -161,25 +161,25 @@ def test_add_op(): def test_tuple(): row = [tys.Bool, tys.Qubit] - h = Dfg.endo(row) + h = Dfg(*row) a, b = h.inputs() - t = h.add(ops.MakeTuple(row)(a, b)) - a, b = h.add(ops.UnpackTuple(row)(t)) + t = h.add(ops.MakeTuple(a, b)) + a, b = h.add(ops.UnpackTuple(t)) h.set_outputs(a, b) _validate(h.hugr) - h1 = Dfg.endo(row) + h1 = Dfg(*row) a, b = h1.inputs() - mt = h1.add_op(ops.MakeTuple(row), a, b) - a, b = h1.add_op(ops.UnpackTuple(row), mt)[0, 1] + mt = h1.add_op(ops.MakeTuple, a, b) + a, b = h1.add_op(ops.UnpackTuple, mt)[0, 1] h1.set_outputs(a, b) assert h.hugr.to_serial() == h1.hugr.to_serial() def test_multi_out(): - h = Dfg([INT_T] * 2, [INT_T] * 2) + h = Dfg(INT_T, INT_T) a, b = h.inputs() a, b = h.add(DivMod(a, b)) h.set_outputs(a, b) @@ -187,7 +187,7 @@ def test_multi_out(): def test_insert(): - h1 = Dfg.endo([tys.Bool]) + h1 = Dfg(tys.Bool) (a1,) = h1.inputs() nt = h1.add(Not(a1)) h1.set_outputs(nt) @@ -200,12 +200,12 @@ def test_insert(): def test_insert_nested(): - h1 = Dfg.endo([tys.Bool]) + h1 = Dfg(tys.Bool) (a1,) = h1.inputs() nt = h1.add(Not(a1)) h1.set_outputs(nt) - h = Dfg.endo([tys.Bool]) + h = Dfg(tys.Bool) (a,) = h.inputs() nested = h.insert_nested(h1, a) h.set_outputs(nested) @@ -219,9 +219,9 @@ def _nested_nop(dfg: Dfg): nt = dfg.add(Not(a1)) dfg.set_outputs(nt) - h = Dfg.endo([tys.Bool]) + h = Dfg(tys.Bool) (a,) = h.inputs() - nested = h.add_nested([tys.Bool], [tys.Bool], a) + nested = h.add_nested(a) _nested_nop(nested) assert len(h.hugr.children(nested)) == 3 @@ -231,9 +231,9 @@ def _nested_nop(dfg: Dfg): def test_build_inter_graph(): - h = Dfg.endo([tys.Bool, tys.Bool]) + h = Dfg(tys.Bool, tys.Bool) (a, b) = h.inputs() - nested = h.add_nested([], [tys.Bool]) + nested = h.add_nested() nt = nested.add(Not(a)) nested.set_outputs(nt) @@ -250,9 +250,9 @@ def test_build_inter_graph(): def test_ancestral_sibling(): - h = Dfg.endo([tys.Bool]) + h = Dfg(tys.Bool) (a,) = h.inputs() - nested = h.add_nested([], [tys.Bool]) + nested = h.add_nested() nt = nested.add(Not(a)) From 9f72d6125dc19bf58e6c6230969a53f80530dd7c Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 13 Jun 2024 23:25:51 +0100 Subject: [PATCH 03/17] refactorL common up op casting --- hugr-py/src/hugr/_cfg.py | 7 ++++--- hugr-py/src/hugr/_dfg.py | 8 ++------ hugr-py/src/hugr/_hugr.py | 7 +++++++ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 2001b53a5..76102bf66 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -84,9 +84,10 @@ def entry(self) -> Node: return self._entry_block.root def _entry_op(self) -> ops.DataflowBlock: - dop = self.hugr[self.entry].op - assert isinstance(dop, ops.DataflowBlock) - return dop + return self.hugr._get_typed_op(self.entry, ops.DataflowBlock) + + def _exit_op(self) -> ops.ExitBlock: + return self.hugr._get_typed_op(self.exit, ops.ExitBlock) def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block: # update entry block types diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index bd25c5a9f..b9e5356cc 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -55,14 +55,10 @@ def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Se return new def _input_op(self) -> ops.Input: - dop = self.hugr[self.input_node].op - assert isinstance(dop, ops.Input) - return dop + return self.hugr._get_typed_op(self.input_node, ops.Input) def _output_op(self) -> ops.Output: - dop = self.hugr[self.output_node].op - assert isinstance(dop, ops.Output) - return dop + return self.hugr._get_typed_op(self.output_node, ops.Output) def root_op(self) -> DP: return cast(DP, self.hugr[self.root].op) diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 3b072953c..a1c48f883 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -12,6 +12,7 @@ TypeVar, cast, overload, + Type as PyType, ) from typing_extensions import Self @@ -131,6 +132,7 @@ def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: P = TypeVar("P", InPort, OutPort) K = TypeVar("K", InPort, OutPort) +OpVar = TypeVar("OpVar", bound=Op) @dataclass(frozen=True, eq=True, order=True) @@ -183,6 +185,11 @@ def __iter__(self): def __len__(self) -> int: return self.num_nodes() + def _get_typed_op(self, node: ToNode, cl: PyType[OpVar]) -> OpVar: + op = self[node].op + assert isinstance(op, cl) + return op + def children(self, node: ToNode | None = None) -> list[Node]: node = node or self.root return self[node].children From fcdbd5caf778e869b560157892575cf8fb8d9cf1 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 17 Jun 2024 16:24:52 +0100 Subject: [PATCH 04/17] clean up wiring up inheritance --- hugr-py/src/hugr/_cfg.py | 40 ++++++++++++++++++++-------------------- hugr-py/src/hugr/_dfg.py | 24 ++++++++++-------------- 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 76102bf66..e87c56b0f 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -1,14 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterable, Sequence +from typing import Sequence import hugr._ops as ops from ._dfg import _DfBase from ._exceptions import NoSiblingAncestor, NotInSameCfg from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire -from ._tys import FunctionType, Sum, TypeRow +from ._tys import FunctionType, Sum, TypeRow, Type class Block(_DfBase[ops.DataflowBlock]): @@ -19,24 +19,24 @@ def set_single_successor_outputs(self, *outputs: Wire) -> None: # TODO requires constants raise NotImplementedError - def _wire_up(self, node: Node, ports: Iterable[Wire]): - for i, p in enumerate(ports): - src = p.out_port() - cfg_node = self.hugr[self.root].parent - assert cfg_node is not None - src_parent = self.hugr[src.node].parent - try: - self._wire_up_port(node, i, p) - except NoSiblingAncestor: - # note this just checks if there is a common CFG ancestor - # it does not check for valid dominance between basic blocks - # that is deferred to full HUGR validation. - while cfg_node != src_parent: - if src_parent is None or src_parent == self.hugr.root: - raise NotInSameCfg(src.node.idx, node.idx) - src_parent = self.hugr[src_parent].parent - - self.hugr.add_link(src, node.inp(i)) + def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: + src = p.out_port() + cfg_node = self.hugr[self.root].parent + assert cfg_node is not None + src_parent = self.hugr[src.node].parent + try: + super()._wire_up_port(node, offset, p) + except NoSiblingAncestor: + # note this just checks if there is a common CFG ancestor + # it does not check for valid dominance between basic blocks + # that is deferred to full HUGR validation. + while cfg_node != src_parent: + if src_parent is None or src_parent == self.hugr.root: + raise NotInSameCfg(src.node.idx, node.idx) + src_parent = self.hugr[src_parent].parent + + self.hugr.add_link(src, node.inp(offset)) + return self._get_dataflow_type(src) @dataclass diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index b9e5356cc..a780d46f4 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, replace from typing import ( - Iterator, Iterable, TYPE_CHECKING, Generic, @@ -86,7 +85,7 @@ def add_nested( ) -> Dfg: from ._dfg import Dfg - _, input_types = zip(*self._get_dataflow_types(args)) if args else ([], []) + input_types = [self._get_dataflow_type(w) for w in args] root_op = ops.DFG(FunctionType(input=list(input_types), output=[])) dfg = Dfg.new_nested(root_op, self.hugr, self.root) @@ -119,22 +118,18 @@ def add_state_order(self, src: Node, dst: Node) -> None: self.hugr.add_link(src.out(-1), dst.inp(-1)) def _wire_up(self, node: Node, ports: Iterable[Wire]): - tys = [] - for i, (p, ty) in enumerate(self._get_dataflow_types(ports)): - tys.append(ty) - self._wire_up_port(node, i, p) + tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops.DataflowOp): op._set_in_types(tys) - def _get_dataflow_types(self, wires: Iterable[Wire]) -> Iterator[tuple[Wire, Type]]: - for w in wires: - port = w.out_port() - ty = self.hugr.port_type(port) - if ty is None: - raise ValueError(f"Port {port} is not a dataflow port.") - yield w, ty + def _get_dataflow_type(self, wire: Wire) -> Type: + port = wire.out_port() + ty = self.hugr.port_type(port) + if ty is None: + raise ValueError(f"Port {port} is not a dataflow port.") + return ty - def _wire_up_port(self, node: Node, offset: int, p: Wire): + def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: src = p.out_port() node_ancestor = _ancestral_sibling(self.hugr, src.node, node) if node_ancestor is None: @@ -142,6 +137,7 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire): if node_ancestor != node: self.add_state_order(src.node, node_ancestor) self.hugr.add_link(src, node.inp(offset)) + return self._get_dataflow_type(src) class Dfg(_DfBase[ops.DFG]): From dcfebc47ba91499b0df101cc9874cb95d340cfe6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 17 Jun 2024 16:42:50 +0100 Subject: [PATCH 05/17] make builders/hugr generic over root op --- hugr-py/src/hugr/_cfg.py | 14 +++++------ hugr-py/src/hugr/_dfg.py | 43 +++++++++++++++----------------- hugr-py/src/hugr/_hugr.py | 18 ++++++++----- hugr-py/tests/test_hugr_build.py | 2 +- 4 files changed, 40 insertions(+), 37 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index e87c56b0f..5a248713e 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -21,7 +21,7 @@ def set_single_successor_outputs(self, *outputs: Wire) -> None: def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: src = p.out_port() - cfg_node = self.hugr[self.root].parent + cfg_node = self.hugr[self.parent_node].parent assert cfg_node is not None src_parent = self.hugr[src.node].parent try: @@ -40,9 +40,9 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: @dataclass -class Cfg(ParentBuilder): +class Cfg(ParentBuilder[ops.CFG]): hugr: Hugr - root: Node + parent_node: Node _entry_block: Block exit: Node @@ -55,13 +55,13 @@ def _init_impl( self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow, output_types: TypeRow ) -> None: self.hugr = hugr - self.root = root + self.parent_node = root # to ensure entry is first child, add a dummy entry at the start self._entry_block = Block.new_nested( ops.DataflowBlock(input_types, []), hugr, root ) - self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root) + self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.parent_node) @classmethod def new_nested( @@ -81,7 +81,7 @@ def new_nested( @property def entry(self) -> Node: - return self._entry_block.root + return self._entry_block.parent_node def _entry_op(self) -> ops.DataflowBlock: return self.hugr._get_typed_op(self.entry, ops.DataflowBlock) @@ -105,7 +105,7 @@ def add_block( new_block = Block.new_nested( ops.DataflowBlock(input_types, list(sum_rows), other_outputs), self.hugr, - self.root, + self.parent_node, ) return new_block diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index a780d46f4..5a9fff66e 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -3,9 +3,7 @@ from typing import ( Iterable, TYPE_CHECKING, - Generic, TypeVar, - cast, ) from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder @@ -25,31 +23,33 @@ @dataclass() -class _DfBase(ParentBuilder, Generic[DP]): +class _DfBase(ParentBuilder[DP]): hugr: Hugr - root: Node + parent_node: Node input_node: Node output_node: Node def __init__(self, root_op: DP) -> None: self.hugr = Hugr(root_op) - self.root = self.hugr.root + self.parent_node = self.hugr.root self._init_io_nodes(root_op) def _init_io_nodes(self, root_op: DP): inner_sig = root_op.inner_signature() self.input_node = self.hugr.add_node( - ops.Input(inner_sig.input), self.root, len(inner_sig.input) + ops.Input(inner_sig.input), self.parent_node, len(inner_sig.input) + ) + self.output_node = self.hugr.add_node( + ops.Output(inner_sig.output), self.parent_node ) - self.output_node = self.hugr.add_node(ops.Output(inner_sig.output), self.root) @classmethod def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self: new = cls.__new__(cls) new.hugr = hugr - new.root = hugr.add_node(root_op, parent or hugr.root) + new.parent_node = hugr.add_node(root_op, parent or hugr.root) new._init_io_nodes(root_op) return new @@ -59,14 +59,11 @@ def _input_op(self) -> ops.Input: def _output_op(self) -> ops.Output: return self.hugr._get_typed_op(self.output_node, ops.Output) - def root_op(self) -> DP: - return cast(DP, self.hugr[self.root].op) - def inputs(self) -> list[OutPort]: return [self.input_node.out(i) for i in range(len(self._input_op().types))] def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node: - new_n = self.hugr.add_node(op, self.root) + new_n = self.hugr.add_node(op, self.parent_node) self._wire_up(new_n, args) return replace(new_n, _num_out_ports=op.num_out) @@ -75,9 +72,9 @@ def add(self, com: ops.Command) -> Node: return self.add_op(com.op, *com.incoming) def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: - mapping = self.hugr.insert_hugr(dfg.hugr, self.root) - self._wire_up(mapping[dfg.root], args) - return mapping[dfg.root] + mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node) + self._wire_up(mapping[dfg.parent_node], args) + return mapping[dfg.parent_node] def add_nested( self, @@ -88,8 +85,8 @@ def add_nested( input_types = [self._get_dataflow_type(w) for w in args] root_op = ops.DFG(FunctionType(input=list(input_types), output=[])) - dfg = Dfg.new_nested(root_op, self.hugr, self.root) - self._wire_up(dfg.root, args) + dfg = Dfg.new_nested(root_op, self.hugr, self.parent_node) + self._wire_up(dfg.parent_node, args) return dfg def add_cfg( @@ -100,18 +97,18 @@ def add_cfg( ) -> Cfg: from ._cfg import Cfg - cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.root) - self._wire_up(cfg.root, args) + cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.parent_node) + self._wire_up(cfg.parent_node, args) return cfg def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: - mapping = self.hugr.insert_hugr(cfg.hugr, self.root) - self._wire_up(mapping[cfg.root], args) - return mapping[cfg.root] + mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node) + self._wire_up(mapping[cfg.parent_node], args) + return mapping[cfg.parent_node] def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) - self.root_op()._set_out_types(self._output_op().types) + self.parent_op()._set_out_types(self._output_op().types) def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index a1c48f883..cb1d85f66 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -148,22 +148,25 @@ def next_sub_offset(self) -> Self: _SI = _SubPort[InPort] -class ParentBuilder(ToNode, Protocol): - hugr: Hugr - root: Node +class ParentBuilder(ToNode, Protocol[OpVar]): + hugr: Hugr[OpVar] + parent_node: Node def to_node(self) -> Node: - return self.root + return self.parent_node + + def parent_op(self) -> OpVar: + return cast(OpVar, self.hugr[self.parent_node].op) @dataclass() -class Hugr(Mapping[Node, NodeData]): +class Hugr(Mapping[Node, NodeData], Generic[OpVar]): root: Node _nodes: list[NodeData | None] _links: BiMap[_SO, _SI] _free_nodes: list[Node] - def __init__(self, root_op: Op) -> None: + def __init__(self, root_op: OpVar) -> None: self._free_nodes = [] self._links = BiMap() self._nodes = [] @@ -269,6 +272,9 @@ def delete_link(self, src: OutPort, dst: InPort) -> None: return # TODO make sure sub-offset is handled correctly + def root_op(self) -> OpVar: + return cast(OpVar, self[self.root].op) + def num_nodes(self) -> int: return len(self._nodes) - len(self._free_nodes) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 54d92ea99..1f0e5acd5 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -256,4 +256,4 @@ def test_ancestral_sibling(): nt = nested.add(Not(a)) - assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.root + assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.parent_node From e8b3a176f6f896447c3692e45a0f7e89a66d1f96 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 17 Jun 2024 18:26:51 +0100 Subject: [PATCH 06/17] feat: flow input types to output in cfg building --- hugr-py/src/hugr/_cfg.py | 59 +++++++++++++++------------------ hugr-py/src/hugr/_dfg.py | 2 +- hugr-py/src/hugr/_exceptions.py | 11 ++++++ hugr-py/src/hugr/_ops.py | 26 ++++++++++++++- hugr-py/src/hugr/_tys.py | 20 ++++++----- hugr-py/tests/test_cfg.py | 42 +++++++++++++++++------ 6 files changed, 108 insertions(+), 52 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 5a248713e..0141c4419 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -1,14 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Sequence +from dataclasses import dataclass, replace import hugr._ops as ops from ._dfg import _DfBase -from ._exceptions import NoSiblingAncestor, NotInSameCfg +from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire -from ._tys import FunctionType, Sum, TypeRow, Type +from ._tys import FunctionType, TypeRow, Type class Block(_DfBase[ops.DataflowBlock]): @@ -46,14 +45,12 @@ class Cfg(ParentBuilder[ops.CFG]): _entry_block: Block exit: Node - def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None: - root_op = ops.CFG(FunctionType(input=input_types, output=output_types)) + def __init__(self, input_types: TypeRow) -> None: + root_op = ops.CFG(FunctionType(input=input_types, output=[])) hugr = Hugr(root_op) - self._init_impl(hugr, hugr.root, input_types, output_types) + self._init_impl(hugr, hugr.root, input_types) - def _init_impl( - self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow, output_types: TypeRow - ) -> None: + def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None: self.hugr = hugr self.parent_node = root # to ensure entry is first child, add a dummy entry at the start @@ -61,22 +58,21 @@ def _init_impl( ops.DataflowBlock(input_types, []), hugr, root ) - self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.parent_node) + self.exit = self.hugr.add_node(ops.ExitBlock([]), self.parent_node) @classmethod def new_nested( cls, input_types: TypeRow, - output_types: TypeRow, hugr: Hugr, parent: ToNode | None = None, ) -> Cfg: new = cls.__new__(cls) root = hugr.add_node( - ops.CFG(FunctionType(input=input_types, output=output_types)), + ops.CFG(FunctionType(input=input_types, output=[])), parent or hugr.root, ) - new._init_impl(hugr, root, input_types, output_types) + new._init_impl(hugr, root, input_types) return new @property @@ -89,30 +85,29 @@ def _entry_op(self) -> ops.DataflowBlock: def _exit_op(self) -> ops.ExitBlock: return self.hugr._get_typed_op(self.exit, ops.ExitBlock) - def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block: - # update entry block types - self._entry_op().sum_rows = list(sum_rows) - self._entry_op().other_outputs = other_outputs - self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs] + def add_entry(self) -> Block: return self._entry_block - def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block: - return self.add_entry([[]] * n_branches, other_outputs) - - def add_block( - self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow - ) -> Block: + def add_block(self, input_types: TypeRow) -> Block: new_block = Block.new_nested( - ops.DataflowBlock(input_types, list(sum_rows), other_outputs), + ops.DataflowBlock(input_types, [], []), self.hugr, self.parent_node, ) return new_block - def simple_block( - self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow - ) -> Block: - return self.add_block(input_types, [[]] * n_branches, other_outputs) - def branch(self, src: Wire, dst: ToNode) -> None: - self.hugr.add_link(src.out_port(), dst.inp(0)) + src = src.out_port() + self.hugr.add_link(src, dst.inp(0)) + + if dst == self.exit: + src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock) + out_types = [*src_block.sum_rows[src.offset], *src_block.other_outputs] + if self._exit_op().cfg_outputs: + if self._exit_op().cfg_outputs != out_types: + raise MismatchedExit(src.node.idx) + else: + self._exit_op().cfg_outputs = out_types + self.parent_op().signature = replace( + self.parent_op().signature, output=out_types + ) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 5a9fff66e..f0658f397 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -97,7 +97,7 @@ def add_cfg( ) -> Cfg: from ._cfg import Cfg - cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.parent_node) + cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node) self._wire_up(cfg.parent_node, args) return cfg diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py index 92ba7ceb0..e5c211d44 100644 --- a/hugr-py/src/hugr/_exceptions.py +++ b/hugr-py/src/hugr/_exceptions.py @@ -21,5 +21,16 @@ def msg(self): return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up." +@dataclass +class MismatchedExit(Exception): + src: int + + @property + def msg(self): + return ( + f"Exit branch from node {self.src} does not match existing exit block type." + ) + + class ParentBeforeChild(Exception): msg: str = "Parent node must be added before child node." diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index f07c8154c..149db8680 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -176,12 +176,30 @@ def _set_in_types(self, types: tys.TypeRow) -> None: assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}" (row,) = t.variant_rows self.types = row - print(row) UnpackTuple = UnpackTupleDef() +@dataclass() +class Tag(DataflowOp): + tag: int + variants: list[tys.TypeRow] + num_out: int | None = 1 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Tag: + return sops.Tag( + parent=parent.idx, + tag=self.tag, + variants=[ser_it(r) for r in self.variants], + ) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType( + input=self.variants[self.tag], output=[tys.Sum(self.variants)] + ) + + class DfParentOp(Op, Protocol): def inner_signature(self) -> tys.FunctionType: ... @@ -259,6 +277,12 @@ def inner_signature(self) -> tys.FunctionType: def port_kind(self, port: _Port) -> tys.Kind: return tys.CFKind() + def _set_out_types(self, types: tys.TypeRow) -> None: + (sum_, *other) = types + assert isinstance(sum_, tys.Sum), f"Expected Sum, got {sum_}" + self.sum_rows = sum_.variant_rows + self.other_outputs = other + @dataclass class ExitBlock(Op): diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 707af9d43..bcf59128c 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -159,14 +159,6 @@ def to_serial(self) -> stys.Array: return stys.Array(ty=self.ty.to_serial_root(), len=self.size) -@dataclass(frozen=True) -class UnitSum(Type): - size: int - - def to_serial(self) -> stys.UnitSum: - return stys.UnitSum(size=self.size) - - @dataclass() class Sum(Type): variant_rows: list[TypeRow] @@ -181,6 +173,18 @@ def as_tuple(self) -> Tuple: return Tuple(*self.variant_rows[0]) +@dataclass() +class UnitSum(Sum): + size: int + + def __init__(self, size: int): + self.size = size + super().__init__(variant_rows=[[]] * size) + + def to_serial(self) -> stys.UnitSum: # type: ignore[override] + return stys.UnitSum(size=self.size) + + @dataclass() class Tuple(Sum): def __init__(self, *tys: Type): diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index dd6d3f849..d4378903b 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -1,30 +1,31 @@ from hugr._cfg import Cfg import hugr._tys as tys from hugr._dfg import Dfg +import hugr._ops as ops from .test_hugr_build import _validate, INT_T, DivMod def build_basic_cfg(cfg: Cfg) -> None: - entry = cfg.simple_entry(1, [tys.Bool]) + entry = cfg.add_entry() entry.set_block_outputs(*entry.inputs()) cfg.branch(entry[0], cfg.exit) def test_basic_cfg() -> None: - cfg = Cfg([tys.Unit, tys.Bool], [tys.Bool]) + cfg = Cfg([tys.Unit, tys.Bool]) build_basic_cfg(cfg) _validate(cfg.hugr) def test_branch() -> None: - cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) - entry = cfg.simple_entry(2, [tys.Unit, INT_T]) + cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + entry = cfg.add_entry() entry.set_block_outputs(*entry.inputs()) - middle_1 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + middle_1 = cfg.add_block([tys.Unit, INT_T]) middle_1.set_block_outputs(*middle_1.inputs()) - middle_2 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + middle_2 = cfg.add_block([tys.Unit, INT_T]) u, i = middle_2.inputs() n = middle_2.add(DivMod(i, i)) middle_2.set_block_outputs(u, n[0]) @@ -50,16 +51,16 @@ def test_nested_cfg() -> None: def test_dom_edge() -> None: - cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) - entry = cfg.simple_entry(2, [INT_T]) + cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + entry = cfg.add_entry() b, u, i = entry.inputs() entry.set_block_outputs(b, i) # entry dominates both middles so Unit type can be used as inter-graph # value between basic blocks - middle_1 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_1 = cfg.add_block([INT_T]) middle_1.set_block_outputs(u, *middle_1.inputs()) - middle_2 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_2 = cfg.add_block([INT_T]) middle_2.set_block_outputs(u, *middle_2.inputs()) cfg.branch(entry[0], middle_1) @@ -69,3 +70,24 @@ def test_dom_edge() -> None: cfg.branch(middle_2[0], cfg.exit) _validate(cfg.hugr) + + +def test_asymm_types() -> None: + # test different types going to entry block's susccessors + cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + entry = cfg.add_entry() + b, u, i = entry.inputs() + + tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(i)) + entry.set_block_outputs(tagged_int) + + middle = cfg.add_block([INT_T]) + # discard the int and return the bool from entry + middle.set_block_outputs(u, b) + + # middle expects an int and exit expects a bool + cfg.branch(entry[0], middle) + cfg.branch(entry[1], cfg.exit) + cfg.branch(middle[0], cfg.exit) + + _validate(cfg.hugr) From cf135eb458f4a057d0a07dfa5c5c254c84f863de Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 09:45:07 +0100 Subject: [PATCH 07/17] feat: `PartialOp` protocol for cleaner op mutation --- hugr-py/src/hugr/_dfg.py | 6 ++-- hugr-py/src/hugr/_exceptions.py | 5 ++++ hugr-py/src/hugr/_ops.py | 51 ++++++++++++++++++++++----------- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index f0658f397..54fd20888 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -108,7 +108,7 @@ def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) - self.parent_op()._set_out_types(self._output_op().types) + self.parent_op()._set_out_types(self._output_op()._types()) def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges @@ -116,8 +116,8 @@ def add_state_order(self, src: Node, dst: Node) -> None: def _wire_up(self, node: Node, ports: Iterable[Wire]): tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] - if isinstance(op := self.hugr[node].op, ops.DataflowOp): - op._set_in_types(tys) + if isinstance(op := self.hugr[node].op, ops.PartialOp): + op.set_in_types(tys) def _get_dataflow_type(self, wire: Wire) -> Type: port = wire.out_port() diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py index e5c211d44..f003fc6c7 100644 --- a/hugr-py/src/hugr/_exceptions.py +++ b/hugr-py/src/hugr/_exceptions.py @@ -34,3 +34,8 @@ def msg(self): class ParentBeforeChild(Exception): msg: str = "Parent node must be added before child node." + + +@dataclass +class IncompleteOp(Exception): + msg: str = "Operation is incomplete, may require set_in_types to be called." diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 149db8680..268394223 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -6,6 +6,7 @@ import hugr.serialization.ops as sops from hugr.utils import ser_it import hugr._tys as tys +from ._exceptions import IncompleteOp if TYPE_CHECKING: from hugr._hugr import Hugr, Node, Wire, _Port @@ -39,13 +40,15 @@ def port_type(self, port: _Port) -> tys.Type: return sig.input[port.offset] return sig.output[port.offset] - def _set_in_types(self, types: tys.TypeRow) -> None: - return - def __call__(self, *args) -> Command: return Command(self, list(args)) +@runtime_checkable +class PartialOp(DataflowOp, Protocol): + def set_in_types(self, types: tys.TypeRow) -> None: ... + + @dataclass(frozen=True) class Command: op: DataflowOp @@ -88,16 +91,21 @@ def __call__(self) -> Command: @dataclass() -class Output(DataflowOp): - types: tys.TypeRow = field(default_factory=list) +class Output(PartialOp): + types: tys.TypeRow | None = None + + def _types(self) -> tys.TypeRow: + if self.types is None: + raise IncompleteOp() + return self.types def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: - return sops.Output(parent=parent.idx, types=ser_it(self.types)) + return sops.Output(parent=parent.idx, types=ser_it(self._types())) def outer_signature(self) -> tys.FunctionType: - return tys.FunctionType(input=self.types, output=[]) + return tys.FunctionType(input=self._types(), output=[]) - def _set_in_types(self, types: tys.TypeRow) -> None: + def set_in_types(self, types: tys.TypeRow) -> None: self.types = types @@ -128,23 +136,28 @@ def outer_signature(self) -> tys.FunctionType: @dataclass() -class MakeTupleDef(DataflowOp): - types: tys.TypeRow = field(default_factory=list) +class MakeTupleDef(PartialOp): + types: tys.TypeRow | None = None num_out: int | None = 1 + def _types(self) -> tys.TypeRow: + if self.types is None: + raise IncompleteOp() + return self.types + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: return sops.MakeTuple( parent=parent.idx, - tys=ser_it(self.types), + tys=ser_it(self._types()), ) def __call__(self, *elements: Wire) -> Command: return super().__call__(*elements) def outer_signature(self) -> tys.FunctionType: - return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)]) + return tys.FunctionType(input=self._types(), output=[tys.Tuple(*self._types())]) - def _set_in_types(self, types: tys.TypeRow) -> None: + def set_in_types(self, types: tys.TypeRow) -> None: self.types = types @@ -152,9 +165,14 @@ def _set_in_types(self, types: tys.TypeRow) -> None: @dataclass() -class UnpackTupleDef(DataflowOp): +class UnpackTupleDef(PartialOp): types: tys.TypeRow = field(default_factory=list) + def _types(self) -> tys.TypeRow: + if self.types is None: + raise IncompleteOp() + return self.types + @property def num_out(self) -> int | None: return len(self.types) @@ -171,7 +189,7 @@ def __call__(self, tuple_: Wire) -> Command: def outer_signature(self) -> tys.FunctionType: return MakeTupleDef(self.types).outer_signature().flip() - def _set_in_types(self, types: tys.TypeRow) -> None: + def set_in_types(self, types: tys.TypeRow) -> None: (t,) = types assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}" (row,) = t.variant_rows @@ -203,8 +221,7 @@ def outer_signature(self) -> tys.FunctionType: class DfParentOp(Op, Protocol): def inner_signature(self) -> tys.FunctionType: ... - def _set_out_types(self, types: tys.TypeRow) -> None: - return + def _set_out_types(self, types: tys.TypeRow) -> None: ... @dataclass() From 22f58dc437906a05b039ba80c865338a4b3ddc24 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 15:06:29 +0100 Subject: [PATCH 08/17] differentiate between empty fields and unset fields for incomplete ops --- hugr-py/src/hugr/_cfg.py | 14 ++-- hugr-py/src/hugr/_dfg.py | 17 ++--- hugr-py/src/hugr/_ops.py | 94 +++++++++++++++++++-------- hugr-py/src/hugr/serialization/ops.py | 2 +- hugr-py/tests/test_cfg.py | 2 +- hugr-py/tests/test_hugr_build.py | 4 +- 6 files changed, 84 insertions(+), 49 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 0141c4419..ef1bfcbc5 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -54,11 +54,9 @@ def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None: self.hugr = hugr self.parent_node = root # to ensure entry is first child, add a dummy entry at the start - self._entry_block = Block.new_nested( - ops.DataflowBlock(input_types, []), hugr, root - ) + self._entry_block = Block.new_nested(ops.DataflowBlock(input_types), hugr, root) - self.exit = self.hugr.add_node(ops.ExitBlock([]), self.parent_node) + self.exit = self.hugr.add_node(ops.ExitBlock(), self.parent_node) @classmethod def new_nested( @@ -90,7 +88,7 @@ def add_entry(self) -> Block: def add_block(self, input_types: TypeRow) -> Block: new_block = Block.new_nested( - ops.DataflowBlock(input_types, [], []), + ops.DataflowBlock(input_types), self.hugr, self.parent_node, ) @@ -103,11 +101,11 @@ def branch(self, src: Wire, dst: ToNode) -> None: if dst == self.exit: src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock) out_types = [*src_block.sum_rows[src.offset], *src_block.other_outputs] - if self._exit_op().cfg_outputs: - if self._exit_op().cfg_outputs != out_types: + if self._exit_op()._cfg_outputs is not None: + if self._exit_op()._cfg_outputs != out_types: raise MismatchedExit(src.node.idx) else: - self._exit_op().cfg_outputs = out_types + self._exit_op()._cfg_outputs = out_types self.parent_op().signature = replace( self.parent_op().signature, output=out_types ) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 54fd20888..d9fe6177b 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -9,7 +9,7 @@ from typing_extensions import Self import hugr._ops as ops -from hugr._tys import FunctionType, TypeRow +from hugr._tys import TypeRow from ._exceptions import NoSiblingAncestor from ._hugr import ToNode @@ -35,14 +35,12 @@ def __init__(self, root_op: DP) -> None: self._init_io_nodes(root_op) def _init_io_nodes(self, root_op: DP): - inner_sig = root_op.inner_signature() + inputs = root_op._inputs() self.input_node = self.hugr.add_node( - ops.Input(inner_sig.input), self.parent_node, len(inner_sig.input) - ) - self.output_node = self.hugr.add_node( - ops.Output(inner_sig.output), self.parent_node + ops.Input(inputs), self.parent_node, len(inputs) ) + self.output_node = self.hugr.add_node(ops.Output(), self.parent_node) @classmethod def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self: @@ -84,7 +82,7 @@ def add_nested( input_types = [self._get_dataflow_type(w) for w in args] - root_op = ops.DFG(FunctionType(input=list(input_types), output=[])) + root_op = ops.DFG(list(input_types)) dfg = Dfg.new_nested(root_op, self.hugr, self.parent_node) self._wire_up(dfg.parent_node, args) return dfg @@ -92,7 +90,6 @@ def add_nested( def add_cfg( self, input_types: TypeRow, - output_types: TypeRow, *args: Wire, ) -> Cfg: from ._cfg import Cfg @@ -108,7 +105,7 @@ def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) - self.parent_op()._set_out_types(self._output_op()._types()) + self.parent_op()._set_out_types(self._output_op().types) def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges @@ -139,7 +136,7 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: class Dfg(_DfBase[ops.DFG]): def __init__(self, *input_types: Type) -> None: - root_op = ops.DFG(FunctionType(input=list(input_types), output=[])) + root_op = ops.DFG(list(input_types)) super().__init__(root_op) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 268394223..380a795ce 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field from typing import Generic, Protocol, TypeVar, TYPE_CHECKING, runtime_checkable from hugr.serialization.ops import BaseOp import hugr.serialization.ops as sops @@ -92,21 +92,22 @@ def __call__(self) -> Command: @dataclass() class Output(PartialOp): - types: tys.TypeRow | None = None + _types: tys.TypeRow | None = None - def _types(self) -> tys.TypeRow: - if self.types is None: + @property + def types(self) -> tys.TypeRow: + if self._types is None: raise IncompleteOp() - return self.types + return self._types def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: - return sops.Output(parent=parent.idx, types=ser_it(self._types())) + return sops.Output(parent=parent.idx, types=ser_it(self.types)) def outer_signature(self) -> tys.FunctionType: - return tys.FunctionType(input=self._types(), output=[]) + return tys.FunctionType(input=self.types, output=[]) def set_in_types(self, types: tys.TypeRow) -> None: - self.types = types + self._types = types @dataclass() @@ -137,28 +138,29 @@ def outer_signature(self) -> tys.FunctionType: @dataclass() class MakeTupleDef(PartialOp): - types: tys.TypeRow | None = None + _types: tys.TypeRow | None = None num_out: int | None = 1 - def _types(self) -> tys.TypeRow: - if self.types is None: + @property + def types(self) -> tys.TypeRow: + if self._types is None: raise IncompleteOp() - return self.types + return self._types def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: return sops.MakeTuple( parent=parent.idx, - tys=ser_it(self._types()), + tys=ser_it(self.types), ) def __call__(self, *elements: Wire) -> Command: return super().__call__(*elements) def outer_signature(self) -> tys.FunctionType: - return tys.FunctionType(input=self._types(), output=[tys.Tuple(*self._types())]) + return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)]) def set_in_types(self, types: tys.TypeRow) -> None: - self.types = types + self._types = types MakeTuple = MakeTupleDef() @@ -166,12 +168,13 @@ def set_in_types(self, types: tys.TypeRow) -> None: @dataclass() class UnpackTupleDef(PartialOp): - types: tys.TypeRow = field(default_factory=list) + _types: tys.TypeRow = field(default_factory=list) - def _types(self) -> tys.TypeRow: - if self.types is None: + @property + def types(self) -> tys.TypeRow: + if self._types is None: raise IncompleteOp() - return self.types + return self._types @property def num_out(self) -> int | None: @@ -193,7 +196,7 @@ def set_in_types(self, types: tys.TypeRow) -> None: (t,) = types assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}" (row,) = t.variant_rows - self.types = row + self._types = row UnpackTuple = UnpackTupleDef() @@ -223,10 +226,18 @@ def inner_signature(self) -> tys.FunctionType: ... def _set_out_types(self, types: tys.TypeRow) -> None: ... + def _inputs(self) -> tys.TypeRow: ... + @dataclass() class DFG(DfParentOp, DataflowOp): - signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + _signature: tys.TypeRow | tys.FunctionType + + @property + def signature(self) -> tys.FunctionType: + if isinstance(self._signature, tys.FunctionType): + return self._signature + raise IncompleteOp() @property def num_out(self) -> int | None: @@ -245,7 +256,15 @@ def outer_signature(self) -> tys.FunctionType: return self.signature def _set_out_types(self, types: tys.TypeRow) -> None: - self.signature = replace(self.signature, output=types) + assert isinstance(self._signature, list), "Signature has already been set." + self._signature = tys.FunctionType(self._signature, types) + + def _inputs(self) -> tys.TypeRow: + match self._signature: + case tys.FunctionType(input, _): + return input + case list(_): + return self._signature @dataclass() @@ -269,10 +288,22 @@ def outer_signature(self) -> tys.FunctionType: @dataclass class DataflowBlock(DfParentOp): inputs: tys.TypeRow - sum_rows: list[tys.TypeRow] - other_outputs: tys.TypeRow = field(default_factory=list) + _sum_rows: list[tys.TypeRow] | None = None + _other_outputs: tys.TypeRow | None = None extension_delta: tys.ExtensionSet = field(default_factory=list) + @property + def sum_rows(self) -> list[tys.TypeRow]: + if self._sum_rows is None: + raise IncompleteOp() + return self._sum_rows + + @property + def other_outputs(self) -> tys.TypeRow: + if self._other_outputs is None: + raise IncompleteOp() + return self._other_outputs + @property def num_out(self) -> int | None: return len(self.sum_rows) @@ -297,15 +328,24 @@ def port_kind(self, port: _Port) -> tys.Kind: def _set_out_types(self, types: tys.TypeRow) -> None: (sum_, *other) = types assert isinstance(sum_, tys.Sum), f"Expected Sum, got {sum_}" - self.sum_rows = sum_.variant_rows - self.other_outputs = other + self._sum_rows = sum_.variant_rows + self._other_outputs = other + + def _inputs(self) -> tys.TypeRow: + return self.inputs @dataclass class ExitBlock(Op): - cfg_outputs: tys.TypeRow + _cfg_outputs: tys.TypeRow | None = None num_out: int | None = 0 + @property + def cfg_outputs(self) -> tys.TypeRow: + if self._cfg_outputs is None: + raise IncompleteOp() + return self._cfg_outputs + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock: return sops.ExitBlock( parent=parent.idx, diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 389a7c2de..049aa1d9f 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -229,7 +229,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.types = list(in_types) def deserialize(self) -> _ops.Output: - return _ops.Output(types=deser_it(self.types)) + return _ops.Output(deser_it(self.types)) class Call(DataflowOp): diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index d4378903b..e2c28038a 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -42,7 +42,7 @@ def test_branch() -> None: def test_nested_cfg() -> None: dfg = Dfg(tys.Unit, tys.Bool) - cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs()) + cfg = dfg.add_cfg([tys.Unit, tys.Bool], *dfg.inputs()) build_basic_cfg(cfg) dfg.set_outputs(cfg) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 1f0e5acd5..ff7a2759b 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -85,7 +85,7 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): def test_stable_indices(): - h = Hugr(ops.DFG()) + h = Hugr(ops.DFG([])) nodes = [h.add_node(Not) for _ in range(3)] assert len(h) == 4 @@ -194,7 +194,7 @@ def test_insert(): assert len(h1.hugr) == 4 - new_h = Hugr(ops.DFG()) + new_h = Hugr(ops.DFG([])) mapping = h1.hugr.insert_hugr(new_h, h1.hugr.root) assert mapping == {new_h.root: Node(4)} From a9a4ab998d7b67fe1c138b66480ffae5eb1c6a7e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 15:17:37 +0100 Subject: [PATCH 09/17] make more properties PartialOp needs to be independent of DataflowOp, because when doing isinstance checks on the protocl python calls every property. Some of these properties now raise exceptions. --- hugr-py/src/hugr/_cfg.py | 12 +++++++----- hugr-py/src/hugr/_dfg.py | 2 +- hugr-py/src/hugr/_hugr.py | 1 + hugr-py/src/hugr/_ops.py | 8 ++++---- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index ef1bfcbc5..7a2d3e6d4 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -77,9 +77,11 @@ def new_nested( def entry(self) -> Node: return self._entry_block.parent_node + @property def _entry_op(self) -> ops.DataflowBlock: return self.hugr._get_typed_op(self.entry, ops.DataflowBlock) + @property def _exit_op(self) -> ops.ExitBlock: return self.hugr._get_typed_op(self.exit, ops.ExitBlock) @@ -101,11 +103,11 @@ def branch(self, src: Wire, dst: ToNode) -> None: if dst == self.exit: src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock) out_types = [*src_block.sum_rows[src.offset], *src_block.other_outputs] - if self._exit_op()._cfg_outputs is not None: - if self._exit_op()._cfg_outputs != out_types: + if self._exit_op._cfg_outputs is not None: + if self._exit_op._cfg_outputs != out_types: raise MismatchedExit(src.node.idx) else: - self._exit_op()._cfg_outputs = out_types - self.parent_op().signature = replace( - self.parent_op().signature, output=out_types + self._exit_op._cfg_outputs = out_types + self.parent_op.signature = replace( + self.parent_op.signature, output=out_types ) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index d9fe6177b..d9b828f37 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -105,7 +105,7 @@ def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) - self.parent_op()._set_out_types(self._output_op().types) + self.parent_op._set_out_types(self._output_op().types) def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index cb1d85f66..1dfede794 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -155,6 +155,7 @@ class ParentBuilder(ToNode, Protocol[OpVar]): def to_node(self) -> Node: return self.parent_node + @property def parent_op(self) -> OpVar: return cast(OpVar, self.hugr[self.parent_node].op) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 380a795ce..4b8a4cb2f 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -45,7 +45,7 @@ def __call__(self, *args) -> Command: @runtime_checkable -class PartialOp(DataflowOp, Protocol): +class PartialOp(Protocol): def set_in_types(self, types: tys.TypeRow) -> None: ... @@ -91,7 +91,7 @@ def __call__(self) -> Command: @dataclass() -class Output(PartialOp): +class Output(DataflowOp, PartialOp): _types: tys.TypeRow | None = None @property @@ -137,7 +137,7 @@ def outer_signature(self) -> tys.FunctionType: @dataclass() -class MakeTupleDef(PartialOp): +class MakeTupleDef(DataflowOp, PartialOp): _types: tys.TypeRow | None = None num_out: int | None = 1 @@ -167,7 +167,7 @@ def set_in_types(self, types: tys.TypeRow) -> None: @dataclass() -class UnpackTupleDef(PartialOp): +class UnpackTupleDef(DataflowOp, PartialOp): _types: tys.TypeRow = field(default_factory=list) @property From 73e1fd960343e903f1a51233868dcfdefa9cbc62 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 15:18:49 +0100 Subject: [PATCH 10/17] rename root_op -> parent_op --- hugr-py/src/hugr/_dfg.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index d9b828f37..7e89b91bc 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -29,13 +29,13 @@ class _DfBase(ParentBuilder[DP]): input_node: Node output_node: Node - def __init__(self, root_op: DP) -> None: - self.hugr = Hugr(root_op) + def __init__(self, parent_op: DP) -> None: + self.hugr = Hugr(parent_op) self.parent_node = self.hugr.root - self._init_io_nodes(root_op) + self._init_io_nodes(parent_op) - def _init_io_nodes(self, root_op: DP): - inputs = root_op._inputs() + def _init_io_nodes(self, parent_op: DP): + inputs = parent_op._inputs() self.input_node = self.hugr.add_node( ops.Input(inputs), self.parent_node, len(inputs) @@ -43,12 +43,14 @@ def _init_io_nodes(self, root_op: DP): self.output_node = self.hugr.add_node(ops.Output(), self.parent_node) @classmethod - def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self: + def new_nested( + cls, parent_op: DP, hugr: Hugr, parent: ToNode | None = None + ) -> Self: new = cls.__new__(cls) new.hugr = hugr - new.parent_node = hugr.add_node(root_op, parent or hugr.root) - new._init_io_nodes(root_op) + new.parent_node = hugr.add_node(parent_op, parent or hugr.root) + new._init_io_nodes(parent_op) return new def _input_op(self) -> ops.Input: @@ -82,8 +84,8 @@ def add_nested( input_types = [self._get_dataflow_type(w) for w in args] - root_op = ops.DFG(list(input_types)) - dfg = Dfg.new_nested(root_op, self.hugr, self.parent_node) + parent_op = ops.DFG(list(input_types)) + dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node) self._wire_up(dfg.parent_node, args) return dfg @@ -136,8 +138,8 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: class Dfg(_DfBase[ops.DFG]): def __init__(self, *input_types: Type) -> None: - root_op = ops.DFG(list(input_types)) - super().__init__(root_op) + parent_op = ops.DFG(list(input_types)) + super().__init__(parent_op) def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: From 7d604837fb49e4ac9a0a17da587a5195208b13d7 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 15:20:35 +0100 Subject: [PATCH 11/17] _Port -> InPort | OutPort --- hugr-py/src/hugr/_ops.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 4b8a4cb2f..941f7428e 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -9,7 +9,7 @@ from ._exceptions import IncompleteOp if TYPE_CHECKING: - from hugr._hugr import Hugr, Node, Wire, _Port + from hugr._hugr import Hugr, Node, Wire, InPort, OutPort @runtime_checkable @@ -20,19 +20,19 @@ def num_out(self) -> int | None: def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... - def port_kind(self, port: _Port) -> tys.Kind: ... + def port_kind(self, port: InPort | OutPort) -> tys.Kind: ... @runtime_checkable class DataflowOp(Op, Protocol): def outer_signature(self) -> tys.FunctionType: ... - def port_kind(self, port: _Port) -> tys.Kind: + def port_kind(self, port: InPort | OutPort) -> tys.Kind: if port.offset == -1: return tys.OrderKind() return tys.ValueKind(self.port_type(port)) - def port_type(self, port: _Port) -> tys.Type: + def port_type(self, port: InPort | OutPort) -> tys.Type: from hugr._hugr import Direction sig = self.outer_signature() @@ -68,7 +68,7 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T: root.parent = parent.idx return root - def port_kind(self, port: _Port) -> tys.Kind: + def port_kind(self, port: InPort | OutPort) -> tys.Kind: raise NotImplementedError @@ -322,7 +322,7 @@ def inner_signature(self) -> tys.FunctionType: input=self.inputs, output=[tys.Sum(self.sum_rows), *self.other_outputs] ) - def port_kind(self, port: _Port) -> tys.Kind: + def port_kind(self, port: InPort | OutPort) -> tys.Kind: return tys.CFKind() def _set_out_types(self, types: tys.TypeRow) -> None: @@ -352,5 +352,5 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock: cfg_outputs=ser_it(self.cfg_outputs), ) - def port_kind(self, port: _Port) -> tys.Kind: + def port_kind(self, port: InPort | OutPort) -> tys.Kind: return tys.CFKind() From c11fdc53411c9685106673ca97e7b398c364e646 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 15:21:12 +0100 Subject: [PATCH 12/17] `ops.Tag.__call__` impl --- hugr-py/src/hugr/_ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 941f7428e..bc818a769 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -220,6 +220,9 @@ def outer_signature(self) -> tys.FunctionType: input=self.variants[self.tag], output=[tys.Sum(self.variants)] ) + def __call__(self, value: Wire) -> Command: + return super().__call__(value) + class DfParentOp(Op, Protocol): def inner_signature(self) -> tys.FunctionType: ... From 06645454f8ed7c9d8b6611e2e4d07ea9c158210a Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 15:22:40 +0100 Subject: [PATCH 13/17] TypeRow in FunctionType --- hugr-py/src/hugr/_tys.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index bcf59128c..fa550e446 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -226,8 +226,8 @@ def to_serial(self) -> stys.Alias: @dataclass(frozen=True) class FunctionType(Type): - input: list[Type] - output: list[Type] + input: TypeRow + output: TypeRow extension_reqs: ExtensionSet = field(default_factory=ExtensionSet) def to_serial(self) -> stys.FunctionType: From 85855de47343eca8995e21a647c2d733144beb27 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 15:36:39 +0100 Subject: [PATCH 14/17] feat: add_successor for much easier directed cfg building --- hugr-py/src/hugr/_cfg.py | 39 ++++++++++++++++++++++++++++----------- hugr-py/src/hugr/_ops.py | 3 +++ hugr-py/tests/test_cfg.py | 32 ++++++++++++++------------------ 3 files changed, 45 insertions(+), 29 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 7a2d3e6d4..97847967a 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -96,18 +96,35 @@ def add_block(self, input_types: TypeRow) -> Block: ) return new_block + def add_successor(self, pred: Wire) -> Block: + pred = pred.out_port() + block = self.hugr._get_typed_op(pred.node, ops.DataflowBlock) + inputs = block.nth_outputs(pred.offset) + b = self.add_block(inputs) + + self.branch(pred, b) + return b + def branch(self, src: Wire, dst: ToNode) -> None: + # TODO check for existing link/type compatibility + if dst.to_node() == self.exit: + return self.branch_exit(src) src = src.out_port() self.hugr.add_link(src, dst.inp(0)) - if dst == self.exit: - src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock) - out_types = [*src_block.sum_rows[src.offset], *src_block.other_outputs] - if self._exit_op._cfg_outputs is not None: - if self._exit_op._cfg_outputs != out_types: - raise MismatchedExit(src.node.idx) - else: - self._exit_op._cfg_outputs = out_types - self.parent_op.signature = replace( - self.parent_op.signature, output=out_types - ) + def branch_exit(self, src: Wire) -> None: + src = src.out_port() + self.hugr.add_link(src, self.exit.inp(0)) + + src_block: ops.DataflowBlock = self.hugr._get_typed_op( + src.node, ops.DataflowBlock + ) + out_types = src_block.nth_outputs(src.offset) + if self._exit_op._cfg_outputs is not None: + if self._exit_op._cfg_outputs != out_types: + raise MismatchedExit(src.node.idx) + else: + self._exit_op._cfg_outputs = out_types + self.parent_op.signature = replace( + self.parent_op.signature, output=out_types + ) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index bc818a769..303726442 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -337,6 +337,9 @@ def _set_out_types(self, types: tys.TypeRow) -> None: def _inputs(self) -> tys.TypeRow: return self.inputs + def nth_outputs(self, n: int) -> tys.TypeRow: + return [*self.sum_rows[n], *self.other_outputs] + @dataclass class ExitBlock(Op): diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index e2c28038a..c8fc5f511 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -23,18 +23,15 @@ def test_branch() -> None: entry = cfg.add_entry() entry.set_block_outputs(*entry.inputs()) - middle_1 = cfg.add_block([tys.Unit, INT_T]) + middle_1 = cfg.add_successor(entry[0]) middle_1.set_block_outputs(*middle_1.inputs()) - middle_2 = cfg.add_block([tys.Unit, INT_T]) + middle_2 = cfg.add_successor(entry[1]) u, i = middle_2.inputs() n = middle_2.add(DivMod(i, i)) middle_2.set_block_outputs(u, n[0]) - cfg.branch(entry[0], middle_1) - cfg.branch(entry[1], middle_2) - - cfg.branch(middle_1[0], cfg.exit) - cfg.branch(middle_2[0], cfg.exit) + cfg.branch_exit(middle_1[0]) + cfg.branch_exit(middle_2[0]) _validate(cfg.hugr) @@ -58,16 +55,13 @@ def test_dom_edge() -> None: # entry dominates both middles so Unit type can be used as inter-graph # value between basic blocks - middle_1 = cfg.add_block([INT_T]) + middle_1 = cfg.add_successor(entry[0]) middle_1.set_block_outputs(u, *middle_1.inputs()) - middle_2 = cfg.add_block([INT_T]) + middle_2 = cfg.add_successor(entry[1]) middle_2.set_block_outputs(u, *middle_2.inputs()) - cfg.branch(entry[0], middle_1) - cfg.branch(entry[1], middle_2) - - cfg.branch(middle_1[0], cfg.exit) - cfg.branch(middle_2[0], cfg.exit) + cfg.branch_exit(middle_1[0]) + cfg.branch_exit(middle_2[0]) _validate(cfg.hugr) @@ -81,13 +75,15 @@ def test_asymm_types() -> None: tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(i)) entry.set_block_outputs(tagged_int) - middle = cfg.add_block([INT_T]) + middle = cfg.add_successor(entry[0]) # discard the int and return the bool from entry middle.set_block_outputs(u, b) # middle expects an int and exit expects a bool - cfg.branch(entry[0], middle) - cfg.branch(entry[1], cfg.exit) - cfg.branch(middle[0], cfg.exit) + cfg.branch_exit(entry[1]) + cfg.branch_exit(middle[0]) _validate(cfg.hugr) + + +# TODO loop From 80e68324c71e8d65725e8ae788773decd47002de Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 15:37:56 +0100 Subject: [PATCH 15/17] fix: incorrect typevar use in _get_typed_op --- hugr-py/src/hugr/_cfg.py | 4 +--- hugr-py/src/hugr/_hugr.py | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 97847967a..429c93de0 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -116,9 +116,7 @@ def branch_exit(self, src: Wire) -> None: src = src.out_port() self.hugr.add_link(src, self.exit.inp(0)) - src_block: ops.DataflowBlock = self.hugr._get_typed_op( - src.node, ops.DataflowBlock - ) + src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock) out_types = src_block.nth_outputs(src.offset) if self._exit_op._cfg_outputs is not None: if self._exit_op._cfg_outputs != out_types: diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 1dfede794..a54d35e6a 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -133,6 +133,7 @@ def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: P = TypeVar("P", InPort, OutPort) K = TypeVar("K", InPort, OutPort) OpVar = TypeVar("OpVar", bound=Op) +OpVar2 = TypeVar("OpVar2", bound=Op) @dataclass(frozen=True, eq=True, order=True) @@ -189,7 +190,7 @@ def __iter__(self): def __len__(self) -> int: return self.num_nodes() - def _get_typed_op(self, node: ToNode, cl: PyType[OpVar]) -> OpVar: + def _get_typed_op(self, node: ToNode, cl: PyType[OpVar2]) -> OpVar2: op = self[node].op assert isinstance(op, cl) return op From 4c616c040a8eebf08a5dfb555c5521477b670031 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 15:39:06 +0100 Subject: [PATCH 16/17] copy when flipping function type --- hugr-py/src/hugr/_cfg.py | 13 +++++++------ hugr-py/src/hugr/_tys.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 429c93de0..d26eec4bc 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -97,14 +97,16 @@ def add_block(self, input_types: TypeRow) -> Block: return new_block def add_successor(self, pred: Wire) -> Block: - pred = pred.out_port() - block = self.hugr._get_typed_op(pred.node, ops.DataflowBlock) - inputs = block.nth_outputs(pred.offset) - b = self.add_block(inputs) + b = self.add_block(self._nth_outputs(pred)) self.branch(pred, b) return b + def _nth_outputs(self, wire: Wire) -> TypeRow: + port = wire.out_port() + block = self.hugr._get_typed_op(port.node, ops.DataflowBlock) + return block.nth_outputs(port.offset) + def branch(self, src: Wire, dst: ToNode) -> None: # TODO check for existing link/type compatibility if dst.to_node() == self.exit: @@ -116,8 +118,7 @@ def branch_exit(self, src: Wire) -> None: src = src.out_port() self.hugr.add_link(src, self.exit.inp(0)) - src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock) - out_types = src_block.nth_outputs(src.offset) + out_types = self._nth_outputs(src) if self._exit_op._cfg_outputs is not None: if self._exit_op._cfg_outputs != out_types: raise MismatchedExit(src.node.idx) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index fa550e446..7b0fa2335 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -238,7 +238,7 @@ def empty(cls) -> FunctionType: return cls(input=[], output=[]) def flip(self) -> FunctionType: - return FunctionType(input=self.output, output=self.input) + return FunctionType(input=list(self.output), output=list(self.input)) @dataclass(frozen=True) From 28dc387db77d97576c90e18a4809a920d07508e4 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 18 Jun 2024 16:48:44 +0100 Subject: [PATCH 17/17] fix unpack type annotation Co-authored-by: Mark Koch <48097969+mark-koch@users.noreply.github.com> --- hugr-py/src/hugr/_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 303726442..026619890 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -168,7 +168,7 @@ def set_in_types(self, types: tys.TypeRow) -> None: @dataclass() class UnpackTupleDef(DataflowOp, PartialOp): - _types: tys.TypeRow = field(default_factory=list) + _types: tys.TypeRow | None = None @property def types(self) -> tys.TypeRow: