Skip to content

Commit

Permalink
feat: Generate constructor methods for structs (#262)
Browse files Browse the repository at this point in the history
Closes #261
  • Loading branch information
mark-koch authored Jun 25, 2024
1 parent d15b2f5 commit f68d0af
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 9 deletions.
21 changes: 20 additions & 1 deletion guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,16 @@ def compile_call(


@dataclass(frozen=True)
class CustomFunctionDef(RawCustomFunctionDef, CompiledCallableDef):
class CustomFunctionDef(CompiledCallableDef):
"""A custom function with parsed and checked signature."""

defined_at: AstNode
call_checker: "CustomCallChecker"
call_compiler: "CustomCallCompiler"
ty: FunctionType
higher_order_value: bool

description: str = field(default="function", init=False)

def check_call(
self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context
Expand Down Expand Up @@ -163,6 +169,19 @@ def load_with_args(
# can load with empty type args
return graph.add_load_function(def_node.out_port(0), [], dfg.node).out_port(0)

def compile_call(
self,
args: list[OutPortV],
type_args: Inst,
dfg: DFContainer,
graph: Hugr,
globals: CompiledGlobals,
node: AstNode,
) -> list[OutPortV]:
"""Compiles a call to the function."""
self.call_compiler._setup(type_args, dfg, graph, globals, node)
return self.call_compiler.compile(args)


class CustomCallChecker(ABC):
"""Abstract base class for custom function call type checkers."""
Expand Down
38 changes: 37 additions & 1 deletion guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import textwrap
from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Any

from guppylang.ast_util import AstNode, annotate_location
Expand All @@ -14,13 +15,19 @@
Definition,
ParsableDef,
)
from guppylang.definition.custom import (
CustomCallCompiler,
CustomFunctionDef,
DefaultCallChecker,
)
from guppylang.definition.parameter import ParamDef
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr_builder.hugr import OutPortV
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter, check_all_args
from guppylang.tys.parsing import type_from_ast
from guppylang.tys.ty import StructType, Type
from guppylang.tys.ty import FunctionType, StructType, Type


@dataclass(frozen=True)
Expand Down Expand Up @@ -186,6 +193,35 @@ def check_instantiate(
check_all_args(self.params, args, self.name, loc)
return StructType(args, self)

@cached_property
def generated_methods(self) -> list[CustomFunctionDef]:
"""Auto-generated methods for this struct."""

class ConstructorCompiler(CustomCallCompiler):
"""Compiler for the `__new__` constructor method of a struct."""

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
return [self.graph.add_make_tuple(args).out_port(0)]

constructor_sig = FunctionType(
inputs=[f.ty for f in self.fields],
output=StructType(
defn=self, args=[p.to_bound(i) for i, p in enumerate(self.params)]
),
input_names=[f.name for f in self.fields],
params=self.params,
)
constructor_def = CustomFunctionDef(
id=DefId.fresh(self.id.module),
name="__new__",
defined_at=self.defined_at,
ty=constructor_sig,
call_checker=DefaultCallChecker(),
call_compiler=ConstructorCompiler(),
higher_order_value=True,
)
return [constructor_def]


def parse_py_class(cls: type) -> ast.ClassDef:
"""Parses a Python class object into an AST."""
Expand Down
12 changes: 11 additions & 1 deletion guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from guppylang.definition.declaration import RawFunctionDecl
from guppylang.definition.function import RawFunctionDef
from guppylang.definition.parameter import ParamDef
from guppylang.definition.struct import CheckedStructDef
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, pretty_errors
from guppylang.hugr_builder.hugr import Hugr
Expand Down Expand Up @@ -180,9 +181,18 @@ def compile(self) -> Hugr:
)
self._globals = self._globals.update_defs(type_defs)

# Collect auto-generated methods
generated: dict[DefId, RawDef] = {}
for defn in type_defs.values():
if isinstance(defn, CheckedStructDef):
self._globals.impls.setdefault(defn.id, {})
for method_def in defn.generated_methods:
generated[method_def.id] = method_def
self._globals.impls[defn.id][method_def.name] = method_def.id

# Now, we can check all other definitions
other_defs = self._check_defs(
self._raw_defs, self._imported_globals | self._globals
self._raw_defs | generated, self._imported_globals | self._globals
)
self._globals = self._globals.update_defs(other_defs)

Expand Down
7 changes: 7 additions & 0 deletions tests/error/struct_errors/constructor_arg_mismatch.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:15

13: @guppy(module)
14: def main() -> None:
15: MyStruct(0)
^
GuppyTypeError: Expected argument of type `(int, int)`, got `int`
18 changes: 18 additions & 0 deletions tests/error/struct_errors/constructor_arg_mismatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")


@guppy.struct(module)
class MyStruct:
x: tuple[int, int]


@guppy(module)
def main() -> None:
MyStruct(0)


module.compile()
7 changes: 7 additions & 0 deletions tests/error/struct_errors/constructor_arg_mismatch_poly.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:19

17: @guppy(module)
18: def main() -> None:
19: MyStruct(0, False)
^^^^^
GuppyTypeError: Expected argument of type `int`, got `bool`
22 changes: 22 additions & 0 deletions tests/error/struct_errors/constructor_arg_mismatch_poly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Generic

from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")
T = guppy.type_var(module, "T")


@guppy.struct(module)
class MyStruct(Generic[T]):
x: T
y: T


@guppy(module)
def main() -> None:
MyStruct(0, False)


module.compile()
7 changes: 7 additions & 0 deletions tests/error/struct_errors/constructor_missing_arg.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:15

13: @guppy(module)
14: def main() -> None:
15: MyStruct()
^^^^^^^^^^
GuppyTypeError: Not enough arguments passed (expected 1, got 0)
18 changes: 18 additions & 0 deletions tests/error/struct_errors/constructor_missing_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")


@guppy.struct(module)
class MyStruct:
x: int


@guppy(module)
def main() -> None:
MyStruct()


module.compile()
7 changes: 7 additions & 0 deletions tests/error/struct_errors/constructor_too_many_args.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:15

13: @guppy(module)
14: def main() -> None:
15: MyStruct(1, 2, 3)
^
GuppyTypeError: Unexpected argument
18 changes: 18 additions & 0 deletions tests/error/struct_errors/constructor_too_many_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule


module = GuppyModule("test")


@guppy.struct(module)
class MyStruct:
x: int


@guppy(module)
def main() -> None:
MyStruct(1, 2, 3)


module.compile()
41 changes: 35 additions & 6 deletions tests/integration/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Generic
from typing import Generic, TYPE_CHECKING

from guppylang.decorator import guppy
from guppylang.module import GuppyModule

if TYPE_CHECKING:
from collections.abc import Callable


def test_basic_defs(validate):
module = GuppyModule("module")
Expand Down Expand Up @@ -30,7 +33,10 @@ class DocstringStruct:
def main(
a: EmptyStruct, b: OneMemberStruct, c: TwoMemberStruct, d: DocstringStruct
) -> None:
pass
EmptyStruct()
OneMemberStruct(42)
TwoMemberStruct((True, 0), 1.0)
DocstringStruct(-1)

validate(module.compile())

Expand All @@ -48,7 +54,7 @@ class StructB:

@guppy(module)
def main(a: StructA, b: StructB) -> None:
pass
StructB(a)

validate(module.compile())

Expand All @@ -66,7 +72,7 @@ class StructB:

@guppy(module)
def main(a: StructA, b: StructB) -> None:
pass
StructA(b)

validate(module.compile())

Expand All @@ -92,7 +98,30 @@ class StructB(Generic[S, T]):
y: StructA[T]

@guppy(module)
def main(a: StructA[StructA[float]], b: StructB[int, bool], c: StructC) -> None:
pass
def main(a: StructA[StructA[float]], b: StructB[bool, int], c: StructC) -> None:
x = StructA((0, False))
y = StructA((0, -5))
StructA((0, x))
StructB(x, a)
StructC(y, StructA((0, [])), StructB(42.0, StructA((4, b))))

validate(module.compile())


def test_higher_order(validate):
module = GuppyModule("module")
T = guppy.type_var(module, "T")

@guppy.struct(module)
class Struct(Generic[T]):
x: T

@guppy(module)
def factory(mk_struct: "Callable[[int], Struct[int]]", x: int) -> Struct[int]:
return mk_struct(x)

@guppy(module)
def main() -> None:
factory(Struct, 42)

validate(module.compile())

0 comments on commit f68d0af

Please sign in to comment.