From 5153df67bad951f916a028a87ff5492921bde48b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 21 Jun 2024 16:22:25 +0100 Subject: [PATCH] feat(hugr-py): add lift keeping extensions to just DFG for now since we expect lift/extension annotations on containers to go away soon --- hugr-py/src/hugr/_dfg.py | 7 +++++-- hugr-py/src/hugr/_ops.py | 27 ++++++++++++++++++++++++++- hugr-py/src/hugr/_tys.py | 6 +++++- hugr-py/src/hugr/serialization/ops.py | 8 +++++++- hugr-py/tests/test_hugr_build.py | 8 ++++++++ 5 files changed, 51 insertions(+), 5 deletions(-) diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index f9ab05137..8fd42f846 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -20,6 +20,7 @@ TypeArg, FunctionKind, PolyFuncType, + ExtensionSet, ) from ._exceptions import NoSiblingAncestor @@ -237,8 +238,10 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type: class Dfg(_DfBase[ops.DFG]): - def __init__(self, *input_types: Type) -> None: - parent_op = ops.DFG(list(input_types)) + def __init__( + self, *input_types: Type, extension_delta: ExtensionSet | None = None + ) -> None: + parent_op = ops.DFG(list(input_types), None, extension_delta or []) super().__init__(parent_op) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 0a2ffc087..807525a51 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -233,6 +233,7 @@ def _inputs(self) -> tys.TypeRow: ... class DFG(DfParentOp, DataflowOp): inputs: tys.TypeRow _outputs: tys.TypeRow | None = None + extension_delta: tys.ExtensionSet = field(default_factory=list) @property def outputs(self) -> tys.TypeRow: @@ -240,7 +241,7 @@ def outputs(self) -> tys.TypeRow: @property def signature(self) -> tys.FunctionType: - return tys.FunctionType(self.inputs, self.outputs) + return tys.FunctionType(self.inputs, self.outputs, self.extension_delta) @property def num_out(self) -> int | None: @@ -746,3 +747,27 @@ def set_in_types(self, types: tys.TypeRow) -> None: Noop = NoopDef() + + +@dataclass +class Lift(DataflowOp, PartialOp): + new_extension: tys.ExtensionId + _type_row: tys.TypeRow | None = None + num_out: int | None = 1 + + @property + def type_row(self) -> tys.TypeRow: + return _check_complete(self._type_row) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Lift: + return sops.Lift( + parent=parent.idx, + new_extension=self.new_extension, + type_row=ser_it(self.type_row), + ) + + def outer_signature(self) -> tys.FunctionType: + return tys.FunctionType.endo(self.type_row) + + def set_in_types(self, types: tys.TypeRow) -> None: + self._type_row = types diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py index 6e2e3584e..cfbb7c294 100644 --- a/hugr-py/src/hugr/_tys.py +++ b/hugr-py/src/hugr/_tys.py @@ -234,7 +234,11 @@ class FunctionType(Type): extension_reqs: ExtensionSet = field(default_factory=ExtensionSet) def to_serial(self) -> stys.FunctionType: - return stys.FunctionType(input=ser_it(self.input), output=ser_it(self.output)) + return stys.FunctionType( + input=ser_it(self.input), + output=ser_it(self.output), + extension_reqs=self.extension_reqs, + ) @classmethod def empty(cls) -> FunctionType: diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index dc867ddf9..c43acd543 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -383,7 +383,7 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: def deserialize(self) -> _ops.DFG: sig = self.signature.deserialize() - return _ops.DFG(sig.input, sig.output) + return _ops.DFG(sig.input, sig.output, sig.extension_reqs) # ------------------------------------------------ @@ -590,6 +590,12 @@ class Lift(DataflowOp): type_row: TypeRow new_extension: ExtensionId + def deserialize(self) -> _ops.Lift: + return _ops.Lift( + _type_row=deser_it(self.type_row), + new_extension=self.new_extension, + ) + class AliasDecl(BaseOp): op: Literal["AliasDecl"] = "AliasDecl" diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 099b1183e..6dcd6de86 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -381,3 +381,11 @@ def test_higher_order() -> None: d.set_outputs(call) _validate(d.hugr) + + +def test_lift() -> None: + d = Dfg(tys.Qubit, extension_delta=["X"]) + (q,) = d.inputs() + lift = d.add(ops.Lift("X")(q)) + d.set_outputs(lift) + _validate(d.hugr)