Skip to content

Commit

Permalink
fix(hugr-py): ops require their own extensions
Browse files Browse the repository at this point in the history
Fixes #1301
  • Loading branch information
ss2165 committed Jul 15, 2024
1 parent cfb0674 commit c3339fb
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 8 deletions.
4 changes: 3 additions & 1 deletion hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def to_custom(self) -> Custom:
return Custom(
"idivmod_u",
tys.FunctionType(
input=[int_t(self.arg1)] * 2, output=[int_t(self.arg2)] * 2
input=[int_t(self.arg1)] * 2,
output=[int_t(self.arg2)] * 2,
extension_reqs=[OPS_EXTENSION],
),
extension=OPS_EXTENSION,
args=[tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)],
Expand Down
3 changes: 2 additions & 1 deletion hugr-py/src/hugr/std/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class _NotDef(AsCustomOp):
"""Not operation."""

def to_custom(self) -> Custom:
return Custom("Not", tys.FunctionType.endo([tys.Bool]), extension=EXTENSION_ID)
sig = tys.FunctionType.endo([tys.Bool], [EXTENSION_ID])
return Custom("Not", sig, extension=EXTENSION_ID)

def __call__(self, a: ComWire) -> Command:
return DataflowOp.__call__(self, a)
Expand Down
8 changes: 6 additions & 2 deletions hugr-py/src/hugr/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,18 @@ def empty(cls) -> FunctionType:
return cls(input=[], output=[])

@classmethod
def endo(cls, tys: TypeRow) -> FunctionType:
def endo(
cls, tys: TypeRow, extension_reqs: ExtensionSet | None = None
) -> FunctionType:
"""Function type with the same input and output types.
Example:
>>> FunctionType.endo([Qubit])
FunctionType([Qubit], [Qubit])
"""
return cls(input=tys, output=tys)
return cls(
input=tys, output=tys, extension_reqs=extension_reqs or ExtensionSet()
)

def flip(self) -> FunctionType:
"""Return a new function type with input and output types swapped.
Expand Down
16 changes: 12 additions & 4 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __call__(self, q: ComWire) -> Command:
def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo([tys.Qubit]),
tys.FunctionType.endo([tys.Qubit], extension_reqs=[QUANTUM_EXTENSION_ID]),
extension=QUANTUM_EXTENSION_ID,
)

Expand All @@ -70,7 +70,9 @@ class _Enum(Enum):
def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo([tys.Qubit] * 2),
tys.FunctionType.endo(
[tys.Qubit] * 2, extension_reqs=[QUANTUM_EXTENSION_ID]
),
extension=QUANTUM_EXTENSION_ID,
)

Expand All @@ -90,7 +92,11 @@ class MeasureDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Measure",
tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]),
tys.FunctionType(
[tys.Qubit],
[tys.Qubit, tys.Bool],
extension_reqs=[QUANTUM_EXTENSION_ID],
),
extension=QUANTUM_EXTENSION_ID,
)

Expand All @@ -106,7 +112,9 @@ class RzDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Rz",
tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]),
tys.FunctionType(
[tys.Qubit, FLOAT_T], [tys.Qubit], extension_reqs=[QUANTUM_EXTENSION_ID]
),
extension=QUANTUM_EXTENSION_ID,
)

Expand Down

0 comments on commit c3339fb

Please sign in to comment.