diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 0c42207c..e67a5bf5 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -1,7 +1,7 @@ import ast import itertools from collections.abc import Iterator -from typing import NamedTuple, cast +from typing import NamedTuple from guppy.ast_util import ( AstVisitor, @@ -258,9 +258,8 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]: @classmethod def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None: """Adds a temporary variable assignment to a basic block.""" - node = ast.Assign(targets=[make_var(tmp_name, value)], value=value) - set_location_from(node, value) - bb.statements.append(node) + lhs = make_var(tmp_name, value) + bb.statements.append(make_assign([lhs], value)) def visit_Name(self, node: ast.Name) -> ast.Name: return node @@ -309,31 +308,24 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST: if g.is_async: raise GuppyError("Async generators are not supported", g) g.iter = self.visit(g.iter) - gen = DesugaredGenerator() - - template = """ - it = make_iter - b, it = has_next - x, it = get_next - """ it = make_var(next(tmp_vars), g.iter) - b = make_var(next(tmp_vars), g.iter) - [gen.iter_assign, gen.hasnext_assign, gen.next_assign] = cast( - list[ast.Assign], - template_replace( - template, - g.iter, - it=it, - b=b, - x=g.target, - make_iter=with_loc(it, MakeIter(value=g.iter, origin_node=node)), - has_next=with_loc(it, IterHasNext(value=it)), - get_next=with_loc(it, IterNext(value=it)), + hasnext = make_var(next(tmp_vars), g.iter) + desugared = DesugaredGenerator( + iter=it, + hasnext=hasnext, + iter_assign=make_assign( + [it], with_loc(it, MakeIter(value=g.iter, origin_node=node)) + ), + hasnext_assign=make_assign( + [hasnext, it], with_loc(it, IterHasNext(value=it)) ), + next_assign=make_assign( + [g.target, it], with_loc(it, IterNext(value=it)) + ), + iterend=with_loc(it, IterEnd(value=it)), + ifs=g.ifs, ) - gen.iterend = with_loc(it, IterEnd(value=it)) - gen.iter, gen.hasnext, gen.ifs = it, b, g.ifs - gens.append(gen) + gens.append(desugared) node.elt = self.visit(node.elt) return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) @@ -507,3 +499,13 @@ def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: if loc is not None: set_location_from(node, loc) return node + + +def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign: + """Creates an `ast.Assign` node.""" + assert len(lhs) > 0 + if len(lhs) == 1: + target = lhs[0] + else: + target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store())) + return with_loc(value, ast.Assign(targets=[target], value=value))