Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Define arithmetic methods via Guppy source instead of custom compilers #703

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
FunctionType,
InputFlags,
NoneType,
NumericType,
OpaqueType,
StructType,
TupleType,
Expand Down Expand Up @@ -207,7 +208,7 @@ def check(
# If we already have a type for the expression, we just have to match it against
# the target
if actual := get_type_opt(expr):
subst, inst = check_type_against(actual, ty, expr, kind)
expr, subst, inst = check_type_against(actual, ty, expr, self.ctx, kind)
if inst:
expr = with_loc(expr, TypeApply(value=expr, tys=inst))
return with_type(ty.substitute(subst), expr), subst
Expand Down Expand Up @@ -329,7 +330,7 @@ def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]:
def generic_visit(self, node: ast.expr, ty: Type) -> tuple[ast.expr, Subst]:
# Try to synthesize and then check if we can unify it with the given type
node, synth = self._synthesize(node, allow_free_vars=False)
subst, inst = check_type_against(synth, ty, node, self._kind)
node, subst, inst = check_type_against(synth, ty, node, self.ctx, self._kind)

# Apply instantiation of quantified type variables
if inst:
Expand Down Expand Up @@ -759,8 +760,8 @@ def generic_visit(self, node: ast.expr) -> NoReturn:


def check_type_against(
act: Type, exp: Type, node: AstNode, kind: str = "expression"
) -> tuple[Subst, Inst]:
act: Type, exp: Type, node: ast.expr, ctx: Context, kind: str = "expression"
) -> tuple[ast.expr, Subst, Inst]:
"""Checks a type against another type.

Returns a substitution for the free variables the expected type and an instantiation
Expand Down Expand Up @@ -797,14 +798,37 @@ def check_type_against(
# Finally, check that the instantiation respects the linearity requirements
check_inst(act, inst, node)

return subst, inst
return node, subst, inst

# Otherwise, we know that `act` has no unsolved type vars, so unification is trivial
assert not act.unsolved_vars
subst = unify(exp, act, {})
if subst is None:
# Maybe we can implicitly coerce `act` to `exp`
if coerced := try_coerce_to(act, exp, node, ctx):
return coerced, {}, []
raise GuppyTypeError(TypeMismatchError(node, exp, act, kind))
return subst, []
return node, subst, []


def try_coerce_to(
act: Type, exp: Type, node: ast.expr, ctx: Context
) -> ast.expr | None:
"""Tries to implicitly coerce an expression to a different type.

Returns the coerced expression or `None` if the type cannot be implicitly coerced.
"""
# Currently, we only support implicit coercions of numeric types
if not isinstance(act, NumericType) or not isinstance(exp, NumericType):
return None
# Ordering on `NumericType.Kind` defines the coercion relation
if act.kind < exp.kind:
f = ctx.globals.get_instance_func(act, f"__{exp.kind.name.lower()}__")
assert f is not None
node, subst = f.check_call([node], exp, node, ctx)
assert len(subst) == 0, "Coercion methods are not generic"
return node
return None


def check_num_args(
Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
from guppylang.checker.expr_checker import check_type_against

expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
return expr, subst

@abstractmethod
Expand Down
59 changes: 23 additions & 36 deletions guppylang/std/_internal/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing_extensions import assert_never

from guppylang.ast_util import AstNode, with_loc, with_type
from guppylang.ast_util import with_loc, with_type
from guppylang.checker.core import Context
from guppylang.checker.errors.generic import ExpectedError, UnsupportedError
from guppylang.checker.errors.type_errors import (
Expand All @@ -22,8 +22,6 @@
)
from guppylang.definition.custom import (
CustomCallChecker,
CustomFunctionDef,
DefaultCallChecker,
)
from guppylang.definition.struct import CheckedStructDef, RawStructDef
from guppylang.diagnostic import Error, Note
Expand Down Expand Up @@ -60,42 +58,28 @@
)


class CoercingChecker(DefaultCallChecker):
"""Function call type checker that automatically coerces arguments to float."""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
for i in range(len(args)):
args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i])
if isinstance(ty, NumericType) and ty.kind != NumericType.Kind.Float:
to_float = self.ctx.globals.get_instance_func(ty, "__float__")
assert to_float is not None
args[i], _ = to_float.synthesize_call([args[i]], self.node, self.ctx)
return super().synthesize(args)


class ReversingChecker(CustomCallChecker):
"""Call checker that reverses the arguments after checking."""

base_checker: CustomCallChecker
"""Call checker for reverse arithmetic methods.

def __init__(self, base_checker: CustomCallChecker | None = None):
self.base_checker = base_checker or DefaultCallChecker()
For examples, turns a call to `__radd__` into a call to `__add__` with reversed
arguments.
"""

def _setup(self, ctx: Context, node: AstNode, func: CustomFunctionDef) -> None:
super()._setup(ctx, node, func)
self.base_checker._setup(ctx, node, func)

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
expr, subst = self.base_checker.check(args, ty)
if isinstance(expr, GlobalCall):
expr.args = list(reversed(expr.args))
return expr, subst
def parse_name(self) -> str:
# Must be a dunder method
assert self.func.name.startswith("__")
assert self.func.name.endswith("__")
name = self.func.name[2:-2]
# Remove the `r`
assert name.startswith("r")
return f"__{name[1:]}__"

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
expr, ty = self.base_checker.synthesize(args)
if isinstance(expr, GlobalCall):
expr.args = list(reversed(expr.args))
return expr, ty
[self_arg, other_arg] = args
self_arg, self_ty = ExprSynthesizer(self.ctx).synthesize(self_arg)
f = self.ctx.globals.get_instance_func(self_ty, self.parse_name())
assert f is not None
return f.synthesize_call([other_arg, self_arg], self.node, self.ctx)


class UnsupportedChecker(CustomCallChecker):
Expand Down Expand Up @@ -232,7 +216,10 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
# TODO: We could use the type information to infer some stuff
# in the comprehension
arr_compr, res_ty = self.synthesize_array_comprehension(compr)
subst, _ = check_type_against(res_ty, ty, self.node)
arr_compr = with_loc(self.node, arr_compr)
arr_compr, subst, _ = check_type_against(
res_ty, ty, arr_compr, self.ctx
)
return arr_compr, subst
# Or a list of array elements
case args:
Expand Down Expand Up @@ -359,7 +346,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
return expr, subst

@staticmethod
Expand Down
Loading
Loading