-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
126 lines (99 loc) · 3.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def get_words_from_input_lines(input_lines):
input = ''.join(line.replace('\n', ' ') for line in input_lines)
for ch in ['(', ')', '{', '}', ',']:
input = input.replace(ch, ' ' + ch + ' ')
words = input.split()
return words
def get_class_name(words):
if 'class' not in words:
raise RuntimeError('class is not detected in input.cpp')
class_idx = words.index('class')
if class_idx == len(words) - 1:
raise RuntimeError('class name is not existed')
class_name = words[class_idx + 1]
return class_name
def get_func_name(words):
dep = 0
func_name_list = []
for (idx, word) in enumerate(words):
dep += {'{': 1, '}': -1}.get(word, 0)
# print(idx, word, len(word), dep)
if dep == 1 and word == '(':
func_name_list.append(words[idx - 1])
return func_name_list
def get_para_type(words, func_name):
left_bracket = words.index(func_name) + 1
right_bracket = words.index(')', left_bracket + 1)
return [_.strip('&') for _ in words[left_bracket + 1: right_bracket: 3]]
def call_func(func_name, para_type, data):
global var_id, solution_name
code_part = ''
args = []
for (t, v) in zip(para_type, data):
var_id += 1
val = v.replace('[', '{').replace(']', '}')
arg = 'var' + str(var_id)
if val[0] != '{':
val = '{' + val + '}'
if t == 'ListNode*':
t = 'List'
arg = arg + '.head'
elif t == 'TreeNode*':
t = 'Tree'
arg = arg + '.root'
# convert to string
if val != '{}':
val = val.replace(
'{', '{"').replace('}', '"}').replace(',', '","')
code_part += '\t' + f'{t} var{var_id}{val};' + '\n'
args.append(arg)
args = ', '.join(args)
code_part += '\t' + f'print({solution_name}->{func_name}({args}));' + '\n'
return code_part
if __name__ == '__main__':
input_cpp_file = open('code/input.cpp', 'r', encoding='utf-8')
solution_name = 'solution'
var_id = 0
input_lines = input_cpp_file.readlines()
words = get_words_from_input_lines(input_lines)
# print(words)
input_data_file = open('code/input.txt', 'r')
data = input_data_file.readlines()
data = [_.strip('\n') for _ in data]
print(data)
class_name = get_class_name(words)
func_name_list = get_func_name(words)
code = '#include \"template/template.h\"' + '\n'
code += ''.join(input_lines)
code += '\n'
if class_name == "Solution":
code += f'Solution* {solution_name};' + '\n'
code += 'Timer timer;' + '\n'
code += "int main() {" + '\n'
main_func_name = max(func_name_list, key=len)
para_type_list = get_para_type(words, main_func_name)
print(main_func_name, para_type_list)
if len(data) % len(para_type_list) > 0:
raise RuntimeError('mismatch between parameters and data')
case_list = []
for i in range(len(data) // len(para_type_list)):
case = ''
case += '\t' + 'timer.start();' + '\n'
case += '\t' + f'{solution_name} = new Solution();' + '\n'
l_pos = i * len(para_type_list)
r_pos = l_pos + len(para_type_list)
case += call_func(main_func_name,
para_type_list, data[l_pos: r_pos])
case += '\t' + f'delete {solution_name};' + '\n'
case += '\t'
case += 'print(\"Runtime: \" + to_string(timer.end()) + \" ms\");'
case += '\n'
case_list.append(case)
code += '\n'.join(case_list)
code += '}'
print(class_name)
print(func_name_list)
print(code)
output_cpp_file = open('code/output.cpp', 'w')
output_cpp_file.write(code)
output_cpp_file.close()