-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Build for loops and comprehensions #68
Changes from 3 commits
359c2da
633421d
7630f7b
70c5d07
fd8f56c
6790161
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,13 +3,28 @@ | |
from collections.abc import Iterator | ||
from typing import NamedTuple | ||
|
||
from guppy.ast_util import AstVisitor, set_location_from, with_loc | ||
from guppy.ast_util import ( | ||
AstVisitor, | ||
find_nodes, | ||
set_location_from, | ||
template_replace, | ||
with_loc, | ||
) | ||
from guppy.cfg.bb import BB, BBStatement | ||
from guppy.cfg.cfg import CFG | ||
from guppy.checker.core import Globals | ||
from guppy.error import GuppyError, InternalGuppyError | ||
from guppy.gtypes import NoneType | ||
from guppy.nodes import NestedFunctionDef, PyExpr | ||
from guppy.nodes import ( | ||
DesugaredGenerator, | ||
DesugaredListComp, | ||
IterEnd, | ||
IterHasNext, | ||
IterNext, | ||
MakeIter, | ||
NestedFunctionDef, | ||
PyExpr, | ||
) | ||
|
||
# In order to build expressions, need an endless stream of unique temporary variables | ||
# to store intermediate results | ||
|
@@ -142,6 +157,35 @@ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None: | |
# its own jumps since the body is not guaranteed to execute | ||
return tail_bb | ||
|
||
def visit_For(self, node: ast.For, bb: BB, jumps: Jumps) -> BB | None: | ||
template = """ | ||
it = make_iter | ||
while True: | ||
b, it = has_next | ||
if b: | ||
x, it = get_next | ||
body | ||
else: | ||
break | ||
end_iter # Consume iterator one last time | ||
""" | ||
|
||
it = make_var(next(tmp_vars), node.iter) | ||
b = make_var(next(tmp_vars), node.iter) | ||
new_nodes = template_replace( | ||
template, | ||
node, | ||
it=it, | ||
b=b, | ||
x=node.target, | ||
make_iter=with_loc(node.iter, MakeIter(value=node.iter, origin_node=node)), | ||
has_next=with_loc(node.iter, IterHasNext(value=it)), | ||
get_next=with_loc(node.iter, IterNext(value=it)), | ||
end_iter=with_loc(node.iter, IterEnd(value=it)), | ||
body=node.body, | ||
) | ||
return self.visit_stmts(new_nodes, bb, jumps) | ||
|
||
def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None: | ||
if not jumps.continue_bb: | ||
raise InternalGuppyError("Continue BB not defined") | ||
|
@@ -211,20 +255,11 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]: | |
builder = ExprBuilder(cfg, bb) | ||
return builder.visit(node), builder.bb | ||
|
||
@classmethod | ||
def _make_var(cls, name: str, loc: ast.expr | None = None) -> ast.Name: | ||
"""Creates an `ast.Name` node.""" | ||
node = ast.Name(id=name, ctx=ast.Load) | ||
if loc is not None: | ||
set_location_from(node, loc) | ||
return node | ||
|
||
@classmethod | ||
def _tmp_assign(cls, tmp_name: str, value: ast.expr, bb: BB) -> None: | ||
"""Adds a temporary variable assignment to a basic block.""" | ||
node = ast.Assign(targets=[cls._make_var(tmp_name, value)], value=value) | ||
set_location_from(node, value) | ||
bb.statements.append(node) | ||
lhs = make_var(tmp_name, value) | ||
bb.statements.append(make_assign([lhs], value)) | ||
|
||
def visit_Name(self, node: ast.Name) -> ast.Name: | ||
return node | ||
|
@@ -256,7 +291,44 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name: | |
self.bb = merge_bb | ||
|
||
# The final value is stored in the temporary variable | ||
return self._make_var(tmp, node) | ||
return make_var(tmp, node) | ||
|
||
def visit_ListComp(self, node: ast.ListComp) -> ast.AST: | ||
# Check for illegal expressions | ||
illegals = find_nodes(is_illegal_in_list_comp, node) | ||
if illegals: | ||
raise GuppyError( | ||
"Expression is not supported inside a list comprehension", illegals[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we maybe include all of them and not just the zeroth? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice, but at the moment we can only report a single location that is highlighted in the error message (see for example here) |
||
) | ||
|
||
# Desugar into statements that create the iterator, check for a next element, | ||
# get the next element, and finalise the iterator. | ||
gens = [] | ||
for g in node.generators: | ||
if g.is_async: | ||
raise GuppyError("Async generators are not supported", g) | ||
g.iter = self.visit(g.iter) | ||
it = make_var(next(tmp_vars), g.iter) | ||
hasnext = make_var(next(tmp_vars), g.iter) | ||
desugared = DesugaredGenerator( | ||
iter=it, | ||
hasnext=hasnext, | ||
iter_assign=make_assign( | ||
[it], with_loc(it, MakeIter(value=g.iter, origin_node=node)) | ||
), | ||
hasnext_assign=make_assign( | ||
[hasnext, it], with_loc(it, IterHasNext(value=it)) | ||
), | ||
next_assign=make_assign( | ||
[g.target, it], with_loc(it, IterNext(value=it)) | ||
), | ||
iterend=with_loc(it, IterEnd(value=it)), | ||
ifs=g.ifs, | ||
) | ||
gens.append(desugared) | ||
|
||
node.elt = self.visit(node.elt) | ||
return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) | ||
|
||
def visit_Call(self, node: ast.Call) -> ast.AST: | ||
# Parse compile-time evaluated `py(...)` expression | ||
|
@@ -291,7 +363,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST: | |
self._tmp_assign(tmp, false_const, false_bb) | ||
merge_bb = self.cfg.new_bb(true_bb, false_bb) | ||
self.bb = merge_bb | ||
return self._make_var(tmp, node) | ||
return make_var(tmp, node) | ||
# For all other expressions, just recurse deeper with the node transformer | ||
return super().generic_visit(node) | ||
|
||
|
@@ -414,3 +486,26 @@ def is_short_circuit_expr(node: ast.AST) -> bool: | |
return isinstance(node, ast.BoolOp) or ( | ||
isinstance(node, ast.Compare) and len(node.comparators) > 1 | ||
) | ||
|
||
|
||
def is_illegal_in_list_comp(node: ast.AST) -> bool: | ||
"""Checks if an expression is illegal to use in a list comprehension.""" | ||
return isinstance(node, ast.IfExp | ast.NamedExpr) or is_short_circuit_expr(node) | ||
|
||
|
||
def make_var(name: str, loc: ast.AST | None = None) -> ast.Name: | ||
"""Creates an `ast.Name` node.""" | ||
node = ast.Name(id=name, ctx=ast.Load) | ||
if loc is not None: | ||
set_location_from(node, loc) | ||
return node | ||
|
||
|
||
def make_assign(lhs: list[ast.AST], value: ast.expr) -> ast.Assign: | ||
"""Creates an `ast.Assign` node.""" | ||
assert len(lhs) > 0 | ||
if len(lhs) == 1: | ||
target = lhs[0] | ||
else: | ||
target = with_loc(value, ast.Tuple(elts=lhs, ctx=ast.Store())) | ||
return with_loc(value, ast.Assign(targets=[target], value=value)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note I think you could just make this an
@dataclass
and avoid the constructor. (@dataclass(eq=False, frozen=True)
or similar if you want to)