Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Build for loops and comprehensions #68

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 153 additions & 37 deletions guppy/ast_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ast
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import textwrap
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast

if TYPE_CHECKING:
from guppy.gtypes import GuppyType
Expand Down Expand Up @@ -54,51 +56,165 @@ def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> T:
raise NotImplementedError(f"visit_{node.__class__.__name__} is not implemented")


class NameVisitor(ast.NodeVisitor):
"""Visitor to collect all `Name` nodes occurring in an AST."""

names: list[ast.Name]

def __init__(self) -> None:
self.names = []

def visit_Name(self, node: ast.Name) -> None:
self.names.append(node)
class AstSearcher(ast.NodeVisitor):
"""Visitor that searches for occurrences of specific nodes in an AST."""

matcher: Callable[[ast.AST], bool]
dont_recurse_into: set[type[ast.AST]]
found: list[ast.AST]
is_first_node: bool

def __init__(
self,
matcher: Callable[[ast.AST], bool],
dont_recurse_into: set[type[ast.AST]] | None = None,
) -> None:
self.matcher = matcher
self.dont_recurse_into = dont_recurse_into or set()
self.found = []
self.is_first_node = True

def generic_visit(self, node: ast.AST) -> None:
if self.matcher(node):
self.found.append(node)
if self.is_first_node or type(node) not in self.dont_recurse_into:
self.is_first_node = False
super().generic_visit(node)


def find_nodes(
matcher: Callable[[ast.AST], bool],
node: ast.AST,
dont_recurse_into: set[type[ast.AST]] | None = None,
) -> list[ast.AST]:
"""Returns all nodes in the AST that satisfy the matcher."""
v = AstSearcher(matcher, dont_recurse_into)
v.visit(node)
return v.found


def name_nodes_in_ast(node: Any) -> list[ast.Name]:
"""Returns all `Name` nodes occurring in an AST."""
v = NameVisitor()
v.visit(node)
return v.names


class ReturnVisitor(ast.NodeVisitor):
"""Visitor to collect all `Return` nodes occurring in an AST."""
found = find_nodes(lambda n: isinstance(n, ast.Name), node)
return cast(list[ast.Name], found)

returns: list[ast.Return]
inside_func_def: bool

def __init__(self) -> None:
self.returns = []
self.inside_func_def = False

def visit_Return(self, node: ast.Return) -> None:
self.returns.append(node)
def return_nodes_in_ast(node: Any) -> list[ast.Return]:
"""Returns all `Return` nodes occurring in an AST."""
found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef})
return cast(list[ast.Return], found)

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
# Don't descend into nested function definitions
if not self.inside_func_def:
self.inside_func_def = True
for n in node.body:
self.visit(n)

def breaks_in_loop(node: Any) -> list[ast.Break]:
"""Returns all `Break` nodes occurring in a loop.

def return_nodes_in_ast(node: Any) -> list[ast.Return]:
"""Returns all `Return` nodes occurring in an AST."""
v = ReturnVisitor()
v.visit(node)
return v.returns
Note that breaks in nested loops are excluded.
"""
found = find_nodes(
lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef}
)
return cast(list[ast.Break], found)


class ContextAdjuster(ast.NodeTransformer):
"""Updates the `ast.Context` indicating if expressions occur on the LHS or RHS."""

ctx: ast.expr_context

def __init__(self, ctx: ast.expr_context) -> None:
self.ctx = ctx

def visit(self, node: ast.AST) -> ast.AST:
return cast(ast.AST, super().visit(node))

def visit_Name(self, node: ast.Name) -> ast.Name:
return with_loc(node, ast.Name(id=node.id, ctx=self.ctx))

def visit_Starred(self, node: ast.Starred) -> ast.Starred:
return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx))

def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple:
return with_loc(
node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
)

def visit_List(self, node: ast.List) -> ast.List:
return with_loc(
node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
)

def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
# Don't adjust the slice!
return with_loc(
node,
ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx),
)

def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
return ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx)


class TemplateReplacer(ast.NodeTransformer):
"""Replaces nodes in a template."""

replacements: Mapping[str, ast.AST | Sequence[ast.AST]]
default_loc: ast.AST

def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note I think you could just make this an @dataclass and avoid the constructor. (@dataclass(eq=False, frozen=True) or similar if you want to)

self,
replacements: Mapping[str, ast.AST | Sequence[ast.AST]],
default_loc: ast.AST,
) -> None:
self.replacements = replacements
self.default_loc = default_loc

def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]:
if x not in self.replacements:
msg = f"No replacement for `{x}` is given"
raise ValueError(msg)
return self.replacements[x]

def visit_Name(self, node: ast.Name) -> ast.AST:
repl = self._get_replacement(node.id)
if not isinstance(repl, ast.expr):
msg = f"Replacement for `{node.id}` must be an expression"
raise TypeError(msg)

# Update the context
adjuster = ContextAdjuster(node.ctx)
return with_loc(repl, adjuster.visit(repl))

def visit_Expr(self, node: ast.Expr) -> ast.AST | Sequence[ast.AST]:
if isinstance(node.value, ast.Name):
repl = self._get_replacement(node.value.id)
repls = [repl] if not isinstance(repl, Sequence) else repl
# Wrap expressions to turn them into statements
return [
with_loc(r, ast.Expr(value=r)) if isinstance(r, ast.expr) else r
for r in repls
]
return self.generic_visit(node)

def generic_visit(self, node: ast.AST) -> ast.AST:
# Insert the default location
node = super().generic_visit(node)
return with_loc(self.default_loc, node)


def template_replace(
template: str, default_loc: ast.AST, **kwargs: ast.AST | Sequence[ast.AST]
) -> list[ast.stmt]:
"""Turns a template into a proper AST by substituting all placeholders."""
nodes = ast.parse(textwrap.dedent(template)).body
replacer = TemplateReplacer(kwargs, default_loc)
new_nodes = []
for n in nodes:
new = replacer.visit(n)
if isinstance(new, list):
new_nodes.extend(new)
else:
new_nodes.append(new)
return new_nodes


def line_col(node: ast.AST) -> tuple[int, int]:
Expand Down
125 changes: 110 additions & 15 deletions guppy/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,28 @@
from collections.abc import Iterator
from typing import NamedTuple

from guppy.ast_util import AstVisitor, set_location_from, with_loc
from guppy.ast_util import (
AstVisitor,
find_nodes,
set_location_from,
template_replace,
with_loc,
)
from guppy.cfg.bb import BB, BBStatement
from guppy.cfg.cfg import CFG
from guppy.checker.core import Globals
from guppy.error import GuppyError, InternalGuppyError
from guppy.gtypes import NoneType
from guppy.nodes import NestedFunctionDef, PyExpr
from guppy.nodes import (
DesugaredGenerator,
DesugaredListComp,
IterEnd,
IterHasNext,
IterNext,
MakeIter,
NestedFunctionDef,
PyExpr,
)

# In order to build expressions, need an endless stream of unique temporary variables
# to store intermediate results
Expand Down Expand Up @@ -142,6 +157,35 @@ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None:
# its own jumps since the body is not guaranteed to execute
return tail_bb

def visit_For(self, node: ast.For, bb: BB, jumps: Jumps) -> BB | None:
template = """
it = make_iter
while True:
b, it = has_next
if b:
x, it = get_next
body
else:
break
end_iter # Consume iterator one last time
"""

it = make_var(next(tmp_vars), node.iter)
b = make_var(next(tmp_vars), node.iter)
new_nodes = template_replace(
template,
node,
it=it,
b=b,
x=node.target,
make_iter=with_loc(node.iter, MakeIter(value=node.iter, origin_node=node)),
has_next=with_loc(node.iter, IterHasNext(value=it)),
get_next=with_loc(node.iter, IterNext(value=it)),
end_iter=with_loc(node.iter, IterEnd(value=it)),
body=node.body,
)
return self.visit_stmts(new_nodes, bb, jumps)

def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None:
if not jumps.continue_bb:
raise InternalGuppyError("Continue BB not defined")
Expand Down Expand Up @@ -211,20 +255,11 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]:
builder = ExprBuilder(cfg, bb)
return builder.visit(node), builder.bb

@classmethod
def _make_var(cls, name: str, loc: ast.expr | None = None) -> ast.Name:
"""Creates an `ast.Name` node."""
node = ast.Name(id=name, ctx=ast.Load)
if loc is not None:
set_location_from(node, loc)
return node

@classmethod
def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None:
"""Adds a temporary variable assignment to a basic block."""
node = ast.Assign(targets=[cls._make_var(tmp_name, value)], value=value)
set_location_from(node, value)
bb.statements.append(node)
lhs = make_var(tmp_name, value)
bb.statements.append(make_assign([lhs], value))

def visit_Name(self, node: ast.Name) -> ast.Name:
return node
Expand Down Expand Up @@ -256,7 +291,44 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
self.bb = merge_bb

# The final value is stored in the temporary variable
return self._make_var(tmp, node)
return make_var(tmp, node)

def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
# Check for illegal expressions
illegals = find_nodes(is_illegal_in_list_comp, node)
if illegals:
raise GuppyError(
"Expression is not supported inside a list comprehension", illegals[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we maybe include all of them and not just the zeroth?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice, but at the moment we can only report a single location that is highlighted in the error message (see for example here)

)

# Desugar into statements that create the iterator, check for a next element,
# get the next element, and finalise the iterator.
gens = []
for g in node.generators:
if g.is_async:
raise GuppyError("Async generators are not supported", g)
g.iter = self.visit(g.iter)
it = make_var(next(tmp_vars), g.iter)
hasnext = make_var(next(tmp_vars), g.iter)
desugared = DesugaredGenerator(
iter=it,
hasnext=hasnext,
iter_assign=make_assign(
[it], with_loc(it, MakeIter(value=g.iter, origin_node=node))
),
hasnext_assign=make_assign(
[hasnext, it], with_loc(it, IterHasNext(value=it))
),
next_assign=make_assign(
[g.target, it], with_loc(it, IterNext(value=it))
),
iterend=with_loc(it, IterEnd(value=it)),
ifs=g.ifs,
)
gens.append(desugared)

node.elt = self.visit(node.elt)
return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens))

def visit_Call(self, node: ast.Call) -> ast.AST:
# Parse compile-time evaluated `py(...)` expression
Expand Down Expand Up @@ -291,7 +363,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST:
self._tmp_assign(tmp, false_const, false_bb)
merge_bb = self.cfg.new_bb(true_bb, false_bb)
self.bb = merge_bb
return self._make_var(tmp, node)
return make_var(tmp, node)
# For all other expressions, just recurse deeper with the node transformer
return super().generic_visit(node)

Expand Down Expand Up @@ -414,3 +486,26 @@ def is_short_circuit_expr(node: ast.AST) -> bool:
return isinstance(node, ast.BoolOp) or (
isinstance(node, ast.Compare) and len(node.comparators) > 1
)


def is_illegal_in_list_comp(node: ast.AST) -> bool:
"""Checks if an expression is illegal to use in a list comprehension."""
return isinstance(node, ast.IfExp | ast.NamedExpr) or is_short_circuit_expr(node)


def make_var(name: str, loc: ast.AST | None = None) -> ast.Name:
"""Creates an `ast.Name` node."""
node = ast.Name(id=name, ctx=ast.Load)
if loc is not None:
set_location_from(node, loc)
return node


def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign:
"""Creates an `ast.Assign` node."""
assert len(lhs) > 0
if len(lhs) == 1:
target = lhs[0]
else:
target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store()))
return with_loc(value, ast.Assign(targets=[target], value=value))