From 89d609da028c2ef7f186aed56d66b8e089f513ac Mon Sep 17 00:00:00 2001 From: Lennart Van Hirtum Date: Mon, 22 Jan 2024 23:32:23 +0100 Subject: [PATCH] First promising version of back-and-forth latency algorithm --- multiply_add.sus | 2 +- src/flattening.rs | 2 +- src/instantiation/latency.rs | 105 +----------------- src/instantiation/latency_algorithm.rs | 148 +++++++++++++++++++++++++ src/instantiation/mod.rs | 1 + 5 files changed, 157 insertions(+), 101 deletions(-) create mode 100644 src/instantiation/latency_algorithm.rs diff --git a/multiply_add.sus b/multiply_add.sus index fb30d42..ac75864 100644 --- a/multiply_add.sus +++ b/multiply_add.sus @@ -235,7 +235,7 @@ module blur2 : gen int a; gen bool b = true; - bool bb = false; + gen bool bb = false; if bb { a = 5; diff --git a/src/flattening.rs b/src/flattening.rs index 25963ad..391bdb5 100644 --- a/src/flattening.rs +++ b/src/flattening.rs @@ -700,7 +700,7 @@ impl<'inst, 'l, 'm> FlatteningContext<'inst, 'l, 'm> { } /* - ==== Additional Warnings + ==== Additional Warnings ==== */ fn find_unused_variables(&self, interface : &InterfacePorts) { // Setup Wire Fanouts List for faster processing diff --git a/src/instantiation/latency.rs b/src/instantiation/latency.rs index 9b5f555..936cc7c 100644 --- a/src/instantiation/latency.rs +++ b/src/instantiation/latency.rs @@ -1,50 +1,17 @@ -use std::{iter::zip, collections::VecDeque}; +use crate::arena_alloc::FlatAlloc; -use crate::{arena_alloc::FlatAlloc, errors::ErrorCollector}; +use super::{latency_algorithm::FanInOut, RealWire, SubModule, SubModuleIDMarker, WireIDMarker}; -use super::{WireID, WireIDMarker, RealWire, SubModule, SubModuleIDMarker}; - - - - -struct FanInOut { - other : WireID, - delta_latency : i64 -} - -/* - Algorithm: - Initialize all inputs at latency 0 - Perform full forward pass, making latencies the maximum of all incoming latencies - Then backward pass, moving nodes forward in latency as much as possible. - Only moving forward is possible, and only when not confliciting with a later node -*/ struct LatencyComputer { - fanins : FlatAlloc, WireIDMarker>, - fanouts : FlatAlloc, WireIDMarker> -} - -fn convert_fanin_to_fanout(fanins : &FlatAlloc, WireIDMarker>) -> FlatAlloc, WireIDMarker> { - let mut fanouts : FlatAlloc, WireIDMarker> = fanins.iter().map(|_| { - Vec::new() - }).collect(); - - for (id, fin) in fanins { - for f in fin { - fanouts[f.other].push(FanInOut { other: id, delta_latency: f.delta_latency }) - } - } - - fanouts + } - impl LatencyComputer { fn setup(wires : &FlatAlloc, submodules : &FlatAlloc) -> Self { // Wire to wire Fanin - let mut fanins : FlatAlloc, WireIDMarker> = wires.iter().map(|(id, wire)| { + let mut fanins : Vec> = wires.iter().map(|(id, wire)| { let mut fanin = Vec::new(); wire.source.iter_sources_with_min_latency(&mut |from, delta_latency| { - fanin.push(FanInOut{other : from, delta_latency}); + fanin.push(FanInOut{other : from.get_hidden_value(), delta_latency}); }); fanin }).collect(); @@ -68,66 +35,6 @@ impl LatencyComputer { }*/ // Process fanouts - let fanouts = convert_fanin_to_fanout(&fanins); - - Self {fanins, fanouts} - } - - fn compute_latencies_forward(&self) -> FlatAlloc { - let mut latencies : FlatAlloc = self.fanins.iter().map(|_| 0).collect(); - - let mut queue : VecDeque = VecDeque::new(); - queue.reserve(self.fanins.len()); - - let mut order : Vec = Vec::new(); - order.reserve(self.fanins.len()); - - for (id, v) in &self.fanouts { - if v.is_empty() { - queue.push_back(id); - latencies[id] = 1; // Initialize with 1 - } - } - - while let Some(s) = queue.pop_front() { - let mut all_explored = false; - for from in &self.fanins[s] { - - } - } - - latencies - } -} - -struct RuntimeData { - part_of_path : bool, - current_absolute_latency : i64 -} - -fn process_node_recursive(runtime_data : &mut FlatAlloc, fanouts : &FlatAlloc, WireIDMarker>, cur_node : WireID) { - runtime_data[cur_node].part_of_path = true; - for &FanInOut{other, delta_latency} in &fanouts[cur_node] { - let to_node_min_latency = runtime_data[cur_node].current_absolute_latency + delta_latency; - if to_node_min_latency > runtime_data[other].current_absolute_latency { - if runtime_data[other].part_of_path { - todo!("Cycles for positive net latency error!"); - } else { - runtime_data[other].current_absolute_latency = to_node_min_latency; - process_node_recursive(runtime_data, fanouts, other); - } - } + todo!() } - runtime_data[cur_node].part_of_path = false; -} - -fn find_all_cycles_starting_from(fanouts : &FlatAlloc, WireIDMarker>, starting_node : WireID) -> FlatAlloc { - let mut runtime_data : FlatAlloc = fanouts.iter().map(|_| RuntimeData{ - part_of_path: false, - current_absolute_latency : i64::MIN // Such that new nodes will always be overwritten - }).collect(); - - runtime_data[starting_node].current_absolute_latency = 0; - process_node_recursive(&mut runtime_data, fanouts, starting_node); - runtime_data } diff --git a/src/instantiation/latency_algorithm.rs b/src/instantiation/latency_algorithm.rs new file mode 100644 index 0000000..3ae8d22 --- /dev/null +++ b/src/instantiation/latency_algorithm.rs @@ -0,0 +1,148 @@ +use std::iter::zip; + + +pub struct FanInOut { + pub other : usize, + pub delta_latency : i64 +} +/* + Algorithm: + Initialize all inputs at latency 0 + Perform full forward pass, making latencies the maximum of all incoming latencies + Then backward pass, moving nodes forward in latency as much as possible. + Only moving forward is possible, and only when not confliciting with a later node +*/ +fn count_latency_recursive(part_of_path : &mut [bool], absolute_latency : &mut [i64], fanouts : &Vec>, cur_node : usize) { + part_of_path[cur_node] = true; + for &FanInOut{other, delta_latency} in &fanouts[cur_node] { + let to_node_min_latency = absolute_latency[cur_node] + delta_latency; + if to_node_min_latency > absolute_latency[other] { + if part_of_path[other] { + todo!("Cycles for positive net latency error!"); + } else { + absolute_latency[other] = to_node_min_latency; + count_latency_recursive(part_of_path, absolute_latency, fanouts, other); + } + } + } + part_of_path[cur_node] = false; +} + +fn count_latency(part_of_path : &mut [bool], absolute_latency : &mut [i64], fanouts : &Vec>, start_node : usize, start_value : i64) -> Option<()> { + for p in part_of_path.iter() {assert!(!*p);} + + if absolute_latency[start_node] != i64::MIN { + if absolute_latency[start_node] == start_value { + Some(()) // Return with no error, latency is already set and is correct value + } else { + todo!("Report latency error"); + None // Latency error. One of the ends has a different new latency! + } + } else { + absolute_latency[start_node] = start_value; + count_latency_recursive(part_of_path, absolute_latency, fanouts, start_node); + + for p in part_of_path.iter() {assert!(!*p);} + Some(()) + } +} + +fn solve_latencies(fanins : &Vec>) -> Option> { + let fanouts_holder = convert_fanin_to_fanout(fanins); + let fanouts = &fanouts_holder; + + let inputs : Vec = fanins.iter().enumerate().filter_map(|(idx, v)| v.is_empty().then_some(idx)).collect(); + let outputs : Vec = fanouts.iter().enumerate().filter_map(|(idx, v)| v.is_empty().then_some(idx)).collect(); + + let mut part_of_path : Box<[bool]> = fanouts.iter().map(|_| false).collect(); + let mut absolute_latencies_forward : Box<[i64]> = fanouts.iter().map(|_| i64::MIN).collect(); + let mut absolute_latencies_backward : Box<[i64]> = fanouts.iter().map(|_| i64::MIN).collect(); + + let Some(starting_node) = inputs.get(0) else {todo!("Output-only modules")}; + + absolute_latencies_backward[*starting_node] = 0; // Provide a seed to start the algorithm + let mut last_num_valid_inputs = 0; + loop { + let mut num_valid_inputs = 0; + // Copy over latencies from backward pass + for input in &inputs { + if absolute_latencies_backward[*input] != i64::MIN { // Once an extremity node has been assigned, its value can never change + count_latency(&mut part_of_path, &mut absolute_latencies_forward, fanouts, *input, -absolute_latencies_backward[*input])?; + num_valid_inputs += 1; + } + } + + // Check end conditions + if num_valid_inputs == inputs.len() { + break; // All inputs covered. Done! + } + if num_valid_inputs == last_num_valid_inputs { + // No change, we can't expore further, but haven't seen all inputs. + todo!("Proper error for disjoint inputs and outputs"); + return None; + } + last_num_valid_inputs = num_valid_inputs; + + // Copy over latencies from forward pass + for output in &outputs { + if absolute_latencies_forward[*output] != i64::MIN { + count_latency(&mut part_of_path, &mut absolute_latencies_backward, fanins, *output, -absolute_latencies_forward[*output])?; + } + } + } + Some(absolute_latencies_forward) +} + +fn convert_fanin_to_fanout(fanins : &Vec>) -> Vec> { + let mut fanouts : Vec> = fanins.iter().map(|_| { + Vec::new() + }).collect(); + + for (id, fin) in fanins.iter().enumerate() { + for f in fin { + fanouts[f.other].push(FanInOut { other: id, delta_latency: f.delta_latency }) + } + } + + fanouts +} + +fn latencies_equal(a : &[i64], b : &[i64]) -> bool { + let diff = a[0] - b[0]; + + for (x, y) in zip(a.iter(), b.iter()) { + if *x - *y != diff { + return false; + } + } + return true; +} + +#[cfg(test)] +mod tests { + use super::*; + + fn mk_fan(other : usize, delta_latency : i64) -> FanInOut { + FanInOut{other, delta_latency} + } + + #[test] + fn check_correct_latency_basic() { + let graph = vec![ + /*0*/vec![], + /*1*/vec![mk_fan(0, 0)], + /*2*/vec![mk_fan(1, 1),mk_fan(5, 1)], + /*3*/vec![mk_fan(2, 0)], + /*4*/vec![], + /*5*/vec![mk_fan(4, 0),mk_fan(1, 1)], + /*6*/vec![mk_fan(5, 0)] + ]; + + let correct_latencies = vec![-1,-1,1,1,0,0,0]; + + let found_latencies = solve_latencies(&graph).unwrap(); + + assert!(latencies_equal(&found_latencies, &correct_latencies)); + } +} + diff --git a/src/instantiation/mod.rs b/src/instantiation/mod.rs index 506b80f..e4c2862 100644 --- a/src/instantiation/mod.rs +++ b/src/instantiation/mod.rs @@ -4,6 +4,7 @@ use num::BigInt; use crate::{arena_alloc::{UUID, UUIDMarker, FlatAlloc, UUIDRange}, ast::{Operator, IdentifierType, Span, InterfacePorts}, typing::{ConcreteType, Type, BOOL_CONCRETE_TYPE, INT_CONCRETE_TYPE}, flattening::{FlatID, Instantiation, FlatIDMarker, ConnectionWritePathElement, WireSource, WireInstance, Connection, ConnectionWritePathElementComputed, FlattenedModule, FlatIDRange}, errors::ErrorCollector, linker::{Linker, NamedConstant}, value::{Value, compute_unary_op, compute_binary_op}, tokenizer::kw}; +pub mod latency_algorithm; pub mod latency; #[derive(Debug,Clone,Copy,PartialEq,Eq,Hash)]