Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static typing and more exports #5

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ast_scope/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .annotate import annotate
from .annotate import annotate, ScopeInfo
from .scope import Scope
17 changes: 9 additions & 8 deletions ast_scope/annotate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from ast_scope.scope import FunctionScope
from .annotator import AnnotateScope, IntermediateGlobalScope
import ast
from ast_scope.scope import ErrorScope, FunctionScope, GlobalScope, Scope
from .annotator import AnnotateScope, IntermediateGlobalScope, IntermediateScope
from .pull_scope import PullScopes
from .utils import get_all_nodes, get_name
from .graph import DiGraph


class ScopeInfo:
def __init__(self, tree, global_scope, error_scope, node_to_containing_scope):
def __init__(self, tree: ast.AST, global_scope: GlobalScope, error_scope: ErrorScope, node_to_containing_scope: dict[ast.AST, Scope]):
self._tree = tree
self._global_scope = global_scope
self._error_scope = error_scope
Expand Down Expand Up @@ -42,13 +43,13 @@ def static_dependency_graph(self):
def __iter__(self):
return iter(self._node_to_containing_scope)

def __contains__(self, node):
def __contains__(self, node: ast.AST):
return node in self._node_to_containing_scope

def __getitem__(self, node):
def __getitem__(self, node: ast.AST):
return self._node_to_containing_scope[node]

def function_scope_for(self, node):
def function_scope_for(self, node: ast.AST):
"""
Returns the function scope for the given FunctionDef node.
"""
Expand All @@ -61,8 +62,8 @@ def function_scope_for(self, node):
return None


def annotate(tree, class_binds_near=False):
annotation_dict = {}
def annotate(tree: ast.AST, class_binds_near: bool=False):
annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]] = {}
annotator = AnnotateScope(
IntermediateGlobalScope(), annotation_dict, class_binds_near=class_binds_near
)
Expand Down
171 changes: 88 additions & 83 deletions ast_scope/annotator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import abc
import ast
from typing import Iterable, Self

from .group_similar_constructs import GroupSimilarConstructsVisitor
from .utils import name_of_alias
Expand All @@ -12,61 +13,66 @@ class IntermediateScope(abc.ABC):
"""

def __init__(self):
self.referenced_variables = set()
self.written_variables = set()
self.nonlocal_variables = set()
self.global_variables = set()
self.referenced_variables: set[str] = set()
self.written_variables: set[str] = set()
self.nonlocal_variables: set[str] = set()
self.global_variables: set[str] = set()

def load(self, variable):
def load(self, variable: str):
self.referenced_variables.add(variable)

def modify(self, variable):
def modify(self, variable: str):
self.written_variables.add(variable)

def globalize(self, variable):
def globalize(self, variable: str):
self.global_variables.add(variable)

def nonlocalize(self, variable):
def nonlocalize(self, variable: str):
self.nonlocal_variables.add(variable)

@abc.abstractmethod
def global_frame(self):
def global_frame(self) -> 'IntermediateGlobalScope':
pass

@abc.abstractmethod
def find(self, name, global_acceptable=True):
def find(self, name: str, is_assignment: bool, global_acceptable: bool=True) -> Self | None:
"""
Finds the actual frame containing the variable name, or None if no frame exists
"""
pass

def true_parent(self):
parent = self.parent
while isinstance(parent, IntermediateClassScope):
parent = parent.parent
return parent


class IntermediateGlobalScope(IntermediateScope):
def find(self, name, is_assignment, global_acceptable=True):
def find(self, name: str, is_assignment: bool, global_acceptable: bool=True):
if not global_acceptable:
return None
return self

def global_frame(self):
return self



class IntermediateFunctionScope(IntermediateScope):
def __init__(self, node, parent):
class IntermediateScopeWithParent(IntermediateScope):
def __init__(self, parent: IntermediateScope):
self.parent = parent
super().__init__()

def true_parent(self) -> IntermediateScope:
parent = self.parent
while isinstance(parent, IntermediateClassScope):
parent = parent.parent
return parent


class IntermediateFunctionScope(IntermediateScopeWithParent):
def __init__(self, node: ast.FunctionDef | ast.AsyncFunctionDef | ast.comprehension | ast.Lambda, parent: IntermediateScope):
super().__init__(parent)
self.node = node
self.parent = parent

def global_frame(self):
def global_frame(self) -> IntermediateGlobalScope:
return self.true_parent().global_frame()

def find(self, name, is_assignment, global_acceptable=True):
def find(self, name: str, is_assignment: bool, global_acceptable: bool=True):
if name in self.global_variables:
return self.global_frame()
if name in self.nonlocal_variables:
Expand All @@ -76,40 +82,38 @@ def find(self, name, is_assignment, global_acceptable=True):
return self.true_parent().find(name, is_assignment, global_acceptable)


class IntermediateClassScope(IntermediateScope):
def __init__(self, node, parent, class_binds_near):
super().__init__()
class IntermediateClassScope(IntermediateScopeWithParent):
def __init__(self, node: ast.ClassDef, parent: IntermediateScope, class_binds_near: bool):
super().__init__(parent)
self.node = node
self.parent = parent
self.class_binds_near = class_binds_near

def global_frame(self):
return self.true_parent().find(self)
def global_frame(self) -> IntermediateGlobalScope:
return self.true_parent().global_frame()

def find(self, name, is_assignment, global_acceptable=True):
def find(self, name: str, is_assignment: bool, global_acceptable: bool=True):
if self.class_binds_near:
# anything can be in a class frame
return self
if is_assignment:
return self
return self.parent.find(name, is_assignment, global_acceptable)


class GrabVariable(ast.NodeVisitor):
"""
Dumps variables from a given name object into the given scope.
"""

def __init__(self, scope, variable, annotation_dict):
def __init__(self, scope: IntermediateScope, variable: ast.Name, annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]]):
self.scope = scope
self.variable = variable
self.annotation_dict = annotation_dict

def visit_generic(self, node):
def visit_generic(self, node: ast.AST):
raise RuntimeError("Unsupported node type: {node}".format(node=node))

def visit_Name(self, node):
super().visit_generic(node)
def visit_Name(self, node: ast.Name):
super().generic_visit(node)

def load(self):
self.annotation_dict[self.variable] = self.variable.id, self.scope, False
Expand All @@ -119,70 +123,71 @@ def modify(self):
self.annotation_dict[self.variable] = self.variable.id, self.scope, True
self.scope.modify(self.variable.id)

def visit_Load(self, _):
def visit_Load(self, node: ast.Load):
self.load()

def visit_Store(self, _):
def visit_Store(self, node: ast.Store):
self.modify()

def visit_Del(self, _):
def visit_Del(self, node: ast.Del):
self.modify()

def visit_AugLoad(self, _):
def visit_AugLoad(self, node: ast.AugLoad):
raise RuntimeError("Unsupported: AugStore")

def visit_AugStore(self, _):
def visit_AugStore(self, node: ast.AugStore):
raise RuntimeError("Unsupported: AugStore")


class ProcessArguments(ast.NodeVisitor):
def __init__(self, expr_scope, arg_scope):
def __init__(self, expr_scope: 'AnnotateScope', arg_scope: 'AnnotateScope'):
self.expr_scope = expr_scope
self.arg_scope = arg_scope

def visit_arg(self, node):
def visit_arg(self, node: ast.arg):
self.arg_scope.visit(node)
visit_all(self.expr_scope, node.annotation, getattr(node, "type_comment", None))

def visit_arguments(self, node):
def visit_arguments(self, node: ast.AST):
super().generic_visit(node)

def generic_visit(self, node):
def generic_visit(self, node: ast.AST):
self.expr_scope.visit(node)


class AnnotateScope(GroupSimilarConstructsVisitor):
def __init__(self, scope, annotation_dict, class_binds_near):
def __init__(self, scope: IntermediateScope, annotation_dict: dict[ast.AST, tuple[str, IntermediateScope, bool]], class_binds_near: bool):
self.scope = scope
self.annotation_dict = annotation_dict
self.class_binds_near = class_binds_near

def annotate_intermediate_scope(self, node, name, is_assign):
def annotate_intermediate_scope(self, node: ast.AST, name: str, is_assign: bool):
self.annotation_dict[node] = name, self.scope, is_assign

def visit_Name(self, name_node):
GrabVariable(self.scope, name_node, self.annotation_dict).generic_visit(
name_node
def visit_Name(self, node: ast.Name):
GrabVariable(self.scope, node, self.annotation_dict).generic_visit(
node
)

def visit_ExceptHandler(self, handler_node):
self.annotate_intermediate_scope(handler_node, handler_node.name, True)
self.scope.modify(handler_node.name)
visit_all(self, handler_node.type, handler_node.body)
def visit_ExceptHandler(self, node: ast.ExceptHandler):
assert node.name
self.annotate_intermediate_scope(node, node.name, True)
self.scope.modify(node.name)
visit_all(self, node.type, node.body)

def visit_alias(self, alias_node):
variable = name_of_alias(alias_node)
self.annotate_intermediate_scope(alias_node, variable, True)
def visit_alias(self, node: ast.alias):
variable = name_of_alias(node)
self.annotate_intermediate_scope(node, variable, True)
self.scope.modify(variable)

def visit_arg(self, arg):
self.annotate_intermediate_scope(arg, arg.arg, True)
self.scope.modify(arg.arg)
def visit_arg(self, node: ast.arg):
self.annotate_intermediate_scope(node, node.arg, True)
self.scope.modify(node.arg)

def create_subannotator(self, scope):
def create_subannotator(self, scope: IntermediateScope):
return AnnotateScope(scope, self.annotation_dict, self.class_binds_near)

def visit_function_def(self, func_node, is_async):
def visit_function_def(self, func_node: ast.FunctionDef | ast.AsyncFunctionDef, is_async: bool):
del is_async
self.annotate_intermediate_scope(func_node, func_node.name, True)
self.scope.modify(func_node.name)
Expand All @@ -195,19 +200,19 @@ def visit_function_def(self, func_node, is_async):
ProcessArguments(self, subscope).visit(func_node.args)
visit_all(subscope, func_node.body, func_node.returns)

def visit_Lambda(self, func_node):
self.annotate_intermediate_scope(func_node, "<lambda>", None)
def visit_Lambda(self, node: ast.Lambda):
self.annotate_intermediate_scope(node, "<lambda>", False)
subscope = self.create_subannotator(
IntermediateFunctionScope(func_node, self.scope)
IntermediateFunctionScope(node, self.scope)
)
ProcessArguments(self, subscope).visit(func_node.args)
visit_all(subscope, func_node.body)
ProcessArguments(self, subscope).visit(node.args)
visit_all(subscope, node.body)

def visit_comprehension_generic(self, targets, comprehensions, typ):
del typ
def visit_comprehension_generic(self, targets: list[ast.expr], comprehensions: list[ast.comprehension], node: ast.AST):
del node
current_scope = self
for comprehension in comprehensions:
self.annotate_intermediate_scope(comprehension, "<comp>", None)
self.annotate_intermediate_scope(comprehension, "<comp>", False)
subscope = self.create_subannotator(
IntermediateFunctionScope(comprehension, current_scope.scope)
)
Expand All @@ -216,38 +221,38 @@ def visit_comprehension_generic(self, targets, comprehensions, typ):
current_scope = subscope
visit_all(current_scope, targets)

def visit_ClassDef(self, class_node):
self.annotate_intermediate_scope(class_node, class_node.name, True)
self.scope.modify(class_node.name)
def visit_ClassDef(self, node: ast.ClassDef):
self.annotate_intermediate_scope(node, node.name, True)
self.scope.modify(node.name)
subscope = self.create_subannotator(
IntermediateClassScope(class_node, self.scope, self.class_binds_near)
IntermediateClassScope(node, self.scope, self.class_binds_near)
)
assert class_node._fields == (
assert node._fields == (
"name",
"bases",
"keywords",
"body",
"decorator_list",
)
visit_all(subscope, class_node.body)
visit_all(subscope, node.body)
visit_all(
self, class_node.bases, class_node.keywords, class_node.decorator_list
self, node.bases, node.keywords, node.decorator_list
)

def visit_Global(self, global_node):
for name in global_node.names:
def visit_Global(self, node: ast.Global):
for name in node.names:
self.scope.globalize(name)

def visit_Nonlocal(self, nonlocal_node):
for name in nonlocal_node.names:
def visit_Nonlocal(self, node: ast.Nonlocal):
for name in node.names:
self.scope.nonlocalize(name)


def visit_all(visitor, *nodes):
def visit_all(visitor: ast.NodeVisitor, *nodes: Iterable[ast.AST] | ast.AST | None):
for node in nodes:
if node is None:
pass
elif isinstance(node, list):
elif isinstance(node, Iterable):
visit_all(visitor, *node)
else:
visitor.visit(node)
Loading