diff --git a/guppylang/definition/custom.py b/guppylang/definition/custom.py index 67789fe0..b2e86370 100644 --- a/guppylang/definition/custom.py +++ b/guppylang/definition/custom.py @@ -12,7 +12,7 @@ from guppylang.definition.common import ParsableDef from guppylang.definition.value import CompiledCallableDef from guppylang.error import GuppyError, InternalGuppyError -from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV +from guppylang.hugr_builder.hugr import Hugr, OutPortV from guppylang.nodes import GlobalCall from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row @@ -146,26 +146,22 @@ def load_with_args( ) assert len(self.ty.params) == len(type_args) - # Find the module node by walking up the hierarchy - module: Node = dfg.node - while not isinstance(module.op, ops.Module): - if module.parent is None: - raise InternalGuppyError( - "Encountered node that is not contained in a module." - ) - module = module.parent - # We create a `FunctionDef` that takes some inputs, compiles a call to the - # function, and returns the results. - def_node = graph.add_def(self.ty, module, self.name) - _, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node) - returns = self.compile_call( - inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node - ) - graph.add_output(returns, parent=def_node) + # function, and returns the results. If the function signature is polymorphic, + # we explicitly monomorphise here and invoke the call compiler with the + # inferred type args. + fun_ty = self.ty.instantiate(type_args) + def_node = graph.add_def(fun_ty, dfg.node, self.name) + with graph.parent(def_node): + _, inp_ports = graph.add_input_with_ports(list(fun_ty.inputs)) + returns = self.compile_call( + inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node + ) + graph.add_output(returns) - # Finally, load the function into the local DFG - return graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0) + # Finally, load the function into the local DFG. We already monomorphised, so we + # can load with empty type args + return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0) class CustomCallChecker(ABC): diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py index 471f4851..5edb4d79 100644 --- a/tests/integration/test_poly.py +++ b/tests/integration/test_poly.py @@ -3,6 +3,8 @@ import pytest from guppylang.decorator import guppy +from guppylang.definition.custom import CustomCallCompiler +from guppylang.hugr_builder.hugr import OutPortV from guppylang.module import GuppyModule from guppylang.prelude.builtins import array from guppylang.prelude.quantum import qubit @@ -274,6 +276,23 @@ def main() -> None: validate(module.compile()) +def test_custom_higher_order(): + class CustomCompiler(CustomCallCompiler): + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + return args + + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.custom(module, CustomCompiler()) + def foo(x: T) -> T: ... + + @guppy(module) + def main(x: int) -> int: + f: Callable[[int], int] = foo + return f(x) + + @pytest.mark.skip("Not yet supported") def test_higher_order_value(validate): module = GuppyModule("test")