Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Unpacking assignment of iterable types with static size #688

Merged
merged 13 commits into from
Dec 17, 2024
4 changes: 3 additions & 1 deletion guppylang/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,15 @@ def _handle_assign_target(self, lhs: ast.expr, node: ast.stmt) -> None:
match lhs:
case ast.Name(id=name):
self.stats.assigned[name] = node
case ast.Tuple(elts=elts):
case ast.Tuple(elts=elts) | ast.List(elts=elts):
for elt in elts:
self._handle_assign_target(elt, node)
case ast.Attribute(value=value):
# Setting attributes counts as a use of the value, so we do a regular
# visit instead of treating it like a LHS
self.visit(value)
case ast.Starred(value=value):
self._handle_assign_target(value, node)

def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:
self._handle_comprehension(node.generators, node.elt)
Expand Down
96 changes: 51 additions & 45 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,56 +316,13 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:

def visit_ListComp(self, node: ast.ListComp) -> DesugaredListComp:
check_lists_enabled(node)
generators, elt = self._build_comprehension(node.generators, node.elt, node)
generators, elt = desugar_comprehension(node.generators, node.elt, node)
return with_loc(node, DesugaredListComp(elt=elt, generators=generators))

def visit_GeneratorExp(self, node: ast.GeneratorExp) -> DesugaredGeneratorExpr:
generators, elt = self._build_comprehension(node.generators, node.elt, node)
generators, elt = desugar_comprehension(node.generators, node.elt, node)
return with_loc(node, DesugaredGeneratorExpr(elt=elt, generators=generators))

def _build_comprehension(
self, generators: list[ast.comprehension], elt: ast.expr, node: ast.AST
) -> tuple[list[DesugaredGenerator], ast.expr]:
# Check for illegal expressions
illegals = find_nodes(is_illegal_in_list_comp, node)
if illegals:
err = UnsupportedError(
illegals[0],
"This expression",
singular=True,
unsupported_in="a list comprehension",
)
raise GuppyError(err)

# Desugar into statements that create the iterator, check for a next element,
# get the next element, and finalise the iterator.
gens = []
for g in generators:
if g.is_async:
raise GuppyError(UnsupportedError(g, "Async generators"))
g.iter = self.visit(g.iter)
it = make_var(next(tmp_vars), g.iter)
hasnext = make_var(next(tmp_vars), g.iter)
desugared = DesugaredGenerator(
iter=it,
hasnext=hasnext,
iter_assign=make_assign(
[it], with_loc(it, MakeIter(value=g.iter, origin_node=node))
),
hasnext_assign=make_assign(
[hasnext, it], with_loc(it, IterHasNext(value=it))
),
next_assign=make_assign(
[g.target, it], with_loc(it, IterNext(value=it))
),
iterend=with_loc(it, IterEnd(value=it)),
ifs=g.ifs,
)
gens.append(desugared)

elt = self.visit(elt)
return gens, elt

def visit_Call(self, node: ast.Call) -> ast.AST:
return is_py_expression(node) or self.generic_visit(node)

Expand Down Expand Up @@ -487,6 +444,55 @@ def generic_visit(self, node: ast.expr, bb: BB, true_bb: BB, false_bb: BB) -> No
self.cfg.link(bb, true_bb)


def desugar_comprehension(
generators: list[ast.comprehension], elt: ast.expr, node: ast.AST
) -> tuple[list[DesugaredGenerator], ast.expr]:
"""Helper function to desugar a comprehension node."""
# Check for illegal expressions
illegals = find_nodes(is_illegal_in_list_comp, node)
if illegals:
err = UnsupportedError(
illegals[0],
"This expression",
singular=True,
unsupported_in="a list comprehension",
)
raise GuppyError(err)

# The check above ensures that the comprehension doesn't contain any control-flow
# expressions. Thus, we can use a dummy `ExprBuilder` to desugar the insides.
# TODO: Refactor so that desugaring is separate from control-flow building
dummy_cfg = CFG()
builder = ExprBuilder(dummy_cfg, dummy_cfg.entry_bb)

# Desugar into statements that create the iterator, check for a next element,
# get the next element, and finalise the iterator.
gens = []
for g in generators:
if g.is_async:
raise GuppyError(UnsupportedError(g, "Async generators"))
g.iter = builder.visit(g.iter)
it = make_var(next(tmp_vars), g.iter)
hasnext = make_var(next(tmp_vars), g.iter)
desugared = DesugaredGenerator(
iter=it,
hasnext=hasnext,
iter_assign=make_assign(
[it], with_loc(it, MakeIter(value=g.iter, origin_node=node))
),
hasnext_assign=make_assign(
[hasnext, it], with_loc(it, IterHasNext(value=it))
),
next_assign=make_assign([g.target, it], with_loc(it, IterNext(value=it))),
iterend=with_loc(it, IterEnd(value=it)),
ifs=g.ifs,
)
gens.append(desugared)

elt = builder.visit(elt)
return gens, elt


def is_functional_annotation(stmt: ast.stmt) -> bool:
"""Returns `True` iff the given statement is the functional pseudo-decorator.

Expand Down
39 changes: 38 additions & 1 deletion guppylang/checker/errors/type_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class WrongNumberOfUnpacksError(Error):
title: ClassVar[str] = "{prefix} values to unpack"
expected: int
actual: int
at_least: bool

@property
def prefix(self) -> str:
Expand All @@ -202,7 +203,43 @@ def rendered_span_label(self) -> str:
msg = "Unexpected assignment " + ("targets" if diff < -1 else "target")
else:
msg = "Not enough assignment targets"
return f"{msg} (expected {self.expected}, got {self.actual})"
at_least = "at least " if self.at_least else ""
return f"{msg} (expected {self.expected}, got {at_least}{self.actual})"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The at_least message could probably go inside the diff < 0 branch for clarity, as it should always be false otherwise.



@dataclass(frozen=True)
class UnpackableError(Error):
title: ClassVar[str] = "Unpackable"
span_label: ClassVar[str] = "Expression of type `{ty}` cannot be unpacked"
ty: Type

@dataclass(frozen=True)
class NonStaticIter(Note):
message: ClassVar[str] = (
"Unpacking of iterable types like `{ty}` is only allowed if the number of "
"items yielded by the iterator is statically known. This is not the case "
"for `{ty}`."
)

@dataclass(frozen=True)
class GenericSize(Note):
message: ClassVar[str] = (
"Unpacking of iterable types like `{ty}` is only allowed if the number of "
"items yielded by the iterator is statically known. Here, the number of "
"items `{num}` is generic and can change between different function "
"invocations."
)
num: Const


@dataclass(frozen=True)
class StarredTupleUnpackError(Error):
title: ClassVar[str] = "Invalid starred unpacking"
span_label: ClassVar[str] = (
"Expression of type `{ty}` cannot be collected into a starred assignment since "
"the yielded items have different types"
)
ty: Type


@dataclass(frozen=True)
Expand Down
34 changes: 24 additions & 10 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,15 +1050,35 @@ def synthesize_comprehension(
node: AstNode, gens: list[DesugaredGenerator], elt: ast.expr, ctx: Context
) -> tuple[list[DesugaredGenerator], ast.expr, Type]:
"""Helper function to synthesise the element type of a list comprehension."""
from guppylang.checker.stmt_checker import StmtChecker

# If there are no more generators left, we can check the list element
if not gens:
elt, elt_ty = ExprSynthesizer(ctx).synthesize(elt)
return gens, elt, elt_ty

# Check the iterator in the outer context
# Check the first generator
gen, *gens = gens
gen, inner_ctx = check_generator(gen, ctx)

# Check remaining generators in inner context
gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx)

# The iter finalizer is again checked in the outer context
gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend)
gen.iterend = with_type(iterend_ty, gen.iterend)
return [gen, *gens], elt, elt_ty


def check_generator(
gen: DesugaredGenerator, ctx: Context
) -> tuple[DesugaredGenerator, Context]:
"""Helper function to check a single generator.

Returns the type annotated generator together with a new nested context in which the
generator variables are bound.
"""
from guppylang.checker.stmt_checker import StmtChecker

# Check the iterator in the outer context
gen.iter_assign = StmtChecker(ctx).visit_Assign(gen.iter_assign)

# The rest is checked in a new nested context to ensure that variables don't escape
Expand All @@ -1078,13 +1098,7 @@ def synthesize_comprehension(
gen.ifs[i], if_ty = expr_sth.synthesize(gen.ifs[i])
gen.ifs[i], _ = to_bool(gen.ifs[i], if_ty, inner_ctx)

# Check remaining generators
gens, elt, elt_ty = synthesize_comprehension(node, gens, elt, inner_ctx)

# The iter finalizer is again checked in the outer context
gen.iterend, iterend_ty = ExprSynthesizer(ctx).synthesize(gen.iterend)
gen.iterend = with_type(iterend_ty, gen.iterend)
return [gen, *gens], elt, elt_ty
return gen, inner_ctx


def eval_py_expr(node: PyExpr, ctx: Context) -> Any:
Expand Down
Loading
Loading