diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 52f55df8..4ae52902 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -9,7 +9,7 @@ jobs: strategy: matrix: - python-version: [3.9] + python-version: ['3.10'] steps: - uses: actions/checkout@v3 diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index f5831dbf..95c9f63a 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -46,7 +46,7 @@ class BB(ABC): idx: int # Pointer to the CFG that contains this node - cfg: "BaseCFG[Self]" + containing_cfg: "BaseCFG[Self]" # AST statements contained in this BB statements: list[BBStatement] = field(default_factory=list) diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 9cd0f41e..787b2ccb 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -1,11 +1,22 @@ +"""Type checking code for control-flow graphs + +Operates on CFGs produced by the `CFGBuilder`. Produces a `CheckedCFG` consisting of +`CheckedBB`s with inferred type signatures. +""" + +import collections from dataclasses import dataclass from typing import Sequence +from guppy.ast_util import line_col from guppy.cfg.bb import BB from guppy.cfg.cfg import CFG, BaseCFG -from guppy.checker.core import Globals +from guppy.checker.core import Globals, Context from guppy.checker.core import Variable +from guppy.checker.expr_checker import ExprSynthesizer, to_bool +from guppy.checker.stmt_checker import StmtChecker +from guppy.error import GuppyError from guppy.gtypes import GuppyType @@ -52,7 +63,53 @@ def check_cfg( Annotates the basic blocks with input and output type signatures and removes unreachable blocks. """ - raise NotImplementedError + # First, we need to run program analysis + ass_before = set(v.name for v in inputs) + cfg.analyze(ass_before, ass_before) + + # We start by compiling the entry BB + checked_cfg = CheckedCFG([v.ty for v in inputs], return_ty) + checked_cfg.entry_bb = check_bb( + cfg.entry_bb, checked_cfg, inputs, return_ty, globals + ) + compiled = {cfg.entry_bb: checked_cfg.entry_bb} + + # Visit all control-flow edges in BFS order. We can't just do a normal loop over + # all BBs since the input types for a BB are computed by checking a predecessor. + # We do BFS instead of DFS to get a better error ordering. + queue = collections.deque( + (checked_cfg.entry_bb, i, succ) + for i, succ in enumerate(cfg.entry_bb.successors) + ) + while len(queue) > 0: + pred, num_output, bb = queue.popleft() + input_row = [ + Variable(v.name, v.ty, v.defined_at, None) + for v in pred.sig.output_rows[num_output] + ] + + if bb in compiled: + # If the BB was already compiled, we just have to check that the signatures + # match. + check_rows_match(input_row, compiled[bb].sig.input_row, bb) + else: + # Otherwise, check the BB and enqueue its successors + checked_bb = check_bb(bb, checked_cfg, input_row, return_ty, globals) + queue += [(checked_bb, i, succ) for i, succ in enumerate(bb.successors)] + compiled[bb] = checked_bb + + # Link up BBs in the checked CFG + compiled[bb].predecessors.append(pred) + pred.successors[num_output] = compiled[bb] + + checked_cfg.bbs = list(compiled.values()) + checked_cfg.exit_bb = compiled[cfg.exit_bb] # TODO: Fails if exit is unreachable + checked_cfg.live_before = {compiled[bb]: cfg.live_before[bb] for bb in cfg.bbs} + checked_cfg.ass_before = {compiled[bb]: cfg.ass_before[bb] for bb in cfg.bbs} + checked_cfg.maybe_ass_before = { + compiled[bb]: cfg.maybe_ass_before[bb] for bb in cfg.bbs + } + return checked_cfg def check_bb( @@ -62,4 +119,105 @@ def check_bb( return_ty: GuppyType, globals: Globals, ) -> CheckedBB: - raise NotImplementedError + cfg = bb.containing_cfg + + # For the entry BB we have to separately check that all used variables are + # defined. For all other BBs, this will be checked when compiling a predecessor. + if bb == cfg.entry_bb: + assert len(bb.predecessors) == 0 + for x, use in bb.vars.used.items(): + if x not in cfg.ass_before[bb] and x not in globals.values: + raise GuppyError(f"Variable `{x}` is not defined", use) + + # Check the basic block + ctx = Context(globals, {v.name: v for v in inputs}) + checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) + + # If we branch, we also have to check the branch predicate + if len(bb.successors) > 1: + assert bb.branch_pred is not None + bb.branch_pred, ty = ExprSynthesizer(ctx).synthesize(bb.branch_pred) + bb.branch_pred, _ = to_bool(bb.branch_pred, ty, ctx) + + for succ in bb.successors: + for x, use_bb in cfg.live_before[succ].items(): + # Check that the variables requested by the successor are defined + if x not in ctx.locals and x not in ctx.globals.values: + # If the variable is defined on *some* paths, we can give a more + # informative error message + if x in cfg.maybe_ass_before[use_bb]: + # TODO: This should be "Variable x is not defined when coming + # from {bb}". But for this we need a way to associate BBs with + # source locations. + raise GuppyError( + f"Variable `{x}` is not defined on all control-flow paths.", + use_bb.vars.used[x], + ) + raise GuppyError(f"Variable `{x}` is not defined", use_bb.vars.used[x]) + + # We have to check that used linear variables are not being outputted + if x in ctx.locals: + var = ctx.locals[x] + if var.ty.linear and var.used: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` was " + "already used (at {0})", + cfg.live_before[succ][x].vars.used[x], + [var.used], + ) + + # On the other hand, unused linear variables *must* be outputted + for x, var in ctx.locals.items(): + if var.ty.linear and not var.used and x not in cfg.live_before[succ]: + # TODO: This should be "Variable x with linear type ty is not + # used in {bb}". But for this we need a way to associate BBs with + # source locations. + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` is " + "not used on all control-flow paths", + var.defined_at, + ) + + # Finally, we need to compute the signature of the basic block + outputs = [ + [ctx.locals[x] for x in cfg.live_before[succ] if x in ctx.locals] + for succ in bb.successors + ] + + # Also prepare the successor list so we can fill it in later + checked_bb = CheckedBB( + bb.idx, checked_cfg, checked_stmts, sig=Signature(inputs, outputs) + ) + checked_bb.successors = [None] * len(bb.successors) # type: ignore[list-item] + checked_bb.branch_pred = bb.branch_pred + return checked_bb + + +def check_rows_match(row1: VarRow, row2: VarRow, bb: BB) -> None: + """Checks that the types of two rows match up. + + Otherwise, an error is thrown, alerting the user that a variable has different + types on different control-flow paths. + """ + map1, map2 = {v.name: v for v in row1}, {v.name: v for v in row2} + assert map1.keys() == map2.keys() + for x in map1: + v1, v2 = map1[x], map2[x] + if v1.ty != v2.ty: + # In the error message, we want to mention the variable that was first + # defined at the start. + if ( + v1.defined_at + and v2.defined_at + and line_col(v2.defined_at) < line_col(v1.defined_at) + ): + v1, v2 = v2, v1 + # We shouldn't mention temporary variables (starting with `%`) + # in error messages: + ident = "Expression" if v1.name.startswith("%") else f"Variable `{v1.name}`" + raise GuppyError( + f"{ident} can refer to different types: " + f"`{v1.ty}` (at {{}}) vs `{v2.ty}` (at {{}})", + bb.containing_cfg.live_before[bb][v1.name].vars.used[v1.name], + [v1.defined_at, v2.defined_at], + ) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py new file mode 100644 index 00000000..09c12774 --- /dev/null +++ b/guppy/checker/expr_checker.py @@ -0,0 +1,364 @@ +"""Type checking and synthesizing code for expressions. + +Operates on expressions in a basic block after CFG construction. In particular, we +assume that expressions that involve control flow (i.e. short-circuiting and ternary +expressions) have been removed during CFG construction. + +Furthermore, we assume that assignment expressions with the walrus operator := have +been turned into regular assignments and are no longer present. As a result, expressions +are assumed to be side effect free, in the sense that they do not modify the variables +available in the type checking context. + +We may alter/desugar AST nodes during type checking. In particular, we turn `ast.Name` +nodes into either `LocalName` or `GlobalName` nodes and `ast.Call` nodes are turned into +`LocalCall` or `GlobalCall` nodes. Furthermore, all nodes in the resulting AST are +annotated with their type. + +Expressions can be checked against a given type by the `ExprChecker`, raising a type +error if the expressions doesn't have the expected type. Checking is used for annotated +assignments, return values, and function arguments. Alternatively, the `ExprSynthesizer` +can be used to infer a type for an expression. +""" + +import ast +from contextlib import suppress +from typing import Optional, Union, NoReturn, Any + +from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt +from guppy.checker.core import Context, CallableVariable, Globals +from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError +from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType +from guppy.nodes import LocalName, GlobalName, LocalCall + +# Mapping from unary AST op to dunder method and display name +unary_table: dict[type[ast.unaryop], tuple[str, str]] = { + ast.UAdd: ("__pos__", "+"), + ast.USub: ("__neg__", "-"), + ast.Invert: ("__invert__", "~"), +} # fmt: skip + +# Mapping from binary AST op to left dunder method, right dunder method and display name +AstOp = Union[ast.operator, ast.cmpop] +binary_table: dict[type[AstOp], tuple[str, str, str]] = { + ast.Add: ("__add__", "__radd__", "+"), + ast.Sub: ("__sub__", "__rsub__", "-"), + ast.Mult: ("__mul__", "__rmul__", "*"), + ast.Div: ("__truediv__", "__rtruediv__", "/"), + ast.FloorDiv: ("__floordiv__", "__rfloordiv__", "//"), + ast.Mod: ("__mod__", "__rmod__", "%"), + ast.Pow: ("__pow__", "__rpow__", "**"), + ast.LShift: ("__lshift__", "__rlshift__", "<<"), + ast.RShift: ("__rshift__", "__rrshift__", ">>"), + ast.BitOr: ("__or__", "__ror__", "|"), + ast.BitXor: ("__xor__", "__rxor__", "^"), + ast.BitAnd: ("__and__", "__rand__", "&"), + ast.MatMult: ("__matmul__", "__rmatmul__", "@"), + ast.Eq: ("__eq__", "__eq__", "=="), + ast.NotEq: ("__neq__", "__neq__", "!="), + ast.Lt: ("__lt__", "__gt__", "<"), + ast.LtE: ("__le__", "__ge__", "<="), + ast.Gt: ("__gt__", "__lt__", ">"), + ast.GtE: ("__ge__", "__le__", ">="), +} # fmt: skip + + +class ExprChecker(AstVisitor[ast.expr]): + """Checks an expression against a type and produces a new type-annotated AST""" + + ctx: Context + + # Name for the kind of term we are currently checking against (used in errors). + # For example, "argument", "return value", or in general "expression". + _kind: str + + def __init__(self, ctx: Context) -> None: + self.ctx = ctx + self._kind = "expression" + + def _fail( + self, + expected: GuppyType, + actual: Union[ast.expr, GuppyType], + loc: Optional[AstNode] = None, + ) -> NoReturn: + """Raises a type error indicating that the type doesn't match.""" + if not isinstance(actual, GuppyType): + loc = loc or actual + _, actual = self._synthesize(actual) + if loc is None: + raise InternalGuppyError("Failure location is required") + raise GuppyTypeError( + f"Expected {self._kind} of type `{expected}`, got `{actual}`", loc + ) + + def check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: + """Checks an expression against a type. + + Returns a new desugared expression with type annotations. + """ + old_kind = self._kind + self._kind = kind or self._kind + expr = self.visit(expr, ty) + self._kind = old_kind + return with_type(ty, expr) + + def _synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + """Invokes the type synthesiser""" + return ExprSynthesizer(self.ctx).synthesize(node) + + def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> ast.expr: + if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): + return self._fail(ty, node) + for i, el in enumerate(node.elts): + node.elts[i] = self.check(el, ty.element_types[i]) + return node + + def generic_visit(self, node: ast.expr, ty: GuppyType) -> ast.expr: # type: ignore[override] + # Try to synthesize and then check if it matches the given type + node, synth = self._synthesize(node) + if synth != ty: + self._fail(ty, synth, node) + return node + + +class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): + ctx: Context + + def __init__(self, ctx: Context) -> None: + self.ctx = ctx + + def synthesize(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + """Tries to synthesise a type for the given expression. + + Also returns a new desugared expression with type annotations. + """ + if ty := get_type_opt(node): + return node, ty + node, ty = self.visit(node) + return with_type(ty, node), ty + + def _check( + self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: + """Checks an expression against a given type""" + return ExprChecker(self.ctx).check(expr, ty, kind) + + def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, GuppyType]: + ty = python_value_to_guppy_type(node.value, node, self.ctx.globals) + if ty is None: + raise GuppyError("Unsupported constant", node) + return node, ty + + def visit_Name(self, node: ast.Name) -> tuple[ast.expr, GuppyType]: + x = node.id + if x in self.ctx.locals: + var = self.ctx.locals[x] + if var.ty.linear and var.used is not None: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` was " + "already used (at {0})", + node, + [var.used], + ) + var.used = node + return with_loc(node, LocalName(id=x)), var.ty + elif x in self.ctx.globals.values: + # Cache value in the AST + value = self.ctx.globals.values[x] + return with_loc(node, GlobalName(id=x, value=value)), value.ty + raise InternalGuppyError( + f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have " + f"been caught by program analysis!" + ) + + def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, GuppyType]: + elems = [self.synthesize(elem) for elem in node.elts] + node.elts = [n for n, _ in elems] + return node, TupleType([ty for _, ty in elems]) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: + # We need to synthesise the argument type, so we can look up dunder methods + node.operand, op_ty = self.synthesize(node.operand) + + # Special case for the `not` operation since it is not implemented via a dunder + # method or control-flow + if isinstance(node.op, ast.Not): + node.operand, bool_ty = to_bool(node.operand, op_ty, self.ctx) + return node, bool_ty + + # Check all other unary expressions by calling out to instance dunder methods + op, display_name = unary_table[node.op.__class__] + func = self.ctx.globals.get_instance_func(op_ty, op) + if func is None: + raise GuppyTypeError( + f"Unary operator `{display_name}` not defined for argument of type " + f" `{op_ty}`", + node.operand, + ) + return func.synthesize_call([node.operand], node, self.ctx) + + def _synthesize_binary( + self, left_expr: ast.expr, right_expr: ast.expr, op: AstOp, node: ast.expr + ) -> tuple[ast.expr, GuppyType]: + """Helper method to compile binary operators by calling out to dunder methods. + + For example, first try calling `__add__` on the left operand. If that fails, try + `__radd__` on the right operand. + """ + if op.__class__ not in binary_table: + raise GuppyError("This binary operation is not supported by Guppy.", op) + lop, rop, display_name = binary_table[op.__class__] + left_expr, left_ty = self.synthesize(left_expr) + right_expr, right_ty = self.synthesize(right_expr) + + if func := self.ctx.globals.get_instance_func(left_ty, lop): + with suppress(GuppyError): + return func.synthesize_call([left_expr, right_expr], node, self.ctx) + + if func := self.ctx.globals.get_instance_func(right_ty, rop): + with suppress(GuppyError): + return func.synthesize_call([right_expr, left_expr], node, self.ctx) + + raise GuppyTypeError( + f"Binary operator `{display_name}` not defined for arguments of type " + f"`{left_ty}` and `{right_ty}`", + node, + ) + + def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: + return self._synthesize_binary(node.left, node.right, node.op, node) + + def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: + if len(node.comparators) != 1 or len(node.ops) != 1: + raise InternalGuppyError( + "BB contains chained comparison. Should have been removed during CFG " + "construction." + ) + left_expr, [op], [right_expr] = node.left, node.ops, node.comparators + return self._synthesize_binary(left_expr, right_expr, op, node) + + def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: + if len(node.keywords) > 0: + raise GuppyError("Keyword arguments are not supported", node.keywords[0]) + node.func, ty = self.synthesize(node.func) + + # First handle direct calls of user-defined functions and extension functions + if isinstance(node.func, GlobalName) and isinstance( + node.func.value, CallableVariable + ): + return node.func.value.synthesize_call(node.args, node, self.ctx) + + # Otherwise, it must be a function as a higher-order value + if isinstance(ty, FunctionType): + args, return_ty = synthesize_call(ty, node.args, node, self.ctx) + return with_loc(node, LocalCall(func=node.func, args=args)), return_ty + elif f := self.ctx.globals.get_instance_func(ty, "__call__"): + return f.synthesize_call(node.args, node, self.ctx) + else: + raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + + def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `NamedExpr`. Should have been removed during CFG" + f"construction: `{ast.unparse(node)}`" + ) + + def visit_BoolOp(self, node: ast.BoolOp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `BoolOp`. Should have been removed during CFG construction: " + f"`{ast.unparse(node)}`" + ) + + def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `IfExp`. Should have been removed during CFG construction: " + f"`{ast.unparse(node)}`" + ) + + +def check_num_args(exp: int, act: int, node: AstNode) -> None: + """Checks that the correct number of arguments have been passed to a function.""" + if act < exp: + raise GuppyTypeError( + f"Not enough arguments passed (expected {exp}, got {act})", node + ) + if exp < act: + if isinstance(node, ast.Call): + raise GuppyTypeError("Unexpected argument", node.args[exp]) + raise GuppyTypeError( + f"Too many arguments passed (expected {exp}, got {act})", node + ) + + +def synthesize_call( + func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context +) -> tuple[list[ast.expr], GuppyType]: + """Synthesizes the return type of a function call. + + Also returns desugared versions of the arguments with type annotations. + """ + check_num_args(len(func_ty.args), len(args), node) + for i, arg in enumerate(args): + args[i] = ExprChecker(ctx).check(arg, func_ty.args[i], "argument") + return args, func_ty.returns + + +def check_call( + func_ty: FunctionType, + args: list[ast.expr], + ty: GuppyType, + node: AstNode, + ctx: Context, +) -> list[ast.expr]: + """Checks the return type of a function call against a given type""" + args, return_ty = synthesize_call(func_ty, args, node, ctx) + if return_ty != ty: + raise GuppyTypeError( + f"Expected expression of type `{ty}`, got `{return_ty}`", node + ) + return args + + +def to_bool( + node: ast.expr, node_ty: GuppyType, ctx: Context +) -> tuple[ast.expr, GuppyType]: + """Tries to turn a node into a bool""" + if isinstance(node_ty, BoolType): + return node, node_ty + + func = ctx.globals.get_instance_func(node_ty, "__bool__") + if func is None: + raise GuppyTypeError( + f"Expression of type `{node_ty}` cannot be interpreted as a `bool`", + node, + ) + + # We could check the return type against bool, but we can give a better error + # message if we synthesise and compare to bool by hand + call, return_ty = func.synthesize_call([node], node, ctx) + if not isinstance(return_ty, BoolType): + raise GuppyTypeError( + f"`__bool__` on type `{node_ty}` returns `{return_ty}` instead of `bool`", + node, + ) + return call, return_ty + + +def python_value_to_guppy_type( + v: Any, node: ast.expr, globals: Globals +) -> Optional[GuppyType]: + """Turns a primitive Python value into a Guppy type. + + Returns `None` if the Python value cannot be represented in Guppy. + """ + match v: + case bool(): + return globals.types["bool"].build(node=node) + case int(): + return globals.types["int"].build(node=node) + case float(): + return globals.types["float"].build(node=node) + case _: + return None diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index cb2d3234..e0695824 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -1,11 +1,21 @@ +"""Type checking code for top-level and nested function definitions. + +For top-level functions, we take the `DefinedFunction` containing the `ast.FunctionDef` +node straight from the Python AST. We build a CFG, check it, and return a +`CheckedFunction` containing a `CheckedCFG` with type annotations. +""" + import ast from dataclasses import dataclass -from guppy.ast_util import AstNode +from guppy.ast_util import return_nodes_in_ast, AstNode, with_loc from guppy.cfg.bb import BB -from guppy.checker.core import Globals, Context, CallableVariable -from guppy.checker.cfg_checker import CheckedCFG -from guppy.gtypes import FunctionType, GuppyType +from guppy.cfg.builder import CFGBuilder +from guppy.checker.core import Variable, Globals, Context, CallableVariable +from guppy.checker.cfg_checker import check_cfg, CheckedCFG +from guppy.checker.expr_checker import synthesize_call, check_call +from guppy.error import GuppyError +from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef @@ -26,12 +36,16 @@ def from_ast( def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context ) -> GlobalCall: - raise NotImplementedError + # Use default implementation from the expression checker + args = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args) def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: - raise NotImplementedError + # Use default implementation from the expression checker + args, ty = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args), ty @dataclass @@ -43,17 +57,133 @@ class CheckedFunction(DefinedFunction): def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFunction: """Type checks a top-level function definition.""" - raise NotImplementedError + func_def = func.defined_at + args = func_def.args.args + returns_none = isinstance(func.ty.returns, NoneType) + assert func.ty.arg_names is not None + + cfg = CFGBuilder().build(func_def.body, returns_none, globals) + inputs = [ + Variable(x, ty, loc, None) + for x, ty, loc in zip(func.ty.arg_names, func.ty.args, args) + ] + cfg = check_cfg(cfg, inputs, func.ty.returns, globals) + return CheckedFunction(func_def.name, func.ty, func_def, None, cfg) def check_nested_func_def( func_def: NestedFunctionDef, bb: BB, ctx: Context ) -> CheckedNestedFunctionDef: """Type checks a local (nested) function definition.""" - raise NotImplementedError + func_ty = check_signature(func_def, ctx.globals) + assert func_ty.arg_names is not None + + # We've already built the CFG for this function while building the CFG of the + # enclosing function + cfg = func_def.cfg + + # Find captured variables + parent_cfg = bb.containing_cfg + def_ass_before = set(func_ty.arg_names) | ctx.locals.keys() + maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb] + cfg.analyze(def_ass_before, maybe_ass_before) + captured = { + x: ctx.locals[x] + for x in cfg.live_before[cfg.entry_bb] + if x not in func_ty.arg_names and x in ctx.locals + } + + # Captured variables may not be linear + for v in captured.values(): + if v.ty.linear: + x = v.name + using_bb = cfg.live_before[cfg.entry_bb][x] + raise GuppyError( + f"Variable `{x}` with linear type `{v.ty}` may not be used here " + f"because it was defined in an outer scope (at {{0}})", + using_bb.vars.used[x], + [v.defined_at], + ) + + # Captured variables may never be assigned to + for bb in cfg.bbs: + for v in captured.values(): + x = v.name + if x in bb.vars.assigned: + raise GuppyError( + f"Variable `{x}` defined in an outer scope (at {{0}}) may not " + f"be assigned to", + bb.vars.assigned[x], + [v.defined_at], + ) + + # Construct inputs for checking the body CFG + inputs = list(captured.values()) + [ + Variable(x, ty, func_def.args.args[i], None) + for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args)) + ] + globals = ctx.globals + + # Check if the body contains a free (recursive) occurrence of the function name. + # By checking if the name is free at the entry BB, we avoid false positives when + # a user shadows the name with a local variable + if func_def.name in cfg.live_before[cfg.entry_bb]: + if not captured: + # If there are no captured vars, we treat the function like a global name + func = DefinedFunction(func_def.name, func_ty, func_def, None) + globals = ctx.globals | Globals({func_def.name: func}, {}) + + else: + # Otherwise, we treat it like a local name + inputs.append(Variable(func_def.name, func_def.ty, func_def, None)) + + checked_cfg = check_cfg(cfg, inputs, func_ty.returns, globals) + checked_def = CheckedNestedFunctionDef( + checked_cfg, + func_ty, + captured, + name=func_def.name, + args=func_def.args, + body=func_def.body, + decorator_list=func_def.decorator_list, + returns=func_def.returns, + type_comment=func_def.type_comment, + ) + return with_loc(func_def, checked_def) def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType: """Checks the signature of a function definition and returns the corresponding Guppy type.""" - raise NotImplementedError + if len(func_def.args.posonlyargs) != 0: + raise GuppyError( + "Positional-only parameters not supported", func_def.args.posonlyargs[0] + ) + if len(func_def.args.kwonlyargs) != 0: + raise GuppyError( + "Keyword-only parameters not supported", func_def.args.kwonlyargs[0] + ) + if func_def.args.vararg is not None: + raise GuppyError("*args not supported", func_def.args.vararg) + if func_def.args.kwarg is not None: + raise GuppyError("**kwargs not supported", func_def.args.kwarg) + if func_def.returns is None: + # TODO: Error location is incorrect + if all(r.value is None for r in return_nodes_in_ast(func_def)): + raise GuppyError( + "Return type must be annotated. Try adding a `-> None` annotation.", + func_def, + ) + raise GuppyError("Return type must be annotated", func_def) + + arg_tys = [] + arg_names = [] + for i, arg in enumerate(func_def.args.args): + if arg.annotation is None: + raise GuppyError("Argument type must be annotated", arg) + ty = type_from_ast(arg.annotation, globals) + arg_tys.append(ty) + arg_names.append(arg.arg) + + ret_type = type_from_ast(func_def.returns, globals) + return FunctionType(arg_tys, ret_type, arg_names) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py new file mode 100644 index 00000000..5622ee55 --- /dev/null +++ b/guppy/checker/stmt_checker.py @@ -0,0 +1,140 @@ +"""Type checking code for statements. + +Operates on statements in a basic block after CFG construction. In particular, we +assume that statements involving control flow (i.e. if, while, break, and return +statements) have been removed during CFG construction. + +After checking, we return a desugared statement where all sub-expression have been type +annotated. +""" + +import ast +from typing import Sequence + +from guppy.ast_util import with_loc, AstVisitor +from guppy.cfg.bb import BB, BBStatement +from guppy.checker.core import Variable, Context +from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker +from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError +from guppy.gtypes import GuppyType, TupleType, type_from_ast, NoneType +from guppy.nodes import NestedFunctionDef + + +class StmtChecker(AstVisitor[BBStatement]): + ctx: Context + bb: BB + return_ty: GuppyType + + def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: + self.ctx = ctx + self.bb = bb + self.return_ty = return_ty + + def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]: + return [self.visit(s) for s in stmts] + + def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + return ExprSynthesizer(self.ctx).synthesize(node) + + def _check_expr( + self, node: ast.expr, ty: GuppyType, kind: str = "expression" + ) -> ast.expr: + return ExprChecker(self.ctx).check(node, ty, kind) + + def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: + """Helper function to check assignments with patterns.""" + match lhs: + # Easiest case is if the LHS pattern is a single variable. + case ast.Name(id=x): + # Check if we override an unused linear variable + if x in self.ctx.locals: + var = self.ctx.locals[x] + if var.ty.linear and var.used is None: + raise GuppyError( + f"Variable `{x}` with linear type `{var.ty}` is not used", + var.defined_at, + ) + self.ctx.locals[x] = Variable(x, ty, node, None) + + # The only other thing we support right now are tuples + case ast.Tuple(elts=elts): + tys = ty.element_types if isinstance(ty, TupleType) else [ty] + n, m = len(elts), len(tys) + if n != m: + raise GuppyTypeError( + f"{'Too many' if n < m else 'Not enough'} values to unpack " + f"(expected {n}, got {m})", + node, + ) + for pat, el_ty in zip(elts, tys): + self._check_assign(pat, el_ty, node) + + # TODO: Python also supports assignments like `[a, b] = [1, 2]` or + # `a, *b = ...`. The former would require some runtime checks but + # the latter should be easier to do (unpack and repack the rest). + case _: + raise GuppyError("Assignment pattern not supported", lhs) + + def visit_Assign(self, node: ast.Assign) -> ast.stmt: + if len(node.targets) > 1: + # This is the case for assignments like `a = b = 1` + raise GuppyError("Multi assignment not supported", node) + + [target] = node.targets + node.value, ty = self._synth_expr(node.value) + self._check_assign(target, ty, node) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: + if node.value is None: + raise GuppyError( + "Variable declaration is not supported. Assignment is required", node + ) + ty = type_from_ast(node.annotation, self.ctx.globals) + node.value = self._check_expr(node.value, ty) + self._check_assign(node.target, ty, node) + return node + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: + bin_op = with_loc( + node, ast.BinOp(left=node.target, op=node.op, right=node.value) + ) + assign = with_loc(node, ast.Assign(targets=[node.target], value=bin_op)) + return self.visit_Assign(assign) + + def visit_Expr(self, node: ast.Expr) -> ast.stmt: + # An expression statement where the return value is discarded + node.value, ty = self._synth_expr(node.value) + if ty.linear: + raise GuppyTypeError(f"Value with linear type `{ty}` is not used", node) + return node + + def visit_Return(self, node: ast.Return) -> ast.stmt: + if node.value is not None: + node.value = self._check_expr(node.value, self.return_ty, "return value") + elif not isinstance(self.return_ty, NoneType): + raise GuppyTypeError( + f"Expected return value of type `{self.return_ty}`", None + ) + return node + + def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: + from guppy.checker.func_checker import check_nested_func_def + + func_def = check_nested_func_def(node, self.bb, self.ctx) + self.ctx.locals[func_def.name] = Variable( + func_def.name, func_def.ty, func_def, None + ) + return func_def + + def visit_If(self, node: ast.If) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_While(self, node: ast.While) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_Break(self, node: ast.Break) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") + + def visit_Continue(self, node: ast.Continue) -> None: + raise InternalGuppyError("Control-flow statement should not be present here.") diff --git a/guppy/custom.py b/guppy/custom.py index 61800095..a80fd134 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -4,17 +4,18 @@ from guppy.ast_util import AstNode, with_type, with_loc, get_type from guppy.checker.core import Context, Globals +from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import ( GuppyError, InternalGuppyError, UnknownFunctionType, - GuppyTypeError, ) from guppy.gtypes import GuppyType, FunctionType, type_to_row from guppy.hugr import ops from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode +from guppy.nodes import GlobalCall class CustomFunction(CompiledFunction): @@ -160,15 +161,12 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: self.node = node self.func = func + @abstractmethod def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: """Checks the return value against a given type. Returns a (possibly) transformed and annotated AST node for the call. """ - args, actual = self.synthesize(args) - raise GuppyTypeError( - f"Expected expression of type `{ty}`, got `{actual}`", self.node - ) @abstractmethod def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: @@ -203,10 +201,14 @@ class DefaultCallChecker(CustomCallChecker): """Checks function calls by comparing to a type signature.""" def check(self, args: list[ast.expr], ty: GuppyType) -> ast.expr: - raise NotImplementedError + # Use default implementation from the expression checker + args = check_call(self.func.ty, args, ty, self.node, self.ctx) + return GlobalCall(func=self.func, args=args) def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: - raise NotImplementedError + # Use default implementation from the expression checker + args, ty = synthesize_call(self.func.ty, args, self.node, self.ctx) + return GlobalCall(func=self.func, args=args), ty class DefaultCallCompiler(CustomCallCompiler): diff --git a/guppy/declared.py b/guppy/declared.py index 0106ed21..91b5283e 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -4,6 +4,7 @@ from guppy.ast_util import AstNode, has_empty_body from guppy.checker.core import Globals, Context +from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals from guppy.error import GuppyError @@ -32,12 +33,16 @@ def from_ast( def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context ) -> GlobalCall: - raise NotImplementedError + # Use default implementation from the expression checker + args = check_call(self.ty, args, ty, node, ctx) + return GlobalCall(func=self, args=args) def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: - raise NotImplementedError + # Use default implementation from the expression checker + args, ty = synthesize_call(self.ty, args, node, ctx) + return GlobalCall(func=self, args=args), ty def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) diff --git a/tests/error/linear_errors/unused_expr.err b/tests/error/linear_errors/unused_expr.err new file mode 100644 index 00000000..6a0ae95e --- /dev/null +++ b/tests/error/linear_errors/unused_expr.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(q: Qubit) -> None: +13: h(q) + ^^^^ +GuppyTypeError: Value with linear type `Qubit` is not used diff --git a/tests/error/linear_errors/unused_expr.py b/tests/error/linear_errors/unused_expr.py new file mode 100644 index 00000000..6ad60a82 --- /dev/null +++ b/tests/error/linear_errors/unused_expr.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(q: Qubit) -> None: + h(q) + + +module.compile()