From 359c2da21201adb41181295e7d9a55e70a4c66da Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 20 Dec 2023 13:25:08 +0000 Subject: [PATCH 1/5] feat: Build for loops and comprehensions --- guppy/ast_util.py | 190 ++++++++++++++++++++++++++++++++++--------- guppy/cfg/builder.py | 122 +++++++++++++++++++++++---- 2 files changed, 260 insertions(+), 52 deletions(-) diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 5a9e66f4..31bb940f 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -1,5 +1,7 @@ import ast -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +import textwrap +from collections.abc import Callable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast if TYPE_CHECKING: from guppy.gtypes import GuppyType @@ -54,51 +56,165 @@ def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> T: raise NotImplementedError(f"visit_{node.__class__.__name__} is not implemented") -class NameVisitor(ast.NodeVisitor): - """Visitor to collect all `Name` nodes occurring in an AST.""" - - names: list[ast.Name] - - def __init__(self) -> None: - self.names = [] - - def visit_Name(self, node: ast.Name) -> None: - self.names.append(node) +class AstSearcher(ast.NodeVisitor): + """Visitor that searches for occurrences of specific nodes in an AST.""" + + matcher: Callable[[ast.AST], bool] + dont_recurse_into: set[type[ast.AST]] + found: list[ast.AST] + is_first_node: bool + + def __init__( + self, + matcher: Callable[[ast.AST], bool], + dont_recurse_into: set[type[ast.AST]] | None = None, + ) -> None: + self.matcher = matcher + self.dont_recurse_into = dont_recurse_into or set() + self.found = [] + self.is_first_node = True + + def generic_visit(self, node: ast.AST) -> None: + if self.matcher(node): + self.found.append(node) + if self.is_first_node or type(node) not in self.dont_recurse_into: + self.is_first_node = False + super().generic_visit(node) + + +def find_nodes( + matcher: Callable[[ast.AST], bool], + node: ast.AST, + dont_recurse_into: set[type[ast.AST]] | None = None, +) -> list[ast.AST]: + """Returns all nodes in the AST that satisfy the matcher.""" + v = AstSearcher(matcher, dont_recurse_into) + v.visit(node) + return v.found def name_nodes_in_ast(node: Any) -> list[ast.Name]: """Returns all `Name` nodes occurring in an AST.""" - v = NameVisitor() - v.visit(node) - return v.names - - -class ReturnVisitor(ast.NodeVisitor): - """Visitor to collect all `Return` nodes occurring in an AST.""" + found = find_nodes(lambda n: isinstance(n, ast.Name), node) + return cast(list[ast.Name], found) - returns: list[ast.Return] - inside_func_def: bool - def __init__(self) -> None: - self.returns = [] - self.inside_func_def = False - - def visit_Return(self, node: ast.Return) -> None: - self.returns.append(node) +def return_nodes_in_ast(node: Any) -> list[ast.Return]: + """Returns all `Return` nodes occurring in an AST.""" + found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef}) + return cast(list[ast.Return], found) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - # Don't descend into nested function definitions - if not self.inside_func_def: - self.inside_func_def = True - for n in node.body: - self.visit(n) +def breaks_in_loop(node: Any) -> list[ast.Break]: + """Returns all `Break` nodes occurring in a loop. -def return_nodes_in_ast(node: Any) -> list[ast.Return]: - """Returns all `Return` nodes occurring in an AST.""" - v = ReturnVisitor() - v.visit(node) - return v.returns + Note that breaks in nested loops are excluded. + """ + found = find_nodes( + lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef} + ) + return cast(list[ast.Break], found) + + +class ContextAdjuster(ast.NodeTransformer): + """Updates the `ast.Context` indicating if expressions occur on the LHS or RHS.""" + + ctx: ast.expr_context + + def __init__(self, ctx: ast.expr_context) -> None: + self.ctx = ctx + + def visit(self, node: ast.AST) -> ast.AST: + return cast(ast.AST, super().visit(node)) + + def visit_Name(self, node: ast.Name) -> ast.Name: + return with_loc(node, ast.Name(id=node.id, ctx=self.ctx)) + + def visit_Starred(self, node: ast.Starred) -> ast.Starred: + return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx)) + + def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple: + return with_loc( + node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx) + ) + + def visit_List(self, node: ast.List) -> ast.List: + return with_loc( + node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx) + ) + + def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: + # Don't adjust the slice! + return with_loc( + node, + ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx), + ) + + def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: + return ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx) + + +class TemplateReplacer(ast.NodeTransformer): + """Replaces nodes in a template.""" + + replacements: Mapping[str, ast.AST | Sequence[ast.AST]] + default_loc: ast.AST + + def __init__( + self, + replacements: Mapping[str, ast.AST | Sequence[ast.AST]], + default_loc: ast.AST, + ) -> None: + self.replacements = replacements + self.default_loc = default_loc + + def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]: + if x not in self.replacements: + msg = f"No replacement for `{x}` is given" + raise ValueError(msg) + return self.replacements[x] + + def visit_Name(self, node: ast.Name) -> ast.AST: + repl = self._get_replacement(node.id) + if not isinstance(repl, ast.expr): + msg = f"Replacement for `{node.id}` must be an expression" + raise TypeError(msg) + + # Update the context + adjuster = ContextAdjuster(node.ctx) + return with_loc(repl, adjuster.visit(repl)) + + def visit_Expr(self, node: ast.Expr) -> ast.AST | Sequence[ast.AST]: + if isinstance(node.value, ast.Name): + repl = self._get_replacement(node.value.id) + repls = [repl] if not isinstance(repl, Sequence) else repl + # Wrap expressions to turn them into statements + return [ + with_loc(r, ast.Expr(value=r)) if isinstance(r, ast.expr) else r + for r in repls + ] + return self.generic_visit(node) + + def generic_visit(self, node: ast.AST) -> ast.AST: + # Insert the default location + node = super().generic_visit(node) + return with_loc(self.default_loc, node) + + +def template_replace( + template: str, default_loc: ast.AST, **kwargs: ast.AST | Sequence[ast.AST] +) -> list[ast.stmt]: + """Turns a template into a proper AST by substituting all placeholders.""" + nodes = ast.parse(textwrap.dedent(template)).body + replacer = TemplateReplacer(kwargs, default_loc) + new_nodes = [] + for n in nodes: + new = replacer.visit(n) + if isinstance(new, list): + new_nodes.extend(new) + else: + new_nodes.append(new) + return new_nodes def line_col(node: ast.AST) -> tuple[int, int]: diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 878437bf..a664394b 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -1,15 +1,29 @@ import ast import itertools from collections.abc import Iterator -from typing import NamedTuple - -from guppy.ast_util import AstVisitor, set_location_from +from typing import NamedTuple, cast + +from guppy.ast_util import ( + AstVisitor, + find_nodes, + set_location_from, + template_replace, + with_loc, +) from guppy.cfg.bb import BB, BBStatement from guppy.cfg.cfg import CFG from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError from guppy.gtypes import NoneType -from guppy.nodes import NestedFunctionDef +from guppy.nodes import ( + DesugaredGenerator, + DesugaredListComp, + IterEnd, + IterHasNext, + IterNext, + MakeIter, + NestedFunctionDef, +) # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -142,6 +156,35 @@ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None: # its own jumps since the body is not guaranteed to execute return tail_bb + def visit_For(self, node: ast.For, bb: BB, jumps: Jumps) -> BB | None: + template = """ + it = make_iter + while True: + b, it = has_next + if b: + x, it = get_next + body + else: + break + end_iter # Consume iterator one last time + """ + + it = make_var(next(tmp_vars), node.iter) + b = make_var(next(tmp_vars), node.iter) + new_nodes = template_replace( + template, + node, + it=it, + b=b, + x=node.target, + make_iter=with_loc(node.iter, MakeIter(value=node.iter, origin_node=node)), + has_next=with_loc(node.iter, IterHasNext(value=it)), + get_next=with_loc(node.iter, IterNext(value=it)), + end_iter=with_loc(node.iter, IterEnd(value=it)), + body=node.body, + ) + return self.visit_stmts(new_nodes, bb, jumps) + def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None: if not jumps.continue_bb: raise InternalGuppyError("Continue BB not defined") @@ -211,18 +254,10 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]: builder = ExprBuilder(cfg, bb) return builder.visit(node), builder.bb - @classmethod - def _make_var(cls, name: str, loc: ast.expr | None = None) -> ast.Name: - """Creates an `ast.Name` node.""" - node = ast.Name(id=name, ctx=ast.Load) - if loc is not None: - set_location_from(node, loc) - return node - @classmethod def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None: """Adds a temporary variable assignment to a basic block.""" - node = ast.Assign(targets=[cls._make_var(tmp_name, value)], value=value) + node = ast.Assign(targets=[make_var(tmp_name, value)], value=value) set_location_from(node, value) bb.statements.append(node) @@ -256,7 +291,51 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name: self.bb = merge_bb # The final value is stored in the temporary variable - return self._make_var(tmp, node) + return make_var(tmp, node) + + def visit_ListComp(self, node: ast.ListComp) -> ast.AST: + # Check for illegal expressions + illegals = find_nodes(is_illegal_in_list_comp, node) + if illegals: + raise GuppyError( + "Expression is not supported inside a list comprehension", illegals[0] + ) + + # Desugar into statements that create the iterator, check for a next element, + # get the next element, and finalise the iterator. + gens = [] + for g in node.generators: + if g.is_async: + raise GuppyError("Async generators are not supported", g) + g.iter = self.visit(g.iter) + gen = DesugaredGenerator() + + template = """ + it = make_iter + b, it = has_next + x, it = get_next + """ + it = make_var(next(tmp_vars), g.iter) + b = make_var(next(tmp_vars), g.iter) + [gen.iter_assign, gen.hasnext_assign, gen.next_assign] = cast( + list[ast.Assign], + template_replace( + template, + g.iter, + it=it, + b=b, + x=g.target, + make_iter=with_loc(it, MakeIter(value=g.iter, origin_node=node)), + has_next=with_loc(it, IterHasNext(value=it)), + get_next=with_loc(it, IterNext(value=it)), + ), + ) + gen.iterend = with_loc(it, IterEnd(value=it)) + gen.iter, gen.hasnext, gen.ifs = it, b, g.ifs + gens.append(gen) + + node.elt = self.visit(node.elt) + return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) def generic_visit(self, node: ast.AST) -> ast.AST: # Short-circuit expressions must be built using the `BranchBuilder`. However, we @@ -275,7 +354,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST: self._tmp_assign(tmp, false_const, false_bb) merge_bb = self.cfg.new_bb(true_bb, false_bb) self.bb = merge_bb - return self._make_var(tmp, node) + return make_var(tmp, node) # For all other expressions, just recurse deeper with the node transformer return super().generic_visit(node) @@ -398,3 +477,16 @@ def is_short_circuit_expr(node: ast.AST) -> bool: return isinstance(node, ast.BoolOp) or ( isinstance(node, ast.Compare) and len(node.comparators) > 1 ) + + +def is_illegal_in_list_comp(node: ast.AST) -> bool: + """Checks if an expression is illegal to use in a list comprehension.""" + return isinstance(node, ast.IfExp | ast.NamedExpr) or is_short_circuit_expr(node) + + +def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: + """Creates an `ast.Name` node.""" + node = ast.Name(id=name, ctx=ast.Load) + if loc is not None: + set_location_from(node, loc) + return node From 7630f7b8cfc13e973ae50fe5dd241c0b0cf0b01f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 15 Jan 2024 17:40:17 +0000 Subject: [PATCH 2/5] Add make_assign helper --- guppy/cfg/builder.py | 54 +++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 0c42207c..e67a5bf5 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -1,7 +1,7 @@ import ast import itertools from collections.abc import Iterator -from typing import NamedTuple, cast +from typing import NamedTuple from guppy.ast_util import ( AstVisitor, @@ -258,9 +258,8 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]: @classmethod def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None: """Adds a temporary variable assignment to a basic block.""" - node = ast.Assign(targets=[make_var(tmp_name, value)], value=value) - set_location_from(node, value) - bb.statements.append(node) + lhs = make_var(tmp_name, value) + bb.statements.append(make_assign([lhs], value)) def visit_Name(self, node: ast.Name) -> ast.Name: return node @@ -309,31 +308,24 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST: if g.is_async: raise GuppyError("Async generators are not supported", g) g.iter = self.visit(g.iter) - gen = DesugaredGenerator() - - template = """ - it = make_iter - b, it = has_next - x, it = get_next - """ it = make_var(next(tmp_vars), g.iter) - b = make_var(next(tmp_vars), g.iter) - [gen.iter_assign, gen.hasnext_assign, gen.next_assign] = cast( - list[ast.Assign], - template_replace( - template, - g.iter, - it=it, - b=b, - x=g.target, - make_iter=with_loc(it, MakeIter(value=g.iter, origin_node=node)), - has_next=with_loc(it, IterHasNext(value=it)), - get_next=with_loc(it, IterNext(value=it)), + hasnext = make_var(next(tmp_vars), g.iter) + desugared = DesugaredGenerator( + iter=it, + hasnext=hasnext, + iter_assign=make_assign( + [it], with_loc(it, MakeIter(value=g.iter, origin_node=node)) + ), + hasnext_assign=make_assign( + [hasnext, it], with_loc(it, IterHasNext(value=it)) ), + next_assign=make_assign( + [g.target, it], with_loc(it, IterNext(value=it)) + ), + iterend=with_loc(it, IterEnd(value=it)), + ifs=g.ifs, ) - gen.iterend = with_loc(it, IterEnd(value=it)) - gen.iter, gen.hasnext, gen.ifs = it, b, g.ifs - gens.append(gen) + gens.append(desugared) node.elt = self.visit(node.elt) return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) @@ -507,3 +499,13 @@ def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: if loc is not None: set_location_from(node, loc) return node + + +def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign: + """Creates an `ast.Assign` node.""" + assert len(lhs) > 0 + if len(lhs) == 1: + target = lhs[0] + else: + target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store())) + return with_loc(value, ast.Assign(targets=[target], value=value)) From 70c5d07b9a9dbee36d06683090dc71f3dcb42774 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 12:37:00 +0000 Subject: [PATCH 3/5] Turn TemplateReplacer into dataclass --- guppy/ast_util.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 31bb940f..bfa44a7c 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -1,6 +1,7 @@ import ast import textwrap from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast if TYPE_CHECKING: @@ -154,20 +155,13 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: return ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx) +@dataclass(frozen=True, eq=False) class TemplateReplacer(ast.NodeTransformer): """Replaces nodes in a template.""" replacements: Mapping[str, ast.AST | Sequence[ast.AST]] default_loc: ast.AST - def __init__( - self, - replacements: Mapping[str, ast.AST | Sequence[ast.AST]], - default_loc: ast.AST, - ) -> None: - self.replacements = replacements - self.default_loc = default_loc - def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]: if x not in self.replacements: msg = f"No replacement for `{x}` is given" From fd8f56c09c593c4f374ae27ae6085ec10ec67116 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 12:40:30 +0000 Subject: [PATCH 4/5] Add missing location --- guppy/ast_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/guppy/ast_util.py b/guppy/ast_util.py index bfa44a7c..b5362b46 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -152,7 +152,10 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: ) def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: - return ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx) + return with_loc( + node, + ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx), + ) @dataclass(frozen=True, eq=False) From 679016121d602909187c9e90bf6b53a5951e1581 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 16 Jan 2024 12:48:32 +0000 Subject: [PATCH 5/5] Adjust context in make_assign --- guppy/cfg/builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index e67a5bf5..1310562e 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -5,6 +5,7 @@ from guppy.ast_util import ( AstVisitor, + ContextAdjuster, find_nodes, set_location_from, template_replace, @@ -504,6 +505,8 @@ def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign: """Creates an `ast.Assign` node.""" assert len(lhs) > 0 + adjuster = ContextAdjuster(ast.Store()) + lhs = [adjuster.visit(expr) for expr in lhs] if len(lhs) == 1: target = lhs[0] else: