diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index f0658f397..54fd20888 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -108,7 +108,7 @@ def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: def set_outputs(self, *args: Wire) -> None: self._wire_up(self.output_node, args) - self.parent_op()._set_out_types(self._output_op().types) + self.parent_op()._set_out_types(self._output_op()._types()) def add_state_order(self, src: Node, dst: Node) -> None: # adds edge to the right of all existing edges @@ -116,8 +116,8 @@ def add_state_order(self, src: Node, dst: Node) -> None: def _wire_up(self, node: Node, ports: Iterable[Wire]): tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] - if isinstance(op := self.hugr[node].op, ops.DataflowOp): - op._set_in_types(tys) + if isinstance(op := self.hugr[node].op, ops.PartialOp): + op.set_in_types(tys) def _get_dataflow_type(self, wire: Wire) -> Type: port = wire.out_port() diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/_exceptions.py index e5c211d44..f003fc6c7 100644 --- a/hugr-py/src/hugr/_exceptions.py +++ b/hugr-py/src/hugr/_exceptions.py @@ -34,3 +34,8 @@ def msg(self): class ParentBeforeChild(Exception): msg: str = "Parent node must be added before child node." + + +@dataclass +class IncompleteOp(Exception): + msg: str = "Operation is incomplete, may require set_in_types to be called." diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 149db8680..268394223 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -6,6 +6,7 @@ import hugr.serialization.ops as sops from hugr.utils import ser_it import hugr._tys as tys +from ._exceptions import IncompleteOp if TYPE_CHECKING: from hugr._hugr import Hugr, Node, Wire, _Port @@ -39,13 +40,15 @@ def port_type(self, port: _Port) -> tys.Type: return sig.input[port.offset] return sig.output[port.offset] - def _set_in_types(self, types: tys.TypeRow) -> None: - return - def __call__(self, *args) -> Command: return Command(self, list(args)) +@runtime_checkable +class PartialOp(DataflowOp, Protocol): + def set_in_types(self, types: tys.TypeRow) -> None: ... + + @dataclass(frozen=True) class Command: op: DataflowOp @@ -88,16 +91,21 @@ def __call__(self) -> Command: @dataclass() -class Output(DataflowOp): - types: tys.TypeRow = field(default_factory=list) +class Output(PartialOp): + types: tys.TypeRow | None = None + + def _types(self) -> tys.TypeRow: + if self.types is None: + raise IncompleteOp() + return self.types def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: - return sops.Output(parent=parent.idx, types=ser_it(self.types)) + return sops.Output(parent=parent.idx, types=ser_it(self._types())) def outer_signature(self) -> tys.FunctionType: - return tys.FunctionType(input=self.types, output=[]) + return tys.FunctionType(input=self._types(), output=[]) - def _set_in_types(self, types: tys.TypeRow) -> None: + def set_in_types(self, types: tys.TypeRow) -> None: self.types = types @@ -128,23 +136,28 @@ def outer_signature(self) -> tys.FunctionType: @dataclass() -class MakeTupleDef(DataflowOp): - types: tys.TypeRow = field(default_factory=list) +class MakeTupleDef(PartialOp): + types: tys.TypeRow | None = None num_out: int | None = 1 + def _types(self) -> tys.TypeRow: + if self.types is None: + raise IncompleteOp() + return self.types + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: return sops.MakeTuple( parent=parent.idx, - tys=ser_it(self.types), + tys=ser_it(self._types()), ) def __call__(self, *elements: Wire) -> Command: return super().__call__(*elements) def outer_signature(self) -> tys.FunctionType: - return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)]) + return tys.FunctionType(input=self._types(), output=[tys.Tuple(*self._types())]) - def _set_in_types(self, types: tys.TypeRow) -> None: + def set_in_types(self, types: tys.TypeRow) -> None: self.types = types @@ -152,9 +165,14 @@ def _set_in_types(self, types: tys.TypeRow) -> None: @dataclass() -class UnpackTupleDef(DataflowOp): +class UnpackTupleDef(PartialOp): types: tys.TypeRow = field(default_factory=list) + def _types(self) -> tys.TypeRow: + if self.types is None: + raise IncompleteOp() + return self.types + @property def num_out(self) -> int | None: return len(self.types) @@ -171,7 +189,7 @@ def __call__(self, tuple_: Wire) -> Command: def outer_signature(self) -> tys.FunctionType: return MakeTupleDef(self.types).outer_signature().flip() - def _set_in_types(self, types: tys.TypeRow) -> None: + def set_in_types(self, types: tys.TypeRow) -> None: (t,) = types assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}" (row,) = t.variant_rows @@ -203,8 +221,7 @@ def outer_signature(self) -> tys.FunctionType: class DfParentOp(Op, Protocol): def inner_signature(self) -> tys.FunctionType: ... - def _set_out_types(self, types: tys.TypeRow) -> None: - return + def _set_out_types(self, types: tys.TypeRow) -> None: ... @dataclass()