Skip to content

Commit

Permalink
feat: Implicit coercion of numeric types
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Dec 10, 2024
1 parent 45ea6b7 commit eca9dc6
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 14 deletions.
36 changes: 30 additions & 6 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
FunctionType,
InputFlags,
NoneType,
NumericType,
OpaqueType,
StructType,
TupleType,
Expand Down Expand Up @@ -207,7 +208,7 @@ def check(
# If we already have a type for the expression, we just have to match it against
# the target
if actual := get_type_opt(expr):
subst, inst = check_type_against(actual, ty, expr, kind)
expr, subst, inst = check_type_against(actual, ty, expr, self.ctx, kind)
if inst:
expr = with_loc(expr, TypeApply(value=expr, tys=inst))
return with_type(ty.substitute(subst), expr), subst
Expand Down Expand Up @@ -329,7 +330,7 @@ def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]:
def generic_visit(self, node: ast.expr, ty: Type) -> tuple[ast.expr, Subst]:
# Try to synthesize and then check if we can unify it with the given type
node, synth = self._synthesize(node, allow_free_vars=False)
subst, inst = check_type_against(synth, ty, node, self._kind)
node, subst, inst = check_type_against(synth, ty, node, self.ctx, self._kind)

# Apply instantiation of quantified type variables
if inst:
Expand Down Expand Up @@ -759,8 +760,8 @@ def generic_visit(self, node: ast.expr) -> NoReturn:


def check_type_against(
act: Type, exp: Type, node: AstNode, kind: str = "expression"
) -> tuple[Subst, Inst]:
act: Type, exp: Type, node: ast.expr, ctx: Context, kind: str = "expression"
) -> tuple[ast.expr, Subst, Inst]:
"""Checks a type against another type.
Returns a substitution for the free variables the expected type and an instantiation
Expand Down Expand Up @@ -797,14 +798,37 @@ def check_type_against(
# Finally, check that the instantiation respects the linearity requirements
check_inst(act, inst, node)

return subst, inst
return node, subst, inst

# Otherwise, we know that `act` has no unsolved type vars, so unification is trivial
assert not act.unsolved_vars
subst = unify(exp, act, {})
if subst is None:
# Maybe we can implicitly coerce `act` to `exp`
if coerced := try_coerce_to(act, exp, node, ctx):
return coerced, {}, []
raise GuppyTypeError(TypeMismatchError(node, exp, act, kind))
return subst, []
return node, subst, []


def try_coerce_to(
act: Type, exp: Type, node: ast.expr, ctx: Context
) -> ast.expr | None:
"""Tries to implicitly coerce an expression to a different type.
Returns the coerced expression or `None` if the type cannot be implicitly coerced.
"""
# Currently, we only support implicit coercions of numeric types
if not isinstance(act, NumericType) or not isinstance(exp, NumericType):
return None
# Ordering on `NumericType.Kind` defines the coercion relation
if act.kind < exp.kind:
f = ctx.globals.get_instance_func(act, f"__{exp.kind.name.lower()}__")
assert f is not None
node, subst = f.check_call([node], exp, node, ctx)
assert len(subst) == 0, "Coercion methods are not generic"
return node
return None


def check_num_args(
Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
from guppylang.checker.expr_checker import check_type_against

expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
return expr, subst

@abstractmethod
Expand Down
7 changes: 5 additions & 2 deletions guppylang/std/_internal/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
# TODO: We could use the type information to infer some stuff
# in the comprehension
arr_compr, res_ty = self.synthesize_array_comprehension(compr)
subst, _ = check_type_against(res_ty, ty, self.node)
arr_compr = with_loc(self.node, arr_compr)
arr_compr, subst, _ = check_type_against(
res_ty, ty, arr_compr, self.ctx
)
return arr_compr, subst
# Or a list of array elements
case args:
Expand Down Expand Up @@ -359,7 +362,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
return expr, subst

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion guppylang/tys/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str:

@_visit.register
def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str:
return ty.kind.value
return ty.kind.name.lower()

@_visit.register
def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str:
Expand Down
12 changes: 8 additions & 4 deletions guppylang/tys/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import Enum, Flag, auto
from functools import cached_property
from functools import cached_property, total_ordering
from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast

import hugr.std.float
Expand Down Expand Up @@ -259,12 +259,16 @@ class NumericType(TypeBase):

kind: "Kind"

@total_ordering
class Kind(Enum):
"""The different kinds of numeric types."""

Nat = "nat"
Int = "int"
Float = "float"
Nat = auto()
Int = auto()
Float = auto()

def __lt__(self, other: "NumericType.Kind") -> bool:
return self.value < other.value

INT_WIDTH: ClassVar[int] = 6

Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def main(a1: angle, a2: angle) -> bool:
validate(module.compile())


def test_implicit_coercion(validate):
@compile_guppy
def coerce(x: nat) -> float:
y: int = x
z: float = y
a: float = 1
return z + a

validate(coerce)


def test_angle_float_coercion(validate):
module = GuppyModule("test")
module.load(angle)
Expand Down

0 comments on commit eca9dc6

Please sign in to comment.