diff --git a/hugr-py/src/hugr/dfg.py b/hugr-py/src/hugr/dfg.py index 99091b6c6..24e2b0647 100644 --- a/hugr-py/src/hugr/dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -215,7 +215,7 @@ def _fn_sig(self, func: ToNode) -> PolyFuncType: def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] - if isinstance(op := self.hugr[node].op, ops.PartialOp): + if isinstance(op := self.hugr[node].op, ops._PartialOp): op.set_in_types(tys) return tys diff --git a/hugr-py/src/hugr/exceptions.py b/hugr-py/src/hugr/exceptions.py index f003fc6c7..e5c211d44 100644 --- a/hugr-py/src/hugr/exceptions.py +++ b/hugr-py/src/hugr/exceptions.py @@ -34,8 +34,3 @@ def msg(self): class ParentBeforeChild(Exception): msg: str = "Parent node must be added before child node." - - -@dataclass -class IncompleteOp(Exception): - msg: str = "Operation is incomplete, may require set_in_types to be called." diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index 7fb9bc70c..68bc8c91b 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -32,8 +32,8 @@ class NodeData: _num_outs: int = 0 children: list[Node] = field(default_factory=list) - def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: - o = self.op.to_serial(node, self.parent if self.parent else node, hugr) + def to_serial(self, node: Node) -> SerialOp: + o = self.op.to_serial(self.parent if self.parent else node) return SerialOp(root=o) # type: ignore[arg-type] @@ -297,7 +297,7 @@ def _serialise_link( return SerialHugr( version="v1", # non contiguous indices will be erased - nodes=[node.to_serial(Node(idx), self) for idx, node in enumerate(node_it)], + nodes=[node.to_serial(Node(idx)) for idx, node in enumerate(node_it)], edges=[_serialise_link(link) for link in self._links.items()], ) diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 07ab362dc..f5f9a50b5 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -1,38 +1,39 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Protocol, TYPE_CHECKING, Sequence, runtime_checkable, TypeVar +from typing import Protocol, Sequence, runtime_checkable, TypeVar from hugr.serialization.ops import BaseOp import hugr.serialization.ops as sops from hugr.utils import ser_it import hugr.tys as tys from hugr.node_port import Node, InPort, OutPort, Wire import hugr.val as val -from .exceptions import IncompleteOp - -if TYPE_CHECKING: - from hugr.hugr import Hugr @dataclass class InvalidPort(Exception): + """Port is not valid for this operation.""" + port: InPort | OutPort + op: Op @property def msg(self) -> str: - return f"Invalid port {self.port}" + return f"Port {self.port} is invalid for operation {self.op}." @runtime_checkable class Op(Protocol): @property - def num_out(self) -> int | None: - return None + def num_out(self) -> int: ... - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... + def to_serial(self, parent: Node) -> BaseOp: ... def port_kind(self, port: InPort | OutPort) -> tys.Kind: ... + def _invalid_port(self, port: InPort | OutPort) -> InvalidPort: + return InvalidPort(port, self) + def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type: from hugr.node_port import Direction @@ -59,10 +60,32 @@ def __call__(self, *args) -> Command: @runtime_checkable -class PartialOp(Protocol): +class _PartialOp(Protocol): def set_in_types(self, types: tys.TypeRow) -> None: ... +@dataclass +class IncompleteOp(Exception): + """Op types have not been set during building.""" + + op: Op + + @property + def msg(self) -> str: + return ( + f"Operation {self.op} is incomplete, may require set_in_types to be called." + ) + + +V = TypeVar("V") + + +def _check_complete(op, v: V | None) -> V: + if v is None: + raise IncompleteOp(op) + return v + + @dataclass(frozen=True) class Command: op: DataflowOp @@ -74,10 +97,10 @@ class Input(DataflowOp): types: tys.TypeRow @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.types) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: + def to_serial(self, parent: Node) -> sops.Input: return sops.Input(parent=parent.idx, types=ser_it(self.types)) def outer_signature(self) -> tys.FunctionType: @@ -87,24 +110,16 @@ def __call__(self) -> Command: return super().__call__() -V = TypeVar("V") - - -def _check_complete(v: V | None) -> V: - if v is None: - raise IncompleteOp() - return v - - @dataclass() -class Output(DataflowOp, PartialOp): +class Output(DataflowOp, _PartialOp): _types: tys.TypeRow | None = None + num_out: int = 0 @property def types(self) -> tys.TypeRow: - return _check_complete(self._types) + return _check_complete(self, self._types) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: + def to_serial(self, parent: Node) -> sops.Output: return sops.Output(parent=parent.idx, types=ser_it(self.types)) def outer_signature(self) -> tys.FunctionType: @@ -123,10 +138,10 @@ class Custom(DataflowOp): args: list[tys.TypeArg] = field(default_factory=list) @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.signature.output) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: + def to_serial(self, parent: Node) -> sops.CustomOp: return sops.CustomOp( parent=parent.idx, extension=self.extension, @@ -141,15 +156,15 @@ def outer_signature(self) -> tys.FunctionType: @dataclass() -class MakeTupleDef(DataflowOp, PartialOp): +class MakeTupleDef(DataflowOp, _PartialOp): _types: tys.TypeRow | None = None - num_out: int | None = 1 + num_out: int = 1 @property def types(self) -> tys.TypeRow: - return _check_complete(self._types) + return _check_complete(self, self._types) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: + def to_serial(self, parent: Node) -> sops.MakeTuple: return sops.MakeTuple( parent=parent.idx, tys=ser_it(self.types), @@ -169,18 +184,18 @@ def set_in_types(self, types: tys.TypeRow) -> None: @dataclass() -class UnpackTupleDef(DataflowOp, PartialOp): +class UnpackTupleDef(DataflowOp, _PartialOp): _types: tys.TypeRow | None = None @property def types(self) -> tys.TypeRow: - return _check_complete(self._types) + return _check_complete(self, self._types) @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.types) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple: + def to_serial(self, parent: Node) -> sops.UnpackTuple: return sops.UnpackTuple( parent=parent.idx, tys=ser_it(self.types), @@ -206,9 +221,9 @@ def set_in_types(self, types: tys.TypeRow) -> None: class Tag(DataflowOp): tag: int sum_ty: tys.Sum - num_out: int | None = 1 + num_out: int = 1 - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Tag: + def to_serial(self, parent: Node) -> sops.Tag: return sops.Tag( parent=parent.idx, tag=self.tag, @@ -237,17 +252,17 @@ class DFG(DfParentOp, DataflowOp): @property def outputs(self) -> tys.TypeRow: - return _check_complete(self._outputs) + return _check_complete(self, self._outputs) @property def signature(self) -> tys.FunctionType: return tys.FunctionType(self.inputs, self.outputs, self.extension_delta) @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.signature.output) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: + def to_serial(self, parent: Node) -> sops.DFG: return sops.DFG( parent=parent.idx, signature=self.signature.to_serial(), @@ -273,17 +288,17 @@ class CFG(DataflowOp): @property def outputs(self) -> tys.TypeRow: - return _check_complete(self._outputs) + return _check_complete(self, self._outputs) @property def signature(self) -> tys.FunctionType: return tys.FunctionType(self.inputs, self.outputs) @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.outputs) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CFG: + def to_serial(self, parent: Node) -> sops.CFG: return sops.CFG( parent=parent.idx, signature=self.signature.to_serial(), @@ -302,17 +317,17 @@ class DataflowBlock(DfParentOp): @property def sum_ty(self) -> tys.Sum: - return _check_complete(self._sum) + return _check_complete(self, self._sum) @property def other_outputs(self) -> tys.TypeRow: - return _check_complete(self._other_outputs) + return _check_complete(self, self._other_outputs) @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.sum_ty.variant_rows) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DataflowBlock: + def to_serial(self, parent: Node) -> sops.DataflowBlock: return sops.DataflowBlock( parent=parent.idx, inputs=ser_it(self.inputs), @@ -344,13 +359,13 @@ def nth_outputs(self, n: int) -> tys.TypeRow: @dataclass class ExitBlock(Op): _cfg_outputs: tys.TypeRow | None = None - num_out: int | None = 0 + num_out: int = 0 @property def cfg_outputs(self) -> tys.TypeRow: - return _check_complete(self._cfg_outputs) + return _check_complete(self, self._cfg_outputs) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock: + def to_serial(self, parent: Node) -> sops.ExitBlock: return sops.ExitBlock( parent=parent.idx, cfg_outputs=ser_it(self.cfg_outputs), @@ -363,9 +378,9 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: @dataclass class Const(Op): val: val.Value - num_out: int | None = 1 + num_out: int = 1 - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Const: + def to_serial(self, parent: Node) -> sops.Const: return sops.Const( parent=parent.idx, v=self.val.to_serial_root(), @@ -376,18 +391,18 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: case OutPort(_, 0): return tys.ConstKind(self.val.type_()) case _: - raise InvalidPort(port) + raise self._invalid_port(port) @dataclass class LoadConst(DataflowOp): typ: tys.Type | None = None - num_out: int | None = 1 + num_out: int = 1 def type_(self) -> tys.Type: - return _check_complete(self.typ) + return _check_complete(self, self.typ) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant: + def to_serial(self, parent: Node) -> sops.LoadConstant: return sops.LoadConstant( parent=parent.idx, datatype=self.type_().to_serial_root(), @@ -403,7 +418,7 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: case OutPort(_, 0): return tys.ValueKind(self.type_()) case _: - raise InvalidPort(port) + raise self._invalid_port(port) @dataclass() @@ -414,7 +429,7 @@ class Conditional(DataflowOp): @property def outputs(self) -> tys.TypeRow: - return _check_complete(self._outputs) + return _check_complete(self, self._outputs) @property def signature(self) -> tys.FunctionType: @@ -422,10 +437,10 @@ def signature(self) -> tys.FunctionType: return tys.FunctionType(inputs, self.outputs) @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.outputs) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Conditional: + def to_serial(self, parent: Node) -> sops.Conditional: return sops.Conditional( parent=parent.idx, sum_rows=[ser_it(r) for r in self.sum_ty.variant_rows], @@ -444,13 +459,13 @@ def nth_inputs(self, n: int) -> tys.TypeRow: class Case(DfParentOp): inputs: tys.TypeRow _outputs: tys.TypeRow | None = None - num_out: int | None = 0 + num_out: int = 0 @property def outputs(self) -> tys.TypeRow: - return _check_complete(self._outputs) + return _check_complete(self, self._outputs) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Case: + def to_serial(self, parent: Node) -> sops.Case: return sops.Case( parent=parent.idx, signature=self.inner_signature().to_serial() ) @@ -459,7 +474,7 @@ def inner_signature(self) -> tys.FunctionType: return tys.FunctionType(self.inputs, self.outputs) def port_kind(self, port: InPort | OutPort) -> tys.Kind: - raise InvalidPort(port) + raise self._invalid_port(port) def _set_out_types(self, types: tys.TypeRow) -> None: self._outputs = types @@ -477,13 +492,13 @@ class TailLoop(DfParentOp, DataflowOp): @property def just_outputs(self) -> tys.TypeRow: - return _check_complete(self._just_outputs) + return _check_complete(self, self._just_outputs) @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.just_outputs) + len(self.rest) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.TailLoop: + def to_serial(self, parent: Node) -> sops.TailLoop: return sops.TailLoop( parent=parent.idx, just_inputs=ser_it(self.just_inputs), @@ -518,11 +533,11 @@ class FuncDefn(DfParentOp): inputs: tys.TypeRow params: list[tys.TypeParam] = field(default_factory=list) _outputs: tys.TypeRow | None = None - num_out: int | None = 1 + num_out: int = 1 @property def outputs(self) -> tys.TypeRow: - return _check_complete(self._outputs) + return _check_complete(self, self._outputs) @property def signature(self) -> tys.PolyFuncType: @@ -530,7 +545,7 @@ def signature(self) -> tys.PolyFuncType: self.params, tys.FunctionType(self.inputs, self.outputs) ) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.FuncDefn: + def to_serial(self, parent: Node) -> sops.FuncDefn: return sops.FuncDefn( parent=parent.idx, name=self.name, @@ -551,16 +566,16 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: case OutPort(_, 0): return tys.FunctionKind(self.signature) case _: - raise InvalidPort(port) + raise self._invalid_port(port) @dataclass class FuncDecl(Op): name: str signature: tys.PolyFuncType - num_out: int | None = 0 + num_out: int = 0 - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.FuncDecl: + def to_serial(self, parent: Node) -> sops.FuncDecl: return sops.FuncDecl( parent=parent.idx, name=self.name, @@ -572,18 +587,18 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: case OutPort(_, 0): return tys.FunctionKind(self.signature) case _: - raise InvalidPort(port) + raise self._invalid_port(port) @dataclass class Module(Op): - num_out: int | None = 0 + num_out: int = 0 - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Module: + def to_serial(self, parent: Node) -> sops.Module: return sops.Module(parent=parent.idx) def port_kind(self, port: InPort | OutPort) -> tys.Kind: - raise InvalidPort(port) + raise self._invalid_port(port) class NoConcreteFunc(Exception): @@ -626,7 +641,7 @@ def __init__( signature, instantiation, type_args ) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Call: + def to_serial(self, parent: Node) -> sops.Call: return sops.Call( parent=parent.idx, func_sig=self.signature.to_serial(), @@ -635,7 +650,7 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Call: ) @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.signature.body.output) def function_port_offset(self) -> int: @@ -650,18 +665,18 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: @dataclass() -class CallIndirectDef(DataflowOp, PartialOp): +class CallIndirectDef(DataflowOp, _PartialOp): _signature: tys.FunctionType | None = None @property - def num_out(self) -> int | None: + def num_out(self) -> int: return len(self.signature.output) @property def signature(self) -> tys.FunctionType: - return _check_complete(self._signature) + return _check_complete(self, self._signature) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CallIndirect: + def to_serial(self, parent: Node) -> sops.CallIndirect: return sops.CallIndirect( parent=parent.idx, signature=self.signature.to_serial(), @@ -692,7 +707,7 @@ class LoadFunc(DataflowOp): signature: tys.PolyFuncType instantiation: tys.FunctionType type_args: list[tys.TypeArg] - num_out: int | None = 1 + num_out: int = 1 def __init__( self, @@ -705,7 +720,7 @@ def __init__( signature, instantiation, type_args ) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadFunction: + def to_serial(self, parent: Node) -> sops.LoadFunction: return sops.LoadFunction( parent=parent.idx, func_sig=self.signature.to_serial(), @@ -723,19 +738,19 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: case OutPort(_, 0): return tys.ValueKind(self.instantiation) case _: - raise InvalidPort(port) + raise self._invalid_port(port) @dataclass -class NoopDef(DataflowOp, PartialOp): +class NoopDef(DataflowOp, _PartialOp): _type: tys.Type | None = None - num_out: int | None = 1 + num_out: int = 1 @property def type_(self) -> tys.Type: - return _check_complete(self._type) + return _check_complete(self, self._type) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Noop: + def to_serial(self, parent: Node) -> sops.Noop: return sops.Noop(parent=parent.idx, ty=self.type_.to_serial_root()) def outer_signature(self) -> tys.FunctionType: @@ -750,16 +765,16 @@ def set_in_types(self, types: tys.TypeRow) -> None: @dataclass -class Lift(DataflowOp, PartialOp): +class Lift(DataflowOp, _PartialOp): new_extension: tys.ExtensionId _type_row: tys.TypeRow | None = None - num_out: int | None = 1 + num_out: int = 1 @property def type_row(self) -> tys.TypeRow: - return _check_complete(self._type_row) + return _check_complete(self, self._type_row) - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Lift: + def to_serial(self, parent: Node) -> sops.Lift: return sops.Lift( parent=parent.idx, new_extension=self.new_extension, @@ -777,9 +792,9 @@ def set_in_types(self, types: tys.TypeRow) -> None: class AliasDecl(Op): name: str bound: tys.TypeBound - num_out: int | None = 0 + num_out: int = 0 - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDecl: + def to_serial(self, parent: Node) -> sops.AliasDecl: return sops.AliasDecl( parent=parent.idx, name=self.name, @@ -787,16 +802,16 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDecl: ) def port_kind(self, port: InPort | OutPort) -> tys.Kind: - raise InvalidPort(port) + raise self._invalid_port(port) @dataclass class AliasDefn(Op): name: str definition: tys.Type - num_out: int | None = 0 + num_out: int = 0 - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDefn: + def to_serial(self, parent: Node) -> sops.AliasDefn: return sops.AliasDefn( parent=parent.idx, name=self.name, @@ -804,4 +819,4 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.AliasDefn: ) def port_kind(self, port: InPort | OutPort) -> tys.Kind: - raise InvalidPort(port) + raise self._invalid_port(port) diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 65706b047..380c0aa7d 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -131,7 +131,7 @@ class TupleValue(BaseValue): vs: list["Value"] def deserialize(self) -> val.Value: - return val.Tuple(deser_it((v.root for v in self.vs))) + return val.Tuple(*deser_it((v.root for v in self.vs))) class SumValue(BaseValue): diff --git a/hugr-py/src/hugr/val.py b/hugr-py/src/hugr/val.py index deb898c09..d69056068 100644 --- a/hugr-py/src/hugr/val.py +++ b/hugr-py/src/hugr/val.py @@ -50,13 +50,18 @@ def bool_value(b: bool) -> Sum: @dataclass -class Tuple(Value): +class Tuple(Sum): vals: list[Value] - def type_(self) -> tys.Tuple: - return tys.Tuple(*(v.type_() for v in self.vals)) + def __init__(self, *vals: Value): + val_list = list(vals) + super().__init__( + tag=0, typ=tys.Tuple(*(v.type_() for v in val_list)), vals=val_list + ) - def to_serial(self) -> sops.TupleValue: + # sops.TupleValue isn't an instance of sops.SumValue + # so mypy doesn't like the override of Sum.to_serial + def to_serial(self) -> sops.TupleValue: # type: ignore[override] return sops.TupleValue( vs=ser_it(self.vals), ) diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 33f792d0b..220aa5037 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -42,7 +42,7 @@ class LogicOps(Custom): # TODO get from YAML @dataclass class NotDef(LogicOps): - num_out: int | None = 1 + num_out: int = 1 op_name: str = "Not" signature: tys.FunctionType = tys.FunctionType.endo([tys.Bool]) @@ -61,7 +61,7 @@ class QuantumOps(Custom): @dataclass class OneQbGate(QuantumOps): op_name: str - num_out: int | None = 1 + num_out: int = 1 signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit]) def __call__(self, q: Wire) -> Command: @@ -74,7 +74,7 @@ def __call__(self, q: Wire) -> Command: @dataclass class MeasureDef(QuantumOps): op_name: str = "Measure" - num_out: int | None = 2 + num_out: int = 2 signature: tys.FunctionType = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) def __call__(self, q: Wire) -> Command: @@ -94,7 +94,7 @@ class IntOps(Custom): @dataclass class DivModDef(IntOps): - num_out: int | None = 2 + num_out: int = 2 extension: tys.ExtensionId = "arithmetic.int" op_name: str = "idivmod_u" signature: tys.FunctionType = field( diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index a01b89d97..b6058a293 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -196,7 +196,7 @@ def test_ancestral_sibling(): [ val.Function(simple_id().hugr), val.Sum(1, tys.Sum([[INT_T], [tys.Bool, INT_T]]), [val.TRUE, IntVal(34)]), - val.Tuple([val.TRUE, IntVal(23)]), + val.Tuple(val.TRUE, IntVal(23)), ], ) def test_vals(val: val.Value):