-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocessor
executable file
·285 lines (219 loc) · 8.58 KB
/
preprocessor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#!/usr/bin/env python3
"""
$ python3 preprocessor [-d] --input_file <input_file>
--output_file <output_file>
--task <task to generate the export function>
E.g.
python3 preprocessor --input_file test_shake256_32_64.jpp
--output_file test_shake256_32_64.jpp
--task "fn:shake256 p:OUTLEN:32 p:INLEN:64"
"""
import os
import re
import sys
import pprint
import argparse
from argparse import Namespace
import utils
from generic_fn import GenericFn
from task import Task
from typed_generic_fn import TypedGenericFn
import json
def parse_args() -> Namespace:
"""Parse CLI arguments"""
parser = argparse.ArgumentParser(description="Jasmin source code preprocessor")
parser.add_argument(
"-in",
"--input_file",
required=True,
help="Input file path",
type=str,
)
parser.add_argument(
"-out",
"--output_file",
required=False,
help="Output file path",
type=str,
)
parser.add_argument(
"-d",
"--debug",
help="Enable debugging",
action="store_true", # Makes -d a flag without an argument
)
parser.add_argument(
"--after_macro",
help="Prints the program after replacing #expand macros",
action="store_true", # Makes a flag without an argument
)
parser.add_argument(
"--after_rm_generic",
help="Prints the program after removing the generic functions from the source code",
action="store_true", # Makes a flag without an argument
)
parser.add_argument(
"--after_tasks",
help="Prints the program after resolving the tasks",
action="store_true", # Makes a flag without an argument
)
parser.add_argument(
"--after_generic_fn_calls",
help="Prints the program after resolving generic function calls",
action="store_true", # Makes a flag without an argument
)
parser.add_argument("--task", required=False, nargs="+")
parser.add_argument(
"--workers",
help="Number of worker threads for finding subtask concurrently",
type=int,
required=False,
)
parser.add_argument(
"--print_tasks",
help="Print tasks in JSON format",
action="store_true", # Makes -d a flag without an argument
)
return parser.parse_args()
def resolve_templates(
text: str,
global_params: dict[str, int],
generic_fn_dict: dict[str, GenericFn],
typed_generic_fn_dict: dict[str, TypedGenericFn],
args: Namespace,
) -> str:
"""
Resolves templates in the source code.
The code assumes that all template declarations follow the format:
'[inline/#[returnaddress="stack"]]
fn <function_name><template_params>(<parameters>) -> <return_type> { <function_body> } //<>'
"""
# pylint: disable=redefined-outer-name
# 1st step: Update global params dict. The source file may define new
# `param int` variables or override previously defined ones
tmp: dict[str, int] = utils.get_params(text)
global_params.update(tmp)
# 2nd step: Update generic functions dict from Jasmin source code
tmp: dict[str, GenericFn] = utils.get_generic_fn_dict(text)
generic_fn_dict.update(tmp)
# 2.5 step: Update generic functions dict from Jasmin source code
tmp: dict[str, TypedGenericFn] = utils.get_typed_generic_fn_dict(text)
typed_generic_fn_dict.update(tmp)
# 3rd step: Remove the code of the generic functions from the source code
text = utils.remove_typed_generic_fn_text(text)
text = utils.remove_generic_fn_text(text)
if args.after_rm_generic:
print(text)
sys.exit(0)
# 4th step: Replace calls to generic functions with calls to concrete
# functions
text = utils.replace_typed_generic_calls_with_concrete(text, global_params)
text = utils.replace_generic_calls_with_concrete(text, global_params)
# 5th step: Get a list of all the tasks that need to be done
# First, we gather the tasks without handling nested calls, and then, after all
# the tasks are collected, resolve the nested calls in the Task class.
# This needs to be done on the input_text instead of the already processed
# text
tasks: list[Task] = (
export_tasks + utils.get_tasks(input_text, global_params) + utils.get_typed_fn_tasks(input_text, global_params)
)
if args.debug:
print("Tasks [1st pass]:")
pprint.pprint(tasks)
print("\n", end="")
if args.debug and args.workers:
print(f"# Workers: {args.workers}")
if args.workers:
tasks = utils.find_sub_tasks_sequentially(tasks, generic_fn_dict, typed_generic_fn_dict)
else:
tasks = utils.find_sub_tasks_sequentially(tasks, generic_fn_dict, typed_generic_fn_dict)
if args.debug:
print("Tasks [2nd pass (includes subtasks)]:")
pprint.pprint(tasks)
print("\n", end="")
# Remove duplicate tasks
tasks = utils.remove_duplicates(tasks)
if args.debug:
print("Tasks [3rd pass (removes duplicate tasks)]:")
pprint.pprint(tasks)
print("\n", end="")
if args.print_tasks:
tasks_json = utils.taks_to_json(generic_fn_dict, tasks)
json_string: str = json.dumps(tasks_json)
print(json_string)
sys.exit(0)
# See if there is any invalid task
utils.validate_tasks(tasks)
# 6th step: Resolve each task
for task in tasks:
text = task.resolve(text, generic_fn_dict, typed_generic_fn_dict)
if args.after_tasks:
print(text)
sys.exit(0)
# 7th step: Resolve generic function calls that were not resolved before
text = utils.resolve_generic_fn_calls(text, global_params)
if args.after_generic_fn_calls:
print(text)
sys.exit(0)
# Last step: Remove extra blank space & remove auxiliary comments
pattern = r"// Place concrete instances of the (\w+) function here"
text = re.sub(pattern, "", text)
pattern = r"\n{3,}"
while re.search(pattern, text):
text = re.sub(pattern, "\n\n", text)
text = text.strip()
return text
if __name__ == "__main__":
try:
args: Namespace = parse_args()
if args.debug:
print(f"Args.task: {args.task}")
# Check if the input files exists
if not os.path.exists(args.input_file):
sys.stderr.write(f"Error: The file '{args.input_file}' does not exist.\n")
sys.exit(-1)
with open(args.input_file, "r", encoding="utf-8") as f:
text: str = f.read()
# 1st step: Load global parameters from param files
global_params: dict[str, int] = utils.get_params(text)
if args.debug:
print("DEBUG: Global Param Dictionary")
pprint.pprint(global_params)
print("\n")
if args.debug:
print("Text after macro replacement")
print(text, end="\n\n")
# 2nd step: Get generic functions dict from Jasmin source code
generic_fn_dict: dict[str, GenericFn] = utils.get_generic_fn_dict(text)
# 2.5 step: Get typed generic functions dict from Jasmin source code
typed_generic_fn_dict: dict[str, TypedGenericFn] = utils.get_typed_generic_fn_dict(text)
export_tasks: list[Task] = utils.parse_tasks(args.task, global_params)
if args.debug:
print("DEBUG: Export tasks: ")
pprint.pprint(export_tasks)
print("\n")
if args.debug:
print("DEBUG: Generic Fn Dictionary:")
pprint.pprint(sorted(generic_fn_dict.keys()))
print("\n")
print("DEBUG: Typed Generic Fn Dictionary:")
pprint.pprint(sorted(typed_generic_fn_dict.keys()))
print("\n")
with open(args.input_file, "r", encoding="utf-8") as f:
input_text: str = f.read()
input_text = utils.replace_expand_macros(input_text)
if args.after_macro:
print(input_text)
sys.exit(0)
output_text: str = resolve_templates(input_text, global_params, generic_fn_dict, typed_generic_fn_dict, args)
output_text += "\n"
output_text = utils.replace_eval_global_params(output_text, global_params)
if args.debug:
print("\n--------------------------------------------------------------\n")
if args.output_file:
with open(args.output_file, "w", encoding="utf-8") as f:
f.write(output_text)
else:
print(output_text)
except KeyboardInterrupt:
sys.stderr.write("Program interrupted (Ctrl+C)\n")