From 563cce0fe2dc12933f6cf351f7f624fa27dd70c3 Mon Sep 17 00:00:00 2001 From: "andrea.zanelli" Date: Tue, 19 May 2020 10:25:29 +0200 Subject: [PATCH 1/3] update benchmark, add syrk --- benchmarks/run_benchmark.py | 9 +++++++++ benchmarks/run_benchmark_numpy.py | 2 +- examples/riccati_example/riccati.py | 3 ++- examples/riccati_example/riccati_mass_spring.py | 1 - examples/riccati_example/riccati_numpy.py | 1 + prometeo/cgen/code_gen_c.py | 2 ++ prometeo/cpmt/pmat_blasfeo_wrapper.c | 16 ++++++++++++++++ prometeo/cpmt/pmat_blasfeo_wrapper.h | 1 + prometeo/mem/ast_analyzer.py | 6 ++++-- 9 files changed, 36 insertions(+), 5 deletions(-) diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py index 24d88ef..69a2b0d 100644 --- a/benchmarks/run_benchmark.py +++ b/benchmarks/run_benchmark.py @@ -22,6 +22,8 @@ LOAD_BLASFEO_RES = True numpy_res_file = 'riccati_benchmark_numpy.json' LOAD_NUMPY_RES = True +numpy_blasfeo_res_file = 'riccati_benchmark_numpy_blasfeo.json' +LOAD_NUMPY_BLASFEO_RES = True julia_res_file = 'riccati_benchmark_julia.json' LOAD_JULIA_RES = True @@ -89,6 +91,13 @@ plt.semilogy(2*AVG_CPU_TIME_BLASFEO[:,1], AVG_CPU_TIME_BLASFEO[:,0]) legend.append('NumPy') +if LOAD_NUMPY_BLASFEO_RES: + with open(numpy_blasfeo_res_file) as res: + AVG_CPU_TIME_BLASFEO = json.load(res) + AVG_CPU_TIME_BLASFEO = np.array(AVG_CPU_TIME_BLASFEO) + plt.semilogy(2*AVG_CPU_TIME_BLASFEO[:,1], AVG_CPU_TIME_BLASFEO[:,0]) + legend.append('NumPy + BLASFEO') + if LOAD_JULIA_RES: with open(julia_res_file) as res: AVG_CPU_TIME_BLASFEO = json.load(res) diff --git a/benchmarks/run_benchmark_numpy.py b/benchmarks/run_benchmark_numpy.py index 645e882..94e4209 100644 --- a/benchmarks/run_benchmark_numpy.py +++ b/benchmarks/run_benchmark_numpy.py @@ -10,7 +10,7 @@ NREP_medium = 100 NREP_large = 10 AVG_CPU_TIME = [] -res_file = 'riccati_benchmark_numpy.json' +res_file = 'riccati_benchmark_numpy_blasfeo.json' RUN = True UPDATE_res = True diff --git a/examples/riccati_example/riccati.py b/examples/riccati_example/riccati.py index 212dfde..0e369f1 100644 --- a/examples/riccati_example/riccati.py +++ b/examples/riccati_example/riccati.py @@ -39,12 +39,13 @@ def factorize(self) -> None: pmt_gemm_nn(BAtP, BA, M, M) pmat_fill(L, 0.0) pmt_potrf(M, L) + pmat_print(L) Mxx[0:nx, 0:nx] = L[nu:nu+nx, nu:nu+nx] pmat_fill(self.P[N-i-1], 0.0) pmt_gemm_nt(Mxx, Mxx, self.P[N-i-1], self.P[N-i-1]) - pmat_print(self.P[N-i-1]) + # pmat_print(self.P[N-i-1]) return diff --git a/examples/riccati_example/riccati_mass_spring.py b/examples/riccati_example/riccati_mass_spring.py index a62a350..3a1c24e 100644 --- a/examples/riccati_example/riccati_mass_spring.py +++ b/examples/riccati_example/riccati_mass_spring.py @@ -89,6 +89,5 @@ def main() -> int: pmt_syrk_ln(w_nxu_nx, w_nxu_nx, RSQ, M) pmt_potrf(M, M) Lxx[0:nx,0:nx] = M[nu:nu+nx,nu:nu+nx] - # pmat_print(M) return 0 diff --git a/examples/riccati_example/riccati_numpy.py b/examples/riccati_example/riccati_numpy.py index 86eb566..6e674b1 100644 --- a/examples/riccati_example/riccati_numpy.py +++ b/examples/riccati_example/riccati_numpy.py @@ -23,6 +23,7 @@ M[nu:nu+nx, nu:nu+nx] = Q M = M + dot(BAtP, BA) L = linalg.cholesky(M) + print('L:\n', L) Mxx = L[nu:nu+nx, nu:nu+nx] P = dot(transpose(Mxx), Mxx) print('P:\n', P) diff --git a/prometeo/cgen/code_gen_c.py b/prometeo/cgen/code_gen_c.py index e425905..a56a8e8 100644 --- a/prometeo/cgen/code_gen_c.py +++ b/prometeo/cgen/code_gen_c.py @@ -39,6 +39,7 @@ 'pmt_gemm_tn': 'c_pmt_gemm_tn', \ 'pmt_gemm_nt': 'c_pmt_gemm_nt', \ 'pmt_trmm_rlnn': 'c_pmt_trmm_rlnn', \ + 'pmt_syrk_ln': 'c_pmt_syrk_ln', \ 'pmt_gead': 'c_pmt_gead', \ 'pmt_getrf': 'c_pmt_getrf', \ 'pmt_getrsm': 'c_pmt_getrsm', \ @@ -75,6 +76,7 @@ 'pmt_gemm_nn': ['pmat', 'pmat', 'pmat', 'pmat'], \ 'pmt_gemm_tn': ['pmat', 'pmat', 'pmat', 'pmat'], \ 'pmt_trmm_rlnn': ['pmat', 'pmat', 'pmat'], \ + 'pmt_syrk_ln': ['pmat', 'pmat', 'pmat', 'pmat'], \ 'pmt_gead': ['float', 'pmat', 'pmat'], \ 'pmt_getrf': ['pmat', 'pmat', 'List'], \ 'pmt_getrsm': ['pmat', 'List', 'pmat'], \ diff --git a/prometeo/cpmt/pmat_blasfeo_wrapper.c b/prometeo/cpmt/pmat_blasfeo_wrapper.c index 1ce9a64..0582966 100644 --- a/prometeo/cpmt/pmat_blasfeo_wrapper.c +++ b/prometeo/cpmt/pmat_blasfeo_wrapper.c @@ -147,6 +147,22 @@ void c_pmt_trmm_rlnn(struct pmat *A, struct pmat *B, struct pmat *D) { blasfeo_dtrmm_rlnn(mB, nB, 1.0, bA, 0, 0, bB, 0, 0, bD, 0, 0); } +void c_pmt_syrk_ln(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D) { + int nA = A->bmat->n; + int mA = A->bmat->m; + struct blasfeo_dmat *bA = A->bmat; + struct blasfeo_dmat *bB = B->bmat; + struct blasfeo_dmat *bD = D->bmat; + + // printf("In dgemm\n"); + // blasfeo_print_dmat(mA, nA, A->bmat, 0, 0); + // blasfeo_print_dmat(mA, nA, B->bmat, 0, 0); + // blasfeo_print_dmat(mA, nA, C->bmat, 0, 0); + // blasfeo_print_dmat(mA, nA, D->bmat, 0, 0); + + blasfeo_dsyrk_ln(mA, nA, 1.0, A->bmat, 0, 0, B->bmat, 0, 0, 1.0, C->bmat, 0, 0, D->bmat, 0, 0); +} + void c_pmt_getrf(struct pmat *A, struct pmat *fact, int *ipiv) { int mA = A->bmat->m; struct blasfeo_dmat *bA = A->bmat; diff --git a/prometeo/cpmt/pmat_blasfeo_wrapper.h b/prometeo/cpmt/pmat_blasfeo_wrapper.h index 6adf670..6c814d4 100644 --- a/prometeo/cpmt/pmat_blasfeo_wrapper.h +++ b/prometeo/cpmt/pmat_blasfeo_wrapper.h @@ -29,6 +29,7 @@ struct pmat * _c_pmt_gemm_nn(struct pmat *A, struct pmat *B); void c_pmt_gemm_tn(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D); void c_pmt_gemm_nt(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D); void c_pmt_trmm_rlnn(struct pmat *A, struct pmat *B, struct pmat *D); +void c_pmt_syrk_ln(struct pmat *A, struct pmat *B, struct pmat *C, struct pmat *D); void c_pmt_getrf(struct pmat *A, struct pmat *fact, int *ipiv); void c_pmt_potrf(struct pmat *A, struct pmat *fact); void c_pmt_getrsm(struct pmat *fact, int *ipiv, struct pmat *rhs); diff --git a/prometeo/mem/ast_analyzer.py b/prometeo/mem/ast_analyzer.py index 0dabd70..c578b0b 100644 --- a/prometeo/mem/ast_analyzer.py +++ b/prometeo/mem/ast_analyzer.py @@ -22,6 +22,7 @@ 'global@pmt_gemm_tn': [], \ 'global@pmt_gemm_nt': [], \ 'global@pmt_trmm_rlnn': [], \ + 'global@pmt_syrk_ln': [], \ 'global@pmt_gead': [], \ 'global@pmt_potrf': [], \ 'global@pmt_potrsm': [], \ @@ -29,6 +30,7 @@ 'global@pmt_getrsm': [], \ 'global@print': [], \ 'global@pparse': [], \ + 'global@pparse': [], \ } def precedence_setter(AST=ast.AST, get_op_precedence=get_op_precedence, @@ -119,8 +121,8 @@ def visit_Module(self, node): return def visit_FunctionDef(self, node): - if node.name != '__init__': - self.caller_scope = self.caller_scope + '@' + node.name + # if node.name != '__init__': + self.caller_scope = self.caller_scope + '@' + node.name self.callees[self.caller_scope] = set([]) # self.visit_ast(node) self.body(node.body) From 295d3d6c94af3ebaecda5d49ce9c40e10cc5a20d Mon Sep 17 00:00:00 2001 From: "andrea.zanelli" Date: Tue, 19 May 2020 15:41:47 +0200 Subject: [PATCH 2/3] draft implementation of instance attributes --- prometeo/cgen/code_gen_c.py | 237 ++++++++++++++++++++++++------------ 1 file changed, 162 insertions(+), 75 deletions(-) diff --git a/prometeo/cgen/code_gen_c.py b/prometeo/cgen/code_gen_c.py index a56a8e8..ccaa091 100644 --- a/prometeo/cgen/code_gen_c.py +++ b/prometeo/cgen/code_gen_c.py @@ -266,6 +266,26 @@ def check_expression(node, binops, unops, usr_types, ast_types, record): else: raise cgenException('could not resolve expression {}\n'.format(astu.unparse(node)), node.lineno) +def get_pmt_type_value(node, record): + """ + Return prometeo-type of node.value + """ + # simple value + if hasattr(node, 'value'): + if isinstance(node.value, ast.Name): + var_name = Num_or_Name(node) + if var_name in record: + return record[var_name] + + # try to infer basic types + if isinstance(node.value, ast.Num): + if isinstance(node.ast.Num, int): + return 'int' + elif isinstance(node.ast.Num, float): + return 'float' + else: + raise cgenException('Could not determine type of node.value', self.lineno) + # def process_annotation(ann_node): # if isinstance(ann_node, ast.Name): @@ -446,24 +466,24 @@ def body(self, statements): def body_class(self, statements, name): self.indentation += 1 - self.write_class_attributes(*statements, name=name) + self.write_class(*statements, name=name) self.write('};', dest = 'hdr') self.indentation -= 1 self.write_class_method_prototypes(*statements, name=name) self.write('\n', dest = 'src') - self.write_class_init(*statements, name=name) + self.write_class_constructor(*statements, name=name) self.write_class_methods(*statements, name=name) - def write_class_attributes(self, *params, name): - """ self.write is a closure for performance (to reduce the number - of attribute lookups). - """ - self.meta_info[self.scope]['attr'] = dict() - self.meta_info[self.scope]['methods'] = dict() + def write_instance_attributes(self, params, name): for item in params: if isinstance(item, ast.AnnAssign): + # skip non-attribute declarations + if isinstance(item.target, ast.Name): + break + if item.target.value.id != 'self': + raise cgenException('Unrecognized attribute declaration', self.lineno) set_precedence(item, item.target, item.annotation) set_precedence(Precedence.Comma, item.value) need_parens = isinstance(item.target, ast.Name) and not item.simple @@ -473,7 +493,7 @@ def write_class_attributes(self, *params, name): # annotation = ast.parse(item.annotation.s).body[0] # if 'value' in annotation.value.__dict__: type_py = annotation.id - self.meta_info[self.scope]['attr'][item.target.id] = type_py + self.meta_info[self.scope]['attr'][item.target.attr] = type_py if type_py is 'List': if item.value.func.id is not 'plist': @@ -490,10 +510,10 @@ def write_class_attributes(self, *params, name): # ann = item.annotation.slice.value.elts[0].id # dims = Num_or_Name(item.annotation.slice.value.elts[1]) if isinstance(dims, str): - self.typed_record[self.scope][item.target.id] = \ + self.typed_record[self.scope][item.target.attr] = \ 'List[' + ann + ', ' + dims + ']' else: - self.typed_record[self.scope][item.target.id] = \ + self.typed_record[self.scope][item.target.attr] = \ 'List[' + ann + ', ' + str(dims) + ']' if ann in pmt_temp_types: @@ -507,7 +527,7 @@ def write_class_attributes(self, *params, name): array_size = len(dim_list) # array_size = str(Num_or_Name(item.value.args[1])) # self.statement([], ann, ' ', item.target, '[', array_size, '];') - self.write('%s' %ann, ' ', '%s' %item.target.id, \ + self.write('%s' %ann, ' ', '%s' %item.target.attr, \ '[%s' %array_size, '];\n', dest = 'hdr') else: # not a List @@ -515,73 +535,88 @@ def write_class_attributes(self, *params, name): # check for user-defined types if ann in usr_temp_types: - self.write('struct %s' %ann, ' ', item.target.id, '___;\n', dest = 'hdr') - self.write('struct %s *' %ann, ' ', item.target.id, ';\n', dest = 'hdr') + self.write('struct %s' %ann, ' ', item.target.attr, '___;\n', dest = 'hdr') + self.write('struct %s *' %ann, ' ', item.target.attr, ';\n', dest = 'hdr') # self.statement(node, node.annotation, ' ', node.target, '= &', node.target, '___;') else: type_c = pmt_temp_types[type_py] - self.write('%s' %type_c, ' ', '%s' %item.target.id, ';\n', dest = 'hdr') + self.write('%s' %type_c, ' ', '%s' %item.target.attr, ';\n', dest = 'hdr') # self.conditional_write(' = ', item.value, ';') - self.typed_record[self.scope][item.target.id] = type_py + self.typed_record[self.scope][item.target.attr] = type_py + else: + # TODO(andrea): need to support constructor arguments + raise cgenException('Unsupported attribute of type {}'.format(item.name), self.lineno) # else: # type_py = annotation.id # type_c = pmt_temp_types[type_py] # self.write('%s' %type_c, ' ', '%s' %item.target.id, ';\n', dest = 'hdr') # # self.conditional_write(' = ', item.value, ';') # self.typed_record[self.scope][item.target.id] = type_py - elif isinstance(item, ast.FunctionDef): - self.meta_info[self.scope]['methods'][item.name] = dict() - # build argument mangling - f_name_len = len(item.name) - pre_mangl = '_Z%s' %f_name_len - if item.args.args[0].arg is not 'self': - raise cgenException('First argument in method {} \ - must be \'self\'. You have \'{}\''.format(item.name, \ - item.args.args[0].arg), item.lineno) - else: - # store self argument - self_arg = item.args.args[0] - # pop self from argument list - item.args.args.pop(0) - post_mangl = self.build_arg_mangling(item.args) - - if hasattr(self.get_returns(item), 'id'): - ret_type = self.get_returns(item).id - else: - ret_type = self.get_returns(item).value + def write_class(self, *params, name): + """ self.write is a closure for performance (to reduce the number + of attribute lookups). + """ + self.meta_info[self.scope]['attr'] = dict() + self.meta_info[self.scope]['methods'] = dict() + for item in params: + if not isinstance(item, ast.FunctionDef): + raise cgenException('Classes can only contain attributes and methods', item.lineno) + # additional treatment of __init__ (declare attributes) + if item.name == '__init__': + self.write_instance_attributes(item.body, name=name) - self.meta_info[self.scope]['methods'][item.name]['return_type'] = ret_type + self.meta_info[self.scope]['methods'][item.name] = dict() + # build argument mangling + f_name_len = len(item.name) + pre_mangl = '_Z%s' %f_name_len + if item.args.args[0].arg is not 'self': + raise cgenException('First argument in method {} \ + must be \'self\'. You have \'{}\''.format(item.name, \ + item.args.args[0].arg), item.lineno) + else: + # store self argument + self_arg = item.args.args[0] + # pop self from argument list + item.args.args.pop(0) - if ret_type is None: - ret_type = 'None' + post_mangl = self.build_arg_mangling(item.args) + + if hasattr(self.get_returns(item), 'id'): + ret_type = self.get_returns(item).id + else: + ret_type = self.get_returns(item).value - if ret_type in pmt_temp_types: - ret_type = pmt_temp_types[ret_type] - else: + self.meta_info[self.scope]['methods'][item.name]['return_type'] = ret_type - raise cgenException ('Usage of non existing type \ - \033[91m{}\033[0m'.format(ann), item.lineno) - # raise cgenException ('Usage of non existing type {}'.format(ret_type)) + if ret_type is None: + ret_type = 'None' - if len(item.args.args) > 0: - self.write('%s (*%s%s%s' % (ret_type, pre_mangl, \ - item.name, post_mangl) , ')', '(%s *self, ' %name, \ - dest = 'hdr') - else: - self.write('%s (*%s%s%s' % (ret_type, pre_mangl, \ - item.name, post_mangl) , ')', '(%s *self' %name, \ - dest = 'hdr') + if ret_type in pmt_temp_types: + ret_type = pmt_temp_types[ret_type] + else: + raise cgenException ('Usage of non existing type \ + \033[91m{}\033[0m'.format(ann), item.lineno) + # raise cgenException ('Usage of non existing type {}'.format(ret_type)) - args_list = self.visit_arguments(item.args, 'hdr') - self.meta_info[self.scope]['methods'][item.name]['args'] = args_list - self.write(');\n', dest = 'hdr') - # insert back self argument - item.args.args.insert(0, self_arg) + if len(item.args.args) > 0: + self.write('%s (*%s%s%s' % (ret_type, pre_mangl, \ + item.name, post_mangl) , ')', '(%s *self, ' %name, \ + dest = 'hdr') else: - raise cgenException('Classes can only contain attributes and methods', item.lineno) + self.write('%s (*%s%s%s' % (ret_type, pre_mangl, \ + item.name, post_mangl) , ')', '(%s *self' %name, \ + dest = 'hdr') + + + args_list = self.visit_arguments(item.args, 'hdr') + self.meta_info[self.scope]['methods'][item.name]['args'] = args_list + # TODO(andrea): implicit call to visit() in write() - make explicit + self.write(');\n', dest = 'hdr') + # insert back self argument + item.args.args.insert(0, self_arg) def write_class_method_prototypes(self, *params, name): """ self.write is a closure for performance (to reduce the number @@ -629,11 +664,11 @@ def write_class_method_prototypes(self, *params, name): # insert back self argument item.args.args.insert(0, self_arg) - def write_class_init(self, *params, name): + def write_class_constructor(self, *params, name): """ self.write is a closure for performance (to reduce the number of attribute lookups). """ - self.write('void ', name, '_init(struct ', name, ' *object){', dest = 'src') + self.write('void ', name, '_constructor(struct ', name, ' *object){', dest = 'src') self.indentation += 1 for item in params: if isinstance(item, ast.AnnAssign): @@ -717,7 +752,7 @@ def write_class_init(self, *params, name): ' initialization.\n', item.lineno) elif ann in usr_temp_types: self.write('\nobject->', item.target.id, ' = &(object->', item.target.id, '___);\n', dest = 'src') - self.write(ann, '_init(object->', item.target.id, ');\n', dest='src') + self.write(ann, '_constructor(object->', item.target.id, ');\n', dest='src') else: if item.value != None: if hasattr(item.value, 'value') is False: @@ -756,6 +791,8 @@ def write_class_init(self, *params, name): # insert back self argument item.args.args.insert(0, self_arg) + # call __init__ + self.write('\n\tobject->_Z8__init__(object);\n', dest = 'src') self.write('\n}\n', dest = 'src') self.indentation -=1 @@ -944,6 +981,7 @@ def comma_list(self, items, trailing=False): self.write(',' if trailing else '', dest = 'src') # Statements + # TODO(andrea): make visit_Assign and visit_AnnAssign fully consistent def visit_Assign(self, node): if 'targets' in node.__dict__: if len(node.targets) != 1: @@ -971,7 +1009,7 @@ def visit_Assign(self, node): else: if node.targets[0].id not in self.typed_record[self.scope]: raise cgenException('Unknown variable {}.'.format(node.targets[0].id), node.lineno) - if type(node.targets[0]) == ast.Subscript: + if isinstance(node.targets[0], ast.Subscript): if target in self.typed_record[scope]: # map subscript for pmats to blasfeo el assign if self.typed_record[scope][target] == 'pmat': @@ -1090,7 +1128,7 @@ def visit_Assign(self, node): value of type {} not implemented'.format(sub_type), node.lineno) # check if subscripted expression is used in the value - if type(node.value) == ast.Subscript: + if isinstance(node.value, ast.Subscript): # if value is a pmat value = node.value.value.id if value in self.typed_record[self.scope]: @@ -1122,7 +1160,7 @@ def visit_Assign(self, node): self.statement([], 'c_pmt_pvec_set_el(', target, ', {}'.format(index), ', {}'.format(value), ');') return - elif type(node.value) == ast.Subscript: + elif isinstance(node.value, ast.Subscript): target = node.targets[0].id if target not in self.typed_record[self.scope]: raise cgenException('Undefined variable {}.'.format(target), node.lineno) @@ -1156,7 +1194,7 @@ def visit_Assign(self, node): value_expr = 'c_pmt_pvec_get_el(' + value + ', {})'.format(index_value) self.statement([], target, ' = {}'.format(value_expr), ';') return - elif 'id' in node.targets[0].__dict__: + elif isinstance(node.targets[0], ast.Name): # check for Assigns targeting pmats target = node.targets[0].id @@ -1197,7 +1235,6 @@ def visit_Assign(self, node): node.lineno) # elif self.typed_record[self.scope][node.value] == 'float': - # import pdb; pdb.set_trace() # value = Num_or_Name(node.value) # self.statement([], 'c_pmt_pmat_set_el(', target.value.id, ', {}'.format(first_index), ', {}'.format(second_index), ', {}'.format(value_expr), ');') @@ -1223,7 +1260,7 @@ def visit_Assign(self, node): node.lineno) - elif 'attr' in node.targets[0].__dict__: + elif isinstance(node.targets[0], ast.Attribute): # Assign targeting a user-defined class (C struct) struct_name = node.targets[0].value.id if struct_name in self.typed_record[self.scope]: @@ -1237,6 +1274,7 @@ def visit_Assign(self, node): else: raise cgenException('Could not resolve Assign node.', node.lineno) + # default assignment set_precedence(node, node.value, *node.targets) self.newline(node) for target in node.targets: @@ -1261,6 +1299,34 @@ def visit_AnnAssign(self, node): raise cgenException('Cannot declare variable without initialization.', node.lineno) ann = node.annotation.id + + # check for attributes + if hasattr(node.target, 'value'): + if isinstance(node.target, ast.Attribute): + # if hasattr(node.target.value, 'attr'): + if node.target.value.id != 'self' and node.target.attr not in self.typed_record[self.scope]: + raise cgenException('Unknown variable {}.'.format( \ + node.target.attr), node.lineno) + # TODO(andrea): need to handle attributes recursively + target = node.target.attr + obj_name = node.target.value.id + # TODO(andrea): need to compute local scope (find strings + # that contain scope and have a string in common with self.scope) + # this assumes that the class has been defined in the global scope + + # do not update scope if an instance attribute is being defined + if node.target.value.id != 'self': + scope = 'global@' + self.typed_record[self.scope][obj_name] + else: + if node.target.value.id not in self.typed_record[self.scope]: + raise cgenException('variable {} already defined.'.format(node.target.value.id), node.lineno) + + target = node.target.value.id + scope = self.scope + else: + if node.target.id in self.typed_record[self.scope]: + raise cgenException('variable {} already defined.'.format(node.target.id), node.lineno) + # check if a CasADi function is being declared (and skip) if ann == 'ca': return @@ -1345,7 +1411,8 @@ def visit_AnnAssign(self, node): else: raise cgenException('Undefined variable {} of type dims.'.format(dim2), node.lineno) # self.heap64_record[self.scope] = self.heap64_record[self.scope] + int(dim1)*int(dim2)*self.size_of_double - self.heap64_record[self.scope] = self.heap64_record[self.scope] + '+' + str(dim1) + '*' + str(dim2) + '*' + str(self.size_of_double) + self.heap64_record[self.scope] = self.heap64_record[self.scope] + '+' + str(dim1) + \ + '*' + str(dim2) + '*' + str(self.size_of_double) # or pvec[] elif ann == 'pvec': if node.value.func.id != 'pvec': @@ -1359,7 +1426,8 @@ def visit_AnnAssign(self, node): # or dims elif ann == 'dims': - check_expression(node.value, tuple([ast.Mult, ast.Sub, ast.Pow, ast.Add]), tuple([ast.USub]),('dims'), tuple([ast.Num]), self.dim_record) + check_expression(node.value, tuple([ast.Mult, ast.Sub, ast.Pow, ast.Add]), \ + tuple([ast.USub]),('dims'), tuple([ast.Num]), self.dim_record) value = astu.unparse(node.value) self.write('#define %s %s\n' %(node.target.id, value), dest='hdr') self.dim_record[node.target.id] = value @@ -1373,7 +1441,8 @@ def visit_AnnAssign(self, node): self.dim_record[node.target.id].append([]) for j in range(len(node.value.elts[i].elts)): self.dim_record[node.target.id][i].append(node.value.elts[i].elts[j].n) - self.write('#define %s_%s_%s %s\n' %(node.target.id, i, j, node.value.elts[i].elts[j].n), dest='hdr') + self.write('#define %s_%s_%s %s\n' %(node.target.id, i, j, \ + node.value.elts[i].elts[j].n), dest='hdr') # check if annotation corresponds to user-defined class name elif ann in usr_temp_types: @@ -1381,19 +1450,37 @@ def visit_AnnAssign(self, node): node.annotation.id = usr_temp_types[ann] self.statement([], 'struct ', class_name, ' ', node.target, '___;') self.statement(node, node.annotation, ' ', node.target, '= &', node.target, '___;') - self.statement([], class_name, '_init(', node.target, '); //') + self.statement([], class_name, '_constructor(', node.target, '); //') else: if ann in pmt_temp_types: c_ann = pmt_temp_types[ann] - self.statement(node, c_ann, ' ', node.target.id) - self.conditional_write(' = ', node.value, ';', dest = 'src') + if isinstance(node.target, ast.Attribute): + # annotated assign that defined an attribute (i.e. . : = ) + if node.target.value.id != 'self': + raise cgenException('invalid AnnAssign on attribute. AnnAssign on attributes can only be used to ' + 'define instance attributes', self.lineno) + else: + if isinstance(node.value, ast.Name): + if node.value.id not in self.typed_record: + raise cgenException('Unknown variable {}.'.format(node.value.id), node.lineno) + attr_value = Num_or_Name(node.value) + attr_name = node.target.attr + import pdb; pdb.set_trace() + self.statement([], node.target.value.id, '->', attr_name, ' = ', str(attr_value), ';') + + else: + self.statement(node, c_ann, ' ', node.target.id) + self.conditional_write(' = ', node.value, ';', dest = 'src') else: raise cgenException('\033[;1mUsage of non existing type\033[0;0m' ' \033[1;31m{}\033[0;0m.'.format(ann), node.lineno) # print('typed_record = \n', self.typed_record, '\n\n') # print('var_dim_record = \n', self.var_dim_record, '\n\n') - self.typed_record[self.scope][node.target.id] = ann + + # AnnAssigns on attributes are only supported for instance attributes + if not isinstance(node.target, ast.Attribute): + self.typed_record[self.scope][node.target.id] = ann # # switch to avoid double ';' # if type(node.value) != ast.Call: From 47be88611114920c91b662100c841beeca1d803f Mon Sep 17 00:00:00 2001 From: "andrea.zanelli" Date: Tue, 19 May 2020 17:15:23 +0200 Subject: [PATCH 3/3] adapted examples, fixed AnnAssigns targeting attributes (__init__) --- examples/laparser/laparser.py | 1 + examples/riccati_example/riccati.py | 13 ++- examples/riccati_example/riccati_compact.py | 15 ++-- .../riccati_example/riccati_mass_spring_2.py | 17 ++-- prometeo/cgen/code_gen_c.py | 86 ++++++++++++------- 5 files changed, 80 insertions(+), 52 deletions(-) diff --git a/examples/laparser/laparser.py b/examples/laparser/laparser.py index f0ddc2e..9001873 100644 --- a/examples/laparser/laparser.py +++ b/examples/laparser/laparser.py @@ -17,6 +17,7 @@ def main() -> int: B[1,1] = 1.0 C: pmat = pmat(nx, nx) + D: pmat = pmat(nx, nx) pparse('C = A - A.T \ (B * D).T') diff --git a/examples/riccati_example/riccati.py b/examples/riccati_example/riccati.py index 0e369f1..33d9bd2 100644 --- a/examples/riccati_example/riccati.py +++ b/examples/riccati_example/riccati.py @@ -7,13 +7,12 @@ N: dims = 5 class qp_data: - A: List = plist(pmat, sizes) - B: List = plist(pmat, sizes) - Q: List = plist(pmat, sizes) - R: List = plist(pmat, sizes) - P: List = plist(pmat, sizes) - - fact: List = plist(pmat, sizes) + def __init__(self) -> None: + self.A: List = plist(pmat, sizes) + self.B: List = plist(pmat, sizes) + self.Q: List = plist(pmat, sizes) + self.R: List = plist(pmat, sizes) + self.P: List = plist(pmat, sizes) def factorize(self) -> None: M: pmat = pmat(nxu, nxu) diff --git a/examples/riccati_example/riccati_compact.py b/examples/riccati_example/riccati_compact.py index 3a0dc6a..5b2a2bd 100644 --- a/examples/riccati_example/riccati_compact.py +++ b/examples/riccati_example/riccati_compact.py @@ -7,13 +7,14 @@ N: dims = 5 class qp_data: - A: List = plist(pmat, sizes) - B: List = plist(pmat, sizes) - Q: List = plist(pmat, sizes) - R: List = plist(pmat, sizes) - P: List = plist(pmat, sizes) - - fact: List = plist(pmat, sizes) + def __init__(self) -> None: + self.A: List = plist(pmat, sizes) + self.B: List = plist(pmat, sizes) + self.Q: List = plist(pmat, sizes) + self.R: List = plist(pmat, sizes) + self.P: List = plist(pmat, sizes) + + self.fact: pmat = pmat(nx,nx) def factorize(self) -> None: Qk: pmat = pmat(nx, nx) diff --git a/examples/riccati_example/riccati_mass_spring_2.py b/examples/riccati_example/riccati_mass_spring_2.py index fef7a34..d95196f 100644 --- a/examples/riccati_example/riccati_mass_spring_2.py +++ b/examples/riccati_example/riccati_mass_spring_2.py @@ -2,21 +2,20 @@ nm: dims = 4 nx: dims = 2*nm -# nx: dims = 2 -# sizes: dimv = [[2,2], [2,2], [2,2], [2,2], [2,2]] sizes: dimv = [[8,8], [8,8], [8,8], [8,8], [8,8]] nu: dims = nm nxu: dims = nx + nu N: dims = 5 class qp_data: - A: List = plist(pmat, sizes) - B: List = plist(pmat, sizes) - Q: List = plist(pmat, sizes) - R: List = plist(pmat, sizes) - P: List = plist(pmat, sizes) - - fact: List = plist(pmat, sizes) + def __init__(self) -> None: + self.A: List = plist(pmat, sizes) + self.B: List = plist(pmat, sizes) + self.Q: List = plist(pmat, sizes) + self.R: List = plist(pmat, sizes) + self.P: List = plist(pmat, sizes) + + self.fact: List = plist(pmat, sizes) def factorize(self) -> None: M: pmat = pmat(nxu, nxu) diff --git a/prometeo/cgen/code_gen_c.py b/prometeo/cgen/code_gen_c.py index ccaa091..641eef8 100644 --- a/prometeo/cgen/code_gen_c.py +++ b/prometeo/cgen/code_gen_c.py @@ -563,6 +563,9 @@ def write_class(self, *params, name): for item in params: if not isinstance(item, ast.FunctionDef): raise cgenException('Classes can only contain attributes and methods', item.lineno) + if item.returns is None: + raise cgenException('Missing return annotation on class method {}'.format(item.name), item.lineno) + # additional treatment of __init__ (declare attributes) if item.name == '__init__': self.write_instance_attributes(item.body, name=name) @@ -1324,7 +1327,8 @@ def visit_AnnAssign(self, node): target = node.target.value.id scope = self.scope else: - if node.target.id in self.typed_record[self.scope]: + target = node.target.id + if target in self.typed_record[self.scope]: raise cgenException('variable {} already defined.'.format(node.target.id), node.lineno) # check if a CasADi function is being declared (and skip) @@ -1353,26 +1357,36 @@ def visit_AnnAssign(self, node): lann = node.value.args[0].id dims = Num_or_Name(node.value.args[1]) if isinstance(dims, str): - self.typed_record[self.scope][node.target.id] = 'List[' + lann + ', ' + dims + ']' + self.typed_record[self.scope][target] = 'List[' + lann + ', ' + dims + ']' else: - self.typed_record[self.scope][node.target.id] = 'List[' + lann + ', ' + str(dims) + ']' + self.typed_record[self.scope][target] = 'List[' + lann + ', ' + str(dims) + ']' if lann in pmt_temp_types: lann = pmt_temp_types[lann] else: raise cgenException ('Usage of non existing type {}.'.format(lann), node.lineno) - # check is dims is not a numerical value + + # check if dims is not a numerical value if isinstance(dims, str): dim_list = self.dim_record[dims] array_size = len(dim_list) else: array_size = dims - # array_size = str(Num_or_Name(node.value.args[1])) - # self.statement([], lann, ' ', node.target, '[', array_size, '];') - self.write('%s' %lann, ' ', '%s' %node.target.id, '[%s' %array_size, '];\n', dest = 'src') + + # assume that AnnAssigns on attributes are only used to declare instance attributes + if not isinstance(node.target, ast.Attribute): + self.write('%s' %lann, ' ', '%s' %target, '[%s' %array_size, '];\n', dest = 'src') + + + # assume that AnnAssigns on attributes are only used to declare instance attributes + if isinstance(node.target, ast.Attribute): + mod_target = 'self->' + target + else: + mod_target = target + if lann == 'struct pmat *': # build init for List of pmats for i in range(len(dim_list)): - self.statement([], node.target.id, \ + self.statement([], mod_target, \ '[', str(i),'] = c_pmt_create_pmat(', \ str(dim_list[i][0]), ', ', \ str(dim_list[i][1]), ');') @@ -1380,10 +1394,9 @@ def visit_AnnAssign(self, node): elif lann == 'struct pvec *': # build init for List of pvecs for i in range(len(dim_list)): - self.statement([], node.target.id, \ + self.statement([], mod_target, \ '[', str(i),'] = c_pmt_create_pvec(', \ str(dim_list[i][0]), ');') - # self.conditional_write(' = ', node.value, '', dest = 'src') # pmat[,] elif ann == 'pmat': @@ -1392,13 +1405,22 @@ def visit_AnnAssign(self, node): ' the pmat(, ) constructor.', node.lineno) dim1 = Num_or_Name(node.value.args[0]) dim2 = Num_or_Name(node.value.args[1]) - self.var_dim_record[self.scope][node.target.id] = [dim1, dim2] + self.var_dim_record[self.scope][target] = [dim1, dim2] node.annotation.id = pmt_temp_types[ann] - self.statement(node, node.annotation, ' ', node.target) - self.conditional_write(' = ', node.value, '', dest = 'src') + # assume that AnnAssigns on attributes are only used to declare instance attributes + if isinstance(node.target, ast.Attribute): + self.write('\nself->' + str(node.target.attr) + ' = ', node.value, '\n', dest = 'src') + else: + self.statement(node, node.annotation, ' ', node.target) + self.conditional_write(' = ', node.value, '', dest = 'src') + # increment scoped heap usage (3 pointers and 6 ints for pmats) - self.heap8_record[self.scope] = self.heap8_record[self.scope] + '+' + '3*' + str(self.size_of_pointer) - self.heap8_record[self.scope] = self.heap8_record[self.scope] + '+' + '6*' + str(self.size_of_int) + self.heap8_record[self.scope] = self.heap8_record[self.scope] + \ + '+' + '3*' + str(self.size_of_pointer) + + self.heap8_record[self.scope] = self.heap8_record[self.scope] + \ + '+' + '6*' + str(self.size_of_int) + # check is dims is not a numerical value if isinstance(dim1, str): if dim1 in self.dim_record: @@ -1410,7 +1432,6 @@ def visit_AnnAssign(self, node): dim2 = self.dim_record[dim2] else: raise cgenException('Undefined variable {} of type dims.'.format(dim2), node.lineno) - # self.heap64_record[self.scope] = self.heap64_record[self.scope] + int(dim1)*int(dim2)*self.size_of_double self.heap64_record[self.scope] = self.heap64_record[self.scope] + '+' + str(dim1) + \ '*' + str(dim2) + '*' + str(self.size_of_double) # or pvec[] @@ -1420,9 +1441,12 @@ def visit_AnnAssign(self, node): dim1 = Num_or_Name(node.value.args[0]) self.var_dim_record[self.scope][node.target.id] = [dim1] node.annotation.id = pmt_temp_types[ann] - self.statement(node, node.annotation, ' ', node.target) - self.conditional_write(' = ', node.value, '', dest = 'src') - + # assume that AnnAssigns on attributes are only used to declare instance attributes + if isinstance(node.target, ast.Attribute): + self.write('\nself->' + str(node.target.attr) + ' = ', node.value, '\n', dest = 'src') + else: + self.statement(node, node.annotation, ' ', node.target) + self.conditional_write(' = ', node.value, '', dest = 'src') # or dims elif ann == 'dims': @@ -1431,7 +1455,6 @@ def visit_AnnAssign(self, node): value = astu.unparse(node.value) self.write('#define %s %s\n' %(node.target.id, value), dest='hdr') self.dim_record[node.target.id] = value - # self.write('const int %s = %s;\n' %(node.target.id, node.value.n), dest='hdr') # or dimv elif ann == 'dimv': @@ -1448,9 +1471,15 @@ def visit_AnnAssign(self, node): elif ann in usr_temp_types: class_name = node.annotation.id node.annotation.id = usr_temp_types[ann] - self.statement([], 'struct ', class_name, ' ', node.target, '___;') - self.statement(node, node.annotation, ' ', node.target, '= &', node.target, '___;') - self.statement([], class_name, '_constructor(', node.target, '); //') + # assume that AnnAssigns on attributes are only used to declare instance attributes + if isinstance(node.target, ast.Attribute): + self.statement([], 'self->', node.annotation, '= & ', node.target, '___;') + # self.statement(node, node.annotation, ' ', node.target, '= &', node.target, '___;') + self.statement([], 'self->', class_name, '_constructor(', node.target, '); //') + else: + self.statement([], 'struct ', class_name, ' ', node.target, '___;') + self.statement(node, node.annotation, ' ', node.target, '= &', node.target, '___;') + self.statement([], class_name, '_constructor(', node.target, '); //') else: if ann in pmt_temp_types: c_ann = pmt_temp_types[ann] @@ -1465,7 +1494,6 @@ def visit_AnnAssign(self, node): raise cgenException('Unknown variable {}.'.format(node.value.id), node.lineno) attr_value = Num_or_Name(node.value) attr_name = node.target.attr - import pdb; pdb.set_trace() self.statement([], node.target.value.id, '->', attr_name, ' = ', str(attr_value), ';') else: @@ -1541,8 +1569,8 @@ def visit_FunctionDef(self, node, is_async=False): # self.write() returns = self.get_returns(node) if returns is None: - raise cgenException('Function {} does not have a \ - return type hint.', node.lineno) + if item.returns is None: + raise cgenException('Missing return annotation on method {}'.format(item.name), node.lineno) if isinstance(returns, ast.NameConstant): return_type_py = str(returns.value) @@ -1888,11 +1916,11 @@ def write_comma(): p = Precedence.Comma if numargs > 1 else Precedence.call_one_arg set_precedence(p, *args) - if type(node.func) == ast.Name: + if isinstance(node.func, ast.Name): if node.func.id in pmt_temp_functions: func_name = node.func.id node.func.id = pmt_temp_functions[func_name] - elif type(node.func) == ast.Attribute: + elif isinstance(node.func, ast.Attribute): # calling a method of a user-defined class func_name = node.func.attr f_name_len = len(func_name) @@ -1901,7 +1929,7 @@ def write_comma(): node.func.attr = pre_mangl + func_name + post_mangl self.visit(node.func) - if type(node.func) == ast.Attribute: + if isinstance(node.func, ast.Attribute): if len(args) > 0: code = '(' + node.func.value.id + ', ' else: