diff --git a/guppylang/definition/custom.py b/guppylang/definition/custom.py index b2e86370..46fdb915 100644 --- a/guppylang/definition/custom.py +++ b/guppylang/definition/custom.py @@ -97,10 +97,16 @@ def compile_call( @dataclass(frozen=True) -class CustomFunctionDef(RawCustomFunctionDef, CompiledCallableDef): +class CustomFunctionDef(CompiledCallableDef): """A custom function with parsed and checked signature.""" + defined_at: AstNode + call_checker: "CustomCallChecker" + call_compiler: "CustomCallCompiler" ty: FunctionType + higher_order_value: bool + + description: str = field(default="function", init=False) def check_call( self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context @@ -163,6 +169,19 @@ def load_with_args( # can load with empty type args return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0) + def compile_call( + self, + args: list[OutPortV], + type_args: Inst, + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, + ) -> list[OutPortV]: + """Compiles a call to the function.""" + self.call_compiler._setup(type_args, dfg, graph, globals, node) + return self.call_compiler.compile(args) + class CustomCallChecker(ABC): """Abstract base class for custom function call type checkers.""" diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index 0d5a722a..65996f29 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -3,6 +3,7 @@ import textwrap from collections.abc import Sequence from dataclasses import dataclass +from functools import cached_property from typing import Any from guppylang.ast_util import AstNode, annotate_location @@ -14,13 +15,19 @@ Definition, ParsableDef, ) +from guppylang.definition.custom import ( + CustomCallCompiler, + CustomFunctionDef, + DefaultCallChecker, +) from guppylang.definition.parameter import ParamDef from guppylang.definition.ty import TypeDef from guppylang.error import GuppyError, InternalGuppyError +from guppylang.hugr_builder.hugr import OutPortV from guppylang.tys.arg import Argument from guppylang.tys.param import Parameter, check_all_args from guppylang.tys.parsing import type_from_ast -from guppylang.tys.ty import StructType, Type +from guppylang.tys.ty import FunctionType, StructType, Type @dataclass(frozen=True) @@ -186,6 +193,35 @@ def check_instantiate( check_all_args(self.params, args, self.name, loc) return StructType(args, self) + @cached_property + def generated_methods(self) -> list[CustomFunctionDef]: + """Auto-generated methods for this struct.""" + + class ConstructorCompiler(CustomCallCompiler): + """Compiler for the `__new__` constructor method of a struct.""" + + def compile(self, args: list[OutPortV]) -> list[OutPortV]: + return [self.graph.add_make_tuple(args).out_port(0)] + + constructor_sig = FunctionType( + inputs=[f.ty for f in self.fields], + output=StructType( + defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)] + ), + input_names=[f.name for f in self.fields], + params=self.params, + ) + constructor_def = CustomFunctionDef( + id=DefId.fresh(self.id.module), + name="__new__", + defined_at=self.defined_at, + ty=constructor_sig, + call_checker=DefaultCallChecker(), + call_compiler=ConstructorCompiler(), + higher_order_value=True, + ) + return [constructor_def] + def parse_py_class(cls: type) -> ast.ClassDef: """Parses a Python class object into an AST.""" diff --git a/guppylang/module.py b/guppylang/module.py index 423299ca..28f16cd0 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -20,6 +20,7 @@ from guppylang.definition.declaration import RawFunctionDecl from guppylang.definition.function import RawFunctionDef from guppylang.definition.parameter import ParamDef +from guppylang.definition.struct import CheckedStructDef from guppylang.definition.ty import TypeDef from guppylang.error import GuppyError, pretty_errors from guppylang.hugr_builder.hugr import Hugr @@ -180,9 +181,18 @@ def compile(self) -> Hugr: ) self._globals = self._globals.update_defs(type_defs) + # Collect auto-generated methods + generated: dict[DefId, RawDef] = {} + for defn in type_defs.values(): + if isinstance(defn, CheckedStructDef): + self._globals.impls.setdefault(defn.id, {}) + for method_def in defn.generated_methods: + generated[method_def.id] = method_def + self._globals.impls[defn.id][method_def.name] = method_def.id + # Now, we can check all other definitions other_defs = self._check_defs( - self._raw_defs, self._imported_globals | self._globals + self._raw_defs | generated, self._imported_globals | self._globals ) self._globals = self._globals.update_defs(other_defs) diff --git a/tests/error/struct_errors/constructor_arg_mismatch.err b/tests/error/struct_errors/constructor_arg_mismatch.err new file mode 100644 index 00000000..fffbb46f --- /dev/null +++ b/tests/error/struct_errors/constructor_arg_mismatch.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:15 + +13: @guppy(module) +14: def main() -> None: +15: MyStruct(0) + ^ +GuppyTypeError: Expected argument of type `(int, int)`, got `int` diff --git a/tests/error/struct_errors/constructor_arg_mismatch.py b/tests/error/struct_errors/constructor_arg_mismatch.py new file mode 100644 index 00000000..45aec0fd --- /dev/null +++ b/tests/error/struct_errors/constructor_arg_mismatch.py @@ -0,0 +1,18 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x: tuple[int, int] + + +@guppy(module) +def main() -> None: + MyStruct(0) + + +module.compile() diff --git a/tests/error/struct_errors/constructor_arg_mismatch_poly.err b/tests/error/struct_errors/constructor_arg_mismatch_poly.err new file mode 100644 index 00000000..15308959 --- /dev/null +++ b/tests/error/struct_errors/constructor_arg_mismatch_poly.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:19 + +17: @guppy(module) +18: def main() -> None: +19: MyStruct(0, False) + ^^^^^ +GuppyTypeError: Expected argument of type `int`, got `bool` diff --git a/tests/error/struct_errors/constructor_arg_mismatch_poly.py b/tests/error/struct_errors/constructor_arg_mismatch_poly.py new file mode 100644 index 00000000..8e2a609c --- /dev/null +++ b/tests/error/struct_errors/constructor_arg_mismatch_poly.py @@ -0,0 +1,22 @@ +from typing import Generic + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") +T = guppy.type_var(module, "T") + + +@guppy.struct(module) +class MyStruct(Generic[T]): + x: T + y: T + + +@guppy(module) +def main() -> None: + MyStruct(0, False) + + +module.compile() diff --git a/tests/error/struct_errors/constructor_missing_arg.err b/tests/error/struct_errors/constructor_missing_arg.err new file mode 100644 index 00000000..cae3f3b7 --- /dev/null +++ b/tests/error/struct_errors/constructor_missing_arg.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:15 + +13: @guppy(module) +14: def main() -> None: +15: MyStruct() + ^^^^^^^^^^ +GuppyTypeError: Not enough arguments passed (expected 1, got 0) diff --git a/tests/error/struct_errors/constructor_missing_arg.py b/tests/error/struct_errors/constructor_missing_arg.py new file mode 100644 index 00000000..9297305a --- /dev/null +++ b/tests/error/struct_errors/constructor_missing_arg.py @@ -0,0 +1,18 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x: int + + +@guppy(module) +def main() -> None: + MyStruct() + + +module.compile() diff --git a/tests/error/struct_errors/constructor_too_many_args.err b/tests/error/struct_errors/constructor_too_many_args.err new file mode 100644 index 00000000..4483e46d --- /dev/null +++ b/tests/error/struct_errors/constructor_too_many_args.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:15 + +13: @guppy(module) +14: def main() -> None: +15: MyStruct(1, 2, 3) + ^ +GuppyTypeError: Unexpected argument diff --git a/tests/error/struct_errors/constructor_too_many_args.py b/tests/error/struct_errors/constructor_too_many_args.py new file mode 100644 index 00000000..32e24a86 --- /dev/null +++ b/tests/error/struct_errors/constructor_too_many_args.py @@ -0,0 +1,18 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x: int + + +@guppy(module) +def main() -> None: + MyStruct(1, 2, 3) + + +module.compile() diff --git a/tests/integration/test_struct.py b/tests/integration/test_struct.py index a071ca6f..04b8d9ea 100644 --- a/tests/integration/test_struct.py +++ b/tests/integration/test_struct.py @@ -1,8 +1,11 @@ -from typing import Generic +from typing import Generic, TYPE_CHECKING from guppylang.decorator import guppy from guppylang.module import GuppyModule +if TYPE_CHECKING: + from collections.abc import Callable + def test_basic_defs(validate): module = GuppyModule("module") @@ -30,7 +33,10 @@ class DocstringStruct: def main( a: EmptyStruct, b: OneMemberStruct, c: TwoMemberStruct, d: DocstringStruct ) -> None: - pass + EmptyStruct() + OneMemberStruct(42) + TwoMemberStruct((True, 0), 1.0) + DocstringStruct(-1) validate(module.compile()) @@ -48,7 +54,7 @@ class StructB: @guppy(module) def main(a: StructA, b: StructB) -> None: - pass + StructB(a) validate(module.compile()) @@ -66,7 +72,7 @@ class StructB: @guppy(module) def main(a: StructA, b: StructB) -> None: - pass + StructA(b) validate(module.compile()) @@ -92,7 +98,30 @@ class StructB(Generic[S, T]): y: StructA[T] @guppy(module) - def main(a: StructA[StructA[float]], b: StructB[int, bool], c: StructC) -> None: - pass + def main(a: StructA[StructA[float]], b: StructB[bool, int], c: StructC) -> None: + x = StructA((0, False)) + y = StructA((0, -5)) + StructA((0, x)) + StructB(x, a) + StructC(y, StructA((0, [])), StructB(42.0, StructA((4, b)))) + + validate(module.compile()) + + +def test_higher_order(validate): + module = GuppyModule("module") + T = guppy.type_var(module, "T") + + @guppy.struct(module) + class Struct(Generic[T]): + x: T + + @guppy(module) + def factory(mk_struct: "Callable[[int], Struct[int]]", x: int) -> Struct[int]: + return mk_struct(x) + + @guppy(module) + def main() -> None: + factory(Struct, 42) validate(module.compile())