diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 5a9e66f4..b5362b46 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -1,5 +1,8 @@ import ast -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +import textwrap +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast if TYPE_CHECKING: from guppy.gtypes import GuppyType @@ -54,51 +57,161 @@ def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> T: raise NotImplementedError(f"visit_{node.__class__.__name__} is not implemented") -class NameVisitor(ast.NodeVisitor): - """Visitor to collect all `Name` nodes occurring in an AST.""" - - names: list[ast.Name] - - def __init__(self) -> None: - self.names = [] - - def visit_Name(self, node: ast.Name) -> None: - self.names.append(node) +class AstSearcher(ast.NodeVisitor): + """Visitor that searches for occurrences of specific nodes in an AST.""" + + matcher: Callable[[ast.AST], bool] + dont_recurse_into: set[type[ast.AST]] + found: list[ast.AST] + is_first_node: bool + + def __init__( + self, + matcher: Callable[[ast.AST], bool], + dont_recurse_into: set[type[ast.AST]] | None = None, + ) -> None: + self.matcher = matcher + self.dont_recurse_into = dont_recurse_into or set() + self.found = [] + self.is_first_node = True + + def generic_visit(self, node: ast.AST) -> None: + if self.matcher(node): + self.found.append(node) + if self.is_first_node or type(node) not in self.dont_recurse_into: + self.is_first_node = False + super().generic_visit(node) + + +def find_nodes( + matcher: Callable[[ast.AST], bool], + node: ast.AST, + dont_recurse_into: set[type[ast.AST]] | None = None, +) -> list[ast.AST]: + """Returns all nodes in the AST that satisfy the matcher.""" + v = AstSearcher(matcher, dont_recurse_into) + v.visit(node) + return v.found def name_nodes_in_ast(node: Any) -> list[ast.Name]: """Returns all `Name` nodes occurring in an AST.""" - v = NameVisitor() - v.visit(node) - return v.names - - -class ReturnVisitor(ast.NodeVisitor): - """Visitor to collect all `Return` nodes occurring in an AST.""" + found = find_nodes(lambda n: isinstance(n, ast.Name), node) + return cast(list[ast.Name], found) - returns: list[ast.Return] - inside_func_def: bool - def __init__(self) -> None: - self.returns = [] - self.inside_func_def = False - - def visit_Return(self, node: ast.Return) -> None: - self.returns.append(node) +def return_nodes_in_ast(node: Any) -> list[ast.Return]: + """Returns all `Return` nodes occurring in an AST.""" + found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef}) + return cast(list[ast.Return], found) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - # Don't descend into nested function definitions - if not self.inside_func_def: - self.inside_func_def = True - for n in node.body: - self.visit(n) +def breaks_in_loop(node: Any) -> list[ast.Break]: + """Returns all `Break` nodes occurring in a loop. -def return_nodes_in_ast(node: Any) -> list[ast.Return]: - """Returns all `Return` nodes occurring in an AST.""" - v = ReturnVisitor() - v.visit(node) - return v.returns + Note that breaks in nested loops are excluded. + """ + found = find_nodes( + lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef} + ) + return cast(list[ast.Break], found) + + +class ContextAdjuster(ast.NodeTransformer): + """Updates the `ast.Context` indicating if expressions occur on the LHS or RHS.""" + + ctx: ast.expr_context + + def __init__(self, ctx: ast.expr_context) -> None: + self.ctx = ctx + + def visit(self, node: ast.AST) -> ast.AST: + return cast(ast.AST, super().visit(node)) + + def visit_Name(self, node: ast.Name) -> ast.Name: + return with_loc(node, ast.Name(id=node.id, ctx=self.ctx)) + + def visit_Starred(self, node: ast.Starred) -> ast.Starred: + return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx)) + + def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple: + return with_loc( + node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx) + ) + + def visit_List(self, node: ast.List) -> ast.List: + return with_loc( + node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx) + ) + + def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: + # Don't adjust the slice! + return with_loc( + node, + ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx), + ) + + def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: + return with_loc( + node, + ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx), + ) + + +@dataclass(frozen=True, eq=False) +class TemplateReplacer(ast.NodeTransformer): + """Replaces nodes in a template.""" + + replacements: Mapping[str, ast.AST | Sequence[ast.AST]] + default_loc: ast.AST + + def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]: + if x not in self.replacements: + msg = f"No replacement for `{x}` is given" + raise ValueError(msg) + return self.replacements[x] + + def visit_Name(self, node: ast.Name) -> ast.AST: + repl = self._get_replacement(node.id) + if not isinstance(repl, ast.expr): + msg = f"Replacement for `{node.id}` must be an expression" + raise TypeError(msg) + + # Update the context + adjuster = ContextAdjuster(node.ctx) + return with_loc(repl, adjuster.visit(repl)) + + def visit_Expr(self, node: ast.Expr) -> ast.AST | Sequence[ast.AST]: + if isinstance(node.value, ast.Name): + repl = self._get_replacement(node.value.id) + repls = [repl] if not isinstance(repl, Sequence) else repl + # Wrap expressions to turn them into statements + return [ + with_loc(r, ast.Expr(value=r)) if isinstance(r, ast.expr) else r + for r in repls + ] + return self.generic_visit(node) + + def generic_visit(self, node: ast.AST) -> ast.AST: + # Insert the default location + node = super().generic_visit(node) + return with_loc(self.default_loc, node) + + +def template_replace( + template: str, default_loc: ast.AST, **kwargs: ast.AST | Sequence[ast.AST] +) -> list[ast.stmt]: + """Turns a template into a proper AST by substituting all placeholders.""" + nodes = ast.parse(textwrap.dedent(template)).body + replacer = TemplateReplacer(kwargs, default_loc) + new_nodes = [] + for n in nodes: + new = replacer.visit(n) + if isinstance(new, list): + new_nodes.extend(new) + else: + new_nodes.append(new) + return new_nodes def line_col(node: ast.AST) -> tuple[int, int]: diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index a91965ca..d5c89af7 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -6,7 +6,7 @@ from typing_extensions import Self from guppy.ast_util import AstNode, name_nodes_in_ast -from guppy.nodes import NestedFunctionDef, PyExpr +from guppy.nodes import DesugaredListComp, NestedFunctionDef, PyExpr if TYPE_CHECKING: from guppy.cfg.cfg import BaseCFG @@ -119,6 +119,25 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node + def visit_DesugaredListComp(self, node: DesugaredListComp) -> None: + # Names bound in the comprehension are only available inside, so we shouldn't + # update `self.stats` with assignments + inner_visitor = VariableVisitor(self.bb) + inner_stats = inner_visitor.stats + + # The generators are evaluated left to right + for gen in node.generators: + inner_visitor.visit(gen.iter_assign) + inner_visitor.visit(gen.hasnext_assign) + inner_visitor.visit(gen.next_assign) + for cond in gen.ifs: + inner_visitor.visit(cond) + inner_visitor.visit(node.elt) + + self.stats.used |= { + x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned + } + def visit_PyExpr(self, node: PyExpr) -> None: # Don't look into `py(...)` expressions pass diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 3296ec93..1310562e 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -3,13 +3,29 @@ from collections.abc import Iterator from typing import NamedTuple -from guppy.ast_util import AstVisitor, set_location_from, with_loc +from guppy.ast_util import ( + AstVisitor, + ContextAdjuster, + find_nodes, + set_location_from, + template_replace, + with_loc, +) from guppy.cfg.bb import BB, BBStatement from guppy.cfg.cfg import CFG from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError from guppy.gtypes import NoneType -from guppy.nodes import NestedFunctionDef, PyExpr +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + IterEnd, + IterHasNext, + IterNext, + MakeIter, + NestedFunctionDef, + PyExpr, +) # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -142,6 +158,35 @@ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None: # its own jumps since the body is not guaranteed to execute return tail_bb + def visit_For(self, node: ast.For, bb: BB, jumps: Jumps) -> BB | None: + template = """ + it = make_iter + while True: + b, it = has_next + if b: + x, it = get_next + body + else: + break + end_iter # Consume iterator one last time + """ + + it = make_var(next(tmp_vars), node.iter) + b = make_var(next(tmp_vars), node.iter) + new_nodes = template_replace( + template, + node, + it=it, + b=b, + x=node.target, + make_iter=with_loc(node.iter, MakeIter(value=node.iter, origin_node=node)), + has_next=with_loc(node.iter, IterHasNext(value=it)), + get_next=with_loc(node.iter, IterNext(value=it)), + end_iter=with_loc(node.iter, IterEnd(value=it)), + body=node.body, + ) + return self.visit_stmts(new_nodes, bb, jumps) + def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None: if not jumps.continue_bb: raise InternalGuppyError("Continue BB not defined") @@ -211,20 +256,11 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]: builder = ExprBuilder(cfg, bb) return builder.visit(node), builder.bb - @classmethod - def _make_var(cls, name: str, loc: ast.expr | None = None) -> ast.Name: - """Creates an `ast.Name` node.""" - node = ast.Name(id=name, ctx=ast.Load) - if loc is not None: - set_location_from(node, loc) - return node - @classmethod def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None: """Adds a temporary variable assignment to a basic block.""" - node = ast.Assign(targets=[cls._make_var(tmp_name, value)], value=value) - set_location_from(node, value) - bb.statements.append(node) + lhs = make_var(tmp_name, value) + bb.statements.append(make_assign([lhs], value)) def visit_Name(self, node: ast.Name) -> ast.Name: return node @@ -256,7 +292,44 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name: self.bb = merge_bb # The final value is stored in the temporary variable - return self._make_var(tmp, node) + return make_var(tmp, node) + + def visit_ListComp(self, node: ast.ListComp) -> ast.AST: + # Check for illegal expressions + illegals = find_nodes(is_illegal_in_list_comp, node) + if illegals: + raise GuppyError( + "Expression is not supported inside a list comprehension", illegals[0] + ) + + # Desugar into statements that create the iterator, check for a next element, + # get the next element, and finalise the iterator. + gens = [] + for g in node.generators: + if g.is_async: + raise GuppyError("Async generators are not supported", g) + g.iter = self.visit(g.iter) + it = make_var(next(tmp_vars), g.iter) + hasnext = make_var(next(tmp_vars), g.iter) + desugared = DesugaredGenerator( + iter=it, + hasnext=hasnext, + iter_assign=make_assign( + [it], with_loc(it, MakeIter(value=g.iter, origin_node=node)) + ), + hasnext_assign=make_assign( + [hasnext, it], with_loc(it, IterHasNext(value=it)) + ), + next_assign=make_assign( + [g.target, it], with_loc(it, IterNext(value=it)) + ), + iterend=with_loc(it, IterEnd(value=it)), + ifs=g.ifs, + ) + gens.append(desugared) + + node.elt = self.visit(node.elt) + return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) def visit_Call(self, node: ast.Call) -> ast.AST: # Parse compile-time evaluated `py(...)` expression @@ -291,7 +364,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST: self._tmp_assign(tmp, false_const, false_bb) merge_bb = self.cfg.new_bb(true_bb, false_bb) self.bb = merge_bb - return self._make_var(tmp, node) + return make_var(tmp, node) # For all other expressions, just recurse deeper with the node transformer return super().generic_visit(node) @@ -414,3 +487,28 @@ def is_short_circuit_expr(node: ast.AST) -> bool: return isinstance(node, ast.BoolOp) or ( isinstance(node, ast.Compare) and len(node.comparators) > 1 ) + + +def is_illegal_in_list_comp(node: ast.AST) -> bool: + """Checks if an expression is illegal to use in a list comprehension.""" + return isinstance(node, ast.IfExp | ast.NamedExpr) or is_short_circuit_expr(node) + + +def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: + """Creates an `ast.Name` node.""" + node = ast.Name(id=name, ctx=ast.Load) + if loc is not None: + set_location_from(node, loc) + return node + + +def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign: + """Creates an `ast.Assign` node.""" + assert len(lhs) > 0 + adjuster = ContextAdjuster(ast.Store()) + lhs = [adjuster.visit(expr) for expr in lhs] + if len(lhs) == 1: + target = lhs[0] + else: + target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store())) + return with_loc(value, ast.Assign(targets=[target], value=value)) diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index c5cf8a64..ada6fd02 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -11,7 +11,7 @@ from guppy.ast_util import line_col from guppy.cfg.bb import BB from guppy.cfg.cfg import CFG, BaseCFG -from guppy.checker.core import Context, Globals, Variable +from guppy.checker.core import Context, Globals, Locals, Variable from guppy.checker.expr_checker import ExprSynthesizer, to_bool from guppy.checker.stmt_checker import StmtChecker from guppy.error import GuppyError @@ -127,7 +127,7 @@ def check_bb( raise GuppyError(f"Variable `{x}` is not defined", use) # Check the basic block - ctx = Context(globals, {v.name: v for v in inputs}) + ctx = Context(globals, Locals({v.name: v for v in inputs})) checked_stmts = StmtChecker(ctx, bb, return_ty).check_stmts(bb.statements) # If we branch, we also have to check the branch predicate diff --git a/guppy/checker/core.py b/guppy/checker/core.py index aa495311..7a65083b 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -1,5 +1,8 @@ import ast +import copy +import itertools from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator from dataclasses import dataclass from typing import Any, NamedTuple @@ -8,6 +11,8 @@ BoolType, FunctionType, GuppyType, + LinstType, + ListType, NoneType, Subst, SumType, @@ -76,6 +81,8 @@ def default() -> "Globals": SumType.name: SumType, NoneType.name: NoneType, BoolType.name: BoolType, + ListType.name: ListType, + LinstType.name: LinstType, } return Globals({}, tys, {}, {}) @@ -106,8 +113,45 @@ def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034 return self -# Local variable mapping -Locals = dict[str, Variable] +@dataclass +class Locals: + """Scoped mapping from names to variables""" + + vars: dict[str, Variable] + parent_scope: "Locals | None" = None + + def __getitem__(self, item: str) -> Variable: + if item not in self.vars and self.parent_scope: + return self.parent_scope[item] + + return self.vars[item] + + def __setitem__(self, key: str, value: Variable) -> None: + self.vars[key] = value + + def __iter__(self) -> Iterator[str]: + parent_iter = iter(self.parent_scope) if self.parent_scope else iter(()) + return itertools.chain(iter(self.vars), parent_iter) + + def __contains__(self, item: str) -> bool: + return (item in self.vars) or ( + self.parent_scope is not None and item in self.parent_scope + ) + + def __copy__(self) -> "Locals": + # Make a copy of the var map so that mutating the copy doesn't + # mutate our variable mapping + return Locals(self.vars.copy(), copy.copy(self.parent_scope)) + + def keys(self) -> set[str]: + parent_keys = self.parent_scope.keys() if self.parent_scope else set() + return parent_keys | self.vars.keys() + + def items(self) -> Iterable[tuple[str, Variable]]: + parent_items = ( + iter(self.parent_scope.items()) if self.parent_scope else iter(()) + ) + return itertools.chain(self.vars.items(), parent_items) class Context(NamedTuple): diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 5ad876bf..776bfe9b 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -26,8 +26,17 @@ from contextlib import suppress from typing import Any, NoReturn, cast -from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type -from guppy.checker.core import CallableVariable, Context, DummyEvalDict, Globals +from guppy.ast_util import ( + AstNode, + AstVisitor, + breaks_in_loop, + get_type_opt, + name_nodes_in_ast, + return_nodes_in_ast, + with_loc, + with_type, +) +from guppy.checker.core import CallableVariable, Context, DummyEvalDict, Globals, Locals from guppy.error import ( GuppyError, GuppyTypeError, @@ -40,11 +49,26 @@ FunctionType, GuppyType, Inst, + LinstType, + ListType, + NoneType, Subst, TupleType, unify, ) -from guppy.nodes import GlobalName, LocalCall, LocalName, PyExpr, TypeApply +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + GlobalName, + IterEnd, + IterHasNext, + IterNext, + LocalCall, + LocalName, + MakeIter, + PyExpr, + TypeApply, +) # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { @@ -120,6 +144,14 @@ def check( a new desugared expression with type annotations and a substitution with the resolved type variables. """ + # If we already have a type for the expression, we just have to match it against + # the target + if actual := get_type_opt(expr): + subst, inst = check_type_against(actual, ty, expr, kind) + if inst: + expr = with_loc(expr, TypeApply(value=expr, tys=inst)) + return with_type(ty.substitute(subst), expr), subst + # When checking against a variable, we have to synthesize if isinstance(ty, ExistentialTypeVar): expr, syn_ty = self._synthesize(expr, allow_free_vars=False) @@ -147,6 +179,27 @@ def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> tuple[ast.expr, Subst]: subst |= s return node, subst + def visit_List(self, node: ast.List, ty: GuppyType) -> tuple[ast.expr, Subst]: + if not isinstance(ty, ListType | LinstType): + return self._fail(ty, node) + subst: Subst = {} + for i, el in enumerate(node.elts): + node.elts[i], s = self.check(el, ty.element_type.substitute(subst)) + subst |= s + return node, subst + + def visit_DesugaredListComp( + self, node: DesugaredListComp, ty: GuppyType + ) -> tuple[ast.expr, Subst]: + if not isinstance(ty, ListType | LinstType): + return self._fail(ty, node) + node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + subst = unify(ty.element_type, elt_ty, {}) + if subst is None: + actual = LinstType(elt_ty) if elt_ty.linear else ListType(elt_ty) + return self._fail(ty, actual, node) + return node, subst + def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: if len(node.keywords) > 0: raise GuppyError( @@ -217,7 +270,7 @@ def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, GuppyType]: raise GuppyError("Unsupported constant", node) return node, ty - def visit_Name(self, node: ast.Name) -> tuple[ast.expr, GuppyType]: + def visit_Name(self, node: ast.Name) -> tuple[ast.Name, GuppyType]: x = node.id if x in self.ctx.locals: var = self.ctx.locals[x] @@ -244,6 +297,22 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, GuppyType]: node.elts = [n for n, _ in elems] return node, TupleType([ty for _, ty in elems]) + def visit_List(self, node: ast.List) -> tuple[ast.expr, GuppyType]: + if len(node.elts) == 0: + raise GuppyTypeInferenceError( + "Cannot infer type variable in expression of type `list[?T]`", node + ) + node.elts[0], el_ty = self.synthesize(node.elts[0]) + node.elts[1:] = [self._check(el, el_ty)[0] for el in node.elts[1:]] + return node, LinstType(el_ty) if el_ty.linear else ListType(el_ty) + + def visit_DesugaredListComp( + self, node: DesugaredListComp + ) -> tuple[ast.expr, GuppyType]: + node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) + result_ty = LinstType(elt_ty) if elt_ty.linear else ListType(elt_ty) + return node, result_ty + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: # We need to synthesise the argument type, so we can look up dunder methods node.operand, op_ty = self.synthesize(node.operand) @@ -293,6 +362,40 @@ def _synthesize_binary( node, ) + def _synthesize_instance_func( + self, + node: ast.expr, + args: list[ast.expr], + func_name: str, + err: str, + exp_sig: FunctionType | None = None, + give_reason: bool = False, + ) -> tuple[ast.expr, GuppyType]: + """Helper method for expressions that are implemented via instance methods. + + Raises a `GuppyTypeError` if the given instance method is not defined. The error + message can be customised by passing an `err` string and an optional error + reason can be printed. + + Optionally, the signature of the instance function can also be checked against a + given expected signature. + """ + node, ty = self.synthesize(node) + func = self.ctx.globals.get_instance_func(ty, func_name) + if func is None: + reason = f" since it does not implement the `{func_name}` method" + raise GuppyTypeError( + f"Expression of type `{ty}` is {err}{reason if give_reason else ''}", + node, + ) + if exp_sig and unify(exp_sig, func.ty.unquantified()[0], {}) is None: + raise GuppyError( + f"Method `{ty.name}.{func_name}` has signature `{func.ty}`, but " + f"expected `{exp_sig}`", + node, + ) + return func.synthesize_call([node, *args], node, self.ctx) + def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: return self._synthesize_binary(node.left, node.right, node.op, node) @@ -305,6 +408,16 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: left_expr, [op], [right_expr] = node.left, node.ops, node.comparators return self._synthesize_binary(left_expr, right_expr, op, node) + def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType( + [ty, ExistentialTypeVar.new("Key", False)], + ExistentialTypeVar.new("Val", False), + ) + return self._synthesize_instance_func( + node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig + ) + def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: if len(node.keywords) > 0: raise GuppyError("Keyword arguments are not supported", node.keywords[0]) @@ -326,6 +439,57 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: else: raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], ExistentialTypeVar.new("Iter", False)) + expr, ty = self._synthesize_instance_func( + node.value, [], "__iter__", "not iterable", exp_sig + ) + + # If the iterator was created by a `for` loop, we can add some extra checks to + # produce nicer errors for linearity violations. Namely, `break` and `return` + # are not allowed when looping over a linear iterator (`continue` is allowed) + if ty.linear and isinstance(node.origin_node, ast.For): + breaks = breaks_in_loop(node.origin_node) or return_nodes_in_ast( + node.origin_node + ) + if breaks: + raise GuppyTypeError( + f"Loop over iterator with linear type `{ty}` cannot be terminated " + f"prematurely", + breaks[0], + ) + return expr, ty + + def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], TupleType([BoolType(), ty])) + return self._synthesize_instance_func( + node.value, [], "__hasnext__", "not an iterator", exp_sig, True + ) + + def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType( + [ty], TupleType([ExistentialTypeVar.new("T", False), ty]) + ) + return self._synthesize_instance_func( + node.value, [], "__next__", "not an iterator", exp_sig, True + ) + + def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + exp_sig = FunctionType([ty], NoneType()) + return self._synthesize_instance_func( + node.value, [], "__end__", "not an iterator", exp_sig, True + ) + + def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, GuppyType]: + raise InternalGuppyError( + "BB contains `ListComp`. Should have been removed during CFG" + f"construction: `{ast.unparse(node)}`" + ) + def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]: # The method we used for obtaining the Python variables in scope only works in # CPython (see `get_py_scope()`). @@ -639,6 +803,95 @@ def to_bool( return call, return_ty +def synthesize_comprehension( + node: DesugaredListComp, gens: list[DesugaredGenerator], ctx: Context +) -> tuple[DesugaredListComp, GuppyType]: + """Helper function to synthesise the element type of a list comprehension.""" + from guppy.checker.stmt_checker import StmtChecker + + def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None: + """Checks if an expression uses a linear variable from an outer scope. + + Since the expression is executed multiple times in the inner scope, this would + mean that the outer linear variable is used multiple times, which is not + allowed. + """ + for name in name_nodes_in_ast(expr): + x = name.id + if x in locals and x not in locals.vars: + var = locals[x] + if var.ty.linear: + raise GuppyTypeError( + f"Variable `{x}` with linear type `{var.ty}` would be used " + "multiple times when evaluating this comprehension", + name, + ) + + # If there are no more generators left, we can check the list element + if not gens: + node.elt, elt_ty = ExprSynthesizer(ctx).synthesize(node.elt) + check_linear_use_from_outer_scope(node.elt, ctx.locals) + return node, elt_ty + + # Check the iterator in the outer context + gen, *gens = gens + gen.iter_assign = StmtChecker(ctx).visit_Assign(gen.iter_assign) + check_linear_use_from_outer_scope(gen.iter_assign.value, ctx.locals) + + # The rest is checked in a new nested context to ensure that variables don't escape + # their scope + inner_locals = Locals({}, parent_scope=ctx.locals) + inner_ctx = Context(ctx.globals, inner_locals) + expr_sth, stmt_chk = ExprSynthesizer(inner_ctx), StmtChecker(inner_ctx) + gen.hasnext_assign = stmt_chk.visit_Assign(gen.hasnext_assign) + gen.next_assign = stmt_chk.visit_Assign(gen.next_assign) + gen.hasnext, hasnext_ty = expr_sth.visit_Name(gen.hasnext) + gen.hasnext = with_type(hasnext_ty, gen.hasnext) + gen.iter, iter_ty = expr_sth.visit_Name(gen.iter) + gen.iter = with_type(iter_ty, gen.iter) + + # `if` guards are generally not allowed when we're iterating over linear variables. + # The only exception is if all linear variables are already consumed by the first + # guard + if gen.ifs: + gen.ifs[0], _ = expr_sth.synthesize(gen.ifs[0]) + + # Now, check if there are linear iteration variables that have not been used by + # the first guard + for target in name_nodes_in_ast(gen.next_assign.targets[0]): + var = inner_ctx.locals[target.id] + if var.ty.linear and not var.used and gen.ifs: + raise GuppyTypeError( + f"Variable `{var.name}` with linear type `{var.ty}` is not used on " + "all control-flow paths of the list comprehension", + target, + ) + + # Now, we can properly check all guards + for i in range(len(gen.ifs)): + gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i]) + gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx) + check_linear_use_from_outer_scope(gen.ifs[i], inner_locals) + + # Check remaining generators + node, elt_ty = synthesize_comprehension(node, gens, inner_ctx) + + # We have to make sure that all linear variables that were introduced in this scope + # have been used + for x, var in inner_ctx.locals.vars.items(): + if var.ty.linear and not var.used: + raise GuppyTypeError( + f"Variable `{x}` with linear type `{var.ty}` is not used", + var.defined_at, + ) + + # The iter finalizer is again checked in the outer context + ctx.locals[gen.iter.id].used = None + gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend) + gen.iterend = with_type(iterend_ty, gen.iterend) + return node, elt_ty + + def python_value_to_guppy_type( v: Any, node: ast.expr, globals: Globals ) -> GuppyType | None: diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index ea7f23b6..0847d609 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -49,14 +49,14 @@ def check_call( ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker args, subst, inst = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), subst + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty, inst = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), ty + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), ty @dataclass diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 0fa46608..85efc2b5 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -22,11 +22,13 @@ class StmtChecker(AstVisitor[BBStatement]): ctx: Context - bb: BB - return_ty: GuppyType + bb: BB | None + return_ty: GuppyType | None - def __init__(self, ctx: Context, bb: BB, return_ty: GuppyType) -> None: - assert not return_ty.unsolved_vars + def __init__( + self, ctx: Context, bb: BB | None = None, return_ty: GuppyType | None = None + ) -> None: + assert not return_ty or not return_ty.unsolved_vars self.ctx = ctx self.bb = bb self.return_ty = return_ty @@ -55,7 +57,7 @@ def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: f"Variable `{x}` with linear type `{var.ty}` is not used", var.defined_at, ) - self.ctx.locals[x] = Variable(x, ty, node, None) + self.ctx.locals[x] = Variable(x, ty, lhs, None) # The only other thing we support right now are tuples case ast.Tuple(elts=elts): @@ -76,7 +78,7 @@ def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: case _: raise GuppyError("Assignment pattern not supported", lhs) - def visit_Assign(self, node: ast.Assign) -> ast.stmt: + def visit_Assign(self, node: ast.Assign) -> ast.Assign: if len(node.targets) > 1: # This is the case for assignments like `a = b = 1` raise GuppyError("Multi assignment not supported", node) @@ -113,6 +115,9 @@ def visit_Expr(self, node: ast.Expr) -> ast.stmt: return node def visit_Return(self, node: ast.Return) -> ast.stmt: + if not self.return_ty: + raise InternalGuppyError("return_ty required to check return stmt!") + if node.value is not None: node.value, subst = self._check_expr( node.value, self.return_ty, "return value" @@ -127,6 +132,9 @@ def visit_Return(self, node: ast.Return) -> ast.stmt: def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> ast.stmt: from guppy.checker.func_checker import check_nested_func_def + if not self.bb: + raise InternalGuppyError("BB required to check nested function def!") + func_def = check_nested_func_def(node, self.bb, self.ctx) self.ctx.locals[func_def.name] = Variable( func_def.name, func_def.ty, func_def, None diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index d2f1abb9..0b9c9a9a 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -56,7 +56,7 @@ def compile_bb( for (i, v) in enumerate(inputs) }, ) - dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, bb, dfg) + dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, dfg) # If we branch, we also have to compile the branch predicate if len(bb.successors) > 1: diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index cf01fc07..1d71556e 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -1,8 +1,16 @@ import ast +from collections.abc import Iterator +from contextlib import contextmanager from typing import Any -from guppy.ast_util import AstVisitor, get_type -from guppy.compiler.core import CompiledFunction, CompilerBase, DFContainer +from guppy.ast_util import AstVisitor, get_type, with_loc, with_type +from guppy.cfg.builder import tmp_vars +from guppy.compiler.core import ( + CompiledFunction, + CompilerBase, + DFContainer, + PortVariable, +) from guppy.error import GuppyError, InternalGuppyError from guppy.gtypes import ( BoolType, @@ -14,8 +22,16 @@ type_to_row, ) from guppy.hugr import ops, val -from guppy.hugr.hugr import OutPortV -from guppy.nodes import GlobalCall, GlobalName, LocalCall, LocalName, TypeApply +from guppy.hugr.hugr import DFContainingNode, OutPortV, VNode +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + GlobalCall, + GlobalName, + LocalCall, + LocalName, + TypeApply, +) class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): @@ -39,6 +55,85 @@ def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]: """ return [self.compile(e, dfg) for e in expr_to_row(expr)] + @contextmanager + def _new_dfcontainer( + self, inputs: list[ast.Name], node: DFContainingNode + ) -> Iterator[None]: + """Context manager to build a graph inside a new `DFContainer`. + + Automatically updates `self.dfg` and makes the inputs available. + """ + old = self.dfg + inp = self.graph.add_input(parent=node) + # Check that the input names are unique + assert len({inp.id for inp in inputs}) == len(inputs), "Inputs are not unique" + new_locals = { + name.id: PortVariable(name.id, inp.add_out_port(get_type(name)), name, None) + for name in inputs + } + self.dfg = DFContainer(node, self.dfg.locals | new_locals) + with self.graph.parent(node): + yield + self.dfg = old + + @contextmanager + def _new_loop( + self, + loop_vars: list[ast.Name], + branch: ast.Name, + parent: DFContainingNode | None = None, + ) -> Iterator[None]: + """Context manager to build a graph inside a new `TailLoop` node. + + Automatically adds the `Output` node to the loop body once the context manager + exits. + """ + loop = self.graph.add_tail_loop( + [self.visit(name) for name in loop_vars], parent + ) + with self._new_dfcontainer(loop_vars, loop): + yield + # Output the branch predicate and the inputs for the next iteration + self.graph.add_output( + # Note that we have to do fresh calls to `self.visit` here since we're + # in a new context + [self.visit(branch), *(self.visit(name) for name in loop_vars)] + ) + # Update the DFG with the outputs from the loop + for name in loop_vars: + self.dfg[name.id].port = loop.add_out_port(get_type(name)) + + @contextmanager + def _new_case( + self, inputs: list[ast.Name], outputs: list[ast.Name], cond_node: VNode + ) -> Iterator[None]: + """Context manager to build a graph inside a new `Case` node. + + Automatically adds the `Output` node once the context manager exits. + """ + with self._new_dfcontainer(inputs, self.graph.add_case(cond_node)): + yield + self.graph.add_output([self.visit(name) for name in outputs]) + + @contextmanager + def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]: + """Context manager to build a graph inside the `true` case of a `Conditional` + + In the `false` case, the inputs are outputted as is. + """ + cond_node = self.graph.add_conditional( + self.visit(cond), [self.visit(inp) for inp in inputs] + ) + # If the condition is false, output the inputs as is + with self._new_case(inputs, inputs, cond_node): + pass + # If the condition is true, we enter the `with` block + with self._new_case(inputs, inputs, cond_node): + yield + # Update the DFG with the outputs from the Conditional node + for name in inputs: + self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) + def visit_Constant(self, node: ast.Constant) -> OutPortV: if value := python_value_to_hugr(node.value): const = self.graph.add_constant(value, get_type(node)).out_port(0) @@ -59,6 +154,12 @@ def visit_Tuple(self, node: ast.Tuple) -> OutPortV: inputs=[self.visit(e) for e in node.elts] ).out_port(0) + def visit_List(self, node: ast.List) -> OutPortV: + # Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension + return self.graph.add_node( + ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts] + ).add_out_port(get_type(node)) + def _pack_returns(self, returns: list[OutPortV]) -> OutPortV: """Groups function return values into a tuple""" if len(returns) != 1: @@ -118,6 +219,61 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") + def visit_DesugaredListComp(self, node: DesugaredListComp) -> OutPortV: + from guppy.compiler.stmt_compiler import StmtCompiler + + compiler = StmtCompiler(self.graph, self.globals) + + # Make up a name for the list under construction and bind it to an empty list + list_ty = get_type(node) + list_name = with_type(list_ty, with_loc(node, LocalName(id=next(tmp_vars)))) + empty_list = self.graph.add_node(ops.DummyOp(name="MakeList")) + self.dfg[list_name.id] = PortVariable( + list_name.id, empty_list.add_out_port(list_ty), node, None + ) + + def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None: + """Helper function to generate nested TailLoop nodes for generators""" + # If there are no more generators left, just append the element to the list + if not gens: + list_port, elt_port = self.visit(list_name), self.visit(elt) + push = self.graph.add_node( + ops.DummyOp(name="Push"), inputs=[list_port, elt_port] + ) + self.dfg[list_name.id].port = push.add_out_port(list_port.ty) + return + + # Otherwise, compile the first iterator and construct a TailLoop + gen, *gens = gens + compiler.compile_stmts([gen.iter_assign], self.dfg) + inputs = [gen.iter, list_name] + with self._new_loop(inputs, gen.hasnext): + # If there is a next element, compile it and continue with the next + # generator + compiler.compile_stmts([gen.hasnext_assign], self.dfg) + with self._if_true(gen.hasnext, inputs): + + def compile_ifs(ifs: list[ast.expr]) -> None: + """Helper function to compile a series of if-guards into nested + Conditional nodes.""" + if ifs: + if_expr, *ifs = ifs + # If the condition is true, continue with the next one + with self._if_true(if_expr, inputs): + compile_ifs(ifs) + else: + # If there are no guards left, compile the next generator + compile_generators(elt, gens) + + compiler.compile_stmts([gen.next_assign], self.dfg) + compile_ifs(gen.ifs) + + # After the loop is done, we have to finalize the iterator + self.visit(gen.iterend) + + compile_generators(node.elt, node.generators) + return self.visit(list_name) + def visit_BinOp(self, node: ast.BinOp) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py index 18147b22..a2ab4f6a 100644 --- a/guppy/compiler/stmt_compiler.py +++ b/guppy/compiler/stmt_compiler.py @@ -2,7 +2,6 @@ from collections.abc import Sequence from guppy.ast_util import AstVisitor -from guppy.checker.cfg_checker import CheckedBB from guppy.compiler.core import ( CompiledGlobals, CompilerBase, @@ -22,7 +21,6 @@ class StmtCompiler(CompilerBase, AstVisitor[None]): expr_compiler: ExprCompiler - bb: CheckedBB dfg: DFContainer def __init__(self, graph: Hugr, globals: CompiledGlobals): @@ -32,7 +30,6 @@ def __init__(self, graph: Hugr, globals: CompiledGlobals): def compile_stmts( self, stmts: Sequence[ast.stmt], - bb: CheckedBB, dfg: DFContainer, ) -> DFContainer: """Compiles a list of basic statements into a dataflow node. @@ -40,7 +37,6 @@ def compile_stmts( Note that the `dfg` is mutated in-place. After compilation, the DFG will also contain all variables that are assigned in the given list of statements. """ - self.bb = bb self.dfg = dfg for s in stmts: self.visit(s) diff --git a/guppy/declared.py b/guppy/declared.py index b868a374..df1c0f29 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -1,7 +1,7 @@ import ast from dataclasses import dataclass -from guppy.ast_util import AstNode, has_empty_body +from guppy.ast_util import AstNode, has_empty_body, with_loc from guppy.checker.core import Context, Globals from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature @@ -34,14 +34,14 @@ def check_call( ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker args, subst, inst = check_call(self.ty, args, ty, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), subst + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), subst def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context ) -> tuple[GlobalCall, GuppyType]: # Use default implementation from the expression checker args, ty, inst = synthesize_call(self.ty, args, node, ctx) - return GlobalCall(func=self, args=args, type_args=inst), ty + return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), ty def add_to_graph(self, graph: Hugr, parent: Node) -> None: self.node = graph.add_declare(self.ty, parent, self.name) diff --git a/guppy/decorator.py b/guppy/decorator.py index d4518f7b..d7e728e7 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -80,6 +80,7 @@ def type( hugr_ty: tys.Type, name: str = "", linear: bool = False, + bound: tys.TypeBound | None = None, ) -> ClassDecorator: """Decorator to annotate a class definitions as Guppy types. @@ -117,6 +118,9 @@ def linear(self) -> bool: def to_hugr(self) -> tys.Type: return hugr_ty + def hugr_bound(self) -> tys.TypeBound: + return bound or super().hugr_bound() + def transform(self, transformer: TypeTransformer) -> GuppyType: return transformer.transform(self) or NewType( [ty.transform(transformer) for ty in self.args] diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 9ac0b432..66290503 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -74,6 +74,11 @@ def to_hugr(self) -> tys.Type: def transform(self, transformer: "TypeTransformer") -> "GuppyType": pass + def hugr_bound(self) -> tys.TypeBound: + if self.linear: + return tys.TypeBound.Any + return tys.TypeBound.join(*(ty.hugr_bound() for ty in self.type_args)) + @property def unsolved_vars(self) -> set["ExistentialTypeVar"]: return self._unsolved_vars @@ -100,6 +105,11 @@ def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: def type_args(self) -> Iterator["GuppyType"]: return iter(()) + def hugr_bound(self) -> tys.TypeBound: + # We shouldn't make variables equatable, since we also want to substitute types + # like `float` + return tys.TypeBound.Any if self.linear else tys.TypeBound.Copyable + def transform(self, transformer: "TypeTransformer") -> GuppyType: return transformer.transform(self) or self @@ -107,7 +117,7 @@ def __str__(self) -> str: return self.display_name def to_hugr(self) -> tys.Type: - return tys.Variable(i=self.idx, b=tys.TypeBound.from_linear(self.linear)) + return tys.Variable(i=self.idx, b=self.hugr_bound()) @dataclass(frozen=True) @@ -194,13 +204,14 @@ def to_hugr(self) -> tys.PolyFuncType: outs = [t.to_hugr() for t in type_to_row(self.returns)] func_ty = tys.FunctionType(input=ins, output=outs, extension_reqs=[]) return tys.PolyFuncType( - params=[ - tys.TypeTypeParam(b=tys.TypeBound.from_linear(v.linear)) - for v in self.quantified - ], + params=[tys.TypeTypeParam(b=v.hugr_bound()) for v in self.quantified], body=func_ty, ) + def hugr_bound(self) -> tys.TypeBound: + # Functions are not equatable, only copyable + return tys.TypeBound.Copyable + def transform(self, transformer: "TypeTransformer") -> GuppyType: return transformer.transform(self) or FunctionType( [ty.transform(transformer) for ty in self.args], @@ -314,6 +325,90 @@ def transform(self, transformer: "TypeTransformer") -> GuppyType: ) +@dataclass(frozen=True) +class ListType(GuppyType): + element_type: GuppyType + + name: ClassVar[Literal["list"]] = "list" + linear: bool = field(default=False, init=False) + + @staticmethod + def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: + from guppy.error import GuppyError + + if len(args) == 0: + raise GuppyError("Missing type parameter for generic type `list`", node) + if len(args) > 1: + raise GuppyError("Too many type arguments for generic type `list`", node) + (arg,) = args + if arg.linear: + raise GuppyError( + "Type `list` cannot store linear data, use `linst` instead", node + ) + return ListType(arg) + + def __str__(self) -> str: + return f"list[{self.element_type}]" + + @property + def type_args(self) -> Iterator[GuppyType]: + return iter((self.element_type,)) + + def to_hugr(self) -> tys.Type: + return tys.Opaque( + extension="Collections", + id="List", + args=[tys.TypeTypeArg(ty=self.element_type.to_hugr())], + bound=self.hugr_bound(), + ) + + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or ListType( + self.element_type.transform(transformer) + ) + + +@dataclass(frozen=True) +class LinstType(GuppyType): + element_type: GuppyType + + name: ClassVar[Literal["linst"]] = "linst" + + @staticmethod + def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: + from guppy.error import GuppyError + + if len(args) == 0: + raise GuppyError("Missing type parameter for generic type `linst`", node) + if len(args) > 1: + raise GuppyError("Too many type arguments for generic type `linst`", node) + return LinstType(args[0]) + + def __str__(self) -> str: + return f"linst[{self.element_type}]" + + @property + def linear(self) -> bool: + return self.element_type.linear + + @property + def type_args(self) -> Iterator[GuppyType]: + return iter((self.element_type,)) + + def to_hugr(self) -> tys.Type: + return tys.Opaque( + extension="Collections", + id="List", + args=[tys.TypeTypeArg(ty=self.element_type.to_hugr())], + bound=self.hugr_bound(), + ) + + def transform(self, transformer: "TypeTransformer") -> GuppyType: + return transformer.transform(self) or LinstType( + self.element_type.transform(transformer) + ) + + @dataclass(frozen=True) class NoneType(GuppyType): name: ClassVar[Literal["None"]] = "None" @@ -484,8 +579,11 @@ def type_from_ast( return NoneType() if isinstance(v, str): try: - return type_from_ast(ast.parse(v), globals, type_var_mapping) - except SyntaxError: + [stmt] = ast.parse(v).body + if not isinstance(stmt, ast.Expr): + raise GuppyError("Invalid Guppy type", node) + return type_from_ast(stmt.value, globals, type_var_mapping) + except (SyntaxError, ValueError): raise GuppyError("Invalid Guppy type", node) from None raise GuppyError(f"Constant `{v}` is not a valid type", node) diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 8bc43b71..ab0e2b1a 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -386,6 +386,7 @@ def add_output( parent: Node | None = None, ) -> VNode: """Adds an `Output` node to the graph.""" + parent = parent or self._default_parent node = self.add_node(ops.Output(), input_tys, [], parent, inputs) if isinstance(parent, DFContainingNode): parent.output_child = node @@ -706,6 +707,25 @@ def insert_order_edges(self) -> "Hugr": elif isinstance(n.op, ops.LoadConstant): assert n.parent.input_child is not None self.add_order_edge(n.parent.input_child, n) + + # Also add order edges for non-local edges + for src, tgt in list(self.edges()): + # Exclude CF and constant edges + if isinstance(src, OutPortCF) or isinstance( + src.node.op, ops.FuncDecl | ops.FuncDefn | ops.Const + ): + continue + + if src.node.parent != tgt.node.parent: + # Walk up the hierarchy from the tgt until we hit a node at the same + # level as src + node = tgt.node + while node.parent != src.node.parent: + if node.parent is None: + raise ValueError("Invalid non-local edge!") + node = node.parent + # Add order edge to make sure that the src is executed first + self.add_order_edge(src.node, node) return self def to_raw(self) -> raw.RawHugr: diff --git a/guppy/hugr/tys.py b/guppy/hugr/tys.py index b598febf..98874cbd 100644 --- a/guppy/hugr/tys.py +++ b/guppy/hugr/tys.py @@ -228,8 +228,15 @@ class TypeBound(Enum): Any = "A" @staticmethod - def from_linear(linear: bool) -> "TypeBound": - return TypeBound.Any if linear else TypeBound.Copyable + def join(*bs: "TypeBound") -> "TypeBound": + """Computes the least upper bound for a sequence of bounds.""" + res = TypeBound.Eq + for b in bs: + if b == TypeBound.Any: + return TypeBound.Any + if res == TypeBound.Eq: + res = b + return res class Opaque(BaseModel): diff --git a/guppy/module.py b/guppy/module.py index bd41d5d2..65c9a81d 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -18,6 +18,7 @@ from guppy.hugr.hugr import Hugr PyFunc = Callable[..., Any] +PyFuncDefOrDecl = tuple[bool, PyFunc] class GuppyModule: @@ -45,7 +46,7 @@ class GuppyModule: # When `_instance_buffer` is not `None`, then all registered functions will be # buffered in this list. They only get properly registered, once # `_register_buffered_instance_funcs` is called. This way, we can associate - _instance_func_buffer: dict[str, PyFunc | CustomFunction] | None + _instance_func_buffer: dict[str, PyFuncDefOrDecl | CustomFunction] | None def __init__(self, name: str, import_builtins: bool = True): self.name = name @@ -95,7 +96,7 @@ def register_func_def( self._check_not_yet_compiled() func_ast = parse_py_func(f) if self._instance_func_buffer is not None: - self._instance_func_buffer[func_ast.name] = f + self._instance_func_buffer[func_ast.name] = (True, f) else: name = ( qualified_name(instance, func_ast.name) if instance else func_ast.name @@ -103,12 +104,20 @@ def register_func_def( self._check_name_available(name, func_ast) self._func_defs[name] = func_ast, get_py_scope(f) - def register_func_decl(self, f: PyFunc) -> None: + def register_func_decl( + self, f: PyFunc, instance: type[GuppyType] | None = None + ) -> None: """Registers a Python function declaration as belonging to this Guppy module.""" self._check_not_yet_compiled() func_ast = parse_py_func(f) - self._check_name_available(func_ast.name, func_ast) - self._func_decls[func_ast.name] = func_ast + if self._instance_func_buffer is not None: + self._instance_func_buffer[func_ast.name] = (False, f) + else: + name = ( + qualified_name(instance, func_ast.name) if instance else func_ast.name + ) + self._check_name_available(name, func_ast) + self._func_decls[name] = func_ast def register_custom_func( self, func: CustomFunction, instance: type[GuppyType] | None = None @@ -143,7 +152,11 @@ def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: if isinstance(f, CustomFunction): self.register_custom_func(f, instance) else: - self.register_func_def(f, instance) + is_def, pyfunc = f + if is_def: + self.register_func_def(pyfunc, instance) + else: + self.register_func_decl(pyfunc, instance) @property def compiled(self) -> bool: diff --git a/guppy/nodes.py b/guppy/nodes.py index ef50d84c..2fd72a8a 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -12,13 +12,13 @@ from guppy.checker.core import CallableVariable, Variable -class LocalName(ast.expr): +class LocalName(ast.Name): id: str _fields = ("id",) -class GlobalName(ast.expr): +class GlobalName(ast.Name): id: str value: "Variable" @@ -60,6 +60,93 @@ class TypeApply(ast.expr): ) +class MakeIter(ast.expr): + """Creates an iterator using the `__iter__` magic method. + + This node is inserted in `for` loops and list comprehensions. + """ + + value: ast.expr + + # Node that triggered the creation of this iterator. For example, a for loop stmt. + # It is not mentioned in `_fields` so that it is not visible to AST visitors + origin_node: ast.AST + + _fields = ("value",) + + +class IterHasNext(ast.expr): + """Checks if an iterator has a next element using the `__hasnext__` magic method. + + This node is inserted in `for` loops and list comprehensions. + """ + + value: ast.expr + + _fields = ("value",) + + +class IterNext(ast.expr): + """Obtains the next element of an iterator using the `__next__` magic method. + + This node is inserted in `for` loops and list comprehensions. + """ + + value: ast.expr + + _fields = ("value",) + + +class IterEnd(ast.expr): + """Finalises an iterator using the `__end__` magic method. + + This node is inserted in `for` loops and list comprehensions. It is needed to + consume linear iterators once they are finished. + """ + + value: ast.expr + + _fields = ("value",) + + +class DesugaredGenerator(ast.expr): + """A single desugared generator in a list comprehension. + + Stores assignments of the original generator targets as well as dummy variables for + the iterator and hasnext test. + """ + + iter_assign: ast.Assign + hasnext_assign: ast.Assign + next_assign: ast.Assign + iterend: ast.expr + iter: ast.Name + hasnext: ast.Name + ifs: list[ast.expr] + + _fields = ( + "iter_assign", + "hasnext_assign", + "next_assign", + "iterend", + "iter", + "hasnext", + "ifs", + ) + + +class DesugaredListComp(ast.expr): + """A desugared list comprehension.""" + + elt: ast.expr + generators: list[DesugaredGenerator] + + _fields = ( + "elt", + "generators", + ) + + class PyExpr(ast.expr): """A compile-time evaluated `py(...)` expression.""" diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index d9ee0b84..196b7663 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -130,6 +130,22 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: return expr, ty +class FailingChecker(CustomCallChecker): + """Call checker for Python functions that are not available in Guppy. + + Gives the uses a nicer error message when they try to use an unsupported feature. + """ + + def __init__(self, msg: str) -> None: + self.msg = msg + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + raise GuppyError(self.msg, self.node) + + def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + raise GuppyError(self.msg, self.node) + + class UnsupportedChecker(CustomCallChecker): """Call checker for Python builtin functions that are not available in Guppy. diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index 57c22920..563aeb2f 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -1,16 +1,17 @@ """Guppy module for builtin types and operations.""" -# mypy: disable-error-code="empty-body, misc, override, no-untyped-def" +# mypy: disable-error-code="empty-body, misc, override, valid-type, no-untyped-def" from guppy.custom import DefaultCallChecker, NoopCompiler from guppy.decorator import guppy -from guppy.gtypes import BoolType +from guppy.gtypes import BoolType, LinstType, ListType from guppy.hugr import ops, tys from guppy.module import GuppyModule from guppy.prelude._internal import ( CallableChecker, CoercingChecker, DunderChecker, + FailingChecker, FloatBoolCompiler, FloatDivmodCompiler, FloatFloordivCompiler, @@ -27,6 +28,9 @@ builtins = GuppyModule("builtins", import_builtins=False) +T = guppy.type_var(builtins, "T") +L = guppy.type_var(builtins, "L", linear=True) + @guppy.extend_type(builtins, BoolType) class Bool: @@ -226,7 +230,7 @@ def __xor__(self: int, other: int) -> int: ... -@guppy.type(builtins, hugr_float_type, name="float") +@guppy.type(builtins, hugr_float_type, name="float", bound=tys.TypeBound.Copyable) class Float: @guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker()) def __abs__(self: float) -> float: @@ -365,6 +369,155 @@ def __trunc__(self: float) -> float: ... +@guppy.extend_type(builtins, ListType) +class List: + @guppy.hugr_op(builtins, ops.DummyOp(name="Concat")) + def __add__(self: list[T], other: list[T]) -> list[T]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="IsEmpty")) + def __bool__(self: list[T]) -> bool: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Contains")) + def __contains__(self: list[T], el: T) -> bool: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="AssertEmpty")) + def __end__(self: list[T]) -> None: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Lookup")) + def __getitem__(self: list[T], idx: int) -> T: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="IsNonEmpty")) + def __hasnext__(self: list[T]) -> tuple[bool, list[T]]: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __iter__(self: list[T]) -> list[T]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Length")) + def __len__(self: list[T]) -> int: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Repeat")) + def __mul__(self: list[T], other: int) -> list[T]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Pop")) + def __next__(self: list[T]) -> tuple[T, list[T]]: + ... + + @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) + def __setitem__(self: list[T], idx: int, value: T) -> None: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Append"), ReversingChecker()) + def __radd__(self: list[T], other: list[T]) -> list[T]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Repeat"), ReversingChecker()) + def __rmul__(self: list[T], other: int) -> list[T]: + ... + + @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) + def append(self: list[T], elt: T) -> None: + ... + + @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) + def clear(self: list[T]) -> None: + ... + + @guppy.custom(builtins, NoopCompiler()) # Can be noop since lists are immutable + def copy(self: list[T]) -> list[T]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Count")) + def count(self: list[T], elt: T) -> int: + ... + + @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) + def extend(self: list[T], seq: None) -> None: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Find")) + def index(self: list[T], elt: T) -> int: + ... + + @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) + def pop(self: list[T], idx: int) -> None: + ... + + @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) + def remove(self: list[T], elt: T) -> None: + ... + + @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) + def reverse(self: list[T]) -> None: + ... + + @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) + def sort(self: list[T]) -> None: + ... + + +linst = list + + +@guppy.extend_type(builtins, LinstType) +class Linst: + @guppy.hugr_op(builtins, ops.DummyOp(name="Append")) + def __add__(self: linst[L], other: linst[L]) -> linst[L]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="AssertEmpty")) + def __end__(self: linst[L]) -> None: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="IsNonempty")) + def __hasnext__(self: linst[L]) -> tuple[bool, linst[L]]: + ... + + @guppy.custom(builtins, NoopCompiler()) + def __iter__(self: linst[L]) -> linst[L]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Length")) + def __len__(self: linst[L]) -> tuple[int, linst[L]]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Pop")) + def __next__(self: linst[L]) -> tuple[L, linst[L]]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Append"), ReversingChecker()) + def __radd__(self: linst[L], other: linst[L]) -> linst[L]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Repeat"), ReversingChecker()) + def __rmul__(self: linst[L], other: int) -> linst[L]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Push")) + def append(self: linst[L], elt: L) -> linst[L]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="PopAt")) + def pop(self: linst[L], idx: int) -> tuple[L, linst[L]]: + ... + + @guppy.hugr_op(builtins, ops.DummyOp(name="Reverse")) + def reverse(self: linst[L]) -> linst[L]: + ... + + @guppy.custom(builtins, checker=FailingChecker("Guppy lists are immutable")) + def sort(self: list[T]) -> None: + ... + + @guppy.custom(builtins, checker=DunderChecker("__abs__"), higher_order_value=False) def abs(x): ... @@ -403,6 +556,11 @@ def _int(x): ... +@guppy.custom(builtins, checker=DunderChecker("__len__"), higher_order_value=False) +def len(x): + ... + + @guppy.custom( builtins, checker=DunderChecker("__pow__", num_args=2), higher_order_value=False ) @@ -578,13 +736,10 @@ def iter(x): ... -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def len(x): - ... - - -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def list(x): +@guppy.custom( + builtins, checker=UnsupportedChecker(), name="list", higher_order_value=False +) +def _list(x): ... diff --git a/ruff.toml b/ruff.toml index 8c9e618a..cfe3fd0a 100644 --- a/ruff.toml +++ b/ruff.toml @@ -73,7 +73,7 @@ ignore = [ [per-file-ignores] "guppy/ast_util.py" = ["B009", "B010"] "guppy/decorator.py" = ["B010"] -"tests/integration/*" = ["F841"] +"tests/integration/*" = ["F841", "C416", "RUF005"] "tests/{hugr,integration}/*" = ["B", "FBT", "SIM", "I"] # [pydocstyle] diff --git a/tests/error/comprehension_errors/__init__.py b/tests/error/comprehension_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/comprehension_errors/capture1.err b/tests/error/comprehension_errors/capture1.err new file mode 100644 index 00000000..078f9eaf --- /dev/null +++ b/tests/error/comprehension_errors/capture1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(xs: list[int], q: Qubit) -> linst[Qubit]: +13: return [q for x in xs] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/capture1.py b/tests/error/comprehension_errors/capture1.py new file mode 100644 index 00000000..9af5781a --- /dev/null +++ b/tests/error/comprehension_errors/capture1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(xs: list[int], q: Qubit) -> linst[Qubit]: + return [q for x in xs] + + +module.compile() diff --git a/tests/error/comprehension_errors/capture2.err b/tests/error/comprehension_errors/capture2.err new file mode 100644 index 00000000..5e16e12a --- /dev/null +++ b/tests/error/comprehension_errors/capture2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(xs: list[int], q: Qubit) -> list[int]: +18: return [x for x in xs if bar(q)] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/capture2.py b/tests/error/comprehension_errors/capture2.py new file mode 100644 index 00000000..27970631 --- /dev/null +++ b/tests/error/comprehension_errors/capture2.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> bool: + ... + + +@guppy(module) +def foo(xs: list[int], q: Qubit) -> list[int]: + return [x for x in xs if bar(q)] + + +module.compile() diff --git a/tests/error/comprehension_errors/guarded1.err b/tests/error/comprehension_errors/guarded1.err new file mode 100644 index 00000000..02498bc7 --- /dev/null +++ b/tests/error/comprehension_errors/guarded1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[tuple[bool, Qubit]]) -> linst[Qubit]: +13: return [q for b, q in qs if b] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` is not used on all control-flow paths of the list comprehension diff --git a/tests/error/comprehension_errors/guarded1.py b/tests/error/comprehension_errors/guarded1.py new file mode 100644 index 00000000..ecb264a9 --- /dev/null +++ b/tests/error/comprehension_errors/guarded1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[tuple[bool, Qubit]]) -> linst[Qubit]: + return [q for b, q in qs if b] + + +module.compile() diff --git a/tests/error/comprehension_errors/guarded2.err b/tests/error/comprehension_errors/guarded2.err new file mode 100644 index 00000000..673103ac --- /dev/null +++ b/tests/error/comprehension_errors/guarded2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(qs: linst[tuple[bool, Qubit]]) -> list[int]: +18: return [42 for b, q in qs if b if q] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` is not used on all control-flow paths of the list comprehension diff --git a/tests/error/comprehension_errors/guarded2.py b/tests/error/comprehension_errors/guarded2.py new file mode 100644 index 00000000..5413b8c0 --- /dev/null +++ b/tests/error/comprehension_errors/guarded2.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> bool: + ... + + +@guppy(module) +def foo(qs: linst[tuple[bool, Qubit]]) -> list[int]: + return [42 for b, q in qs if b if q] + + +module.compile() diff --git a/tests/error/comprehension_errors/illegal_short_circuit.err b/tests/error/comprehension_errors/illegal_short_circuit.err new file mode 100644 index 00000000..6c8bfb78 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_short_circuit.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo(xs: list[int]) -> None: +6: [x for x in xs if x < 5 and x != 6] + ^^^^^^^^^^^^^^^^ +GuppyError: Expression is not supported inside a list comprehension diff --git a/tests/error/comprehension_errors/illegal_short_circuit.py b/tests/error/comprehension_errors/illegal_short_circuit.py new file mode 100644 index 00000000..dfd0b0a2 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_short_circuit.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[int]) -> None: + [x for x in xs if x < 5 and x != 6] diff --git a/tests/error/comprehension_errors/illegal_ternary.err b/tests/error/comprehension_errors/illegal_ternary.err new file mode 100644 index 00000000..b7e9e7a9 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_ternary.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo(xs: list[int], ys: list[int], b: bool) -> None: +6: [x for x in (xs if b else ys)] + ^^^^^^^^^^^^^^^ +GuppyError: Expression is not supported inside a list comprehension diff --git a/tests/error/comprehension_errors/illegal_ternary.py b/tests/error/comprehension_errors/illegal_ternary.py new file mode 100644 index 00000000..3b587938 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_ternary.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[int], ys: list[int], b: bool) -> None: + [x for x in (xs if b else ys)] diff --git a/tests/error/comprehension_errors/illegal_walrus.err b/tests/error/comprehension_errors/illegal_walrus.err new file mode 100644 index 00000000..2fbb77d7 --- /dev/null +++ b/tests/error/comprehension_errors/illegal_walrus.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo(xs: list[int]) -> None: +6: [y := x for x in xs] + ^^^^^^ +GuppyError: Expression is not supported inside a list comprehension diff --git a/tests/error/comprehension_errors/illegal_walrus.py b/tests/error/comprehension_errors/illegal_walrus.py new file mode 100644 index 00000000..d95a345b --- /dev/null +++ b/tests/error/comprehension_errors/illegal_walrus.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[int]) -> None: + [y := x for x in xs] diff --git a/tests/error/comprehension_errors/multi_use1.err b/tests/error/comprehension_errors/multi_use1.err new file mode 100644 index 00000000..8fcb48f9 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: +13: return [q for q in qs for x in xs] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/multi_use1.py b/tests/error/comprehension_errors/multi_use1.py new file mode 100644 index 00000000..d80f6a38 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: + return [q for q in qs for x in xs] + + +module.compile() diff --git a/tests/error/comprehension_errors/multi_use2.err b/tests/error/comprehension_errors/multi_use2.err new file mode 100644 index 00000000..83e6f324 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: +13: return [q for x in xs for q in qs] + ^^ +GuppyTypeError: Variable `qs` with linear type `linst[Qubit]` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/multi_use2.py b/tests/error/comprehension_errors/multi_use2.py new file mode 100644 index 00000000..9dc78ba7 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use2.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: + return [q for x in xs for q in qs] + + +module.compile() diff --git a/tests/error/comprehension_errors/multi_use3.err b/tests/error/comprehension_errors/multi_use3.err new file mode 100644 index 00000000..c8b5ab0b --- /dev/null +++ b/tests/error/comprehension_errors/multi_use3.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(qs: linst[Qubit], xs: list[int]) -> list[int]: +18: return [x for q in qs for x in xs if bar(q)] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` would be used multiple times when evaluating this comprehension diff --git a/tests/error/comprehension_errors/multi_use3.py b/tests/error/comprehension_errors/multi_use3.py new file mode 100644 index 00000000..cdf10bc4 --- /dev/null +++ b/tests/error/comprehension_errors/multi_use3.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> bool: + ... + + +@guppy(module) +def foo(qs: linst[Qubit], xs: list[int]) -> list[int]: + return [x for q in qs for x in xs if bar(q)] + + +module.compile() diff --git a/tests/error/comprehension_errors/not_used.err b/tests/error/comprehension_errors/not_used.err new file mode 100644 index 00000000..9c191989 --- /dev/null +++ b/tests/error/comprehension_errors/not_used.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit]) -> list[int]: +13: return [42 for q in qs] + ^ +GuppyTypeError: Variable `q` with linear type `Qubit` is not used diff --git a/tests/error/comprehension_errors/not_used.py b/tests/error/comprehension_errors/not_used.py new file mode 100644 index 00000000..c5854103 --- /dev/null +++ b/tests/error/comprehension_errors/not_used.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit]) -> list[int]: + return [42 for q in qs] + + +module.compile() diff --git a/tests/error/comprehension_errors/pattern_override1.err b/tests/error/comprehension_errors/pattern_override1.err new file mode 100644 index 00000000..d8fe0a3b --- /dev/null +++ b/tests/error/comprehension_errors/pattern_override1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[tuple[Qubit, Qubit]]) -> linst[Qubit]: +13: return [q for q, q in qs] + ^ +GuppyError: Variable `q` with linear type `Qubit` is not used diff --git a/tests/error/comprehension_errors/pattern_override1.py b/tests/error/comprehension_errors/pattern_override1.py new file mode 100644 index 00000000..0a46bed3 --- /dev/null +++ b/tests/error/comprehension_errors/pattern_override1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[tuple[Qubit, Qubit]]) -> linst[Qubit]: + return [q for q, q in qs] + + +module.compile() diff --git a/tests/error/comprehension_errors/pattern_override2.err b/tests/error/comprehension_errors/pattern_override2.err new file mode 100644 index 00000000..58c71958 --- /dev/null +++ b/tests/error/comprehension_errors/pattern_override2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: +13: return [q for q in qs for q in xs] + ^ +GuppyError: Variable `q` with linear type `Qubit` is not used diff --git a/tests/error/comprehension_errors/pattern_override2.py b/tests/error/comprehension_errors/pattern_override2.py new file mode 100644 index 00000000..b684e202 --- /dev/null +++ b/tests/error/comprehension_errors/pattern_override2.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit], xs: list[int]) -> linst[Qubit]: + return [q for q in qs for q in xs] + + +module.compile() diff --git a/tests/error/comprehension_errors/used_twice1.err b/tests/error/comprehension_errors/used_twice1.err new file mode 100644 index 00000000..013258e9 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo(qs: linst[Qubit]) -> linst[tuple[Qubit, Qubit]]: +13: return [(q, q) for q in qs] + ^ +GuppyError: Variable `q` with linear type `Qubit` was already used (at 13:13) diff --git a/tests/error/comprehension_errors/used_twice1.py b/tests/error/comprehension_errors/used_twice1.py new file mode 100644 index 00000000..920ae354 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice1.py @@ -0,0 +1,16 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[Qubit]) -> linst[tuple[Qubit, Qubit]]: + return [(q, q) for q in qs] + + +module.compile() diff --git a/tests/error/comprehension_errors/used_twice2.err b/tests/error/comprehension_errors/used_twice2.err new file mode 100644 index 00000000..07f49da6 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(qs: linst[Qubit]) -> linst[Qubit]: +18: return [q for q in qs if bar(q)] + ^ +GuppyError: Variable `q` with linear type `Qubit` was already used (at 18:33) diff --git a/tests/error/comprehension_errors/used_twice2.py b/tests/error/comprehension_errors/used_twice2.py new file mode 100644 index 00000000..8d2ffbe1 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice2.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> bool: + ... + + +@guppy(module) +def foo(qs: linst[Qubit]) -> linst[Qubit]: + return [q for q in qs if bar(q)] + + +module.compile() diff --git a/tests/error/comprehension_errors/used_twice3.err b/tests/error/comprehension_errors/used_twice3.err new file mode 100644 index 00000000..f8add102 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice3.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:18 + +16: @guppy(module) +17: def foo(qs: linst[Qubit]) -> linst[Qubit]: +18: return [q for q in qs for x in bar(q)] + ^ +GuppyError: Variable `q` with linear type `Qubit` was already used (at 18:39) diff --git a/tests/error/comprehension_errors/used_twice3.py b/tests/error/comprehension_errors/used_twice3.py new file mode 100644 index 00000000..784070e0 --- /dev/null +++ b/tests/error/comprehension_errors/used_twice3.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy.declare(module) +def bar(q: Qubit) -> list[int]: + ... + + +@guppy(module) +def foo(qs: linst[Qubit]) -> linst[Qubit]: + return [q for q in qs for x in bar(q)] + + +module.compile() diff --git a/tests/error/errors_on_usage/for_new_var.err b/tests/error/errors_on_usage/for_new_var.err new file mode 100644 index 00000000..7aae5692 --- /dev/null +++ b/tests/error/errors_on_usage/for_new_var.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:8 + +6: for _ in xs: +7: y = 5 +8: return y + ^ +GuppyError: Variable `y` is not defined on all control-flow paths. diff --git a/tests/error/errors_on_usage/for_new_var.py b/tests/error/errors_on_usage/for_new_var.py new file mode 100644 index 00000000..1054ad61 --- /dev/null +++ b/tests/error/errors_on_usage/for_new_var.py @@ -0,0 +1,8 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[int]) -> int: + for _ in xs: + y = 5 + return y diff --git a/tests/error/errors_on_usage/for_target.err b/tests/error/errors_on_usage/for_target.err new file mode 100644 index 00000000..30602c60 --- /dev/null +++ b/tests/error/errors_on_usage/for_target.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:8 + +6: for x in xs: +7: pass +8: return x + ^ +GuppyError: Variable `x` is not defined on all control-flow paths. diff --git a/tests/error/errors_on_usage/for_target.py b/tests/error/errors_on_usage/for_target.py new file mode 100644 index 00000000..fe4c0eb1 --- /dev/null +++ b/tests/error/errors_on_usage/for_target.py @@ -0,0 +1,8 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[int]) -> int: + for x in xs: + pass + return x diff --git a/tests/error/errors_on_usage/for_target_type_change.err b/tests/error/errors_on_usage/for_target_type_change.err new file mode 100644 index 00000000..82bffc0e --- /dev/null +++ b/tests/error/errors_on_usage/for_target_type_change.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: for x in xs: +8: pass +9: return x + ^ +GuppyError: Variable `x` can refer to different types: `int` (at 6:4) vs `bool` (at 7:8) diff --git a/tests/error/errors_on_usage/for_target_type_change.py b/tests/error/errors_on_usage/for_target_type_change.py new file mode 100644 index 00000000..7bc2b0b3 --- /dev/null +++ b/tests/error/errors_on_usage/for_target_type_change.py @@ -0,0 +1,9 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[bool]) -> int: + x = 5 + for x in xs: + pass + return x diff --git a/tests/error/errors_on_usage/for_type_change.err b/tests/error/errors_on_usage/for_type_change.err new file mode 100644 index 00000000..8b4fb557 --- /dev/null +++ b/tests/error/errors_on_usage/for_type_change.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: for x in xs: +8: y = True +9: return y + ^ +GuppyError: Variable `y` can refer to different types: `int` (at 6:4) vs `bool` (at 8:8) diff --git a/tests/error/errors_on_usage/for_type_change.py b/tests/error/errors_on_usage/for_type_change.py new file mode 100644 index 00000000..96f4017c --- /dev/null +++ b/tests/error/errors_on_usage/for_type_change.py @@ -0,0 +1,9 @@ +from guppy.decorator import guppy + + +@guppy +def foo(xs: list[int]) -> int: + y = 5 + for x in xs: + y = True + return y diff --git a/tests/error/iter_errors/__init__.py b/tests/error/iter_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/iter_errors/end_missing.err b/tests/error/iter_errors/end_missing.err new file mode 100644 index 00000000..90ad0692 --- /dev/null +++ b/tests/error/iter_errors/end_missing.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:33 + +31: @guppy(module) +32: def test(x: MyType) -> None: +33: for _ in x: + ^ +GuppyTypeError: Expression of type `MyIter` is not an iterator since it does not implement the `__end__` method diff --git a/tests/error/iter_errors/end_missing.py b/tests/error/iter_errors/end_missing.py new file mode 100644 index 00000000..33eafe12 --- /dev/null +++ b/tests/error/iter_errors/end_missing.py @@ -0,0 +1,37 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyIter: + """An iterator that is missing the `__end__` method.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[None, "MyIter"]: + ... + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/end_wrong_type.err b/tests/error/iter_errors/end_wrong_type.err new file mode 100644 index 00000000..635d6f84 --- /dev/null +++ b/tests/error/iter_errors/end_wrong_type.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:37 + +35: @guppy(module) +36: def test(x: MyType) -> None: +37: for _ in x: + ^ +GuppyError: Method `MyIter.__end__` has signature `MyIter -> MyIter`, but expected `MyIter -> None` diff --git a/tests/error/iter_errors/end_wrong_type.py b/tests/error/iter_errors/end_wrong_type.py new file mode 100644 index 00000000..82568d2b --- /dev/null +++ b/tests/error/iter_errors/end_wrong_type.py @@ -0,0 +1,41 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyIter: + """An iterator where the `__end__` method has the wrong signature.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[None, "MyIter"]: + ... + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> "MyIter": + ... + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/hasnext_missing.err b/tests/error/iter_errors/hasnext_missing.err new file mode 100644 index 00000000..949b776f --- /dev/null +++ b/tests/error/iter_errors/hasnext_missing.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:33 + +31: @guppy(module) +32: def test(x: MyType) -> None: +33: for _ in x: + ^ +GuppyTypeError: Expression of type `MyIter` is not an iterator since it does not implement the `__hasnext__` method diff --git a/tests/error/iter_errors/hasnext_missing.py b/tests/error/iter_errors/hasnext_missing.py new file mode 100644 index 00000000..68b11c7a --- /dev/null +++ b/tests/error/iter_errors/hasnext_missing.py @@ -0,0 +1,37 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyIter: + """An iterator that is missing the `__hasnext__` method.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[None, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/hasnext_wrong_type.err b/tests/error/iter_errors/hasnext_wrong_type.err new file mode 100644 index 00000000..234e217a --- /dev/null +++ b/tests/error/iter_errors/hasnext_wrong_type.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:37 + +35: @guppy(module) +36: def test(x: MyType) -> None: +37: for _ in x: + ^ +GuppyError: Method `MyIter.__hasnext__` has signature `MyIter -> bool`, but expected `MyIter -> (bool, MyIter)` diff --git a/tests/error/iter_errors/hasnext_wrong_type.py b/tests/error/iter_errors/hasnext_wrong_type.py new file mode 100644 index 00000000..f276f1b7 --- /dev/null +++ b/tests/error/iter_errors/hasnext_wrong_type.py @@ -0,0 +1,41 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyIter: + """An iterator where the `__hasnext__` method has the wrong signature.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[None, "MyIter"]: + ... + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> bool: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/iter_missing.err b/tests/error/iter_errors/iter_missing.err new file mode 100644 index 00000000..3e3c571d --- /dev/null +++ b/tests/error/iter_errors/iter_missing.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:16 + +14: @guppy(module) +15: def test(x: MyType) -> None: +16: for _ in x: + ^ +GuppyTypeError: Expression of type `MyType` is not iterable diff --git a/tests/error/iter_errors/iter_missing.py b/tests/error/iter_errors/iter_missing.py new file mode 100644 index 00000000..de04fdb6 --- /dev/null +++ b/tests/error/iter_errors/iter_missing.py @@ -0,0 +1,20 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyType: + """A non-iterable type.""" + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/iter_wrong_type.err b/tests/error/iter_errors/iter_wrong_type.err new file mode 100644 index 00000000..cc10baab --- /dev/null +++ b/tests/error/iter_errors/iter_wrong_type.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:20 + +18: @guppy(module) +19: def test(x: MyType) -> None: +20: for _ in x: + ^ +GuppyError: Method `MyType.__iter__` has signature `(MyType, int) -> MyType`, but expected `MyType -> ?Iter` diff --git a/tests/error/iter_errors/iter_wrong_type.py b/tests/error/iter_errors/iter_wrong_type.py new file mode 100644 index 00000000..39a34163 --- /dev/null +++ b/tests/error/iter_errors/iter_wrong_type.py @@ -0,0 +1,24 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyType: + """A type where the `__iter__` method has the wrong signature.""" + + @guppy.declare(module) + def __iter__(self: "MyType", x: int) -> "MyType": + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/next_missing.err b/tests/error/iter_errors/next_missing.err new file mode 100644 index 00000000..8ba51c01 --- /dev/null +++ b/tests/error/iter_errors/next_missing.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:33 + +31: @guppy(module) +32: def test(x: MyType) -> None: +33: for _ in x: + ^ +GuppyTypeError: Expression of type `MyIter` is not an iterator since it does not implement the `__next__` method diff --git a/tests/error/iter_errors/next_missing.py b/tests/error/iter_errors/next_missing.py new file mode 100644 index 00000000..8142d1f6 --- /dev/null +++ b/tests/error/iter_errors/next_missing.py @@ -0,0 +1,37 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyIter: + """An iterator that is missing the `__next__` method.""" + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/iter_errors/next_wrong_type.err b/tests/error/iter_errors/next_wrong_type.err new file mode 100644 index 00000000..b9bac527 --- /dev/null +++ b/tests/error/iter_errors/next_wrong_type.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:37 + +35: @guppy(module) +36: def test(x: MyType) -> None: +37: for _ in x: + ^ +GuppyError: Method `MyIter.__next__` has signature `MyIter -> (MyIter, float)`, but expected `MyIter -> (?T, MyIter)` diff --git a/tests/error/iter_errors/next_wrong_type.py b/tests/error/iter_errors/next_wrong_type.py new file mode 100644 index 00000000..0d5ae680 --- /dev/null +++ b/tests/error/iter_errors/next_wrong_type.py @@ -0,0 +1,41 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyIter: + """An iterator where the `__next__` method has the wrong signature.""" + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple["MyIter", float]: + ... + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + +@guppy.type(module, tys.TupleType(inner=[])) +class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + +@guppy(module) +def test(x: MyType) -> None: + for _ in x: + pass + + +module.compile() diff --git a/tests/error/linear_errors/branch_use.err b/tests/error/linear_errors/branch_use.err index 742a0ca7..c096aa16 100644 --- a/tests/error/linear_errors/branch_use.err +++ b/tests/error/linear_errors/branch_use.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:23 21: @guppy(module) 22: def foo(b: bool) -> bool: 23: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/break_unused.err b/tests/error/linear_errors/break_unused.err index a91d243e..8ee9039f 100644 --- a/tests/error/linear_errors/break_unused.err +++ b/tests/error/linear_errors/break_unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:25 23: b = False 24: while True: 25: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/continue_unused.err b/tests/error/linear_errors/continue_unused.err index 6b91f2be..7ecd4c5a 100644 --- a/tests/error/linear_errors/continue_unused.err +++ b/tests/error/linear_errors/continue_unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:25 23: b = False 24: while i > 0: 25: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/for_break.err b/tests/error/linear_errors/for_break.err new file mode 100644 index 00000000..271f11ed --- /dev/null +++ b/tests/error/linear_errors/for_break.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: rs += [q] +16: if b: +17: break + ^^^^^ +GuppyTypeError: Loop over iterator with linear type `linst[(Qubit, bool)]` cannot be terminated prematurely diff --git a/tests/error/linear_errors/for_break.py b/tests/error/linear_errors/for_break.py new file mode 100644 index 00000000..3f5f6cff --- /dev/null +++ b/tests/error/linear_errors/for_break.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[tuple[Qubit, bool]]) -> linst[Qubit]: + rs: linst[Qubit] = [] + for q, b in qs: + rs += [q] + if b: + break + return rs + + +module.compile() diff --git a/tests/error/linear_errors/for_return.err b/tests/error/linear_errors/for_return.err new file mode 100644 index 00000000..33fab35a --- /dev/null +++ b/tests/error/linear_errors/for_return.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: rs += [q] +16: if b: +17: return [] + ^^^^^^^^^ +GuppyTypeError: Loop over iterator with linear type `linst[(Qubit, bool)]` cannot be terminated prematurely diff --git a/tests/error/linear_errors/for_return.py b/tests/error/linear_errors/for_return.py new file mode 100644 index 00000000..05982f9c --- /dev/null +++ b/tests/error/linear_errors/for_return.py @@ -0,0 +1,21 @@ +import guppy.prelude.quantum as quantum +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.hugr.tys import Qubit +from guppy.prelude.builtins import linst + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo(qs: linst[tuple[Qubit, bool]]) -> linst[Qubit]: + rs: linst[Qubit] = [] + for q, b in qs: + rs += [q] + if b: + return [] + return rs + + +module.compile() diff --git a/tests/error/linear_errors/if_both_unused.err b/tests/error/linear_errors/if_both_unused.err index db518100..f942875c 100644 --- a/tests/error/linear_errors/if_both_unused.err +++ b/tests/error/linear_errors/if_both_unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:19 17: def foo(b: bool) -> int: 18: if b: 19: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/if_both_unused_reassign.err b/tests/error/linear_errors/if_both_unused_reassign.err index cacd42ca..0b0d5dbe 100644 --- a/tests/error/linear_errors/if_both_unused_reassign.err +++ b/tests/error/linear_errors/if_both_unused_reassign.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:19 17: def foo(b: bool) -> Qubit: 18: if b: 19: q = new_qubit() - ^^^^^^^^^^^^^^^ + ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/unused.err b/tests/error/linear_errors/unused.err index 7d62fee3..764d3d95 100644 --- a/tests/error/linear_errors/unused.err +++ b/tests/error/linear_errors/unused.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:13 11: @guppy(module) 12: def foo(q: Qubit) -> int: 13: x = q - ^^^^^ + ^ GuppyError: Variable `x` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/unused_same_block.err b/tests/error/linear_errors/unused_same_block.err index 960f266c..835fff05 100644 --- a/tests/error/linear_errors/unused_same_block.err +++ b/tests/error/linear_errors/unused_same_block.err @@ -3,5 +3,5 @@ Guppy compilation failed. Error in file $FILE:13 11: @guppy(module) 12: def foo(q: Qubit) -> int: 13: x = q - ^^^^^ + ^ GuppyError: Variable `x` with linear type `Qubit` is not used diff --git a/tests/error/misc_errors/list_linear.err b/tests/error/misc_errors/list_linear.err new file mode 100644 index 00000000..39da0457 --- /dev/null +++ b/tests/error/misc_errors/list_linear.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy(module) +12: def foo() -> list[Qubit]: + ^^^^^^^^^^^ +GuppyError: Type `list` cannot store linear data, use `linst` instead diff --git a/tests/error/misc_errors/list_linear.py b/tests/error/misc_errors/list_linear.py new file mode 100644 index 00000000..5b377e53 --- /dev/null +++ b/tests/error/misc_errors/list_linear.py @@ -0,0 +1,17 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.prelude.quantum import Qubit + +import guppy.prelude.quantum as quantum + + +module = GuppyModule("test") +module.load(quantum) + + +@guppy(module) +def foo() -> list[Qubit]: + return [] + + +module.compile() diff --git a/tests/error/test_comprehension_errors.py b/tests/error/test_comprehension_errors.py new file mode 100644 index 00000000..9e40f5fa --- /dev/null +++ b/tests/error/test_comprehension_errors.py @@ -0,0 +1,20 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "comprehension_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_comprehension_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/error/test_iter_errors.py b/tests/error/test_iter_errors.py new file mode 100644 index 00000000..60c8c2a6 --- /dev/null +++ b/tests/error/test_iter_errors.py @@ -0,0 +1,20 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "iter_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_iter_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/integration/test_comprehension.py b/tests/integration/test_comprehension.py new file mode 100644 index 00000000..aec2f459 --- /dev/null +++ b/tests/integration/test_comprehension.py @@ -0,0 +1,260 @@ +from guppy.decorator import guppy +from guppy.hugr import tys +from guppy.module import GuppyModule +from guppy.prelude.builtins import linst +from guppy.prelude.quantum import Qubit, h, cx + +import guppy.prelude.quantum as quantum + + +def test_basic(validate): + @guppy + def test(xs: list[float]) -> list[int]: + return [int(x) for x in xs] + + validate(test) + + +def test_basic_linear(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(qs: linst[Qubit]) -> linst[Qubit]: + return [h(q) for q in qs] + + validate(module.compile()) + + +def test_guarded(validate): + @guppy + def test(xs: list[int]) -> list[int]: + return [2 * x for x in xs if x > 0 if x < 20] + + validate(test) + + +def test_multiple(validate): + @guppy + def test(xs: list[int], ys: list[int]) -> list[int]: + return [x + y for x in xs for y in ys if x + y > 42] + + validate(test) + + +def test_tuple_pat(validate): + @guppy + def test(xs: list[tuple[int, int, float]]) -> list[float]: + return [x + y * z for x, y, z in xs if x - y > z] + + validate(test) + + +def test_tuple_pat_linear(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(qs: linst[tuple[int, Qubit, Qubit]]) -> linst[tuple[Qubit, Qubit]]: + return [cx(q1, q2) for _, q1, q2 in qs] + + validate(module.compile()) + + +def test_tuple_return(validate): + @guppy + def test(xs: list[int], ys: list[float]) -> list[tuple[int, float]]: + return [(x, y) for x in xs for y in ys] + + validate(test) + + +def test_dependent(validate): + module = GuppyModule("test") + + @guppy.declare(module) + def process(x: float) -> list[int]: + ... + + @guppy(module) + def test(xs: list[float]) -> list[float]: + return [x * y for x in xs if x > 0 for y in process(x) if y > x] + + validate(module.compile()) + + +def test_capture(validate): + @guppy + def test(xs: list[int], y: int) -> list[int]: + return [x + y for x in xs if x > y] + + validate(test) + + +def test_scope(validate): + @guppy + def test(xs: list[None]) -> float: + x = 42.0 + [x for x in xs] + return x + + validate(test) + + +def test_nested_left(validate): + @guppy + def test(xs: list[int], ys: list[float]) -> list[list[float]]: + return [[x + y for y in ys] for x in xs] + + validate(test) + + +def test_nested_right(validate): + @guppy + def test(xs: list[int]) -> list[int]: + return [-x for x in [2 * x for x in xs]] + + validate(test) + + +def test_nested_linear(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(qs: linst[Qubit]) -> linst[Qubit]: + return [h(q) for q in [h(q) for q in qs]] + + validate(module.compile()) + + +def test_classical_linst_comp(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(xs: list[int]) -> linst[int]: + return [x for x in xs] + + validate(module.compile()) + + +def test_linear_discard(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def discard(q: Qubit) -> None: + ... + + @guppy(module) + def test(qs: linst[Qubit]) -> list[None]: + return [discard(q) for q in qs] + + validate(module.compile()) + + +def test_linear_consume_in_guard(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def cond(q: Qubit) -> bool: + ... + + @guppy(module) + def test(qs: linst[tuple[int, Qubit]]) -> list[int]: + return [x for x, q in qs if cond(q)] + + validate(module.compile()) + + +def test_linear_consume_in_iter(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def make_list(q: Qubit) -> list[int]: + ... + + @guppy(module) + def test(qs: linst[Qubit]) -> list[int]: + return [x for q in qs for x in make_list(q)] + + validate(module.compile()) + + +def test_linear_next_nonlinear_iter(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.type(module, tys.TupleType(inner=[])) + class MyIter: + """An iterator that yields linear values but is not linear itself.""" + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[Qubit, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + @guppy.type(module, tys.TupleType(inner=[])) + class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + @guppy(module) + def test(mt: MyType, xs: list[int]) -> linst[tuple[int, Qubit]]: + # We can use `mt` in an inner loop since it's not linear + return [(x, q) for x in xs for q in mt] + + validate(module.compile()) + + +def test_nonlinear_next_linear_iter(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.type( + module, + tys.Opaque(extension="prelude", id="qubit", args=[], bound=tys.TypeBound.Any), + linear=True, + ) + class MyIter: + """A linear iterator that yields non-linear values.""" + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[int, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + @guppy.type(module, tys.TupleType(inner=[])) + class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + @guppy(module) + def test(mt: MyType, xs: list[int]) -> linst[tuple[int, int]]: + # We can use `mt` in an outer loop since the target `x` is not linear + return [(x, x + y) for x in mt for y in xs] + + validate(module.compile()) diff --git a/tests/integration/test_for.py b/tests/integration/test_for.py new file mode 100644 index 00000000..917c3ffc --- /dev/null +++ b/tests/integration/test_for.py @@ -0,0 +1,129 @@ +from guppy.decorator import guppy + + +def test_basic(validate): + @guppy + def foo(xs: list[int]) -> int: + for x in xs: + pass + return 0 + + validate(foo) + + +def test_counting_loop(validate): + @guppy + def foo(xs: list[int]) -> int: + s = 0 + for x in xs: + s += x + return s + + validate(foo) + + +def test_multi_targets(validate): + @guppy + def foo(xs: list[tuple[int, float]]) -> float: + s = 0.0 + for x, y in xs: + s += x * y + return s + + validate(foo) + + +def test_multi_targets_same(validate): + @guppy + def foo(xs: list[tuple[int, float]]) -> float: + s = 1.0 + for x, x in xs: + s *= x + return s + + validate(foo) + + +def test_reassign_iter(validate): + @guppy + def foo(xs: list[int]) -> int: + s = 1 + for x in xs: + xs = False + s += x + return s + + validate(foo) + + +def test_break(validate): + @guppy + def foo(xs: list[int]) -> int: + i = 1 + for x in xs: + if x >= 42: + break + i *= x + return i + + validate(foo) + + +def test_continue(validate): + @guppy + def foo(xs: list[int]) -> int: + i = len(xs) + for x in xs: + if x >= 42: + continue + i -= 1 + return i + + validate(foo) + + +def test_return_in_loop(validate): + @guppy + def foo(xs: list[int]) -> int: + y = 42 + for x in xs: + if x >= 1337: + return x * y + y = y + x + return y + + validate(foo) + + +def test_nested_loop(validate): + @guppy + def foo(xs: list[int], ys: list[int]) -> int: + p = 0 + for x in xs: + s = 0 + for y in ys: + s += y * x + p += s - x + return p + + validate(foo) + + +def test_nested_loop_break_continue(validate): + @guppy + def foo(xs: list[int], ys: list[int]) -> int: + p = 0 + for x in xs: + s = 0 + for y in ys: + if x % 2 == 0: + continue + s += x + if s > y: + s = y + else: + break + p += s * x + return p + + validate(foo) diff --git a/tests/integration/test_linear.py b/tests/integration/test_linear.py index 35091b61..cdd975d9 100644 --- a/tests/integration/test_linear.py +++ b/tests/integration/test_linear.py @@ -1,5 +1,7 @@ from guppy.decorator import guppy +from guppy.hugr import tys from guppy.module import GuppyModule +from guppy.prelude.builtins import linst from guppy.prelude.quantum import Qubit import guppy.prelude.quantum as quantum @@ -207,6 +209,101 @@ def foo(i: bool) -> bool: return b +def test_for(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(qs: linst[tuple[Qubit, Qubit]]) -> linst[Qubit]: + rs: linst[Qubit] = [] + for q1, q2 in qs: + q1, q2 = cx(q1, q2) + rs += [q1, q2] + return rs + + validate(module.compile()) + + +def test_for_measure(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def measure(q: Qubit) -> bool: + ... + + @guppy(module) + def test(qs: linst[Qubit]) -> bool: + parity = False + for q in qs: + parity |= measure(q) + return parity + + validate(module.compile()) + + +def test_for_continue(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.declare(module) + def measure(q: Qubit) -> bool: + ... + + @guppy(module) + def test(qs: linst[Qubit]) -> int: + x = 0 + for q in qs: + if measure(q): + continue + x += 1 + return x + + validate(module.compile()) + + +def test_for_nonlinear_break(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy.type(module, tys.TupleType(inner=[])) + class MyIter: + """An iterator that yields linear values but is not linear itself.""" + + @guppy.declare(module) + def __hasnext__(self: "MyIter") -> tuple[bool, "MyIter"]: + ... + + @guppy.declare(module) + def __next__(self: "MyIter") -> tuple[Qubit, "MyIter"]: + ... + + @guppy.declare(module) + def __end__(self: "MyIter") -> None: + ... + + @guppy.type(module, tys.TupleType(inner=[])) + class MyType: + """Type that produces the iterator above.""" + + @guppy.declare(module) + def __iter__(self: "MyType") -> MyIter: + ... + + @guppy.declare(module) + def measure(q: Qubit) -> bool: + ... + + @guppy(module) + def test(mt: MyType, xs: list[int]) -> None: + # We can break, since `mt` itself is not linear + for q in mt: + if measure(q): + break + + validate(module.compile()) + + def test_rus(validate): module = GuppyModule("test") module.load(quantum) diff --git a/tests/integration/test_linst.py b/tests/integration/test_linst.py new file mode 100644 index 00000000..4086d234 --- /dev/null +++ b/tests/integration/test_linst.py @@ -0,0 +1,66 @@ +from guppy.decorator import guppy +from guppy.module import GuppyModule +from guppy.prelude.builtins import linst +from guppy.prelude.quantum import Qubit, h + +import guppy.prelude.quantum as quantum + + +def test_types(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test( + xs: linst[Qubit], ys: linst[tuple[int, Qubit]] + ) -> tuple[linst[Qubit], linst[tuple[int, Qubit]]]: + return xs, ys + + validate(module.compile()) + + +def test_len(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(xs: linst[Qubit]) -> tuple[int, linst[Qubit]]: + return len(xs) + + validate(module.compile()) + + +def test_literal(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(q1: Qubit, q2: Qubit) -> linst[Qubit]: + return [q1, h(q2)] + + validate(module.compile()) + + +def test_arith(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test(xs: linst[Qubit], ys: linst[Qubit], q: Qubit) -> linst[Qubit]: + xs += [q] + return xs + ys + + validate(module.compile()) + + +def test_copyable(validate): + module = GuppyModule("test") + module.load(quantum) + + @guppy(module) + def test() -> linst[int]: + xs: linst[int] = [1, 2, 3] + ys: linst[int] = [] + return xs + xs + + validate(module.compile()) diff --git a/tests/integration/test_list.py b/tests/integration/test_list.py new file mode 100644 index 00000000..e59db276 --- /dev/null +++ b/tests/integration/test_list.py @@ -0,0 +1,45 @@ +from guppy.decorator import guppy + + +def test_types(validate): + @guppy + def test( + xs: list[int], ys: list[tuple[int, float]] + ) -> tuple[list[int], list[tuple[int, float]]]: + return xs, ys + + validate(test) + + +def test_len(validate): + @guppy + def test(xs: list[int]) -> int: + return len(xs) + + validate(test) + + +def test_literal(validate): + @guppy + def test(x: float) -> list[float]: + return [1.0, 2.0, 3.0, 4.0 + x] + + validate(test) + + +def test_arith(validate): + @guppy + def test(xs: list[int]) -> list[int]: + xs += xs + [42] + xs = 3 * xs + return xs * 4 + + validate(test) + + +def test_subscript(validate): + @guppy + def test(xs: list[float], i: int) -> float: + return xs[2 * i] + + validate(test) diff --git a/tests/integration/test_poly.py b/tests/integration/test_poly.py index 2446e308..ee8f112c 100644 --- a/tests/integration/test_poly.py +++ b/tests/integration/test_poly.py @@ -134,6 +134,22 @@ def main() -> None: validate(module.compile()) +def test_infer_list(validate): + module = GuppyModule("test") + T = guppy.type_var(module, "T") + + @guppy.declare(module) + def foo() -> T: + ... + + @guppy(module) + def main() -> None: + xs: list[int] = [foo()] + ys = [1.0, foo()] + + validate(module.compile()) + + def test_infer_nested(validate): module = GuppyModule("test") T = guppy.type_var(module, "T") diff --git a/validator/Cargo.lock b/validator/Cargo.lock index 8de1c9f0..ade968f5 100644 --- a/validator/Cargo.lock +++ b/validator/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "aho-corasick" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] @@ -47,9 +47,9 @@ dependencies = [ [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cfg-if" @@ -123,27 +123,33 @@ checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650" [[package]] name = "either" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "enum_dispatch" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11f36e95862220b211a6e2aa5eca09b4fa391b13cd52ceb8035a24bf65a79de2" +checksum = "8f33313078bb8d4d05a2733a94ac4c2d8a0df9a2b84424ebf4f33bfc224a890e" dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.48", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "erased-serde" -version = "0.3.25" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2b0c2380453a92ea8b6c8e5f64ecaafccddde8ceab55ff7a8ac1029f894569" +checksum = "55d05712b2d8d88102bc9868020c9e5c7a1f5527c452b9b97450a1d006140ba7" dependencies = [ "serde", ] @@ -160,22 +166,11 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" -[[package]] -name = "ghost" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e77ac7b51b8e6313251737fcef4b1c01a2ea102bde68415b62c0ee9268fec357" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.48", -] - [[package]] name = "hashbrown" -version = "0.12.3" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" [[package]] name = "heck" @@ -194,11 +189,11 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.3" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ - "autocfg", + "equivalent", "hashbrown", ] @@ -210,12 +205,9 @@ checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" [[package]] name = "inventory" -version = "0.3.6" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0539b5de9241582ce6bd6b0ba7399313560151e58c9aaf8b74b711b1bdce644" -dependencies = [ - "ghost", -] +checksum = "c8573b2b1fb643a372c73b23f4da5f888677feef3305146d68a539250a9bccc7" [[package]] name = "itertools" @@ -228,9 +220,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.6" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "lazy_static" @@ -240,15 +232,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.146" +version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" dependencies = [ "autocfg", "scopeguard", @@ -256,9 +248,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.6.3" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "memoffset" @@ -271,9 +263,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" dependencies = [ "autocfg", "num-integer", @@ -305,18 +297,18 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", ] [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "parking_lot" @@ -330,9 +322,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.8" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", "libc", @@ -343,15 +335,15 @@ dependencies = [ [[package]] name = "paste" -version = "1.0.12" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] name = "petgraph" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" +checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", "indexmap", @@ -382,9 +374,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.19.0" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cffef52f74ec3b1a1baf295d9b8fcc3070327aefc39a6d00656b13c1d0b8885c" +checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" dependencies = [ "cfg-if", "indoc", @@ -399,9 +391,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.19.0" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "713eccf888fb05f1a96eb78c0dbc51907fee42b3377272dc902eb38985f418d5" +checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" dependencies = [ "once_cell", "target-lexicon", @@ -409,9 +401,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.19.0" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b2ecbdcfb01cbbf56e179ce969a048fd7305a66d4cdf3303e0da09d69afe4c3" +checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" dependencies = [ "libc", "pyo3-build-config", @@ -419,9 +411,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.19.0" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b78fdc0899f2ea781c463679b20cb08af9247febc8d052de941951024cd8aea0" +checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -431,9 +423,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.19.0" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60da7b84f1227c3e2fe7593505de274dcf4c8928b4e0a1c23d551a14e4e80a0f" +checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" dependencies = [ "proc-macro2", "quote", @@ -488,18 +480,18 @@ checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" [[package]] name = "redox_syscall" -version = "0.3.5" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ "bitflags", ] [[package]] name = "regex" -version = "1.9.5" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ "aho-corasick", "memchr", @@ -509,9 +501,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.8" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ "aho-corasick", "memchr", @@ -520,15 +512,15 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.5" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "rmp" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44519172358fd6d58656c86ab8e7fbc9e1490c3e8f14d35ed78ca0dd07403c9f" +checksum = "7f9860a6cc38ed1da53456442089b4dfa35e7cedaa326df63017af88385e6b20" dependencies = [ "byteorder", "num-traits", @@ -537,9 +529,9 @@ dependencies = [ [[package]] name = "rmp-serde" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5b13be192e0220b8afb7222aa5813cb62cc269ebb5cac346ca6487681d2913e" +checksum = "bffea85eea980d8a74453e5d02a8d93028f3c34725de143085a844ebe953258a" dependencies = [ "byteorder", "rmp", @@ -563,21 +555,21 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "ryu" -version = "1.0.13" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "scopeguard" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "semver" -version = "1.0.17" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" +checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" [[package]] name = "serde" @@ -612,9 +604,9 @@ dependencies = [ [[package]] name = "serde_yaml" -version = "0.9.21" +version = "0.9.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9d684e3ec7de3bf5466b32bd75303ac16f0736426e5a4e0d6e489559ce1249c" +checksum = "b1bf28c79a99f70ee1f1d83d10c875d2e70618417fda01ad1785e027579d9d38" dependencies = [ "indexmap", "itoa", @@ -625,15 +617,15 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.10.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "2593d31f82ead8df961d8bd23a64c2ccf2eb5dd34b0a34bfb4dd54011c72009e" [[package]] name = "smol_str" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74212e6bbe9a4352329b2f68ba3130c15a3f26fe88ff22dbdc6cdd58fa85e99c" +checksum = "e6845563ada680337a52d43bb0b29f396f2d911616f6573012645b9e3d048a49" dependencies = [ "serde", ] @@ -687,24 +679,24 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "target-lexicon" -version = "0.12.8" +version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1c7f239eb94671427157bd93b3694320f3668d4e1eff08c7285366fd777fac" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", "quote", @@ -713,9 +705,9 @@ dependencies = [ [[package]] name = "typetag" -version = "0.2.8" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6898cc6f6a32698cc3e14d5632a14d2b23ed9f7b11e6b8e05ce685990acc22" +checksum = "c43148481c7b66502c48f35b8eef38b6ccdc7a9f04bd4cc294226d901ccc9bc7" dependencies = [ "erased-serde", "inventory", @@ -726,9 +718,9 @@ dependencies = [ [[package]] name = "typetag-impl" -version = "0.2.8" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c3e1c30cedd24fc597f7d37a721efdbdc2b1acae012c1ef1218f4c7c2c0f3e7" +checksum = "291db8a81af4840c10d636e047cac67664e343be44e24dfdbd1492df9a5d3390" dependencies = [ "proc-macro2", "quote", @@ -737,9 +729,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.9" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unindent" @@ -749,19 +741,19 @@ checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" [[package]] name = "unsafe-libyaml" -version = "0.2.8" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1865806a559042e51ab5414598446a5871b561d21b6764f2eabb0dd481d880a6" +checksum = "ab4c90930b95a82d00dc9e9ac071b4991924390d46cbd0dfe566148667605e4b" [[package]] name = "utf8-width" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5190c9442dcdaf0ddd50f37420417d219ae5261bbf5db120d0f9bab996c9cba1" +checksum = "86bd8d4e895da8537e5315b8254664e6b769c4ff3db18321b297a1e7004392e3" [[package]] name = "validator" -version = "0.1.0" +version = "0.2.0" dependencies = [ "lazy_static", "pyo3", @@ -772,9 +764,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -787,45 +779,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_i686_gnu" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_x86_64_gnu" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "wyz" diff --git a/validator/Cargo.toml b/validator/Cargo.toml index 305f1d2f..de2e314a 100644 --- a/validator/Cargo.toml +++ b/validator/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "validator" -version = "0.1.0" +version = "0.2.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/validator/src/lib.rs b/validator/src/lib.rs index a809cd44..57f561b7 100644 --- a/validator/src/lib.rs +++ b/validator/src/lib.rs @@ -1,5 +1,6 @@ use hugr::extension::{ExtensionRegistry, PRELUDE}; use hugr::std_extensions::arithmetic::{float_ops, float_types, int_ops, int_types}; +use hugr::std_extensions::collections; use hugr::std_extensions::logic; use lazy_static::lazy_static; use pyo3::prelude::*; @@ -12,6 +13,7 @@ lazy_static! { int_ops::EXTENSION.to_owned(), float_types::EXTENSION.to_owned(), float_ops::EXTENSION.to_owned(), + collections::EXTENSION.to_owned(), ]) .unwrap(); }