Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Type check lists and comprehensions #69

Merged
merged 14 commits into from
Jan 16, 2024
190 changes: 153 additions & 37 deletions guppy/ast_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ast
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import textwrap
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast

if TYPE_CHECKING:
from guppy.gtypes import GuppyType
Expand Down Expand Up @@ -54,51 +56,165 @@ def generic_visit(self, node: Any, *args: Any, **kwargs: Any) -> T:
raise NotImplementedError(f"visit_{node.__class__.__name__} is not implemented")


class NameVisitor(ast.NodeVisitor):
"""Visitor to collect all `Name` nodes occurring in an AST."""

names: list[ast.Name]

def __init__(self) -> None:
self.names = []

def visit_Name(self, node: ast.Name) -> None:
self.names.append(node)
class AstSearcher(ast.NodeVisitor):
"""Visitor that searches for occurrences of specific nodes in an AST."""

matcher: Callable[[ast.AST], bool]
dont_recurse_into: set[type[ast.AST]]
found: list[ast.AST]
is_first_node: bool

def __init__(
self,
matcher: Callable[[ast.AST], bool],
dont_recurse_into: set[type[ast.AST]] | None = None,
) -> None:
self.matcher = matcher
self.dont_recurse_into = dont_recurse_into or set()
self.found = []
self.is_first_node = True

def generic_visit(self, node: ast.AST) -> None:
if self.matcher(node):
self.found.append(node)
if self.is_first_node or type(node) not in self.dont_recurse_into:
self.is_first_node = False
super().generic_visit(node)


def find_nodes(
matcher: Callable[[ast.AST], bool],
node: ast.AST,
dont_recurse_into: set[type[ast.AST]] | None = None,
) -> list[ast.AST]:
"""Returns all nodes in the AST that satisfy the matcher."""
v = AstSearcher(matcher, dont_recurse_into)
v.visit(node)
return v.found


def name_nodes_in_ast(node: Any) -> list[ast.Name]:
"""Returns all `Name` nodes occurring in an AST."""
v = NameVisitor()
v.visit(node)
return v.names


class ReturnVisitor(ast.NodeVisitor):
"""Visitor to collect all `Return` nodes occurring in an AST."""
found = find_nodes(lambda n: isinstance(n, ast.Name), node)
return cast(list[ast.Name], found)

returns: list[ast.Return]
inside_func_def: bool

def __init__(self) -> None:
self.returns = []
self.inside_func_def = False

def visit_Return(self, node: ast.Return) -> None:
self.returns.append(node)
def return_nodes_in_ast(node: Any) -> list[ast.Return]:
"""Returns all `Return` nodes occurring in an AST."""
found = find_nodes(lambda n: isinstance(n, ast.Return), node, {ast.FunctionDef})
return cast(list[ast.Return], found)

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
# Don't descend into nested function definitions
if not self.inside_func_def:
self.inside_func_def = True
for n in node.body:
self.visit(n)

def breaks_in_loop(node: Any) -> list[ast.Break]:
"""Returns all `Break` nodes occurring in a loop.

def return_nodes_in_ast(node: Any) -> list[ast.Return]:
"""Returns all `Return` nodes occurring in an AST."""
v = ReturnVisitor()
v.visit(node)
return v.returns
Note that breaks in nested loops are excluded.
"""
found = find_nodes(
lambda n: isinstance(n, ast.Break), node, {ast.For, ast.While, ast.FunctionDef}
)
return cast(list[ast.Break], found)


class ContextAdjuster(ast.NodeTransformer):
"""Updates the `ast.Context` indicating if expressions occur on the LHS or RHS."""

ctx: ast.expr_context

def __init__(self, ctx: ast.expr_context) -> None:
self.ctx = ctx

def visit(self, node: ast.AST) -> ast.AST:
return cast(ast.AST, super().visit(node))

def visit_Name(self, node: ast.Name) -> ast.Name:
return with_loc(node, ast.Name(id=node.id, ctx=self.ctx))

def visit_Starred(self, node: ast.Starred) -> ast.Starred:
return with_loc(node, ast.Starred(value=self.visit(node.value), ctx=self.ctx))

def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple:
return with_loc(
node, ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
)

def visit_List(self, node: ast.List) -> ast.List:
return with_loc(
node, ast.List(elts=[self.visit(elt) for elt in node.elts], ctx=self.ctx)
)

def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
# Don't adjust the slice!
return with_loc(
node,
ast.Subscript(value=self.visit(node.value), slice=node.slice, ctx=self.ctx),
)

def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute:
return ast.Attribute(value=self.visit(node.value), attr=node.attr, ctx=self.ctx)


class TemplateReplacer(ast.NodeTransformer):
"""Replaces nodes in a template."""

replacements: Mapping[str, ast.AST | Sequence[ast.AST]]
default_loc: ast.AST

def __init__(
self,
replacements: Mapping[str, ast.AST | Sequence[ast.AST]],
default_loc: ast.AST,
) -> None:
self.replacements = replacements
self.default_loc = default_loc

def _get_replacement(self, x: str) -> ast.AST | Sequence[ast.AST]:
if x not in self.replacements:
msg = f"No replacement for `{x}` is given"
raise ValueError(msg)
return self.replacements[x]

def visit_Name(self, node: ast.Name) -> ast.AST:
repl = self._get_replacement(node.id)
if not isinstance(repl, ast.expr):
msg = f"Replacement for `{node.id}` must be an expression"
raise TypeError(msg)

# Update the context
adjuster = ContextAdjuster(node.ctx)
return with_loc(repl, adjuster.visit(repl))

def visit_Expr(self, node: ast.Expr) -> ast.AST | Sequence[ast.AST]:
if isinstance(node.value, ast.Name):
repl = self._get_replacement(node.value.id)
repls = [repl] if not isinstance(repl, Sequence) else repl
# Wrap expressions to turn them into statements
return [
with_loc(r, ast.Expr(value=r)) if isinstance(r, ast.expr) else r
for r in repls
]
return self.generic_visit(node)

def generic_visit(self, node: ast.AST) -> ast.AST:
# Insert the default location
node = super().generic_visit(node)
return with_loc(self.default_loc, node)


def template_replace(
template: str, default_loc: ast.AST, **kwargs: ast.AST | Sequence[ast.AST]
) -> list[ast.stmt]:
"""Turns a template into a proper AST by substituting all placeholders."""
nodes = ast.parse(textwrap.dedent(template)).body
replacer = TemplateReplacer(kwargs, default_loc)
new_nodes = []
for n in nodes:
new = replacer.visit(n)
if isinstance(new, list):
new_nodes.extend(new)
else:
new_nodes.append(new)
return new_nodes


def line_col(node: ast.AST) -> tuple[int, int]:
Expand Down
33 changes: 26 additions & 7 deletions guppy/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self

from guppy.ast_util import AstNode, name_nodes_in_ast
from guppy.nodes import NestedFunctionDef
from guppy.nodes import DesugaredListComp, NestedFunctionDef

if TYPE_CHECKING:
from guppy.cfg.cfg import BaseCFG
Expand Down Expand Up @@ -99,24 +99,46 @@ def __init__(self, bb: BB):
self.bb = bb
self.stats = VariableStats()

def visit_Name(self, node: ast.Name) -> None:
self.stats.update_used(node)

def visit_Assign(self, node: ast.Assign) -> None:
self.stats.update_used(node.value)
self.visit(node.value)
for t in node.targets:
for name in name_nodes_in_ast(t):
self.stats.assigned[name.id] = node

def visit_AugAssign(self, node: ast.AugAssign) -> None:
self.stats.update_used(node.value)
self.visit(node.value)
self.stats.update_used(node.target) # The target is also used
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
if node.value:
self.stats.update_used(node.value)
self.visit(node.value)
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_DesugaredListComp(self, node: DesugaredListComp) -> None:
# Names bound in the comprehension are only available inside, so we shouldn't
# update `self.stats` with assignments
inner_visitor = VariableVisitor(self.bb)
inner_stats = inner_visitor.stats

# The generators are evaluated left to right
for gen in node.generators:
inner_visitor.visit(gen.iter_assign)
inner_visitor.visit(gen.hasnext_assign)
inner_visitor.visit(gen.next_assign)
for cond in gen.ifs:
inner_visitor.visit(cond)
inner_visitor.visit(node.elt)

self.stats.used |= {
x: n for x, n in inner_stats.used.items() if x not in self.stats.assigned
croyzor marked this conversation as resolved.
Show resolved Hide resolved
}

def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:
# In order to compute the used external variables in a nested function
# definition, we have to run live variable analysis first
Expand All @@ -139,6 +161,3 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:

# The name of the function is now assigned
self.stats.assigned[node.name] = node

def generic_visit(self, node: ast.AST) -> None:
self.stats.update_used(node)
Loading