diff --git a/multiply_add.sus b/multiply_add.sus index 6de4806..fb30d42 100644 --- a/multiply_add.sus +++ b/multiply_add.sus @@ -231,6 +231,17 @@ module blur2 : blurred = data + prev; } prev = data; + + gen int a; + + gen bool b = true; + bool bb = false; + + if bb { + a = 5; + } else { + a = 3; + } } diff --git a/src/flattening.rs b/src/flattening.rs index 8f866fc..25963ad 100644 --- a/src/flattening.rs +++ b/src/flattening.rs @@ -57,11 +57,13 @@ impl WireSource { &WireSource::BinaryOp { op:_, left, right } => {func(left); func(right)} &WireSource::ArrayAccess { arr, arr_idx } => {func(arr); func(arr_idx)} WireSource::Constant(_) => {} - &WireSource::NamedConstant(_) => {} + WireSource::NamedConstant(_) => {} } } } +const IS_GEN_UNINIT : bool = false; + #[derive(Debug)] pub struct WireInstance { pub typ : Type, @@ -100,7 +102,6 @@ pub struct SubModuleInstance { #[derive(Debug)] pub struct IfStatement { - pub is_compiletime : bool, pub condition : FlatID, pub then_start : FlatID, pub then_end_else_start : FlatID, @@ -180,15 +181,6 @@ struct FlatteningContext<'inst, 'l, 'm> { module : &'m Module, } -fn must_be_compiletime_with_info Vec>(wire : &WireInstance, context : &str, errors : &ErrorCollector, ctx_func : CtxFunc) { - if !wire.is_compiletime { - errors.error_with_info(wire.span, format!("{context} must be compile time"), ctx_func()); - } -} -fn must_be_compiletime(wire : &WireInstance, context : &str, errors : &ErrorCollector) { - must_be_compiletime_with_info(wire, context, errors, || Vec::new()); -} - impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { fn map_to_type(&mut self, type_expr : &SpanTypeExpression) -> Type { match &type_expr.0 { @@ -203,7 +195,6 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { let (array_type_expr, array_size_expr) = b.deref(); let array_element_type = self.map_to_type(&array_type_expr); if let Some(array_size_wire_id) = self.flatten_expr(array_size_expr) { - must_be_compiletime(self.instantiations[array_size_wire_id].extract_wire(), "Array size", &self.errors); Type::Array(Box::new((array_element_type, array_size_wire_id))) } else { Type::Error @@ -238,7 +229,7 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { let latency_specifier = if let Some(lat_expr) = &decl.latency_expr { if let Some(latency_spec) = self.flatten_expr(lat_expr) { - must_be_compiletime(self.instantiations[latency_spec].extract_wire(), "Latency specifier", &self.errors); + self.must_be_compiletime(self.instantiations[latency_spec].extract_wire(), "Latency specifier"); Some(latency_spec) } else { None @@ -346,41 +337,34 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { Some((md, submodule_local_wires)) } fn flatten_expr(&mut self, (expr, expr_span) : &SpanExpression) -> Option { - let (is_compiletime, source) = match expr { + let source = match expr { Expression::Named(LocalOrGlobal::Local(l)) => { let from_wire = self.decl_to_flat_map[*l].unwrap(); - let decl = self.instantiations[from_wire].extract_wire_declaration(); - (decl.identifier_type == IdentifierType::Generative, WireSource::WireRead(from_wire)) + WireSource::WireRead(from_wire) } Expression::Named(LocalOrGlobal::Global(ref_span)) => { let cst = self.linker.resolve_constant(*ref_span, &self.errors)?; - (true, WireSource::NamedConstant(cst)) + WireSource::NamedConstant(cst) } Expression::Constant(cst) => { - (true, WireSource::Constant(cst.clone())) + WireSource::Constant(cst.clone()) } Expression::UnaryOp(op_box) => { let (op, _op_pos, operate_on) = op_box.deref(); let right = self.flatten_expr(operate_on)?; - let right_wire = self.instantiations[right].extract_wire(); - (right_wire.is_compiletime, WireSource::UnaryOp{op : *op, right}) + WireSource::UnaryOp{op : *op, right} } Expression::BinOp(binop_box) => { let (left_expr, op, _op_pos, right_expr) = binop_box.deref(); let left = self.flatten_expr(left_expr)?; let right = self.flatten_expr(right_expr)?; - let left_wire = self.instantiations[left].extract_wire(); - let right_wire = self.instantiations[right].extract_wire(); - let is_compiletime = left_wire.is_compiletime && right_wire.is_compiletime; - (is_compiletime, WireSource::BinaryOp{op : *op, left, right}) + WireSource::BinaryOp{op : *op, left, right} } Expression::Array(arr_box) => { let (left, right, _bracket_span) = arr_box.deref(); let arr = self.flatten_expr(left)?; let arr_idx = self.flatten_expr(right)?; - let arr_wire = self.instantiations[arr].extract_wire(); - let arr_idx_wire = self.instantiations[arr_idx].extract_wire(); - (arr_wire.is_compiletime && arr_idx_wire.is_compiletime, WireSource::ArrayAccess{arr, arr_idx}) + WireSource::ArrayAccess{arr, arr_idx} } Expression::FuncCall(func_and_args) => { let (md, interface_wires) = self.desugar_func_call(func_and_args, expr_span.1)?; @@ -397,7 +381,7 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { } }; - let wire_instance = WireInstance{typ : Type::Unknown, span : *expr_span, is_compiletime, source, is_remote_declaration : self.is_remote_declaration,}; + let wire_instance = WireInstance{typ : Type::Unknown, is_compiletime : IS_GEN_UNINIT, span : *expr_span, source, is_remote_declaration : self.is_remote_declaration,}; Some(self.instantiations.alloc(Instantiation::Wire(wire_instance))) } fn flatten_assignable_expr(&mut self, (expr, span) : &SpanAssignableExpression) -> Option { @@ -457,7 +441,7 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { // temporary let module_port_wire_decl = self.instantiations[*field].extract_wire_declaration(); - let module_port_proxy = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : module_port_wire_decl.typ.clone(), is_compiletime : module_port_wire_decl.identifier_type == IdentifierType::Generative, span : *func_span, is_remote_declaration : self.is_remote_declaration, source : WireSource::WireRead(*field)})); + let module_port_proxy = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : module_port_wire_decl.typ.clone(), is_compiletime : IS_GEN_UNINIT, span : *func_span, is_remote_declaration : self.is_remote_declaration, source : WireSource::WireRead(*field)})); self.instantiations.alloc(Instantiation::Connection(Connection{num_regs : to_i.num_regs, from: module_port_proxy, to: write_side})); } }, @@ -477,9 +461,7 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { Statement::If{condition : condition_expr, then, els} => { let Some(condition) = self.flatten_expr(condition_expr) else {continue}; - let is_compiletime = self.instantiations[condition].extract_wire().is_compiletime; - - let if_id = self.instantiations.alloc(Instantiation::IfStatement(IfStatement{is_compiletime, condition, then_start : UUID::PLACEHOLDER, then_end_else_start : UUID::PLACEHOLDER, else_end : UUID::PLACEHOLDER})); + let if_id = self.instantiations.alloc(Instantiation::IfStatement(IfStatement{condition, then_start : UUID::PLACEHOLDER, then_end_else_start : UUID::PLACEHOLDER, else_end : UUID::PLACEHOLDER})); let then_start = self.instantiations.get_next_alloc_id(); self.flatten_code(then); @@ -522,29 +504,13 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { } } - /* - ==== Type Checking ==== + ==== Typechecking ==== */ fn typecheck_wire_is_of_type(&self, wire : &WireInstance, expected : &Type, context : &str) { typecheck(&wire.typ, wire.span, expected, context, self.type_list_for_naming, &self.errors); } - // Typechecks things like that arrays have compiletime integer sizes - fn typecheck_type_generic_parameters(&self, typ : &Type) { - match typ { - Type::Error => {} - Type::Unknown => unreachable!(), // Should only run this on types that have been properly resolved! - Type::Named{id:_, span:_} => {} - Type::Array(arr_box) => { - let (arr_typ, size_val) = arr_box.deref(); - self.typecheck_type_generic_parameters(arr_typ); - let size_val_wire = &self.instantiations[*size_val].extract_wire(); - self.typecheck_wire_is_of_type(size_val_wire, &INT_TYPE, "Array size"); - } - } - } - fn typecheck(&mut self) { let look_at_queue : Vec = self.instantiations.iter().map(|(id,_)| id).collect(); @@ -556,7 +522,10 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { let latency_spec_wire = &self.instantiations[latency_spec].extract_wire(); self.typecheck_wire_is_of_type(latency_spec_wire, &INT_TYPE, "latency specifier"); } - self.typecheck_type_generic_parameters(&decl.typ); + + decl.typ.for_each_generative_input(&mut |param_id| { + self.typecheck_wire_is_of_type(self.instantiations[param_id].extract_wire(), &INT_TYPE, "Array size"); + }); } Instantiation::IfStatement(stm) => { let wire = &self.instantiations[stm.condition].extract_wire(); @@ -624,13 +593,8 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { } } - // Typecheck compile-time ness - let from_wire = self.instantiations[conn.from].extract_wire(); - if conn_root.identifier_type == IdentifierType::Generative { - must_be_compiletime_with_info(from_wire, "Assignments to generative variables", &self.errors, || vec![error_info(conn_root.get_full_decl_span(), self.errors.file, "Declared here")]); - } - // Typecheck the value with target type + let from_wire = self.instantiations[conn.from].extract_wire(); if let Some(target_type) = write_to_type { self.typecheck_wire_is_of_type(from_wire, &target_type, "connection"); } @@ -648,7 +612,96 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { } } - /* Additional Warnings */ + /* + ==== Generative Code Checking ==== + */ + fn must_be_compiletime_with_info Vec>(&self, wire : &WireInstance, context : &str, ctx_func : CtxFunc) { + if !wire.is_compiletime { + self.errors.error_with_info(wire.span, format!("{context} must be compile time"), ctx_func()); + } + } + fn must_be_compiletime(&self, wire : &WireInstance, context : &str) { + self.must_be_compiletime_with_info(wire, context, || Vec::new()); + } + + fn generative_check(&mut self) { + let mut runtime_if_stack : Vec<(FlatID, Span)> = Vec::new(); + + let mut declaration_depths : FlatAlloc, FlatIDMarker> = self.instantiations.iter().map(|_| None).collect(); + + for inst_id in self.instantiations.id_range() { + while let Some((end_id, span)) = runtime_if_stack.pop() { + if end_id != inst_id { + runtime_if_stack.push((end_id, span)); + break; + } + } + match &self.instantiations[inst_id] { + Instantiation::SubModule(_) => {} + Instantiation::WireDeclaration(decl) => { + if decl.identifier_type == IdentifierType::Generative { + assert!(declaration_depths[inst_id].is_none()); + declaration_depths[inst_id] = Some(runtime_if_stack.len()) + } + + decl.typ.for_each_generative_input(&mut |param_id| { + self.must_be_compiletime(self.instantiations[param_id].extract_wire(), "Array size"); + }); + } + + Instantiation::Wire(wire) => { + let mut is_generative = true; + if let WireSource::WireRead(from) = &wire.source { + let decl = self.instantiations[*from].extract_wire_declaration(); + if decl.identifier_type != IdentifierType::Generative { + is_generative = false; + } + } else { + wire.source.for_each_input_wire(&mut |source_id| { + let source_wire = self.instantiations[source_id].extract_wire(); + if !source_wire.is_compiletime { + is_generative = false; + } + }); + } + let Instantiation::Wire(wire) = &mut self.instantiations[inst_id] else {unreachable!()}; + wire.is_compiletime = is_generative; + } + Instantiation::Connection(conn) => { + let conn_root_decl = self.instantiations[conn.to.root].extract_wire_declaration(); + + if conn_root_decl.identifier_type == IdentifierType::Generative { + let from_wire = self.instantiations[conn.from].extract_wire(); + // Check that whatever's written to this declaration is also generative + self.must_be_compiletime_with_info(from_wire, "Assignments to generative variables", || vec![error_info(conn_root_decl.get_full_decl_span(), self.errors.file, "Declared here")]); + + // Check that this declaration isn't used in a non-compiletime if + let declared_at_depth = declaration_depths[conn.to.root].unwrap(); + + if runtime_if_stack.len() > declared_at_depth { + let mut infos = Vec::new(); + infos.push(error_info(conn_root_decl.get_full_decl_span(), self.errors.file, "Declared here")); + for (_, if_cond_span) in &runtime_if_stack[declared_at_depth..] { + infos.push(error_info(*if_cond_span, self.errors.file, "Runtime Condition here")); + } + self.errors.error_with_info(conn.to.span, "Cannot write to generative variables in runtime conditional block", infos); + } + } + } + Instantiation::IfStatement(if_stmt) => { + let condition_wire = self.instantiations[if_stmt.condition].extract_wire(); + if !condition_wire.is_compiletime { + runtime_if_stack.push((if_stmt.else_end, condition_wire.span)); + } + } + Instantiation::ForStatement(_) => {} + } + } + } + + /* + ==== Additional Warnings + */ fn find_unused_variables(&self, interface : &InterfacePorts) { // Setup Wire Fanouts List for faster processing let mut gathered_connection_fanin : FlatAlloc, FlatIDMarker> = self.instantiations.iter().map(|_| Vec::new()).collect(); @@ -774,6 +827,7 @@ impl FlattenedModule { context.flatten_code(&module.code); context.typecheck(); + context.generative_check(); context.find_unused_variables(&interface_ports); FlattenedModule { diff --git a/src/instantiation/mod.rs b/src/instantiation/mod.rs index d0b22d3..506b80f 100644 --- a/src/instantiation/mod.rs +++ b/src/instantiation/mod.rs @@ -436,7 +436,8 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { Instantiation::IfStatement(stm) => { let then_range = UUIDRange(stm.then_start, stm.then_end_else_start); let else_range = UUIDRange(stm.then_end_else_start, stm.else_end); - if stm.is_compiletime { + let if_condition_wire = self.flattened.instantiations[stm.condition].extract_wire(); + if if_condition_wire.is_compiletime { let condition_val = self.get_generation_value(stm.condition)?; let run_range = if condition_val.extract_bool() { then_range