From 359c2da21201adb41181295e7d9a55e70a4c66da Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 20 Dec 2023 13:25:08 +0000 Subject: [PATCH 01/31] feat: Build for loops and comprehensions --- guppy/ast_util.py | 190 ++++++++++++++++++++++++++++++++++--------- guppy/cfg/builder.py | 122 +++++++++++++++++++++++---- 2 files changed, 260 insertions(+), 52 deletions(-) diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 5a9e66f4..31bb940f 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -1,5 +1,7 @@ import ast -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +import textwrap +from collections.abc import Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast if TYPE_CHECKING: from guppy.gtypes import GuppyType @@ -54,51 +56,165 @@ def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> T: raise NotImplementedError(f"visit_{node.__class__.__name__} is not implemented") -class NameVisitor(ast.NodeVisitor): - """Visitor to collect all `Name` nodes occurring in an AST.""" - - names: list[ast.Name] - - def __init__(self) -> None: - self.names = [] - - def visit_Name(self, node: ast.Name) -> None: - self.names.append(node) +class AstSearcher(ast.NodeVisitor): + """Visitor that searches for occurrences of specific nodes in an AST.""" + + matcher: Callable[[ast.AST], bool] + dont_recurse_into: set[type[ast.AST]] + found: list[ast.AST] + is_first_node: bool + + def __init__( + self, + matcher: Callable[[ast.AST], bool], + dont_recurse_into: set[type[ast.AST]] | None = None, + ) -> None: + self.matcher = matcher + self.dont_recurse_into = dont_recurse_into or set() + self.found = [] + self.is_first_node = True + + def generic_visit(self, node: ast.AST) -> None: + if self.matcher(node): + self.found.append(node) + if self.is_first_node or type(node) not in self.dont_recurse_into: + self.is_first_node = False + super().generic_visit(node) + + +def find_nodes( + matcher: Callable[[ast.AST], bool], + node: ast.AST, + dont_recurse_into: set[type[ast.AST]] | None = None, +) -> list[ast.AST]: + """Returns all nodes in the AST that satisfy the matcher.""" + v = AstSearcher(matcher, dont_recurse_into) + v.visit(node) + return v.found def name_nodes_in_ast(node: Any) -> list[ast.Name]: """Returns all `Name` nodes occurring in an AST.""" - v = NameVisitor() - v.visit(node) - return v.names - - -class ReturnVisitor(ast.NodeVisitor): - """Visitor to collect all `Return` nodes occurring in an AST.""" + found = find_nodes(lambda n: isinstance(n, ast.Name), node) + return cast(list[ast.Name], found) - returns: list[ast.Return] - inside_func_def: bool - def __init__(self) -> None: - self.returns = [] - self.inside_func_def = False - - def visit_Return(self, node: ast.Return) -> None: - self.returns.append(node) +def return_nodes_in_ast(node: Any) -> list[ast.Return]: + """Returns all `Return` nodes occurring in an AST.""" + found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef}) + return cast(list[ast.Return], found) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - # Don't descend into nested function definitions - if not self.inside_func_def: - self.inside_func_def = True - for n in node.body: - self.visit(n) +def breaks_in_loop(node: Any) -> list[ast.Break]: + """Returns all `Break` nodes occurring in a loop. -def return_nodes_in_ast(node: Any) -> list[ast.Return]: - """Returns all `Return` nodes occurring in an AST.""" - v = ReturnVisitor() - v.visit(node) - return v.returns + Note that breaks in nested loops are excluded. + """ + found = find_nodes( + lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef} + ) + return cast(list[ast.Break], found) + + +class ContextAdjuster(ast.NodeTransformer): + """Updates the `ast.Context` indicating if expressions occur on the LHS or RHS.""" + + ctx: ast.expr_context + + def __init__(self, ctx: ast.expr_context) -> None: + self.ctx = ctx + + def visit(self, node: ast.AST) -> ast.AST: + return cast(ast.AST, super().visit(node)) + + def visit_Name(self, node: ast.Name) -> ast.Name: + return with_loc(node, ast.Name(id=node.id, ctx=self.ctx)) + + def visit_Starred(self, node: ast.Starred) -> ast.Starred: + return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx)) + + def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple: + return with_loc( + node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx) + ) + + def visit_List(self, node: ast.List) -> ast.List: + return with_loc( + node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx) + ) + + def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: + # Don't adjust the slice! + return with_loc( + node, + ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx), + ) + + def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: + return ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx) + + +class TemplateReplacer(ast.NodeTransformer): + """Replaces nodes in a template.""" + + replacements: Mapping[str, ast.AST | Sequence[ast.AST]] + default_loc: ast.AST + + def __init__( + self, + replacements: Mapping[str, ast.AST | Sequence[ast.AST]], + default_loc: ast.AST, + ) -> None: + self.replacements = replacements + self.default_loc = default_loc + + def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]: + if x not in self.replacements: + msg = f"No replacement for `{x}` is given" + raise ValueError(msg) + return self.replacements[x] + + def visit_Name(self, node: ast.Name) -> ast.AST: + repl = self._get_replacement(node.id) + if not isinstance(repl, ast.expr): + msg = f"Replacement for `{node.id}` must be an expression" + raise TypeError(msg) + + # Update the context + adjuster = ContextAdjuster(node.ctx) + return with_loc(repl, adjuster.visit(repl)) + + def visit_Expr(self, node: ast.Expr) -> ast.AST | Sequence[ast.AST]: + if isinstance(node.value, ast.Name): + repl = self._get_replacement(node.value.id) + repls = [repl] if not isinstance(repl, Sequence) else repl + # Wrap expressions to turn them into statements + return [ + with_loc(r, ast.Expr(value=r)) if isinstance(r, ast.expr) else r + for r in repls + ] + return self.generic_visit(node) + + def generic_visit(self, node: ast.AST) -> ast.AST: + # Insert the default location + node = super().generic_visit(node) + return with_loc(self.default_loc, node) + + +def template_replace( + template: str, default_loc: ast.AST, **kwargs: ast.AST | Sequence[ast.AST] +) -> list[ast.stmt]: + """Turns a template into a proper AST by substituting all placeholders.""" + nodes = ast.parse(textwrap.dedent(template)).body + replacer = TemplateReplacer(kwargs, default_loc) + new_nodes = [] + for n in nodes: + new = replacer.visit(n) + if isinstance(new, list): + new_nodes.extend(new) + else: + new_nodes.append(new) + return new_nodes def line_col(node: ast.AST) -> tuple[int, int]: diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 878437bf..a664394b 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -1,15 +1,29 @@ import ast import itertools from collections.abc import Iterator -from typing import NamedTuple - -from guppy.ast_util import AstVisitor, set_location_from +from typing import NamedTuple, cast + +from guppy.ast_util import ( + AstVisitor, + find_nodes, + set_location_from, + template_replace, + with_loc, +) from guppy.cfg.bb import BB, BBStatement from guppy.cfg.cfg import CFG from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError from guppy.gtypes import NoneType -from guppy.nodes import NestedFunctionDef +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + IterEnd, + IterHasNext, + IterNext, + MakeIter, + NestedFunctionDef, +) # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -142,6 +156,35 @@ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None: # its own jumps since the body is not guaranteed to execute return tail_bb + def visit_For(self, node: ast.For, bb: BB, jumps: Jumps) -> BB | None: + template = """ + it = make_iter + while True: + b, it = has_next + if b: + x, it = get_next + body + else: + break + end_iter # Consume iterator one last time + """ + + it = make_var(next(tmp_vars), node.iter) + b = make_var(next(tmp_vars), node.iter) + new_nodes = template_replace( + template, + node, + it=it, + b=b, + x=node.target, + make_iter=with_loc(node.iter, MakeIter(value=node.iter, origin_node=node)), + has_next=with_loc(node.iter, IterHasNext(value=it)), + get_next=with_loc(node.iter, IterNext(value=it)), + end_iter=with_loc(node.iter, IterEnd(value=it)), + body=node.body, + ) + return self.visit_stmts(new_nodes, bb, jumps) + def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None: if not jumps.continue_bb: raise InternalGuppyError("Continue BB not defined") @@ -211,18 +254,10 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]: builder = ExprBuilder(cfg, bb) return builder.visit(node), builder.bb - @classmethod - def _make_var(cls, name: str, loc: ast.expr | None = None) -> ast.Name: - """Creates an `ast.Name` node.""" - node = ast.Name(id=name, ctx=ast.Load) - if loc is not None: - set_location_from(node, loc) - return node - @classmethod def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None: """Adds a temporary variable assignment to a basic block.""" - node = ast.Assign(targets=[cls._make_var(tmp_name, value)], value=value) + node = ast.Assign(targets=[make_var(tmp_name, value)], value=value) set_location_from(node, value) bb.statements.append(node) @@ -256,7 +291,51 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name: self.bb = merge_bb # The final value is stored in the temporary variable - return self._make_var(tmp, node) + return make_var(tmp, node) + + def visit_ListComp(self, node: ast.ListComp) -> ast.AST: + # Check for illegal expressions + illegals = find_nodes(is_illegal_in_list_comp, node) + if illegals: + raise GuppyError( + "Expression is not supported inside a list comprehension", illegals[0] + ) + + # Desugar into statements that create the iterator, check for a next element, + # get the next element, and finalise the iterator. + gens = [] + for g in node.generators: + if g.is_async: + raise GuppyError("Async generators are not supported", g) + g.iter = self.visit(g.iter) + gen = DesugaredGenerator() + + template = """ + it = make_iter + b, it = has_next + x, it = get_next + """ + it = make_var(next(tmp_vars), g.iter) + b = make_var(next(tmp_vars), g.iter) + [gen.iter_assign, gen.hasnext_assign, gen.next_assign] = cast( + list[ast.Assign], + template_replace( + template, + g.iter, + it=it, + b=b, + x=g.target, + make_iter=with_loc(it, MakeIter(value=g.iter, origin_node=node)), + has_next=with_loc(it, IterHasNext(value=it)), + get_next=with_loc(it, IterNext(value=it)), + ), + ) + gen.iterend = with_loc(it, IterEnd(value=it)) + gen.iter, gen.hasnext, gen.ifs = it, b, g.ifs + gens.append(gen) + + node.elt = self.visit(node.elt) + return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) def generic_visit(self, node: ast.AST) -> ast.AST: # Short-circuit expressions must be built using the `BranchBuilder`. However, we @@ -275,7 +354,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST: self._tmp_assign(tmp, false_const, false_bb) merge_bb = self.cfg.new_bb(true_bb, false_bb) self.bb = merge_bb - return self._make_var(tmp, node) + return make_var(tmp, node) # For all other expressions, just recurse deeper with the node transformer return super().generic_visit(node) @@ -398,3 +477,16 @@ def is_short_circuit_expr(node: ast.AST) -> bool: return isinstance(node, ast.BoolOp) or ( isinstance(node, ast.Compare) and len(node.comparators) > 1 ) + + +def is_illegal_in_list_comp(node: ast.AST) -> bool: + """Checks if an expression is illegal to use in a list comprehension.""" + return isinstance(node, ast.IfExp | ast.NamedExpr) or is_short_circuit_expr(node) + + +def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: + """Creates an `ast.Name` node.""" + node = ast.Name(id=name, ctx=ast.Load) + if loc is not None: + set_location_from(node, loc) + return node From 36396cccfc03eb30277a8b37980722a95f0c0f06 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 20 Dec 2023 13:35:17 +0000 Subject: [PATCH 02/31] feat: Add list and comprehension type checking --- guppy/cfg/bb.py | 33 ++++- guppy/checker/cfg_checker.py | 4 +- guppy/checker/core.py | 43 +++++- guppy/checker/expr_checker.py | 258 +++++++++++++++++++++++++++++++++- guppy/checker/func_checker.py | 4 +- guppy/checker/stmt_checker.py | 20 ++- guppy/declared.py | 6 +- 7 files changed, 340 insertions(+), 28 deletions(-) diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index 99fd61a2..84bb6cb1 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -6,7 +6,7 @@ from typing_extensions import Self from guppy.ast_util import AstNode, name_nodes_in_ast -from guppy.nodes import NestedFunctionDef +from guppy.nodes import DesugaredListComp, NestedFunctionDef if TYPE_CHECKING: from guppy.cfg.cfg import BaseCFG @@ -99,24 +99,46 @@ def __init__(self, bb: BB): self.bb = bb self.stats = VariableStats() + def visit_Name(self, node: ast.Name) -> None: + self.stats.update_used(node) + def visit_Assign(self, node: ast.Assign) -> None: - self.stats.update_used(node.value) + self.visit(node.value) for t in node.targets: for name in name_nodes_in_ast(t): self.stats.assigned[name.id] = node def visit_AugAssign(self, node: ast.AugAssign) -> None: - self.stats.update_used(node.value) + self.visit(node.value) self.stats.update_used(node.target) # The target is also used for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node def visit_AnnAssign(self, node: ast.AnnAssign) -> None: if node.value: - self.stats.update_used(node.value) + self.visit(node.value) for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node + def visit_DesugaredListComp(self, node: DesugaredListComp) -> None: + # Names bound in the comprehension are only available inside, so we shouldn't + # update `self.stats` with assignments + inner_visitor = VariableVisitor(self.bb) + inner_stats = inner_visitor.stats + + # The generators are evaluated left to right + for gen in node.generators: + inner_visitor.visit(gen.iter_assign) + inner_visitor.visit(gen.hasnext_assign) + inner_visitor.visit(gen.next_assign) + for cond in gen.ifs: + inner_visitor.visit(cond) + inner_visitor.visit(node.elt) + + self.stats.used = { + x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned + } + def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # In order to compute the used external variables in a nested function # definition, we have to run live variable analysis first @@ -139,6 +161,3 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # The name of the function is now assigned self.stats.assigned[node.name] = node - - def generic_visit(self, node: ast.AST) -> None: - self.stats.update_used(node) diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index c5cf8a64..ada6fd02 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -11,7 +11,7 @@ from guppy.ast_util import line_col from guppy.cfg.bb import BB from guppy.cfg.cfg import CFG, BaseCFG -from guppy.checker.core import Context, Globals, Variable +from guppy.checker.core import Context, Globals, Locals, Variable from guppy.checker.expr_checker import ExprSynthesizer, to_bool from guppy.checker.stmt_checker import StmtChecker from guppy.error import GuppyError @@ -127,7 +127,7 @@ def check_bb( raise GuppyError(f"Variable `{x}` is not defined", use) # Check the basic block - ctx = Context(globals, {v.name: v for v in inputs}) + ctx = Context(globals, Locals({v.name: v for v in inputs})) checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) # If we branch, we also have to check the branch predicate diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 9abc695b..9827c7b5 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -1,5 +1,8 @@ import ast +import copy +import itertools from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator from dataclasses import dataclass from typing import NamedTuple @@ -105,8 +108,44 @@ def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034 return self -# Local variable mapping -Locals = dict[str, Variable] +@dataclass +class Locals: + """Scoped mapping from names to variables""" + + vars: dict[str, Variable] + parent_scope: "Locals | None" = None + + def __getitem__(self, item: str) -> Variable: + if item not in self.vars and self.parent_scope: + return self.parent_scope[item] + + return self.vars[item] + + def __setitem__(self, key: str, value: Variable) -> None: + self.vars[key] = value + + def __iter__(self) -> Iterator[str]: + return iter(self.keys()) + + def __contains__(self, item: str) -> bool: + return (item in self.vars) or ( + self.parent_scope is not None and item in self.parent_scope + ) + + def __copy__(self) -> "Locals": + # Make a copy of the var map so that mutating the copy doesn't + # mutate our variable mapping + return Locals(self.vars.copy(), copy.copy(self.parent_scope)) + + def keys(self) -> set[str]: + parent_keys = self.parent_scope.keys() if self.parent_scope else set() + return parent_keys | self.vars.keys() + + def items(self) -> Iterable[tuple[str, Variable]]: + parent_items = ( + iter(self.parent_scope.items()) if self.parent_scope else iter(()) + ) + return itertools.chain(self.vars.items(), parent_items) class Context(NamedTuple): diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index c54f5e68..2b1d9f27 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -22,10 +22,19 @@ import ast from contextlib import suppress -from typing import Any, NoReturn - -from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type -from guppy.checker.core import CallableVariable, Context, Globals +from typing import Any, NoReturn, cast + +from guppy.ast_util import ( + AstNode, + AstVisitor, + breaks_in_loop, + get_type_opt, + name_nodes_in_ast, + return_nodes_in_ast, + with_loc, + with_type, +) +from guppy.checker.core import CallableVariable, Context, Globals, Locals from guppy.error import ( GuppyError, GuppyTypeError, @@ -38,11 +47,25 @@ FunctionType, GuppyType, Inst, + LinstType, + ListType, + NoneType, Subst, TupleType, unify, ) -from guppy.nodes import GlobalName, LocalCall, LocalName, TypeApply +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + GlobalName, + IterEnd, + IterHasNext, + IterNext, + LocalCall, + LocalName, + MakeIter, + TypeApply, +) # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { @@ -114,6 +137,14 @@ def check( a new desugared expression with type annotations and a substitution with the resolved type variables. """ + # If we already have a type for the expression, we just have to match it against + # the target + if actual := get_type_opt(expr): + subst, inst = check_type_against(actual, ty, expr, kind) + if inst: + expr = with_loc(expr, TypeApply(value=expr, tys=inst)) + return with_type(ty.substitute(subst), expr), subst + # When checking against a variable, we have to synthesize if isinstance(ty, FreeTypeVar): expr, syn_ty = self._synthesize(expr, allow_free_vars=False) @@ -141,6 +172,27 @@ def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> tuple[ast.expr, Subst]: subst |= s return node, subst + def visit_List(self, node: ast.List, ty: GuppyType) -> tuple[ast.expr, Subst]: + if not isinstance(ty, ListType | LinstType): + return self._fail(ty, node) + subst: Subst = {} + for i, el in enumerate(node.elts): + node.elts[i], s = self.check(el, ty.element_type.substitute(subst)) + subst |= s + return node, subst + + def visit_DesugaredListComp( + self, node: DesugaredListComp, ty: GuppyType + ) -> tuple[ast.expr, Subst]: + if not isinstance(ty, ListType | LinstType): + return self._fail(ty, node) + node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + subst = unify(ty.element_type, elt_ty, {}) + if subst is None: + actual = LinstType(elt_ty) if elt_ty.linear else ListType(elt_ty) + return self._fail(ty, actual, node) + return node, subst + def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: if len(node.keywords) > 0: raise GuppyError( @@ -213,7 +265,7 @@ def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, GuppyType]: raise GuppyError("Unsupported constant", node) return node, ty - def visit_Name(self, node: ast.Name) -> tuple[ast.expr, GuppyType]: + def visit_Name(self, node: ast.Name) -> tuple[ast.Name, GuppyType]: x = node.id if x in self.ctx.locals: var = self.ctx.locals[x] @@ -240,6 +292,22 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, GuppyType]: node.elts = [n for n, _ in elems] return node, TupleType([ty for _, ty in elems]) + def visit_List(self, node: ast.List) -> tuple[ast.expr, GuppyType]: + if len(node.elts) == 0: + raise GuppyTypeInferenceError( + "Cannot infer type variable in expression of type `list[?T]`", node + ) + node.elts[0], el_ty = self.synthesize(node.elts[0]) + node.elts[1:] = [self._check(el, el_ty)[0] for el in node.elts[1:]] + return node, LinstType(el_ty) if el_ty.linear else ListType(el_ty) + + def visit_DesugaredListComp( + self, node: DesugaredListComp + ) -> tuple[ast.expr, GuppyType]: + node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + result_ty = LinstType(elt_ty) if elt_ty.linear else ListType(elt_ty) + return node, result_ty + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: # We need to synthesise the argument type, so we can look up dunder methods node.operand, op_ty = self.synthesize(node.operand) @@ -289,6 +357,36 @@ def _synthesize_binary( node, ) + def _synthesize_instance_func( + self, + node: ast.expr, + args: list[ast.expr], + func_name: str, + err: str, + exp_ty: FunctionType | None = None, + var: FreeTypeVar | None = None, + give_reason: bool = False, + ) -> tuple[ast.expr, GuppyType]: + """Helper method for expressions that are implemented via instance methods.""" + node, ty = self.synthesize(node) + func = self.ctx.globals.get_instance_func(ty, func_name) + if func is None: + reason = f" since it does not implement the `{func_name}` method" + raise GuppyTypeError( + f"Expression of type `{ty}` is {err}{reason if give_reason else ''}", + node, + ) + if exp_ty: + assert var is not None + exp_ty = cast(FunctionType, exp_ty.substitute({var: ty})) + if unify(exp_ty, func.ty.unquantified()[0], {}) is None: + raise GuppyError( + f"Method `{ty.name}.{func_name}` has signature `{func.ty}`, but " + f"expected `{exp_ty}`", + node, + ) + return func.synthesize_call([node, *args], node, self.ctx) + def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: return self._synthesize_binary(node.left, node.right, node.op, node) @@ -301,6 +399,16 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: left_expr, [op], [right_expr] = node.left, node.ops, node.comparators return self._synthesize_binary(left_expr, right_expr, op, node) + def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, GuppyType]: + var = FreeTypeVar.new("T", False) + exp_ty = FunctionType( + [var, FreeTypeVar.new("Key", False)], + FreeTypeVar.new("Val", False), + ) + return self._synthesize_instance_func( + node.value, [node.slice], "__getitem__", "not subscriptable", exp_ty, var + ) + def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: if len(node.keywords) > 0: raise GuppyError("Keyword arguments are not supported", node.keywords[0]) @@ -322,6 +430,55 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: else: raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]: + var = FreeTypeVar.new("T", False) + exp_ty = FunctionType([var], FreeTypeVar.new("Iter", False)) + expr, ty = self._synthesize_instance_func( + node.value, [], "__iter__", "not iterable", exp_ty, var + ) + + # If the iterator was created by a `for` loop, we can add some extra checks to + # produce nicer errors for linearity violations. Namely, `break` and `return` + # are not allowed when looping over a linear iterator (`continue` is allowed) + if ty.linear and isinstance(node.origin_node, ast.For): + breaks = breaks_in_loop(node.origin_node) or return_nodes_in_ast( + node.origin_node + ) + if breaks: + raise GuppyTypeError( + f"Loop over iterator with linear type `{ty}` cannot be terminated " + f"(cannot ensure that all values have been used)", + breaks[0], + ) + return expr, ty + + def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, GuppyType]: + var = FreeTypeVar.new("Iter", False) + exp_ty = FunctionType([var], TupleType([BoolType(), var])) + return self._synthesize_instance_func( + node.value, [], "__hasnext__", "not an iterator", exp_ty, var, True + ) + + def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, GuppyType]: + var = FreeTypeVar.new("Iter", False) + exp_ty = FunctionType([var], TupleType([FreeTypeVar.new("T", True), var])) + return self._synthesize_instance_func( + node.value, [], "__next__", "not an iterator", exp_ty, var, True + ) + + def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, GuppyType]: + var = FreeTypeVar.new("Iter", False) + exp_ty = FunctionType([var], NoneType()) + return self._synthesize_instance_func( + node.value, [], "__end__", "not an iterator", exp_ty, var, True + ) + + def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `ListComp`. Should have been removed during CFG" + f"construction: `{ast.unparse(node)}`" + ) + def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]: raise InternalGuppyError( "BB contains `NamedExpr`. Should have been removed during CFG" @@ -598,6 +755,95 @@ def to_bool( return call, return_ty +def synthesize_comprehension( + node: DesugaredListComp, gens: list[DesugaredGenerator], ctx: Context +) -> tuple[DesugaredListComp, GuppyType]: + """Helper function to synthesise the element type of a list comprehension.""" + from guppy.checker.stmt_checker import StmtChecker + + def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None: + """Checks if an expression uses a linear variable from an outer scope. + + Since the expression is executed multiple times in the inner scope, this would + mean that the outer linear variable is used multiple times, which is not + allowed. + """ + for name in name_nodes_in_ast(expr): + x = name.id + if x in locals and x not in locals.vars: + var = locals[x] + if var.ty.linear: + raise GuppyTypeError( + f"Variable `{x}` with linear type `{var.ty}` would be used " + "multiple times when evaluating this comprehension", + name, + ) + + # If there are no more generators left, we can check the list element + if not gens: + node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt) + check_linear_use_from_outer_scope(node.elt, ctx.locals) + return node, elt_ty + + # Check the iterator in the outer context + gen, *gens = gens + gen.iter_assign = StmtChecker(ctx).visit_Assign(gen.iter_assign) + check_linear_use_from_outer_scope(gen.iter_assign.value, ctx.locals) + + # The rest is checked in a new nested context to ensure that variables don't escape + # their scope + inner_locals = Locals({}, parent_scope=ctx.locals) + inner_ctx = Context(ctx.globals, inner_locals) + expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx) + gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign) + gen.next_assign = stmt_chk.visit_Assign(gen.next_assign) + gen.hasnext, hasnext_ty = expr_sth.visit_Name(gen.hasnext) + gen.hasnext = with_type(hasnext_ty, gen.hasnext) + gen.iter, iter_ty = expr_sth.visit_Name(gen.iter) + gen.iter = with_type(iter_ty, gen.iter) + + # `if` guards are generally not allowed when we're iterating over linear variables. + # The only exception is if all linear variables are already consumed by the first + # guard + if gen.ifs: + gen.ifs[0], _ = expr_sth.synthesize(gen.ifs[0]) + + # Now, check if there are linear iteration variables that have not been used by + # the first guard + for target in name_nodes_in_ast(gen.next_assign.targets[0]): + var = inner_ctx.locals[target.id] + if var.ty.linear and not var.used and gen.ifs: + raise GuppyTypeError( + f"Variable `{var.name}` with linear type `{var.ty}` is not used on " + "all control-flow paths of the list comprehension", + target, + ) + + # Now, we can properly check all guards + for i in range(len(gen.ifs)): + gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i]) + gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx) + check_linear_use_from_outer_scope(gen.ifs[i], inner_locals) + + # Check remaining generators + node, elt_ty = synthesize_comprehension(node, gens, inner_ctx) + + # We have to make sure that all linear variables that were introduced in this scope + # have been used + for x, var in inner_ctx.locals.vars.items(): + if var.ty.linear and not var.used: + raise GuppyTypeError( + f"Variable `{x}` with linear type `{var.ty}` is not used", + var.defined_at, + ) + + # The iter finalizer is again checked in the outer context + ctx.locals[gen.iter.id].used = None + gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend) + gen.iterend = with_type(iterend_ty, gen.iterend) + return node, elt_ty + + def python_value_to_guppy_type( v: Any, node: ast.expr, globals: Globals ) -> GuppyType | None: diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index dcef53bd..022446ee 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -49,14 +49,14 @@ def check_call( ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker args, subst, inst = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), subst + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty, inst = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), ty + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), ty @dataclass diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 3959801e..00559cf4 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -22,11 +22,13 @@ class StmtChecker(AstVisitor[BBStatement]): ctx: Context - bb: BB - return_ty: GuppyType + bb: BB | None + return_ty: GuppyType | None - def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: - assert not return_ty.free_vars + def __init__( + self, ctx: Context, bb: BB | None = None, return_ty: GuppyType | None = None + ) -> None: + assert not return_ty or not return_ty.free_vars self.ctx = ctx self.bb = bb self.return_ty = return_ty @@ -55,7 +57,7 @@ def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: f"Variable `{x}` with linear type `{var.ty}` is not used", var.defined_at, ) - self.ctx.locals[x] = Variable(x, ty, node, None) + self.ctx.locals[x] = Variable(x, ty, lhs, None) # The only other thing we support right now are tuples case ast.Tuple(elts=elts): @@ -76,7 +78,7 @@ def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: case _: raise GuppyError("Assignment pattern not supported", lhs) - def visit_Assign(self, node: ast.Assign) -> ast.stmt: + def visit_Assign(self, node: ast.Assign) -> ast.Assign: if len(node.targets) > 1: # This is the case for assignments like `a = b = 1` raise GuppyError("Multi assignment not supported", node) @@ -113,6 +115,9 @@ def visit_Expr(self, node: ast.Expr) -> ast.stmt: return node def visit_Return(self, node: ast.Return) -> ast.stmt: + if not self.return_ty: + raise InternalGuppyError("return_ty required to check return stmt!") + if node.value is not None: node.value, subst = self._check_expr( node.value, self.return_ty, "return value" @@ -127,6 +132,9 @@ def visit_Return(self, node: ast.Return) -> ast.stmt: def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: from guppy.checker.func_checker import check_nested_func_def + if not self.bb: + raise InternalGuppyError("BB required to check nested function def!") + func_def = check_nested_func_def(node, self.bb, self.ctx) self.ctx.locals[func_def.name] = Variable( func_def.name, func_def.ty, func_def, None diff --git a/guppy/declared.py b/guppy/declared.py index b868a374..df1c0f29 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -1,7 +1,7 @@ import ast from dataclasses import dataclass -from guppy.ast_util import AstNode, has_empty_body +from guppy.ast_util import AstNode, has_empty_body, with_loc from guppy.checker.core import Context, Globals from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature @@ -34,14 +34,14 @@ def check_call( ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker args, subst, inst = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), subst + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty, inst = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), ty + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), ty def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) From 000e1466d0442ccd78c3292283473cdfb6228847 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 20 Dec 2023 13:37:45 +0000 Subject: [PATCH 03/31] Fix error golden files --- tests/error/linear_errors/branch_use.err | 2 +- tests/error/linear_errors/break_unused.err | 2 +- tests/error/linear_errors/continue_unused.err | 2 +- tests/error/linear_errors/if_both_unused.err | 2 +- tests/error/linear_errors/if_both_unused_reassign.err | 2 +- tests/error/linear_errors/unused.err | 2 +- tests/error/linear_errors/unused_same_block.err | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/error/linear_errors/branch_use.err b/tests/error/linear_errors/branch_use.err index 742a0ca7..c096aa16 100644 --- a/tests/error/linear_errors/branch_use.err +++ b/tests/error/linear_errors/branch_use.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:23 21: @guppy(module) 22: def foo(b: bool) -> bool: 23: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/break_unused.err b/tests/error/linear_errors/break_unused.err index a91d243e..8ee9039f 100644 --- a/tests/error/linear_errors/break_unused.err +++ b/tests/error/linear_errors/break_unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:25 23: b = False 24: while True: 25: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/continue_unused.err b/tests/error/linear_errors/continue_unused.err index 6b91f2be..7ecd4c5a 100644 --- a/tests/error/linear_errors/continue_unused.err +++ b/tests/error/linear_errors/continue_unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:25 23: b = False 24: while i > 0: 25: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/if_both_unused.err b/tests/error/linear_errors/if_both_unused.err index db518100..f942875c 100644 --- a/tests/error/linear_errors/if_both_unused.err +++ b/tests/error/linear_errors/if_both_unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:19 17: def foo(b: bool) -> int: 18: if b: 19: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/if_both_unused_reassign.err b/tests/error/linear_errors/if_both_unused_reassign.err index cacd42ca..0b0d5dbe 100644 --- a/tests/error/linear_errors/if_both_unused_reassign.err +++ b/tests/error/linear_errors/if_both_unused_reassign.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:19 17: def foo(b: bool) -> Qubit: 18: if b: 19: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/unused.err b/tests/error/linear_errors/unused.err index 7d62fee3..764d3d95 100644 --- a/tests/error/linear_errors/unused.err +++ b/tests/error/linear_errors/unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:13 11: @guppy(module) 12: def foo(q: Qubit) -> int: 13: x = q - ^^^^^ + ^ GuppyError: Variable `x` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/unused_same_block.err b/tests/error/linear_errors/unused_same_block.err index 960f266c..835fff05 100644 --- a/tests/error/linear_errors/unused_same_block.err +++ b/tests/error/linear_errors/unused_same_block.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:13 11: @guppy(module) 12: def foo(q: Qubit) -> int: 13: x = q - ^^^^^ + ^ GuppyError: Variable `x` with linear type `Qubit` is not used From 34d5415c697cde163aa97f43f490b6f45c54057b Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 20 Dec 2023 13:46:24 +0000 Subject: [PATCH 04/31] feat: Compile lists and comprehensions --- guppy/compiler/cfg_compiler.py | 2 +- guppy/compiler/expr_compiler.py | 157 +++++++++++++++++++++++++++++++- guppy/compiler/stmt_compiler.py | 4 - guppy/hugr/hugr.py | 20 ++++ 4 files changed, 174 insertions(+), 9 deletions(-) diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index d2f1abb9..0b9c9a9a 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -56,7 +56,7 @@ def compile_bb( for (i, v) in enumerate(inputs) }, ) - dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, bb, dfg) + dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, dfg) # If we branch, we also have to compile the branch predicate if len(bb.successors) > 1: diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index 97b08945..dff60963 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -1,8 +1,16 @@ import ast +from collections.abc import Iterator +from contextlib import contextmanager from typing import Any -from guppy.ast_util import AstVisitor, get_type -from guppy.compiler.core import CompiledFunction, CompilerBase, DFContainer +from guppy.ast_util import AstVisitor, get_type, with_loc, with_type +from guppy.cfg.builder import tmp_vars +from guppy.compiler.core import ( + CompiledFunction, + CompilerBase, + DFContainer, + PortVariable, +) from guppy.error import GuppyError, InternalGuppyError from guppy.gtypes import ( BoolType, @@ -14,8 +22,16 @@ type_to_row, ) from guppy.hugr import ops, val -from guppy.hugr.hugr import OutPortV -from guppy.nodes import GlobalCall, GlobalName, LocalCall, LocalName, TypeApply +from guppy.hugr.hugr import DFContainingNode, OutPortV, VNode +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + GlobalCall, + GlobalName, + LocalCall, + LocalName, + TypeApply, +) class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): @@ -39,6 +55,64 @@ def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: """ return [self.compile(e, dfg) for e in expr_to_row(expr)] + @contextmanager + def _new_dfcontainer( + self, inputs: list[ast.Name], node: DFContainingNode + ) -> Iterator[None]: + """Context manager to build a graph inside a new `DFContainer`. + + Automatically updates `self.dfg` and makes the inputs available. + """ + old = self.dfg + inp = self.graph.add_input(parent=node) + new_locals = { + name.id: PortVariable(name.id, inp.add_out_port(get_type(name)), name, None) + for name in inputs + } + self.dfg = DFContainer(node, self.dfg.locals | new_locals) + with self.graph.parent(node): + yield + self.dfg = old + + @contextmanager + def _new_loop( + self, + inputs: list[ast.Name], + branch: ast.Name, + parent: DFContainingNode | None = None, + ) -> Iterator[None]: + """Context manager to build a graph inside a new `TailLoop` node. + + Automatically adds the `Output` node once the context manager exists. + """ + loop = self.graph.add_tail_loop([self.visit(name) for name in inputs], parent) + with self._new_dfcontainer(inputs, loop): + yield + # Output the branch predicate and the inputs for the next iteration + self.graph.add_output( + [self.visit(branch), *(self.visit(name) for name in inputs)] + ) + # Update the DFG with the outputs from the loop + for name in inputs: + self.dfg[name.id].port = loop.add_out_port(get_type(name)) + + @contextmanager + def _new_case( + self, inputs: list[ast.Name], outputs: list[ast.Name], cond_node: VNode + ) -> Iterator[None]: + """Context manager to build a graph inside a new `Case` node. + + Automatically adds the `Output` node once the context manager exists. + """ + with self._new_dfcontainer(inputs, self.graph.add_case(cond_node)): + yield + self.graph.add_output([self.visit(name) for name in outputs]) + # Update the DFG with the outputs from the Conditional node, but only we haven't + # already added some + if cond_node.num_out_ports == 0: + for name in inputs: + self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) + def visit_Constant(self, node: ast.Constant) -> OutPortV: if value := python_value_to_hugr(node.value): const = self.graph.add_constant(value, get_type(node)).out_port(0) @@ -59,6 +133,11 @@ def visit_Tuple(self, node: ast.Tuple) -> OutPortV: inputs=[self.visit(e) for e in node.elts] ).out_port(0) + def visit_List(self, node: ast.List) -> OutPortV: + return self.graph.add_node( + ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts] + ).add_out_port(get_type(node)) + def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: """Groups function return values into a tuple""" if len(returns) != 1: @@ -118,6 +197,76 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") + def visit_DesugaredListComp(self, node: DesugaredListComp) -> OutPortV: + from guppy.compiler.stmt_compiler import StmtCompiler + + compiler = StmtCompiler(self.graph, self.globals) + + # Make up a name for the list under construction and bind it to an empty list + list_ty = get_type(node) + list_name = with_type(list_ty, with_loc(node, LocalName(id=next(tmp_vars)))) + empty_list = self.graph.add_node(ops.DummyOp(name="MakeList")) + self.dfg[list_name.id] = PortVariable( + list_name.id, empty_list.add_out_port(list_ty), node, None + ) + + def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: + """Helper function to generate nested TailLoop nodes for generators""" + # If there are no more generators left, just append the element to the list + if not gens: + list_port, elt_port = self.visit(list_name), self.visit(elt) + push = self.graph.add_node( + ops.DummyOp(name="Push"), inputs=[list_port, elt_port] + ) + self.dfg[list_name.id].port = push.add_out_port(list_port.ty) + return + + # Otherwise, compile the first iterator and construct a TailLoop + gen, *gens = gens + compiler.compile_stmts([gen.iter_assign], self.dfg) + inputs = [gen.iter, list_name] + with self._new_loop(inputs, gen.hasnext): + # Compile the `hasnext` check and plug it into a conditional + compiler.compile_stmts([gen.hasnext_assign], self.dfg) + cond = self.graph.add_conditional( + self.visit(gen.hasnext), + [self.visit(gen.iter), self.visit(list_name)], + ) + + # If the iterator is finished, output the iterator and list as is + with self._new_case(inputs, inputs, cond): + pass + + # If there is a next element, compile it and continue with the next + # generator + with self._new_case(inputs, inputs, cond): + + def compile_ifs(ifs: list[ast.expr]) -> None: + if not ifs: + # If there are no guards left, compile the next generator + compile_generators(elt, gens) + return + if_expr, *ifs = ifs + cond = self.graph.add_conditional( + self.visit(if_expr), + [self.visit(gen.iter), self.visit(list_name)], + ) + # If the condition is false, output the iterator and list as is + with self._new_case(inputs, inputs, cond): + pass + # If the condition is true, continue with the next one + with self._new_case(inputs, inputs, cond): + compile_ifs(ifs) + + compiler.compile_stmts([gen.next_assign], self.dfg) + compile_ifs(gen.ifs) + + # After the loop is done, we have to finalize the iterator + self.visit(gen.iterend) + + compile_generators(node.elt, node.generators) + return self.visit(list_name) + def visit_BinOp(self, node: ast.BinOp) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py index 18147b22..a2ab4f6a 100644 --- a/guppy/compiler/stmt_compiler.py +++ b/guppy/compiler/stmt_compiler.py @@ -2,7 +2,6 @@ from collections.abc import Sequence from guppy.ast_util import AstVisitor -from guppy.checker.cfg_checker import CheckedBB from guppy.compiler.core import ( CompiledGlobals, CompilerBase, @@ -22,7 +21,6 @@ class StmtCompiler(CompilerBase, AstVisitor[None]): expr_compiler: ExprCompiler - bb: CheckedBB dfg: DFContainer def __init__(self, graph: Hugr, globals: CompiledGlobals): @@ -32,7 +30,6 @@ def __init__(self, graph: Hugr, globals: CompiledGlobals): def compile_stmts( self, stmts: Sequence[ast.stmt], - bb: CheckedBB, dfg: DFContainer, ) -> DFContainer: """Compiles a list of basic statements into a dataflow node. @@ -40,7 +37,6 @@ def compile_stmts( Note that the `dfg` is mutated in-place. After compilation, the DFG will also contain all variables that are assigned in the given list of statements. """ - self.bb = bb self.dfg = dfg for s in stmts: self.visit(s) diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 71846e6f..a643d065 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -384,6 +384,7 @@ def add_output( parent: Node | None = None, ) -> VNode: """Adds an `Output` node to the graph.""" + parent = parent or self._default_parent node = self.add_node(ops.Output(), input_tys, [], parent, inputs) if isinstance(parent, DFContainingNode): parent.output_child = node @@ -697,6 +698,25 @@ def insert_order_edges(self) -> "Hugr": elif isinstance(n.op, ops.LoadConstant): assert n.parent.input_child is not None self.add_order_edge(n.parent.input_child, n) + + # Also add order edges for non-local edges + for src, tgt in list(self.edges()): + # Exclude CF and constant edges + if isinstance(src, OutPortCF) or isinstance( + src.node.op, ops.FuncDecl | ops.FuncDefn | ops.Const + ): + continue + + if src.node.parent != tgt.node.parent: + # Walk up the hierarchy from the src until we hit a node at the same + # level as tgt + node = tgt.node + while node.parent != src.node.parent: + if node.parent is None: + raise ValueError("Invalid non-local edge!") + node = node.parent + # Edge order edge to make sure that the src is executed first + self.add_order_edge(src.node, node) return self def to_raw(self) -> raw.RawHugr: From a55c14700057aaaed30d7fdc13288cdf3785ccb2 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 20 Dec 2023 13:57:17 +0000 Subject: [PATCH 05/31] test: Add list, loop, and comprehension tests --- guppy/module.py | 25 +- ruff.toml | 2 +- tests/error/comprehension_errors/__init__.py | 0 tests/error/comprehension_errors/capture1.err | 7 + tests/error/comprehension_errors/capture1.py | 16 ++ tests/error/comprehension_errors/capture2.err | 7 + tests/error/comprehension_errors/capture2.py | 21 ++ tests/error/comprehension_errors/guarded1.err | 7 + tests/error/comprehension_errors/guarded1.py | 16 ++ tests/error/comprehension_errors/guarded2.err | 7 + tests/error/comprehension_errors/guarded2.py | 21 ++ .../illegal_short_circuit.err | 7 + .../illegal_short_circuit.py | 6 + .../comprehension_errors/illegal_ternary.err | 7 + .../comprehension_errors/illegal_ternary.py | 6 + .../comprehension_errors/illegal_walrus.err | 7 + .../comprehension_errors/illegal_walrus.py | 6 + .../error/comprehension_errors/multi_use1.err | 7 + .../error/comprehension_errors/multi_use1.py | 16 ++ .../error/comprehension_errors/multi_use2.err | 7 + .../error/comprehension_errors/multi_use2.py | 16 ++ .../error/comprehension_errors/multi_use3.err | 7 + .../error/comprehension_errors/multi_use3.py | 21 ++ tests/error/comprehension_errors/not_used.err | 7 + tests/error/comprehension_errors/not_used.py | 16 ++ .../pattern_override1.err | 7 + .../comprehension_errors/pattern_override1.py | 16 ++ .../pattern_override2.err | 7 + .../comprehension_errors/pattern_override2.py | 16 ++ .../comprehension_errors/used_twice1.err | 7 + .../error/comprehension_errors/used_twice1.py | 16 ++ .../comprehension_errors/used_twice2.err | 7 + .../error/comprehension_errors/used_twice2.py | 21 ++ .../comprehension_errors/used_twice3.err | 7 + .../error/comprehension_errors/used_twice3.py | 21 ++ tests/error/errors_on_usage/for_new_var.err | 7 + tests/error/errors_on_usage/for_new_var.py | 8 + tests/error/errors_on_usage/for_target.err | 7 + tests/error/errors_on_usage/for_target.py | 8 + .../for_target_type_change.err | 7 + .../errors_on_usage/for_target_type_change.py | 9 + .../error/errors_on_usage/for_type_change.err | 7 + .../error/errors_on_usage/for_type_change.py | 9 + tests/error/iter_errors/__init__.py | 0 tests/error/iter_errors/end_missing.err | 7 + tests/error/iter_errors/end_missing.py | 37 +++ tests/error/iter_errors/end_wrong_type.err | 7 + tests/error/iter_errors/end_wrong_type.py | 41 +++ tests/error/iter_errors/hasnext_missing.err | 7 + tests/error/iter_errors/hasnext_missing.py | 37 +++ .../error/iter_errors/hasnext_wrong_type.err | 7 + tests/error/iter_errors/hasnext_wrong_type.py | 41 +++ tests/error/iter_errors/iter_missing.err | 7 + tests/error/iter_errors/iter_missing.py | 20 ++ tests/error/iter_errors/iter_wrong_type.err | 7 + tests/error/iter_errors/iter_wrong_type.py | 24 ++ tests/error/iter_errors/next_missing.err | 7 + tests/error/iter_errors/next_missing.py | 37 +++ tests/error/iter_errors/next_wrong_type.err | 7 + tests/error/iter_errors/next_wrong_type.py | 41 +++ tests/error/linear_errors/for_break.err | 7 + tests/error/linear_errors/for_break.py | 21 ++ tests/error/linear_errors/for_return.err | 7 + tests/error/linear_errors/for_return.py | 21 ++ tests/error/misc_errors/list_linear.err | 6 + tests/error/misc_errors/list_linear.py | 17 ++ tests/error/test_comprehension_errors.py | 20 ++ tests/error/test_iter_errors.py | 20 ++ tests/integration/test_comprehension.py | 260 ++++++++++++++++++ tests/integration/test_for.py | 131 +++++++++ tests/integration/test_linear.py | 97 +++++++ tests/integration/test_linst.py | 66 +++++ tests/integration/test_list.py | 45 +++ tests/integration/test_poly.py | 16 ++ validator/src/lib.rs | 5 +- 75 files changed, 1517 insertions(+), 8 deletions(-) create mode 100644 tests/error/comprehension_errors/__init__.py create mode 100644 tests/error/comprehension_errors/capture1.err create mode 100644 tests/error/comprehension_errors/capture1.py create mode 100644 tests/error/comprehension_errors/capture2.err create mode 100644 tests/error/comprehension_errors/capture2.py create mode 100644 tests/error/comprehension_errors/guarded1.err create mode 100644 tests/error/comprehension_errors/guarded1.py create mode 100644 tests/error/comprehension_errors/guarded2.err create mode 100644 tests/error/comprehension_errors/guarded2.py create mode 100644 tests/error/comprehension_errors/illegal_short_circuit.err create mode 100644 tests/error/comprehension_errors/illegal_short_circuit.py create mode 100644 tests/error/comprehension_errors/illegal_ternary.err create mode 100644 tests/error/comprehension_errors/illegal_ternary.py create mode 100644 tests/error/comprehension_errors/illegal_walrus.err create mode 100644 tests/error/comprehension_errors/illegal_walrus.py create mode 100644 tests/error/comprehension_errors/multi_use1.err create mode 100644 tests/error/comprehension_errors/multi_use1.py create mode 100644 tests/error/comprehension_errors/multi_use2.err create mode 100644 tests/error/comprehension_errors/multi_use2.py create mode 100644 tests/error/comprehension_errors/multi_use3.err create mode 100644 tests/error/comprehension_errors/multi_use3.py create mode 100644 tests/error/comprehension_errors/not_used.err create mode 100644 tests/error/comprehension_errors/not_used.py create mode 100644 tests/error/comprehension_errors/pattern_override1.err create mode 100644 tests/error/comprehension_errors/pattern_override1.py create mode 100644 tests/error/comprehension_errors/pattern_override2.err create mode 100644 tests/error/comprehension_errors/pattern_override2.py create mode 100644 tests/error/comprehension_errors/used_twice1.err create mode 100644 tests/error/comprehension_errors/used_twice1.py create mode 100644 tests/error/comprehension_errors/used_twice2.err create mode 100644 tests/error/comprehension_errors/used_twice2.py create mode 100644 tests/error/comprehension_errors/used_twice3.err create mode 100644 tests/error/comprehension_errors/used_twice3.py create mode 100644 tests/error/errors_on_usage/for_new_var.err create mode 100644 tests/error/errors_on_usage/for_new_var.py create mode 100644 tests/error/errors_on_usage/for_target.err create mode 100644 tests/error/errors_on_usage/for_target.py create mode 100644 tests/error/errors_on_usage/for_target_type_change.err create mode 100644 tests/error/errors_on_usage/for_target_type_change.py create mode 100644 tests/error/errors_on_usage/for_type_change.err create mode 100644 tests/error/errors_on_usage/for_type_change.py create mode 100644 tests/error/iter_errors/__init__.py create mode 100644 tests/error/iter_errors/end_missing.err create mode 100644 tests/error/iter_errors/end_missing.py create mode 100644 tests/error/iter_errors/end_wrong_type.err create mode 100644 tests/error/iter_errors/end_wrong_type.py create mode 100644 tests/error/iter_errors/hasnext_missing.err create mode 100644 tests/error/iter_errors/hasnext_missing.py create mode 100644 tests/error/iter_errors/hasnext_wrong_type.err create mode 100644 tests/error/iter_errors/hasnext_wrong_type.py create mode 100644 tests/error/iter_errors/iter_missing.err create mode 100644 tests/error/iter_errors/iter_missing.py create mode 100644 tests/error/iter_errors/iter_wrong_type.err create mode 100644 tests/error/iter_errors/iter_wrong_type.py create mode 100644 tests/error/iter_errors/next_missing.err create mode 100644 tests/error/iter_errors/next_missing.py create mode 100644 tests/error/iter_errors/next_wrong_type.err create mode 100644 tests/error/iter_errors/next_wrong_type.py create mode 100644 tests/error/linear_errors/for_break.err create mode 100644 tests/error/linear_errors/for_break.py create mode 100644 tests/error/linear_errors/for_return.err create mode 100644 tests/error/linear_errors/for_return.py create mode 100644 tests/error/misc_errors/list_linear.err create mode 100644 tests/error/misc_errors/list_linear.py create mode 100644 tests/error/test_comprehension_errors.py create mode 100644 tests/error/test_iter_errors.py create mode 100644 tests/integration/test_comprehension.py create mode 100644 tests/integration/test_for.py create mode 100644 tests/integration/test_linst.py create mode 100644 tests/integration/test_list.py diff --git a/guppy/module.py b/guppy/module.py index 09fe2c42..9e806d10 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -17,6 +17,7 @@ from guppy.hugr.hugr import Hugr PyFunc = Callable[..., Any] +PyFuncDefOrDecl = tuple[bool, PyFunc] class GuppyModule: @@ -44,7 +45,7 @@ class GuppyModule: # When `_instance_buffer` is not `None`, then all registered functions will be # buffered in this list. They only get properly registered, once # `_register_buffered_instance_funcs` is called. This way, we can associate - _instance_func_buffer: dict[str, PyFunc | CustomFunction] | None + _instance_func_buffer: dict[str, PyFuncDefOrDecl | CustomFunction] | None def __init__(self, name: str, import_builtins: bool = True): self.name = name @@ -94,7 +95,7 @@ def register_func_def( self._check_not_yet_compiled() func_ast = parse_py_func(f) if self._instance_func_buffer is not None: - self._instance_func_buffer[func_ast.name] = f + self._instance_func_buffer[func_ast.name] = (True, f) else: name = ( qualified_name(instance, func_ast.name) if instance else func_ast.name @@ -102,12 +103,20 @@ def register_func_def( self._check_name_available(name, func_ast) self._func_defs[name] = func_ast - def register_func_decl(self, f: PyFunc) -> None: + def register_func_decl( + self, f: PyFunc, instance: type[GuppyType] | None = None + ) -> None: """Registers a Python function declaration as belonging to this Guppy module.""" self._check_not_yet_compiled() func_ast = parse_py_func(f) - self._check_name_available(func_ast.name, func_ast) - self._func_decls[func_ast.name] = func_ast + if self._instance_func_buffer is not None: + self._instance_func_buffer[func_ast.name] = (False, f) + else: + name = ( + qualified_name(instance, func_ast.name) if instance else func_ast.name + ) + self._check_name_available(name, func_ast) + self._func_decls[name] = func_ast def register_custom_func( self, func: CustomFunction, instance: type[GuppyType] | None = None @@ -142,7 +151,11 @@ def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: if isinstance(f, CustomFunction): self.register_custom_func(f, instance) else: - self.register_func_def(f, instance) + is_def, pyfunc = f + if is_def: + self.register_func_def(pyfunc, instance) + else: + self.register_func_decl(pyfunc, instance) @property def compiled(self) -> bool: diff --git a/ruff.toml b/ruff.toml index 8c9e618a..cfe3fd0a 100644 --- a/ruff.toml +++ b/ruff.toml @@ -73,7 +73,7 @@ ignore = [ [per-file-ignores] "guppy/ast_util.py" = ["B009", "B010"] "guppy/decorator.py" = ["B010"] -"tests/integration/*" = ["F841"] +"tests/integration/*" = ["F841", "C416", "RUF005"] "tests/{hugr,integration}/*" = ["B", "FBT", "SIM", "I"] # [pydocstyle] diff --git a/tests/error/comprehension_errors/__init__.py b/tests/error/comprehension_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/comprehension_errors/capture1.err b/tests/error/comprehension_errors/capture1.err new file mode 100644 index 00000000..078f9eaf --- /dev/null +++ b/tests/error/comprehension_errors/capture1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(xs: list[int], q: Qubit) -> linst[Qubit]: +13: return [q for x in xs] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/capture1.py b/tests/error/comprehension_errors/capture1.py new file mode 100644 index 00000000..9af5781a --- /dev/null +++ b/tests/error/comprehension_errors/capture1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(xs: list[int], q: Qubit) -> linst[Qubit]: + return [q for x in xs] + + +module.compile() diff --git a/tests/error/comprehension_errors/capture2.err b/tests/error/comprehension_errors/capture2.err new file mode 100644 index 00000000..5e16e12a --- /dev/null +++ b/tests/error/comprehension_errors/capture2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(xs: list[int], q: Qubit) -> list[int]: +18: return [x for x in xs if bar(q)] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/capture2.py b/tests/error/comprehension_errors/capture2.py new file mode 100644 index 00000000..27970631 --- /dev/null +++ b/tests/error/comprehension_errors/capture2.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> bool: + ... + + +@guppy(module) +def foo(xs: list[int], q: Qubit) -> list[int]: + return [x for x in xs if bar(q)] + + +module.compile() diff --git a/tests/error/comprehension_errors/guarded1.err b/tests/error/comprehension_errors/guarded1.err new file mode 100644 index 00000000..02498bc7 --- /dev/null +++ b/tests/error/comprehension_errors/guarded1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[tuple[bool, Qubit]]) -> linst[Qubit]: +13: return [q for b, q in qs if b] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` is not used on all control-flow paths of the list comprehension diff --git a/tests/error/comprehension_errors/guarded1.py b/tests/error/comprehension_errors/guarded1.py new file mode 100644 index 00000000..ecb264a9 --- /dev/null +++ b/tests/error/comprehension_errors/guarded1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[tuple[bool, Qubit]]) -> linst[Qubit]: + return [q for b, q in qs if b] + + +module.compile() diff --git a/tests/error/comprehension_errors/guarded2.err b/tests/error/comprehension_errors/guarded2.err new file mode 100644 index 00000000..673103ac --- /dev/null +++ b/tests/error/comprehension_errors/guarded2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(qs: linst[tuple[bool, Qubit]]) -> list[int]: +18: return [42 for b, q in qs if b if q] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` is not used on all control-flow paths of the list comprehension diff --git a/tests/error/comprehension_errors/guarded2.py b/tests/error/comprehension_errors/guarded2.py new file mode 100644 index 00000000..5413b8c0 --- /dev/null +++ b/tests/error/comprehension_errors/guarded2.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> bool: + ... + + +@guppy(module) +def foo(qs: linst[tuple[bool, Qubit]]) -> list[int]: + return [42 for b, q in qs if b if q] + + +module.compile() diff --git a/tests/error/comprehension_errors/illegal_short_circuit.err b/tests/error/comprehension_errors/illegal_short_circuit.err new file mode 100644 index 00000000..6c8bfb78 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_short_circuit.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo(xs: list[int]) -> None: +6: [x for x in xs if x < 5 and x != 6] + ^^^^^^^^^^^^^^^^ +GuppyError: Expression is not supported inside a list comprehension diff --git a/tests/error/comprehension_errors/illegal_short_circuit.py b/tests/error/comprehension_errors/illegal_short_circuit.py new file mode 100644 index 00000000..dfd0b0a2 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_short_circuit.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[int]) -> None: + [x for x in xs if x < 5 and x != 6] diff --git a/tests/error/comprehension_errors/illegal_ternary.err b/tests/error/comprehension_errors/illegal_ternary.err new file mode 100644 index 00000000..b7e9e7a9 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_ternary.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo(xs: list[int], ys: list[int], b: bool) -> None: +6: [x for x in (xs if b else ys)] + ^^^^^^^^^^^^^^^ +GuppyError: Expression is not supported inside a list comprehension diff --git a/tests/error/comprehension_errors/illegal_ternary.py b/tests/error/comprehension_errors/illegal_ternary.py new file mode 100644 index 00000000..3b587938 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_ternary.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[int], ys: list[int], b: bool) -> None: + [x for x in (xs if b else ys)] diff --git a/tests/error/comprehension_errors/illegal_walrus.err b/tests/error/comprehension_errors/illegal_walrus.err new file mode 100644 index 00000000..2fbb77d7 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_walrus.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo(xs: list[int]) -> None: +6: [y := x for x in xs] + ^^^^^^ +GuppyError: Expression is not supported inside a list comprehension diff --git a/tests/error/comprehension_errors/illegal_walrus.py b/tests/error/comprehension_errors/illegal_walrus.py new file mode 100644 index 00000000..d95a345b --- /dev/null +++ b/tests/error/comprehension_errors/illegal_walrus.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[int]) -> None: + [y := x for x in xs] diff --git a/tests/error/comprehension_errors/multi_use1.err b/tests/error/comprehension_errors/multi_use1.err new file mode 100644 index 00000000..8fcb48f9 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: +13: return [q for q in qs for x in xs] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/multi_use1.py b/tests/error/comprehension_errors/multi_use1.py new file mode 100644 index 00000000..d80f6a38 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: + return [q for q in qs for x in xs] + + +module.compile() diff --git a/tests/error/comprehension_errors/multi_use2.err b/tests/error/comprehension_errors/multi_use2.err new file mode 100644 index 00000000..83e6f324 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: +13: return [q for x in xs for q in qs] + ^^ +GuppyTypeError: Variable `qs` with linear type `linst[Qubit]` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/multi_use2.py b/tests/error/comprehension_errors/multi_use2.py new file mode 100644 index 00000000..9dc78ba7 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use2.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: + return [q for x in xs for q in qs] + + +module.compile() diff --git a/tests/error/comprehension_errors/multi_use3.err b/tests/error/comprehension_errors/multi_use3.err new file mode 100644 index 00000000..c8b5ab0b --- /dev/null +++ b/tests/error/comprehension_errors/multi_use3.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(qs: linst[Qubit], xs: list[int]) -> list[int]: +18: return [x for q in qs for x in xs if bar(q)] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/multi_use3.py b/tests/error/comprehension_errors/multi_use3.py new file mode 100644 index 00000000..cdf10bc4 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use3.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> bool: + ... + + +@guppy(module) +def foo(qs: linst[Qubit], xs: list[int]) -> list[int]: + return [x for q in qs for x in xs if bar(q)] + + +module.compile() diff --git a/tests/error/comprehension_errors/not_used.err b/tests/error/comprehension_errors/not_used.err new file mode 100644 index 00000000..9c191989 --- /dev/null +++ b/tests/error/comprehension_errors/not_used.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit]) -> list[int]: +13: return [42 for q in qs] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` is not used diff --git a/tests/error/comprehension_errors/not_used.py b/tests/error/comprehension_errors/not_used.py new file mode 100644 index 00000000..c5854103 --- /dev/null +++ b/tests/error/comprehension_errors/not_used.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit]) -> list[int]: + return [42 for q in qs] + + +module.compile() diff --git a/tests/error/comprehension_errors/pattern_override1.err b/tests/error/comprehension_errors/pattern_override1.err new file mode 100644 index 00000000..d8fe0a3b --- /dev/null +++ b/tests/error/comprehension_errors/pattern_override1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[tuple[Qubit, Qubit]]) -> linst[Qubit]: +13: return [q for q, q in qs] + ^ +GuppyError: Variable `q` with linear type `Qubit` is not used diff --git a/tests/error/comprehension_errors/pattern_override1.py b/tests/error/comprehension_errors/pattern_override1.py new file mode 100644 index 00000000..0a46bed3 --- /dev/null +++ b/tests/error/comprehension_errors/pattern_override1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[tuple[Qubit, Qubit]]) -> linst[Qubit]: + return [q for q, q in qs] + + +module.compile() diff --git a/tests/error/comprehension_errors/pattern_override2.err b/tests/error/comprehension_errors/pattern_override2.err new file mode 100644 index 00000000..58c71958 --- /dev/null +++ b/tests/error/comprehension_errors/pattern_override2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: +13: return [q for q in qs for q in xs] + ^ +GuppyError: Variable `q` with linear type `Qubit` is not used diff --git a/tests/error/comprehension_errors/pattern_override2.py b/tests/error/comprehension_errors/pattern_override2.py new file mode 100644 index 00000000..b684e202 --- /dev/null +++ b/tests/error/comprehension_errors/pattern_override2.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: + return [q for q in qs for q in xs] + + +module.compile() diff --git a/tests/error/comprehension_errors/used_twice1.err b/tests/error/comprehension_errors/used_twice1.err new file mode 100644 index 00000000..013258e9 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit]) -> linst[tuple[Qubit, Qubit]]: +13: return [(q, q) for q in qs] + ^ +GuppyError: Variable `q` with linear type `Qubit` was already used (at 13:13) diff --git a/tests/error/comprehension_errors/used_twice1.py b/tests/error/comprehension_errors/used_twice1.py new file mode 100644 index 00000000..920ae354 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit]) -> linst[tuple[Qubit, Qubit]]: + return [(q, q) for q in qs] + + +module.compile() diff --git a/tests/error/comprehension_errors/used_twice2.err b/tests/error/comprehension_errors/used_twice2.err new file mode 100644 index 00000000..07f49da6 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(qs: linst[Qubit]) -> linst[Qubit]: +18: return [q for q in qs if bar(q)] + ^ +GuppyError: Variable `q` with linear type `Qubit` was already used (at 18:33) diff --git a/tests/error/comprehension_errors/used_twice2.py b/tests/error/comprehension_errors/used_twice2.py new file mode 100644 index 00000000..8d2ffbe1 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice2.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> bool: + ... + + +@guppy(module) +def foo(qs: linst[Qubit]) -> linst[Qubit]: + return [q for q in qs if bar(q)] + + +module.compile() diff --git a/tests/error/comprehension_errors/used_twice3.err b/tests/error/comprehension_errors/used_twice3.err new file mode 100644 index 00000000..f8add102 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice3.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(qs: linst[Qubit]) -> linst[Qubit]: +18: return [q for q in qs for x in bar(q)] + ^ +GuppyError: Variable `q` with linear type `Qubit` was already used (at 18:39) diff --git a/tests/error/comprehension_errors/used_twice3.py b/tests/error/comprehension_errors/used_twice3.py new file mode 100644 index 00000000..784070e0 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice3.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> list[int]: + ... + + +@guppy(module) +def foo(qs: linst[Qubit]) -> linst[Qubit]: + return [q for q in qs for x in bar(q)] + + +module.compile() diff --git a/tests/error/errors_on_usage/for_new_var.err b/tests/error/errors_on_usage/for_new_var.err new file mode 100644 index 00000000..7aae5692 --- /dev/null +++ b/tests/error/errors_on_usage/for_new_var.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:8 + +6: for _ in xs: +7: y = 5 +8: return y + ^ +GuppyError: Variable `y` is not defined on all control-flow paths. diff --git a/tests/error/errors_on_usage/for_new_var.py b/tests/error/errors_on_usage/for_new_var.py new file mode 100644 index 00000000..f0f35a0f --- /dev/null +++ b/tests/error/errors_on_usage/for_new_var.py @@ -0,0 +1,8 @@ +from tests.error.util import guppy + + +@guppy +def foo(xs: list[int]) -> int: + for _ in xs: + y = 5 + return y diff --git a/tests/error/errors_on_usage/for_target.err b/tests/error/errors_on_usage/for_target.err new file mode 100644 index 00000000..30602c60 --- /dev/null +++ b/tests/error/errors_on_usage/for_target.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:8 + +6: for x in xs: +7: pass +8: return x + ^ +GuppyError: Variable `x` is not defined on all control-flow paths. diff --git a/tests/error/errors_on_usage/for_target.py b/tests/error/errors_on_usage/for_target.py new file mode 100644 index 00000000..3d9c9354 --- /dev/null +++ b/tests/error/errors_on_usage/for_target.py @@ -0,0 +1,8 @@ +from tests.error.util import guppy + + +@guppy +def foo(xs: list[int]) -> int: + for x in xs: + pass + return x diff --git a/tests/error/errors_on_usage/for_target_type_change.err b/tests/error/errors_on_usage/for_target_type_change.err new file mode 100644 index 00000000..82bffc0e --- /dev/null +++ b/tests/error/errors_on_usage/for_target_type_change.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: for x in xs: +8: pass +9: return x + ^ +GuppyError: Variable `x` can refer to different types: `int` (at 6:4) vs `bool` (at 7:8) diff --git a/tests/error/errors_on_usage/for_target_type_change.py b/tests/error/errors_on_usage/for_target_type_change.py new file mode 100644 index 00000000..5c2cfc54 --- /dev/null +++ b/tests/error/errors_on_usage/for_target_type_change.py @@ -0,0 +1,9 @@ +from tests.error.util import guppy + + +@guppy +def foo(xs: list[bool]) -> int: + x = 5 + for x in xs: + pass + return x diff --git a/tests/error/errors_on_usage/for_type_change.err b/tests/error/errors_on_usage/for_type_change.err new file mode 100644 index 00000000..8b4fb557 --- /dev/null +++ b/tests/error/errors_on_usage/for_type_change.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: for x in xs: +8: y = True +9: return y + ^ +GuppyError: Variable `y` can refer to different types: `int` (at 6:4) vs `bool` (at 8:8) diff --git a/tests/error/errors_on_usage/for_type_change.py b/tests/error/errors_on_usage/for_type_change.py new file mode 100644 index 00000000..b22b20ab --- /dev/null +++ b/tests/error/errors_on_usage/for_type_change.py @@ -0,0 +1,9 @@ +from tests.error.util import guppy + + +@guppy +def foo(xs: list[int]) -> int: + y = 5 + for x in xs: + y = True + return y diff --git a/tests/error/iter_errors/__init__.py b/tests/error/iter_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/iter_errors/end_missing.err b/tests/error/iter_errors/end_missing.err new file mode 100644 index 00000000..42f0022b --- /dev/null +++ b/tests/error/iter_errors/end_missing.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file /Users/mark.koch/code/guppy/tests/error/iter_errors/end_missing.py:33 + +31: @guppy(module) +32: def test(x: MyType) -> None: +33: for _ in x: + ^ +GuppyTypeError: Expression of type `MyIter` is not an iterator since it does not implement the `__end__` method diff --git a/tests/error/iter_errors/end_missing.py b/tests/error/iter_errors/end_missing.py new file mode 100644 index 00000000..4a8f6e0f --- /dev/null +++ b/tests/error/iter_errors/end_missing.py @@ -0,0 +1,37 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyIter: + """An iterator that is missing the `__end__` method.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[None, "MyIter"]: + ... + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/end_wrong_type.err b/tests/error/iter_errors/end_wrong_type.err new file mode 100644 index 00000000..635d6f84 --- /dev/null +++ b/tests/error/iter_errors/end_wrong_type.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:37 + +35: @guppy(module) +36: def test(x: MyType) -> None: +37: for _ in x: + ^ +GuppyError: Method `MyIter.__end__` has signature `MyIter -> MyIter`, but expected `MyIter -> None` diff --git a/tests/error/iter_errors/end_wrong_type.py b/tests/error/iter_errors/end_wrong_type.py new file mode 100644 index 00000000..edcb175e --- /dev/null +++ b/tests/error/iter_errors/end_wrong_type.py @@ -0,0 +1,41 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyIter: + """An iterator where the `__end__` method has the wrong signature.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[None, "MyIter"]: + ... + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> "MyIter": + ... + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/hasnext_missing.err b/tests/error/iter_errors/hasnext_missing.err new file mode 100644 index 00000000..949b776f --- /dev/null +++ b/tests/error/iter_errors/hasnext_missing.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:33 + +31: @guppy(module) +32: def test(x: MyType) -> None: +33: for _ in x: + ^ +GuppyTypeError: Expression of type `MyIter` is not an iterator since it does not implement the `__hasnext__` method diff --git a/tests/error/iter_errors/hasnext_missing.py b/tests/error/iter_errors/hasnext_missing.py new file mode 100644 index 00000000..a0046e4e --- /dev/null +++ b/tests/error/iter_errors/hasnext_missing.py @@ -0,0 +1,37 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyIter: + """An iterator that is missing the `__hasnext__` method.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[None, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/hasnext_wrong_type.err b/tests/error/iter_errors/hasnext_wrong_type.err new file mode 100644 index 00000000..234e217a --- /dev/null +++ b/tests/error/iter_errors/hasnext_wrong_type.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:37 + +35: @guppy(module) +36: def test(x: MyType) -> None: +37: for _ in x: + ^ +GuppyError: Method `MyIter.__hasnext__` has signature `MyIter -> bool`, but expected `MyIter -> (bool, MyIter)` diff --git a/tests/error/iter_errors/hasnext_wrong_type.py b/tests/error/iter_errors/hasnext_wrong_type.py new file mode 100644 index 00000000..9f4e6929 --- /dev/null +++ b/tests/error/iter_errors/hasnext_wrong_type.py @@ -0,0 +1,41 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyIter: + """An iterator where the `__hasnext__` method has the wrong signature.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[None, "MyIter"]: + ... + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> bool: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/iter_missing.err b/tests/error/iter_errors/iter_missing.err new file mode 100644 index 00000000..3e3c571d --- /dev/null +++ b/tests/error/iter_errors/iter_missing.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:16 + +14: @guppy(module) +15: def test(x: MyType) -> None: +16: for _ in x: + ^ +GuppyTypeError: Expression of type `MyType` is not iterable diff --git a/tests/error/iter_errors/iter_missing.py b/tests/error/iter_errors/iter_missing.py new file mode 100644 index 00000000..dab0c6c6 --- /dev/null +++ b/tests/error/iter_errors/iter_missing.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyType: + """A non-iterable type.""" + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/iter_wrong_type.err b/tests/error/iter_errors/iter_wrong_type.err new file mode 100644 index 00000000..cc10baab --- /dev/null +++ b/tests/error/iter_errors/iter_wrong_type.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:20 + +18: @guppy(module) +19: def test(x: MyType) -> None: +20: for _ in x: + ^ +GuppyError: Method `MyType.__iter__` has signature `(MyType, int) -> MyType`, but expected `MyType -> ?Iter` diff --git a/tests/error/iter_errors/iter_wrong_type.py b/tests/error/iter_errors/iter_wrong_type.py new file mode 100644 index 00000000..4fa63c37 --- /dev/null +++ b/tests/error/iter_errors/iter_wrong_type.py @@ -0,0 +1,24 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyType: + """A type where the `__iter__` method has the wrong signature.""" + + @guppy.declare(module) + def __iter__(self: "MyType", x: int) -> "MyType": + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/next_missing.err b/tests/error/iter_errors/next_missing.err new file mode 100644 index 00000000..8ba51c01 --- /dev/null +++ b/tests/error/iter_errors/next_missing.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:33 + +31: @guppy(module) +32: def test(x: MyType) -> None: +33: for _ in x: + ^ +GuppyTypeError: Expression of type `MyIter` is not an iterator since it does not implement the `__next__` method diff --git a/tests/error/iter_errors/next_missing.py b/tests/error/iter_errors/next_missing.py new file mode 100644 index 00000000..8a1e509b --- /dev/null +++ b/tests/error/iter_errors/next_missing.py @@ -0,0 +1,37 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyIter: + """An iterator that is missing the `__next__` method.""" + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/next_wrong_type.err b/tests/error/iter_errors/next_wrong_type.err new file mode 100644 index 00000000..b9bac527 --- /dev/null +++ b/tests/error/iter_errors/next_wrong_type.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:37 + +35: @guppy(module) +36: def test(x: MyType) -> None: +37: for _ in x: + ^ +GuppyError: Method `MyIter.__next__` has signature `MyIter -> (MyIter, float)`, but expected `MyIter -> (?T, MyIter)` diff --git a/tests/error/iter_errors/next_wrong_type.py b/tests/error/iter_errors/next_wrong_type.py new file mode 100644 index 00000000..4b454616 --- /dev/null +++ b/tests/error/iter_errors/next_wrong_type.py @@ -0,0 +1,41 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyIter: + """An iterator where the `__next__` method has the wrong signature.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple["MyIter", float]: + ... + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + +@guppy.type(module, tys.Tuple(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/linear_errors/for_break.err b/tests/error/linear_errors/for_break.err new file mode 100644 index 00000000..2d4c195f --- /dev/null +++ b/tests/error/linear_errors/for_break.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: rs += [q] +16: if b: +17: break + ^^^^^ +GuppyTypeError: Loop over iterator with linear type `linst[(Qubit, bool)]` cannot be terminated (cannot ensure that all values have been used) diff --git a/tests/error/linear_errors/for_break.py b/tests/error/linear_errors/for_break.py new file mode 100644 index 00000000..3f5f6cff --- /dev/null +++ b/tests/error/linear_errors/for_break.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[tuple[Qubit, bool]]) -> linst[Qubit]: + rs: linst[Qubit] = [] + for q, b in qs: + rs += [q] + if b: + break + return rs + + +module.compile() diff --git a/tests/error/linear_errors/for_return.err b/tests/error/linear_errors/for_return.err new file mode 100644 index 00000000..aae47b8a --- /dev/null +++ b/tests/error/linear_errors/for_return.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: rs += [q] +16: if b: +17: return [] + ^^^^^^^^^ +GuppyTypeError: Loop over iterator with linear type `linst[(Qubit, bool)]` cannot be terminated (cannot ensure that all values have been used) diff --git a/tests/error/linear_errors/for_return.py b/tests/error/linear_errors/for_return.py new file mode 100644 index 00000000..05982f9c --- /dev/null +++ b/tests/error/linear_errors/for_return.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[tuple[Qubit, bool]]) -> linst[Qubit]: + rs: linst[Qubit] = [] + for q, b in qs: + rs += [q] + if b: + return [] + return rs + + +module.compile() diff --git a/tests/error/misc_errors/list_linear.err b/tests/error/misc_errors/list_linear.err new file mode 100644 index 00000000..39da0457 --- /dev/null +++ b/tests/error/misc_errors/list_linear.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo() -> list[Qubit]: + ^^^^^^^^^^^ +GuppyError: Type `list` cannot store linear data, use `linst` instead diff --git a/tests/error/misc_errors/list_linear.py b/tests/error/misc_errors/list_linear.py new file mode 100644 index 00000000..5b377e53 --- /dev/null +++ b/tests/error/misc_errors/list_linear.py @@ -0,0 +1,17 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.prelude.quantum import Qubit + +import guppy.prelude.quantum as quantum + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo() -> list[Qubit]: + return [] + + +module.compile() diff --git a/tests/error/test_comprehension_errors.py b/tests/error/test_comprehension_errors.py new file mode 100644 index 00000000..9e40f5fa --- /dev/null +++ b/tests/error/test_comprehension_errors.py @@ -0,0 +1,20 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "comprehension_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_comprehension_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/error/test_iter_errors.py b/tests/error/test_iter_errors.py new file mode 100644 index 00000000..60c8c2a6 --- /dev/null +++ b/tests/error/test_iter_errors.py @@ -0,0 +1,20 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "iter_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_iter_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/integration/test_comprehension.py b/tests/integration/test_comprehension.py new file mode 100644 index 00000000..9587f775 --- /dev/null +++ b/tests/integration/test_comprehension.py @@ -0,0 +1,260 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule +from guppy.prelude.builtins import linst +from guppy.prelude.quantum import Qubit, h, cx + +import guppy.prelude.quantum as quantum + + +def test_basic(validate): + @guppy + def test(xs: list[float]) -> list[int]: + return [int(x) for x in xs] + + validate(test) + + +def test_basic_linear(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(qs: linst[Qubit]) -> linst[Qubit]: + return [h(q) for q in qs] + + validate(module.compile()) + + +def test_guarded(validate): + @guppy + def test(xs: list[int]) -> list[int]: + return [2 * x for x in xs if x > 0 if x < 20] + + validate(test) + + +def test_multiple(validate): + @guppy + def test(xs: list[int], ys: list[int]) -> list[int]: + return [x + y for x in xs for y in ys if x + y > 42] + + validate(test) + + +def test_tuple_pat(validate): + @guppy + def test(xs: list[tuple[int, int, float]]) -> list[float]: + return [x + y * z for x, y, z in xs if x - y > z] + + validate(test) + + +def test_tuple_pat_linear(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(qs: linst[tuple[int, Qubit, Qubit]]) -> linst[tuple[Qubit, Qubit]]: + return [cx(q1, q2) for _, q1, q2 in qs] + + validate(module.compile()) + + +def test_tuple_return(validate): + @guppy + def test(xs: list[int], ys: list[float]) -> list[tuple[int, float]]: + return [(x, y) for x in xs for y in ys] + + validate(test) + + +def test_dependent(validate): + module = GuppyModule("test") + + @guppy.declare(module) + def process(x: float) -> list[int]: + ... + + @guppy(module) + def test(xs: list[float]) -> list[float]: + return [x * y for x in xs if x > 0 for y in process(x) if y > x] + + validate(module.compile()) + + +def test_capture(validate): + @guppy + def test(xs: list[int], y: int) -> list[int]: + return [x + y for x in xs if x > y] + + validate(test) + + +def test_scope(validate): + @guppy + def test(xs: list[None]) -> float: + x = 42.0 + [x for x in xs] + return x + + validate(test) + + +def test_nested_left(validate): + @guppy + def test(xs: list[int], ys: list[float]) -> list[list[float]]: + return [[x + y for y in ys] for x in xs] + + validate(test) + + +def test_nested_right(validate): + @guppy + def test(xs: list[int]) -> list[int]: + return [-x for x in [2 * x for x in xs]] + + validate(test) + + +def test_nested_linear(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(qs: linst[Qubit]) -> linst[Qubit]: + return [h(q) for q in [h(q) for q in qs]] + + validate(module.compile()) + + +def test_classical_linst_comp(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(xs: list[int]) -> linst[int]: + return [x for x in xs] + + validate(module.compile()) + + +def test_linear_discard(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def discard(q: Qubit) -> None: + ... + + @guppy(module) + def test(qs: linst[Qubit]) -> list[None]: + return [discard(q) for q in qs] + + validate(module.compile()) + + +def test_linear_consume_in_guard(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def cond(q: Qubit) -> bool: + ... + + @guppy(module) + def test(qs: linst[tuple[int, Qubit]]) -> list[int]: + return [x for x, q in qs if cond(q)] + + validate(module.compile()) + + +def test_linear_consume_in_iter(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def make_list(q: Qubit) -> list[int]: + ... + + @guppy(module) + def test(qs: linst[Qubit]) -> list[int]: + return [x for q in qs for x in make_list(q)] + + validate(module.compile()) + + +def test_linear_next_nonlinear_iter(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.type(module, tys.Tuple(inner=[])) + class MyIter: + """An iterator that yields linear values but is not linear itself.""" + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[Qubit, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + @guppy.type(module, tys.Tuple(inner=[])) + class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + @guppy(module) + def test(mt: MyType, xs: list[int]) -> linst[tuple[int, Qubit]]: + # We can use `mt` in an inner loop since it's not linear + return [(x, q) for x in xs for q in mt] + + validate(module.compile()) + + +def test_nonlinear_next_linear_iter(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.type( + module, + tys.Opaque(extension="prelude", id="qubit", args=[], bound=tys.TypeBound.Any), + linear=True, + ) + class MyIter: + """A linear iterator that yields non-linear values.""" + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[int, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + @guppy.type(module, tys.Tuple(inner=[])) + class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + @guppy(module) + def test(mt: MyType, xs: list[int]) -> linst[tuple[int, int]]: + # We can use `mt` in an outer loop since the target `x` is not linear + return [(x, x + y) for x in mt for y in xs] + + validate(module.compile()) diff --git a/tests/integration/test_for.py b/tests/integration/test_for.py new file mode 100644 index 00000000..585d75e4 --- /dev/null +++ b/tests/integration/test_for.py @@ -0,0 +1,131 @@ +from guppy.decorator import guppy + + +def test_basic(validate): + @guppy + def foo(xs: list[int]) -> int: + for x in xs: + pass + return 0 + + validate(foo) + + +def test_counting_loop(validate): + @guppy + def foo(xs: list[int]) -> int: + s = 0 + for x in xs: + s += x + return s + + validate(foo) + + +def test_multi_targets(validate): + @guppy + def foo(xs: list[tuple[int, float]]) -> float: + s = 0.0 + for x, y in xs: + s += x * y + return s + + validate(foo) + + +def test_multi_targets_same(validate): + @guppy + def foo(xs: list[tuple[int, float]]) -> float: + s = 1.0 + for x, x in xs: + s *= x + return s + + validate(foo) + + +def test_reassign_iter(validate): + @guppy + def foo(xs: list[int]) -> int: + s = 1 + for x in xs: + xs = False + s += x + return s + + validate(foo) + + +def test_break(validate): + @guppy + def foo(xs: list[int]) -> int: + i = 1 + for x in xs: + if x >= 42: + break + i *= x + return i + + validate(foo) + + +def test_continue(validate): + @guppy + def foo(xs: list[int]) -> int: + i = len(xs) + for x in xs: + if x >= 42: + continue + i -= 1 + return i + + validate(foo) + + validate(foo) + + +def test_return_in_loop(validate): + @guppy + def foo(xs: list[int]) -> int: + y = 42 + for x in xs: + if x >= 1337: + return x * y + y = y + x + return y + + validate(foo) + + +def test_nested_loop(validate): + @guppy + def foo(xs: list[int], ys: list[int]) -> int: + p = 0 + for x in xs: + s = 0 + for y in ys: + s += y * x + p += s - x + return p + + validate(foo) + + +def test_nested_loop_break_continue(validate): + @guppy + def foo(xs: list[int], ys: list[int]) -> int: + p = 0 + for x in xs: + s = 0 + for y in ys: + if x % 2 == 0: + continue + s += x + if s > y: + s = y + else: + break + p += s * x + return p + + validate(foo) diff --git a/tests/integration/test_linear.py b/tests/integration/test_linear.py index 35091b61..b7597416 100644 --- a/tests/integration/test_linear.py +++ b/tests/integration/test_linear.py @@ -1,5 +1,7 @@ from guppy.decorator import guppy +from guppy.hugr import tys from guppy.module import GuppyModule +from guppy.prelude.builtins import linst from guppy.prelude.quantum import Qubit import guppy.prelude.quantum as quantum @@ -207,6 +209,101 @@ def foo(i: bool) -> bool: return b +def test_for(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(qs: linst[tuple[Qubit, Qubit]]) -> linst[Qubit]: + rs: linst[Qubit] = [] + for q1, q2 in qs: + q1, q2 = cx(q1, q2) + rs += [q1, q2] + return rs + + validate(module.compile()) + + +def test_for_measure(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def measure(q: Qubit) -> bool: + ... + + @guppy(module) + def test(qs: linst[Qubit]) -> bool: + parity = False + for q in qs: + parity |= measure(q) + return parity + + validate(module.compile()) + + +def test_for_continue(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def measure(q: Qubit) -> bool: + ... + + @guppy(module) + def test(qs: linst[Qubit]) -> int: + x = 0 + for q in qs: + if measure(q): + continue + x += 1 + return x + + validate(module.compile()) + + +def test_for_nonlinear_break(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.type(module, tys.Tuple(inner=[])) + class MyIter: + """An iterator that yields linear values but is not linear itself.""" + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[Qubit, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + @guppy.type(module, tys.Tuple(inner=[])) + class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + @guppy.declare(module) + def measure(q: Qubit) -> bool: + ... + + @guppy(module) + def test(mt: MyType, xs: list[int]) -> None: + # We can break, since `mt` itself is not linear + for q in mt: + if measure(q): + break + + validate(module.compile()) + + def test_rus(validate): module = GuppyModule("test") module.load(quantum) diff --git a/tests/integration/test_linst.py b/tests/integration/test_linst.py new file mode 100644 index 00000000..4086d234 --- /dev/null +++ b/tests/integration/test_linst.py @@ -0,0 +1,66 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.prelude.builtins import linst +from guppy.prelude.quantum import Qubit, h + +import guppy.prelude.quantum as quantum + + +def test_types(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test( + xs: linst[Qubit], ys: linst[tuple[int, Qubit]] + ) -> tuple[linst[Qubit], linst[tuple[int, Qubit]]]: + return xs, ys + + validate(module.compile()) + + +def test_len(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(xs: linst[Qubit]) -> tuple[int, linst[Qubit]]: + return len(xs) + + validate(module.compile()) + + +def test_literal(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(q1: Qubit, q2: Qubit) -> linst[Qubit]: + return [q1, h(q2)] + + validate(module.compile()) + + +def test_arith(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(xs: linst[Qubit], ys: linst[Qubit], q: Qubit) -> linst[Qubit]: + xs += [q] + return xs + ys + + validate(module.compile()) + + +def test_copyable(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test() -> linst[int]: + xs: linst[int] = [1, 2, 3] + ys: linst[int] = [] + return xs + xs + + validate(module.compile()) diff --git a/tests/integration/test_list.py b/tests/integration/test_list.py new file mode 100644 index 00000000..e59db276 --- /dev/null +++ b/tests/integration/test_list.py @@ -0,0 +1,45 @@ +from guppy.decorator import guppy + + +def test_types(validate): + @guppy + def test( + xs: list[int], ys: list[tuple[int, float]] + ) -> tuple[list[int], list[tuple[int, float]]]: + return xs, ys + + validate(test) + + +def test_len(validate): + @guppy + def test(xs: list[int]) -> int: + return len(xs) + + validate(test) + + +def test_literal(validate): + @guppy + def test(x: float) -> list[float]: + return [1.0, 2.0, 3.0, 4.0 + x] + + validate(test) + + +def test_arith(validate): + @guppy + def test(xs: list[int]) -> list[int]: + xs += xs + [42] + xs = 3 * xs + return xs * 4 + + validate(test) + + +def test_subscript(validate): + @guppy + def test(xs: list[float], i: int) -> float: + return xs[2 * i] + + validate(test) diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py index 2446e308..ee8f112c 100644 --- a/tests/integration/test_poly.py +++ b/tests/integration/test_poly.py @@ -134,6 +134,22 @@ def main() -> None: validate(module.compile()) +def test_infer_list(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo() -> T: + ... + + @guppy(module) + def main() -> None: + xs: list[int] = [foo()] + ys = [1.0, foo()] + + validate(module.compile()) + + def test_infer_nested(validate): module = GuppyModule("test") T = guppy.type_var(module, "T") diff --git a/validator/src/lib.rs b/validator/src/lib.rs index 4d31a136..fb6d463f 100644 --- a/validator/src/lib.rs +++ b/validator/src/lib.rs @@ -1,6 +1,8 @@ use hugr::extension::{ExtensionRegistry, PRELUDE}; use hugr::std_extensions::arithmetic::{float_ops, float_types, int_ops, int_types}; +use hugr::std_extensions::collections; use hugr::std_extensions::logic; +use hugr::HugrView; use lazy_static::lazy_static; use pyo3::prelude::*; @@ -11,7 +13,8 @@ lazy_static! { int_types::extension(), int_ops::EXTENSION.to_owned(), float_types::extension(), - float_ops::extension() + float_ops::extension(), + collections::EXTENSION.to_owned(), ]) .unwrap(); } From f51f0bfec54c5b1edd594def73e36c3877bc85ac Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 8 Jan 2024 12:20:36 +0100 Subject: [PATCH 06/31] Improve instance method helper function --- guppy/checker/expr_checker.py | 62 +++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 2b1d9f27..392824e8 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -22,7 +22,7 @@ import ast from contextlib import suppress -from typing import Any, NoReturn, cast +from typing import Any, NoReturn from guppy.ast_util import ( AstNode, @@ -363,11 +363,18 @@ def _synthesize_instance_func( args: list[ast.expr], func_name: str, err: str, - exp_ty: FunctionType | None = None, - var: FreeTypeVar | None = None, + exp_sig: FunctionType | None = None, give_reason: bool = False, ) -> tuple[ast.expr, GuppyType]: - """Helper method for expressions that are implemented via instance methods.""" + """Helper method for expressions that are implemented via instance methods. + + Raises a `GuppyTypeError` if the given instance method is not defined. The error + message can be customised by passing an `err` string and an optional error + reason can be printed. + + Optionally, the signature of the instance function can also be checked against a + given expected signature. + """ node, ty = self.synthesize(node) func = self.ctx.globals.get_instance_func(ty, func_name) if func is None: @@ -376,15 +383,12 @@ def _synthesize_instance_func( f"Expression of type `{ty}` is {err}{reason if give_reason else ''}", node, ) - if exp_ty: - assert var is not None - exp_ty = cast(FunctionType, exp_ty.substitute({var: ty})) - if unify(exp_ty, func.ty.unquantified()[0], {}) is None: - raise GuppyError( - f"Method `{ty.name}.{func_name}` has signature `{func.ty}`, but " - f"expected `{exp_ty}`", - node, - ) + if exp_sig and unify(exp_sig, func.ty.unquantified()[0], {}) is None: + raise GuppyError( + f"Method `{ty.name}.{func_name}` has signature `{func.ty}`, but " + f"expected `{exp_sig}`", + node, + ) return func.synthesize_call([node, *args], node, self.ctx) def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: @@ -400,13 +404,13 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: return self._synthesize_binary(left_expr, right_expr, op, node) def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, GuppyType]: - var = FreeTypeVar.new("T", False) - exp_ty = FunctionType( - [var, FreeTypeVar.new("Key", False)], + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType( + [ty, FreeTypeVar.new("Key", False)], FreeTypeVar.new("Val", False), ) return self._synthesize_instance_func( - node.value, [node.slice], "__getitem__", "not subscriptable", exp_ty, var + node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig ) def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: @@ -431,10 +435,10 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]: - var = FreeTypeVar.new("T", False) - exp_ty = FunctionType([var], FreeTypeVar.new("Iter", False)) + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], FreeTypeVar.new("Iter", False)) expr, ty = self._synthesize_instance_func( - node.value, [], "__iter__", "not iterable", exp_ty, var + node.value, [], "__iter__", "not iterable", exp_sig ) # If the iterator was created by a `for` loop, we can add some extra checks to @@ -453,24 +457,24 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]: return expr, ty def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, GuppyType]: - var = FreeTypeVar.new("Iter", False) - exp_ty = FunctionType([var], TupleType([BoolType(), var])) + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], TupleType([BoolType(), ty])) return self._synthesize_instance_func( - node.value, [], "__hasnext__", "not an iterator", exp_ty, var, True + node.value, [], "__hasnext__", "not an iterator", exp_sig, True ) def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, GuppyType]: - var = FreeTypeVar.new("Iter", False) - exp_ty = FunctionType([var], TupleType([FreeTypeVar.new("T", True), var])) + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], TupleType([FreeTypeVar.new("T", True), ty])) return self._synthesize_instance_func( - node.value, [], "__next__", "not an iterator", exp_ty, var, True + node.value, [], "__next__", "not an iterator", exp_sig, True ) def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, GuppyType]: - var = FreeTypeVar.new("Iter", False) - exp_ty = FunctionType([var], NoneType()) + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], NoneType()) return self._synthesize_instance_func( - node.value, [], "__end__", "not an iterator", exp_ty, var, True + node.value, [], "__end__", "not an iterator", exp_sig, True ) def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, GuppyType]: From 01319cc0f0c510411bc5d4f542faf594911f1cdb Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 8 Jan 2024 12:35:37 +0100 Subject: [PATCH 07/31] Don't override usage stats --- guppy/cfg/bb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index 84bb6cb1..91ed5455 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -135,7 +135,7 @@ def visit_DesugaredListComp(self, node: DesugaredListComp) -> None: inner_visitor.visit(cond) inner_visitor.visit(node.elt) - self.stats.used = { + self.stats.used |= { x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned } From 2359ab0b447872c49fb6a9802440a63af59aef80 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 8 Jan 2024 12:41:30 +0100 Subject: [PATCH 08/31] Include parent scope in Locals.__iter__ --- guppy/checker/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 9827c7b5..bb0d9dda 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -125,7 +125,8 @@ def __setitem__(self, key: str, value: Variable) -> None: self.vars[key] = value def __iter__(self) -> Iterator[str]: - return iter(self.keys()) + parent_iter = iter(self.parent_scope) if self.parent_scope else iter(()) + return itertools.chain(iter(self.vars), parent_iter) def __contains__(self, item: str) -> bool: return (item in self.vars) or ( From 81dcddb9f792e95038cdb047b2455dec13d66406 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 8 Jan 2024 15:07:00 +0100 Subject: [PATCH 09/31] Don't require linear return for __next__ --- guppy/checker/expr_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 392824e8..6171fbab 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -465,7 +465,7 @@ def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, GuppyType]: def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, GuppyType]: node.value, ty = self.synthesize(node.value) - exp_sig = FunctionType([ty], TupleType([FreeTypeVar.new("T", True), ty])) + exp_sig = FunctionType([ty], TupleType([FreeTypeVar.new("T", False), ty])) return self._synthesize_instance_func( node.value, [], "__next__", "not an iterator", exp_sig, True ) From 34e9ce2e2347c3e870401fc0d330507558517287 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 9 Jan 2024 14:49:19 +0100 Subject: [PATCH 10/31] Improve error msg --- guppy/checker/expr_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 6171fbab..31505b7b 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -451,7 +451,7 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]: if breaks: raise GuppyTypeError( f"Loop over iterator with linear type `{ty}` cannot be terminated " - f"(cannot ensure that all values have been used)", + f"prematurely", breaks[0], ) return expr, ty From f8e6768c5e296bb37a60cadbf5f5c955fc5096a0 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 10:09:41 +0000 Subject: [PATCH 11/31] Add comment to visit_List --- guppy/compiler/expr_compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index dff60963..f90a0809 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -134,6 +134,7 @@ def visit_Tuple(self, node: ast.Tuple) -> OutPortV: ).out_port(0) def visit_List(self, node: ast.List) -> OutPortV: + # Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension return self.graph.add_node( ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts] ).add_out_port(get_type(node)) From b02dc3de46820123994eabea589447587837bab8 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 10:18:14 +0000 Subject: [PATCH 12/31] Check uniqueness of input names --- guppy/compiler/expr_compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index f90a0809..c7e0ba4f 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -65,6 +65,8 @@ def _new_dfcontainer( """ old = self.dfg inp = self.graph.add_input(parent=node) + # Check that the input names are unique + assert len({inp.id for inp in inputs}) == len(inputs), "Inputs are not unique" new_locals = { name.id: PortVariable(name.id, inp.add_out_port(get_type(name)), name, None) for name in inputs From 34ba5e19e68a8a66fa6aa7e7d248748a6db651ea Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 10:28:51 +0000 Subject: [PATCH 13/31] Rename _new_loop inputs to variants --- guppy/compiler/expr_compiler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index c7e0ba4f..b8dda837 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -79,7 +79,7 @@ def _new_dfcontainer( @contextmanager def _new_loop( self, - inputs: list[ast.Name], + variants: list[ast.Name], branch: ast.Name, parent: DFContainingNode | None = None, ) -> Iterator[None]: @@ -87,15 +87,15 @@ def _new_loop( Automatically adds the `Output` node once the context manager exists. """ - loop = self.graph.add_tail_loop([self.visit(name) for name in inputs], parent) - with self._new_dfcontainer(inputs, loop): + loop = self.graph.add_tail_loop([self.visit(name) for name in variants], parent) + with self._new_dfcontainer(variants, loop): yield # Output the branch predicate and the inputs for the next iteration self.graph.add_output( - [self.visit(branch), *(self.visit(name) for name in inputs)] + [self.visit(branch), *(self.visit(name) for name in variants)] ) # Update the DFG with the outputs from the loop - for name in inputs: + for name in variants: self.dfg[name.id].port = loop.add_out_port(get_type(name)) @contextmanager From 17aca097636d7ef9a1ff3f34f906b420e44ca943 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 10:31:07 +0000 Subject: [PATCH 14/31] Fix spelling and clarify docstring --- guppy/compiler/expr_compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index b8dda837..a41ab185 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -85,7 +85,8 @@ def _new_loop( ) -> Iterator[None]: """Context manager to build a graph inside a new `TailLoop` node. - Automatically adds the `Output` node once the context manager exists. + Automatically adds the `Output` node to the loop body once the context manager + exits. """ loop = self.graph.add_tail_loop([self.visit(name) for name in variants], parent) with self._new_dfcontainer(variants, loop): @@ -104,7 +105,7 @@ def _new_case( ) -> Iterator[None]: """Context manager to build a graph inside a new `Case` node. - Automatically adds the `Output` node once the context manager exists. + Automatically adds the `Output` node once the context manager exits. """ with self._new_dfcontainer(inputs, self.graph.add_case(cond_node)): yield From 76390894e3acafec71753d754040e3b703eb1760 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 10:40:25 +0000 Subject: [PATCH 15/31] Fix _new_case inputs for outputs --- guppy/compiler/expr_compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index a41ab185..0579e8cd 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -113,7 +113,7 @@ def _new_case( # Update the DFG with the outputs from the Conditional node, but only we haven't # already added some if cond_node.num_out_ports == 0: - for name in inputs: + for name in outputs: self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) def visit_Constant(self, node: ast.Constant) -> OutPortV: From 8cb0df01960e284e1dfa66f9cc703e43585de10f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 10:43:50 +0000 Subject: [PATCH 16/31] Add comment explaining passing `inputs` twice --- guppy/compiler/expr_compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index 0579e8cd..feb6caad 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -237,7 +237,8 @@ def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: [self.visit(gen.iter), self.visit(list_name)], ) - # If the iterator is finished, output the iterator and list as is + # If the iterator is finished, output the iterator and list as is (this + # is achieved by passing `inputs, inputs` below) with self._new_case(inputs, inputs, cond): pass From 20b2d99e5dc73661628cbd8b32ed511c465b8f48 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 11:00:18 +0000 Subject: [PATCH 17/31] Use inputs comprehension --- guppy/compiler/expr_compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index feb6caad..bcf3b4eb 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -234,7 +234,7 @@ def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: compiler.compile_stmts([gen.hasnext_assign], self.dfg) cond = self.graph.add_conditional( self.visit(gen.hasnext), - [self.visit(gen.iter), self.visit(list_name)], + [self.visit(inp) for inp in inputs], ) # If the iterator is finished, output the iterator and list as is (this From a14ae4505e23bf2fc504d84c8821657a0afc66c9 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 11:01:50 +0000 Subject: [PATCH 18/31] Fix comments --- guppy/hugr/hugr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index a643d065..bbd81588 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -708,14 +708,14 @@ def insert_order_edges(self) -> "Hugr": continue if src.node.parent != tgt.node.parent: - # Walk up the hierarchy from the src until we hit a node at the same - # level as tgt + # Walk up the hierarchy from the tgt until we hit a node at the same + # level as src node = tgt.node while node.parent != src.node.parent: if node.parent is None: raise ValueError("Invalid non-local edge!") node = node.parent - # Edge order edge to make sure that the src is executed first + # Add order edge to make sure that the src is executed first self.add_order_edge(src.node, node) return self From 8f561e25fc34fe2e4fe4bca266fa4b005b14bbef Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 11 Jan 2024 11:30:54 +0000 Subject: [PATCH 19/31] Add _if_true context manager --- guppy/compiler/expr_compiler.py | 60 ++++++++++++++++----------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index bcf3b4eb..cd083248 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -110,11 +110,25 @@ def _new_case( with self._new_dfcontainer(inputs, self.graph.add_case(cond_node)): yield self.graph.add_output([self.visit(name) for name in outputs]) - # Update the DFG with the outputs from the Conditional node, but only we haven't - # already added some - if cond_node.num_out_ports == 0: - for name in outputs: - self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) + + @contextmanager + def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]: + """Context manager to build a graph inside the `true` case of a `Conditional` + + In the `false` case, the inputs are outputted as is. + """ + cond_node = self.graph.add_conditional( + self.visit(cond), [self.visit(inp) for inp in inputs] + ) + # If the condition is false, output the inputs as is + with self._new_case(inputs, inputs, cond_node): + pass + # If the condition is true, we enter the `with` block + with self._new_case(inputs, inputs, cond_node): + yield + # Update the DFG with the outputs from the Conditional node + for name in inputs: + self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) def visit_Constant(self, node: ast.Constant) -> OutPortV: if value := python_value_to_hugr(node.value): @@ -230,38 +244,22 @@ def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: compiler.compile_stmts([gen.iter_assign], self.dfg) inputs = [gen.iter, list_name] with self._new_loop(inputs, gen.hasnext): - # Compile the `hasnext` check and plug it into a conditional - compiler.compile_stmts([gen.hasnext_assign], self.dfg) - cond = self.graph.add_conditional( - self.visit(gen.hasnext), - [self.visit(inp) for inp in inputs], - ) - - # If the iterator is finished, output the iterator and list as is (this - # is achieved by passing `inputs, inputs` below) - with self._new_case(inputs, inputs, cond): - pass - # If there is a next element, compile it and continue with the next # generator - with self._new_case(inputs, inputs, cond): + compiler.compile_stmts([gen.hasnext_assign], self.dfg) + with self._if_true(gen.hasnext, inputs): def compile_ifs(ifs: list[ast.expr]) -> None: - if not ifs: + """Helper function to compile a series of if-guards into nested + Conditional nodes.""" + if ifs: + if_expr, *ifs = ifs + # If the condition is true, continue with the next one + with self._if_true(if_expr, inputs): + compile_ifs(ifs) + else: # If there are no guards left, compile the next generator compile_generators(elt, gens) - return - if_expr, *ifs = ifs - cond = self.graph.add_conditional( - self.visit(if_expr), - [self.visit(gen.iter), self.visit(list_name)], - ) - # If the condition is false, output the iterator and list as is - with self._new_case(inputs, inputs, cond): - pass - # If the condition is true, continue with the next one - with self._new_case(inputs, inputs, cond): - compile_ifs(ifs) compiler.compile_stmts([gen.next_assign], self.dfg) compile_ifs(gen.ifs) From 1fb02a056ae6ff689ef9470ee35cccd98c6c6c2f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 15 Jan 2024 10:15:24 +0000 Subject: [PATCH 20/31] Rename variants to loop_vars --- guppy/compiler/expr_compiler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index cd083248..e00dcf05 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -79,7 +79,7 @@ def _new_dfcontainer( @contextmanager def _new_loop( self, - variants: list[ast.Name], + loop_vars: list[ast.Name], branch: ast.Name, parent: DFContainingNode | None = None, ) -> Iterator[None]: @@ -88,15 +88,15 @@ def _new_loop( Automatically adds the `Output` node to the loop body once the context manager exits. """ - loop = self.graph.add_tail_loop([self.visit(name) for name in variants], parent) - with self._new_dfcontainer(variants, loop): + loop = self.graph.add_tail_loop([self.visit(name) for name in loop_vars], parent) + with self._new_dfcontainer(loop_vars, loop): yield # Output the branch predicate and the inputs for the next iteration self.graph.add_output( - [self.visit(branch), *(self.visit(name) for name in variants)] + [self.visit(branch), *(self.visit(name) for name in loop_vars)] ) # Update the DFG with the outputs from the loop - for name in variants: + for name in loop_vars: self.dfg[name.id].port = loop.add_out_port(get_type(name)) @contextmanager From 13eac52a82d0145707cb3631be59be9e49fac96c Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 15 Jan 2024 10:18:04 +0000 Subject: [PATCH 21/31] Add comment explaining fresh visit calls --- guppy/compiler/expr_compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index e00dcf05..30cace3c 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -93,6 +93,8 @@ def _new_loop( yield # Output the branch predicate and the inputs for the next iteration self.graph.add_output( + # Note that we have to do fresh calls to `self.visit` here since we're + # in a new context [self.visit(branch), *(self.visit(name) for name in loop_vars)] ) # Update the DFG with the outputs from the loop From a6ca99efe89b2cd6ccf4acf153643abdb082158d Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 15 Jan 2024 16:36:31 +0000 Subject: [PATCH 22/31] Remove double validate --- tests/integration/test_for.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integration/test_for.py b/tests/integration/test_for.py index 585d75e4..917c3ffc 100644 --- a/tests/integration/test_for.py +++ b/tests/integration/test_for.py @@ -81,8 +81,6 @@ def foo(xs: list[int]) -> int: validate(foo) - validate(foo) - def test_return_in_loop(validate): @guppy From 7630f7b8cfc13e973ae50fe5dd241c0b0cf0b01f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 15 Jan 2024 17:40:17 +0000 Subject: [PATCH 23/31] Add make_assign helper --- guppy/cfg/builder.py | 54 +++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 0c42207c..e67a5bf5 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -1,7 +1,7 @@ import ast import itertools from collections.abc import Iterator -from typing import NamedTuple, cast +from typing import NamedTuple from guppy.ast_util import ( AstVisitor, @@ -258,9 +258,8 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]: @classmethod def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None: """Adds a temporary variable assignment to a basic block.""" - node = ast.Assign(targets=[make_var(tmp_name, value)], value=value) - set_location_from(node, value) - bb.statements.append(node) + lhs = make_var(tmp_name, value) + bb.statements.append(make_assign([lhs], value)) def visit_Name(self, node: ast.Name) -> ast.Name: return node @@ -309,31 +308,24 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST: if g.is_async: raise GuppyError("Async generators are not supported", g) g.iter = self.visit(g.iter) - gen = DesugaredGenerator() - - template = """ - it = make_iter - b, it = has_next - x, it = get_next - """ it = make_var(next(tmp_vars), g.iter) - b = make_var(next(tmp_vars), g.iter) - [gen.iter_assign, gen.hasnext_assign, gen.next_assign] = cast( - list[ast.Assign], - template_replace( - template, - g.iter, - it=it, - b=b, - x=g.target, - make_iter=with_loc(it, MakeIter(value=g.iter, origin_node=node)), - has_next=with_loc(it, IterHasNext(value=it)), - get_next=with_loc(it, IterNext(value=it)), + hasnext = make_var(next(tmp_vars), g.iter) + desugared = DesugaredGenerator( + iter=it, + hasnext=hasnext, + iter_assign=make_assign( + [it], with_loc(it, MakeIter(value=g.iter, origin_node=node)) + ), + hasnext_assign=make_assign( + [hasnext, it], with_loc(it, IterHasNext(value=it)) ), + next_assign=make_assign( + [g.target, it], with_loc(it, IterNext(value=it)) + ), + iterend=with_loc(it, IterEnd(value=it)), + ifs=g.ifs, ) - gen.iterend = with_loc(it, IterEnd(value=it)) - gen.iter, gen.hasnext, gen.ifs = it, b, g.ifs - gens.append(gen) + gens.append(desugared) node.elt = self.visit(node.elt) return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) @@ -507,3 +499,13 @@ def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: if loc is not None: set_location_from(node, loc) return node + + +def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign: + """Creates an `ast.Assign` node.""" + assert len(lhs) > 0 + if len(lhs) == 1: + target = lhs[0] + else: + target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store())) + return with_loc(value, ast.Assign(targets=[target], value=value)) From 70c5d07b9a9dbee36d06683090dc71f3dcb42774 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 12:37:00 +0000 Subject: [PATCH 24/31] Turn TemplateReplacer into dataclass --- guppy/ast_util.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 31bb940f..bfa44a7c 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -1,6 +1,7 @@ import ast import textwrap from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast if TYPE_CHECKING: @@ -154,20 +155,13 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: return ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx) +@dataclass(frozen=True, eq=False) class TemplateReplacer(ast.NodeTransformer): """Replaces nodes in a template.""" replacements: Mapping[str, ast.AST | Sequence[ast.AST]] default_loc: ast.AST - def __init__( - self, - replacements: Mapping[str, ast.AST | Sequence[ast.AST]], - default_loc: ast.AST, - ) -> None: - self.replacements = replacements - self.default_loc = default_loc - def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]: if x not in self.replacements: msg = f"No replacement for `{x}` is given" From fd8f56c09c593c4f374ae27ae6085ec10ec67116 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 12:40:30 +0000 Subject: [PATCH 25/31] Add missing location --- guppy/ast_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/guppy/ast_util.py b/guppy/ast_util.py index bfa44a7c..b5362b46 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -152,7 +152,10 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: ) def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: - return ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx) + return with_loc( + node, + ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx), + ) @dataclass(frozen=True, eq=False) From 679016121d602909187c9e90bf6b53a5951e1581 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 12:48:32 +0000 Subject: [PATCH 26/31] Adjust context in make_assign --- guppy/cfg/builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index e67a5bf5..1310562e 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -5,6 +5,7 @@ from guppy.ast_util import ( AstVisitor, + ContextAdjuster, find_nodes, set_location_from, template_replace, @@ -504,6 +505,8 @@ def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign: """Creates an `ast.Assign` node.""" assert len(lhs) > 0 + adjuster = ContextAdjuster(ast.Store()) + lhs = [adjuster.visit(expr) for expr in lhs] if len(lhs) == 1: target = lhs[0] else: From 81d17222c09aade30688af10017665c6859484c1 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 12:59:09 +0000 Subject: [PATCH 27/31] Fix formatting --- guppy/compiler/expr_compiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index dfbc7330..1d71556e 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -88,7 +88,9 @@ def _new_loop( Automatically adds the `Output` node to the loop body once the context manager exits. """ - loop = self.graph.add_tail_loop([self.visit(name) for name in loop_vars], parent) + loop = self.graph.add_tail_loop( + [self.visit(name) for name in loop_vars], parent + ) with self._new_dfcontainer(loop_vars, loop): yield # Output the branch predicate and the inputs for the next iteration From f36da02ea9d7808642dbe5298be98ef6b4c8eabe Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 13:14:12 +0000 Subject: [PATCH 28/31] Rename Tuple to TupleType --- tests/error/iter_errors/end_missing.py | 4 ++-- tests/error/iter_errors/end_wrong_type.py | 4 ++-- tests/error/iter_errors/hasnext_missing.py | 4 ++-- tests/error/iter_errors/hasnext_wrong_type.py | 4 ++-- tests/error/iter_errors/iter_missing.py | 2 +- tests/error/iter_errors/iter_wrong_type.py | 2 +- tests/error/iter_errors/next_missing.py | 4 ++-- tests/error/iter_errors/next_wrong_type.py | 4 ++-- tests/integration/test_comprehension.py | 6 +++--- tests/integration/test_linear.py | 4 ++-- 10 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/error/iter_errors/end_missing.py b/tests/error/iter_errors/end_missing.py index 4a8f6e0f..33eafe12 100644 --- a/tests/error/iter_errors/end_missing.py +++ b/tests/error/iter_errors/end_missing.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyIter: """An iterator that is missing the `__end__` method.""" @@ -19,7 +19,7 @@ def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: ... -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyType: """Type that produces the iterator above.""" diff --git a/tests/error/iter_errors/end_wrong_type.py b/tests/error/iter_errors/end_wrong_type.py index edcb175e..82568d2b 100644 --- a/tests/error/iter_errors/end_wrong_type.py +++ b/tests/error/iter_errors/end_wrong_type.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyIter: """An iterator where the `__end__` method has the wrong signature.""" @@ -23,7 +23,7 @@ def __end__(self: "MyIter") -> "MyIter": ... -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyType: """Type that produces the iterator above.""" diff --git a/tests/error/iter_errors/hasnext_missing.py b/tests/error/iter_errors/hasnext_missing.py index a0046e4e..68b11c7a 100644 --- a/tests/error/iter_errors/hasnext_missing.py +++ b/tests/error/iter_errors/hasnext_missing.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyIter: """An iterator that is missing the `__hasnext__` method.""" @@ -19,7 +19,7 @@ def __end__(self: "MyIter") -> None: ... -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyType: """Type that produces the iterator above.""" diff --git a/tests/error/iter_errors/hasnext_wrong_type.py b/tests/error/iter_errors/hasnext_wrong_type.py index 9f4e6929..f276f1b7 100644 --- a/tests/error/iter_errors/hasnext_wrong_type.py +++ b/tests/error/iter_errors/hasnext_wrong_type.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyIter: """An iterator where the `__hasnext__` method has the wrong signature.""" @@ -23,7 +23,7 @@ def __end__(self: "MyIter") -> None: ... -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyType: """Type that produces the iterator above.""" diff --git a/tests/error/iter_errors/iter_missing.py b/tests/error/iter_errors/iter_missing.py index dab0c6c6..de04fdb6 100644 --- a/tests/error/iter_errors/iter_missing.py +++ b/tests/error/iter_errors/iter_missing.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyType: """A non-iterable type.""" diff --git a/tests/error/iter_errors/iter_wrong_type.py b/tests/error/iter_errors/iter_wrong_type.py index 4fa63c37..39a34163 100644 --- a/tests/error/iter_errors/iter_wrong_type.py +++ b/tests/error/iter_errors/iter_wrong_type.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyType: """A type where the `__iter__` method has the wrong signature.""" diff --git a/tests/error/iter_errors/next_missing.py b/tests/error/iter_errors/next_missing.py index 8a1e509b..8142d1f6 100644 --- a/tests/error/iter_errors/next_missing.py +++ b/tests/error/iter_errors/next_missing.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyIter: """An iterator that is missing the `__next__` method.""" @@ -19,7 +19,7 @@ def __end__(self: "MyIter") -> None: ... -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyType: """Type that produces the iterator above.""" diff --git a/tests/error/iter_errors/next_wrong_type.py b/tests/error/iter_errors/next_wrong_type.py index 4b454616..0d5ae680 100644 --- a/tests/error/iter_errors/next_wrong_type.py +++ b/tests/error/iter_errors/next_wrong_type.py @@ -6,7 +6,7 @@ module = GuppyModule("test") -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyIter: """An iterator where the `__next__` method has the wrong signature.""" @@ -23,7 +23,7 @@ def __end__(self: "MyIter") -> None: ... -@guppy.type(module, tys.Tuple(inner=[])) +@guppy.type(module, tys.TupleType(inner=[])) class MyType: """Type that produces the iterator above.""" diff --git a/tests/integration/test_comprehension.py b/tests/integration/test_comprehension.py index 9587f775..aec2f459 100644 --- a/tests/integration/test_comprehension.py +++ b/tests/integration/test_comprehension.py @@ -188,7 +188,7 @@ def test_linear_next_nonlinear_iter(validate): module = GuppyModule("test") module.load(quantum) - @guppy.type(module, tys.Tuple(inner=[])) + @guppy.type(module, tys.TupleType(inner=[])) class MyIter: """An iterator that yields linear values but is not linear itself.""" @@ -204,7 +204,7 @@ def __next__(self: "MyIter") -> tuple[Qubit, "MyIter"]: def __end__(self: "MyIter") -> None: ... - @guppy.type(module, tys.Tuple(inner=[])) + @guppy.type(module, tys.TupleType(inner=[])) class MyType: """Type that produces the iterator above.""" @@ -244,7 +244,7 @@ def __next__(self: "MyIter") -> tuple[int, "MyIter"]: def __end__(self: "MyIter") -> None: ... - @guppy.type(module, tys.Tuple(inner=[])) + @guppy.type(module, tys.TupleType(inner=[])) class MyType: """Type that produces the iterator above.""" diff --git a/tests/integration/test_linear.py b/tests/integration/test_linear.py index b7597416..cdd975d9 100644 --- a/tests/integration/test_linear.py +++ b/tests/integration/test_linear.py @@ -266,7 +266,7 @@ def test_for_nonlinear_break(validate): module = GuppyModule("test") module.load(quantum) - @guppy.type(module, tys.Tuple(inner=[])) + @guppy.type(module, tys.TupleType(inner=[])) class MyIter: """An iterator that yields linear values but is not linear itself.""" @@ -282,7 +282,7 @@ def __next__(self: "MyIter") -> tuple[Qubit, "MyIter"]: def __end__(self: "MyIter") -> None: ... - @guppy.type(module, tys.Tuple(inner=[])) + @guppy.type(module, tys.TupleType(inner=[])) class MyType: """Type that produces the iterator above.""" From dba93fda4fe304b8507ca02ec07d1d16279d8860 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 13:18:11 +0000 Subject: [PATCH 29/31] Stop using old test decorator --- tests/error/errors_on_usage/for_new_var.py | 2 +- tests/error/errors_on_usage/for_target.py | 2 +- tests/error/errors_on_usage/for_target_type_change.py | 2 +- tests/error/errors_on_usage/for_type_change.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/error/errors_on_usage/for_new_var.py b/tests/error/errors_on_usage/for_new_var.py index f0f35a0f..1054ad61 100644 --- a/tests/error/errors_on_usage/for_new_var.py +++ b/tests/error/errors_on_usage/for_new_var.py @@ -1,4 +1,4 @@ -from tests.error.util import guppy +from guppy.decorator import guppy @guppy diff --git a/tests/error/errors_on_usage/for_target.py b/tests/error/errors_on_usage/for_target.py index 3d9c9354..fe4c0eb1 100644 --- a/tests/error/errors_on_usage/for_target.py +++ b/tests/error/errors_on_usage/for_target.py @@ -1,4 +1,4 @@ -from tests.error.util import guppy +from guppy.decorator import guppy @guppy diff --git a/tests/error/errors_on_usage/for_target_type_change.py b/tests/error/errors_on_usage/for_target_type_change.py index 5c2cfc54..7bc2b0b3 100644 --- a/tests/error/errors_on_usage/for_target_type_change.py +++ b/tests/error/errors_on_usage/for_target_type_change.py @@ -1,4 +1,4 @@ -from tests.error.util import guppy +from guppy.decorator import guppy @guppy diff --git a/tests/error/errors_on_usage/for_type_change.py b/tests/error/errors_on_usage/for_type_change.py index b22b20ab..96f4017c 100644 --- a/tests/error/errors_on_usage/for_type_change.py +++ b/tests/error/errors_on_usage/for_type_change.py @@ -1,4 +1,4 @@ -from tests.error.util import guppy +from guppy.decorator import guppy @guppy From a31e46f6c90ce8de12215ded361609bae916659b Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 13:19:12 +0000 Subject: [PATCH 30/31] Fix updated error message --- tests/error/linear_errors/for_break.err | 2 +- tests/error/linear_errors/for_return.err | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/error/linear_errors/for_break.err b/tests/error/linear_errors/for_break.err index 2d4c195f..271f11ed 100644 --- a/tests/error/linear_errors/for_break.err +++ b/tests/error/linear_errors/for_break.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:17 16: if b: 17: break ^^^^^ -GuppyTypeError: Loop over iterator with linear type `linst[(Qubit, bool)]` cannot be terminated (cannot ensure that all values have been used) +GuppyTypeError: Loop over iterator with linear type `linst[(Qubit, bool)]` cannot be terminated prematurely diff --git a/tests/error/linear_errors/for_return.err b/tests/error/linear_errors/for_return.err index aae47b8a..33fab35a 100644 --- a/tests/error/linear_errors/for_return.err +++ b/tests/error/linear_errors/for_return.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:17 16: if b: 17: return [] ^^^^^^^^^ -GuppyTypeError: Loop over iterator with linear type `linst[(Qubit, bool)]` cannot be terminated (cannot ensure that all values have been used) +GuppyTypeError: Loop over iterator with linear type `linst[(Qubit, bool)]` cannot be terminated prematurely From 53abfeaf9f856f015e9bcb93f1f615a53b017651 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 13:22:06 +0000 Subject: [PATCH 31/31] Fix error message file placeholder --- tests/error/iter_errors/end_missing.err | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/error/iter_errors/end_missing.err b/tests/error/iter_errors/end_missing.err index 42f0022b..90ad0692 100644 --- a/tests/error/iter_errors/end_missing.err +++ b/tests/error/iter_errors/end_missing.err @@ -1,4 +1,4 @@ -Guppy compilation failed. Error in file /Users/mark.koch/code/guppy/tests/error/iter_errors/end_missing.py:33 +Guppy compilation failed. Error in file $FILE:33 31: @guppy(module) 32: def test(x: MyType) -> None: