Skip to content

Commit

Permalink
Move all Compiletime code checking to its own function
Browse files Browse the repository at this point in the history
  • Loading branch information
VonTum committed Jan 20, 2024
1 parent 2d71d73 commit 0294d9b
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 58 deletions.
11 changes: 11 additions & 0 deletions multiply_add.sus
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ module blur2 :
blurred = data + prev;
}
prev = data;

gen int a;

gen bool b = true;
bool bb = false;

if bb {
a = 5;
} else {
a = 3;
}
}


Expand Down
168 changes: 111 additions & 57 deletions src/flattening.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ impl WireSource {
&WireSource::BinaryOp { op:_, left, right } => {func(left); func(right)}
&WireSource::ArrayAccess { arr, arr_idx } => {func(arr); func(arr_idx)}
WireSource::Constant(_) => {}
&WireSource::NamedConstant(_) => {}
WireSource::NamedConstant(_) => {}
}
}
}

const IS_GEN_UNINIT : bool = false;

#[derive(Debug)]
pub struct WireInstance {
pub typ : Type,
Expand Down Expand Up @@ -100,7 +102,6 @@ pub struct SubModuleInstance {

#[derive(Debug)]
pub struct IfStatement {
pub is_compiletime : bool,
pub condition : FlatID,
pub then_start : FlatID,
pub then_end_else_start : FlatID,
Expand Down Expand Up @@ -180,15 +181,6 @@ struct FlatteningContext<'inst, 'l, 'm> {
module : &'m Module,
}

fn must_be_compiletime_with_info<CtxFunc : FnOnce() -> Vec<ErrorInfo>>(wire : &WireInstance, context : &str, errors : &ErrorCollector, ctx_func : CtxFunc) {
if !wire.is_compiletime {
errors.error_with_info(wire.span, format!("{context} must be compile time"), ctx_func());
}
}
fn must_be_compiletime(wire : &WireInstance, context : &str, errors : &ErrorCollector) {
must_be_compiletime_with_info(wire, context, errors, || Vec::new());
}

impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {
fn map_to_type(&mut self, type_expr : &SpanTypeExpression) -> Type {
match &type_expr.0 {
Expand All @@ -203,7 +195,6 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {
let (array_type_expr, array_size_expr) = b.deref();
let array_element_type = self.map_to_type(&array_type_expr);
if let Some(array_size_wire_id) = self.flatten_expr(array_size_expr) {
must_be_compiletime(self.instantiations[array_size_wire_id].extract_wire(), "Array size", &self.errors);
Type::Array(Box::new((array_element_type, array_size_wire_id)))
} else {
Type::Error
Expand Down Expand Up @@ -238,7 +229,7 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {

let latency_specifier = if let Some(lat_expr) = &decl.latency_expr {
if let Some(latency_spec) = self.flatten_expr(lat_expr) {
must_be_compiletime(self.instantiations[latency_spec].extract_wire(), "Latency specifier", &self.errors);
self.must_be_compiletime(self.instantiations[latency_spec].extract_wire(), "Latency specifier");
Some(latency_spec)
} else {
None
Expand Down Expand Up @@ -346,41 +337,34 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {
Some((md, submodule_local_wires))
}
fn flatten_expr(&mut self, (expr, expr_span) : &SpanExpression) -> Option<FlatID> {
let (is_compiletime, source) = match expr {
let source = match expr {
Expression::Named(LocalOrGlobal::Local(l)) => {
let from_wire = self.decl_to_flat_map[*l].unwrap();
let decl = self.instantiations[from_wire].extract_wire_declaration();
(decl.identifier_type == IdentifierType::Generative, WireSource::WireRead(from_wire))
WireSource::WireRead(from_wire)
}
Expression::Named(LocalOrGlobal::Global(ref_span)) => {
let cst = self.linker.resolve_constant(*ref_span, &self.errors)?;
(true, WireSource::NamedConstant(cst))
WireSource::NamedConstant(cst)
}
Expression::Constant(cst) => {
(true, WireSource::Constant(cst.clone()))
WireSource::Constant(cst.clone())
}
Expression::UnaryOp(op_box) => {
let (op, _op_pos, operate_on) = op_box.deref();
let right = self.flatten_expr(operate_on)?;
let right_wire = self.instantiations[right].extract_wire();
(right_wire.is_compiletime, WireSource::UnaryOp{op : *op, right})
WireSource::UnaryOp{op : *op, right}
}
Expression::BinOp(binop_box) => {
let (left_expr, op, _op_pos, right_expr) = binop_box.deref();
let left = self.flatten_expr(left_expr)?;
let right = self.flatten_expr(right_expr)?;
let left_wire = self.instantiations[left].extract_wire();
let right_wire = self.instantiations[right].extract_wire();
let is_compiletime = left_wire.is_compiletime && right_wire.is_compiletime;
(is_compiletime, WireSource::BinaryOp{op : *op, left, right})
WireSource::BinaryOp{op : *op, left, right}
}
Expression::Array(arr_box) => {
let (left, right, _bracket_span) = arr_box.deref();
let arr = self.flatten_expr(left)?;
let arr_idx = self.flatten_expr(right)?;
let arr_wire = self.instantiations[arr].extract_wire();
let arr_idx_wire = self.instantiations[arr_idx].extract_wire();
(arr_wire.is_compiletime && arr_idx_wire.is_compiletime, WireSource::ArrayAccess{arr, arr_idx})
WireSource::ArrayAccess{arr, arr_idx}
}
Expression::FuncCall(func_and_args) => {
let (md, interface_wires) = self.desugar_func_call(func_and_args, expr_span.1)?;
Expand All @@ -397,7 +381,7 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {
}
};

let wire_instance = WireInstance{typ : Type::Unknown, span : *expr_span, is_compiletime, source, is_remote_declaration : self.is_remote_declaration,};
let wire_instance = WireInstance{typ : Type::Unknown, is_compiletime : IS_GEN_UNINIT, span : *expr_span, source, is_remote_declaration : self.is_remote_declaration,};
Some(self.instantiations.alloc(Instantiation::Wire(wire_instance)))
}
fn flatten_assignable_expr(&mut self, (expr, span) : &SpanAssignableExpression) -> Option<ConnectionWrite> {
Expand Down Expand Up @@ -457,7 +441,7 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {

// temporary
let module_port_wire_decl = self.instantiations[*field].extract_wire_declaration();
let module_port_proxy = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : module_port_wire_decl.typ.clone(), is_compiletime : module_port_wire_decl.identifier_type == IdentifierType::Generative, span : *func_span, is_remote_declaration : self.is_remote_declaration, source : WireSource::WireRead(*field)}));
let module_port_proxy = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : module_port_wire_decl.typ.clone(), is_compiletime : IS_GEN_UNINIT, span : *func_span, is_remote_declaration : self.is_remote_declaration, source : WireSource::WireRead(*field)}));
self.instantiations.alloc(Instantiation::Connection(Connection{num_regs : to_i.num_regs, from: module_port_proxy, to: write_side}));
}
},
Expand All @@ -477,9 +461,7 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {
Statement::If{condition : condition_expr, then, els} => {
let Some(condition) = self.flatten_expr(condition_expr) else {continue};

let is_compiletime = self.instantiations[condition].extract_wire().is_compiletime;

let if_id = self.instantiations.alloc(Instantiation::IfStatement(IfStatement{is_compiletime, condition, then_start : UUID::PLACEHOLDER, then_end_else_start : UUID::PLACEHOLDER, else_end : UUID::PLACEHOLDER}));
let if_id = self.instantiations.alloc(Instantiation::IfStatement(IfStatement{condition, then_start : UUID::PLACEHOLDER, then_end_else_start : UUID::PLACEHOLDER, else_end : UUID::PLACEHOLDER}));
let then_start = self.instantiations.get_next_alloc_id();

self.flatten_code(then);
Expand Down Expand Up @@ -522,29 +504,13 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {
}
}


/*
==== Type Checking ====
==== Typechecking ====
*/
fn typecheck_wire_is_of_type(&self, wire : &WireInstance, expected : &Type, context : &str) {
typecheck(&wire.typ, wire.span, expected, context, self.type_list_for_naming, &self.errors);
}

// Typechecks things like that arrays have compiletime integer sizes
fn typecheck_type_generic_parameters(&self, typ : &Type) {
match typ {
Type::Error => {}
Type::Unknown => unreachable!(), // Should only run this on types that have been properly resolved!
Type::Named{id:_, span:_} => {}
Type::Array(arr_box) => {
let (arr_typ, size_val) = arr_box.deref();
self.typecheck_type_generic_parameters(arr_typ);
let size_val_wire = &self.instantiations[*size_val].extract_wire();
self.typecheck_wire_is_of_type(size_val_wire, &INT_TYPE, "Array size");
}
}
}

fn typecheck(&mut self) {
let look_at_queue : Vec<FlatID> = self.instantiations.iter().map(|(id,_)| id).collect();

Expand All @@ -556,7 +522,10 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {
let latency_spec_wire = &self.instantiations[latency_spec].extract_wire();
self.typecheck_wire_is_of_type(latency_spec_wire, &INT_TYPE, "latency specifier");
}
self.typecheck_type_generic_parameters(&decl.typ);

decl.typ.for_each_generative_input(&mut |param_id| {
self.typecheck_wire_is_of_type(self.instantiations[param_id].extract_wire(), &INT_TYPE, "Array size");
});
}
Instantiation::IfStatement(stm) => {
let wire = &self.instantiations[stm.condition].extract_wire();
Expand Down Expand Up @@ -624,13 +593,8 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {
}
}

// Typecheck compile-time ness
let from_wire = self.instantiations[conn.from].extract_wire();
if conn_root.identifier_type == IdentifierType::Generative {
must_be_compiletime_with_info(from_wire, "Assignments to generative variables", &self.errors, || vec![error_info(conn_root.get_full_decl_span(), self.errors.file, "Declared here")]);
}

// Typecheck the value with target type
let from_wire = self.instantiations[conn.from].extract_wire();
if let Some(target_type) = write_to_type {
self.typecheck_wire_is_of_type(from_wire, &target_type, "connection");
}
Expand All @@ -648,7 +612,96 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> {
}
}

/* Additional Warnings */
/*
==== Generative Code Checking ====
*/
fn must_be_compiletime_with_info<CtxFunc : FnOnce() -> Vec<ErrorInfo>>(&self, wire : &WireInstance, context : &str, ctx_func : CtxFunc) {
if !wire.is_compiletime {
self.errors.error_with_info(wire.span, format!("{context} must be compile time"), ctx_func());
}
}
fn must_be_compiletime(&self, wire : &WireInstance, context : &str) {
self.must_be_compiletime_with_info(wire, context, || Vec::new());
}

fn generative_check(&mut self) {
let mut runtime_if_stack : Vec<(FlatID, Span)> = Vec::new();

let mut declaration_depths : FlatAlloc<Option<usize>, FlatIDMarker> = self.instantiations.iter().map(|_| None).collect();

for inst_id in self.instantiations.id_range() {
while let Some((end_id, span)) = runtime_if_stack.pop() {
if end_id != inst_id {
runtime_if_stack.push((end_id, span));
break;
}
}
match &self.instantiations[inst_id] {
Instantiation::SubModule(_) => {}
Instantiation::WireDeclaration(decl) => {
if decl.identifier_type == IdentifierType::Generative {
assert!(declaration_depths[inst_id].is_none());
declaration_depths[inst_id] = Some(runtime_if_stack.len())
}

decl.typ.for_each_generative_input(&mut |param_id| {
self.must_be_compiletime(self.instantiations[param_id].extract_wire(), "Array size");
});
}

Instantiation::Wire(wire) => {
let mut is_generative = true;
if let WireSource::WireRead(from) = &wire.source {
let decl = self.instantiations[*from].extract_wire_declaration();
if decl.identifier_type != IdentifierType::Generative {
is_generative = false;
}
} else {
wire.source.for_each_input_wire(&mut |source_id| {
let source_wire = self.instantiations[source_id].extract_wire();
if !source_wire.is_compiletime {
is_generative = false;
}
});
}
let Instantiation::Wire(wire) = &mut self.instantiations[inst_id] else {unreachable!()};
wire.is_compiletime = is_generative;
}
Instantiation::Connection(conn) => {
let conn_root_decl = self.instantiations[conn.to.root].extract_wire_declaration();

if conn_root_decl.identifier_type == IdentifierType::Generative {
let from_wire = self.instantiations[conn.from].extract_wire();
// Check that whatever's written to this declaration is also generative
self.must_be_compiletime_with_info(from_wire, "Assignments to generative variables", || vec![error_info(conn_root_decl.get_full_decl_span(), self.errors.file, "Declared here")]);

// Check that this declaration isn't used in a non-compiletime if
let declared_at_depth = declaration_depths[conn.to.root].unwrap();

if runtime_if_stack.len() > declared_at_depth {
let mut infos = Vec::new();
infos.push(error_info(conn_root_decl.get_full_decl_span(), self.errors.file, "Declared here"));
for (_, if_cond_span) in &runtime_if_stack[declared_at_depth..] {
infos.push(error_info(*if_cond_span, self.errors.file, "Runtime Condition here"));
}
self.errors.error_with_info(conn.to.span, "Cannot write to generative variables in runtime conditional block", infos);
}
}
}
Instantiation::IfStatement(if_stmt) => {
let condition_wire = self.instantiations[if_stmt.condition].extract_wire();
if !condition_wire.is_compiletime {
runtime_if_stack.push((if_stmt.else_end, condition_wire.span));
}
}
Instantiation::ForStatement(_) => {}
}
}
}

/*
==== Additional Warnings
*/
fn find_unused_variables(&self, interface : &InterfacePorts<FlatID>) {
// Setup Wire Fanouts List for faster processing
let mut gathered_connection_fanin : FlatAlloc<Vec<FlatID>, FlatIDMarker> = self.instantiations.iter().map(|_| Vec::new()).collect();
Expand Down Expand Up @@ -774,6 +827,7 @@ impl FlattenedModule {

context.flatten_code(&module.code);
context.typecheck();
context.generative_check();
context.find_unused_variables(&interface_ports);

FlattenedModule {
Expand Down
3 changes: 2 additions & 1 deletion src/instantiation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> {
Instantiation::IfStatement(stm) => {
let then_range = UUIDRange(stm.then_start, stm.then_end_else_start);
let else_range = UUIDRange(stm.then_end_else_start, stm.else_end);
if stm.is_compiletime {
let if_condition_wire = self.flattened.instantiations[stm.condition].extract_wire();
if if_condition_wire.is_compiletime {
let condition_val = self.get_generation_value(stm.condition)?;
let run_range = if condition_val.extract_bool() {
then_range
Expand Down

0 comments on commit 0294d9b

Please sign in to comment.