diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/_cfg.py index 2001b53a5..d26eec4bc 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/_cfg.py @@ -1,14 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Iterable, Sequence +from dataclasses import dataclass, replace import hugr._ops as ops from ._dfg import _DfBase -from ._exceptions import NoSiblingAncestor, NotInSameCfg +from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire -from ._tys import FunctionType, Sum, TypeRow +from ._tys import FunctionType, TypeRow, Type class Block(_DfBase[ops.DataflowBlock]): @@ -19,99 +18,112 @@ def set_single_successor_outputs(self, *outputs: Wire) -> None: # TODO requires constants raise NotImplementedError - def _wire_up(self, node: Node, ports: Iterable[Wire]): - for i, p in enumerate(ports): - src = p.out_port() - cfg_node = self.hugr[self.root].parent - assert cfg_node is not None - src_parent = self.hugr[src.node].parent - try: - self._wire_up_port(node, i, p) - except NoSiblingAncestor: - # note this just checks if there is a common CFG ancestor - # it does not check for valid dominance between basic blocks - # that is deferred to full HUGR validation. - while cfg_node != src_parent: - if src_parent is None or src_parent == self.hugr.root: - raise NotInSameCfg(src.node.idx, node.idx) - src_parent = self.hugr[src_parent].parent - - self.hugr.add_link(src, node.inp(i)) + def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: + src = p.out_port() + cfg_node = self.hugr[self.parent_node].parent + assert cfg_node is not None + src_parent = self.hugr[src.node].parent + try: + super()._wire_up_port(node, offset, p) + except NoSiblingAncestor: + # note this just checks if there is a common CFG ancestor + # it does not check for valid dominance between basic blocks + # that is deferred to full HUGR validation. + while cfg_node != src_parent: + if src_parent is None or src_parent == self.hugr.root: + raise NotInSameCfg(src.node.idx, node.idx) + src_parent = self.hugr[src_parent].parent + + self.hugr.add_link(src, node.inp(offset)) + return self._get_dataflow_type(src) @dataclass -class Cfg(ParentBuilder): +class Cfg(ParentBuilder[ops.CFG]): hugr: Hugr - root: Node + parent_node: Node _entry_block: Block exit: Node - def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None: - root_op = ops.CFG(FunctionType(input=input_types, output=output_types)) + def __init__(self, input_types: TypeRow) -> None: + root_op = ops.CFG(FunctionType(input=input_types, output=[])) hugr = Hugr(root_op) - self._init_impl(hugr, hugr.root, input_types, output_types) + self._init_impl(hugr, hugr.root, input_types) - def _init_impl( - self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow, output_types: TypeRow - ) -> None: + def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None: self.hugr = hugr - self.root = root + self.parent_node = root # to ensure entry is first child, add a dummy entry at the start - self._entry_block = Block.new_nested( - ops.DataflowBlock(input_types, []), hugr, root - ) + self._entry_block = Block.new_nested(ops.DataflowBlock(input_types), hugr, root) - self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root) + self.exit = self.hugr.add_node(ops.ExitBlock(), self.parent_node) @classmethod def new_nested( cls, input_types: TypeRow, - output_types: TypeRow, hugr: Hugr, parent: ToNode | None = None, ) -> Cfg: new = cls.__new__(cls) root = hugr.add_node( - ops.CFG(FunctionType(input=input_types, output=output_types)), + ops.CFG(FunctionType(input=input_types, output=[])), parent or hugr.root, ) - new._init_impl(hugr, root, input_types, output_types) + new._init_impl(hugr, root, input_types) return new @property def entry(self) -> Node: - return self._entry_block.root + return self._entry_block.parent_node + @property def _entry_op(self) -> ops.DataflowBlock: - dop = self.hugr[self.entry].op - assert isinstance(dop, ops.DataflowBlock) - return dop - - def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block: - # update entry block types - self._entry_op().sum_rows = list(sum_rows) - self._entry_op().other_outputs = other_outputs - self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs] - return self._entry_block + return self.hugr._get_typed_op(self.entry, ops.DataflowBlock) + + @property + def _exit_op(self) -> ops.ExitBlock: + return self.hugr._get_typed_op(self.exit, ops.ExitBlock) - def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block: - return self.add_entry([[]] * n_branches, other_outputs) + def add_entry(self) -> Block: + return self._entry_block - def add_block( - self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow - ) -> Block: + def add_block(self, input_types: TypeRow) -> Block: new_block = Block.new_nested( - ops.DataflowBlock(input_types, list(sum_rows), other_outputs), + ops.DataflowBlock(input_types), self.hugr, - self.root, + self.parent_node, ) return new_block - def simple_block( - self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow - ) -> Block: - return self.add_block(input_types, [[]] * n_branches, other_outputs) + def add_successor(self, pred: Wire) -> Block: + b = self.add_block(self._nth_outputs(pred)) + + self.branch(pred, b) + return b + + def _nth_outputs(self, wire: Wire) -> TypeRow: + port = wire.out_port() + block = self.hugr._get_typed_op(port.node, ops.DataflowBlock) + return block.nth_outputs(port.offset) def branch(self, src: Wire, dst: ToNode) -> None: - self.hugr.add_link(src.out_port(), dst.inp(0)) + # TODO check for existing link/type compatibility + if dst.to_node() == self.exit: + return self.branch_exit(src) + src = src.out_port() + self.hugr.add_link(src, dst.inp(0)) + + def branch_exit(self, src: Wire) -> None: + src = src.out_port() + self.hugr.add_link(src, self.exit.inp(0)) + + out_types = self._nth_outputs(src) + if self._exit_op._cfg_outputs is not None: + if self._exit_op._cfg_outputs != out_types: + raise MismatchedExit(src.node.idx) + else: + self._exit_op._cfg_outputs = out_types + self.parent_op.signature = replace( + self.parent_op.signature, output=out_types + ) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index 80b6ad550..7e89b91bc 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -1,13 +1,19 @@ from __future__ import annotations +from dataclasses import dataclass, replace +from typing import ( + Iterable, + TYPE_CHECKING, + TypeVar, +) +from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder -from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Iterable, TypeVar, cast from typing_extensions import Self import hugr._ops as ops -from hugr._tys import FunctionType, TypeRow +from hugr._tys import TypeRow from ._exceptions import NoSiblingAncestor -from ._hugr import Hugr, Node, OutPort, ParentBuilder, Wire, ToNode +from ._hugr import ToNode +from hugr._tys import Type if TYPE_CHECKING: from ._cfg import Cfg @@ -17,105 +23,109 @@ @dataclass() -class _DfBase(ParentBuilder, Generic[DP]): +class _DfBase(ParentBuilder[DP]): hugr: Hugr - root: Node + parent_node: Node input_node: Node output_node: Node - def __init__(self, root_op: DP) -> None: - self.hugr = Hugr(root_op) - self.root = self.hugr.root - self._init_io_nodes(root_op) + def __init__(self, parent_op: DP) -> None: + self.hugr = Hugr(parent_op) + self.parent_node = self.hugr.root + self._init_io_nodes(parent_op) + + def _init_io_nodes(self, parent_op: DP): + inputs = parent_op._inputs() - def _init_io_nodes(self, root_op: DP): - input_types = root_op.input_types() - output_types = root_op.output_types() self.input_node = self.hugr.add_node( - ops.Input(input_types), self.root, len(input_types) + ops.Input(inputs), self.parent_node, len(inputs) ) - self.output_node = self.hugr.add_node(ops.Output(output_types), self.root) + self.output_node = self.hugr.add_node(ops.Output(), self.parent_node) @classmethod - def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self: + def new_nested( + cls, parent_op: DP, hugr: Hugr, parent: ToNode | None = None + ) -> Self: new = cls.__new__(cls) new.hugr = hugr - new.root = hugr.add_node(root_op, parent or hugr.root) - new._init_io_nodes(root_op) + new.parent_node = hugr.add_node(parent_op, parent or hugr.root) + new._init_io_nodes(parent_op) return new def _input_op(self) -> ops.Input: - dop = self.hugr[self.input_node].op - assert isinstance(dop, ops.Input) - return dop + return self.hugr._get_typed_op(self.input_node, ops.Input) def _output_op(self) -> ops.Output: - dop = self.hugr[self.output_node].op - assert isinstance(dop, ops.Output) - return dop - - def root_op(self) -> DP: - return cast(DP, self.hugr[self.root].op) + return self.hugr._get_typed_op(self.output_node, ops.Output) def inputs(self) -> list[OutPort]: return [self.input_node.out(i) for i in range(len(self._input_op().types))] - def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node: - new_n = self.hugr.add_node(op, self.root, num_outs=num_outs) + def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node: + new_n = self.hugr.add_node(op, self.parent_node) self._wire_up(new_n, args) - return new_n + + return replace(new_n, _num_out_ports=op.num_out) def add(self, com: ops.Command) -> Node: - return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out) + return self.add_op(com.op, *com.incoming) def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: - mapping = self.hugr.insert_hugr(dfg.hugr, self.root) - self._wire_up(mapping[dfg.root], args) - return mapping[dfg.root] + mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node) + self._wire_up(mapping[dfg.parent_node], args) + return mapping[dfg.parent_node] def add_nested( self, - input_types: TypeRow, - output_types: TypeRow, *args: Wire, ) -> Dfg: from ._dfg import Dfg - root_op = ops.DFG(FunctionType(input=input_types, output=output_types)) - dfg = Dfg.new_nested(root_op, self.hugr, self.root) - self._wire_up(dfg.root, args) + input_types = [self._get_dataflow_type(w) for w in args] + + parent_op = ops.DFG(list(input_types)) + dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node) + self._wire_up(dfg.parent_node, args) return dfg def add_cfg( self, input_types: TypeRow, - output_types: TypeRow, *args: Wire, ) -> Cfg: from ._cfg import Cfg - cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.root) - self._wire_up(cfg.root, args) + cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node) + self._wire_up(cfg.parent_node, args) return cfg def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: - mapping = self.hugr.insert_hugr(cfg.hugr, self.root) - self._wire_up(mapping[cfg.root], args) - return mapping[cfg.root] + mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node) + self._wire_up(mapping[cfg.parent_node], args) + return mapping[cfg.parent_node] def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) + self.parent_op._set_out_types(self._output_op().types) def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges self.hugr.add_link(src.out(-1), dst.inp(-1)) def _wire_up(self, node: Node, ports: Iterable[Wire]): - for i, p in enumerate(ports): - self._wire_up_port(node, i, p) - - def _wire_up_port(self, node: Node, offset: int, p: Wire): + tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] + if isinstance(op := self.hugr[node].op, ops.PartialOp): + op.set_in_types(tys) + + def _get_dataflow_type(self, wire: Wire) -> Type: + port = wire.out_port() + ty = self.hugr.port_type(port) + if ty is None: + raise ValueError(f"Port {port} is not a dataflow port.") + return ty + + def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: src = p.out_port() node_ancestor = _ancestral_sibling(self.hugr, src.node, node) if node_ancestor is None: @@ -123,16 +133,13 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire): if node_ancestor != node: self.add_state_order(src.node, node_ancestor) self.hugr.add_link(src, node.inp(offset)) + return self._get_dataflow_type(src) class Dfg(_DfBase[ops.DFG]): - def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None: - root_op = ops.DFG(FunctionType(input=input_types, output=output_types)) - super().__init__(root_op) - - @classmethod - def endo(cls, types: TypeRow) -> Dfg: - return cls(types, types) + def __init__(self, *input_types: Type) -> None: + parent_op = ops.DFG(list(input_types)) + super().__init__(parent_op) def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None: diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py index 92ba7ceb0..f003fc6c7 100644 --- a/hugr-py/src/hugr/_exceptions.py +++ b/hugr-py/src/hugr/_exceptions.py @@ -21,5 +21,21 @@ def msg(self): return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up." +@dataclass +class MismatchedExit(Exception): + src: int + + @property + def msg(self): + return ( + f"Exit branch from node {self.src} does not match existing exit block type." + ) + + class ParentBeforeChild(Exception): msg: str = "Parent node must be added before child node." + + +@dataclass +class IncompleteOp(Exception): + msg: str = "Operation is incomplete, may require set_in_types to be called." diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 416daac4f..a54d35e6a 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -12,11 +12,13 @@ TypeVar, cast, overload, + Type as PyType, ) from typing_extensions import Self -from hugr._ops import Op +from hugr._ops import Op, DataflowOp +from hugr._tys import Type, Kind from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap @@ -130,6 +132,8 @@ def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: P = TypeVar("P", InPort, OutPort) K = TypeVar("K", InPort, OutPort) +OpVar = TypeVar("OpVar", bound=Op) +OpVar2 = TypeVar("OpVar2", bound=Op) @dataclass(frozen=True, eq=True, order=True) @@ -145,22 +149,26 @@ def next_sub_offset(self) -> Self: _SI = _SubPort[InPort] -class ParentBuilder(ToNode, Protocol): - hugr: Hugr - root: Node +class ParentBuilder(ToNode, Protocol[OpVar]): + hugr: Hugr[OpVar] + parent_node: Node def to_node(self) -> Node: - return self.root + return self.parent_node + + @property + def parent_op(self) -> OpVar: + return cast(OpVar, self.hugr[self.parent_node].op) @dataclass() -class Hugr(Mapping[Node, NodeData]): +class Hugr(Mapping[Node, NodeData], Generic[OpVar]): root: Node _nodes: list[NodeData | None] _links: BiMap[_SO, _SI] _free_nodes: list[Node] - def __init__(self, root_op: Op) -> None: + def __init__(self, root_op: OpVar) -> None: self._free_nodes = [] self._links = BiMap() self._nodes = [] @@ -182,6 +190,11 @@ def __iter__(self): def __len__(self) -> int: return self.num_nodes() + def _get_typed_op(self, node: ToNode, cl: PyType[OpVar2]) -> OpVar2: + op = self[node].op + assert isinstance(op, cl) + return op + def children(self, node: ToNode | None = None) -> list[Node]: node = node or self.root return self[node].children @@ -261,6 +274,9 @@ def delete_link(self, src: OutPort, dst: InPort) -> None: return # TODO make sure sub-offset is handled correctly + def root_op(self) -> OpVar: + return cast(OpVar, self[self.root].op) + def num_nodes(self) -> int: return len(self._nodes) - len(self._free_nodes) @@ -333,6 +349,15 @@ def num_outgoing(self, node: ToNode) -> int: # TODO: num_links and _linked_ports + def port_kind(self, port: InPort | OutPort) -> Kind: + return self[port.node].op.port_kind(port) + + def port_type(self, port: InPort | OutPort) -> Type | None: + op = self[port.node].op + if isinstance(op, DataflowOp): + return op.port_type(port) + return None + def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]: mapping: dict[Node, Node] = {} diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 6b6986081..026619890 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -1,16 +1,18 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Generic, Protocol, TypeVar, TYPE_CHECKING +from typing import Generic, Protocol, TypeVar, TYPE_CHECKING, runtime_checkable from hugr.serialization.ops import BaseOp import hugr.serialization.ops as sops from hugr.utils import ser_it import hugr._tys as tys +from ._exceptions import IncompleteOp if TYPE_CHECKING: - from hugr._hugr import Hugr, Node, Wire + from hugr._hugr import Hugr, Node, Wire, InPort, OutPort +@runtime_checkable class Op(Protocol): @property def num_out(self) -> int | None: @@ -18,13 +20,38 @@ def num_out(self) -> int | None: def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... + def port_kind(self, port: InPort | OutPort) -> tys.Kind: ... + + +@runtime_checkable +class DataflowOp(Op, Protocol): + def outer_signature(self) -> tys.FunctionType: ... + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + if port.offset == -1: + return tys.OrderKind() + return tys.ValueKind(self.port_type(port)) + + def port_type(self, port: InPort | OutPort) -> tys.Type: + from hugr._hugr import Direction + + sig = self.outer_signature() + if port.direction == Direction.INCOMING: + return sig.input[port.offset] + return sig.output[port.offset] + def __call__(self, *args) -> Command: return Command(self, list(args)) +@runtime_checkable +class PartialOp(Protocol): + def set_in_types(self, types: tys.TypeRow) -> None: ... + + @dataclass(frozen=True) class Command: - op: Op + op: DataflowOp incoming: list[Wire] @@ -41,9 +68,12 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T: root.parent = parent.idx return root + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + raise NotImplementedError + @dataclass() -class Input(Op): +class Input(DataflowOp): types: tys.TypeRow @property @@ -53,20 +83,35 @@ def num_out(self) -> int | None: def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: return sops.Input(parent=parent.idx, types=ser_it(self.types)) + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(input=[], output=self.types) + def __call__(self) -> Command: return super().__call__() @dataclass() -class Output(Op): - types: tys.TypeRow +class Output(DataflowOp, PartialOp): + _types: tys.TypeRow | None = None + + @property + def types(self) -> tys.TypeRow: + if self._types is None: + raise IncompleteOp() + return self._types def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: return sops.Output(parent=parent.idx, types=ser_it(self.types)) + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(input=self.types, output=[]) + + def set_in_types(self, types: tys.TypeRow) -> None: + self._types = types + @dataclass() -class Custom(Op): +class Custom(DataflowOp): op_name: str signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) description: str = "" @@ -87,12 +132,21 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: args=ser_it(self.args), ) + def outer_signature(self) -> tys.FunctionType: + return self.signature + @dataclass() -class MakeTuple(Op): - types: tys.TypeRow +class MakeTupleDef(DataflowOp, PartialOp): + _types: tys.TypeRow | None = None num_out: int | None = 1 + @property + def types(self) -> tys.TypeRow: + if self._types is None: + raise IncompleteOp() + return self._types + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: return sops.MakeTuple( parent=parent.idx, @@ -102,10 +156,25 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: def __call__(self, *elements: Wire) -> Command: return super().__call__(*elements) + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)]) + + def set_in_types(self, types: tys.TypeRow) -> None: + self._types = types + + +MakeTuple = MakeTupleDef() + @dataclass() -class UnpackTuple(Op): - types: tys.TypeRow +class UnpackTupleDef(DataflowOp, PartialOp): + _types: tys.TypeRow | None = None + + @property + def types(self) -> tys.TypeRow: + if self._types is None: + raise IncompleteOp() + return self._types @property def num_out(self) -> int | None: @@ -120,15 +189,58 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple: def __call__(self, tuple_: Wire) -> Command: return super().__call__(tuple_) + def outer_signature(self) -> tys.FunctionType: + return MakeTupleDef(self.types).outer_signature().flip() + + def set_in_types(self, types: tys.TypeRow) -> None: + (t,) = types + assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}" + (row,) = t.variant_rows + self._types = row + + +UnpackTuple = UnpackTupleDef() + + +@dataclass() +class Tag(DataflowOp): + tag: int + variants: list[tys.TypeRow] + num_out: int | None = 1 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Tag: + return sops.Tag( + parent=parent.idx, + tag=self.tag, + variants=[ser_it(r) for r in self.variants], + ) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType( + input=self.variants[self.tag], output=[tys.Sum(self.variants)] + ) + + def __call__(self, value: Wire) -> Command: + return super().__call__(value) + class DfParentOp(Op, Protocol): - def input_types(self) -> tys.TypeRow: ... - def output_types(self) -> tys.TypeRow: ... + def inner_signature(self) -> tys.FunctionType: ... + + def _set_out_types(self, types: tys.TypeRow) -> None: ... + + def _inputs(self) -> tys.TypeRow: ... @dataclass() -class DFG(DfParentOp): - signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) +class DFG(DfParentOp, DataflowOp): + _signature: tys.TypeRow | tys.FunctionType + + @property + def signature(self) -> tys.FunctionType: + if isinstance(self._signature, tys.FunctionType): + return self._signature + raise IncompleteOp() @property def num_out(self) -> int | None: @@ -140,15 +252,26 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: signature=self.signature.to_serial(), ) - def input_types(self) -> tys.TypeRow: - return self.signature.input + def inner_signature(self) -> tys.FunctionType: + return self.signature - def output_types(self) -> tys.TypeRow: - return self.signature.output + def outer_signature(self) -> tys.FunctionType: + return self.signature + + def _set_out_types(self, types: tys.TypeRow) -> None: + assert isinstance(self._signature, list), "Signature has already been set." + self._signature = tys.FunctionType(self._signature, types) + + def _inputs(self) -> tys.TypeRow: + match self._signature: + case tys.FunctionType(input, _): + return input + case list(_): + return self._signature @dataclass() -class CFG(Op): +class CFG(DataflowOp): signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) @property @@ -161,14 +284,29 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CFG: signature=self.signature.to_serial(), ) + def outer_signature(self) -> tys.FunctionType: + return self.signature + @dataclass class DataflowBlock(DfParentOp): inputs: tys.TypeRow - sum_rows: list[tys.TypeRow] - other_outputs: tys.TypeRow = field(default_factory=list) + _sum_rows: list[tys.TypeRow] | None = None + _other_outputs: tys.TypeRow | None = None extension_delta: tys.ExtensionSet = field(default_factory=list) + @property + def sum_rows(self) -> list[tys.TypeRow]: + if self._sum_rows is None: + raise IncompleteOp() + return self._sum_rows + + @property + def other_outputs(self) -> tys.TypeRow: + if self._other_outputs is None: + raise IncompleteOp() + return self._other_outputs + @property def num_out(self) -> int | None: return len(self.sum_rows) @@ -182,20 +320,43 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DataflowBlock: extension_delta=self.extension_delta, ) - def input_types(self) -> tys.TypeRow: + def inner_signature(self) -> tys.FunctionType: + return tys.FunctionType( + input=self.inputs, output=[tys.Sum(self.sum_rows), *self.other_outputs] + ) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + return tys.CFKind() + + def _set_out_types(self, types: tys.TypeRow) -> None: + (sum_, *other) = types + assert isinstance(sum_, tys.Sum), f"Expected Sum, got {sum_}" + self._sum_rows = sum_.variant_rows + self._other_outputs = other + + def _inputs(self) -> tys.TypeRow: return self.inputs - def output_types(self) -> tys.TypeRow: - return [tys.Sum(self.sum_rows), *self.other_outputs] + def nth_outputs(self, n: int) -> tys.TypeRow: + return [*self.sum_rows[n], *self.other_outputs] @dataclass class ExitBlock(Op): - cfg_outputs: tys.TypeRow + _cfg_outputs: tys.TypeRow | None = None num_out: int | None = 0 + @property + def cfg_outputs(self) -> tys.TypeRow: + if self._cfg_outputs is None: + raise IncompleteOp() + return self._cfg_outputs + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock: return sops.ExitBlock( parent=parent.idx, cfg_outputs=ser_it(self.cfg_outputs), ) + + def port_kind(self, port: InPort | OutPort) -> tys.Kind: + return tys.CFKind() diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 6f199959c..7b0fa2335 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -159,14 +159,6 @@ def to_serial(self) -> stys.Array: return stys.Array(ty=self.ty.to_serial_root(), len=self.size) -@dataclass(frozen=True) -class UnitSum(Type): - size: int - - def to_serial(self) -> stys.UnitSum: - return stys.UnitSum(size=self.size) - - @dataclass() class Sum(Type): variant_rows: list[TypeRow] @@ -181,6 +173,18 @@ def as_tuple(self) -> Tuple: return Tuple(*self.variant_rows[0]) +@dataclass() +class UnitSum(Sum): + size: int + + def __init__(self, size: int): + self.size = size + super().__init__(variant_rows=[[]] * size) + + def to_serial(self) -> stys.UnitSum: # type: ignore[override] + return stys.UnitSum(size=self.size) + + @dataclass() class Tuple(Sum): def __init__(self, *tys: Type): @@ -222,8 +226,8 @@ def to_serial(self) -> stys.Alias: @dataclass(frozen=True) class FunctionType(Type): - input: list[Type] - output: list[Type] + input: TypeRow + output: TypeRow extension_reqs: ExtensionSet = field(default_factory=ExtensionSet) def to_serial(self) -> stys.FunctionType: @@ -233,6 +237,9 @@ def to_serial(self) -> stys.FunctionType: def empty(cls) -> FunctionType: return cls(input=[], output=[]) + def flip(self) -> FunctionType: + return FunctionType(input=list(self.output), output=list(self.input)) + @dataclass(frozen=True) class PolyFuncType(Type): @@ -270,3 +277,29 @@ def to_serial(self) -> stys.Qubit: Qubit = QubitDef() Bool = UnitSum(size=2) Unit = UnitSum(size=1) + + +@dataclass(frozen=True) +class ValueKind: + ty: Type + + +@dataclass(frozen=True) +class ConstKind: + ty: Type + + +@dataclass(frozen=True) +class FunctionKind: + ty: PolyFuncType + + +@dataclass(frozen=True) +class CFKind: ... + + +@dataclass(frozen=True) +class OrderKind: ... + + +Kind = ValueKind | ConstKind | FunctionKind | CFKind | OrderKind diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index fbeb3e8c2..049aa1d9f 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -229,7 +229,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.types = list(in_types) def deserialize(self) -> _ops.Output: - return _ops.Output(types=deser_it(self.types)) + return _ops.Output(deser_it(self.types)) class Call(DataflowOp): @@ -447,8 +447,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: in_types = [] self.tys = list(in_types) - def deserialize(self) -> _ops.MakeTuple: - return _ops.MakeTuple(deser_it(self.tys)) + def deserialize(self) -> _ops.MakeTupleDef: + return _ops.MakeTupleDef(deser_it(self.tys)) class UnpackTuple(DataflowOp): @@ -460,8 +460,8 @@ class UnpackTuple(DataflowOp): def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.tys = list(out_types) - def deserialize(self) -> _ops.UnpackTuple: - return _ops.UnpackTuple(deser_it(self.tys)) + def deserialize(self) -> _ops.UnpackTupleDef: + return _ops.UnpackTupleDef(deser_it(self.tys)) class Tag(DataflowOp): diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 8d41ef8c0..c8fc5f511 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -1,47 +1,45 @@ from hugr._cfg import Cfg import hugr._tys as tys from hugr._dfg import Dfg +import hugr._ops as ops from .test_hugr_build import _validate, INT_T, DivMod def build_basic_cfg(cfg: Cfg) -> None: - entry = cfg.simple_entry(1, [tys.Bool]) + entry = cfg.add_entry() entry.set_block_outputs(*entry.inputs()) cfg.branch(entry[0], cfg.exit) def test_basic_cfg() -> None: - cfg = Cfg([tys.Unit, tys.Bool], [tys.Bool]) + cfg = Cfg([tys.Unit, tys.Bool]) build_basic_cfg(cfg) _validate(cfg.hugr) def test_branch() -> None: - cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) - entry = cfg.simple_entry(2, [tys.Unit, INT_T]) + cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + entry = cfg.add_entry() entry.set_block_outputs(*entry.inputs()) - middle_1 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + middle_1 = cfg.add_successor(entry[0]) middle_1.set_block_outputs(*middle_1.inputs()) - middle_2 = cfg.simple_block([tys.Unit, INT_T], 1, [INT_T]) + middle_2 = cfg.add_successor(entry[1]) u, i = middle_2.inputs() n = middle_2.add(DivMod(i, i)) middle_2.set_block_outputs(u, n[0]) - cfg.branch(entry[0], middle_1) - cfg.branch(entry[1], middle_2) - - cfg.branch(middle_1[0], cfg.exit) - cfg.branch(middle_2[0], cfg.exit) + cfg.branch_exit(middle_1[0]) + cfg.branch_exit(middle_2[0]) _validate(cfg.hugr) def test_nested_cfg() -> None: - dfg = Dfg([tys.Unit, tys.Bool], [tys.Bool]) + dfg = Dfg(tys.Unit, tys.Bool) - cfg = dfg.add_cfg([tys.Unit, tys.Bool], [tys.Bool], *dfg.inputs()) + cfg = dfg.add_cfg([tys.Unit, tys.Bool], *dfg.inputs()) build_basic_cfg(cfg) dfg.set_outputs(cfg) @@ -50,22 +48,42 @@ def test_nested_cfg() -> None: def test_dom_edge() -> None: - cfg = Cfg([tys.Bool, tys.Unit, INT_T], [INT_T]) - entry = cfg.simple_entry(2, [INT_T]) + cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + entry = cfg.add_entry() b, u, i = entry.inputs() entry.set_block_outputs(b, i) # entry dominates both middles so Unit type can be used as inter-graph # value between basic blocks - middle_1 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_1 = cfg.add_successor(entry[0]) middle_1.set_block_outputs(u, *middle_1.inputs()) - middle_2 = cfg.simple_block([INT_T], 1, [INT_T]) + middle_2 = cfg.add_successor(entry[1]) middle_2.set_block_outputs(u, *middle_2.inputs()) - cfg.branch(entry[0], middle_1) - cfg.branch(entry[1], middle_2) + cfg.branch_exit(middle_1[0]) + cfg.branch_exit(middle_2[0]) + + _validate(cfg.hugr) + + +def test_asymm_types() -> None: + # test different types going to entry block's susccessors + cfg = Cfg([tys.Bool, tys.Unit, INT_T]) + entry = cfg.add_entry() + b, u, i = entry.inputs() + + tagged_int = entry.add(ops.Tag(0, [[INT_T], [tys.Bool]])(i)) + entry.set_block_outputs(tagged_int) + + middle = cfg.add_successor(entry[0]) + # discard the int and return the bool from entry + middle.set_block_outputs(u, b) - cfg.branch(middle_1[0], cfg.exit) - cfg.branch(middle_2[0], cfg.exit) + # middle expects an int and exit expects a bool + cfg.branch_exit(entry[1]) + cfg.branch_exit(middle[0]) _validate(cfg.hugr) + + +# TODO loop diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 167f28d2f..ff7a2759b 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -85,7 +85,7 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): def test_stable_indices(): - h = Hugr(ops.DFG()) + h = Hugr(ops.DFG([])) nodes = [h.add_node(Not) for _ in range(3)] assert len(h) == 4 @@ -120,7 +120,7 @@ def test_stable_indices(): def test_simple_id(): - h = Dfg.endo([tys.Qubit] * 2) + h = Dfg(tys.Qubit, tys.Qubit) a, b = h.inputs() h.set_outputs(a, b) @@ -128,7 +128,7 @@ def test_simple_id(): def test_multiport(): - h = Dfg([tys.Bool], [tys.Bool] * 2) + h = Dfg(tys.Bool) (a,) = h.inputs() h.set_outputs(a, a) in_n, ou_n = h.input_node, h.output_node @@ -151,7 +151,7 @@ def test_multiport(): def test_add_op(): - h = Dfg.endo([tys.Bool]) + h = Dfg(tys.Bool) (a,) = h.inputs() nt = h.add_op(Not, a) h.set_outputs(nt) @@ -161,25 +161,25 @@ def test_add_op(): def test_tuple(): row = [tys.Bool, tys.Qubit] - h = Dfg.endo(row) + h = Dfg(*row) a, b = h.inputs() - t = h.add(ops.MakeTuple(row)(a, b)) - a, b = h.add(ops.UnpackTuple(row)(t)) + t = h.add(ops.MakeTuple(a, b)) + a, b = h.add(ops.UnpackTuple(t)) h.set_outputs(a, b) _validate(h.hugr) - h1 = Dfg.endo(row) + h1 = Dfg(*row) a, b = h1.inputs() - mt = h1.add_op(ops.MakeTuple(row), a, b) - a, b = h1.add_op(ops.UnpackTuple(row), mt)[0, 1] + mt = h1.add_op(ops.MakeTuple, a, b) + a, b = h1.add_op(ops.UnpackTuple, mt)[0, 1] h1.set_outputs(a, b) assert h.hugr.to_serial() == h1.hugr.to_serial() def test_multi_out(): - h = Dfg([INT_T] * 2, [INT_T] * 2) + h = Dfg(INT_T, INT_T) a, b = h.inputs() a, b = h.add(DivMod(a, b)) h.set_outputs(a, b) @@ -187,25 +187,25 @@ def test_multi_out(): def test_insert(): - h1 = Dfg.endo([tys.Bool]) + h1 = Dfg(tys.Bool) (a1,) = h1.inputs() nt = h1.add(Not(a1)) h1.set_outputs(nt) assert len(h1.hugr) == 4 - new_h = Hugr(ops.DFG()) + new_h = Hugr(ops.DFG([])) mapping = h1.hugr.insert_hugr(new_h, h1.hugr.root) assert mapping == {new_h.root: Node(4)} def test_insert_nested(): - h1 = Dfg.endo([tys.Bool]) + h1 = Dfg(tys.Bool) (a1,) = h1.inputs() nt = h1.add(Not(a1)) h1.set_outputs(nt) - h = Dfg.endo([tys.Bool]) + h = Dfg(tys.Bool) (a,) = h.inputs() nested = h.insert_nested(h1, a) h.set_outputs(nested) @@ -219,9 +219,9 @@ def _nested_nop(dfg: Dfg): nt = dfg.add(Not(a1)) dfg.set_outputs(nt) - h = Dfg.endo([tys.Bool]) + h = Dfg(tys.Bool) (a,) = h.inputs() - nested = h.add_nested([tys.Bool], [tys.Bool], a) + nested = h.add_nested(a) _nested_nop(nested) assert len(h.hugr.children(nested)) == 3 @@ -231,9 +231,9 @@ def _nested_nop(dfg: Dfg): def test_build_inter_graph(): - h = Dfg.endo([tys.Bool, tys.Bool]) + h = Dfg(tys.Bool, tys.Bool) (a, b) = h.inputs() - nested = h.add_nested([], [tys.Bool]) + nested = h.add_nested() nt = nested.add(Not(a)) nested.set_outputs(nt) @@ -250,10 +250,10 @@ def test_build_inter_graph(): def test_ancestral_sibling(): - h = Dfg.endo([tys.Bool]) + h = Dfg(tys.Bool) (a,) = h.inputs() - nested = h.add_nested([], [tys.Bool]) + nested = h.add_nested() nt = nested.add(Not(a)) - assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.root + assert _ancestral_sibling(h.hugr, h.input_node, nt) == nested.parent_node