diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 59843d96..2c7f4982 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -12,7 +12,6 @@ //! it gets too large. mod eq_circ_class; -mod hugr_hash_set; mod hugr_pchannel; mod hugr_pqueue; pub mod log; @@ -25,11 +24,11 @@ pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; +use fxhash::FxHashSet; use hugr::Hugr; use crate::circuit::CircuitHash; -use crate::optimiser::taso::hugr_hash_set::HugrHashSet; -use crate::optimiser::taso::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog}; +use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel; use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ}; use crate::optimiser::taso::worker::TasoWorker; use crate::rewrite::strategy::RewriteStrategy; @@ -112,7 +111,7 @@ where logger.log_best(best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge - let mut seen_hashes = HugrHashSet::singleton(circ.circuit_hash(), best_circ_cost); + let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); // The priority queue of circuits to be processed (this should not get big) const PRIORITY_QUEUE_CAPACITY: usize = 10_000; @@ -130,26 +129,19 @@ where let rewrites = self.rewriter.get_rewrites(&circ); for new_circ in self.strategy.apply_rewrites(rewrites, &circ) { - let new_circ_cost = (self.cost)(&new_circ); - if pq.len() > PRIORITY_QUEUE_CAPACITY / 2 && new_circ_cost > *pq.max_cost().unwrap() - { - // Ignore this circuit: it's too big - continue; - } let new_circ_hash = new_circ.circuit_hash(); + circ_cnt += 1; logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len()); - if !seen_hashes.insert(new_circ_hash, new_circ_cost) { - // Ignore this circuit: we've already seen it + if seen_hashes.contains(&new_circ_hash) { continue; } - circ_cnt += 1; - pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost); + pq.push_with_hash_unchecked(new_circ, new_circ_hash); + seen_hashes.insert(new_circ_hash); } if pq.len() >= PRIORITY_QUEUE_CAPACITY { // Haircut to keep the queue size manageable pq.truncate(PRIORITY_QUEUE_CAPACITY / 2); - seen_hashes.clear_over(*pq.max_cost().unwrap()); } if let Some(timeout) = timeout { @@ -180,37 +172,51 @@ where const PRIORITY_QUEUE_CAPACITY: usize = 10_000; // multi-consumer priority channel for queuing circuits to be processed by the workers - let mut pq = + let (tx_work, rx_work) = HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads); + // channel for sending circuits from threads back to main + let (tx_result, rx_result) = crossbeam_channel::unbounded(); let initial_circ_hash = circ.circuit_hash(); let mut best_circ = circ.clone(); let mut best_circ_cost = (self.cost)(&best_circ); + logger.log_best(best_circ_cost); + + // Hash of seen circuits. Dot not store circuits as this map gets huge + let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]); // Each worker waits for circuits to scan for rewrites using all the // patterns and sends the results back to main. let joins: Vec<_> = (0..n_threads) .map(|i| { TasoWorker::spawn( - pq.pop.clone().unwrap(), - pq.push.clone().unwrap(), + rx_work.clone(), + tx_result.clone(), self.rewriter.clone(), self.strategy.clone(), Some(format!("taso-worker-{i}")), ) }) .collect(); + // Drop our copy of the worker channels, so we don't count as a + // connected worker. + drop(rx_work); + drop(tx_result); // Queue the initial circuit - pq.push - .as_ref() - .unwrap() + tx_work .send(vec![(initial_circ_hash, circ.clone())]) .unwrap(); - // Drop our copy of the priority queue channels, so we don't count as a - // connected worker. - pq.drop_pop_push(); + // A counter of circuits seen. + let mut circ_cnt = 1; + + // A counter of jobs sent to the workers. + #[allow(unused)] + let mut jobs_sent = 0usize; + // A counter of completed jobs received from the workers. + #[allow(unused)] + let mut jobs_completed = 0usize; // TODO: Report dropped jobs in the queue, so we can check for termination. // Deadline for the optimisation timeout @@ -219,51 +225,66 @@ where Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)), }; - // Main loop: log best circuits as they come in from the priority queue, - // until the timeout is reached. + // Process worker results until we have seen all the circuits, or we run + // out of time. let mut timeout_flag = false; loop { select! { - recv(pq.log) -> msg => { + recv(rx_result) -> msg => { match msg { - Ok(PriorityChannelLog::NewBestCircuit(circ, cost)) => { - best_circ = circ; - best_circ_cost = cost; - logger.log_best(best_circ_cost); + Ok(hashed_circs) => { + let send_result = tracing::trace_span!(target: "taso::metrics", "recv_result").in_scope(|| { + jobs_completed += 1; + for (circ_hash, circ) in &hashed_circs { + circ_cnt += 1; + logger.log_progress(circ_cnt, None, seen_hashes.len()); + if seen_hashes.contains(circ_hash) { + continue; + } + seen_hashes.insert(*circ_hash); + + let cost = (self.cost)(circ); + + // Check if we got a new best circuit + if cost < best_circ_cost { + best_circ = circ.clone(); + best_circ_cost = cost; + logger.log_best(best_circ_cost); + } + jobs_sent += 1; + } + // Fill the workqueue with data from pq + tx_work.send(hashed_circs) + }); + if send_result.is_err() { + eprintln!("All our workers panicked. Stopping optimisation."); + break; + } + + // If there is no more data to process, we are done. + // + // TODO: Report dropped jobs in the workers, so we can check for termination. + //if jobs_sent == jobs_completed { + // break 'main; + //}; }, - Ok(PriorityChannelLog::CircuitCount(circuit_cnt, seen_cnt)) => { - logger.log_progress(circuit_cnt, None, seen_cnt); - } Err(crossbeam_channel::RecvError) => { - eprintln!("Priority queue panicked. Stopping optimisation."); + eprintln!("All our workers panicked. Stopping optimisation."); break; } } } recv(timeout_event) -> _ => { timeout_flag = true; - pq.timeout(); break; } } } - // Empty the log from the priority queue and store final circuit count. - let mut circuit_cnt = None; - while let Ok(log) = pq.log.recv() { - match log { - PriorityChannelLog::NewBestCircuit(circ, cost) => { - best_circ = circ; - best_circ_cost = cost; - logger.log_best(best_circ_cost); - } - PriorityChannelLog::CircuitCount(circ_cnt, _) => { - circuit_cnt = Some(circ_cnt); - } - } - } - logger.log_processing_end(circuit_cnt.unwrap_or(0), best_circ_cost, true, timeout_flag); + logger.log_processing_end(circ_cnt, best_circ_cost, true, timeout_flag); + // Drop the channel so the threads know to stop. + drop(tx_work); joins.into_iter().for_each(|j| j.join().unwrap()); best_circ diff --git a/src/optimiser/taso/hugr_hash_set.rs b/src/optimiser/taso/hugr_hash_set.rs deleted file mode 100644 index e710f32c..00000000 --- a/src/optimiser/taso/hugr_hash_set.rs +++ /dev/null @@ -1,155 +0,0 @@ -use std::collections::VecDeque; - -use fxhash::FxHashSet; - -/// A datastructure storing Hugr hashes. -/// -/// Stores hashes in buckets based on a cost function, to allow clearing -/// the set of hashes that are no longer needed. -pub(super) struct HugrHashSet { - buckets: VecDeque>, - /// The cost at the front of the queue. - min_cost: Option, -} - -impl HugrHashSet { - /// Create a new empty set. - pub(super) fn new() -> Self { - Self { - buckets: VecDeque::new(), - min_cost: None, - } - } - - /// Create a new set with a single hash and cost. - pub(super) fn singleton(hash: u64, cost: usize) -> Self { - let mut set = Self::new(); - set.insert(hash, cost); - set - } - - /// Insert circuit with given hash and cost. - /// - /// Returns whether the insertion was successful, i.e. the negation - /// of whether it was already present. - pub(super) fn insert(&mut self, hash: u64, cost: usize) -> bool { - let Some(min_cost) = self.min_cost.as_mut() else { - self.min_cost = Some(cost); - self.buckets.push_front([hash].into_iter().collect()); - return true; - }; - self.buckets.reserve(min_cost.saturating_sub(cost)); - while cost < *min_cost { - self.buckets.push_front(FxHashSet::default()); - *min_cost -= 1; - } - let bucket_index = cost - *min_cost; - if bucket_index >= self.buckets.len() { - self.buckets - .resize_with(bucket_index + 1, FxHashSet::default); - } - self.buckets[bucket_index].insert(hash) - } - - /// Returns whether the given hash is present in the set. - #[allow(dead_code)] - pub(super) fn contains(&self, hash: u64, cost: usize) -> bool { - let Some(min_cost) = self.min_cost else { - return false; - }; - let Some(index) = cost.checked_sub(min_cost) else { - return false; - }; - let Some(b) = self.buckets.get(index) else { - return false; - }; - b.contains(&hash) - } - - fn max_cost(&self) -> Option { - Some(self.min_cost? + self.buckets.len() - 1) - } - - /// Remove all hashes with cost strictly greater than the given cost. - pub(super) fn clear_over(&mut self, cost: usize) { - while self.max_cost().is_some() && self.max_cost() > Some(cost) { - self.buckets.pop_back(); - if self.buckets.is_empty() { - self.min_cost = None; - } - } - } - - /// The number of hashes in the set - pub(super) fn len(&self) -> usize { - self.buckets.iter().map(|b| b.len()).sum() - } -} - -#[cfg(test)] -mod tests { - use super::HugrHashSet; - - #[test] - fn insert_elements() { - // For simplicity, we use as cost: hash % 10 - let mut set = HugrHashSet::new(); - - assert!(!set.contains(0, 0)); - assert!(!set.contains(2, 0)); - assert!(!set.contains(2, 3)); - - assert!(set.insert(20, 2)); - assert!(!set.contains(0, 0)); - assert!(!set.insert(20, 2)); - assert!(set.insert(22, 2)); - assert!(set.insert(23, 2)); - - assert!(set.contains(22, 2)); - - assert!(set.insert(33, 3)); - assert_eq!(set.min_cost, Some(2)); - assert_eq!(set.max_cost(), Some(3)); - assert_eq!( - set.buckets, - [ - [20, 22, 23].into_iter().collect(), - [33].into_iter().collect() - ] - ); - - assert!(set.insert(3, 0)); - assert!(set.insert(1, 0)); - assert!(!set.insert(22, 2)); - assert!(set.contains(33, 3)); - assert!(set.contains(3, 0)); - assert!(!set.contains(3, 2)); - assert_eq!(set.min_cost, Some(0)); - assert_eq!(set.max_cost(), Some(3)); - - assert_eq!(set.min_cost, Some(0)); - assert_eq!( - set.buckets, - [ - [1, 3].into_iter().collect(), - [].into_iter().collect(), - [20, 22, 23].into_iter().collect(), - [33].into_iter().collect(), - ] - ); - } - - #[test] - fn remove_empty() { - let mut set = HugrHashSet::new(); - assert!(set.insert(20, 2)); - assert!(set.insert(30, 3)); - - assert_eq!(set.len(), 2); - set.clear_over(2); - assert_eq!(set.len(), 1); - set.clear_over(0); - assert_eq!(set.len(), 0); - assert!(set.min_cost.is_none()); - } -} diff --git a/src/optimiser/taso/hugr_pchannel.rs b/src/optimiser/taso/hugr_pchannel.rs index b6b295cb..1ec0d2e4 100644 --- a/src/optimiser/taso/hugr_pchannel.rs +++ b/src/optimiser/taso/hugr_pchannel.rs @@ -1,213 +1,95 @@ //! A multi-producer multi-consumer min-priority channel of Hugrs. +use std::marker::PhantomData; use std::thread; use crossbeam_channel::{select, Receiver, Sender}; use hugr::Hugr; -use super::{ - hugr_hash_set::HugrHashSet, - hugr_pqueue::{Entry, HugrPQ}, -}; +use super::hugr_pqueue::{Entry, HugrPQ}; /// A priority channel for HUGRs. /// /// Queues hugrs using a cost function `C` that produces priority values `P`. /// /// Uses a thread internally to orchestrate the queueing. -pub(super) struct HugrPriorityChannel { - // Channels to add and remove circuits from the queue. - push: Receiver>, - pop: Sender<(u64, Hugr)>, - // Outbound channel to log to main thread. - log: Sender, - // Inbound channel to be terminated. - timeout: Receiver<()>, - // The queue capacity. Queue size is halved when it exceeds this. - queue_capacity: usize, - // The priority queue data structure. - pq: HugrPQ, - // The set of hashes we've seen. - seen_hashes: HugrHashSet, - // The minimum cost we've seen. - min_cost: Option, - // The number of circuits we've seen (for logging). - circ_cnt: usize, +pub struct HugrPriorityChannel { + _phantom: PhantomData<(P, C)>, } -pub(super) type Item = (u64, Hugr); +pub type Item = (u64, Hugr); -/// Logging information from the priority channel. -pub(super) enum PriorityChannelLog { - NewBestCircuit(Hugr, usize), - CircuitCount(usize, usize), -} - -/// Channels for communication with the priority channel. -pub(super) struct PriorityChannelCommunication { - pub(super) push: Option>>, - pub(super) pop: Option>, - pub(super) log: Receiver, - timeout: Sender<()>, -} - -impl PriorityChannelCommunication { - /// Send Timeout signal to the priority channel. - pub(super) fn timeout(&self) { - self.timeout.send(()).unwrap(); - } - - /// Close the local copies of the push and pop channels. - pub(super) fn drop_pop_push(&mut self) { - self.pop = None; - self.push = None; - } -} - -impl HugrPriorityChannel +impl HugrPriorityChannel where - C: Fn(&Hugr) -> usize + Send + Sync + 'static, + C: Fn(&Hugr) -> P + Send + Sync + 'static, { /// Initialize the queueing system. /// - /// Start the Hugr priority queue in a new thread. - /// - /// Get back channels for communication with the priority queue - /// - push/pop channels for adding and removing circuits to/from the queue, - /// - a channel on which to receive logging information, and - /// - a channel on which to send a timeout signal. - pub(super) fn init(cost_fn: C, queue_capacity: usize) -> PriorityChannelCommunication { - // channels for pushing and popping circuits from pqueue - let (tx_push, rx_push) = crossbeam_channel::unbounded(); - let (tx_pop, rx_pop) = crossbeam_channel::bounded(0); - // channels for communication with main (logging, minimum circuits and timeout) - let (tx_log, rx_log) = crossbeam_channel::unbounded(); - let (tx_timeout, rx_timeout) = crossbeam_channel::bounded(0); - let pq = - HugrPriorityChannel::new(rx_push, tx_pop, tx_log, rx_timeout, cost_fn, queue_capacity); - pq.run(); - PriorityChannelCommunication { - push: Some(tx_push), - pop: Some(rx_pop), - log: rx_log, - timeout: tx_timeout, - } + /// Get back a channel on which to queue hugrs with their hash, and + /// a channel on which to receive the output. + pub fn init(cost_fn: C, queue_capacity: usize) -> (Sender>, Receiver) { + let (ins, inr) = crossbeam_channel::unbounded(); + let (outs, outr) = crossbeam_channel::bounded(0); + Self::run(inr, outs, cost_fn, queue_capacity); + (ins, outr) } - fn new( - push: Receiver>, - pop: Sender<(u64, Hugr)>, - log: Sender, - timeout: Receiver<()>, + /// Run the queuer as a thread. + fn run( + in_channel_orig: Receiver>, + out_channel_orig: Sender<(u64, Hugr)>, cost_fn: C, queue_capacity: usize, - ) -> Self { - // The priority queue, local to this thread. - let pq: HugrPQ = HugrPQ::with_capacity(cost_fn, queue_capacity); - // The set of hashes we've seen. - let seen_hashes = HugrHashSet::new(); - // The minimum cost we've seen. - let min_cost = None; - // The number of circuits we've seen (for logging). - let circ_cnt = 0; - - HugrPriorityChannel { - push, - pop, - log, - timeout, - queue_capacity, - pq, - seen_hashes, - min_cost, - circ_cnt, - } - } - - /// Run the queuer as a thread. - fn run(mut self) { + ) { let builder = thread::Builder::new().name("priority queueing".into()); + let in_channel = in_channel_orig.clone(); + let out_channel = out_channel_orig.clone(); let _ = builder .name("priority-channel".into()) .spawn(move || { + // The priority queue, local to this thread. + let mut pq: HugrPQ = + HugrPQ::with_capacity(cost_fn, queue_capacity); + loop { - if self.pq.is_empty() { - let Ok(new_circs) = self.push.recv() else { - // The senders have closed the channel, we can stop. - break; - }; - self.recv(new_circs); + if pq.is_empty() { + // Nothing queued to go out. Wait for input. + match in_channel.recv() { + Ok(new_circs) => { + for (hash, circ) in new_circs { + pq.push_with_hash_unchecked(circ, hash); + } + } + // The sender has closed the channel, we can stop. + Err(_) => break, + } } else { select! { - recv(self.push) -> result => { - let Ok(new_circs) = result else { - // The senders have closed the channel, we can stop. - break; - }; - self.recv(new_circs); + recv(in_channel) -> result => { + match result { + Ok(new_circs) => { + for (hash, circ) in new_circs { + pq.push_with_hash_unchecked(circ, hash); + } + } + // The sender has closed the channel, we can stop. + Err(_) => break, + } } - send(self.pop, {let Entry {hash, circ, ..} = self.pq.pop().unwrap(); (hash, circ)}) -> result => { + send(out_channel, {let Entry {hash, circ, ..} = pq.pop().unwrap(); (hash, circ)}) -> result => { match result { Ok(()) => {}, // The receivers have closed the channel, we can stop. Err(_) => break, } } - recv(self.timeout) -> _ => { - // We've timed out. - break - } } } + if pq.len() >= queue_capacity { + pq.truncate(queue_capacity / 2); + } } - // Send a last set of logs before terminating. - self.log - .send(PriorityChannelLog::CircuitCount( - self.circ_cnt, - self.seen_hashes.len(), - )) - .unwrap(); }) .unwrap(); } - - /// Add circuits to queue. - fn recv(&mut self, circs: Vec<(u64, Hugr)>) { - for (hash, circ) in circs { - let cost = (self.pq.cost_fn)(&circ); - if (self.pq.len() > self.queue_capacity / 2 && cost > *self.pq.max_cost().unwrap()) - || !self.seen_hashes.insert(hash, cost) - { - // Ignore this circuit: it's either too big or we've seen it before. - continue; - } - - // A new best circuit - if self.min_cost.is_none() || Some(cost) < self.min_cost { - self.min_cost = Some(cost); - self.log - .send(PriorityChannelLog::NewBestCircuit(circ.clone(), cost)) - .unwrap(); - } - - self.circ_cnt += 1; - self.pq.push_unchecked(circ, hash, cost); - - // Send logs every 1000 circuits. - if self.circ_cnt % 1000 == 0 { - // TODO: Add a minimum time between logs - self.log - .send(PriorityChannelLog::CircuitCount( - self.circ_cnt, - self.seen_hashes.len(), - )) - .unwrap(); - } - } - // If the queue got too big, truncate it. - if self.pq.len() >= self.queue_capacity { - self.pq.truncate(self.queue_capacity / 2); - self.seen_hashes.clear_over(*self.pq.max_cost().unwrap()); - } - } } diff --git a/src/optimiser/taso/hugr_pqueue.rs b/src/optimiser/taso/hugr_pqueue.rs index 7dd532cc..eba000e0 100644 --- a/src/optimiser/taso/hugr_pqueue.rs +++ b/src/optimiser/taso/hugr_pqueue.rs @@ -13,7 +13,7 @@ use crate::circuit::CircuitHash; pub(super) struct HugrPQ { queue: DoublePriorityQueue, hash_lookup: FxHashMap, - pub(super) cost_fn: C, + cost_fn: C, } pub(super) struct Entry { @@ -51,20 +51,20 @@ impl HugrPQ { C: Fn(&Hugr) -> P, { let hash = hugr.circuit_hash(); - let cost = (self.cost_fn)(&hugr); - self.push_unchecked(hugr, hash, cost); + self.push_with_hash_unchecked(hugr, hash); } - /// Push a Hugr into the queue with a precomputed hash and cost. + /// Push a Hugr into the queue with a precomputed hash. /// - /// This is useful to avoid recomputing the hash and cost function in - /// [`HugrPQ::push`] when they are already known. + /// This is useful to avoid recomputing the hash in [`HugrPQ::push`] when + /// it is already known. /// /// This does not check that the hash is valid. - pub(super) fn push_unchecked(&mut self, hugr: Hugr, hash: u64, cost: P) + pub(super) fn push_with_hash_unchecked(&mut self, hugr: Hugr, hash: u64) where C: Fn(&Hugr) -> P, { + let cost = (self.cost_fn)(&hugr); self.queue.push(hash, cost); self.hash_lookup.insert(hash, hugr); } @@ -85,11 +85,6 @@ impl HugrPQ { } } - /// The largest cost in the queue. - pub(super) fn max_cost(&self) -> Option<&P> { - self.queue.peek_max().map(|(_, cost)| cost) - } - delegate! { to self.queue { pub(super) fn len(&self) -> usize; diff --git a/src/optimiser/taso/worker.rs b/src/optimiser/taso/worker.rs index a2aadbf7..0d8667c8 100644 --- a/src/optimiser/taso/worker.rs +++ b/src/optimiser/taso/worker.rs @@ -51,7 +51,7 @@ where let send = tracing::trace_span!(target: "taso::metrics", "TasoWorker::send_result") .in_scope(|| tx_result.send(hashed_circs)); if send.is_err() { - // The priority queue closed the send channel, we can stop. + // The main thread closed the send channel, we can stop. break; } }