Skip to content

Commit

Permalink
feat(hugr-py): add lift
Browse files Browse the repository at this point in the history
keeping extensions to just DFG for now since we expect lift/extension annotations on containers to go away soon
  • Loading branch information
ss2165 committed Jun 21, 2024
1 parent 69431ba commit 5153df6
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 5 deletions.
7 changes: 5 additions & 2 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypeArg,
FunctionKind,
PolyFuncType,
ExtensionSet,
)

from ._exceptions import NoSiblingAncestor
Expand Down Expand Up @@ -237,8 +238,10 @@ def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:


class Dfg(_DfBase[ops.DFG]):
def __init__(self, *input_types: Type) -> None:
parent_op = ops.DFG(list(input_types))
def __init__(
self, *input_types: Type, extension_delta: ExtensionSet | None = None
) -> None:
parent_op = ops.DFG(list(input_types), None, extension_delta or [])
super().__init__(parent_op)


Expand Down
27 changes: 26 additions & 1 deletion hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,15 @@ def _inputs(self) -> tys.TypeRow: ...
class DFG(DfParentOp, DataflowOp):
inputs: tys.TypeRow
_outputs: tys.TypeRow | None = None
extension_delta: tys.ExtensionSet = field(default_factory=list)

@property
def outputs(self) -> tys.TypeRow:
return _check_complete(self._outputs)

@property
def signature(self) -> tys.FunctionType:
return tys.FunctionType(self.inputs, self.outputs)
return tys.FunctionType(self.inputs, self.outputs, self.extension_delta)

@property
def num_out(self) -> int | None:
Expand Down Expand Up @@ -746,3 +747,27 @@ def set_in_types(self, types: tys.TypeRow) -> None:


Noop = NoopDef()


@dataclass
class Lift(DataflowOp, PartialOp):
new_extension: tys.ExtensionId
_type_row: tys.TypeRow | None = None
num_out: int | None = 1

@property
def type_row(self) -> tys.TypeRow:
return _check_complete(self._type_row)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Lift:
return sops.Lift(
parent=parent.idx,
new_extension=self.new_extension,
type_row=ser_it(self.type_row),
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType.endo(self.type_row)

def set_in_types(self, types: tys.TypeRow) -> None:
self._type_row = types
6 changes: 5 additions & 1 deletion hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,11 @@ class FunctionType(Type):
extension_reqs: ExtensionSet = field(default_factory=ExtensionSet)

def to_serial(self) -> stys.FunctionType:
return stys.FunctionType(input=ser_it(self.input), output=ser_it(self.output))
return stys.FunctionType(
input=ser_it(self.input),
output=ser_it(self.output),
extension_reqs=self.extension_reqs,
)

@classmethod
def empty(cls) -> FunctionType:
Expand Down
8 changes: 7 additions & 1 deletion hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:

def deserialize(self) -> _ops.DFG:
sig = self.signature.deserialize()
return _ops.DFG(sig.input, sig.output)
return _ops.DFG(sig.input, sig.output, sig.extension_reqs)


# ------------------------------------------------
Expand Down Expand Up @@ -590,6 +590,12 @@ class Lift(DataflowOp):
type_row: TypeRow
new_extension: ExtensionId

def deserialize(self) -> _ops.Lift:
return _ops.Lift(
_type_row=deser_it(self.type_row),
new_extension=self.new_extension,
)


class AliasDecl(BaseOp):
op: Literal["AliasDecl"] = "AliasDecl"
Expand Down
8 changes: 8 additions & 0 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,11 @@ def test_higher_order() -> None:
d.set_outputs(call)

_validate(d.hugr)


def test_lift() -> None:
d = Dfg(tys.Qubit, extension_delta=["X"])
(q,) = d.inputs()
lift = d.add(ops.Lift("X")(q))
d.set_outputs(lift)
_validate(d.hugr)

0 comments on commit 5153df6

Please sign in to comment.