-
Notifications
You must be signed in to change notification settings - Fork 1
/
strip.py
159 lines (124 loc) · 4.49 KB
/
strip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from __future__ import annotations
import ast
import pathlib
import sys
if len(sys.argv) < 2:
sys.exit(f"Usage: {sys.argv[0]} <input file> [<output file>]")
input_file_path = sys.argv[1]
input_file = pathlib.Path(input_file_path).read_text()
class TypeHintRemover(ast.NodeTransformer):
# remove type annotations and docstrings from functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST | None:
self.generic_visit(node)
node.returns = None
if not node.body:
return node
node.body = [
statement
for statement in node.body
if not (
isinstance(statement, ast.Expr)
and isinstance(statement.value, ast.Constant)
)
]
return node
# remove type annotations from args
def visit_arg(self, node: ast.arg) -> ast.AST | None:
self.generic_visit(node)
node.annotation = None
return node
# remove type annotations, docstrings from classes and handle dataclasses
def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST | None:
if not node.body:
return node
# remove all docstrings
node.body = [
statement
for statement in node.body
if not (
isinstance(statement, ast.Expr)
and isinstance(statement.value, ast.Constant)
)
]
# remove and collect all class attributes
class_vars = []
to_pop = []
for i, subnode in enumerate(node.body):
if isinstance(subnode, ast.AnnAssign):
class_vars.append(subnode)
to_pop.append(i)
for i in reversed(to_pop):
node.body.pop(i)
decorators = [
decorator.id
for decorator in node.decorator_list
if isinstance(decorator, ast.Name)
]
if "dataclass" in decorators:
node.decorator_list = []
self.transform_dataclass(class_vars, node)
self.generic_visit(node)
return node
# manually implement simplified dataclass decorator
def transform_dataclass(
self,
class_vars: list[ast.AnnAssign],
node: ast.ClassDef,
) -> None:
arguments = [ast.arg(arg="self")]
body = []
for var in class_vars:
target = var.target
if not isinstance(target, ast.Name):
continue
name = target.id
arguments.append(ast.arg(arg=name))
body.append(
ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr=name)],
value=ast.Name(id=name),
),
)
init = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=arguments,
defaults=[],
posonlyargs=[],
kwonlyargs=[],
),
body=body,
decorator_list=[],
annotations=[],
)
node.body.insert(0, init)
# remove all type annotations from assignments
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST | None:
return None if node.value is None else ast.Assign([node.target], node.value)
# remove all import 'typing' statements
def visit_Import(self, node: ast.Import) -> ast.AST | None:
node.names = [n for n in node.names if n.name not in {"typing", "dataclasses"}]
return node if node.names else None
# remove all import from 'typing' and '__future__' statements
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
return (
node if node.module not in {"typing", "__future__", "dataclasses"} else None
)
# remove all assert statements
def visit_Assert(self, node: ast.Assert) -> ast.AST | None:
return None
def visit_TypeAlias(self, node: ast.TypeAlias) -> ast.AST | None:
return None
def visit_TypeVar(self, node: ast.TypeVar) -> ast.AST | None:
return None
# parse the source code into an AST
parsed_source = ast.parse(input_file)
# apply the visitor
transformed = TypeHintRemover().visit(parsed_source)
# convert the AST back to source code
ast.fix_missing_locations(transformed)
unparsed = ast.unparse(transformed)
if len(sys.argv) < 3 or (output_file_path := sys.argv[2]) == input_file_path:
print(unparsed)
sys.exit(0)
pathlib.Path(output_file_path).write_text(unparsed)