Skip to content

Commit

Permalink
fix: hugr-py not adding extension-reqs on custom ops (#1759)
Browse files Browse the repository at this point in the history
Fixes #1758 

Without these changes, the validation tests fail if we remove the
workaround in `custom.rs`
  • Loading branch information
aborgna-q authored Dec 10, 2024
1 parent 7a49f05 commit 97ba7f4
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 19 deletions.
9 changes: 1 addition & 8 deletions hugr-core/src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,7 @@ impl DataflowOpTrait for OpaqueOp {
}

fn signature(&self) -> Cow<'_, Signature> {
// TODO: Return a borrowed cow once
// https://github.com/CQCL/hugr/issues/1758
// gets fixed
Cow::Owned(
self.signature
.clone()
.with_extension_delta(self.extension.clone()),
)
Cow::Borrowed(&self.signature)
}
}

Expand Down
12 changes: 8 additions & 4 deletions hugr-py/src/hugr/_serialization/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,19 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True):
lower_funcs: list[FixedHugr] = pd.Field(default_factory=list)

def deserialize(self, extension: ext.Extension) -> ext.OpDef:
signature = ext.OpDefSig(
self.signature.deserialize().with_extension_reqs([extension.name])
if self.signature
else None,
self.binary,
)

return extension.add_op_def(
ext.OpDef(
name=self.name,
description=self.description,
misc=self.misc or {},
signature=ext.OpDefSig(
self.signature.deserialize() if self.signature else None,
self.binary,
),
signature=signature,
lower_funcs=[f.deserialize() for f in self.lower_funcs],
)
)
Expand Down
14 changes: 13 additions & 1 deletion hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,13 @@ def _to_serial(self) -> ext_s.FixedHugr:

@dataclass
class OpDefSig:
"""Type signature of an :class:`OpDef`."""
"""Type signature of an :class:`OpDef`.
Args:
poly_func: The polymorphic function type of the operation.
binary: If no static type scheme known, flag indicates a computation of the
signature
"""

#: The polymorphic function type of the operation (type scheme).
poly_func: tys.PolyFuncType | None
Expand Down Expand Up @@ -311,6 +317,12 @@ def add_op_def(self, op_def: OpDef) -> OpDef:
Returns:
The added operation definition, now associated with the extension.
"""
if op_def.signature.poly_func is not None:
# Ensure the op def signature has the extension as a requirement
op_def.signature.poly_func = op_def.signature.poly_func.with_extension_reqs(
[self.name]
)

op_def._extension = self
self.operations[op_def.name] = op_def
return self.operations[op_def.name]
Expand Down
22 changes: 18 additions & 4 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,11 @@ def op_def(self) -> ext.OpDef:
return std.PRELUDE.get_op("MakeTuple")

def cached_signature(self) -> tys.FunctionType | None:
return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)])
return tys.FunctionType(
input=self.types,
output=[tys.Tuple(*self.types)],
extension_reqs=["prelude"],
)

def type_args(self) -> list[tys.TypeArg]:
return [tys.SequenceArg([t.type_arg() for t in self.types])]
Expand Down Expand Up @@ -492,7 +496,11 @@ def op_def(self) -> ext.OpDef:
return std.PRELUDE.get_op("UnpackTuple")

def cached_signature(self) -> tys.FunctionType | None:
return tys.FunctionType(input=[tys.Tuple(*self.types)], output=self.types)
return tys.FunctionType(
input=[tys.Tuple(*self.types)],
output=self.types,
extension_reqs=["prelude"],
)

def type_args(self) -> list[tys.TypeArg]:
return [tys.SequenceArg([t.type_arg() for t in self.types])]
Expand Down Expand Up @@ -1266,10 +1274,16 @@ def op_def(self) -> ext.OpDef:
return std.PRELUDE.get_op("Noop")

def cached_signature(self) -> tys.FunctionType | None:
return tys.FunctionType.endo([self.type_])
return tys.FunctionType.endo(
[self.type_],
extension_reqs=["prelude"],
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType.endo([self.type_])
return tys.FunctionType.endo(
[self.type_],
extension_reqs=["prelude"],
)

def _set_in_types(self, types: tys.TypeRow) -> None:
(t,) = types
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def type_args(self) -> list[tys.TypeArg]:

def cached_signature(self) -> tys.FunctionType | None:
row: list[tys.Type] = [int_t(self.width)] * 2
return tys.FunctionType.endo(row)
return tys.FunctionType.endo(row, extension_reqs=[INT_OPS_EXTENSION.name])

@classmethod
def from_ext(cls, custom: ExtOp) -> Self | None:
Expand Down
17 changes: 17 additions & 0 deletions hugr-py/src/hugr/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,14 @@ def resolve(self, registry: ext.ExtensionRegistry) -> FunctionType:
extension_reqs=self.extension_reqs,
)

def with_extension_reqs(self, extension_reqs: ExtensionSet) -> FunctionType:
"""Adds a list of extension requirements to the function type, and
returns the new signature.
"""
exts = set(self.extension_reqs)
exts = exts.union(extension_reqs)
return FunctionType(self.input, self.output, [*exts])

def __str__(self) -> str:
return f"{comma_sep_str(self.input)} -> {comma_sep_str(self.output)}"

Expand Down Expand Up @@ -543,6 +551,15 @@ def resolve(self, registry: ext.ExtensionRegistry) -> PolyFuncType:
body=self.body.resolve(registry),
)

def with_extension_reqs(self, extension_reqs: ExtensionSet) -> PolyFuncType:
"""Adds a list of extension requirements to the function type, and
returns the new signature.
"""
return PolyFuncType(
params=self.params,
body=self.body.with_extension_reqs(extension_reqs),
)

def __str__(self) -> str:
return f"∀ {comma_sep_str(self.params)}. {self.body!s}"

Expand Down
2 changes: 1 addition & 1 deletion hugr-py/tests/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def type_args(self) -> list[tys.TypeArg]:
return [tys.StringArg(self.tag)]

def cached_signature(self) -> tys.FunctionType | None:
return tys.FunctionType.endo([])
return tys.FunctionType.endo([], extension_reqs=[STRINGLY_EXT.name])

@classmethod
def from_ext(cls, custom: ops.ExtOp) -> "StringlyOp":
Expand Down

0 comments on commit 97ba7f4

Please sign in to comment.