diff --git a/src/circuit/cost.rs b/src/circuit/cost.rs index 2400bc7c..9992aafd 100644 --- a/src/circuit/cost.rs +++ b/src/circuit/cost.rs @@ -12,8 +12,16 @@ use crate::T2Op; /// The cost for a group of operations in a circuit, each with cost `OpCost`. pub trait CircuitCost: Add + Sum + Debug + Default + Clone + Ord { + /// Return the cost as a `usize`. This may discard some of the cost information. + fn as_usize(&self) -> usize; + /// Subtract another cost to get the signed distance between `self` and `rhs`. - fn sub_cost(&self, rhs: &Self) -> isize; + /// + /// Equivalent to `self.as_usize() - rhs.as_usize()`. + #[inline] + fn sub_cost(&self, rhs: &Self) -> isize { + self.as_usize() as isize - rhs.as_usize() as isize + } /// Divide the cost, rounded up. fn div_cost(&self, n: NonZeroUsize) -> Self; @@ -64,8 +72,8 @@ impl Sum for MajorMinorCost { impl CircuitCost for MajorMinorCost { #[inline] - fn sub_cost(&self, rhs: &Self) -> isize { - self.major as isize - rhs.major as isize + fn as_usize(&self) -> usize { + self.major } #[inline] @@ -78,8 +86,8 @@ impl CircuitCost for MajorMinorCost { impl CircuitCost for usize { #[inline] - fn sub_cost(&self, rhs: &Self) -> isize { - *self as isize - *rhs as isize + fn as_usize(&self) -> usize { + *self } #[inline] diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 20e73078..64b70e67 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -152,7 +152,7 @@ where let mut pq = HugrPQ::new(cost_fn, queue_size); pq.push(circ.clone()); - let mut circ_cnt = 1; + let mut circ_cnt = 0; let mut timeout_flag = false; while let Some(Entry { circ, cost, .. }) = pq.pop() { if cost < best_circ_cost { @@ -160,6 +160,7 @@ where best_circ_cost = cost; logger.log_best(&best_circ_cost); } + circ_cnt += 1; let rewrites = self.rewriter.get_rewrites(&circ); for new_circ in self.strategy.apply_rewrites(rewrites, &circ) { @@ -187,7 +188,13 @@ where } } - logger.log_processing_end(circ_cnt, best_circ_cost, false, timeout_flag); + logger.log_processing_end( + circ_cnt, + Some(seen_hashes.len()), + best_circ_cost, + false, + timeout_flag, + ); best_circ } @@ -211,38 +218,34 @@ where let strategy = self.strategy.clone(); move |circ: &'_ Hugr| strategy.circuit_cost(circ) }; - let mut pq = HugrPriorityChannel::init(cost_fn, queue_size); + let (pq, rx_log) = HugrPriorityChannel::init(cost_fn.clone(), queue_size); let initial_circ_hash = circ.circuit_hash(); let mut best_circ = circ.clone(); let mut best_circ_cost = self.cost(&best_circ); + // Initialise the work channels and send the initial circuit. + pq.send(vec![Work { + cost: best_circ_cost.clone(), + hash: initial_circ_hash, + circ: circ.clone(), + }]) + .unwrap(); + // Each worker waits for circuits to scan for rewrites using all the // patterns and sends the results back to main. let joins: Vec<_> = (0..n_threads) .map(|i| { TasoWorker::spawn( - pq.pop.clone().unwrap(), - pq.push.clone().unwrap(), + i, + pq.clone(), self.rewriter.clone(), self.strategy.clone(), - Some(format!("taso-worker-{i}")), + cost_fn.clone(), ) }) .collect(); - // Queue the initial circuit - pq.push - .as_ref() - .unwrap() - .send(vec![(initial_circ_hash, circ.clone())]) - .unwrap(); - // Drop our copy of the priority queue channels, so we don't count as a - // connected worker. - pq.drop_pop_push(); - - // TODO: Report dropped jobs in the queue, so we can check for termination. - // Deadline for the optimisation timeout let timeout_event = match timeout { None => crossbeam_channel::never(), @@ -252,47 +255,68 @@ where // Main loop: log best circuits as they come in from the priority queue, // until the timeout is reached. let mut timeout_flag = false; + let mut processed_count = 0; + let mut seen_count = 0; loop { select! { - recv(pq.log) -> msg => { + recv(rx_log) -> msg => { match msg { Ok(PriorityChannelLog::NewBestCircuit(circ, cost)) => { - best_circ = circ; - best_circ_cost = cost; - logger.log_best(&best_circ_cost); + if cost < best_circ_cost { + best_circ = circ; + best_circ_cost = cost; + logger.log_best(&best_circ_cost); + } }, - Ok(PriorityChannelLog::CircuitCount(circuit_cnt, seen_cnt)) => { - logger.log_progress(circuit_cnt, None, seen_cnt); + Ok(PriorityChannelLog::CircuitCount{processed_count: proc, seen_count: seen, queue_length}) => { + processed_count = proc; + seen_count = seen; + logger.log_progress(processed_count, Some(queue_length), seen_count); } Err(crossbeam_channel::RecvError) => { - eprintln!("Priority queue panicked. Stopping optimisation."); + logger.log("The priority channel panicked. Stopping TASO optimisation."); + let _ = pq.close(); break; } } } recv(timeout_event) -> _ => { timeout_flag = true; - pq.timeout(); + // Signal the workers to stop. + let _ = pq.close(); break; } } } // Empty the log from the priority queue and store final circuit count. - let mut circuit_cnt = None; - while let Ok(log) = pq.log.recv() { + while let Ok(log) = rx_log.recv() { match log { PriorityChannelLog::NewBestCircuit(circ, cost) => { - best_circ = circ; - best_circ_cost = cost; - logger.log_best(&best_circ_cost); + if cost < best_circ_cost { + best_circ = circ; + best_circ_cost = cost; + logger.log_best(&best_circ_cost); + } } - PriorityChannelLog::CircuitCount(circ_cnt, _) => { - circuit_cnt = Some(circ_cnt); + PriorityChannelLog::CircuitCount { + processed_count: proc, + seen_count: seen, + queue_length, + } => { + processed_count = proc; + seen_count = seen; + logger.log_progress(processed_count, Some(queue_length), seen_count); } } } - logger.log_processing_end(circuit_cnt.unwrap_or(0), best_circ_cost, true, timeout_flag); + logger.log_processing_end( + processed_count, + Some(seen_count), + best_circ_cost, + true, + timeout_flag, + ); joins.into_iter().for_each(|j| j.join().unwrap()); @@ -359,7 +383,7 @@ where logger.log_best(best_circ_cost.clone()); } - logger.log_processing_end(n_threads.get(), best_circ_cost, true, false); + logger.log_processing_end(n_threads.get(), None, best_circ_cost, true, false); joins.into_iter().for_each(|j| j.join().unwrap()); Ok(best_circ) @@ -406,6 +430,8 @@ mod taso_default { #[cfg(feature = "portmatching")] pub use taso_default::DefaultTasoOptimiser; +use self::hugr_pchannel::Work; + #[cfg(test)] #[cfg(feature = "portmatching")] mod tests { diff --git a/src/optimiser/taso/hugr_pchannel.rs b/src/optimiser/taso/hugr_pchannel.rs index 13d7e6b0..9242e3d9 100644 --- a/src/optimiser/taso/hugr_pchannel.rs +++ b/src/optimiser/taso/hugr_pchannel.rs @@ -1,101 +1,170 @@ //! A multi-producer multi-consumer min-priority channel of Hugrs. +use std::sync::{Arc, RwLock}; use std::thread; +use std::time::Instant; -use crossbeam_channel::{select, Receiver, Sender}; +use crossbeam_channel::{select, Receiver, RecvError, SendError, Sender}; use fxhash::FxHashSet; use hugr::Hugr; +use crate::circuit::cost::CircuitCost; + use super::hugr_pqueue::{Entry, HugrPQ}; +/// A unit of work for a worker, consisting of a circuit to process, along its +/// hash and cost. +pub type Work

= Entry; + /// A priority channel for HUGRs. /// /// Queues hugrs using a cost function `C` that produces priority values `P`. /// /// Uses a thread internally to orchestrate the queueing. -pub(super) struct HugrPriorityChannel { - // Channels to add and remove circuits from the queue. - push: Receiver>, - pop: Sender<(u64, Hugr)>, - // Outbound channel to log to main thread. +#[derive(Debug, Clone)] +pub struct HugrPriorityChannel { + /// Channel to add circuits from the queue. + push: Receiver>>, + /// Channel to pop circuits from the queue. + pop: Sender>, + /// Outbound channel to log to main thread. log: Sender>, - // Inbound channel to be terminated. - timeout: Receiver<()>, - // The priority queue data structure. + /// Timestamp of the last progress log. + /// Used to avoid spamming the log. + last_progress_log: Instant, + /// The priority queue data structure. pq: HugrPQ, - // The set of hashes we've seen. + /// The set of hashes we've seen. seen_hashes: FxHashSet, - // The minimum cost we've seen. + /// The minimum cost we've seen. min_cost: Option

, - // The number of circuits we've seen (for logging). + /// The number of circuits we've processed. circ_cnt: usize, + /// The maximum cost in the queue. Shared with the workers so they can cull + /// the circuits they generate. + max_cost: Arc>>, + /// Local copy of `max_cost`, used to avoid locking when checking the value. + local_max_cost: Option

, } -pub(super) type Item = (u64, Hugr); - /// Logging information from the priority channel. -pub(super) enum PriorityChannelLog { - NewBestCircuit(Hugr, C), - CircuitCount(usize, usize), +#[derive(Debug, Clone)] +pub enum PriorityChannelLog

{ + NewBestCircuit(Hugr, P), + CircuitCount { + processed_count: usize, + seen_count: usize, + queue_length: usize, + }, } /// Channels for communication with the priority channel. -pub(super) struct PriorityChannelCommunication { - pub(super) push: Option>>, - pub(super) pop: Option>, - pub(super) log: Receiver>, - timeout: Sender<()>, +#[derive(Clone)] +pub struct PriorityChannelCommunication

{ + /// A channel to add batches of circuits to the queue. + push: Sender>>, + /// A channel to remove the best candidate circuit from the queue. + pop: Receiver>, + /// A maximum accepted cost for the queue. Circuits with higher costs will + /// be dropped. + /// + /// Shared with the workers so they can cull the circuits they generate. + max_cost: Arc>>, } -impl PriorityChannelCommunication { - /// Send Timeout signal to the priority channel. - pub(super) fn timeout(&self) { - self.timeout.send(()).unwrap(); +impl PriorityChannelCommunication

{ + /// Signal the priority channel to stop. + /// + /// This will in turn signal the workers to stop. + pub fn close(&self) -> Result<(), SendError>>> { + self.push.send(Vec::new()) } - /// Close the local copies of the push and pop channels. - pub(super) fn drop_pop_push(&mut self) { - self.pop = None; - self.push = None; + /// Send a lot of circuits to the priority channel. + pub fn send(&self, work: Vec>) -> Result<(), SendError>>> { + if work.is_empty() { + return Ok(()); + } + // + match self.max_cost() { + Some(max_cost) => { + let filtered = work + .into_iter() + .filter(|Work { cost, .. }| cost < &max_cost) + .collect(); + self.push.send(filtered)?; + } + _ => self.push.send(work)?, + } + Ok(()) + } + + /// Receive a circuit from the priority channel. + /// + /// Blocks until a circuit is available. + pub fn recv(&self) -> Result, RecvError> { + self.pop.recv() + } + + /// Get the maximum accepted circuit cost. + /// + /// This function requires locking, so its value should be cached where + /// appropriate. + pub fn max_cost(&self) -> Option

{ + self.max_cost.read().as_deref().ok().cloned().flatten() } } impl HugrPriorityChannel where C: Fn(&Hugr) -> P + Send + Sync + 'static, - P: Ord + Send + Sync + Clone + 'static, + P: CircuitCost + Send + Sync + 'static, { /// Initialize the queueing system. /// /// Start the Hugr priority queue in a new thread. /// - /// Get back channels for communication with the priority queue - /// - push/pop channels for adding and removing circuits to/from the queue, - /// - a channel on which to receive logging information, and - /// - a channel on which to send a timeout signal. - pub(super) fn init(cost_fn: C, queue_capacity: usize) -> PriorityChannelCommunication

{ - // channels for pushing and popping circuits from pqueue + /// Get back a [`PriorityChannelCommunication`] for adding and removing circuits to/from the queue, + /// and a channel receiver to receive logging information. + pub fn init( + cost_fn: C, + queue_capacity: usize, + ) -> ( + PriorityChannelCommunication

, + Receiver>, + ) { + // Shared maximum cost in the queue. + let max_cost = Arc::new(RwLock::new(None)); + // Channels for pushing and popping circuits from pqueue let (tx_push, rx_push) = crossbeam_channel::unbounded(); let (tx_pop, rx_pop) = crossbeam_channel::bounded(0); - // channels for communication with main (logging, minimum circuits and timeout) + // Channel for logging results and statistics to the main thread. let (tx_log, rx_log) = crossbeam_channel::unbounded(); - let (tx_timeout, rx_timeout) = crossbeam_channel::bounded(0); - let pq = - HugrPriorityChannel::new(rx_push, tx_pop, tx_log, rx_timeout, cost_fn, queue_capacity); + + let pq = HugrPriorityChannel::new( + rx_push, + tx_pop, + tx_log, + max_cost.clone(), + cost_fn, + queue_capacity, + ); pq.run(); - PriorityChannelCommunication { - push: Some(tx_push), - pop: Some(rx_pop), - log: rx_log, - timeout: tx_timeout, - } + ( + PriorityChannelCommunication { + push: tx_push, + pop: rx_pop, + max_cost, + }, + rx_log, + ) } fn new( - push: Receiver>, - pop: Sender<(u64, Hugr)>, + push: Receiver>>, + pop: Sender>, log: Sender>, - timeout: Receiver<()>, + max_cost: Arc>>, cost_fn: C, queue_capacity: usize, ) -> Self { @@ -112,11 +181,14 @@ where push, pop, log, - timeout, + // Ensure we log the first progress. + last_progress_log: Instant::now() - std::time::Duration::from_secs(60), pq, seen_hashes, min_cost, circ_cnt, + max_cost, + local_max_cost: None, } } @@ -126,42 +198,46 @@ where let _ = builder .name("priority-channel".into()) .spawn(move || { - loop { - if self.pq.is_empty() { + 'main: loop { + while self.pq.is_empty() { let Ok(new_circs) = self.push.recv() else { - // The senders have closed the channel, we can stop. - break; + // Something went wrong + break 'main; }; + if new_circs.is_empty() { + // The main thread signalled us to stop. + break 'main; + } self.enqueue_circs(new_circs); - } else { - select! { - recv(self.push) -> result => { - let Ok(new_circs) = result else { - // The senders have closed the channel, we can stop. - break; - }; - self.enqueue_circs(new_circs); - } - send(self.pop, {let Entry {hash, circ, ..} = self.pq.pop().unwrap(); (hash, circ)}) -> result => { - match result { - Ok(()) => {}, - // The receivers have closed the channel, we can stop. - Err(_) => break, - } + } + select! { + recv(self.push) -> result => { + let Ok(new_circs) = result else { + // Something went wrong + break 'main; + }; + if new_circs.is_empty() { + // The main thread signalled us to stop. + break 'main; } - recv(self.timeout) -> _ => { - // We've timed out. - break + self.enqueue_circs(new_circs); + } + send(self.pop, self.pq.pop().unwrap()) -> result => { + if result.is_err() { + // Something went wrong. + break 'main; } + self.update_max_cost(); } } } // Send a last set of logs before terminating. self.log - .send(PriorityChannelLog::CircuitCount( - self.circ_cnt, - self.seen_hashes.len(), - )) + .send(PriorityChannelLog::CircuitCount { + processed_count: self.circ_cnt, + seen_count: self.seen_hashes.len(), + queue_length: self.pq.len(), + }) .unwrap(); }) .unwrap(); @@ -169,9 +245,8 @@ where /// Add circuits to queue. #[tracing::instrument(target = "taso::metrics", skip(self, circs))] - fn enqueue_circs(&mut self, circs: Vec<(u64, Hugr)>) { - for (hash, circ) in circs { - let cost = (self.pq.cost_fn)(&circ); + fn enqueue_circs(&mut self, circs: Vec>) { + for Work { cost, hash, circ } in circs { if !self.seen_hashes.insert(hash) { // Ignore this circuit: we've seen it before. continue; @@ -188,19 +263,37 @@ where .unwrap(); } - self.circ_cnt += 1; self.pq.push_unchecked(circ, hash, cost); + } + self.update_max_cost(); - // Send logs every 1000 circuits. - if self.circ_cnt % 1000 == 0 { - // TODO: Add a minimum time between logs - self.log - .send(PriorityChannelLog::CircuitCount( - self.circ_cnt, - self.seen_hashes.len(), - )) - .unwrap(); - } + // This is the result from processing a circuit. Add it to the count. + self.circ_cnt += 1; + if Instant::now() - self.last_progress_log > std::time::Duration::from_millis(100) { + self.log + .send(PriorityChannelLog::CircuitCount { + processed_count: self.circ_cnt, + seen_count: self.seen_hashes.len(), + queue_length: self.pq.len(), + }) + .unwrap(); + } + } + + /// Update the shared `max_cost` value. + /// + /// If the priority queue is full, set the `max_cost` to the maximum cost. + /// Otherwise, leave it as `None`. + #[inline] + fn update_max_cost(&mut self) { + if !self.pq.is_full() || self.pq.is_empty() { + return; + } + let queue_max = self.pq.max_cost().unwrap().clone(); + let local_max = self.local_max_cost.clone(); + if local_max.is_some() && queue_max < local_max.unwrap() { + self.local_max_cost = Some(queue_max.clone()); + *self.max_cost.write().unwrap() = Some(queue_max); } } } diff --git a/src/optimiser/taso/hugr_pqueue.rs b/src/optimiser/taso/hugr_pqueue.rs index f6e5440b..2919c161 100644 --- a/src/optimiser/taso/hugr_pqueue.rs +++ b/src/optimiser/taso/hugr_pqueue.rs @@ -10,23 +10,22 @@ use crate::circuit::CircuitHash; /// The cost function provided will be used as the priority of the Hugrs. /// Uses hashes internally to store the Hugrs. #[derive(Debug, Clone, Default)] -pub(super) struct HugrPQ { +pub struct HugrPQ { queue: DoublePriorityQueue, hash_lookup: FxHashMap, - pub(super) cost_fn: C, + cost_fn: C, max_size: usize, } -pub(super) struct Entry { - pub(super) circ: C, - pub(super) cost: P, - #[allow(unused)] // TODO remove? - pub(super) hash: H, +pub struct Entry { + pub circ: C, + pub cost: P, + pub hash: H, } impl HugrPQ { /// Create a new HugrPQ with a cost function and some initial capacity. - pub(super) fn new(cost_fn: C, max_size: usize) -> Self { + pub fn new(cost_fn: C, max_size: usize) -> Self { Self { queue: DoublePriorityQueue::with_capacity(max_size), hash_lookup: Default::default(), @@ -37,7 +36,7 @@ impl HugrPQ { /// Reference to the minimal Hugr in the queue. #[allow(unused)] - pub(super) fn peek(&self) -> Option> { + pub fn peek(&self) -> Option> { let (hash, cost) = self.queue.peek_min()?; let circ = self.hash_lookup.get(hash)?; Some(Entry { @@ -50,7 +49,7 @@ impl HugrPQ { /// Push a Hugr into the queue. /// /// If the queue is full, the element with the highest cost will be dropped. - pub(super) fn push(&mut self, hugr: Hugr) + pub fn push(&mut self, hugr: Hugr) where C: Fn(&Hugr) -> P, { @@ -67,7 +66,7 @@ impl HugrPQ { /// This does not check that the hash is valid. /// /// If the queue is full, the most last will be dropped. - pub(super) fn push_unchecked(&mut self, hugr: Hugr, hash: u64, cost: P) + pub fn push_unchecked(&mut self, hugr: Hugr, hash: u64, cost: P) where C: Fn(&Hugr) -> P, { @@ -85,14 +84,14 @@ impl HugrPQ { } /// Pop the minimal Hugr from the queue. - pub(super) fn pop(&mut self) -> Option> { + pub fn pop(&mut self) -> Option> { let (hash, cost) = self.queue.pop_min()?; let circ = self.hash_lookup.remove(&hash)?; Some(Entry { circ, cost, hash }) } /// Pop the maximal Hugr from the queue. - pub(super) fn pop_max(&mut self) -> Option> { + pub fn pop_max(&mut self) -> Option> { let (hash, cost) = self.queue.pop_max()?; let circ = self.hash_lookup.remove(&hash)?; Some(Entry { circ, cost, hash }) @@ -101,23 +100,32 @@ impl HugrPQ { /// Discard the largest elements of the queue. /// /// Only keep up to `max_size` elements. - pub(super) fn truncate(&mut self, max_size: usize) { + pub fn truncate(&mut self, max_size: usize) { while self.queue.len() > max_size { let (hash, _) = self.queue.pop_max().unwrap(); self.hash_lookup.remove(&hash); } } + /// The cost function used by the queue. + pub fn cost_fn(&self) -> &C { + &self.cost_fn + } + /// The largest cost in the queue. - #[allow(unused)] - pub(super) fn max_cost(&self) -> Option<&P> { + pub fn max_cost(&self) -> Option<&P> { self.queue.peek_max().map(|(_, cost)| cost) } + /// Returns `true` is the queue is at capacity. + pub fn is_full(&self) -> bool { + self.queue.len() >= self.max_size + } + delegate! { to self.queue { - pub(super) fn len(&self) -> usize; - pub(super) fn is_empty(&self) -> bool; + pub fn len(&self) -> usize; + pub fn is_empty(&self) -> bool; } } } diff --git a/src/optimiser/taso/log.rs b/src/optimiser/taso/log.rs index 402ed69c..9ef85554 100644 --- a/src/optimiser/taso/log.rs +++ b/src/optimiser/taso/log.rs @@ -1,11 +1,24 @@ //! Logging utilities for the TASO optimiser. +use std::time::{Duration, Instant}; use std::{fmt::Debug, io}; /// Logging configuration for the TASO optimiser. -#[derive(Default)] pub struct TasoLogger<'w> { circ_candidates_csv: Option>>, + last_circ_processed: usize, + last_progress_time: Instant, +} + +impl<'w> Default for TasoLogger<'w> { + fn default() -> Self { + Self { + circ_candidates_csv: Default::default(), + last_circ_processed: Default::default(), + // Ensure the first progress message is printed. + last_progress_time: Instant::now() - Duration::from_secs(60), + } + } } /// The logging target for general events. @@ -30,6 +43,7 @@ impl<'w> TasoLogger<'w> { let boxed_candidates_writer: Box = Box::new(best_progress_csv_writer); Self { circ_candidates_csv: Some(csv::Writer::from_writer(boxed_candidates_writer)), + ..Default::default() } } @@ -47,19 +61,25 @@ impl<'w> TasoLogger<'w> { #[inline] pub fn log_processing_end( &self, - circuit_count: usize, + circuits_processed: usize, + circuits_seen: Option, best_cost: C, needs_joining: bool, timeout: bool, ) { match timeout { - true => self.log("Optimisation finished (timeout)"), - false => self.log("Optimisation finished"), + true => self.log("Optimisation finished (timeout)."), + false => self.log("Optimisation finished."), }; - self.log(format!("Tried {circuit_count} circuits")); + match circuits_seen { + Some(circuits_seen) => self.log(format!( + "Processed {circuits_processed} circuits (out of {circuits_seen} seen)." + )), + None => self.log(format!("Processed {circuits_processed} circuits.")), + } self.log(format!("---- END RESULT: {:?} ----", best_cost)); if needs_joining { - self.log("Joining worker threads"); + self.log("Joining worker threads."); } } @@ -67,16 +87,21 @@ impl<'w> TasoLogger<'w> { #[inline(always)] pub fn log_progress( &mut self, - circ_cnt: usize, + circuits_processed: usize, workqueue_len: Option, seen_hashes: usize, ) { - if circ_cnt % 1000 == 0 { - self.progress(format!("{circ_cnt} circuits...")); + if circuits_processed > self.last_circ_processed + && Instant::now() - self.last_progress_time > Duration::from_secs(1) + { + self.last_circ_processed = circuits_processed; + self.last_progress_time = Instant::now(); + + self.progress(format!("Processed {circuits_processed} circuits...")); if let Some(workqueue_len) = workqueue_len { - self.progress(format!("Queue size: {workqueue_len} circuits")); + self.progress(format!("Queue size: {workqueue_len} circuits.")); } - self.progress(format!("Total seen: {} circuits", seen_hashes)); + self.progress(format!("Total seen: {} circuits.", seen_hashes)); } } diff --git a/src/optimiser/taso/worker.rs b/src/optimiser/taso/worker.rs index a2aadbf7..7945726c 100644 --- a/src/optimiser/taso/worker.rs +++ b/src/optimiser/taso/worker.rs @@ -4,34 +4,57 @@ use std::thread::{self, JoinHandle}; use hugr::Hugr; +use crate::circuit::cost::CircuitCost; use crate::circuit::CircuitHash; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; +use super::hugr_pchannel::{PriorityChannelCommunication, Work}; + /// A worker that processes circuits for the TASO optimiser. -pub struct TasoWorker { - _phantom: std::marker::PhantomData<(R, S)>, +pub struct TasoWorker { + /// The worker ID. + #[allow(unused)] + id: usize, + /// The channel to send and receive work from. + priority_channel: PriorityChannelCommunication

, + /// The rewriter to use. + rewriter: R, + /// The rewrite strategy to use. + strategy: S, + /// The cost function + cost_fn: C, } -impl TasoWorker +impl TasoWorker where R: Rewriter + Send + 'static, S: RewriteStrategy + Send + 'static, + C: Fn(&Hugr) -> P + Send + Sync + 'static, + P: CircuitCost + Send + Sync + 'static, { /// Spawn a new worker thread. + #[allow(clippy::too_many_arguments)] pub fn spawn( - rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, - tx_result: crossbeam_channel::Sender>, + id: usize, + priority_channel: PriorityChannelCommunication

, rewriter: R, strategy: S, - worker_name: Option, + cost_fn: C, ) -> JoinHandle<()> { - let mut builder = thread::Builder::new(); - if let Some(name) = worker_name { - builder = builder.name(name); - }; - builder - .spawn(move || Self::worker_loop(rx_work, tx_result, rewriter, strategy)) + let name = format!("TasoWorker-{id}"); + thread::Builder::new() + .name(name) + .spawn(move || { + let mut worker = Self { + id, + priority_channel, + rewriter, + strategy, + cost_fn, + }; + worker.run_loop() + }) .unwrap() } @@ -39,29 +62,34 @@ where /// /// Processes work until the main thread closes the channel send or receive /// channel. - #[tracing::instrument(target = "taso::metrics", skip_all)] - fn worker_loop( - rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, - tx_result: crossbeam_channel::Sender>, - rewriter: R, - strategy: S, - ) { - while let Ok((_hash, circ)) = rx_work.recv() { - let hashed_circs = Self::process_circ(circ, &rewriter, &strategy); + #[tracing::instrument(target = "taso::metrics", skip(self))] + fn run_loop(&mut self) { + loop { + let Ok(Work { circ, .. }) = self.priority_channel.recv() else { + break; + }; + + let rewrites = self.rewriter.get_rewrites(&circ); + let circs = self.strategy.apply_rewrites(rewrites, &circ); + let new_circs = circs + .into_iter() + .map(|c| { + let hash = c.circuit_hash(); + let cost = (self.cost_fn)(&c); + Work { + cost, + hash, + circ: c, + } + }) + .collect(); + let send = tracing::trace_span!(target: "taso::metrics", "TasoWorker::send_result") - .in_scope(|| tx_result.send(hashed_circs)); + .in_scope(|| self.priority_channel.send(new_circs)); if send.is_err() { - // The priority queue closed the send channel, we can stop. + // Terminating break; } } } - - /// Process a circuit. - #[tracing::instrument(target = "taso::metrics", skip_all)] - fn process_circ(circ: Hugr, rewriter: &R, strategy: &S) -> Vec<(u64, Hugr)> { - let rewrites = rewriter.get_rewrites(&circ); - let circs = strategy.apply_rewrites(rewrites, &circ); - circs.into_iter().map(|c| (c.circuit_hash(), c)).collect() - } } diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 5d9d3595..9ca1242a 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -128,7 +128,10 @@ fn main() -> Result<(), Box> { ); exit(1); }; - println!("Using {n_threads} threads"); + println!( + "Using {n_threads} threads. Queue size is {}.", + opts.queue_size + ); if opts.split_circ && n_threads.get() > 1 { println!("Splitting circuit into {n_threads} chunks.");