From 68c915b1dfd6e265d45f9dd8f74919fca4c3a02f Mon Sep 17 00:00:00 2001 From: Lennart Van Hirtum Date: Wed, 3 Jan 2024 23:24:55 +0100 Subject: [PATCH] Refactor of how interface ports are handled --- src/ast.rs | 4 +- src/codegen_fallback.rs | 12 ++-- src/flattening.rs | 144 +++++++++++++++++++-------------------- src/instantiation/mod.rs | 45 +++++------- src/value.rs | 4 ++ 5 files changed, 102 insertions(+), 107 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 4ae0186..fd9ffc4 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -155,8 +155,8 @@ pub struct Module { impl Module { pub fn print_flattened_module(&self, linker : &Linker) { println!("Interface:"); - for port in &self.interface.interface_wires { - let port_direction = if port.is_input {"input"} else {"output"}; + for (port_idx, port) in self.interface.interface_wires.iter().enumerate() { + let port_direction = if port_idx < self.interface.outputs_start {"input"} else {"output"}; let port_type = port.typ.to_string(linker); let port_name = &port.port_name; println!(" {port_direction} {port_type} {port_name} -> {:?}", port.wire_id); diff --git a/src/codegen_fallback.rs b/src/codegen_fallback.rs index 1182d1b..80aec94 100644 --- a/src/codegen_fallback.rs +++ b/src/codegen_fallback.rs @@ -1,6 +1,6 @@ use std::{iter::zip, ops::Deref}; -use crate::{ast::{Module, IdentifierType}, instantiation::{InstantiatedModule, RealWireDataSource, StateInitialValue, ConnectToPathElem}, linker::{NamedUUID, get_builtin_uuid}, typing::ConcreteType, tokenizer::get_token_type_name, flattening::{Instantiation, WireDeclaration}, value::Value}; +use crate::{ast::{Module, IdentifierType}, instantiation::{InstantiatedModule, RealWireDataSource, StateInitialValue, ConnectToPathElem}, linker::{NamedUUID, get_builtin_uuid}, typing::ConcreteType, tokenizer::get_token_type_name, flattening::Instantiation, value::Value}; fn get_type_name_size(id : NamedUUID) -> u64 { if id == get_builtin_uuid("int") { @@ -56,9 +56,9 @@ pub fn gen_verilog_code(md : &Module, instance : &InstantiatedModule) -> String assert!(!instance.errors.did_error(), "Module cannot have experienced an error"); let mut program_text : String = format!("module {}(\n\tinput clk, \n", md.link_info.name); let submodule_interface = instance.interface.as_ref().unwrap(); - for (port, real_port) in zip(&md.interface.interface_wires, submodule_interface) { - let wire = &instance.wires[real_port.id]; - program_text.push_str(if port.is_input {"\tinput"} else {"\toutput /*mux_wire*/ reg"}); + for (port_idx, (port, real_port)) in zip(md.interface.interface_wires.iter(), submodule_interface).enumerate() { + let wire = &instance.wires[*real_port]; + program_text.push_str(if port_idx < md.interface.outputs_start {"\tinput"} else {"\toutput /*mux_wire*/ reg"}); program_text.push_str(&typ_to_verilog_array(&wire.typ)); program_text.push(' '); program_text.push_str(&wire.name); @@ -117,7 +117,7 @@ pub fn gen_verilog_code(md : &Module, instance : &InstantiatedModule) -> String let Some(sm_interface) = &sm.instance.interface else {unreachable!()}; // Having an invalid interface in a submodule is an error! This should have been caught before! for (port, wire) in zip(sm_interface, &sm.wires) { program_text.push_str(",\n."); - program_text.push_str(&sm.instance.wires[port.id].name); + program_text.push_str(&sm.instance.wires[*port].name); program_text.push('('); program_text.push_str(&instance.wires[*wire].name); program_text.push_str(")"); @@ -132,7 +132,7 @@ pub fn gen_verilog_code(md : &Module, instance : &InstantiatedModule) -> String let output_name = w.name.deref(); match is_state { StateInitialValue::Combinatorial => { - program_text.push_str(&format!("/*always_comb*/ always @(*) begin\n\t{output_name} <= 1'bX; // Not defined when not valid\n")); + program_text.push_str(&format!("/*always_comb*/ always @(*) begin\n\t{output_name} <= 1'bX; // Combinatorial wires are not defined when not valid\n")); } StateInitialValue::State{initial_value : _} => { program_text.push_str(&format!("/*always_ff*/ always @(posedge clk) begin\n")); diff --git a/src/flattening.rs b/src/flattening.rs index e6d0af1..dc4871c 100644 --- a/src/flattening.rs +++ b/src/flattening.rs @@ -45,22 +45,27 @@ pub struct Connection { pub condition : Option } -#[derive(Debug,Clone,Copy)] -pub struct InterfacePort { - pub is_input : bool, - pub id : FlatID -} - #[derive(Debug)] pub enum WireSource { WireRead{from_wire : FlatID}, // Used to add a span to the reference of a wire. - //SubModuleOutput{submodule : FlatID, port_idx : usize}, UnaryOp{op : Operator, right : FlatID}, BinaryOp{op : Operator, left : FlatID, right : FlatID}, ArrayAccess{arr : FlatID, arr_idx : FlatID}, Constant{value : Value}, } +impl WireSource { + pub fn for_each_input_wire(&self, func : &mut F) { + match self { + &WireSource::WireRead { from_wire } => {func(from_wire)} + &WireSource::UnaryOp { op:_, right } => {func(right)} + &WireSource::BinaryOp { op:_, left, right } => {func(left); func(right)} + &WireSource::ArrayAccess { arr, arr_idx } => {func(arr); func(arr_idx)} + WireSource::Constant { value:_ } => {} + } + } +} + #[derive(Debug)] pub struct WireInstance { pub typ : Type, @@ -85,9 +90,18 @@ impl WireDeclaration { } } +#[derive(Debug)] +pub struct SubModuleInstance { + pub module_uuid : NamedUUID, + pub name : Box, + pub typ_span : Span, + pub outputs_start : usize, + pub local_wires : Box<[FlatID]> +} + #[derive(Debug)] pub enum Instantiation { - SubModule{module_uuid : NamedUUID, name : Box, typ_span : Span, interface_wires : Vec}, + SubModule(SubModuleInstance), WireDeclaration(WireDeclaration), Wire(WireInstance), Connection(Connection), @@ -174,26 +188,20 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> { Some(()) } fn alloc_module_interface(&self, name : Box, module : &Module, module_uuid : NamedUUID, typ_span : Span) -> Instantiation { - let interface_wires = module.interface.interface_wires.iter().enumerate().map(|(port_idx, port)| { - let identifier_type = if port.is_input { - IdentifierType::Input - } else { - IdentifierType::Output - }; - let id = self.instantiations.alloc(Instantiation::WireDeclaration(WireDeclaration{ + let local_wires : Vec = module.interface.interface_wires.iter().enumerate().map(|(port_idx, port)| { + self.instantiations.alloc(Instantiation::WireDeclaration(WireDeclaration{ typ: port.typ.clone(), typ_span, - read_only : !port.is_input, - identifier_type, - name : format!("{}_{}", &module.link_info.name, &port.port_name).into_boxed_str(), + read_only : port_idx >= module.interface.outputs_start, + identifier_type : IdentifierType::Local, + name : format!("{}_{}", &name, &port.port_name).into_boxed_str(), name_token : None - })); - InterfacePort{is_input : port.is_input, id} + })) }).collect(); - Instantiation::SubModule{name, module_uuid, typ_span, interface_wires} + Instantiation::SubModule(SubModuleInstance{name, module_uuid, typ_span, outputs_start : module.interface.outputs_start, local_wires : local_wires.into_boxed_slice()}) } - fn desugar_func_call(&self, func_and_args : &[SpanExpression], closing_bracket_pos : usize, condition : Option) -> Option<(&Module, &[InterfacePort])> { + fn desugar_func_call(&self, func_and_args : &[SpanExpression], closing_bracket_pos : usize, condition : Option) -> Option<(&Module, &[FlatID])> { let (name_expr, name_expr_span) = &func_and_args[0]; // Function name is always there let func_instantiation_id = match name_expr { Expression::Named(LocalOrGlobal::Local(l)) => { @@ -212,9 +220,9 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> { } }; let func_instantiation = &self.instantiations[func_instantiation_id]; - let Instantiation::SubModule{module_uuid, name : _, typ_span : _, interface_wires} = func_instantiation else {unreachable!("It should be proven {func_instantiation:?} was a Module!");}; + let Instantiation::SubModule(SubModuleInstance{module_uuid, name : _, typ_span : _, outputs_start:_, local_wires}) = func_instantiation else {unreachable!("It should be proven {func_instantiation:?} was a Module!");}; let Named::Module(md) = &self.linker.links.globals[*module_uuid] else {unreachable!("UUID Should be a module!");}; - let (inputs, output_range) = md.interface.get_function_sugar_inputs_outputs(); + let (inputs, output_range) = md.interface.func_call_syntax_interface(); let mut args = &func_and_args[1..]; @@ -239,12 +247,12 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> { if self.typecheck(arg_read_side, &md.interface.interface_wires[field].typ, "submodule output") == None { continue; } - let func_input_port = &interface_wires[field]; - self.create_connection(Connection { num_regs: 0, from: arg_read_side, to: ConnectionWrite::simple(func_input_port.id, *name_expr_span), condition }); + let func_input_port = &local_wires[field]; + self.create_connection(Connection { num_regs: 0, from: arg_read_side, to: ConnectionWrite::simple(*func_input_port, *name_expr_span), condition }); } } - Some((md, &interface_wires[output_range])) + Some((md, &local_wires[output_range])) } fn flatten_single_expr(&self, (expr, expr_span) : &SpanExpression, condition : Option) -> Option { let span = *expr_span; // for more compact constructors @@ -302,7 +310,7 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> { return None; } - outputs[0].id + outputs[0] } }; Some(single_connection_side) @@ -430,8 +438,8 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> { let Some(write_side) = self.flatten_assignable_expr(&to_i.expr, condition) else {return;}; // temporary - let module_port_wire_decl = self.instantiations[field.id].extract_wire_declaration(); - let module_port_proxy = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : module_port_wire_decl.typ.clone(), is_compiletime : module_port_wire_decl.identifier_type == IdentifierType::Generative, span : *func_span, inst : WireSource::WireRead { from_wire: field.id }})); + let module_port_wire_decl = self.instantiations[*field].extract_wire_declaration(); + let module_port_proxy = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : module_port_wire_decl.typ.clone(), is_compiletime : module_port_wire_decl.identifier_type == IdentifierType::Generative, span : *func_span, inst : WireSource::WireRead { from_wire: *field }})); self.create_connection(Connection{num_regs : to_i.num_regs, from: module_port_proxy, to: write_side, condition}); } }, @@ -457,7 +465,6 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> { #[derive(Debug)] pub struct FlattenedInterfacePort { pub wire_id : FlatID, - pub is_input : bool, pub typ : Type, pub port_name : Box, pub span : Span @@ -465,34 +472,23 @@ pub struct FlattenedInterfacePort { #[derive(Debug, Default)] pub struct FlattenedInterface { - pub interface_wires : Vec, // Indexed by FieldID + pub interface_wires : Box<[FlattenedInterfacePort]>, // Ordered such that all inputs come first, then all outputs + pub outputs_start : usize } impl FlattenedInterface { pub fn new() -> Self { - FlattenedInterface { interface_wires: Vec::new() } + FlattenedInterface { interface_wires: Box::new([]), outputs_start : 0 } } - pub fn get_function_sugar_inputs_outputs(&self) -> (Range, Range) { - let mut last_output = self.interface_wires.len() - 1; - - while last_output > 0 { - last_output -= 1; - if self.interface_wires[last_output].is_input { - last_output += 1; - break; - } - } - - let mut last_input = last_output - 1; - while last_input > 0 { - last_input -= 1; - if !self.interface_wires[last_input].is_input { - last_input += 1; - break; - } - } - - (last_input..last_output, last_output..self.interface_wires.len()) + // Todo, just treat all inputs and outputs as function call interface + pub fn func_call_syntax_interface(&self) -> (Range, Range) { + (0..self.outputs_start, self.outputs_start..self.interface_wires.len()) + } + pub fn inputs(&self) -> &[FlattenedInterfacePort] { + &self.interface_wires[..self.outputs_start] + } + pub fn outputs(&self) -> &[FlattenedInterfacePort] { + &self.interface_wires[self.outputs_start..] } } @@ -515,8 +511,6 @@ impl FlattenedModule { Must be further processed by flatten, but this requires all modules to have been Initial Flattened for dependency resolution */ pub fn initialize_interfaces(linker : &Linker, module : &Module) -> (FlattenedInterface, FlattenedModule, FlatAlloc, DeclIDMarker>) { - let mut interface = FlattenedInterface::new(); - let flat_mod = FlattenedModule { instantiations: ListAllocator::new(), errors: ErrorCollector::new(module.link_info.file) @@ -530,6 +524,8 @@ impl FlattenedModule { module, }; + let mut inputs = Vec::new(); + let mut outputs = Vec::new(); for (decl_id, decl) in &module.declarations { let is_input = match decl.identifier_type { IdentifierType::Input => true, @@ -546,11 +542,22 @@ impl FlattenedModule { name : decl.name.clone(), name_token : Some(decl.name_token) })); + + let port = FlattenedInterfacePort { wire_id, typ, port_name: decl.name.clone(), span: decl.span }; + if is_input { + inputs.push(port); + } else { + outputs.push(port); + } - interface.interface_wires.push(FlattenedInterfacePort { wire_id, is_input, typ, port_name: decl.name.clone(), span: decl.span }); context.decl_to_flat_map[decl_id] = Some(wire_id); } + let outputs_start = inputs.len(); + inputs.reserve(outputs.len()); + inputs.append(&mut outputs); + let interface = FlattenedInterface{interface_wires: inputs.into_boxed_slice(), outputs_start}; + let decl_to_flat_map = context.decl_to_flat_map; (interface, flat_mod, decl_to_flat_map) } @@ -589,11 +596,9 @@ impl FlattenedModule { let mut wire_to_explore_queue : Vec = Vec::new(); - for port in &md.interface.interface_wires { - if !port.is_input { - is_instance_used_map[port.wire_id] = true; - wire_to_explore_queue.push(port.wire_id); - } + for port in md.interface.outputs() { + is_instance_used_map[port.wire_id] = true; + wire_to_explore_queue.push(port.wire_id); } println!("Pre Explore"); @@ -611,17 +616,12 @@ impl FlattenedModule { match &self.instantiations[item] { Instantiation::WireDeclaration(_) => {} Instantiation::Wire(wire) => { - match &wire.inst { - WireSource::WireRead{from_wire} => { - func(*from_wire); - } - _other => {} - } + wire.inst.for_each_input_wire(&mut func); } - Instantiation::SubModule{module_uuid : _, name : _, typ_span : _, interface_wires} => { - for port in interface_wires { - if port.is_input { - func(port.id); + Instantiation::SubModule(submodule) => { + for (port_id, port) in submodule.local_wires.iter().enumerate() { + if port_id < submodule.outputs_start { + func(*port); } } } diff --git a/src/instantiation/mod.rs b/src/instantiation/mod.rs index 5343d53..a0f0c18 100644 --- a/src/instantiation/mod.rs +++ b/src/instantiation/mod.rs @@ -2,7 +2,7 @@ use std::{rc::Rc, ops::Deref, cell::RefCell}; use num::BigInt; -use crate::{arena_alloc::{UUID, UUIDMarker, FlatAlloc}, ast::{Operator, Module, IdentifierType, Span}, typing::{ConcreteType, Type}, flattening::{FlatID, Instantiation, FlatIDMarker, ConnectionWritePathElement, WireSource, WireInstance, Connection, ConnectionWritePathElementComputed, WireDeclaration}, errors::ErrorCollector, linker::{Linker, get_builtin_uuid}, value::{Value, compute_unary_op, compute_binary_op}}; +use crate::{arena_alloc::{UUID, UUIDMarker, FlatAlloc}, ast::{Operator, Module, IdentifierType, Span}, typing::{ConcreteType, Type}, flattening::{FlatID, Instantiation, FlatIDMarker, ConnectionWritePathElement, WireSource, WireInstance, Connection, ConnectionWritePathElementComputed, WireDeclaration, SubModuleInstance}, errors::ErrorCollector, linker::{Linker, get_builtin_uuid}, value::{Value, compute_unary_op, compute_binary_op}}; pub mod latency; @@ -92,12 +92,6 @@ pub struct RealWire { pub name : Box } -#[derive(Debug,Clone,Copy)] -pub struct InstantiatedInterfacePort { - pub id : WireID, - pub is_input : bool -} - #[derive(Debug)] pub struct SubModule { pub original_flat : FlatID, @@ -109,7 +103,7 @@ pub struct SubModule { #[derive(Debug)] pub struct InstantiatedModule { pub name : Box, // Unique name involving all template arguments - pub interface : Option>, // Interface is only valid if all wires of the interface were valid + pub interface : Option>, // Interface is only valid if all wires of the interface were valid pub wires : FlatAlloc, pub submodules : FlatAlloc, pub errors : ErrorCollector, @@ -283,7 +277,7 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { None } } - fn compute_compile_time(&self, wire_inst : &WireSource, typ : &ConcreteType) -> Option { + fn compute_compile_time(&self, wire_inst : &WireSource) -> Option { Some(match wire_inst { &WireSource::WireRead{from_wire} => { self.get_generation_value(from_wire)?.clone() @@ -330,19 +324,19 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { } } fn wire_to_real_wire(&mut self, w: &WireInstance, typ : ConcreteType, original_wire : FlatID) -> Option { - let (name, source) = match &w.inst { + let source = match &w.inst { &WireSource::WireRead{from_wire} => { - let Instantiation::WireDeclaration(WireDeclaration{typ:_, typ_span:_, read_only:_, identifier_type:_, name:_, name_token:_}) = &self.module.flattened.instantiations[from_wire] else {unreachable!("WireReads must point to a NamedWire!")}; + /*Assert*/ self.module.flattened.instantiations[from_wire].extract_wire_declaration(); // WireReads must point to a NamedWire! return Some(self.generation_state[from_wire].extract_wire()) } &WireSource::UnaryOp{op, right} => { let right = self.get_wire_or_constant_as_wire(right)?; - (None, RealWireDataSource::UnaryOp{op: op, right}) + RealWireDataSource::UnaryOp{op: op, right} } &WireSource::BinaryOp{op, left, right} => { let left = self.get_wire_or_constant_as_wire(left)?; let right = self.get_wire_or_constant_as_wire(right)?; - (None, RealWireDataSource::BinaryOp{op: op, left, right}) + RealWireDataSource::BinaryOp{op: op, left, right} } &WireSource::ArrayAccess{arr, arr_idx} => { let arr = self.get_wire_or_constant_as_wire(arr)?; @@ -350,12 +344,12 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { SubModuleOrWire::SubModule(_) => unreachable!(), SubModuleOrWire::Unnasigned => unreachable!(), SubModuleOrWire::Wire(w) => { - (None, RealWireDataSource::ArrayAccess{arr, arr_idx: *w}) + RealWireDataSource::ArrayAccess{arr, arr_idx: *w} } SubModuleOrWire::CompileTimeValue(v) => { let arr_idx_wire = self.module.flattened.instantiations[arr_idx].extract_wire(); let arr_idx = self.extract_integer_from_value(v, arr_idx_wire.span)?; - (None, RealWireDataSource::ConstArrayAccess{arr, arr_idx}) + RealWireDataSource::ConstArrayAccess{arr, arr_idx} } } } @@ -363,16 +357,16 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { unreachable!("Constant cannot be non-compile-time"); } }; - let name = name.unwrap_or_else(|| self.get_unique_name()); + let name = self.get_unique_name(); Some(self.wires.alloc(RealWire{ name, typ, original_wire, source})) } fn instantiate_flattened_module(&mut self) { for (original_wire, inst) in &self.module.flattened.instantiations { let instance_to_add : SubModuleOrWire = match inst { - Instantiation::SubModule{module_uuid, name, typ_span, interface_wires} => { + Instantiation::SubModule(SubModuleInstance{module_uuid, name, typ_span, outputs_start, local_wires}) => { let instance = self.linker.instantiate(*module_uuid); - let interface_real_wires = interface_wires.iter().map(|port| { - self.generation_state[port.id].extract_wire() + let interface_real_wires = local_wires.iter().map(|port| { + self.generation_state[*port].extract_wire() }).collect(); SubModuleOrWire::SubModule(self.submodules.alloc(SubModule { original_flat: original_wire, instance, wires : interface_real_wires, name : name.clone()})) } @@ -404,7 +398,7 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { return; // Exit early, do not produce invalid wires in InstantiatedModule }; if w.is_compiletime { - let Some(value_computed) = self.compute_compile_time(&w.inst, &typ) else {return}; + let Some(value_computed) = self.compute_compile_time(&w.inst) else {return}; assert!(value_computed.is_of_type(&typ)); SubModuleOrWire::CompileTimeValue(value_computed) } else { @@ -423,19 +417,16 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> { } // Returns a proper interface if all ports involved did not produce an error. If a port did produce an error then returns None. - fn make_interface(&self) -> Option> { + fn make_interface(&self) -> Option> { let mut result = Vec::new(); result.reserve(self.module.interface.interface_wires.len()); - for port in &self.module.interface.interface_wires { + for port in self.module.interface.interface_wires.iter() { match &self.generation_state[port.wire_id] { SubModuleOrWire::Wire(w) => { - result.push(InstantiatedInterfacePort { - id: *w, - is_input: port.is_input - }); + result.push(*w) } SubModuleOrWire::Unnasigned => { - return None + return None // Error building interface } _other => unreachable!() // interface wires cannot point to anything else } diff --git a/src/value.rs b/src/value.rs index 3ed2572..4017d3b 100644 --- a/src/value.rs +++ b/src/value.rs @@ -52,10 +52,14 @@ impl Value { _other => false } } + + #[track_caller] pub fn extract_integer(&self) -> &BigInt { let Self::Integer(i) = self else {panic!("{:?} is not an integer!", self)}; i } + + #[track_caller] pub fn extract_bool(&self) -> bool { let Self::Bool(b) = self else {panic!("{:?} is not a bool!", self)}; *b