From ae292f50f7ef740b7f080ba26c7a04798576b19c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Thu, 21 Sep 2023 12:52:14 +0200 Subject: [PATCH] feat: Parallel taso with crossbeam (#113) Simplifies the multithreading logic by using a multi-consumer channel. Includes some small improvements to the TASO bin. Now, why is this a draft: I'm testing with `barenco_tof_5_rm` and `Nam_6_3_complete_ECC_set` on my 12-core laptop. Here are some results with a timeout of `10` seconds: | | circuits seen after 10s | best size after 10s | | ----------------- | ------ | ---- | | single-threaded | 7244 | 178 | | -j 2 | 31391 | 204 | | -j 10 | 42690 | 204 | There's something wrong somewhere :P (note that `-j N` means `N` worker threads, plus the master) --------- Co-authored-by: Luca Mondada Co-authored-by: Luca Mondada <72734770+lmondada@users.noreply.github.com> --- Cargo.toml | 1 + compile-matcher/src/main.rs | 9 +- src/optimiser/taso.rs | 542 ++++++++---------- src/optimiser/taso/eq_circ_class.rs | 9 +- src/optimiser/taso/hugr_pchannel.rs | 94 +++ .../taso/{hugr_pq.rs => hugr_pqueue.rs} | 8 +- src/optimiser/taso/qtz_circuit.rs | 11 +- src/optimiser/taso/worker.rs | 31 + src/rewrite/ecc_rewriter.rs | 9 +- taso-optimiser/src/main.rs | 28 +- 10 files changed, 425 insertions(+), 317 deletions(-) create mode 100644 src/optimiser/taso/hugr_pchannel.rs rename src/optimiser/taso/{hugr_pq.rs => hugr_pqueue.rs} (89%) create mode 100644 src/optimiser/taso/worker.rs diff --git a/Cargo.toml b/Cargo.toml index 41ba49bb..fa4abbfc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ csv = { version = "1.2.2" } chrono = { version ="0.4.30" } bytemuck = "1.14.0" stringreader = "0.1.1" +crossbeam-channel = "0.5.8" [features] pyo3 = [ diff --git a/compile-matcher/src/main.rs b/compile-matcher/src/main.rs index 42c4a757..64f3c6d5 100644 --- a/compile-matcher/src/main.rs +++ b/compile-matcher/src/main.rs @@ -1,5 +1,6 @@ use std::fs; use std::path::Path; +use std::process::exit; use clap::Parser; use itertools::Itertools; @@ -42,7 +43,13 @@ fn main() { let all_circs = if input_path.is_file() { // Input is an ECC file in JSON format - let eccs = load_eccs_json_file(input_path); + let Ok(eccs) = load_eccs_json_file(input_path) else { + eprintln!( + "Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?", + input_path + ); + exit(1); + }; eccs.into_iter() .flat_map(|ecc| ecc.into_circuits()) .collect_vec() diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index b1b07031..6a48d942 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -12,25 +12,28 @@ //! it gets too large. mod eq_circ_class; -mod hugr_pq; +mod hugr_pchannel; +mod hugr_pqueue; mod qtz_circuit; +mod worker; +use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; -use std::sync::mpsc::{self, Receiver, SyncSender}; -use std::thread::{self, JoinHandle}; -use std::time::Instant; +use std::num::NonZeroUsize; +use std::time::{Duration, Instant}; use std::{fs, io}; use fxhash::FxHashSet; -use hugr::{Hugr, HugrView}; -use itertools::{izip, Itertools}; +use hugr::Hugr; use crate::circuit::CircuitHash; use crate::json::save_tk1_json_writer; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; -use hugr_pq::{Entry, HugrPQ}; +use hugr_pqueue::{Entry, HugrPQ}; + +use self::hugr_pchannel::HugrPriorityChannel; /// Logging configuration for the TASO optimiser. #[derive(Default)] @@ -76,6 +79,7 @@ impl<'w> LogConfig<'w> { /// /// [Quartz]: https://arxiv.org/abs/2204.09033 /// [TASO]: https://dl.acm.org/doi/10.1145/3341301.3359630 +#[derive(Clone, Debug)] pub struct TasoOptimiser { rewriter: R, strategy: S, @@ -91,24 +95,19 @@ impl TasoOptimiser { cost, } } +} +impl TasoOptimiser +where + R: Rewriter + Send + Clone + 'static, + S: RewriteStrategy + Send + Clone + 'static, + C: Fn(&Hugr) -> usize + Send + Sync + Clone + 'static, +{ /// Run the TASO optimiser on a circuit. /// /// A timeout (in seconds) can be provided. - pub fn optimise(&self, circ: &Hugr, timeout: Option) -> Hugr - where - R: Rewriter, - S: RewriteStrategy, - C: Fn(&Hugr) -> usize, - { - taso( - circ, - &self.rewriter, - &self.strategy, - &self.cost, - Default::default(), - timeout, - ) + pub fn optimise(&self, circ: &Hugr, timeout: Option, n_threads: NonZeroUsize) -> Hugr { + self.optimise_with_log(circ, Default::default(), timeout, n_threads) } /// Run the TASO optimiser on a circuit with logging activated. @@ -119,20 +118,12 @@ impl TasoOptimiser { circ: &Hugr, log_config: LogConfig, timeout: Option, - ) -> Hugr - where - R: Rewriter, - S: RewriteStrategy, - C: Fn(&Hugr) -> usize, - { - taso( - circ, - &self.rewriter, - &self.strategy, - &self.cost, - log_config, - timeout, - ) + n_threads: NonZeroUsize, + ) -> Hugr { + match n_threads.get() { + 1 => self.taso(circ, log_config, timeout), + _ => self.taso_multithreaded(circ, log_config, timeout, n_threads), + } } /// Run the TASO optimiser on a circuit with default logging. @@ -145,298 +136,257 @@ impl TasoOptimiser { /// If the creation of any of these files fails, an error is returned. /// /// A timeout (in seconds) can be provided. - pub fn optimise_with_default_log(&self, circ: &Hugr, timeout: Option) -> io::Result - where - R: Rewriter, - S: RewriteStrategy, - C: Fn(&Hugr) -> usize, - { + pub fn optimise_with_default_log( + &self, + circ: &Hugr, + timeout: Option, + n_threads: NonZeroUsize, + ) -> io::Result { let final_circ_json = fs::File::create("final_circ.json")?; let circ_candidates_csv = fs::File::create("best_circs.csv")?; let progress_log = fs::File::create("taso-optimisation.log")?; let log_config = LogConfig::new(final_circ_json, circ_candidates_csv, progress_log); - Ok(self.optimise_with_log(circ, log_config, timeout)) - } -} - -#[cfg(feature = "portmatching")] -mod taso_default { - use crate::circuit::Circuit; - use crate::rewrite::strategy::ExhaustiveRewriteStrategy; - use crate::rewrite::ECCRewriter; - - use super::*; - - impl TasoOptimiser usize> { - /// A sane default optimiser using the given ECC sets. - pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> Self { - let rewriter = ECCRewriter::from_eccs_json_file(eccs_path); - let strategy = ExhaustiveRewriteStrategy::default(); - Self::new(rewriter, strategy, |c| c.num_gates()) - } + Ok(self.optimise_with_log(circ, log_config, timeout, n_threads)) } -} - -fn taso( - circ: &Hugr, - rewriter: &impl Rewriter, - strategy: &impl RewriteStrategy, - cost: impl Fn(&Hugr) -> usize, - mut log_config: LogConfig, - timeout: Option, -) -> Hugr { - let start_time = Instant::now(); - - let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - - let mut best_circ = circ.clone(); - let mut best_circ_cost = cost(circ); - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - // Hash of seen circuits. Dot not store circuits as this map gets huge - let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); - - // The priority queue of circuits to be processed (this should not get big) - let mut pq = HugrPQ::new(&cost); - - pq.push(circ.clone()); - - let mut circ_cnt = 0; - while let Some(Entry { circ, cost, .. }) = pq.pop() { - if cost < best_circ_cost { - best_circ = circ.clone(); - best_circ_cost = cost; - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - } - - let rewrites = rewriter.get_rewrites(&circ); - for new_circ in strategy.apply_rewrites(rewrites, &circ) { - let new_circ_hash = new_circ.circuit_hash(); - circ_cnt += 1; - if circ_cnt % 1000 == 0 { - log_progress( - log_config.progress_log.as_mut(), - circ_cnt, - &pq, - &seen_hashes, - ) - .expect("Failed to write to progress log"); - } - if seen_hashes.contains(&new_circ_hash) { - continue; - } - pq.push_with_hash_unchecked(new_circ, new_circ_hash); - seen_hashes.insert(new_circ_hash); - } + fn taso(&self, circ: &Hugr, mut log_config: LogConfig, timeout: Option) -> Hugr { + let start_time = Instant::now(); - if pq.len() >= 10000 { - // Haircut to keep the queue size manageable - pq.truncate(5000); - } + let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - if let Some(timeout) = timeout { - if start_time.elapsed().as_secs() > timeout { - println!("Timeout"); - break; - } - } - } + let mut best_circ = circ.clone(); + let mut best_circ_cost = (self.cost)(circ); + log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - log_final( - &best_circ, - log_config.progress_log.as_mut(), - log_config.final_circ_json.as_mut(), - &cost, - ) - .expect("Failed to write to progress log and/or final circuit JSON"); + // Hash of seen circuits. Dot not store circuits as this map gets huge + let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); - best_circ -} + // The priority queue of circuits to be processed (this should not get big) + const PRIORITY_QUEUE_CAPACITY: usize = 10_000; + let mut pq = HugrPQ::with_capacity(&self.cost, PRIORITY_QUEUE_CAPACITY); + pq.push(circ.clone()); -/// Run the TASO optimiser on a circuit. -/// -/// The optimiser will repeatedly rewrite the circuit using the rewriter and -/// the rewrite strategy, optimising the circuit according to the cost function -/// provided. Optionally, a timeout (in seconds) can be provided. -/// -/// A log of the successive best candidate circuits can be found in the file -/// `best_circs.csv`. In addition, the final best circuit is retrievable in the -/// files `final_best_circ.gv` and `final_best_circ.json`. -/// -/// This is the multi-threaded version of the optimiser. See [`TasoOptimiser`] for the -/// single-threaded version. -// TODO Support MPSC and expose in API -#[allow(dead_code)] -fn taso_mpsc( - circ: Hugr, - rewriter: impl Rewriter + Send + Clone + 'static, - strategy: impl RewriteStrategy + Send + Clone + 'static, - cost: impl Fn(&Hugr) -> usize + Send + Sync, - log_config: LogConfig, - timeout: Option, - n_threads: usize, -) -> Hugr { - let start_time = Instant::now(); - - let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - - println!("Spinning up {n_threads} threads"); - - // channel for sending circuits from threads back to main - let (t_main, r_main) = mpsc::sync_channel(n_threads * 100); - - let mut best_circ = circ.clone(); - let mut best_circ_cost = cost(&best_circ); - let circ_hash = circ.circuit_hash(); - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - - // Hash of seen circuits. Dot not store circuits as this map gets huge - let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ_hash)]); - - // The priority queue of circuits to be processed (this should not get big) - let mut pq = HugrPQ::new(&cost); - pq.push(circ); - - // each thread scans for rewrites using all the patterns and - // sends rewritten circuits back to main - let (joins, threads_tx, signal_new_data): (Vec<_>, Vec<_>, Vec<_>) = (0..n_threads) - .map(|_| spawn_pattern_matching_thread(t_main.clone(), rewriter.clone(), strategy.clone())) - .multiunzip(); - - let mut cycle_inds = (0..n_threads).cycle(); - let mut threads_empty = vec![true; n_threads]; - - let mut circ_cnt = 0; - loop { - // Fill each thread workqueue with data from pq - while let Some(Entry { - circ, - cost: &cost, - hash, - }) = pq.peek() - { + let mut circ_cnt = 1; + while let Some(Entry { circ, cost, .. }) = pq.pop() { if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost; log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - // Now we only care about smaller circuits - seen_hashes.clear(); - seen_hashes.insert(hash); - } - // try to send to first available thread - // TODO: Consider using crossbeam-channel - if let Some(next_ind) = cycle_inds.by_ref().take(n_threads).find(|next_ind| { - let tx = &threads_tx[*next_ind]; - tx.try_send(Some(circ.clone())).is_ok() - }) { - pq.pop(); - // Unblock thread if waiting - let _ = signal_new_data[next_ind].try_recv(); - threads_empty[next_ind] = false; - } else { - // All send channels are full, continue - break; } - } - // Receive data from threads, add to pq - // We compute the hashes in the threads because it's expensive - while let Ok(received) = r_main.try_recv() { - let Some((circ_hash, circ)) = received else { - panic!("A thread panicked"); - }; - circ_cnt += 1; - if circ_cnt % 1000 == 0 { - println!("{circ_cnt} circuits..."); - println!("Queue size: {} circuits", pq.len()); - println!("Total seen: {} circuits", seen_hashes.len()); - } - if seen_hashes.contains(&circ_hash) { - continue; + let rewrites = self.rewriter.get_rewrites(&circ); + for new_circ in self.strategy.apply_rewrites(rewrites, &circ) { + let new_circ_hash = new_circ.circuit_hash(); + circ_cnt += 1; + if circ_cnt % 1000 == 0 { + log_progress( + log_config.progress_log.as_mut(), + circ_cnt, + Some(&pq), + &seen_hashes, + ) + .expect("Failed to write to progress log"); + } + if seen_hashes.contains(&new_circ_hash) { + continue; + } + pq.push_with_hash_unchecked(new_circ, new_circ_hash); + seen_hashes.insert(new_circ_hash); } - pq.push_with_hash_unchecked(circ, circ_hash); - seen_hashes.insert(circ_hash); - } - // Check if all threads are waiting for new data - for (is_waiting, is_empty) in signal_new_data.iter().zip(threads_empty.iter_mut()) { - if is_waiting.try_recv().is_ok() { - *is_empty = true; + if pq.len() >= PRIORITY_QUEUE_CAPACITY { + // Haircut to keep the queue size manageable + pq.truncate(PRIORITY_QUEUE_CAPACITY / 2); } - } - // If everyone is waiting and we do not have new data, we are done - if pq.is_empty() && threads_empty.iter().all(|&x| x) { - break; - } - if let Some(timeout) = timeout { - if start_time.elapsed().as_secs() > timeout { - println!("Timeout"); - break; + + if let Some(timeout) = timeout { + if start_time.elapsed().as_secs() > timeout { + println!("Timeout"); + break; + } } } - if pq.len() >= 10000 { - // Haircut to keep the queue size manageable - pq.truncate(5000); - } - } - println!("Tried {circ_cnt} circuits"); - println!("Joining"); + log_processing_end(circ_cnt, false); - for (join, tx, data_tx) in izip!(joins, threads_tx, signal_new_data) { - // tell all the threads we're done and join the threads - tx.send(None).unwrap(); - let _ = data_tx.try_recv(); - join.join().unwrap(); - } + log_final( + &best_circ, + log_config.progress_log.as_mut(), + log_config.final_circ_json.as_mut(), + &self.cost, + ) + .expect("Failed to write to progress log and/or final circuit JSON"); - println!("END RESULT: {}", cost(&best_circ)); - fs::write("final_best_circ.gv", best_circ.dot_string()).unwrap(); - fs::write( - "final_best_circ.json", - serde_json::to_vec(&best_circ).unwrap(), - ) - .unwrap(); - best_circ -} + best_circ + } -fn spawn_pattern_matching_thread( - tx_main: SyncSender>, - rewriter: impl Rewriter + Send + 'static, - strategy: impl RewriteStrategy + Send + 'static, -) -> (JoinHandle<()>, SyncSender>, Receiver<()>) { - // channel for sending circuits to each thread - let (tx_thread, rx) = mpsc::sync_channel(1000); - // A flag to wait until new data - let (wait_new_data, signal_new_data) = mpsc::sync_channel(0); - - let jn = thread::spawn(move || { + /// Run the TASO optimiser on a circuit, using multiple threads. + /// + /// This is the multi-threaded version of [`taso`]. See [`TasoOptimiser`] for + /// more details. + fn taso_multithreaded( + &self, + circ: &Hugr, + mut log_config: LogConfig, + timeout: Option, + n_threads: NonZeroUsize, + ) -> Hugr { + let n_threads: usize = n_threads.get(); + const PRIORITY_QUEUE_CAPACITY: usize = 10_000; + + let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); + + // multi-consumer priority channel for queuing circuits to be processed by the workers + let (tx_work, rx_work) = + HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads); + // channel for sending circuits from threads back to main + let (tx_result, rx_result) = crossbeam_channel::unbounded(); + + let initial_circ_hash = circ.circuit_hash(); + let mut best_circ = circ.clone(); + let mut best_circ_cost = (self.cost)(&best_circ); + log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + + // Hash of seen circuits. Dot not store circuits as this map gets huge + let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]); + + // Each worker waits for circuits to scan for rewrites using all the + // patterns and sends the results back to main. + let joins: Vec<_> = (0..n_threads) + .map(|_| { + worker::spawn_pattern_matching_thread( + rx_work.clone(), + tx_result.clone(), + self.rewriter.clone(), + self.strategy.clone(), + ) + }) + .collect(); + // Drop our copy of the worker channels, so we don't count as a + // connected worker. + drop(rx_work); + drop(tx_result); + + // Queue the initial circuit + tx_work + .send(vec![(initial_circ_hash, circ.clone())]) + .unwrap(); + + // A counter of circuits seen. + let mut circ_cnt = 1; + + // A counter of jobs sent to the workers. + #[allow(unused)] + let mut jobs_sent = 0usize; + // A counter of completed jobs received from the workers. + #[allow(unused)] + let mut jobs_completed = 0usize; + // TODO: Report dropped jobs in the queue, so we can check for termination. + + // Deadline for the optimization timeout + let timeout_event = match timeout { + None => crossbeam_channel::never(), + Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)), + }; + + // Process worker results until we have seen all the circuits, or we run + // out of time. loop { - if let Ok(received) = rx.try_recv() { - let Some(sent_hugr): Option = received else { - // Terminate thread + select! { + recv(rx_result) -> msg => { + match msg { + Ok(hashed_circs) => { + jobs_completed += 1; + for (circ_hash, circ) in &hashed_circs { + circ_cnt += 1; + if circ_cnt % 1000 == 0 { + // TODO: Add a minimum time between logs + log_progress::<_,u64,usize>(log_config.progress_log.as_mut(), circ_cnt, None, &seen_hashes) + .expect("Failed to write to progress log"); + } + if !seen_hashes.insert(*circ_hash) { + continue; + } + + let cost = (self.cost)(circ); + + // Check if we got a new best circuit + if cost < best_circ_cost { + best_circ = circ.clone(); + best_circ_cost = cost; + log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + } + jobs_sent += 1; + } + // Fill the workqueue with data from pq + if tx_work.send(hashed_circs).is_err() { + eprintln!("All our workers panicked. Stopping optimisation."); + break; + }; + + // If there is no more data to process, we are done. + // + // TODO: Report dropped jobs in the workers, so we can check for termination. + //if jobs_sent == jobs_completed { + // break 'main; + //}; + }, + Err(crossbeam_channel::RecvError) => { + eprintln!("All our workers panicked. Stopping optimisation."); + break; + } + } + } + recv(timeout_event) -> _ => { + println!("Timeout"); break; - }; - let rewrites = rewriter.get_rewrites(&sent_hugr); - for new_circ in strategy.apply_rewrites(rewrites, &sent_hugr) { - let new_circ_hash = new_circ.circuit_hash(); - tx_main.send(Some((new_circ_hash, new_circ))).unwrap(); } - } else { - // We are out of work, wait for new data - wait_new_data.send(()).unwrap(); } } - }); - (jn, tx_thread, signal_new_data) + log_processing_end(circ_cnt, true); + + // Drop the channel so the threads know to stop. + drop(tx_work); + let _ = joins; // joins.into_iter().for_each(|j| j.join().unwrap()); + + log_final( + &best_circ, + log_config.progress_log.as_mut(), + log_config.final_circ_json.as_mut(), + &self.cost, + ) + .expect("Failed to write to progress log and/or final circuit JSON"); + + best_circ + } +} + +#[cfg(feature = "portmatching")] +mod taso_default { + use crate::circuit::Circuit; + use crate::rewrite::strategy::ExhaustiveRewriteStrategy; + use crate::rewrite::ECCRewriter; + + use super::*; + + impl TasoOptimiser usize> { + /// A sane default optimiser using the given ECC sets. + pub fn default_with_eccs_json_file( + eccs_path: impl AsRef, + ) -> io::Result { + let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; + let strategy = ExhaustiveRewriteStrategy::default(); + Ok(Self::new(rewriter, strategy, |c| c.num_gates())) + } + } } /// A helper struct for logging improvements in circuit size seen during the /// TASO execution. // // TODO: Replace this fixed logging. Report back intermediate results. -#[derive(serde::Serialize, Debug)] +#[derive(serde::Serialize, Clone, Debug)] struct BestCircSer { circ_len: usize, time: String, @@ -458,15 +408,25 @@ fn log_best(cbest: usize, wtr: Option<&mut csv::Writer>) -> io: wtr.flush() } +fn log_processing_end(circuit_count: usize, needs_joining: bool) { + println!("END"); + println!("Tried {circuit_count} circuits"); + if needs_joining { + println!("Joining"); + } +} + fn log_progress( wr: Option<&mut W>, circ_cnt: usize, - pq: &HugrPQ, + pq: Option<&HugrPQ>, seen_hashes: &FxHashSet, ) -> io::Result<()> { if let Some(wr) = wr { writeln!(wr, "{circ_cnt} circuits...")?; - writeln!(wr, "Queue size: {} circuits", pq.len())?; + if let Some(pq) = pq { + writeln!(wr, "Queue size: {} circuits", pq.len())?; + } writeln!(wr, "Total seen: {} circuits", seen_hashes.len())?; } Ok(()) diff --git a/src/optimiser/taso/eq_circ_class.rs b/src/optimiser/taso/eq_circ_class.rs index 865df40f..f45c2846 100644 --- a/src/optimiser/taso/eq_circ_class.rs +++ b/src/optimiser/taso/eq_circ_class.rs @@ -1,3 +1,4 @@ +use std::io; use std::path::Path; use hugr::Hugr; @@ -77,12 +78,12 @@ impl EqCircClass { } /// Load a set of equivalence classes from a JSON file. -pub fn load_eccs_json_file(path: impl AsRef) -> Vec { - let all_circs = load_ecc_set(path); +pub fn load_eccs_json_file(path: impl AsRef) -> io::Result> { + let all_circs = load_ecc_set(path)?; - all_circs + Ok(all_circs .into_values() .map(EqCircClass::from_circuits) .collect::, _>>() - .unwrap() + .unwrap()) } diff --git a/src/optimiser/taso/hugr_pchannel.rs b/src/optimiser/taso/hugr_pchannel.rs new file mode 100644 index 00000000..c3e157cf --- /dev/null +++ b/src/optimiser/taso/hugr_pchannel.rs @@ -0,0 +1,94 @@ +//! A multi-producer multi-consumer min-priority channel of Hugrs. + +use std::marker::PhantomData; +use std::thread; + +use crossbeam_channel::{select, Receiver, Sender}; +use hugr::Hugr; + +use super::hugr_pqueue::{Entry, HugrPQ}; + +/// A priority channel for HUGRs. +/// +/// Queues hugrs using a cost function `C` that produces priority values `P`. +/// +/// Uses a thread internally to orchestrate the queueing. +pub struct HugrPriorityChannel { + _phantom: PhantomData<(P, C)>, +} + +pub type Item = (u64, Hugr); + +impl HugrPriorityChannel +where + C: Fn(&Hugr) -> P + Send + Sync + 'static, +{ + /// Initialize the queueing system. + /// + /// Get back a channel on which to queue hugrs with their hash, and + /// a channel on which to receive the output. + pub fn init(cost_fn: C, queue_capacity: usize) -> (Sender>, Receiver) { + let (ins, inr) = crossbeam_channel::unbounded(); + let (outs, outr) = crossbeam_channel::bounded(0); + Self::run(inr, outs, cost_fn, queue_capacity); + (ins, outr) + } + + /// Run the queuer as a thread. + fn run( + in_channel_orig: Receiver>, + out_channel_orig: Sender<(u64, Hugr)>, + cost_fn: C, + queue_capacity: usize, + ) { + let builder = thread::Builder::new().name("priority queueing".into()); + let in_channel = in_channel_orig.clone(); + let out_channel = out_channel_orig.clone(); + let _ = builder + .spawn(move || { + // The priority queue, local to this thread. + let mut pq: HugrPQ = + HugrPQ::with_capacity(cost_fn, queue_capacity); + + loop { + if pq.is_empty() { + // Nothing queued to go out. Wait for input. + match in_channel.recv() { + Ok(new_circs) => { + for (hash, circ) in new_circs { + pq.push_with_hash_unchecked(circ, hash); + } + } + // The sender has closed the channel, we can stop. + Err(_) => break, + } + } else { + select! { + recv(in_channel) -> result => { + match result { + Ok(new_circs) => { + for (hash, circ) in new_circs { + pq.push_with_hash_unchecked(circ, hash); + } + } + // The sender has closed the channel, we can stop. + Err(_) => break, + } + } + send(out_channel, {let Entry {hash, circ, ..} = pq.pop().unwrap(); (hash, circ)}) -> result => { + match result { + Ok(()) => {}, + // The receivers have closed the channel, we can stop. + Err(_) => break, + } + } + } + } + if pq.len() >= queue_capacity { + pq.truncate(queue_capacity / 2); + } + } + }) + .unwrap(); + } +} diff --git a/src/optimiser/taso/hugr_pq.rs b/src/optimiser/taso/hugr_pqueue.rs similarity index 89% rename from src/optimiser/taso/hugr_pq.rs rename to src/optimiser/taso/hugr_pqueue.rs index 26eea0eb..eba000e0 100644 --- a/src/optimiser/taso/hugr_pq.rs +++ b/src/optimiser/taso/hugr_pqueue.rs @@ -19,20 +19,22 @@ pub(super) struct HugrPQ { pub(super) struct Entry { pub(super) circ: C, pub(super) cost: P, + #[allow(unused)] // TODO remove? pub(super) hash: H, } impl HugrPQ { - /// Create a new HugrPQ with a cost function. - pub(super) fn new(cost_fn: C) -> Self { + /// Create a new HugrPQ with a cost function and some initial capacity. + pub(super) fn with_capacity(cost_fn: C, capacity: usize) -> Self { Self { - queue: DoublePriorityQueue::new(), + queue: DoublePriorityQueue::with_capacity(capacity), hash_lookup: Default::default(), cost_fn, } } /// Reference to the minimal Hugr in the queue. + #[allow(unused)] pub(super) fn peek(&self) -> Option> { let (hash, cost) = self.queue.peek_min()?; let circ = self.hash_lookup.get(hash)?; diff --git a/src/optimiser/taso/qtz_circuit.rs b/src/optimiser/taso/qtz_circuit.rs index 5c4d01b7..a665a518 100644 --- a/src/optimiser/taso/qtz_circuit.rs +++ b/src/optimiser/taso/qtz_circuit.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::io; use std::path::Path; use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; @@ -109,19 +110,19 @@ impl From for Circuit { } } -pub(super) fn load_ecc_set(path: impl AsRef) -> HashMap> { - let jsons = std::fs::read_to_string(path).unwrap(); +pub(super) fn load_ecc_set(path: impl AsRef) -> io::Result>> { + let jsons = std::fs::read_to_string(path)?; let (_, ecc_map): (Vec<()>, HashMap>) = serde_json::from_str(&jsons).unwrap(); - ecc_map + Ok(ecc_map .into_values() .map(|datmap| { let id = datmap[0].meta.id[0].clone(); let circs = datmap.into_iter().map(|rcd| rcd.into()).collect(); (id, circs) }) - .collect() + .collect()) } #[cfg(test)] @@ -155,7 +156,7 @@ mod tests { #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri fn test_read_complete() { let _ecc: HashMap> = - load_ecc_set("test_files/h_rz_cxcomplete_ECC_set.json"); + load_ecc_set("test_files/h_rz_cxcomplete_ECC_set.json").unwrap(); // ecc.values() // .flatten() diff --git a/src/optimiser/taso/worker.rs b/src/optimiser/taso/worker.rs new file mode 100644 index 00000000..4472031e --- /dev/null +++ b/src/optimiser/taso/worker.rs @@ -0,0 +1,31 @@ +//! Distributed workers for the taso optimiser. + +use std::thread::{self, JoinHandle}; + +use hugr::Hugr; + +use crate::circuit::CircuitHash; +use crate::rewrite::strategy::RewriteStrategy; +use crate::rewrite::Rewriter; + +pub fn spawn_pattern_matching_thread( + rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, + tx_result: crossbeam_channel::Sender>, + rewriter: impl Rewriter + Send + 'static, + strategy: impl RewriteStrategy + Send + 'static, +) -> JoinHandle<()> { + thread::spawn(move || { + // Process work until the main thread closes the channel send or receive + // channel. + while let Ok((_hash, circ)) = rx_work.recv() { + let rewrites = rewriter.get_rewrites(&circ); + let circs = strategy.apply_rewrites(rewrites, &circ); + let hashed_circs = circs.into_iter().map(|c| (c.circuit_hash(), c)).collect(); + let send = tx_result.send(hashed_circs); + if send.is_err() { + // The main thread closed the send channel, we can stop. + break; + } + } + }) +} diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index 40f1b7eb..afd4ff10 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -15,6 +15,7 @@ use derive_more::{From, Into}; use itertools::Itertools; use portmatching::PatternID; +use std::io; use std::path::Path; use hugr::Hugr; @@ -56,9 +57,9 @@ impl ECCRewriter { /// the Quartz repository. /// /// Quartz: . - pub fn from_eccs_json_file(path: impl AsRef) -> Self { - let eccs = load_eccs_json_file(path); - Self::from_eccs(eccs) + pub fn try_from_eccs_json_file(path: impl AsRef) -> io::Result { + let eccs = load_eccs_json_file(path)?; + Ok(Self::from_eccs(eccs)) } /// Create a new rewriter from a list of equivalent circuit classes. @@ -211,7 +212,7 @@ mod tests { // In this example, all circuits are valid patterns, thus // PatternID == TargetID. let test_file = "test_files/small_eccs.json"; - let rewriter = ECCRewriter::from_eccs_json_file(test_file); + let rewriter = ECCRewriter::try_from_eccs_json_file(test_file).unwrap(); assert_eq!(rewriter.rewrite_rules.len(), rewriter.matcher.n_patterns()); assert_eq!(rewriter.targets.len(), 5 * 4 + 4 * 3); diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 09b022a1..4c277f8e 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroUsize; +use std::process::exit; use std::{fs, io, path::Path}; use clap::Parser; @@ -59,11 +61,10 @@ struct CmdLineArgs { #[arg( short = 'j', long, - default_value = "1", value_name = "N_THREADS", - help = "The number of threads to use. Currently only single-threaded TASO is supported." + help = "The number of threads to use. By default, the number of threads is equal to the number of logical cores." )] - n_threads: usize, + n_threads: Option, } fn save_tk1_json_file(path: impl AsRef, circ: &Hugr) -> Result<(), std::io::Error> { @@ -84,15 +85,24 @@ fn main() { let circ = load_tk1_json_file(input_path).unwrap(); println!("Compiling rewriter..."); - let optimiser = if opts.n_threads == 1 { - println!("Using single-threaded TASO"); - TasoOptimiser::default_with_eccs_json_file(ecc_path) - } else { - unimplemented!("Multi-threaded TASO has been disabled until fixed"); + let Ok(optimiser) = TasoOptimiser::default_with_eccs_json_file(ecc_path) else { + eprintln!( + "Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?", + ecc_path + ); + exit(1); }; + + let n_threads = opts + .n_threads + // TODO: Default to multithreading once that produces better results. + //.or_else(|| std::thread::available_parallelism().ok()) + .unwrap_or(NonZeroUsize::new(1).unwrap()); + println!("Using {n_threads} threads"); + println!("Optimising..."); let opt_circ = optimiser - .optimise_with_default_log(&circ, opts.timeout) + .optimise_with_default_log(&circ, opts.timeout, n_threads) .unwrap(); println!("Saving result");