diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 25dd256d..600c3669 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -9,6 +9,7 @@ FunctionType, GuppyType, NoneType, + Subst, SumType, TupleType, ) @@ -33,7 +34,7 @@ class CallableVariable(ABC, Variable): @abstractmethod def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context" - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: """Checks the return type of a function call against a given type.""" @abstractmethod @@ -43,6 +44,14 @@ def synthesize_call( """Synthesizes the return type of a function call.""" +@dataclass +class TypeVarDecl: + """A declared type variable.""" + + name: str + linear: bool + + class Globals(NamedTuple): """Collection of names that are available on module-level. @@ -52,6 +61,7 @@ class Globals(NamedTuple): values: dict[str, Variable] types: dict[str, type[GuppyType]] + type_vars: dict[str, TypeVarDecl] @staticmethod def default() -> "Globals": @@ -63,7 +73,7 @@ def default() -> "Globals": NoneType.name: NoneType, BoolType.name: BoolType, } - return Globals({}, tys) + return Globals({}, tys, {}) def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None: """Looks up an instance function with a given name for a type. @@ -81,11 +91,13 @@ def __or__(self, other: "Globals") -> "Globals": return Globals( self.values | other.values, self.types | other.types, + self.type_vars | other.type_vars, ) def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034 self.values.update(other.values) self.types.update(other.types) + self.type_vars.update(other.type_vars) return self diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 731d8f7c..9e799399 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -26,9 +26,23 @@ from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type from guppy.checker.core import CallableVariable, Context, Globals -from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.gtypes import BoolType, FunctionType, GuppyType, TupleType -from guppy.nodes import GlobalName, LocalCall, LocalName +from guppy.error import ( + GuppyError, + GuppyTypeError, + GuppyTypeInferenceError, + InternalGuppyError, +) +from guppy.gtypes import ( + BoolType, + ExistentialTypeVar, + FunctionType, + GuppyType, + Inst, + Subst, + TupleType, + unify, +) +from guppy.nodes import GlobalName, LocalCall, LocalName, TypeApply # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { @@ -62,8 +76,12 @@ } # fmt: skip -class ExprChecker(AstVisitor[ast.expr]): - """Checks an expression against a type and produces a new type-annotated AST""" +class ExprChecker(AstVisitor[tuple[ast.expr, Subst]]): + """Checks an expression against a type and produces a new type-annotated AST. + + The type may contain free variables that the checker will try to solve. Note that + the checker will fail, if some free variables cannot be inferred. + """ ctx: Context @@ -84,7 +102,7 @@ def _fail( """Raises a type error indicating that the type doesn't match.""" if not isinstance(actual, GuppyType): loc = loc or actual - _, actual = self._synthesize(actual) + _, actual = self._synthesize(actual, allow_free_vars=True) if loc is None: raise InternalGuppyError("Failure location is required") raise GuppyTypeError( @@ -93,34 +111,76 @@ def _fail( def check( self, expr: ast.expr, ty: GuppyType, kind: str = "expression" - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: """Checks an expression against a type. - Returns a new desugared expression with type annotations. + The type may have free type variables which will try to be resolved. Returns + a new desugared expression with type annotations and a substitution with the + resolved type variables. """ + # When checking against a variable, we have to synthesize + if isinstance(ty, ExistentialTypeVar): + expr, syn_ty = self._synthesize(expr, allow_free_vars=False) + return with_type(syn_ty, expr), {ty: syn_ty} + + # Otherwise, invoke the visitor old_kind = self._kind self._kind = kind or self._kind - expr = self.visit(expr, ty) + expr, subst = self.visit(expr, ty) self._kind = old_kind - return with_type(ty, expr) + return with_type(ty.substitute(subst), expr), subst - def _synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + def _synthesize( + self, node: ast.expr, allow_free_vars: bool + ) -> tuple[ast.expr, GuppyType]: """Invokes the type synthesiser""" - return ExprSynthesizer(self.ctx).synthesize(node) + return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars) - def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> ast.expr: + def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> tuple[ast.expr, Subst]: if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): return self._fail(ty, node) + subst: Subst = {} for i, el in enumerate(node.elts): - node.elts[i] = self.check(el, ty.element_types[i]) - return node + node.elts[i], s = self.check(el, ty.element_types[i].substitute(subst)) + subst |= s + return node, subst + + def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: + if len(node.keywords) > 0: + raise GuppyError( + "Argument passing by keyword is not supported", node.keywords[0] + ) + node.func, func_ty = self._synthesize(node.func, allow_free_vars=False) + + # First handle direct calls of user-defined functions and extension functions + if isinstance(node.func, GlobalName) and isinstance( + node.func.value, CallableVariable + ): + return node.func.value.check_call(node.args, ty, node, self.ctx) + + # Otherwise, it must be a function as a higher-order value + if isinstance(func_ty, FunctionType): + args, return_ty, inst = check_call(func_ty, node.args, ty, node, self.ctx) + check_inst(func_ty, inst, node) + node.func = instantiate_poly(node.func, func_ty, inst) + return with_loc(node, LocalCall(func=node.func, args=args)), return_ty + elif f := self.ctx.globals.get_instance_func(func_ty, "__call__"): + return f.check_call(node.args, ty, node, self.ctx) + else: + raise GuppyTypeError(f"Expected function type, got `{func_ty}`", node.func) + + def generic_visit( # type: ignore[override] + self, node: ast.expr, ty: GuppyType + ) -> tuple[ast.expr, Subst]: + # Try to synthesize and then check if we can unify it with the given type + node, synth = self._synthesize(node, allow_free_vars=False) + subst, inst = check_type_against(synth, ty, node, self._kind) - def generic_visit(self, node: ast.expr, ty: GuppyType) -> ast.expr: # type: ignore[override] - # Try to synthesize and then check if it matches the given type - node, synth = self._synthesize(node) - if synth != ty: - self._fail(ty, synth, node) - return node + # Apply instantiation of quantified type variables + if inst: + node = with_loc(node, TypeApply(value=node, tys=inst)) + + return node, subst class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): @@ -129,7 +189,9 @@ class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): def __init__(self, ctx: Context) -> None: self.ctx = ctx - def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + def synthesize( + self, node: ast.expr, allow_free_vars: bool = False + ) -> tuple[ast.expr, GuppyType]: """Tries to synthesise a type for the given expression. Also returns a new desugared expression with type annotations. @@ -137,11 +199,15 @@ def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: if ty := get_type_opt(node): return node, ty node, ty = self.visit(node) + if ty.unsolved_vars and not allow_free_vars: + raise GuppyTypeError( + f"Cannot infer type variable in expression of type `{ty}`", node + ) return with_type(ty, node), ty def _check( self, expr: ast.expr, ty: GuppyType, kind: str = "expression" - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: """Checks an expression against a given type""" return ExprChecker(self.ctx).check(expr, ty, kind) @@ -252,7 +318,8 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: # Otherwise, it must be a function as a higher-order value if isinstance(ty, FunctionType): - args, return_ty = synthesize_call(ty, node.args, node, self.ctx) + args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx) + node.func = instantiate_poly(node.func, ty, inst) return with_loc(node, LocalCall(func=node.func, args=args)), return_ty elif f := self.ctx.globals.get_instance_func(ty, "__call__"): return f.synthesize_call(node.args, node, self.ctx) @@ -278,6 +345,58 @@ def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, GuppyType]: ) +def check_type_against( + act: GuppyType, exp: GuppyType, node: AstNode, kind: str = "expression" +) -> tuple[Subst, Inst]: + """Checks a type against another type. + + Returns a substitution for the free variables the expected type and an instantiation + for the quantified variables in the actual type. Note that the expected type may not + be quantified and the actual type may not contain free unification variables. + """ + assert not isinstance(exp, FunctionType) or not exp.quantified + assert not act.unsolved_vars + + # The actual type may be quantified. In that case, we have to find an instantiation + # to avoid higher-rank types. + subst: Subst | None + if isinstance(act, FunctionType) and act.quantified: + unquantified, free_vars = act.unquantified() + subst = unify(exp, unquantified, {}) + if subst is None: + raise GuppyTypeError(f"Expected {kind} of type `{exp}`, got `{act}`", node) + # Check that we have found a valid instantiation for all quantified vars + for i, v in enumerate(free_vars): + if v not in subst: + raise GuppyTypeInferenceError( + f"Expected {kind} of type `{exp}`, got `{act}`. Couldn't infer an " + f"instantiation for type variable `{act.quantified[i]}` (higher-" + "rank polymorphic types are not supported)", + node, + ) + if subst[v].unsolved_vars: + raise GuppyTypeError( + f"Expected {kind} of type `{exp}`, got `{act}`. Can't instantiate " + f"type variable `{act.quantified[i]}` with type `{subst[v]}` " + "containing free variables", + node, + ) + inst = [subst[v] for v in free_vars] + subst = {v: t for v, t in subst.items() if v in exp.unsolved_vars} + + # Finally, check that the instantiation respects the linearity requirements + check_inst(act, inst, node) + + return subst, inst + + # Otherwise, we know that `act` has no unsolved type vars, so unification is trivial + assert not act.unsolved_vars + subst = unify(exp, act, {}) + if subst is None: + raise GuppyTypeError(f"Expected {kind} of type `{exp}`, got `{act}`", node) + return subst, [] + + def check_num_args(exp: int, act: int, node: AstNode) -> None: """Checks that the correct number of arguments have been passed to a function.""" if act < exp: @@ -292,17 +411,66 @@ def check_num_args(exp: int, act: int, node: AstNode) -> None: ) +def type_check_args( + args: list[ast.expr], + func_ty: FunctionType, + subst: Subst, + ctx: Context, + node: AstNode, +) -> tuple[list[ast.expr], Subst]: + """Checks the arguments of a function call and infers free type variables. + + We expect that quantified variables have been replaced with free unification + variables. Checks that all unification variables can be inferred. + """ + assert not func_ty.quantified + check_num_args(len(func_ty.args), len(args), node) + + new_args: list[ast.expr] = [] + for arg, ty in zip(args, func_ty.args): + a, s = ExprChecker(ctx).check(arg, ty.substitute(subst), "argument") + new_args.append(a) + subst |= s + + # If the argument check succeeded, this means that we must have found instantiations + # for all unification variables occurring in the argument types + assert all(set.issubset(arg.unsolved_vars, subst.keys()) for arg in func_ty.args) + + # We also have to check that we found instantiations for all vars in the return type + if not set.issubset(func_ty.returns.unsolved_vars, subst.keys()): + raise GuppyTypeInferenceError( + f"Cannot infer type variable in expression of type " + f"`{func_ty.returns.substitute(subst)}`", + node, + ) + + return new_args, subst + + def synthesize_call( func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context -) -> tuple[list[ast.expr], GuppyType]: +) -> tuple[list[ast.expr], GuppyType, Inst]: """Synthesizes the return type of a function call. - Also returns desugared versions of the arguments with type annotations. + Returns an annotated argument list, the synthesized return type, and an + instantiation for the quantifiers in the function type. """ + assert not func_ty.unsolved_vars check_num_args(len(func_ty.args), len(args), node) - for i, arg in enumerate(args): - args[i] = ExprChecker(ctx).check(arg, func_ty.args[i], "argument") - return args, func_ty.returns + + # Replace quantified variables with free unification variables and try to infer an + # instantiation by checking the arguments + unquantified, free_vars = func_ty.unquantified() + args, subst = type_check_args(args, unquantified, {}, ctx, node) + + # Success implies that the substitution is closed + assert all(not t.unsolved_vars for t in subst.values()) + inst = [subst[v] for v in free_vars] + + # Finally, check that the instantiation respects the linearity requirements + check_inst(func_ty, inst, node) + + return args, unquantified.returns.substitute(subst), inst def check_call( @@ -311,14 +479,102 @@ def check_call( ty: GuppyType, node: AstNode, ctx: Context, -) -> list[ast.expr]: - """Checks the return type of a function call against a given type""" - args, return_ty = synthesize_call(func_ty, args, node, ctx) - if return_ty != ty: + kind: str = "expression", +) -> tuple[list[ast.expr], Subst, Inst]: + """Checks the return type of a function call against a given type. + + Returns an annotated argument list, a substitution for the free variables in the + expected type, and an instantiation for the quantifiers in the function type. + """ + assert not func_ty.unsolved_vars + check_num_args(len(func_ty.args), len(args), node) + + # When checking, we can use the information from the expected return type to infer + # some type arguments. However, this pushes errors inwards. For example, given a + # function `foo: forall T. T -> T`, the following type mismatch would be reported: + # + # x: int = foo(None) + # ^^^^ Expected argument of type `int`, got `None` + # + # But the following error location would be more intuitive for users: + # + # x: int = foo(None) + # ^^^^^^^^^ Expected expression of type `int`, got `None` + # + # In other words, if we can get away with synthesising the call without the extra + # information from the expected type, we should do that to improve the error. + + # TODO: The approach below can result in exponential runtime in the worst case. + # However the bad case, e.g. `x: int = foo(foo(...foo(?)...))`, shouldn't be common + # in practice. Can we do better than that? + + # First, try to synthesize + res: tuple[GuppyType, Inst] | None = None + try: + args, synth, inst = synthesize_call(func_ty, args, node, ctx) + res = synth, inst + except GuppyTypeInferenceError: + pass + if res is not None: + synth, inst = res + subst = unify(ty, synth, {}) + if subst is None: + raise GuppyTypeError(f"Expected {kind} of type `{ty}`, got `{synth}`", node) + return args, subst, inst + + # If synthesis fails, we try again, this time also using information from the + # expected return type + unquantified, free_vars = func_ty.unquantified() + subst = unify(ty, unquantified.returns, {}) + if subst is None: raise GuppyTypeError( - f"Expected expression of type `{ty}`, got `{return_ty}`", node + f"Expected {kind} of type `{ty}`, got `{unquantified.returns}`", node ) - return args + + # Try to infer more by checking against the arguments + args, subst = type_check_args(args, unquantified, subst, ctx, node) + + # Also make sure we found an instantiation for all free vars in the type we're + # checking against + if not set.issubset(ty.unsolved_vars, subst.keys()): + raise GuppyTypeInferenceError( + f"Expected expression of type `{ty}`, got " + f"`{func_ty.returns.substitute(subst)}`. Couldn't infer type variables", + node, + ) + + # Success implies that the substitution is closed + assert all(not t.unsolved_vars for t in subst.values()) + inst = [subst[v] for v in free_vars] + subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars} + + # Finally, check that the instantiation respects the linearity requirements + check_inst(func_ty, inst, node) + + return args, subst, inst + + +def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None: + """Checks if an instantiation is valid. + + Makes sure that the linearity requirements are satisfied. + """ + for var, ty in zip(func_ty.quantified, inst): + if not var.linear and ty.linear: + raise GuppyTypeError( + f"Cannot instantiate non-linear type variable `{var}` in type " + f"`{func_ty}` with linear type `{ty}`", + node, + ) + + +def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr: + """Instantiates quantified type arguments in a function.""" + assert len(ty.quantified) == len(inst) + if len(inst) > 0: + node = with_loc(node, TypeApply(value=with_type(ty, node), tys=inst)) + return with_type(ty.instantiate(inst), node) + return node def to_bool( diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 20113685..dcef53bd 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -15,7 +15,14 @@ from guppy.checker.core import CallableVariable, Context, Globals, Variable from guppy.checker.expr_checker import check_call, synthesize_call from guppy.error import GuppyError -from guppy.gtypes import FunctionType, GuppyType, NoneType, type_from_ast +from guppy.gtypes import ( + BoundTypeVar, + FunctionType, + GuppyType, + NoneType, + Subst, + type_from_ast, +) from guppy.nodes import CheckedNestedFunctionDef, GlobalCall, NestedFunctionDef @@ -31,21 +38,25 @@ def from_ast( func_def: ast.FunctionDef, name: str, globals: Globals ) -> "DefinedFunction": ty = check_signature(func_def, globals) + if ty.quantified: + raise GuppyError( + "Generic function definitions are not supported yet", func_def + ) return DefinedFunction(name, ty, func_def, None) def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context - ) -> GlobalCall: + ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker - args = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args) + args, subst, inst = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args, type_args=inst), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker - args, ty = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args), ty + args, ty, inst = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args, type_args=inst), ty @dataclass @@ -131,7 +142,7 @@ def check_nested_func_def( if not captured: # If there are no captured vars, we treat the function like a global name func = DefinedFunction(func_def.name, func_ty, func_def, None) - globals = ctx.globals | Globals({func_def.name: func}, {}) + globals = ctx.globals | Globals({func_def.name: func}, {}, {}) else: # Otherwise, we treat it like a local name @@ -176,14 +187,21 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType ) raise GuppyError("Return type must be annotated", func_def) + # TODO: Prepopulate mapping when using Python 3.12 style generic functions + type_var_mapping: dict[str, BoundTypeVar] = {} arg_tys = [] arg_names = [] for _i, arg in enumerate(func_def.args.args): if arg.annotation is None: raise GuppyError("Argument type must be annotated", arg) - ty = type_from_ast(arg.annotation, globals) + ty = type_from_ast(arg.annotation, globals, type_var_mapping) arg_tys.append(ty) arg_names.append(arg.arg) - ret_type = type_from_ast(func_def.returns, globals) - return FunctionType(arg_tys, ret_type, arg_names) + ret_type = type_from_ast(func_def.returns, globals, type_var_mapping) + return FunctionType( + arg_tys, + ret_type, + arg_names, + sorted(type_var_mapping.values(), key=lambda v: v.idx), + ) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 61df3120..0fa46608 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -16,7 +16,7 @@ from guppy.checker.core import Context, Variable from guppy.checker.expr_checker import ExprChecker, ExprSynthesizer from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.gtypes import GuppyType, NoneType, TupleType, type_from_ast +from guppy.gtypes import GuppyType, NoneType, Subst, TupleType, type_from_ast from guppy.nodes import NestedFunctionDef @@ -26,6 +26,7 @@ class StmtChecker(AstVisitor[BBStatement]): return_ty: GuppyType def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: + assert not return_ty.unsolved_vars self.ctx = ctx self.bb = bb self.return_ty = return_ty @@ -38,7 +39,7 @@ def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: def _check_expr( self, node: ast.expr, ty: GuppyType, kind: str = "expression" - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: return ExprChecker(self.ctx).check(node, ty, kind) def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: @@ -91,7 +92,9 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: "Variable declaration is not supported. Assignment is required", node ) ty = type_from_ast(node.annotation, self.ctx.globals) - node.value = self._check_expr(node.value, ty) + node.value, subst = self._check_expr(node.value, ty) + assert not ty.unsolved_vars # `ty` must be closed! + assert len(subst) == 0 self._check_assign(node.target, ty, node) return node @@ -111,7 +114,10 @@ def visit_Expr(self, node: ast.Expr) -> ast.stmt: def visit_Return(self, node: ast.Return) -> ast.stmt: if node.value is not None: - node.value = self._check_expr(node.value, self.return_ty, "return value") + node.value, subst = self._check_expr( + node.value, self.return_ty, "return value" + ) + assert len(subst) == 0 # `self.return_ty` is closed! elif not isinstance(self.return_ty, NoneType): raise GuppyTypeError( f"Expected return value of type `{self.return_ty}`", None diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py index d3b16e59..5f00f715 100644 --- a/guppy/compiler/core.py +++ b/guppy/compiler/core.py @@ -4,7 +4,7 @@ from guppy.ast_util import AstNode from guppy.checker.core import CallableVariable, Variable -from guppy.gtypes import FunctionType +from guppy.gtypes import FunctionType, Inst from guppy.hugr.hugr import DFContainingNode, Hugr, OutPortV @@ -47,6 +47,7 @@ class CompiledFunction(CompiledVariable, CallableVariable): def compile_call( self, args: list[OutPortV], + type_args: Inst, dfg: "DFContainer", graph: Hugr, globals: "CompiledGlobals", diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index f74c18d2..97b08945 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -3,11 +3,19 @@ from guppy.ast_util import AstVisitor, get_type from guppy.compiler.core import CompiledFunction, CompilerBase, DFContainer -from guppy.error import InternalGuppyError -from guppy.gtypes import BoolType, FunctionType, type_to_row +from guppy.error import GuppyError, InternalGuppyError +from guppy.gtypes import ( + BoolType, + BoundTypeVar, + FunctionType, + Inst, + NoneType, + TupleType, + type_to_row, +) from guppy.hugr import ops, val from guppy.hugr.hugr import OutPortV -from guppy.nodes import GlobalCall, GlobalName, LocalCall, LocalName +from guppy.nodes import GlobalCall, GlobalName, LocalCall, LocalName, TypeApply class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): @@ -71,12 +79,34 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: assert isinstance(func, CompiledFunction) args = [self.visit(arg) for arg in node.args] - rets = func.compile_call(args, self.dfg, self.graph, self.globals, node) + rets = func.compile_call( + args, list(node.type_args), self.dfg, self.graph, self.globals, node + ) return self._pack_returns(rets) def visit_Call(self, node: ast.Call) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") + def visit_TypeApply(self, node: TypeApply) -> OutPortV: + func = self.visit(node.value) + assert isinstance(func.ty, FunctionType) + ta = self.graph.add_type_apply(func, node.tys, self.dfg.node).out_port(0) + + # We have to be very careful here: If we instantiate `foo: forall T. T -> T` + # with a tuple type `tuple[A, B]`, we get the type `tuple[A, B] -> tuple[A, B]`. + # Normally, this would be represented in Hugr as a function with two output + # ports types A and B. However, when TypeApplying `foo`, we actually get a + # function with a single output port typed `tuple[A, B]`. + # TODO: We would need to do manual monomorphisation in that case to obtain a + # function that returns two ports as expected + if instantiation_needs_unpacking(func.ty, node.tys): + raise GuppyError( + "Generic function instantiations returning rows are not supported yet", + node, + ) + + return ta + def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: # The only case that is not desugared by the type checker is the `not` operation # since it is not implemented via a dunder method @@ -100,6 +130,14 @@ def expr_to_row(expr: ast.expr) -> list[ast.expr]: return expr.elts if isinstance(expr, ast.Tuple) else [expr] +def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool: + """Checks if instantiating a polymorphic makes it return a row.""" + if isinstance(func_ty.returns, BoundTypeVar): + return_ty = inst[func_ty.returns.idx] + return isinstance(return_ty, TupleType | NoneType) + return False + + def python_value_to_hugr(v: Any) -> val.Value | None: """Turns a Python value into a Hugr value. diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index f4376b17..1869c8ac 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -9,7 +9,7 @@ DFContainer, PortVariable, ) -from guppy.gtypes import FunctionType, type_to_row +from guppy.gtypes import FunctionType, Inst, type_to_row from guppy.hugr.hugr import DFContainingVNode, Hugr, OutPortV from guppy.nodes import CheckedNestedFunctionDef @@ -26,12 +26,20 @@ def load( def compile_call( self, args: list[OutPortV], + type_args: Inst, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode, ) -> list[OutPortV]: - call = graph.add_call(self.node.out_port(0), args, dfg.node) + # TODO: Hugr should probably allow us to pass type args to `Call`, so we can + # avoid loading the function to manually add a `TypeApply` + if type_args: + func = graph.add_load_constant(self.node.out_port(0), dfg.node) + func = graph.add_type_apply(func.out_port(0), type_args, dfg.node) + call = graph.add_indirect_call(func.out_port(0), args, dfg.node) + else: + call = graph.add_call(self.node.out_port(0), args, dfg.node) return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] diff --git a/guppy/custom.py b/guppy/custom.py index 4c20ea09..8c5a4fc1 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -11,7 +11,7 @@ InternalGuppyError, UnknownFunctionType, ) -from guppy.gtypes import FunctionType, GuppyType, type_to_row +from guppy.gtypes import FunctionType, GuppyType, Inst, Subst, type_to_row from guppy.hugr import ops from guppy.hugr.hugr import DFContainingVNode, Hugr, Node, OutPortV from guppy.nodes import GlobalCall @@ -84,9 +84,10 @@ def check_type(self, globals: Globals) -> None: def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context - ) -> ast.expr: + ) -> tuple[ast.expr, Subst]: self.call_checker._setup(ctx, node, self) - return with_type(ty, with_loc(node, self.call_checker.check(args, ty))) + new_node, subst = self.call_checker.check(args, ty) + return with_type(ty, with_loc(node, new_node)), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: "Context" @@ -98,12 +99,13 @@ def synthesize_call( def compile_call( self, args: list[OutPortV], + type_args: Inst, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode, ) -> list[OutPortV]: - self.call_compiler._setup(dfg, graph, globals, node) + self.call_compiler._setup(type_args, dfg, graph, globals, node) return self.call_compiler.compile(args) def load( @@ -121,6 +123,12 @@ def load( node, ) + if self._ty.quantified: + raise InternalGuppyError( + "Can't yet generate higher-order versions of custom functions. This " + "requires generic function *definitions*" + ) + # Find the module node by walking up the hierarchy module: Node = dfg.node while not isinstance(module.op, ops.Module): @@ -137,7 +145,7 @@ def load( def_node = graph.add_def(self.ty, module, self.name) _, inp_ports = graph.add_input_with_ports(list(self.ty.args), def_node) returns = self.compile_call( - inp_ports, DFContainer(def_node, {}), graph, globals, node + inp_ports, [], DFContainer(def_node, {}), graph, globals, node ) graph.add_output(returns, parent=def_node) self._defined[module] = def_node @@ -161,7 +169,7 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: self.func = func @abstractmethod - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: """Checks the return value against a given type. Returns a (possibly) transformed and annotated AST node for the call. @@ -178,14 +186,21 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: class CustomCallCompiler(ABC): """Protocol for custom function call compilers.""" + type_args: Inst dfg: DFContainer graph: Hugr globals: CompiledGlobals node: AstNode def _setup( - self, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode + self, + type_args: Inst, + dfg: DFContainer, + graph: Hugr, + globals: CompiledGlobals, + node: AstNode, ) -> None: + self.type_args = type_args self.dfg = dfg self.graph = graph self.globals = globals @@ -199,15 +214,15 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: class DefaultCallChecker(CustomCallChecker): """Checks function calls by comparing to a type signature.""" - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker - args = check_call(self.func.ty, args, ty, self.node, self.ctx) - return GlobalCall(func=self.func, args=args) + args, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx) + return GlobalCall(func=self.func, args=args, type_args=inst), subst def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: # Use default implementation from the expression checker - args, ty = synthesize_call(self.func.ty, args, self.node, self.ctx) - return GlobalCall(func=self.func, args=args), ty + args, ty, inst = synthesize_call(self.func.ty, args, self.node, self.ctx) + return GlobalCall(func=self.func, args=args, type_args=inst), ty class DefaultCallCompiler(CustomCallCompiler): diff --git a/guppy/declared.py b/guppy/declared.py index 265ef1a6..b868a374 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -7,7 +7,7 @@ from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, CompiledGlobals, DFContainer from guppy.error import GuppyError -from guppy.gtypes import GuppyType, type_to_row +from guppy.gtypes import GuppyType, Inst, Subst, type_to_row from guppy.hugr.hugr import Hugr, Node, OutPortV, VNode from guppy.nodes import GlobalCall @@ -31,17 +31,17 @@ def from_ast( def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context - ) -> GlobalCall: + ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker - args = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args) + args, subst, inst = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args, type_args=inst), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker - args, ty = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args), ty + args, ty, inst = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args, type_args=inst), ty def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) @@ -55,11 +55,19 @@ def load( def compile_call( self, args: list[OutPortV], + type_args: Inst, dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, node: AstNode, ) -> list[OutPortV]: assert self.node is not None - call = graph.add_call(self.node.out_port(0), args, dfg.node) + # TODO: Hugr should probably allow us to pass type args to `Call`, so we can + # avoid loading the function to manually add a `TypeApply` + if type_args: + func = graph.add_load_constant(self.node.out_port(0), dfg.node) + func = graph.add_type_apply(func.out_port(0), type_args, dfg.node) + call = graph.add_indirect_call(func.out_port(0), args, dfg.node) + else: + call = graph.add_call(self.node.out_port(0), args, dfg.node) return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] diff --git a/guppy/decorator.py b/guppy/decorator.py index 32f93ab1..7585d4fd 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -1,7 +1,7 @@ import functools -from collections.abc import Callable +from collections.abc import Callable, Iterator, Sequence from dataclasses import dataclass -from typing import Any +from typing import Any, ClassVar, TypeVar from guppy.ast_util import AstNode, has_empty_body from guppy.custom import ( @@ -13,7 +13,7 @@ OpCompiler, ) from guppy.error import GuppyError, pretty_errors -from guppy.gtypes import GuppyType +from guppy.gtypes import GuppyType, TypeTransformer from guppy.hugr import ops, tys from guppy.hugr.hugr import Hugr from guppy.module import GuppyModule, PyFunc, parse_py_func @@ -94,7 +94,8 @@ def dec(c: type) -> type: @dataclass(frozen=True) class NewType(GuppyType): - name = _name + args: Sequence[GuppyType] + name: ClassVar[str] = _name @staticmethod def build(*args: GuppyType, node: AstNode | None = None) -> "GuppyType": @@ -103,7 +104,11 @@ def build(*args: GuppyType, node: AstNode | None = None) -> "GuppyType": raise GuppyError( f"Type `{_name}` does not accept type parameters.", node ) - return NewType() + return NewType([]) + + @property + def type_args(self) -> Iterator[GuppyType]: + return iter(self.args) @property def linear(self) -> bool: @@ -112,6 +117,11 @@ def linear(self) -> bool: def to_hugr(self) -> tys.SimpleType: return hugr_ty + def transform(self, transformer: TypeTransformer) -> GuppyType: + return transformer.transform(self) or NewType( + [ty.transform(transformer) for ty in self.args] + ) + def __str__(self) -> str: return _name @@ -124,6 +134,12 @@ def __str__(self) -> str: return dec + @pretty_errors + def type_var(self, module: GuppyModule, name: str, linear: bool = False) -> TypeVar: + """Creates a new type variable in a module.""" + module.register_type_var(name, linear) + return TypeVar(name) + @pretty_errors def custom( self, diff --git a/guppy/error.py b/guppy/error.py index bc6d38e7..14bb5813 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -7,7 +7,7 @@ from typing import Any, TypeVar, cast from guppy.ast_util import AstNode, get_file, get_line_offset, get_source -from guppy.gtypes import FunctionType, GuppyType +from guppy.gtypes import BoundTypeVar, ExistentialTypeVar, FunctionType, GuppyType from guppy.hugr.hugr import Node, OutPortV # Whether the interpreter should exit when a Guppy error occurs @@ -72,6 +72,10 @@ class GuppyTypeError(GuppyError): """Special Guppy exception for type errors.""" +class GuppyTypeInferenceError(GuppyError): + """Special Guppy exception for type inference errors.""" + + class InternalGuppyError(Exception): """Exception for internal problems during compilation.""" @@ -119,6 +123,14 @@ def returns(self) -> GuppyType: def args_names(self) -> Sequence[str] | None: raise InternalGuppyError("Tried to access unknown function type") + @property + def quantified(self) -> Sequence[BoundTypeVar]: + raise InternalGuppyError("Tried to access unknown function type") + + @property + def unsolved_vars(self) -> set[ExistentialTypeVar]: + return set() + def format_source_location( loc: ast.AST, @@ -130,7 +142,9 @@ def format_source_location( assert source is not None assert line_offset is not None source_lines = source.splitlines(keepends=True) - end_col_offset = loc.end_col_offset or len(source_lines[loc.lineno]) + end_col_offset = loc.end_col_offset + if end_col_offset is None or (loc.end_lineno and loc.end_lineno > loc.lineno): + end_col_offset = len(source_lines[loc.lineno - 1]) - 1 s = "".join(source_lines[max(loc.lineno - num_lines, 0) : loc.lineno]).rstrip() s += "\n" + loc.col_offset * " " + (end_col_offset - loc.col_offset) * "^" s = textwrap.dedent(s).splitlines() diff --git a/guppy/gtypes.py b/guppy/gtypes.py index a6a15c20..9046e2cb 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -1,29 +1,66 @@ import ast +import itertools from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + ClassVar, + Literal, +) import guppy.hugr.tys as tys -from guppy.ast_util import AstNode, set_location_from +from guppy.ast_util import AstNode if TYPE_CHECKING: from guppy.checker.core import Globals +Subst = dict["ExistentialTypeVar", "GuppyType"] +Inst = Sequence["GuppyType"] + + +@dataclass(frozen=True) class GuppyType(ABC): """Base class for all Guppy types. Note that all instances of `GuppyType` subclasses are expected to be immutable. """ - name: str = "" + name: ClassVar[str] + + # Cache for free variables + _unsolved_vars: set["ExistentialTypeVar"] = field(init=False, repr=False) + + def __post_init__(self) -> None: + # Make sure that we don't have higher-rank polymorphic types + for arg in self.type_args: + if isinstance(arg, FunctionType) and arg.quantified: + from guppy.error import InternalGuppyError + + raise InternalGuppyError( + "Tried to construct a higher-rank polymorphic type!" + ) + + # Compute free variables + if isinstance(self, ExistentialTypeVar): + vs = {self} + else: + vs = set() + for arg in self.type_args: + vs |= arg.unsolved_vars + object.__setattr__(self, "_unsolved_vars", vs) @staticmethod @abstractmethod def build(*args: "GuppyType", node: AstNode | None = None) -> "GuppyType": pass + @property + @abstractmethod + def type_args(self) -> Iterator["GuppyType"]: + pass + @property @abstractmethod def linear(self) -> bool: @@ -33,6 +70,87 @@ def linear(self) -> bool: def to_hugr(self) -> tys.SimpleType: pass + @abstractmethod + def transform(self, transformer: "TypeTransformer") -> "GuppyType": + pass + + @property + def unsolved_vars(self) -> set["ExistentialTypeVar"]: + return self._unsolved_vars + + def substitute(self, s: Subst) -> "GuppyType": + return self.transform(Substituter(s)) + + +@dataclass(frozen=True) +class BoundTypeVar(GuppyType): + """Bound type variable, identified with a de Bruijn index.""" + + idx: int + display_name: str + linear: bool = False + + name: ClassVar[Literal["BoundTypeVar"]] = "BoundTypeVar" + + @staticmethod + def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: + raise NotImplementedError + + @property + def type_args(self) -> Iterator["GuppyType"]: + return iter(()) + + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or self + + def __str__(self) -> str: + return self.display_name + + def to_hugr(self) -> tys.SimpleType: + return tys.Variable(i=self.idx, b=tys.TypeBound.from_linear(self.linear)) + + +@dataclass(frozen=True) +class ExistentialTypeVar(GuppyType): + """Existential type variable, identified with a globally unique id. + + Is solved during type checking. + """ + + id: int + display_name: str + linear: bool = False + + name: ClassVar[Literal["ExistentialTypeVar"]] = "ExistentialTypeVar" + + _id_generator: ClassVar[Iterator[int]] = itertools.count() + + @classmethod + def new(cls, display_name: str, linear: bool) -> "ExistentialTypeVar": + return ExistentialTypeVar(next(cls._id_generator), display_name, linear) + + @staticmethod + def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: + raise NotImplementedError + + @property + def type_args(self) -> Iterator["GuppyType"]: + return iter(()) + + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or self + + def __str__(self) -> str: + return "?" + self.display_name + + def __hash__(self) -> int: + return self.id + + def to_hugr(self) -> tys.SimpleType: + from guppy.error import InternalGuppyError + + raise InternalGuppyError("Tried to convert free type variable to Hugr") + @dataclass(frozen=True) class FunctionType(GuppyType): @@ -42,16 +160,24 @@ class FunctionType(GuppyType): default=None, compare=False, # Argument names are not taken into account for type equality ) + quantified: Sequence[BoundTypeVar] = field(default_factory=list) - name: str = "->" + name: ClassVar[Literal["%function"]] = "%function" linear = False def __str__(self) -> str: + prefix = ( + "forall " + ", ".join(str(v) for v in self.quantified) + ". " + if self.quantified + else "" + ) if len(self.args) == 1: [arg] = self.args - return f"{arg} -> {self.returns}" + return prefix + f"{arg} -> {self.returns}" else: - return f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" + return ( + prefix + f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" + ) @staticmethod def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: @@ -59,18 +185,67 @@ def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: # has a special case for function types. raise NotImplementedError - def to_hugr(self) -> tys.SimpleType: + @property + def type_args(self) -> Iterator[GuppyType]: + return itertools.chain(iter(self.args), iter((self.returns,))) + + def to_hugr(self) -> tys.PolyFuncType: ins = [t.to_hugr() for t in self.args] outs = [t.to_hugr() for t in type_to_row(self.returns)] func_ty = tys.FunctionType(input=ins, output=outs, extension_reqs=[]) - return tys.PolyFuncType(params=[], body=func_ty) + return tys.PolyFuncType( + params=[ + tys.TypeParam(b=tys.TypeBound.from_linear(v.linear)) + for v in self.quantified + ], + body=func_ty, + ) + + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or FunctionType( + [ty.transform(transformer) for ty in self.args], + self.returns.transform(transformer), + self.arg_names, + ) + + def instantiate(self, tys: Sequence[GuppyType]) -> "FunctionType": + """Instantiates quantified type variables.""" + assert len(tys) == len(self.quantified) + + # Set the `preserve` flag for instantiated tuples and None + preserved_tys: list[GuppyType] = [] + for ty in tys: + if isinstance(ty, TupleType): + ty = TupleType(ty.element_types, preserve=True) + elif isinstance(ty, NoneType): + ty = NoneType(preserve=True) + preserved_tys.append(ty) + + inst = Instantiator(preserved_tys) + return FunctionType( + [ty.transform(inst) for ty in self.args], + self.returns.transform(inst), + self.arg_names, + ) + + def unquantified(self) -> tuple["FunctionType", Sequence[ExistentialTypeVar]]: + """Replaces all quantified variables with free type variables.""" + inst = [ + ExistentialTypeVar.new(v.display_name, v.linear) for v in self.quantified + ] + return self.instantiate(inst), inst @dataclass(frozen=True) class TupleType(GuppyType): element_types: Sequence[GuppyType] - name: str = "tuple" + # Flag to avoid turning the tuple into row when calling `type_to_row()`. This is + # used to make sure that type vars instantiated to tuples are not broken up into + # rows when generating a Hugr + preserve: bool = field(default=False, compare=False) + + name: ClassVar[Literal["tuple"]] = "tuple" @staticmethod def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: @@ -88,15 +263,26 @@ def __str__(self) -> str: def linear(self) -> bool: return any(t.linear for t in self.element_types) + @property + def type_args(self) -> Iterator[GuppyType]: + return iter(self.element_types) + def to_hugr(self) -> tys.SimpleType: ts = [t.to_hugr() for t in self.element_types] return tys.Tuple(inner=ts) + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or TupleType( + [ty.transform(transformer) for ty in self.element_types] + ) + @dataclass(frozen=True) class SumType(GuppyType): element_types: Sequence[GuppyType] + name: ClassVar[str] = "%tuple" + @staticmethod def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: # Sum types cannot be parsed and constructed using `build` since they cannot be @@ -110,6 +296,10 @@ def __str__(self) -> str: def linear(self) -> bool: return any(t.linear for t in self.element_types) + @property + def type_args(self) -> Iterator[GuppyType]: + return iter(self.element_types) + def to_hugr(self) -> tys.SimpleType: if all( isinstance(e, TupleType) and len(e.element_types) == 0 @@ -118,12 +308,22 @@ def to_hugr(self) -> tys.SimpleType: return tys.UnitSum(size=len(self.element_types)) return tys.GeneralSum(row=[t.to_hugr() for t in self.element_types]) + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or SumType( + [ty.transform(transformer) for ty in self.element_types] + ) + @dataclass(frozen=True) class NoneType(GuppyType): - name: str = "None" + name: ClassVar[Literal["None"]] = "None" linear: bool = False + # Flag to avoid turning the type into a row when calling `type_to_row()`. This is + # used to make sure that type vars instantiated to Nones are not broken up into + # empty rows when generating a Hugr + preserve: bool = field(default=False, compare=False) + @staticmethod def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: if len(args) > 0: @@ -132,19 +332,29 @@ def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: raise GuppyError("Type `None` is not generic", node) return NoneType() + @property + def type_args(self) -> Iterator[GuppyType]: + return iter(()) + + def substitute(self, s: Subst) -> GuppyType: + return self + def __str__(self) -> str: return "None" def to_hugr(self) -> tys.SimpleType: return tys.Tuple(inner=[]) + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or self + @dataclass(frozen=True) class BoolType(SumType): """The type of booleans.""" - linear = False - name = "bool" + linear: bool = False + name: ClassVar[Literal["bool"]] = "bool" def __init__(self) -> None: # Hugr bools are encoded as Sum((), ()) @@ -161,34 +371,129 @@ def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: def __str__(self) -> str: return "bool" + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or self -def _lookup_type(node: AstNode, globals: "Globals") -> type[GuppyType] | None: - if isinstance(node, ast.Name) and node.id in globals.types: - return globals.types[node.id] - if isinstance(node, ast.Constant) and node.value is None: - return NoneType - if ( - isinstance(node, ast.Constant) - and isinstance(node.value, str) - and node.value in globals.types - ): - return globals.types[node.value] + +class TypeTransformer(ABC): + """Abstract base class for a type visitor that transforms types.""" + + @abstractmethod + def transform(self, ty: GuppyType) -> GuppyType | None: + """This method is called for each visited type. + + Return a transformed type or `None` to continue the recursive visit. + """ + + +class Substituter(TypeTransformer): + """Type transformer that substitutes free type variables.""" + + subst: Subst + + def __init__(self, subst: Subst) -> None: + self.subst = subst + + def transform(self, ty: GuppyType) -> GuppyType | None: + if isinstance(ty, ExistentialTypeVar): + return self.subst.get(ty, None) + return None + + +class Instantiator(TypeTransformer): + """Type transformer that instantiates bound type variables.""" + + tys: Sequence[GuppyType] + + def __init__(self, tys: Sequence[GuppyType]) -> None: + self.tys = tys + + def transform(self, ty: GuppyType) -> GuppyType | None: + if isinstance(ty, BoundTypeVar): + # Instantiate if type for the index is available + if ty.idx < len(self.tys): + return self.tys[ty.idx] + + # Otherwise, lower the de Bruijn index + return BoundTypeVar(ty.idx - len(self.tys), ty.display_name, ty.linear) + return None + + +def unify(s: GuppyType, t: GuppyType, subst: Subst | None) -> Subst | None: + """Computes a most general unifier for two types. + + Return a substitutions `subst` such that `s[subst] == t[subst]` or `None` if this + not possible. + """ + if subst is None: + return None + if s == t: + return subst + if isinstance(s, ExistentialTypeVar): + return _unify_var(s, t, subst) + if isinstance(t, ExistentialTypeVar): + return _unify_var(t, s, subst) + if type(s) == type(t): + sargs, targs = list(s.type_args), list(t.type_args) + if len(sargs) == len(targs): + for sa, ta in zip(sargs, targs): + subst = unify(sa, ta, subst) + return subst return None -def type_from_ast(node: AstNode, globals: "Globals") -> GuppyType: +def _unify_var(var: ExistentialTypeVar, t: GuppyType, subst: Subst) -> Subst | None: + """Helper function for unification of type variables.""" + if var in subst: + return unify(subst[var], t, subst) + if isinstance(t, ExistentialTypeVar) and t in subst: + return unify(var, subst[t], subst) + if var in t.unsolved_vars: + return None + return {var: t, **subst} + + +def type_from_ast( + node: AstNode, + globals: "Globals", + type_var_mapping: dict[str, BoundTypeVar] | None = None, +) -> GuppyType: """Turns an AST expression into a Guppy type.""" - if isinstance(node, ast.Name) and (ty := _lookup_type(node, globals)): - return ty.build(node=node) - if isinstance(node, ast.Constant) and (ty := _lookup_type(node, globals)): - name = ast.Name(id=node.value) - set_location_from(name, node) - return ty.build(node=name) - if isinstance(node, ast.Subscript) and (ty := _lookup_type(node.value, globals)): - args = node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] - return ty.build(*(type_from_ast(a, globals) for a in args), node=node) + from guppy.error import GuppyError + + if isinstance(node, ast.Name): + x = node.id + if x in globals.types: + return globals.types[x].build(node=node) + if x in globals.type_vars: + if type_var_mapping is None: + raise GuppyError( + "Free type variable. Only function types can be generic", node + ) + var_decl = globals.type_vars[x] + if var_decl.name not in type_var_mapping: + type_var_mapping[var_decl.name] = BoundTypeVar( + len(type_var_mapping), var_decl.name, var_decl.linear + ) + return type_var_mapping[var_decl.name] + raise GuppyError("Unknown type", node) + + if isinstance(node, ast.Constant): + v = node.value + if v is None: + return NoneType() + if isinstance(v, str): + try: + return type_from_ast(ast.parse(v), globals, type_var_mapping) + except SyntaxError: + raise GuppyError("Invalid Guppy type", node) from None + raise GuppyError(f"Constant `{v}` is not a valid type", node) + if isinstance(node, ast.Tuple): - return TupleType([type_from_ast(el, globals) for el in node.elts]) + return TupleType( + [type_from_ast(el, globals, type_var_mapping) for el in node.elts] + ) + if ( isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name) @@ -196,13 +501,23 @@ def type_from_ast(node: AstNode, globals: "Globals") -> GuppyType: and isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 2 ): + # TODO: Do we want to allow polymorphic Callable types? [func_args, ret] = node.slice.elts if isinstance(func_args, ast.List): return FunctionType( - [type_from_ast(a, globals) for a in func_args.elts], - type_from_ast(ret, globals), + [type_from_ast(a, globals, type_var_mapping) for a in func_args.elts], + type_from_ast(ret, globals, type_var_mapping), + ) + + if isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name): + x = node.value.id + if x in globals.types: + args = ( + node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] + ) + return globals.types[x].build( + *(type_from_ast(a, globals, type_var_mapping) for a in args), node=node ) - from guppy.error import GuppyError raise GuppyError("Not a valid Guppy type", node) @@ -234,8 +549,8 @@ def row_to_type(row: Sequence[GuppyType]) -> GuppyType: def type_to_row(ty: GuppyType) -> Sequence[GuppyType]: """Turns a type into a row of types by unpacking top-level tuples.""" - if isinstance(ty, NoneType): + if isinstance(ty, NoneType) and not ty.preserve: return [] - if isinstance(ty, TupleType): + if isinstance(ty, TupleType) and not ty.preserve: return ty.element_types return [ty] diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 89532fe7..1e63d680 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -12,12 +12,13 @@ from guppy.gtypes import ( FunctionType, GuppyType, + Inst, SumType, TupleType, row_to_type, type_to_row, ) -from guppy.hugr import val +from guppy.hugr import tys, val NodeIdx = int PortOffset = int @@ -527,6 +528,25 @@ def add_partial( ops.DummyOp(name="partial"), None, [new_ty], parent, [*args, def_port] ) + def add_type_apply( + self, func_port: OutPortV, args: Inst, parent: Node | None = None + ) -> VNode: + """Adds a `TypeApply` node to the graph.""" + assert isinstance(func_port.ty, FunctionType) + assert len(func_port.ty.quantified) == len(args) + result_ty = func_port.ty.instantiate(args) + ta = ops.TypeApplication( + input=func_port.ty.to_hugr(), + args=[tys.TypeArg(ty=ty.to_hugr()) for ty in args], + output=result_ty.to_hugr(), + ) + return self.add_node( + ops.TypeApply(ta=ta), + inputs=[func_port], + output_types=[result_ty], + parent=parent, + ) + def add_def( self, fun_ty: FunctionType, parent: Node | None, name: str ) -> DFContainingVNode: diff --git a/guppy/hugr/ops.py b/guppy/hugr/ops.py index 01cbc255..02b393e3 100644 --- a/guppy/hugr/ops.py +++ b/guppy/hugr/ops.py @@ -476,6 +476,26 @@ class Tag(LeafOp): variants: TypeRow # The variants of the sum type. +class TypeApply(LeafOp): + """Fixes some TypeParams of a polymorphic type by providing TypeArgs""" + + lop: Literal["TypeApply"] = "TypeApply" + ta: "TypeApplication" + + +class TypeApplication(BaseModel): + """Records details of an application of a PolyFuncType to some TypeArgs and the + result (a less-, but still potentially-, polymorphic type). + + Note that Guppy only generates full type applications, where the result is a + monomorphic type. Partial type applications are not used by Guppy. + """ + + input: PolyFuncType + args: list[tys.TypeArg] + output: PolyFuncType + + LeafOpUnion = Annotated[ ( CustomOp @@ -498,6 +518,7 @@ class Tag(LeafOp): | UnpackTuple | MakeNewType | Tag + | TypeApply ), Field(discriminator="lop"), ] diff --git a/guppy/hugr/tys.py b/guppy/hugr/tys.py index fa01db41..9098e552 100644 --- a/guppy/hugr/tys.py +++ b/guppy/hugr/tys.py @@ -144,10 +144,11 @@ class GeneralSum(Sum): class Variable(BaseModel): - """A type variable identified by a name.""" + """A type variable identified by a de Bruijn index.""" - t: Literal["Var"] = "Var" - name: str + t: Literal["V"] = "V" + i: int + b: "TypeBound" class Int(BaseModel): @@ -208,6 +209,10 @@ class TypeBound(Enum): Copyable = "C" Any = "A" + @staticmethod + def from_linear(linear: bool) -> "TypeBound": + return TypeBound.Any if linear else TypeBound.Copyable + class Opaque(BaseModel): """An opaque operation that can be downcasted by the extensions that define it.""" diff --git a/guppy/module.py b/guppy/module.py index 3061bafe..09fe2c42 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -6,7 +6,7 @@ from typing import Any, Union from guppy.ast_util import AstNode, annotate_location -from guppy.checker.core import Globals, qualified_name +from guppy.checker.core import Globals, TypeVarDecl, qualified_name from guppy.checker.func_checker import DefinedFunction, check_global_func_def from guppy.compiler.core import CompiledGlobals from guppy.compiler.func_compiler import CompiledFunctionDef, compile_global_func_def @@ -48,7 +48,7 @@ class GuppyModule: def __init__(self, name: str, import_builtins: bool = True): self.name = name - self._globals = Globals({}, {}) + self._globals = Globals({}, {}, {}) self._compiled_globals = {} self._imported_globals = Globals.default() self._imported_compiled_globals = {} @@ -124,8 +124,16 @@ def register_custom_func( def register_type(self, name: str, ty: type[GuppyType]) -> None: """Registers an existing Guppy type as belonging to this Guppy module.""" + self._check_not_yet_compiled() + self._check_type_name_available(name, None) self._globals.types[name] = ty + def register_type_var(self, name: str, linear: bool) -> None: + """Registers a new type variable""" + self._check_not_yet_compiled() + self._check_type_name_available(name, None) + self._globals.type_vars[name] = TypeVarDecl(name, linear) + def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: assert self._instance_func_buffer is not None buffer = self._instance_func_buffer @@ -207,6 +215,19 @@ def _check_name_available(self, name: str, node: AstNode | None) -> None: node, ) + def _check_type_name_available(self, name: str, node: AstNode | None) -> None: + if name in self._globals.types: + raise GuppyError( + f"Module `{self.name}` already contains a type `{name}`", + node, + ) + + if name in self._globals.type_vars: + raise GuppyError( + f"Module `{self.name}` already contains a type variable `{name}`", + node, + ) + def parse_py_func(f: PyFunc) -> ast.FunctionDef: source_lines, line_offset = inspect.getsourcelines(f) diff --git a/guppy/nodes.py b/guppy/nodes.py index 22730db5..a08d30da 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -1,10 +1,10 @@ """Custom AST nodes used by Guppy""" import ast -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from guppy.gtypes import FunctionType +from guppy.gtypes import FunctionType, GuppyType, Inst if TYPE_CHECKING: from guppy.cfg.cfg import CFG @@ -41,12 +41,22 @@ class LocalCall(ast.expr): class GlobalCall(ast.expr): func: "CallableVariable" args: list[ast.expr] - - # Later: Inferred type args + type_args: Inst # Inferred type arguments _fields = ( "func", "args", + "type_args", + ) + + +class TypeApply(ast.expr): + value: ast.expr + tys: Sequence[GuppyType] + + _fields = ( + "value", + "tys", ) diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index 6262b1e8..9f52593c 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -13,7 +13,7 @@ DefaultCallChecker, ) from guppy.error import GuppyError, GuppyTypeError -from guppy.gtypes import BoolType, FunctionType, GuppyType +from guppy.gtypes import BoolType, FunctionType, GuppyType, Subst, unify from guppy.hugr import ops, tys, val from guppy.hugr.hugr import OutPortV from guppy.nodes import GlobalCall @@ -98,7 +98,8 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) if isinstance(ty, self.ctx.globals.types["int"]): call = with_loc( - self.node, GlobalCall(func=Int.__float__, args=[args[i]]) + self.node, + GlobalCall(func=Int.__float__, args=[args[i]], type_args=[]), ) args[i] = with_type(self.ctx.globals.types["float"].build(), call) return super().synthesize(args) @@ -116,11 +117,11 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: super()._setup(ctx, node, func) self.base_checker._setup(ctx, node, func) - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: - expr = self.base_checker.check(args, ty) + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + expr, subst = self.base_checker.check(args, ty) if isinstance(expr, GlobalCall): expr.args = list(reversed(args)) - return expr + return expr, subst def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: expr, ty = self.base_checker.synthesize(args) @@ -140,7 +141,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: f"Builtin method `{self.func.name}` is not supported by Guppy", self.node ) - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: raise GuppyError( f"Builtin method `{self.func.name}` is not supported by Guppy", self.node ) @@ -176,7 +177,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: args, func = self._get_func(args) return func.synthesize_call(args, self.node, self.ctx) - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: args, func = self._get_func(args) return func.check_call(args, ty, self.node, self.ctx) @@ -195,13 +196,14 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: const = with_loc(self.node, ast.Constant(value=is_callable)) return const, BoolType() - def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: args, _ = self.synthesize(args) - if not isinstance(ty, BoolType): + subst = unify(ty, BoolType(), {}) + if subst is None: raise GuppyTypeError( f"Expected expression of type `{ty}`, got `bool`", self.node ) - return args + return args, subst class IntTruedivCompiler(CustomCallCompiler): @@ -213,13 +215,13 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # Compile `truediv` using float arithmetic [left, right] = args [left] = Int.__float__.compile_call( - [left], self.dfg, self.graph, self.globals, self.node + [left], [], self.dfg, self.graph, self.globals, self.node ) [right] = Int.__float__.compile_call( - [right], self.dfg, self.graph, self.globals, self.node + [right], [], self.dfg, self.graph, self.globals, self.node ) return Float.__truediv__.compile_call( - [left, right], self.dfg, self.graph, self.globals, self.node + [left, right], [], self.dfg, self.graph, self.globals, self.node ) @@ -235,7 +237,12 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: ) zero = self.graph.add_load_constant(zero_const.out_port(0), self.dfg.node) return Float.__ne__.compile_call( - [args[0], zero.out_port(0)], self.dfg, self.graph, self.globals, self.node + [args[0], zero.out_port(0)], + [], + self.dfg, + self.graph, + self.globals, + self.node, ) @@ -247,10 +254,10 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # We have: floordiv(x, y) = floor(truediv(x, y)) [div] = Float.__truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node + args, [], self.dfg, self.graph, self.globals, self.node ) [floor] = Float.__floor__.compile_call( - [div], self.dfg, self.graph, self.globals, self.node + [div], [], self.dfg, self.graph, self.globals, self.node ) return [floor] @@ -263,13 +270,13 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # We have: mod(x, y) = x - (x // y) * y [div] = Float.__floordiv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node + args, [], self.dfg, self.graph, self.globals, self.node ) [mul] = Float.__mul__.compile_call( - [div, args[1]], self.dfg, self.graph, self.globals, self.node + [div, args[1]], [], self.dfg, self.graph, self.globals, self.node ) [sub] = Float.__sub__.compile_call( - [args[0], mul], self.dfg, self.graph, self.globals, self.node + [args[0], mul], [], self.dfg, self.graph, self.globals, self.node ) return [sub] @@ -282,9 +289,9 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: # We have: divmod(x, y) = (div(x, y), mod(x, y)) [div] = Float.__truediv__.compile_call( - args, self.dfg, self.graph, self.globals, self.node + args, [], self.dfg, self.graph, self.globals, self.node ) [mod] = Float.__mod__.compile_call( - args, self.dfg, self.graph, self.globals, self.node + args, [], self.dfg, self.graph, self.globals, self.node ) return [self.graph.add_make_tuple([div, mod], self.dfg.node).out_port(0)] diff --git a/tests/error/misc_errors/return_not_annotated.err b/tests/error/misc_errors/return_not_annotated.err index 0c2eb656..dcce68ed 100644 --- a/tests/error/misc_errors/return_not_annotated.err +++ b/tests/error/misc_errors/return_not_annotated.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:5 3: @guppy 4: def foo(x: bool): - ^^^^^^^^^^^^ + ^^^^^^^^^^^^^^^^^ GuppyError: Return type must be annotated diff --git a/tests/error/misc_errors/return_not_annotated_none2.err b/tests/error/misc_errors/return_not_annotated_none2.err index cebd15a9..58d79a43 100644 --- a/tests/error/misc_errors/return_not_annotated_none2.err +++ b/tests/error/misc_errors/return_not_annotated_none2.err @@ -2,5 +2,5 @@ Guppy compilation failed. Error in file $FILE:5 3: @guppy 4: def foo(): - ^^^^^^^^^^^^^^^^ + ^^^^^^^^^^ GuppyError: Return type must be annotated. Try adding a `-> None` annotation. diff --git a/tests/error/poly_errors/__init__.py b/tests/error/poly_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/poly_errors/arg_mismatch1.err b/tests/error/poly_errors/arg_mismatch1.err new file mode 100644 index 00000000..211407eb --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main(x: bool, y: tuple[bool]) -> None: +17: foo(x, y) + ^ +GuppyTypeError: Expected argument of type `bool`, got `(bool)` diff --git a/tests/error/poly_errors/arg_mismatch1.py b/tests/error/poly_errors/arg_mismatch1.py new file mode 100644 index 00000000..0e631f22 --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch1.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: T, y: T) -> None: + ... + + +@guppy(module) +def main(x: bool, y: tuple[bool]) -> None: + foo(x, y) + + +module.compile() diff --git a/tests/error/poly_errors/arg_mismatch2.err b/tests/error/poly_errors/arg_mismatch2.err new file mode 100644 index 00000000..13b54199 --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main() -> None: +17: foo(False) + ^^^^^ +GuppyTypeError: Expected argument of type `(?T, ?T)`, got `bool` diff --git a/tests/error/poly_errors/arg_mismatch2.py b/tests/error/poly_errors/arg_mismatch2.py new file mode 100644 index 00000000..78f82b0c --- /dev/null +++ b/tests/error/poly_errors/arg_mismatch2.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: tuple[T, T]) -> None: + ... + + +@guppy(module) +def main() -> None: + foo(False) + + +module.compile() diff --git a/tests/error/poly_errors/define.err b/tests/error/poly_errors/define.err new file mode 100644 index 00000000..d2500379 --- /dev/null +++ b/tests/error/poly_errors/define.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:11 + +9: @guppy(module) +10: def main(x: T) -> T: + ^^^^^^^^^^^^^^^^^^^^ +GuppyError: Generic function definitions are not supported yet diff --git a/tests/error/poly_errors/define.py b/tests/error/poly_errors/define.py new file mode 100644 index 00000000..433c71d9 --- /dev/null +++ b/tests/error/poly_errors/define.py @@ -0,0 +1,15 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy(module) +def main(x: T) -> T: + return x + + +module.compile() diff --git a/tests/error/poly_errors/free_return_var.err b/tests/error/poly_errors/free_return_var.err new file mode 100644 index 00000000..bd5e522c --- /dev/null +++ b/tests/error/poly_errors/free_return_var.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main() -> None: +17: x = foo() + ^^^^^ +GuppyTypeInferenceError: Cannot infer type variable in expression of type `?T` diff --git a/tests/error/poly_errors/free_return_var.py b/tests/error/poly_errors/free_return_var.py new file mode 100644 index 00000000..9ac9299e --- /dev/null +++ b/tests/error/poly_errors/free_return_var.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo() -> T: + ... + + +@guppy(module) +def main() -> None: + x = foo() + + +module.compile() diff --git a/tests/error/poly_errors/inst_return_mismatch.err b/tests/error/poly_errors/inst_return_mismatch.err new file mode 100644 index 00000000..90683da3 --- /dev/null +++ b/tests/error/poly_errors/inst_return_mismatch.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main(x: bool) -> None: +17: y: None = foo(x) + ^^^^^^ +GuppyTypeError: Expected expression of type `None`, got `bool` diff --git a/tests/error/poly_errors/inst_return_mismatch.py b/tests/error/poly_errors/inst_return_mismatch.py new file mode 100644 index 00000000..885de5fa --- /dev/null +++ b/tests/error/poly_errors/inst_return_mismatch.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: T) -> T: + ... + + +@guppy(module) +def main(x: bool) -> None: + y: None = foo(x) + + +module.compile() diff --git a/tests/error/poly_errors/inst_return_mismatch_nested.err b/tests/error/poly_errors/inst_return_mismatch_nested.err new file mode 100644 index 00000000..1c14fb58 --- /dev/null +++ b/tests/error/poly_errors/inst_return_mismatch_nested.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main(x: bool) -> None: +17: y: None = foo(foo(foo(x))) + ^^^^^^^^^^^^^^^^ +GuppyTypeError: Expected expression of type `None`, got `bool` diff --git a/tests/error/poly_errors/inst_return_mismatch_nested.py b/tests/error/poly_errors/inst_return_mismatch_nested.py new file mode 100644 index 00000000..b308f31f --- /dev/null +++ b/tests/error/poly_errors/inst_return_mismatch_nested.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: T) -> T: + ... + + +@guppy(module) +def main(x: bool) -> None: + y: None = foo(foo(foo(x))) + + +module.compile() diff --git a/tests/error/poly_errors/non_linear1.err b/tests/error/poly_errors/non_linear1.err new file mode 100644 index 00000000..a6a2ad1d --- /dev/null +++ b/tests/error/poly_errors/non_linear1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:21 + +19: @guppy(module) +20: def main(q: Qubit) -> None: +21: foo(q) + ^^^^^^ +GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. T -> None` with linear type `Qubit` diff --git a/tests/error/poly_errors/non_linear1.py b/tests/error/poly_errors/non_linear1.py new file mode 100644 index 00000000..728dccd7 --- /dev/null +++ b/tests/error/poly_errors/non_linear1.py @@ -0,0 +1,24 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.prelude.quantum import Qubit + +import guppy.prelude.quantum as quantum + +module = GuppyModule("test") +module.load(quantum) + + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: T) -> None: + ... + + +@guppy(module) +def main(q: Qubit) -> None: + foo(q) + + +module.compile() diff --git a/tests/error/poly_errors/non_linear2.err b/tests/error/poly_errors/non_linear2.err new file mode 100644 index 00000000..63c18f16 --- /dev/null +++ b/tests/error/poly_errors/non_linear2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:23 + +21: @guppy(module) +22: def main() -> None: +23: foo(h) + ^^^^^^ +GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. T -> T -> None` with linear type `Qubit` diff --git a/tests/error/poly_errors/non_linear2.py b/tests/error/poly_errors/non_linear2.py new file mode 100644 index 00000000..b46238c3 --- /dev/null +++ b/tests/error/poly_errors/non_linear2.py @@ -0,0 +1,26 @@ +from typing import Callable + +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.prelude.quantum import h + +import guppy.prelude.quantum as quantum + +module = GuppyModule("test") +module.load(quantum) + + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(x: Callable[[T], T]) -> None: + ... + + +@guppy(module) +def main() -> None: + foo(h) + + +module.compile() diff --git a/tests/error/poly_errors/pass_poly_free.err b/tests/error/poly_errors/pass_poly_free.err new file mode 100644 index 00000000..5154e2df --- /dev/null +++ b/tests/error/poly_errors/pass_poly_free.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:24 + +22: @guppy(module) +23: def main() -> None: +24: foo(bar) + ^^^ +GuppyTypeInferenceError: Expected argument of type `?T -> ?T`, got `forall T. T -> T`. Couldn't infer an instantiation for type variable `T` (higher-rank polymorphic types are not supported) diff --git a/tests/error/poly_errors/pass_poly_free.py b/tests/error/poly_errors/pass_poly_free.py new file mode 100644 index 00000000..d88a6b83 --- /dev/null +++ b/tests/error/poly_errors/pass_poly_free.py @@ -0,0 +1,27 @@ +from typing import Callable + +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo(f: Callable[[T], T]) -> None: + ... + + +@guppy.declare(module) +def bar(x: T) -> T: + ... + + +@guppy(module) +def main() -> None: + foo(bar) + + +module.compile() diff --git a/tests/error/poly_errors/return_mismatch.err b/tests/error/poly_errors/return_mismatch.err new file mode 100644 index 00000000..4d3029b2 --- /dev/null +++ b/tests/error/poly_errors/return_mismatch.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def main() -> None: +17: x: bool = foo() + ^^^^^ +GuppyTypeError: Expected expression of type `bool`, got `(?T, ?T)` diff --git a/tests/error/poly_errors/return_mismatch.py b/tests/error/poly_errors/return_mismatch.py new file mode 100644 index 00000000..7be60db6 --- /dev/null +++ b/tests/error/poly_errors/return_mismatch.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo() -> tuple[T, T]: + ... + + +@guppy(module) +def main() -> None: + x: bool = foo() + + +module.compile() diff --git a/tests/error/poly_errors/right_to_left.err b/tests/error/poly_errors/right_to_left.err new file mode 100644 index 00000000..b488be7d --- /dev/null +++ b/tests/error/poly_errors/right_to_left.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:22 + +20: @guppy(module) +21: def main() -> None: +22: bar(foo(), 42) + ^^^^^ +GuppyTypeInferenceError: Cannot infer type variable in expression of type `?T` diff --git a/tests/error/poly_errors/right_to_left.py b/tests/error/poly_errors/right_to_left.py new file mode 100644 index 00000000..1fec068e --- /dev/null +++ b/tests/error/poly_errors/right_to_left.py @@ -0,0 +1,25 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule + + +module = GuppyModule("test") + +T = guppy.type_var(module, "T") + + +@guppy.declare(module) +def foo() -> T: + ... + + +@guppy.declare(module) +def bar(x: T, y: T) -> None: + ... + + +@guppy(module) +def main() -> None: + bar(foo(), 42) + + +module.compile() diff --git a/tests/error/test_poly_errors.py b/tests/error/test_poly_errors.py new file mode 100644 index 00000000..a140769f --- /dev/null +++ b/tests/error/test_poly_errors.py @@ -0,0 +1,15 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "poly_errors" +files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_type_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py index e33537cd..d2729326 100644 --- a/tests/hugr/test_dummy_nodes.py +++ b/tests/hugr/test_dummy_nodes.py @@ -5,7 +5,7 @@ def test_single_dummy(): g = Hugr() - defn = g.add_def(FunctionType([BoolType()], [BoolType()]), g.root, "test") + defn = g.add_def(FunctionType([BoolType()], BoolType()), g.root, "test") dfg = g.add_dfg(defn) inp = g.add_input([BoolType()], dfg).out_port(0) dummy = g.add_node( diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py new file mode 100644 index 00000000..2446e308 --- /dev/null +++ b/tests/integration/test_poly.py @@ -0,0 +1,284 @@ +from collections.abc import Callable + +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.prelude.quantum import Qubit + +import guppy.prelude.quantum as quantum + + +def test_id(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int) -> int: + return foo(x) + + validate(module.compile()) + + +def test_id_nested(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int) -> int: + return foo(foo(foo(x))) + + validate(module.compile()) + + +def test_use_twice(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int, y: bool) -> None: + foo(x) + foo(y) + + validate(module.compile()) + + +def test_define_twice(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy.declare(module) + def bar(x: T) -> T: # Reuse same type var! + ... + + @guppy(module) + def main(x: bool, y: float) -> None: + foo(x) + foo(y) + + validate(module.compile()) + + +def test_return_tuple_implicit(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int) -> tuple[int, int]: + return foo((x, 0)) + + validate(module.compile()) + + +def test_same_args(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T, y: T) -> None: + ... + + @guppy(module) + def main(x: int) -> None: + foo(x, 42) + + validate(module.compile()) + + +def test_different_args(validate): + module = GuppyModule("test") + S = guppy.type_var(module, "S") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: S, y: T, z: tuple[S, T]) -> T: + ... + + @guppy(module) + def main(x: int, y: float) -> float: + return foo(x, y, (x, y)) + foo(y, 42.0, (0.0, y)) + + validate(module.compile()) + + +def test_infer_basic(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo() -> T: + ... + + @guppy(module) + def main() -> None: + x: int = foo() + + validate(module.compile()) + + +def test_infer_nested(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo() -> T: + ... + + @guppy.declare(module) + def bar(x: T) -> T: + ... + + @guppy(module) + def main() -> None: + x: int = bar(foo()) + + validate(module.compile()) + + +def test_infer_left_to_right(validate): + module = GuppyModule("test") + S = guppy.type_var(module, "S") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo() -> T: + ... + + @guppy.declare(module) + def bar(x: T, y: T, z: S, a: tuple[T, S]) -> None: + ... + + @guppy(module) + def main() -> None: + bar(42, foo(), False, foo()) + + validate(module.compile()) + + +def test_pass_poly_basic(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(f: Callable[[T], T]) -> None: + ... + + @guppy.declare(module) + def bar(x: int) -> int: + ... + + @guppy(module) + def main() -> None: + foo(bar) + + validate(module.compile()) + + +def test_pass_poly_cross(validate): + module = GuppyModule("test") + S = guppy.type_var(module, "S") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(f: Callable[[S], int]) -> None: + ... + + @guppy.declare(module) + def bar(x: bool) -> T: + ... + + @guppy(module) + def main() -> None: + foo(bar) + + validate(module.compile()) + + +def test_linear(validate): + module = GuppyModule("test") + module.load(quantum) + T = guppy.type_var(module, "T", linear=True) + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(q: Qubit) -> Qubit: + return foo(q) + + validate(module.compile()) + + +def test_pass_nonlinear(validate): + module = GuppyModule("test") + module.load(quantum) + T = guppy.type_var(module, "T", linear=True) + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy(module) + def main(x: int) -> None: + foo(x) + + validate(module.compile()) + + +def test_pass_linear(validate): + module = GuppyModule("test") + module.load(quantum) + T = guppy.type_var(module, "T", linear=True) + + @guppy.declare(module) + def foo(f: Callable[[T], T]) -> None: + ... + + @guppy.declare(module) + def bar(q: Qubit) -> Qubit: + ... + + @guppy(module) + def main() -> None: + foo(bar) + + validate(module.compile()) + + +def test_higher_order_value(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo(x: T) -> T: + ... + + @guppy.declare(module) + def bar(x: T) -> T: + ... + + @guppy(module) + def main(b: bool) -> int: + f = foo if b else bar + return f(42) + + validate(module.compile())