From 5a59da359ee7a098ce069db5cdebd5eb98ec9781 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Mon, 2 Sep 2024 14:15:28 +0100 Subject: [PATCH] feat: Allow calling of methods (#440) Closes #439. * Add a new `PartialApply` node to create a closure capturing the `self` argument of methods * Calls of partial applies are simplified to direct calls * Higher-order usage of methods is only allowed if `self` is not linear --- guppylang/checker/expr_checker.py | 26 +++++++++++++++++++ guppylang/checker/linearity_checker.py | 14 ++++++++++ guppylang/compiler/expr_compiler.py | 11 ++++++++ guppylang/nodes.py | 16 ++++++++++++ tests/error/linear_errors/method_capture.err | 7 +++++ tests/error/linear_errors/method_capture.py | 26 +++++++++++++++++++ tests/integration/test_call.py | 11 ++++++++ tests/integration/test_higher_order.py | 11 ++++++++ tests/integration/test_struct.py | 27 ++++++++++++++++++++ 9 files changed, 149 insertions(+) create mode 100644 tests/error/linear_errors/method_capture.err create mode 100644 tests/error/linear_errors/method_capture.py diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index cef50fb4..98d31daa 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -61,6 +61,7 @@ IterNext, LocalCall, MakeIter, + PartialApply, PlaceNode, PyExpr, TensorCall, @@ -239,6 +240,12 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: if isinstance(defn, CallableDef): return defn.check_call(node.args, ty, node, self.ctx) + # When calling a `PartialApply` node, we just move the args into this call + if isinstance(node.func, PartialApply): + node.args = [*node.func.args, *node.args] + node.func = node.func.func + return self.visit_Call(node, ty) + # Otherwise, it must be a function as a higher-order value - something # whose type is either a FunctionType or a Tuple of FunctionTypes if isinstance(func_ty, FunctionType): @@ -371,6 +378,19 @@ def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]: # you loose access to all fields besides `a`). expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field) return with_loc(node, expr), field.ty + elif func := self.ctx.globals.get_instance_func(ty, node.attr): + name = with_type( + func.ty, with_loc(node, GlobalName(id=func.name, def_id=func.id)) + ) + # Make a closure by partially applying the `self` argument + # TODO: Try to infer some type args based on `self` + result_ty = FunctionType( + func.ty.inputs[1:], + func.ty.output, + func.ty.input_names[1:] if func.ty.input_names else None, + func.ty.params, + ) + return with_loc(node, PartialApply(func=name, args=[node.value])), result_ty raise GuppyTypeError( f"Expression of type `{ty}` has no attribute `{node.attr}`", # Unfortunately, `node.attr` doesn't contain source annotations, so we have @@ -517,6 +537,12 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: if isinstance(defn, CallableDef): return defn.synthesize_call(node.args, node, self.ctx) + # When calling a `PartialApply` node, we just move the args into this call + if isinstance(node.func, PartialApply): + node.args = [*node.func.args, *node.args] + node.func = node.func.func + return self.visit_Call(node) + # Otherwise, it must be a function as a higher-order value, or a tensor if isinstance(ty, FunctionType): args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx) diff --git a/guppylang/checker/linearity_checker.py b/guppylang/checker/linearity_checker.py index 534246c3..d25550d6 100644 --- a/guppylang/checker/linearity_checker.py +++ b/guppylang/checker/linearity_checker.py @@ -30,6 +30,7 @@ GlobalCall, InoutReturnSentinel, LocalCall, + PartialApply, PlaceNode, TensorCall, ) @@ -199,6 +200,19 @@ def visit_TensorCall(self, node: TensorCall) -> None: self.visit(arg) self._reassign_inout_args(node.tensor_ty, node.args) + def visit_PartialApply(self, node: PartialApply) -> None: + self.visit(node.func) + for arg in node.args: + ty = get_type(arg) + if ty.linear: + raise GuppyError( + f"Capturing a value with linear type `{ty}` in a closure is not " + "allowed. Try calling the function directly instead of using it as " + "a higher-order value.", + node, + ) + self.visit(arg) + def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None: # A field access on a value that is not a place. This means the value can no # longer be accessed after the field has been projected out. Thus, this is only diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 207c61d8..58d701d7 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -28,6 +28,7 @@ GlobalCall, GlobalName, LocalCall, + PartialApply, PlaceNode, ResultExpr, TensorCall, @@ -324,6 +325,16 @@ def visit_GlobalCall(self, node: GlobalCall) -> Wire: def visit_Call(self, node: ast.Call) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") + def visit_PartialApply(self, node: PartialApply) -> Wire: + from guppylang.compiler.func_compiler import make_partial_op + + func_ty = get_type(node.func) + assert isinstance(func_ty, FunctionType) + op = make_partial_op(func_ty, [get_type(arg) for arg in node.args]) + return self.builder.add_op( + op, self.visit(node.func), *(self.visit(arg) for arg in node.args) + ) + def visit_TypeApply(self, node: TypeApply) -> Wire: # For now, we can only TypeApply global FunctionDefs/Decls. if not isinstance(node.value, GlobalName): diff --git a/guppylang/nodes.py b/guppylang/nodes.py index a79b41d4..197738a8 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -80,6 +80,22 @@ class TypeApply(ast.expr): ) +class PartialApply(ast.expr): + """A partial function application. + + This node is emitted when methods are loaded as values, since this requires + partially applying the `self` argument. + """ + + func: ast.expr + args: list[ast.expr] + + _fields = ( + "func", + "args", + ) + + class FieldAccessAndDrop(ast.expr): """A field access on a struct, dropping all the remaining other fields.""" diff --git a/tests/error/linear_errors/method_capture.err b/tests/error/linear_errors/method_capture.err new file mode 100644 index 00000000..6a1d5416 --- /dev/null +++ b/tests/error/linear_errors/method_capture.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:22 + +20: @guppy(module) +21: def foo(s: Struct) -> Struct: +22: f = s.foo + ^^^^^ +GuppyError: Capturing a value with linear type `Struct` in a closure is not allowed. Try calling the function directly instead of using it as a higher-order value. diff --git a/tests/error/linear_errors/method_capture.py b/tests/error/linear_errors/method_capture.py new file mode 100644 index 00000000..191777ea --- /dev/null +++ b/tests/error/linear_errors/method_capture.py @@ -0,0 +1,26 @@ +import guppylang.prelude.quantum as quantum +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.prelude.quantum import qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.struct(module) +class Struct: + q: qubit + + @guppy(module) + def foo(self: "Struct") -> "Struct": + return self + + +@guppy(module) +def foo(s: Struct) -> Struct: + f = s.foo + return f() + + +module.compile() diff --git a/tests/integration/test_call.py b/tests/integration/test_call.py index 82fe8c5d..79664bfa 100644 --- a/tests/integration/test_call.py +++ b/tests/integration/test_call.py @@ -66,3 +66,14 @@ def bar(x: int) -> int: return y validate(module.compile()) + + +def test_method_call(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int) -> int: + return x.__add__(2) + + validate(module.compile()) + diff --git a/tests/integration/test_higher_order.py b/tests/integration/test_higher_order.py index 7df447fe..cfd78b8d 100644 --- a/tests/integration/test_higher_order.py +++ b/tests/integration/test_higher_order.py @@ -55,6 +55,17 @@ def baz(y: int) -> None: validate(module.compile()) +def test_method(validate): + module = GuppyModule("module") + + @guppy(module) + def foo(x: int) -> tuple[int, Callable[[int], int]]: + f = x.__add__ + return f(1), f + + validate(module.compile()) + + def test_nested(validate): @compile_guppy def foo(x: int) -> Callable[[int], bool]: diff --git a/tests/integration/test_struct.py b/tests/integration/test_struct.py index a048a274..b1e64bab 100644 --- a/tests/integration/test_struct.py +++ b/tests/integration/test_struct.py @@ -108,6 +108,33 @@ def main(a: StructA[StructA[float]], b: StructB[bool, int], c: StructC) -> None: validate(module.compile()) +def test_methods(validate): + module = GuppyModule("module") + + @guppy.struct(module) + class StructA: + x: int + + @guppy(module) + def foo(self: "StructA", y: int) -> int: + return 2 * self.x + y + + @guppy.struct(module) + class StructB: + x: int + y: float + + @guppy(module) + def bar(self: "StructB", a: StructA) -> float: + return a.foo(self.x) + self.y + + @guppy(module) + def main(a: StructA, b: StructB) -> tuple[int, float]: + return a.foo(1), b.bar(a) + + validate(module.compile()) + + def test_higher_order(validate): module = GuppyModule("module") T = guppy.type_var(module, "T")