Skip to content

Commit

Permalink
Merge pull request #8 from kavigupta/class-bases-and-decorators
Browse files Browse the repository at this point in the history
Class bases and decorators
  • Loading branch information
kavigupta authored Feb 13, 2024
2 parents 0d84efe + 1986980 commit f40b1ab
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 11 deletions.
2 changes: 1 addition & 1 deletion ast_scope/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def static_dependency_graph(self):
g.add_nodes_from(variables)
varis = self.global_scope.variables
for construct in varis.functions | varis.classes:
for node in get_all_nodes(construct):
for node in get_all_nodes(*construct.body):
if node not in self:
continue
if self[node] is not self._global_scope:
Expand Down
12 changes: 11 additions & 1 deletion ast_scope/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,17 @@ def visit_ClassDef(self, class_node):
subscope = self.create_subannotator(
IntermediateClassScope(class_node, self.scope, self.class_binds_near)
)
ast.NodeVisitor.generic_visit(subscope, class_node)
assert class_node._fields == (
"name",
"bases",
"keywords",
"body",
"decorator_list",
)
visit_all(subscope, class_node.body)
visit_all(
self, class_node.bases, class_node.keywords, class_node.decorator_list
)

def visit_Global(self, global_node):
for name in global_node.names:
Expand Down
10 changes: 9 additions & 1 deletion ast_scope/pull_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,13 @@ def visit_ClassDef(self, node):
scope = self.pull_scope(node)
if node not in self.node_to_corresponding_scope:
self.node_to_corresponding_scope[node] = ClassScope(node, scope)
assert node._fields == (
"name",
"bases",
"keywords",
"body",
"decorator_list",
)
visit_all(self, node.bases, node.keywords, node.decorator_list)
scope.add_class(node, self.node_to_corresponding_scope[node])
super().generic_visit(node)
visit_all(self, node.body)
8 changes: 5 additions & 3 deletions ast_scope/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ def generic_visit(self, node):
super().generic_visit(node)


def get_all_nodes(node):
def get_all_nodes(*nodes):
getter = GetAllNodes()
getter.visit(node)
return [subnode for subnode in getter.nodes if subnode is not node]
for node in nodes:
getter.visit(node)
nodes = set(nodes)
return [subnode for subnode in getter.nodes if subnode not in nodes]


class GetName(GroupSimilarConstructsVisitor):
Expand Down
10 changes: 10 additions & 0 deletions tests/class_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,13 @@ def test_listcomp_value(self):
[{g}x for {~@3:11}t in {-X@1:0}x]
"""
)

def test_subclass(self):
self.assertAnnotationWorks(
"""
{g}class X:
{-X@1:0}x = 2
{g}class Y({g}X):
pass
"""
)
34 changes: 34 additions & 0 deletions tests/types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import ast

import ast_scope
from .utils import DisplayAnnotatedTestCase, remove_directives


class TypeAnnotationTest(DisplayAnnotatedTestCase):
def test_basic_assignment(self):
annotated_code = """
{<3.8!g}@{g}f
{>=3.8!g}class A:
pass
{g}class B({g}A, x={g}A):
pass
{g}C = {g}B
{g}CONSTANT: {g}C = {g}C()
{g}CONST: {g}C
"""
self.assertAnnotationWorks(annotated_code)
scope_info = ast_scope.annotate(ast.parse(remove_directives(annotated_code)))

self.assertEqual(
scope_info.static_dependency_graph._DiGraph__adjacency_list,
{
"B": set(),
"C": set(),
"CONSTANT": set(),
"CONST": set(),
"A": set(),
"f": set(),
},
)
15 changes: 10 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import sys
import ast
import re
Expand Down Expand Up @@ -93,6 +92,10 @@ def all_nodes_gen_for_scope(scope):
for node in all_nodes_gen_for_variables(scope.variables):
yield scope, node


def remove_directives(annotated_code):
return trim(re.sub(r"\{[^\}]+\}", "", annotated_code))

class DisplayAnnotatedTestCase(unittest.TestCase):
def _check_nodes(self, mapping, *scopes):
overall_scope = [item for scope in scopes for item in all_nodes_gen_for_scope(scope)]
Expand All @@ -102,15 +105,17 @@ def _check_nodes(self, mapping, *scopes):

def assertAnnotationWorks(self, annotated_code, code=None, *, class_binds_near=False):
# directives of the form {>version!scope} are removed unless the version is satisfied
regex = r"\{>=(\d+\.\d+)!([^\}]+)\}"
regex = r"\{(<|>=)(\d+\.\d+)!([^\}]+)\}"
def replacer(match):
version, scope = match.groups()
if sys.version_info >= tuple(map(int, version.split("."))):
comparator, version, scope = match.groups()
geq_expected = comparator == ">="
geq_actual = sys.version_info >= tuple(map(int, version.split(".")))
if geq_actual == geq_expected:
return "{" + scope + "}"
return ""
annotated_code = re.sub(regex, replacer, annotated_code)
if code is None:
code = trim(re.sub(r"\{[^\}]+\}", "", annotated_code))
code = remove_directives(annotated_code)

scope_info = annotate(ast.parse(code), class_binds_near)
scope_info.static_dependency_graph # just check for errors
Expand Down

0 comments on commit f40b1ab

Please sign in to comment.