Skip to content

Commit

Permalink
fix: Loading custom polymorphic function defs as values (#260)
Browse files Browse the repository at this point in the history
Fixes #259
  • Loading branch information
mark-koch authored Jun 25, 2024
1 parent b68224b commit d15b2f5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 19 deletions.
34 changes: 15 additions & 19 deletions guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from guppylang.definition.common import ParsableDef
from guppylang.definition.value import CompiledCallableDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr_builder.hugr import Hugr, Node, OutPortV
from guppylang.hugr_builder.hugr import Hugr, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row
Expand Down Expand Up @@ -146,26 +146,22 @@ def load_with_args(
)
assert len(self.ty.params) == len(type_args)

# Find the module node by walking up the hierarchy
module: Node = dfg.node
while not isinstance(module.op, ops.Module):
if module.parent is None:
raise InternalGuppyError(
"Encountered node that is not contained in a module."
)
module = module.parent

# We create a `FunctionDef` that takes some inputs, compiles a call to the
# function, and returns the results.
def_node = graph.add_def(self.ty, module, self.name)
_, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node)
returns = self.compile_call(
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node
)
graph.add_output(returns, parent=def_node)
# function, and returns the results. If the function signature is polymorphic,
# we explicitly monomorphise here and invoke the call compiler with the
# inferred type args.
fun_ty = self.ty.instantiate(type_args)
def_node = graph.add_def(fun_ty, dfg.node, self.name)
with graph.parent(def_node):
_, inp_ports = graph.add_input_with_ports(list(fun_ty.inputs))
returns = self.compile_call(
inp_ports, type_args, DFContainer(def_node, {}), graph, globals, node
)
graph.add_output(returns)

# Finally, load the function into the local DFG
return graph.add_load_constant(def_node.out_port(0), dfg.node).out_port(0)
# Finally, load the function into the local DFG. We already monomorphised, so we
# can load with empty type args
return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0)


class CustomCallChecker(ABC):
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/test_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest

from guppylang.decorator import guppy
from guppylang.definition.custom import CustomCallCompiler
from guppylang.hugr_builder.hugr import OutPortV
from guppylang.module import GuppyModule
from guppylang.prelude.builtins import array
from guppylang.prelude.quantum import qubit
Expand Down Expand Up @@ -274,6 +276,23 @@ def main() -> None:
validate(module.compile())


def test_custom_higher_order():
class CustomCompiler(CustomCallCompiler):
def compile(self, args: list[OutPortV]) -> list[OutPortV]:
return args

module = GuppyModule("test")
T = guppy.type_var(module, "T")

@guppy.custom(module, CustomCompiler())
def foo(x: T) -> T: ...

@guppy(module)
def main(x: int) -> int:
f: Callable[[int], int] = foo
return f(x)


@pytest.mark.skip("Not yet supported")
def test_higher_order_value(validate):
module = GuppyModule("test")
Expand Down

0 comments on commit d15b2f5

Please sign in to comment.