diff --git a/guppylang/cfg/bb.py b/guppylang/cfg/bb.py index 6bb6b75a..776ff333 100644 --- a/guppylang/cfg/bb.py +++ b/guppylang/cfg/bb.py @@ -142,13 +142,15 @@ def _handle_assign_target(self, lhs: ast.expr, node: ast.stmt) -> None: match lhs: case ast.Name(id=name): self.stats.assigned[name] = node - case ast.Tuple(elts=elts): + case ast.Tuple(elts=elts) | ast.List(elts=elts): for elt in elts: self._handle_assign_target(elt, node) case ast.Attribute(value=value): # Setting attributes counts as a use of the value, so we do a regular # visit instead of treating it like a LHS self.visit(value) + case ast.Starred(value=value): + self._handle_assign_target(value, node) def visit_DesugaredListComp(self, node: DesugaredListComp) -> None: self._handle_comprehension(node.generators, node.elt) diff --git a/guppylang/cfg/builder.py b/guppylang/cfg/builder.py index 29316a0e..87715c27 100644 --- a/guppylang/cfg/builder.py +++ b/guppylang/cfg/builder.py @@ -316,56 +316,13 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name: def visit_ListComp(self, node: ast.ListComp) -> DesugaredListComp: check_lists_enabled(node) - generators, elt = self._build_comprehension(node.generators, node.elt, node) + generators, elt = desugar_comprehension(node.generators, node.elt, node) return with_loc(node, DesugaredListComp(elt=elt, generators=generators)) def visit_GeneratorExp(self, node: ast.GeneratorExp) -> DesugaredGeneratorExpr: - generators, elt = self._build_comprehension(node.generators, node.elt, node) + generators, elt = desugar_comprehension(node.generators, node.elt, node) return with_loc(node, DesugaredGeneratorExpr(elt=elt, generators=generators)) - def _build_comprehension( - self, generators: list[ast.comprehension], elt: ast.expr, node: ast.AST - ) -> tuple[list[DesugaredGenerator], ast.expr]: - # Check for illegal expressions - illegals = find_nodes(is_illegal_in_list_comp, node) - if illegals: - err = UnsupportedError( - illegals[0], - "This expression", - singular=True, - unsupported_in="a list comprehension", - ) - raise GuppyError(err) - - # Desugar into statements that create the iterator, check for a next element, - # get the next element, and finalise the iterator. - gens = [] - for g in generators: - if g.is_async: - raise GuppyError(UnsupportedError(g, "Async generators")) - g.iter = self.visit(g.iter) - it = make_var(next(tmp_vars), g.iter) - hasnext = make_var(next(tmp_vars), g.iter) - desugared = DesugaredGenerator( - iter=it, - hasnext=hasnext, - iter_assign=make_assign( - [it], with_loc(it, MakeIter(value=g.iter, origin_node=node)) - ), - hasnext_assign=make_assign( - [hasnext, it], with_loc(it, IterHasNext(value=it)) - ), - next_assign=make_assign( - [g.target, it], with_loc(it, IterNext(value=it)) - ), - iterend=with_loc(it, IterEnd(value=it)), - ifs=g.ifs, - ) - gens.append(desugared) - - elt = self.visit(elt) - return gens, elt - def visit_Call(self, node: ast.Call) -> ast.AST: return is_py_expression(node) or self.generic_visit(node) @@ -487,6 +444,56 @@ def generic_visit(self, node: ast.expr, bb: BB, true_bb: BB, false_bb: BB) -> No self.cfg.link(bb, true_bb) +def desugar_comprehension( + generators: list[ast.comprehension], elt: ast.expr, node: ast.AST +) -> tuple[list[DesugaredGenerator], ast.expr]: + """Helper function to desugar a comprehension node.""" + # Check for illegal expressions + illegals = find_nodes(is_illegal_in_list_comp, node) + if illegals: + err = UnsupportedError( + illegals[0], + "This expression", + singular=True, + unsupported_in="a list comprehension", + ) + raise GuppyError(err) + + # The check above ensures that the comprehension doesn't contain any control-flow + # expressions. Thus, we can use a dummy `ExprBuilder` to desugar the insides. + # TODO: Refactor so that desugaring is separate from control-flow building + dummy_cfg = CFG() + builder = ExprBuilder(dummy_cfg, dummy_cfg.entry_bb) + + # Desugar into statements that create the iterator, check for a next element, + # get the next element, and finalise the iterator. + gens = [] + for g in generators: + if g.is_async: + raise GuppyError(UnsupportedError(g, "Async generators")) + g.iter = builder.visit(g.iter) + it = make_var(next(tmp_vars), g.iter) + hasnext = make_var(next(tmp_vars), g.iter) + desugared = DesugaredGenerator( + iter=it, + hasnext=hasnext, + iter_assign=make_assign( + [it], with_loc(it, MakeIter(value=g.iter, origin_node=node)) + ), + hasnext_assign=make_assign( + [hasnext, it], with_loc(it, IterHasNext(value=it)) + ), + next_assign=make_assign([g.target, it], with_loc(it, IterNext(value=it))), + iterend=with_loc(it, IterEnd(value=it)), + ifs=g.ifs, + borrowed_outer_places=[], + ) + gens.append(desugared) + + elt = builder.visit(elt) + return gens, elt + + def is_functional_annotation(stmt: ast.stmt) -> bool: """Returns `True` iff the given statement is the functional pseudo-decorator. diff --git a/guppylang/checker/errors/type_errors.py b/guppylang/checker/errors/type_errors.py index 3aa5d406..df7d771b 100644 --- a/guppylang/checker/errors/type_errors.py +++ b/guppylang/checker/errors/type_errors.py @@ -190,6 +190,7 @@ class WrongNumberOfUnpacksError(Error): title: ClassVar[str] = "{prefix} values to unpack" expected: int actual: int + at_least: bool @property def prefix(self) -> str: @@ -200,9 +201,47 @@ def rendered_span_label(self) -> str: diff = self.expected - self.actual if diff < 0: msg = "Unexpected assignment " + ("targets" if diff < -1 else "target") + at_least = "at least " if self.at_least else "" else: msg = "Not enough assignment targets" - return f"{msg} (expected {self.expected}, got {self.actual})" + assert not self.at_least + at_least = "" + return f"{msg} (expected {self.expected}, got {at_least}{self.actual})" + + +@dataclass(frozen=True) +class UnpackableError(Error): + title: ClassVar[str] = "Unpackable" + span_label: ClassVar[str] = "Expression of type `{ty}` cannot be unpacked" + ty: Type + + @dataclass(frozen=True) + class NonStaticIter(Note): + message: ClassVar[str] = ( + "Unpacking of iterable types like `{ty}` is only allowed if the number of " + "items yielded by the iterator is statically known. This is not the case " + "for `{ty}`." + ) + + @dataclass(frozen=True) + class GenericSize(Note): + message: ClassVar[str] = ( + "Unpacking of iterable types like `{ty}` is only allowed if the number of " + "items yielded by the iterator is statically known. Here, the number of " + "items `{num}` is generic and can change between different function " + "invocations." + ) + num: Const + + +@dataclass(frozen=True) +class StarredTupleUnpackError(Error): + title: ClassVar[str] = "Invalid starred unpacking" + span_label: ClassVar[str] = ( + "Expression of type `{ty}` cannot be collected into a starred assignment since " + "the yielded items have different types" + ) + ty: Type @dataclass(frozen=True) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 3d23424a..d6753e30 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -1091,15 +1091,35 @@ def synthesize_comprehension( node: AstNode, gens: list[DesugaredGenerator], elt: ast.expr, ctx: Context ) -> tuple[list[DesugaredGenerator], ast.expr, Type]: """Helper function to synthesise the element type of a list comprehension.""" - from guppylang.checker.stmt_checker import StmtChecker - # If there are no more generators left, we can check the list element if not gens: elt, elt_ty = ExprSynthesizer(ctx).synthesize(elt) return gens, elt, elt_ty - # Check the iterator in the outer context + # Check the first generator gen, *gens = gens + gen, inner_ctx = check_generator(gen, ctx) + + # Check remaining generators in inner context + gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx) + + # The iter finalizer is again checked in the outer context + gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend) + gen.iterend = with_type(iterend_ty, gen.iterend) + return [gen, *gens], elt, elt_ty + + +def check_generator( + gen: DesugaredGenerator, ctx: Context +) -> tuple[DesugaredGenerator, Context]: + """Helper function to check a single generator. + + Returns the type annotated generator together with a new nested context in which the + generator variables are bound. + """ + from guppylang.checker.stmt_checker import StmtChecker + + # Check the iterator in the outer context gen.iter_assign = StmtChecker(ctx).visit_Assign(gen.iter_assign) # The rest is checked in a new nested context to ensure that variables don't escape @@ -1119,13 +1139,7 @@ def synthesize_comprehension( gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i]) gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx) - # Check remaining generators - gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx) - - # The iter finalizer is again checked in the outer context - gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend) - gen.iterend = with_type(iterend_ty, gen.iterend) - return [gen, *gens], elt, elt_ty + return gen, inner_ctx def eval_py_expr(node: PyExpr, ctx: Context) -> Any: diff --git a/guppylang/checker/stmt_checker.py b/guppylang/checker/stmt_checker.py index 177ba517..68269f80 100644 --- a/guppylang/checker/stmt_checker.py +++ b/guppylang/checker/stmt_checker.py @@ -9,10 +9,24 @@ """ import ast -from collections.abc import Sequence +import functools +from collections.abc import Iterable, Sequence +from dataclasses import replace +from itertools import takewhile +from typing import TypeVar, cast -from guppylang.ast_util import AstVisitor, with_loc, with_type +from guppylang.ast_util import ( + AstVisitor, + get_type, + with_loc, + with_type, +) from guppylang.cfg.bb import BB, BBStatement +from guppylang.cfg.builder import ( + desugar_comprehension, + make_var, + tmp_vars, +) from guppylang.checker.core import Context, FieldAccess, Variable from guppylang.checker.errors.generic import UnsupportedError from guppylang.checker.errors.type_errors import ( @@ -20,15 +34,44 @@ AssignNonPlaceHelp, AttributeNotFoundError, MissingReturnValueError, + StarredTupleUnpackError, + TypeInferenceError, + UnpackableError, WrongNumberOfUnpacksError, ) -from guppylang.checker.expr_checker import ExprChecker, ExprSynthesizer +from guppylang.checker.expr_checker import ( + ExprChecker, + ExprSynthesizer, + synthesize_comprehension, +) from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppylang.nodes import NestedFunctionDef, PlaceNode +from guppylang.nodes import ( + AnyUnpack, + DesugaredArrayComp, + IterableUnpack, + MakeIter, + NestedFunctionDef, + PlaceNode, + TupleUnpack, + UnpackPattern, +) from guppylang.span import Span, to_span +from guppylang.tys.builtin import ( + array_type, + get_iter_size, + is_sized_iter_type, + nat_type, +) +from guppylang.tys.const import ConstValue from guppylang.tys.parsing import type_from_ast from guppylang.tys.subst import Subst -from guppylang.tys.ty import NoneType, StructType, TupleType, Type +from guppylang.tys.ty import ( + ExistentialTypeVar, + NoneType, + StructType, + TupleType, + Type, +) class StmtChecker(AstVisitor[BBStatement]): @@ -55,81 +98,176 @@ def _check_expr( ) -> tuple[ast.expr, Subst]: return ExprChecker(self.ctx).check(node, ty, kind) - def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr: + @functools.singledispatchmethod + def _check_assign(self, lhs: ast.expr, rhs: ast.expr, rhs_ty: Type) -> ast.expr: """Helper function to check assignments with patterns.""" - match lhs: - # Easiest case is if the LHS pattern is a single variable. - case ast.Name(id=x): - var = Variable(x, ty, lhs) - self.ctx.locals[x] = var - return with_loc(lhs, with_type(ty, PlaceNode(place=var))) - - # The LHS could also be a field `expr.field` - case ast.Attribute(value=value, attr=attr): - # Unfortunately, the `attr` is just a string, not an AST node, so we - # have to compute its span by hand. This is fine since linebreaks are - # not allowed in the identifier following the `.` - span = to_span(lhs) - attr_span = Span(span.end.shift_left(len(attr)), span.end) - value, struct_ty = self._synth_expr(value) - if ( - not isinstance(struct_ty, StructType) - or attr not in struct_ty.field_dict - ): - raise GuppyTypeError( - AttributeNotFoundError(attr_span, struct_ty, attr) - ) - field = struct_ty.field_dict[attr] - # TODO: In the future, we could infer some type args here - if field.ty != ty: - # TODO: Get hold of a span for the RHS and use a regular - # `TypeMismatchError` instead (maybe with a custom hint). - raise GuppyTypeError( - AssignFieldTypeMismatchError(attr_span, ty, field) - ) - if not isinstance(value, PlaceNode): - # For now we complain if someone tries to assign to something that - # is not a place, e.g. `f().a = 4`. This would only make sense if - # there is another reference to the return value of `f`, otherwise - # the mutation cannot be observed. We can start supporting this once - # we have proper reference semantics. - err = UnsupportedError( - value, "Assigning to this expression", singular=True - ) - err.add_sub_diagnostic(AssignNonPlaceHelp(None, field)) - raise GuppyError(err) - if not field.ty.linear: - raise GuppyError( - UnsupportedError( - attr_span, "Mutation of classical fields", singular=True - ) - ) - place = FieldAccess(value.place, struct_ty.field_dict[attr], lhs) - return with_loc(lhs, with_type(ty, PlaceNode(place=place))) - - # The only other thing we support right now are tuples - case ast.Tuple(elts=elts) as lhs: - tys = ty.element_types if isinstance(ty, TupleType) else [ty] - n, m = len(elts), len(tys) - if n != m: - if n > m: - span = Span(to_span(elts[m]).start, to_span(elts[-1]).end) - else: - span = to_span(lhs) - raise GuppyTypeError(WrongNumberOfUnpacksError(span, m, n)) - lhs.elts = [ - self._check_assign(pat, el_ty, node) - for pat, el_ty in zip(elts, tys, strict=True) - ] - return with_type(ty, lhs) - - # TODO: Python also supports assignments like `[a, b] = [1, 2]` or - # `a, *b = ...`. The former would require some runtime checks but - # the latter should be easier to do (unpack and repack the rest). - case _: - raise GuppyError( - UnsupportedError(lhs, "This assignment pattern", singular=True) + raise InternalGuppyError("Unexpected assignment pattern") + + @_check_assign.register + def _check_variable_assign( + self, lhs: ast.Name, _rhs: ast.expr, rhs_ty: Type + ) -> PlaceNode: + x = lhs.id + var = Variable(x, rhs_ty, lhs) + self.ctx.locals[x] = var + return with_loc(lhs, with_type(rhs_ty, PlaceNode(place=var))) + + @_check_assign.register + def _check_field_assign( + self, lhs: ast.Attribute, _rhs: ast.expr, rhs_ty: Type + ) -> PlaceNode: + # Unfortunately, the `attr` is just a string, not an AST node, so we + # have to compute its span by hand. This is fine since linebreaks are + # not allowed in the identifier following the `.` + span = to_span(lhs) + value, attr = lhs.value, lhs.attr + attr_span = Span(span.end.shift_left(len(attr)), span.end) + value, struct_ty = self._synth_expr(value) + if not isinstance(struct_ty, StructType) or attr not in struct_ty.field_dict: + raise GuppyTypeError(AttributeNotFoundError(attr_span, struct_ty, attr)) + field = struct_ty.field_dict[attr] + # TODO: In the future, we could infer some type args here + if field.ty != rhs_ty: + # TODO: Get hold of a span for the RHS and use a regular `TypeMismatchError` + # instead (maybe with a custom hint). + raise GuppyTypeError(AssignFieldTypeMismatchError(attr_span, rhs_ty, field)) + if not isinstance(value, PlaceNode): + # For now we complain if someone tries to assign to something that is not a + # place, e.g. `f().a = 4`. This would only make sense if there is another + # reference to the return value of `f`, otherwise the mutation cannot be + # observed. We can start supporting this once we have proper reference + # semantics. + err = UnsupportedError(value, "Assigning to this expression", singular=True) + err.add_sub_diagnostic(AssignNonPlaceHelp(None, field)) + raise GuppyError(err) + if not field.ty.linear: + raise GuppyError( + UnsupportedError( + attr_span, "Mutation of classical fields", singular=True ) + ) + place = FieldAccess(value.place, struct_ty.field_dict[attr], lhs) + return with_loc(lhs, with_type(rhs_ty, PlaceNode(place=place))) + + @_check_assign.register + def _check_tuple_assign( + self, lhs: ast.Tuple, rhs: ast.expr, rhs_ty: Type + ) -> AnyUnpack: + return self._check_unpack_assign(lhs, rhs, rhs_ty) + + @_check_assign.register + def _check_list_assign( + self, lhs: ast.List, rhs: ast.expr, rhs_ty: Type + ) -> AnyUnpack: + return self._check_unpack_assign(lhs, rhs, rhs_ty) + + def _check_unpack_assign( + self, lhs: ast.Tuple | ast.List, rhs: ast.expr, rhs_ty: Type + ) -> AnyUnpack: + """Helper function to check unpacking assignments. + + These are the ones where the LHS is either a tuple or a list. + """ + # Parse LHS into `left, *starred, right` + pattern = parse_unpack_pattern(lhs) + left, starred, right = pattern.left, pattern.starred, pattern.right + # Check that the RHS has an appropriate type to be unpacked + unpack, rhs_elts, rhs_tys = self._check_unpackable(rhs, rhs_ty, pattern) + + # Check that the numbers match up on the LHS and RHS + num_lhs, num_rhs = len(right) + len(left), len(rhs_tys) + err = WrongNumberOfUnpacksError( + lhs, num_rhs, num_lhs, at_least=starred is not None + ) + if num_lhs > num_rhs: + # Build span that covers the unexpected elts on the LHS + span = Span(to_span(lhs.elts[num_rhs]).start, to_span(lhs.elts[-1]).end) + raise GuppyTypeError(replace(err, span=span)) + elif num_lhs < num_rhs and not starred: + raise GuppyTypeError(err) + + # Recursively check any nested patterns on the left or right + le, rs = len(left), len(rhs_elts) - len(right) # left_end, right_start + unpack.pattern.left = [ + self._check_assign(pat, elt, ty) + for pat, elt, ty in zip(left, rhs_elts[:le], rhs_tys[:le], strict=True) + ] + unpack.pattern.right = [ + self._check_assign(pat, elt, ty) + for pat, elt, ty in zip(right, rhs_elts[rs:], rhs_tys[rs:], strict=True) + ] + + # Starred assignments are collected into an array + if starred: + starred_tys = rhs_tys[le:rs] + assert all_equal(starred_tys) + if starred_tys: + starred_ty, *_ = starred_tys + # Starred part could be empty. If it's an iterable unpack, we're still fine + # since we know the yielded type + elif isinstance(unpack, IterableUnpack): + starred_ty = unpack.compr.elt_ty + # For tuple unpacks, there is no way to infer a type for the empty starred + # part + else: + unsolved = array_type(ExistentialTypeVar.fresh("T", False), 0) + raise GuppyError(TypeInferenceError(starred, unsolved)) + array_ty = array_type(starred_ty, len(starred_tys)) + unpack.pattern.starred = self._check_assign(starred, rhs_elts[0], array_ty) + + return with_type(rhs_ty, with_loc(lhs, unpack)) + + def _check_unpackable( + self, expr: ast.expr, ty: Type, pattern: UnpackPattern + ) -> tuple[AnyUnpack, list[ast.expr], Sequence[Type]]: + """Checks that the given expression can be used in an unpacking assignment. + + This is the case for expressions with tuple types or ones that are iterable with + a static size. Also checks that the expression is compatible with the given + unpacking pattern. + + Returns an AST node capturing the unpacking operation together with expressions + and types for all unpacked items. Emits a user error if the given expression is + not unpackable. + """ + left, starred, right = pattern.left, pattern.starred, pattern.right + if isinstance(ty, TupleType): + # Starred assignment of tuples is only allowed if all starred elements have + # the same type + if starred: + starred_tys = ( + ty.element_types[len(left) : -len(right)] + if right + else ty.element_types[len(left) :] + ) + if not all_equal(starred_tys): + tuple_ty = TupleType(starred_tys) + raise GuppyError(StarredTupleUnpackError(starred, tuple_ty)) + tys = ty.element_types + elts = expr.elts if isinstance(expr, ast.Tuple) else [expr] * len(tys) + return TupleUnpack(pattern), elts, tys + + elif self.ctx.globals.get_instance_func(ty, "__iter__"): + size = check_iter_unpack_has_static_size(expr, self.ctx) + # Create a dummy variable and assign the expression to it. This helps us to + # wire it up correctly during Hugr generation. + var = self._check_assign(make_var(next(tmp_vars), expr), expr, ty) + assert isinstance(var, PlaceNode) + # We collect the whole RHS into an array. For this, we can reuse the + # existing array comprehension logic. + elt = make_var(next(tmp_vars), expr) + gen = ast.comprehension(target=elt, iter=var, ifs=[], is_async=False) + [gen], elt = desugar_comprehension([with_loc(expr, gen)], elt, expr) + # Type check the comprehension + [gen], elt, elt_ty = synthesize_comprehension(expr, [gen], elt, self.ctx) + compr = DesugaredArrayComp( + elt, gen, length=ConstValue(nat_type(), size), elt_ty=elt_ty + ) + compr = with_type(array_type(elt_ty, size), compr) + return IterableUnpack(pattern, compr, var), size * [elt], size * [elt_ty] + + # Otherwise, we can't unpack this expression + raise GuppyError(UnpackableError(expr, ty)) def visit_Assign(self, node: ast.Assign) -> ast.Assign: if len(node.targets) > 1: @@ -138,7 +276,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.Assign: [target] = node.targets node.value, ty = self._synth_expr(node.value) - node.targets = [self._check_assign(target, ty, node)] + node.targets = [self._check_assign(target, node.value, ty)] return node def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: @@ -148,7 +286,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.stmt: node.value, subst = self._check_expr(node.value, ty) assert not ty.unsolved_vars # `ty` must be closed! assert len(subst) == 0 - target = self._check_assign(node.target, ty, node) + target = self._check_assign(node.target, node.value, ty) return with_loc(node, ast.Assign(targets=[target], value=node.value)) def visit_AugAssign(self, node: ast.AugAssign) -> ast.stmt: @@ -197,3 +335,56 @@ def visit_Break(self, node: ast.Break) -> None: def visit_Continue(self, node: ast.Continue) -> None: raise InternalGuppyError("Control-flow statement should not be present here.") + + +T = TypeVar("T") + + +def all_equal(xs: Iterable[T]) -> bool: + """Checks if all elements yielded from an iterable are equal.""" + it = iter(xs) + try: + first = next(it) + except StopIteration: + return True + return all(first == x for x in it) + + +def parse_unpack_pattern(lhs: ast.Tuple | ast.List) -> UnpackPattern: + """Parses the LHS of an unpacking assignment like `a, *bs, c = ...` or + `[a, *bs, c] = ...`.""" + # Split up LHS into `left, *starred, right` (the Python grammar ensures + # that there is at most one starred expression) + left = list(takewhile(lambda e: not isinstance(e, ast.Starred), lhs.elts)) + starred = ( + cast(ast.Starred, lhs.elts[len(left)]).value + if len(left) < len(lhs.elts) + else None + ) + right = lhs.elts[len(left) + 1 :] + assert isinstance(starred, ast.Name | None), "Python grammar" + return UnpackPattern(left, starred, right) + + +def check_iter_unpack_has_static_size(expr: ast.expr, ctx: Context) -> int: + """Helper function to check that an iterable expression is suitable to be unpacked + in an assignment. + + This is the case if the iterator has a static, non-generic size. + + Returns the size of the iterator or emits a user error if the iterable is not + suitable. + """ + expr_synth = ExprSynthesizer(ctx) + make_iter = with_loc(expr, MakeIter(expr, expr, unwrap_size_hint=False)) + make_iter, iter_ty = expr_synth.visit_MakeIter(make_iter) + err = UnpackableError(expr, get_type(expr)) + if not is_sized_iter_type(iter_ty): + err.add_sub_diagnostic(UnpackableError.NonStaticIter(None)) + raise GuppyError(err) + match get_iter_size(iter_ty): + case ConstValue(value=int(size)): + return size + case generic_size: + err.add_sub_diagnostic(UnpackableError.GenericSize(None, generic_size)) + raise GuppyError(err) diff --git a/guppylang/compiler/stmt_compiler.py b/guppylang/compiler/stmt_compiler.py index addcb8cb..6f92854e 100644 --- a/guppylang/compiler/stmt_compiler.py +++ b/guppylang/compiler/stmt_compiler.py @@ -1,6 +1,8 @@ import ast +import functools from collections.abc import Sequence +import hugr.tys as ht from hugr import Wire, ops from hugr.build.dfg import DfBase @@ -14,8 +16,21 @@ ) from guppylang.compiler.expr_compiler import ExprCompiler from guppylang.error import InternalGuppyError -from guppylang.nodes import CheckedNestedFunctionDef, PlaceNode -from guppylang.tys.ty import TupleType, Type +from guppylang.nodes import ( + CheckedNestedFunctionDef, + IterableUnpack, + PlaceNode, + TupleUnpack, +) +from guppylang.std._internal.compiler.array import ( + array_discard_empty, + array_new, + array_pop, +) +from guppylang.std._internal.compiler.prelude import build_unwrap +from guppylang.tys.builtin import get_element_type +from guppylang.tys.const import ConstValue +from guppylang.tys.ty import TupleType, Type, type_to_row class StmtCompiler(CompilerBase, AstVisitor[None]): @@ -49,27 +64,86 @@ def builder(self) -> DfBase[ops.DfParentOp]: """The Hugr dataflow graph builder.""" return self.dfg.builder - def _unpack_assign(self, lhs: ast.expr, port: Wire, node: ast.stmt) -> None: + @functools.singledispatchmethod + def _assign(self, lhs: ast.expr, port: Wire) -> None: """Updates the local DFG with assignments.""" - if isinstance(lhs, PlaceNode): - self.dfg[lhs.place] = port - elif isinstance(lhs, ast.Tuple): - types = [get_type(e).to_hugr() for e in lhs.elts] - unpack = self.builder.add_op(ops.UnpackTuple(types), port) - for pat, wire in zip(lhs.elts, unpack, strict=True): - self._unpack_assign(pat, wire, node) + raise InternalGuppyError("Invalid assign pattern in compiler") + + @_assign.register + def _assign_place(self, lhs: PlaceNode, port: Wire) -> None: + self.dfg[lhs.place] = port + + @_assign.register + def _assign_tuple(self, lhs: TupleUnpack, port: Wire) -> None: + """Handles assignment where the RHS is a tuple that should be unpacked.""" + # Unpack the RHS tuple + left, starred, right = lhs.pattern.left, lhs.pattern.starred, lhs.pattern.right + types = [ty.to_hugr() for ty in type_to_row(get_type(lhs))] + unpack = self.builder.add_op(ops.UnpackTuple(types), port) + ports = list(unpack) + + # Assign left and right + for pat, wire in zip(left, ports[: len(left)], strict=True): + self._assign(pat, wire) + if right: + for pat, wire in zip(right, ports[-len(right) :], strict=True): + self._assign(pat, wire) + + # Starred assignments are collected into an array + if starred: + array_ty = get_type(starred) + starred_ports = ( + ports[len(left) : -len(right)] if right else ports[len(left) :] + ) + opt_ty = ht.Option(get_element_type(array_ty).to_hugr()) + opts = [self.builder.add_op(ops.Tag(1, opt_ty), p) for p in starred_ports] + array = self.builder.add_op(array_new(opt_ty, len(opts)), *opts) + self._assign(starred, array) + + @_assign.register + def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None: + """Handles assignment where the RHS is an iterable that should be unpacked.""" + # Given an assignment pattern `left, *starred, right`, collect the RHS into an + # array and pop from the left and right, leaving us with the starred array in + # the middle + assert isinstance(lhs.compr.length, ConstValue) + length = lhs.compr.length.value + assert isinstance(length, int) + opt_elt_ty = ht.Option(lhs.compr.elt_ty.to_hugr()) + + def pop( + array: Wire, length: int, pats: list[ast.expr], from_left: bool + ) -> tuple[Wire, int]: + err = "Internal error: unpacking of iterable failed" + for pat in pats: + res = self.builder.add_op( + array_pop(opt_elt_ty, length, from_left), array + ) + [elt_opt, array] = build_unwrap(self.builder, res, err) + [elt] = build_unwrap(self.builder, elt_opt, err) + self._assign(pat, elt) + length -= 1 + return array, length + + self.dfg[lhs.rhs_var.place] = port + array = self.expr_compiler.visit_DesugaredArrayComp(lhs.compr) + array, length = pop(array, length, lhs.pattern.left, True) + array, length = pop(array, length, lhs.pattern.right, False) + if lhs.pattern.starred: + self._assign(lhs.pattern.starred, array) else: - raise InternalGuppyError("Invalid assign pattern in compiler") + assert length == 0 + self.builder.add_op(array_discard_empty(opt_elt_ty), array) def visit_Assign(self, node: ast.Assign) -> None: [target] = node.targets port = self.expr_compiler.compile(node.value, self.dfg) - self._unpack_assign(target, port, node) + self._assign(target, port) def visit_AnnAssign(self, node: ast.AnnAssign) -> None: assert node.value is not None port = self.expr_compiler.compile(node.value, self.dfg) - self._unpack_assign(node.target, port, node) + self._assign(node.target, port) def visit_AugAssign(self, node: ast.AugAssign) -> None: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 2eca1ea4..89d03787 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -290,6 +290,60 @@ class InoutReturnSentinel(ast.expr): _fields = ("var",) +class UnpackPattern(ast.expr): + """The LHS of an unpacking assignment like `a, *bs, c = ...` or + `[a, *bs, c] = ...`.""" + + #: Patterns occurring on the left of the starred target + left: list[ast.expr] + + #: The starred target or `None` if there is none + starred: ast.expr | None + + #: Patterns occurring on the right of the starred target. This will be an empty list + #: if there is no starred target + right: list[ast.expr] + + _fields = ("left", "starred", "right") + + +class TupleUnpack(ast.expr): + """The LHS of an unpacking assignment of a tuple.""" + + #: The (possibly starred) unpacking pattern + pattern: UnpackPattern + + _fields = ("pattern",) + + +class IterableUnpack(ast.expr): + """The LHS of an unpacking assignment of an iterable type.""" + + #: The (possibly starred) unpacking pattern + pattern: UnpackPattern + + #: Comprehension that collects the RHS iterable into an array + compr: DesugaredArrayComp + + #: Dummy variable that the RHS should be bound to. This variable is referenced in + #: `compr` + rhs_var: PlaceNode + + # Don't mention the comprehension in _fields to avoid visitors recursing it + _fields = ("pattern",) + + def __init__( + self, pattern: UnpackPattern, compr: DesugaredArrayComp, rhs_var: PlaceNode + ) -> None: + super().__init__(pattern) + self.compr = compr + self.rhs_var = rhs_var + + +#: Any unpacking operation. +AnyUnpack = TupleUnpack | IterableUnpack + + class NestedFunctionDef(ast.FunctionDef): cfg: "CFG" ty: FunctionType diff --git a/guppylang/std/_internal/compiler/array.py b/guppylang/std/_internal/compiler/array.py index 66e602cd..e4432360 100644 --- a/guppylang/std/_internal/compiler/array.py +++ b/guppylang/std/_internal/compiler/array.py @@ -72,6 +72,26 @@ def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: ) +def array_pop(elem_ty: ht.Type, length: int, from_left: bool) -> ops.ExtOp: + """Returns an operation that pops an element from the left of an array.""" + assert length > 0 + length_arg = ht.BoundedNatArg(length) + arr_ty = array_type(elem_ty, length_arg) + popped_arr_ty = array_type(elem_ty, ht.BoundedNatArg(length - 1)) + op = "pop_left" if from_left else "pop_right" + return _instantiate_array_op( + op, elem_ty, length_arg, [arr_ty], [ht.Option(elem_ty, popped_arr_ty)] + ) + + +def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp: + """Returns an operation that discards an array of length zero.""" + arr_ty = array_type(elem_ty, ht.BoundedNatArg(0)) + return EXTENSION.get_op("discard_empty").instantiate( + [ht.TypeTypeArg(elem_ty)], ht.FunctionType([arr_ty], []) + ) + + def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp: """Returns an operation that maps a function across an array.""" # TODO diff --git a/guppylang/tys/const.py b/guppylang/tys/const.py index d942f417..245f0161 100644 --- a/guppylang/tys/const.py +++ b/guppylang/tys/const.py @@ -39,6 +39,11 @@ def unsolved_vars(self) -> set[ExistentialVar]: """The existential type variables contained in this constant.""" return set() + def __str__(self) -> str: + from guppylang.tys.printing import TypePrinter + + return TypePrinter().visit(self.cast()) + def visit(self, visitor: Visitor) -> None: """Accepts a visitor on this constant.""" visitor.visit(self) diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index 633039fd..5ae0b3f9 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -2,7 +2,7 @@ from guppylang.error import InternalGuppyError from guppylang.tys.arg import ConstArg, TypeArg -from guppylang.tys.const import ConstValue +from guppylang.tys.const import Const, ConstValue from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.ty import ( FunctionType, @@ -53,7 +53,7 @@ def _fresh_name(self, display_name: str) -> str: self.counter[display_name] += 1 return indexed - def visit(self, ty: Type) -> str: + def visit(self, ty: Type | Const) -> str: return self._visit(ty, False) @singledispatchmethod diff --git a/tests/error/type_errors/not_unpackable.err b/tests/error/type_errors/not_unpackable.err new file mode 100644 index 00000000..8f4cb981 --- /dev/null +++ b/tests/error/type_errors/not_unpackable.err @@ -0,0 +1,8 @@ +Error: Unpackable (at $FILE:6:9) + | +4 | @compile_guppy +5 | def foo() -> int: +6 | a, = 1 + | ^ Expression of type `int` cannot be unpacked + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/type_errors/not_unpackable.py b/tests/error/type_errors/not_unpackable.py new file mode 100644 index 00000000..d66ab9cd --- /dev/null +++ b/tests/error/type_errors/not_unpackable.py @@ -0,0 +1,7 @@ +from tests.util import compile_guppy + + +@compile_guppy +def foo() -> int: + a, = 1 + return a diff --git a/tests/error/type_errors/unpack_generic_size.err b/tests/error/type_errors/unpack_generic_size.err new file mode 100644 index 00000000..37925025 --- /dev/null +++ b/tests/error/type_errors/unpack_generic_size.err @@ -0,0 +1,12 @@ +Error: Unpackable (at $FILE:11:13) + | + 9 | @guppy(module) +10 | def foo(xs: array[int, n]) -> int: +11 | a, *bs = xs + | ^^ Expression of type `array[int, n]` cannot be unpacked + +Note: Unpacking of iterable types like `array[int, n]` is only allowed if the +number of items yielded by the iterator is statically known. Here, the number of +items `n` is generic and can change between different function invocations. + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/type_errors/unpack_generic_size.py b/tests/error/type_errors/unpack_generic_size.py new file mode 100644 index 00000000..eb7d8013 --- /dev/null +++ b/tests/error/type_errors/unpack_generic_size.py @@ -0,0 +1,15 @@ +from guppylang import GuppyModule, guppy +from guppylang.std.builtins import array + + +module = GuppyModule('main') +n = guppy.nat_var("n", module=module) + + +@guppy(module) +def foo(xs: array[int, n]) -> int: + a, *bs = xs + return a + + +module.compile() diff --git a/tests/error/type_errors/unpack_non_static.err b/tests/error/type_errors/unpack_non_static.err new file mode 100644 index 00000000..a4624983 --- /dev/null +++ b/tests/error/type_errors/unpack_non_static.err @@ -0,0 +1,12 @@ +Error: Unpackable (at $FILE:6:13) + | +4 | @compile_guppy +5 | def foo(xs: list[int]) -> int: +6 | a, *bs = xs + | ^^ Expression of type `list[int]` cannot be unpacked + +Note: Unpacking of iterable types like `list[int]` is only allowed if the number +of items yielded by the iterator is statically known. This is not the case for +`list[int]`. + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/type_errors/unpack_non_static.py b/tests/error/type_errors/unpack_non_static.py new file mode 100644 index 00000000..2bc46f42 --- /dev/null +++ b/tests/error/type_errors/unpack_non_static.py @@ -0,0 +1,7 @@ +from tests.util import compile_guppy + + +@compile_guppy +def foo(xs: list[int]) -> int: + a, *bs = xs + return a diff --git a/tests/error/type_errors/unpack_tuple_starred.err b/tests/error/type_errors/unpack_tuple_starred.err new file mode 100644 index 00000000..7c3776b8 --- /dev/null +++ b/tests/error/type_errors/unpack_tuple_starred.err @@ -0,0 +1,10 @@ +Error: Invalid starred unpacking (at $FILE:6:8) + | +4 | @compile_guppy +5 | def foo() -> int: +6 | a, *bs = 1, 2, True + | ^^ Expression of type `(int, bool)` cannot be collected into a + | starred assignment since the yielded items have + | different types + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/type_errors/unpack_tuple_starred.py b/tests/error/type_errors/unpack_tuple_starred.py new file mode 100644 index 00000000..5c5c845b --- /dev/null +++ b/tests/error/type_errors/unpack_tuple_starred.py @@ -0,0 +1,7 @@ +from tests.util import compile_guppy + + +@compile_guppy +def foo() -> int: + a, *bs = 1, 2, True + return a diff --git a/tests/integration/test_unpack.py b/tests/integration/test_unpack.py new file mode 100644 index 00000000..f399f22d --- /dev/null +++ b/tests/integration/test_unpack.py @@ -0,0 +1,104 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.builtins import array, owned + +from guppylang.std.quantum import qubit + + +def test_unpack_array(validate): + module = GuppyModule("test") + module.load(qubit) + + @guppy(module) + def main(qs: array[qubit, 3] @ owned) -> tuple[qubit, qubit, qubit]: + q1, q2, q3 = qs + return q1, q2, q3 + + validate(module.compile()) + + +def test_unpack_starred(validate): + module = GuppyModule("test") + module.load(qubit) + + @guppy(module) + def main( + qs: array[qubit, 10] @ owned, + ) -> tuple[qubit, qubit, qubit, qubit, qubit, qubit, array[qubit, 4]]: + q1, q2, *qs, q3 = qs + [q4, *qs] = qs + *qs, q5, q6 = qs + [*qs] = qs + return q1, q2, q3, q4, q5, q6, qs + + validate(module.compile()) + + +def test_unpack_starred_empty(validate): + module = GuppyModule("test") + module.load(qubit) + + @guppy(module) + def main(qs: array[qubit, 2] @ owned) -> tuple[qubit, array[qubit, 0], qubit]: + q1, *empty, q2 = qs + return q1, empty, q2 + + validate(module.compile()) + + +def test_unpack_big_iterable(validate): + # Test that the compile-time doesn't scale with the size of the unpacked iterable + module = GuppyModule("test") + module.load(qubit) + + @guppy(module) + def main(qs: array[qubit, 1000] @ owned) -> tuple[qubit, array[qubit, 998], qubit]: + q1, *qs, q2 = qs + return q1, qs, q2 + + validate(module.compile()) + + +def test_unpack_range(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) + def main() -> int: + [_, x, *_, y, _] = range(10) + return x + y + + compiled = module.compile() + validate(compiled) + # TODO: Enable execution test once array lowering is fully supported + # run_int_fn(compiled, expected=9) + + +def test_unpack_tuple_starred(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) + def main() -> array[int, 2]: + x, *ys, z = 1, 2, 3, 4 + return ys + + validate(module.compile()) + + +def test_unpack_nested(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) + def main( + xs: array[array[array[int, 5], 10], 20], + ) -> tuple[ + array[int, 5], # x + int, # y + array[int, 3], # z + array[array[int, 5], 8], # a + array[array[array[int, 5], 10], 18], # b + array[array[int, 5], 10], # c + ]: + (x, [y, *z, _], *a), *b, c = xs + return x, y, z, a, b, c + + validate(module.compile())