-
Notifications
You must be signed in to change notification settings - Fork 0
/
task.py
306 lines (258 loc) · 12.2 KB
/
task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import re
import sys
import utils
from generic_fn import GenericFn
from typed_generic_fn import TypedGenericFn
from typing import cast, Optional
import pprint
class Task:
"""
Represents a task that needs to be resolved by inserting the concrete definition
for the function in the Jasmin source code.
Attributes:
fn_name (str): The name of the function to be resolved in the Jasmin code.
template_params (list[str]): A list of the value of the parameters to be used in the function.
global_params (dict[str, int]): The dictionary of the global parameters
"""
def __init__(
self,
fn_name: str,
template_params: list[str | int],
global_params: dict[str, int],
# unless specified otherwise, a task is considered to be a simple
# generic function (i.e. not typed)
typed_fn_names: list[str] | None = None,
# TODO: from typing import Optional
typed_fn_types: list[str] | None = None,
):
self.fn_name = fn_name
self.template_params = template_params
self.global_params = global_params
if isinstance(typed_fn_names, str):
self.typed_fn_names = utils.string_to_singleton_list(typed_fn_names)
else:
self.typed_fn_names = typed_fn_names
if isinstance(typed_fn_types, str):
self.typed_fn_types = utils.string_to_singleton_list(typed_fn_types)
else:
self.typed_fn_types = typed_fn_types
self.is_typed_task: bool = typed_fn_names is not None and typed_fn_types is not None
def __eq__(self, other: object):
if not isinstance(other, Task):
return False
template_params_int_self: list[int] = [int(val) for val in self.template_params]
template_params_int_other: list[int] = [int(val) for val in other.template_params]
if self.is_typed_task and other.is_typed_task:
return (
self.fn_name == other.fn_name
and template_params_int_self == template_params_int_other
and self.typed_fn_names == other.typed_fn_names
and self.typed_fn_types == other.typed_fn_types
)
elif (not self.is_typed_task) and (not other.is_typed_task):
return self.fn_name == other.fn_name and template_params_int_self == template_params_int_other
else:
return False
def __hash__(self):
# Convert the list to a tuple for hashing (TypeError: unhashable type:
# 'list')
return hash((self.fn_name, tuple(self.template_params)))
def __repr__(self) -> str:
params_str = ", ".join(str(param) for param in self.template_params)
return f"Task(fn_name='{self.fn_name}', params=[{params_str}], typed_fn_names={self.typed_fn_names}, typed_fn_types={self.typed_fn_types}, typed_task={self.is_typed_task})"
def is_valid(self) -> bool:
return self.fn_name != ""
def resolve(
self,
text: str,
generic_fn_dict: dict[str, GenericFn],
typed_generic_fn_dict: dict[str, TypedGenericFn],
) -> str:
"""
Resolve the function and return the concrete definition
"""
pattern = rf"// Place concrete instances of the {self.fn_name} function here"
generic_fn: GenericFn | TypedGenericFn = None
try:
generic_fn = generic_fn_dict[self.fn_name]
except KeyError:
try:
generic_fn = typed_generic_fn_dict[self.fn_name]
except KeyError:
sys.stderr.write(f"Could not find {self.fn_name} in generic_fn_dict in Task.resolve [1]\n")
sys.exit(-1)
replacement_dict: dict[str, int] = dict(zip(generic_fn.params, self.template_params))
if self.is_typed_task:
# suppress pytype warning
generic_fn = cast(TypedGenericFn, generic_fn)
tmp = dict(zip(generic_fn.generic_fn_names, self.typed_fn_names))
tm2 = dict(zip(generic_fn.generic_fn_types, self.typed_fn_types))
replacement_dict = {**replacement_dict, **tmp, **tm2}
# print(f"DEBUG: Printing replacement DICT of task {self.fn_name}")
# pprint.pprint(replacement_dict)
return re.sub(
pattern,
lambda match: match.group()
+ "\n"
+ utils.build_concrete_fn(generic_fn, replacement_dict, self.is_typed_task)
+ "\n",
text,
)
# pytype: disable=name-error
def get_sub_tasks(
self,
generic_fn_dict: dict[str, GenericFn],
typed_generic_fn_dict: dict[str, TypedGenericFn],
context_params: Optional[dict[str, int]] = None,
) -> list["Task"]:
# pytype: enable=name-error
"""
Get the sub-tasks for the current task by resolving nested generic function calls.
"""
subtasks: list[Task] = []
# print(f'\n\nDebug in Get Sub Task for task {self.fn_name}')
# print("Context params at the beginning")
# print(context_params)
# print("Generic_Fn_Dict")
# pprint.pprint(generic_fn_dict.keys())
# print("Typed_Generic_Fn_Dict")
# pprint.pprint(typed_generic_fn_dict.keys())
generic_fn: GenericFn | TypedGenericFn = None
function_name = self.fn_name
if context_params is not None:
try:
function_name = eval(self.fn_name, {}, context_params)
print(f"Name: {function_name}")
except NameError:
pass
try:
generic_fn = generic_fn_dict[function_name]
except KeyError:
try:
generic_fn = typed_generic_fn_dict[function_name]
except KeyError:
sys.stderr.write(
f"Could not find {self.fn_name} in generic_fn_dict/typed_generic_fn_dict in Task.get_sub_tasks [2]\n"
)
sys.exit(-1)
resolved_fn_body: str = self.resolve(generic_fn.fn_body, generic_fn_dict, typed_generic_fn_dict)
# The 1st function call does not have
if context_params is None:
context_params = dict(zip(generic_fn.params, self.template_params))
if self.is_typed_task:
t = dict(zip(generic_fn.generic_fn_types, self.typed_fn_types))
u = dict(zip(generic_fn.generic_fn_names, self.typed_fn_names))
context_params.update(t)
context_params.update(u)
else:
tmp = dict(zip(generic_fn.params, self.template_params))
context_params.update(tmp)
# 'Simple' generic fn
generic_fn_call_pattern = r"(\w+)<([^>]+)>\(([^)]+)\);"
for match in re.finditer(generic_fn_call_pattern, resolved_fn_body):
fn_name, generic_params, _ = match.groups()
if fn_name == self.fn_name:
sys.stderr.write(f"Recursive functions not supported: {self.fn_name}\n")
sys.exit(-1)
generic_params = [p.strip() for p in generic_params.split(",")]
for param in generic_params:
try:
# Context params may override global params
context_params[param] = int(context_params[param])
except KeyError:
try:
context_params[param] = context_params.get(param, self.global_params[param])
except KeyError:
template_params_int: list[int] = [int(val) for val in self.template_params]
template_dict: dict[str, int] = dict(
zip(
generic_fn_dict[self.fn_name].params,
template_params_int,
)
) # TODO: handle key error
context_params[param] = eval(
param.replace("/", "//"),
{},
{
**self.global_params,
**context_params,
**template_dict,
},
)
except Exception:
sys.stderr.write("Could not evaluate parameter: {param}")
concrete_args: list[int] = [int(context_params.get(p, p)) for p in generic_params]
subtask = Task(fn_name, concrete_args, self.global_params)
subtasks.append(subtask)
# 'Typed' generic fn
typed_generic_fn_call_pattern = r"(\w+)<([^>]+)>\s*\[([^\]]+)]\(([^)]+)\);"
for match in re.finditer(typed_generic_fn_call_pattern, resolved_fn_body):
fn_name, generic_params, type_info, _ = match.groups()
fn_names = type_info.split(";")[0]
if isinstance(fn_names, str):
fn_names: list[str] = utils.string_to_singleton_list(fn_names)
fn_names = [v.strip() for v in fn_names]
fn_types = type_info.split(";")[-1]
if isinstance(fn_types, str):
fn_types: list[str] = utils.string_to_singleton_list(fn_types)
fn_types = [v.strip() for v in fn_types]
if fn_name == self.fn_name:
sys.stderr.write(f"Recursive functions not supported: {self.fn_name}\n")
sys.exit(-1)
generic_params: list[str] = [p.strip() for p in generic_params.split(",")]
for param in generic_params:
try:
# Context params may override global params
context_params[param] = int(context_params[param])
except KeyError:
try:
context_params[param] = context_params.get(param, self.global_params[param])
except KeyError:
template_params_int: list[int] = [int(val) for val in self.template_params]
template_dict: dict[str, int] = dict(
zip(
generic_fn_dict[self.fn_name].params,
template_params_int,
)
) # TODO: handle key error
context_params[param] = eval(
param.replace("/", "//"),
{},
{
**self.global_params,
**context_params,
**template_dict,
},
)
except Exception:
sys.stderr.write("Could not evaluate parameter: {param}")
concrete_args = [int(context_params.get(p, p)) for p in generic_params]
fn_names: list[str] = [name for name in fn_names if name != ""]
print(f"Debug: evaluating fn_names: {fn_names}")
fn_types: list[str] = [type for type in fn_types if type != ""]
print(f"Debug: evaluating fn_types: {fn_types}")
fn_names: list[str] = utils.eval_list(fn_names, context_params)
fn_types: list[str] = utils.eval_list(fn_types, context_params)
subtask = Task(fn_name, concrete_args, self.global_params, fn_names, fn_types)
subtasks.append(subtask)
# Recursive step: Find and collect sub-tasks from the resolved function
# body
replacement_dict: dict[str, int] = dict(zip(generic_fn.params, self.template_params))
if self.is_typed_task:
# suppress pytype warning
generic_fn = cast(TypedGenericFn, generic_fn)
tmp = dict(zip(generic_fn.generic_fn_names, self.typed_fn_names))
tm2 = dict(zip(generic_fn.generic_fn_types, self.typed_fn_types))
replacement_dict = {**replacement_dict, **tmp, **tm2}
for subtask in subtasks:
sub_subtasks = subtask.get_sub_tasks(
generic_fn_dict,
typed_generic_fn_dict,
{
**self.global_params,
**context_params,
**replacement_dict,
}, # Pass resolved_params_local in the recursion
)
subtasks.extend(sub_subtasks)
return subtasks