From 4c76248da7ce3e7d2d169d39897611f49be6fbeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Thu, 21 Sep 2023 12:52:14 +0200 Subject: [PATCH 1/6] feat: Parallel taso with crossbeam (#113) Simplifies the multithreading logic by using a multi-consumer channel. Includes some small improvements to the TASO bin. Now, why is this a draft: I'm testing with `barenco_tof_5_rm` and `Nam_6_3_complete_ECC_set` on my 12-core laptop. Here are some results with a timeout of `10` seconds: | | circuits seen after 10s | best size after 10s | | ----------------- | ------ | ---- | | single-threaded | 7244 | 178 | | -j 2 | 31391 | 204 | | -j 10 | 42690 | 204 | There's something wrong somewhere :P (note that `-j N` means `N` worker threads, plus the master) --------- Co-authored-by: Luca Mondada Co-authored-by: Luca Mondada <72734770+lmondada@users.noreply.github.com> --- Cargo.toml | 1 + compile-matcher/src/main.rs | 9 +- src/optimiser/taso.rs | 542 ++++++++---------- src/optimiser/taso/eq_circ_class.rs | 9 +- src/optimiser/taso/hugr_pchannel.rs | 94 +++ .../taso/{hugr_pq.rs => hugr_pqueue.rs} | 8 +- src/optimiser/taso/qtz_circuit.rs | 11 +- src/optimiser/taso/worker.rs | 31 + src/rewrite/ecc_rewriter.rs | 9 +- taso-optimiser/src/main.rs | 28 +- 10 files changed, 425 insertions(+), 317 deletions(-) create mode 100644 src/optimiser/taso/hugr_pchannel.rs rename src/optimiser/taso/{hugr_pq.rs => hugr_pqueue.rs} (89%) create mode 100644 src/optimiser/taso/worker.rs diff --git a/Cargo.toml b/Cargo.toml index 41ba49bb..fa4abbfc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ csv = { version = "1.2.2" } chrono = { version ="0.4.30" } bytemuck = "1.14.0" stringreader = "0.1.1" +crossbeam-channel = "0.5.8" [features] pyo3 = [ diff --git a/compile-matcher/src/main.rs b/compile-matcher/src/main.rs index 42c4a757..64f3c6d5 100644 --- a/compile-matcher/src/main.rs +++ b/compile-matcher/src/main.rs @@ -1,5 +1,6 @@ use std::fs; use std::path::Path; +use std::process::exit; use clap::Parser; use itertools::Itertools; @@ -42,7 +43,13 @@ fn main() { let all_circs = if input_path.is_file() { // Input is an ECC file in JSON format - let eccs = load_eccs_json_file(input_path); + let Ok(eccs) = load_eccs_json_file(input_path) else { + eprintln!( + "Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?", + input_path + ); + exit(1); + }; eccs.into_iter() .flat_map(|ecc| ecc.into_circuits()) .collect_vec() diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index b1b07031..6a48d942 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -12,25 +12,28 @@ //! it gets too large. mod eq_circ_class; -mod hugr_pq; +mod hugr_pchannel; +mod hugr_pqueue; mod qtz_circuit; +mod worker; +use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; -use std::sync::mpsc::{self, Receiver, SyncSender}; -use std::thread::{self, JoinHandle}; -use std::time::Instant; +use std::num::NonZeroUsize; +use std::time::{Duration, Instant}; use std::{fs, io}; use fxhash::FxHashSet; -use hugr::{Hugr, HugrView}; -use itertools::{izip, Itertools}; +use hugr::Hugr; use crate::circuit::CircuitHash; use crate::json::save_tk1_json_writer; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; -use hugr_pq::{Entry, HugrPQ}; +use hugr_pqueue::{Entry, HugrPQ}; + +use self::hugr_pchannel::HugrPriorityChannel; /// Logging configuration for the TASO optimiser. #[derive(Default)] @@ -76,6 +79,7 @@ impl<'w> LogConfig<'w> { /// /// [Quartz]: https://arxiv.org/abs/2204.09033 /// [TASO]: https://dl.acm.org/doi/10.1145/3341301.3359630 +#[derive(Clone, Debug)] pub struct TasoOptimiser { rewriter: R, strategy: S, @@ -91,24 +95,19 @@ impl TasoOptimiser { cost, } } +} +impl TasoOptimiser +where + R: Rewriter + Send + Clone + 'static, + S: RewriteStrategy + Send + Clone + 'static, + C: Fn(&Hugr) -> usize + Send + Sync + Clone + 'static, +{ /// Run the TASO optimiser on a circuit. /// /// A timeout (in seconds) can be provided. - pub fn optimise(&self, circ: &Hugr, timeout: Option) -> Hugr - where - R: Rewriter, - S: RewriteStrategy, - C: Fn(&Hugr) -> usize, - { - taso( - circ, - &self.rewriter, - &self.strategy, - &self.cost, - Default::default(), - timeout, - ) + pub fn optimise(&self, circ: &Hugr, timeout: Option, n_threads: NonZeroUsize) -> Hugr { + self.optimise_with_log(circ, Default::default(), timeout, n_threads) } /// Run the TASO optimiser on a circuit with logging activated. @@ -119,20 +118,12 @@ impl TasoOptimiser { circ: &Hugr, log_config: LogConfig, timeout: Option, - ) -> Hugr - where - R: Rewriter, - S: RewriteStrategy, - C: Fn(&Hugr) -> usize, - { - taso( - circ, - &self.rewriter, - &self.strategy, - &self.cost, - log_config, - timeout, - ) + n_threads: NonZeroUsize, + ) -> Hugr { + match n_threads.get() { + 1 => self.taso(circ, log_config, timeout), + _ => self.taso_multithreaded(circ, log_config, timeout, n_threads), + } } /// Run the TASO optimiser on a circuit with default logging. @@ -145,298 +136,257 @@ impl TasoOptimiser { /// If the creation of any of these files fails, an error is returned. /// /// A timeout (in seconds) can be provided. - pub fn optimise_with_default_log(&self, circ: &Hugr, timeout: Option) -> io::Result - where - R: Rewriter, - S: RewriteStrategy, - C: Fn(&Hugr) -> usize, - { + pub fn optimise_with_default_log( + &self, + circ: &Hugr, + timeout: Option, + n_threads: NonZeroUsize, + ) -> io::Result { let final_circ_json = fs::File::create("final_circ.json")?; let circ_candidates_csv = fs::File::create("best_circs.csv")?; let progress_log = fs::File::create("taso-optimisation.log")?; let log_config = LogConfig::new(final_circ_json, circ_candidates_csv, progress_log); - Ok(self.optimise_with_log(circ, log_config, timeout)) - } -} - -#[cfg(feature = "portmatching")] -mod taso_default { - use crate::circuit::Circuit; - use crate::rewrite::strategy::ExhaustiveRewriteStrategy; - use crate::rewrite::ECCRewriter; - - use super::*; - - impl TasoOptimiser usize> { - /// A sane default optimiser using the given ECC sets. - pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> Self { - let rewriter = ECCRewriter::from_eccs_json_file(eccs_path); - let strategy = ExhaustiveRewriteStrategy::default(); - Self::new(rewriter, strategy, |c| c.num_gates()) - } + Ok(self.optimise_with_log(circ, log_config, timeout, n_threads)) } -} - -fn taso( - circ: &Hugr, - rewriter: &impl Rewriter, - strategy: &impl RewriteStrategy, - cost: impl Fn(&Hugr) -> usize, - mut log_config: LogConfig, - timeout: Option, -) -> Hugr { - let start_time = Instant::now(); - - let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - - let mut best_circ = circ.clone(); - let mut best_circ_cost = cost(circ); - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - // Hash of seen circuits. Dot not store circuits as this map gets huge - let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); - - // The priority queue of circuits to be processed (this should not get big) - let mut pq = HugrPQ::new(&cost); - - pq.push(circ.clone()); - - let mut circ_cnt = 0; - while let Some(Entry { circ, cost, .. }) = pq.pop() { - if cost < best_circ_cost { - best_circ = circ.clone(); - best_circ_cost = cost; - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - } - - let rewrites = rewriter.get_rewrites(&circ); - for new_circ in strategy.apply_rewrites(rewrites, &circ) { - let new_circ_hash = new_circ.circuit_hash(); - circ_cnt += 1; - if circ_cnt % 1000 == 0 { - log_progress( - log_config.progress_log.as_mut(), - circ_cnt, - &pq, - &seen_hashes, - ) - .expect("Failed to write to progress log"); - } - if seen_hashes.contains(&new_circ_hash) { - continue; - } - pq.push_with_hash_unchecked(new_circ, new_circ_hash); - seen_hashes.insert(new_circ_hash); - } + fn taso(&self, circ: &Hugr, mut log_config: LogConfig, timeout: Option) -> Hugr { + let start_time = Instant::now(); - if pq.len() >= 10000 { - // Haircut to keep the queue size manageable - pq.truncate(5000); - } + let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - if let Some(timeout) = timeout { - if start_time.elapsed().as_secs() > timeout { - println!("Timeout"); - break; - } - } - } + let mut best_circ = circ.clone(); + let mut best_circ_cost = (self.cost)(circ); + log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - log_final( - &best_circ, - log_config.progress_log.as_mut(), - log_config.final_circ_json.as_mut(), - &cost, - ) - .expect("Failed to write to progress log and/or final circuit JSON"); + // Hash of seen circuits. Dot not store circuits as this map gets huge + let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); - best_circ -} + // The priority queue of circuits to be processed (this should not get big) + const PRIORITY_QUEUE_CAPACITY: usize = 10_000; + let mut pq = HugrPQ::with_capacity(&self.cost, PRIORITY_QUEUE_CAPACITY); + pq.push(circ.clone()); -/// Run the TASO optimiser on a circuit. -/// -/// The optimiser will repeatedly rewrite the circuit using the rewriter and -/// the rewrite strategy, optimising the circuit according to the cost function -/// provided. Optionally, a timeout (in seconds) can be provided. -/// -/// A log of the successive best candidate circuits can be found in the file -/// `best_circs.csv`. In addition, the final best circuit is retrievable in the -/// files `final_best_circ.gv` and `final_best_circ.json`. -/// -/// This is the multi-threaded version of the optimiser. See [`TasoOptimiser`] for the -/// single-threaded version. -// TODO Support MPSC and expose in API -#[allow(dead_code)] -fn taso_mpsc( - circ: Hugr, - rewriter: impl Rewriter + Send + Clone + 'static, - strategy: impl RewriteStrategy + Send + Clone + 'static, - cost: impl Fn(&Hugr) -> usize + Send + Sync, - log_config: LogConfig, - timeout: Option, - n_threads: usize, -) -> Hugr { - let start_time = Instant::now(); - - let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - - println!("Spinning up {n_threads} threads"); - - // channel for sending circuits from threads back to main - let (t_main, r_main) = mpsc::sync_channel(n_threads * 100); - - let mut best_circ = circ.clone(); - let mut best_circ_cost = cost(&best_circ); - let circ_hash = circ.circuit_hash(); - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - - // Hash of seen circuits. Dot not store circuits as this map gets huge - let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ_hash)]); - - // The priority queue of circuits to be processed (this should not get big) - let mut pq = HugrPQ::new(&cost); - pq.push(circ); - - // each thread scans for rewrites using all the patterns and - // sends rewritten circuits back to main - let (joins, threads_tx, signal_new_data): (Vec<_>, Vec<_>, Vec<_>) = (0..n_threads) - .map(|_| spawn_pattern_matching_thread(t_main.clone(), rewriter.clone(), strategy.clone())) - .multiunzip(); - - let mut cycle_inds = (0..n_threads).cycle(); - let mut threads_empty = vec![true; n_threads]; - - let mut circ_cnt = 0; - loop { - // Fill each thread workqueue with data from pq - while let Some(Entry { - circ, - cost: &cost, - hash, - }) = pq.peek() - { + let mut circ_cnt = 1; + while let Some(Entry { circ, cost, .. }) = pq.pop() { if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost; log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - // Now we only care about smaller circuits - seen_hashes.clear(); - seen_hashes.insert(hash); - } - // try to send to first available thread - // TODO: Consider using crossbeam-channel - if let Some(next_ind) = cycle_inds.by_ref().take(n_threads).find(|next_ind| { - let tx = &threads_tx[*next_ind]; - tx.try_send(Some(circ.clone())).is_ok() - }) { - pq.pop(); - // Unblock thread if waiting - let _ = signal_new_data[next_ind].try_recv(); - threads_empty[next_ind] = false; - } else { - // All send channels are full, continue - break; } - } - // Receive data from threads, add to pq - // We compute the hashes in the threads because it's expensive - while let Ok(received) = r_main.try_recv() { - let Some((circ_hash, circ)) = received else { - panic!("A thread panicked"); - }; - circ_cnt += 1; - if circ_cnt % 1000 == 0 { - println!("{circ_cnt} circuits..."); - println!("Queue size: {} circuits", pq.len()); - println!("Total seen: {} circuits", seen_hashes.len()); - } - if seen_hashes.contains(&circ_hash) { - continue; + let rewrites = self.rewriter.get_rewrites(&circ); + for new_circ in self.strategy.apply_rewrites(rewrites, &circ) { + let new_circ_hash = new_circ.circuit_hash(); + circ_cnt += 1; + if circ_cnt % 1000 == 0 { + log_progress( + log_config.progress_log.as_mut(), + circ_cnt, + Some(&pq), + &seen_hashes, + ) + .expect("Failed to write to progress log"); + } + if seen_hashes.contains(&new_circ_hash) { + continue; + } + pq.push_with_hash_unchecked(new_circ, new_circ_hash); + seen_hashes.insert(new_circ_hash); } - pq.push_with_hash_unchecked(circ, circ_hash); - seen_hashes.insert(circ_hash); - } - // Check if all threads are waiting for new data - for (is_waiting, is_empty) in signal_new_data.iter().zip(threads_empty.iter_mut()) { - if is_waiting.try_recv().is_ok() { - *is_empty = true; + if pq.len() >= PRIORITY_QUEUE_CAPACITY { + // Haircut to keep the queue size manageable + pq.truncate(PRIORITY_QUEUE_CAPACITY / 2); } - } - // If everyone is waiting and we do not have new data, we are done - if pq.is_empty() && threads_empty.iter().all(|&x| x) { - break; - } - if let Some(timeout) = timeout { - if start_time.elapsed().as_secs() > timeout { - println!("Timeout"); - break; + + if let Some(timeout) = timeout { + if start_time.elapsed().as_secs() > timeout { + println!("Timeout"); + break; + } } } - if pq.len() >= 10000 { - // Haircut to keep the queue size manageable - pq.truncate(5000); - } - } - println!("Tried {circ_cnt} circuits"); - println!("Joining"); + log_processing_end(circ_cnt, false); - for (join, tx, data_tx) in izip!(joins, threads_tx, signal_new_data) { - // tell all the threads we're done and join the threads - tx.send(None).unwrap(); - let _ = data_tx.try_recv(); - join.join().unwrap(); - } + log_final( + &best_circ, + log_config.progress_log.as_mut(), + log_config.final_circ_json.as_mut(), + &self.cost, + ) + .expect("Failed to write to progress log and/or final circuit JSON"); - println!("END RESULT: {}", cost(&best_circ)); - fs::write("final_best_circ.gv", best_circ.dot_string()).unwrap(); - fs::write( - "final_best_circ.json", - serde_json::to_vec(&best_circ).unwrap(), - ) - .unwrap(); - best_circ -} + best_circ + } -fn spawn_pattern_matching_thread( - tx_main: SyncSender>, - rewriter: impl Rewriter + Send + 'static, - strategy: impl RewriteStrategy + Send + 'static, -) -> (JoinHandle<()>, SyncSender>, Receiver<()>) { - // channel for sending circuits to each thread - let (tx_thread, rx) = mpsc::sync_channel(1000); - // A flag to wait until new data - let (wait_new_data, signal_new_data) = mpsc::sync_channel(0); - - let jn = thread::spawn(move || { + /// Run the TASO optimiser on a circuit, using multiple threads. + /// + /// This is the multi-threaded version of [`taso`]. See [`TasoOptimiser`] for + /// more details. + fn taso_multithreaded( + &self, + circ: &Hugr, + mut log_config: LogConfig, + timeout: Option, + n_threads: NonZeroUsize, + ) -> Hugr { + let n_threads: usize = n_threads.get(); + const PRIORITY_QUEUE_CAPACITY: usize = 10_000; + + let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); + + // multi-consumer priority channel for queuing circuits to be processed by the workers + let (tx_work, rx_work) = + HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads); + // channel for sending circuits from threads back to main + let (tx_result, rx_result) = crossbeam_channel::unbounded(); + + let initial_circ_hash = circ.circuit_hash(); + let mut best_circ = circ.clone(); + let mut best_circ_cost = (self.cost)(&best_circ); + log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + + // Hash of seen circuits. Dot not store circuits as this map gets huge + let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]); + + // Each worker waits for circuits to scan for rewrites using all the + // patterns and sends the results back to main. + let joins: Vec<_> = (0..n_threads) + .map(|_| { + worker::spawn_pattern_matching_thread( + rx_work.clone(), + tx_result.clone(), + self.rewriter.clone(), + self.strategy.clone(), + ) + }) + .collect(); + // Drop our copy of the worker channels, so we don't count as a + // connected worker. + drop(rx_work); + drop(tx_result); + + // Queue the initial circuit + tx_work + .send(vec![(initial_circ_hash, circ.clone())]) + .unwrap(); + + // A counter of circuits seen. + let mut circ_cnt = 1; + + // A counter of jobs sent to the workers. + #[allow(unused)] + let mut jobs_sent = 0usize; + // A counter of completed jobs received from the workers. + #[allow(unused)] + let mut jobs_completed = 0usize; + // TODO: Report dropped jobs in the queue, so we can check for termination. + + // Deadline for the optimization timeout + let timeout_event = match timeout { + None => crossbeam_channel::never(), + Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)), + }; + + // Process worker results until we have seen all the circuits, or we run + // out of time. loop { - if let Ok(received) = rx.try_recv() { - let Some(sent_hugr): Option = received else { - // Terminate thread + select! { + recv(rx_result) -> msg => { + match msg { + Ok(hashed_circs) => { + jobs_completed += 1; + for (circ_hash, circ) in &hashed_circs { + circ_cnt += 1; + if circ_cnt % 1000 == 0 { + // TODO: Add a minimum time between logs + log_progress::<_,u64,usize>(log_config.progress_log.as_mut(), circ_cnt, None, &seen_hashes) + .expect("Failed to write to progress log"); + } + if !seen_hashes.insert(*circ_hash) { + continue; + } + + let cost = (self.cost)(circ); + + // Check if we got a new best circuit + if cost < best_circ_cost { + best_circ = circ.clone(); + best_circ_cost = cost; + log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + } + jobs_sent += 1; + } + // Fill the workqueue with data from pq + if tx_work.send(hashed_circs).is_err() { + eprintln!("All our workers panicked. Stopping optimisation."); + break; + }; + + // If there is no more data to process, we are done. + // + // TODO: Report dropped jobs in the workers, so we can check for termination. + //if jobs_sent == jobs_completed { + // break 'main; + //}; + }, + Err(crossbeam_channel::RecvError) => { + eprintln!("All our workers panicked. Stopping optimisation."); + break; + } + } + } + recv(timeout_event) -> _ => { + println!("Timeout"); break; - }; - let rewrites = rewriter.get_rewrites(&sent_hugr); - for new_circ in strategy.apply_rewrites(rewrites, &sent_hugr) { - let new_circ_hash = new_circ.circuit_hash(); - tx_main.send(Some((new_circ_hash, new_circ))).unwrap(); } - } else { - // We are out of work, wait for new data - wait_new_data.send(()).unwrap(); } } - }); - (jn, tx_thread, signal_new_data) + log_processing_end(circ_cnt, true); + + // Drop the channel so the threads know to stop. + drop(tx_work); + let _ = joins; // joins.into_iter().for_each(|j| j.join().unwrap()); + + log_final( + &best_circ, + log_config.progress_log.as_mut(), + log_config.final_circ_json.as_mut(), + &self.cost, + ) + .expect("Failed to write to progress log and/or final circuit JSON"); + + best_circ + } +} + +#[cfg(feature = "portmatching")] +mod taso_default { + use crate::circuit::Circuit; + use crate::rewrite::strategy::ExhaustiveRewriteStrategy; + use crate::rewrite::ECCRewriter; + + use super::*; + + impl TasoOptimiser usize> { + /// A sane default optimiser using the given ECC sets. + pub fn default_with_eccs_json_file( + eccs_path: impl AsRef, + ) -> io::Result { + let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; + let strategy = ExhaustiveRewriteStrategy::default(); + Ok(Self::new(rewriter, strategy, |c| c.num_gates())) + } + } } /// A helper struct for logging improvements in circuit size seen during the /// TASO execution. // // TODO: Replace this fixed logging. Report back intermediate results. -#[derive(serde::Serialize, Debug)] +#[derive(serde::Serialize, Clone, Debug)] struct BestCircSer { circ_len: usize, time: String, @@ -458,15 +408,25 @@ fn log_best(cbest: usize, wtr: Option<&mut csv::Writer>) -> io: wtr.flush() } +fn log_processing_end(circuit_count: usize, needs_joining: bool) { + println!("END"); + println!("Tried {circuit_count} circuits"); + if needs_joining { + println!("Joining"); + } +} + fn log_progress( wr: Option<&mut W>, circ_cnt: usize, - pq: &HugrPQ, + pq: Option<&HugrPQ>, seen_hashes: &FxHashSet, ) -> io::Result<()> { if let Some(wr) = wr { writeln!(wr, "{circ_cnt} circuits...")?; - writeln!(wr, "Queue size: {} circuits", pq.len())?; + if let Some(pq) = pq { + writeln!(wr, "Queue size: {} circuits", pq.len())?; + } writeln!(wr, "Total seen: {} circuits", seen_hashes.len())?; } Ok(()) diff --git a/src/optimiser/taso/eq_circ_class.rs b/src/optimiser/taso/eq_circ_class.rs index 865df40f..f45c2846 100644 --- a/src/optimiser/taso/eq_circ_class.rs +++ b/src/optimiser/taso/eq_circ_class.rs @@ -1,3 +1,4 @@ +use std::io; use std::path::Path; use hugr::Hugr; @@ -77,12 +78,12 @@ impl EqCircClass { } /// Load a set of equivalence classes from a JSON file. -pub fn load_eccs_json_file(path: impl AsRef) -> Vec { - let all_circs = load_ecc_set(path); +pub fn load_eccs_json_file(path: impl AsRef) -> io::Result> { + let all_circs = load_ecc_set(path)?; - all_circs + Ok(all_circs .into_values() .map(EqCircClass::from_circuits) .collect::, _>>() - .unwrap() + .unwrap()) } diff --git a/src/optimiser/taso/hugr_pchannel.rs b/src/optimiser/taso/hugr_pchannel.rs new file mode 100644 index 00000000..c3e157cf --- /dev/null +++ b/src/optimiser/taso/hugr_pchannel.rs @@ -0,0 +1,94 @@ +//! A multi-producer multi-consumer min-priority channel of Hugrs. + +use std::marker::PhantomData; +use std::thread; + +use crossbeam_channel::{select, Receiver, Sender}; +use hugr::Hugr; + +use super::hugr_pqueue::{Entry, HugrPQ}; + +/// A priority channel for HUGRs. +/// +/// Queues hugrs using a cost function `C` that produces priority values `P`. +/// +/// Uses a thread internally to orchestrate the queueing. +pub struct HugrPriorityChannel { + _phantom: PhantomData<(P, C)>, +} + +pub type Item = (u64, Hugr); + +impl HugrPriorityChannel +where + C: Fn(&Hugr) -> P + Send + Sync + 'static, +{ + /// Initialize the queueing system. + /// + /// Get back a channel on which to queue hugrs with their hash, and + /// a channel on which to receive the output. + pub fn init(cost_fn: C, queue_capacity: usize) -> (Sender>, Receiver) { + let (ins, inr) = crossbeam_channel::unbounded(); + let (outs, outr) = crossbeam_channel::bounded(0); + Self::run(inr, outs, cost_fn, queue_capacity); + (ins, outr) + } + + /// Run the queuer as a thread. + fn run( + in_channel_orig: Receiver>, + out_channel_orig: Sender<(u64, Hugr)>, + cost_fn: C, + queue_capacity: usize, + ) { + let builder = thread::Builder::new().name("priority queueing".into()); + let in_channel = in_channel_orig.clone(); + let out_channel = out_channel_orig.clone(); + let _ = builder + .spawn(move || { + // The priority queue, local to this thread. + let mut pq: HugrPQ = + HugrPQ::with_capacity(cost_fn, queue_capacity); + + loop { + if pq.is_empty() { + // Nothing queued to go out. Wait for input. + match in_channel.recv() { + Ok(new_circs) => { + for (hash, circ) in new_circs { + pq.push_with_hash_unchecked(circ, hash); + } + } + // The sender has closed the channel, we can stop. + Err(_) => break, + } + } else { + select! { + recv(in_channel) -> result => { + match result { + Ok(new_circs) => { + for (hash, circ) in new_circs { + pq.push_with_hash_unchecked(circ, hash); + } + } + // The sender has closed the channel, we can stop. + Err(_) => break, + } + } + send(out_channel, {let Entry {hash, circ, ..} = pq.pop().unwrap(); (hash, circ)}) -> result => { + match result { + Ok(()) => {}, + // The receivers have closed the channel, we can stop. + Err(_) => break, + } + } + } + } + if pq.len() >= queue_capacity { + pq.truncate(queue_capacity / 2); + } + } + }) + .unwrap(); + } +} diff --git a/src/optimiser/taso/hugr_pq.rs b/src/optimiser/taso/hugr_pqueue.rs similarity index 89% rename from src/optimiser/taso/hugr_pq.rs rename to src/optimiser/taso/hugr_pqueue.rs index 26eea0eb..eba000e0 100644 --- a/src/optimiser/taso/hugr_pq.rs +++ b/src/optimiser/taso/hugr_pqueue.rs @@ -19,20 +19,22 @@ pub(super) struct HugrPQ { pub(super) struct Entry { pub(super) circ: C, pub(super) cost: P, + #[allow(unused)] // TODO remove? pub(super) hash: H, } impl HugrPQ { - /// Create a new HugrPQ with a cost function. - pub(super) fn new(cost_fn: C) -> Self { + /// Create a new HugrPQ with a cost function and some initial capacity. + pub(super) fn with_capacity(cost_fn: C, capacity: usize) -> Self { Self { - queue: DoublePriorityQueue::new(), + queue: DoublePriorityQueue::with_capacity(capacity), hash_lookup: Default::default(), cost_fn, } } /// Reference to the minimal Hugr in the queue. + #[allow(unused)] pub(super) fn peek(&self) -> Option> { let (hash, cost) = self.queue.peek_min()?; let circ = self.hash_lookup.get(hash)?; diff --git a/src/optimiser/taso/qtz_circuit.rs b/src/optimiser/taso/qtz_circuit.rs index 5c4d01b7..a665a518 100644 --- a/src/optimiser/taso/qtz_circuit.rs +++ b/src/optimiser/taso/qtz_circuit.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::io; use std::path::Path; use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; @@ -109,19 +110,19 @@ impl From for Circuit { } } -pub(super) fn load_ecc_set(path: impl AsRef) -> HashMap> { - let jsons = std::fs::read_to_string(path).unwrap(); +pub(super) fn load_ecc_set(path: impl AsRef) -> io::Result>> { + let jsons = std::fs::read_to_string(path)?; let (_, ecc_map): (Vec<()>, HashMap>) = serde_json::from_str(&jsons).unwrap(); - ecc_map + Ok(ecc_map .into_values() .map(|datmap| { let id = datmap[0].meta.id[0].clone(); let circs = datmap.into_iter().map(|rcd| rcd.into()).collect(); (id, circs) }) - .collect() + .collect()) } #[cfg(test)] @@ -155,7 +156,7 @@ mod tests { #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri fn test_read_complete() { let _ecc: HashMap> = - load_ecc_set("test_files/h_rz_cxcomplete_ECC_set.json"); + load_ecc_set("test_files/h_rz_cxcomplete_ECC_set.json").unwrap(); // ecc.values() // .flatten() diff --git a/src/optimiser/taso/worker.rs b/src/optimiser/taso/worker.rs new file mode 100644 index 00000000..4472031e --- /dev/null +++ b/src/optimiser/taso/worker.rs @@ -0,0 +1,31 @@ +//! Distributed workers for the taso optimiser. + +use std::thread::{self, JoinHandle}; + +use hugr::Hugr; + +use crate::circuit::CircuitHash; +use crate::rewrite::strategy::RewriteStrategy; +use crate::rewrite::Rewriter; + +pub fn spawn_pattern_matching_thread( + rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, + tx_result: crossbeam_channel::Sender>, + rewriter: impl Rewriter + Send + 'static, + strategy: impl RewriteStrategy + Send + 'static, +) -> JoinHandle<()> { + thread::spawn(move || { + // Process work until the main thread closes the channel send or receive + // channel. + while let Ok((_hash, circ)) = rx_work.recv() { + let rewrites = rewriter.get_rewrites(&circ); + let circs = strategy.apply_rewrites(rewrites, &circ); + let hashed_circs = circs.into_iter().map(|c| (c.circuit_hash(), c)).collect(); + let send = tx_result.send(hashed_circs); + if send.is_err() { + // The main thread closed the send channel, we can stop. + break; + } + } + }) +} diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index 40f1b7eb..afd4ff10 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -15,6 +15,7 @@ use derive_more::{From, Into}; use itertools::Itertools; use portmatching::PatternID; +use std::io; use std::path::Path; use hugr::Hugr; @@ -56,9 +57,9 @@ impl ECCRewriter { /// the Quartz repository. /// /// Quartz: . - pub fn from_eccs_json_file(path: impl AsRef) -> Self { - let eccs = load_eccs_json_file(path); - Self::from_eccs(eccs) + pub fn try_from_eccs_json_file(path: impl AsRef) -> io::Result { + let eccs = load_eccs_json_file(path)?; + Ok(Self::from_eccs(eccs)) } /// Create a new rewriter from a list of equivalent circuit classes. @@ -211,7 +212,7 @@ mod tests { // In this example, all circuits are valid patterns, thus // PatternID == TargetID. let test_file = "test_files/small_eccs.json"; - let rewriter = ECCRewriter::from_eccs_json_file(test_file); + let rewriter = ECCRewriter::try_from_eccs_json_file(test_file).unwrap(); assert_eq!(rewriter.rewrite_rules.len(), rewriter.matcher.n_patterns()); assert_eq!(rewriter.targets.len(), 5 * 4 + 4 * 3); diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 09b022a1..4c277f8e 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -1,3 +1,5 @@ +use std::num::NonZeroUsize; +use std::process::exit; use std::{fs, io, path::Path}; use clap::Parser; @@ -59,11 +61,10 @@ struct CmdLineArgs { #[arg( short = 'j', long, - default_value = "1", value_name = "N_THREADS", - help = "The number of threads to use. Currently only single-threaded TASO is supported." + help = "The number of threads to use. By default, the number of threads is equal to the number of logical cores." )] - n_threads: usize, + n_threads: Option, } fn save_tk1_json_file(path: impl AsRef, circ: &Hugr) -> Result<(), std::io::Error> { @@ -84,15 +85,24 @@ fn main() { let circ = load_tk1_json_file(input_path).unwrap(); println!("Compiling rewriter..."); - let optimiser = if opts.n_threads == 1 { - println!("Using single-threaded TASO"); - TasoOptimiser::default_with_eccs_json_file(ecc_path) - } else { - unimplemented!("Multi-threaded TASO has been disabled until fixed"); + let Ok(optimiser) = TasoOptimiser::default_with_eccs_json_file(ecc_path) else { + eprintln!( + "Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?", + ecc_path + ); + exit(1); }; + + let n_threads = opts + .n_threads + // TODO: Default to multithreading once that produces better results. + //.or_else(|| std::thread::available_parallelism().ok()) + .unwrap_or(NonZeroUsize::new(1).unwrap()); + println!("Using {n_threads} threads"); + println!("Optimising..."); let opt_circ = optimiser - .optimise_with_default_log(&circ, opts.timeout) + .optimise_with_default_log(&circ, opts.timeout, n_threads) .unwrap(); println!("Saving result"); From 2aeba3efe3fbbef205dc76b1136ede4ecee680ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Thu, 21 Sep 2023 16:51:13 +0200 Subject: [PATCH 2/6] feat: add more command helper methods (#125) - qubits units iterators - linear_unit_port query - check if a port is linear - implement traits without requiring bounds on Circ Also, publicly export `circuit::units` (this required adding a missing doc). Closes #124. Closes #123 --- src/circuit.rs | 2 +- src/circuit/command.rs | 58 ++++++++++++++++++++++++++++++++++++- src/circuit/units/filter.rs | 1 + 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/src/circuit.rs b/src/circuit.rs index 7ea2aeef..8f9e9a22 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -2,7 +2,7 @@ pub mod command; mod hash; -mod units; +pub mod units; pub use command::{Command, CommandIterator}; pub use hash::CircuitHash; diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 7ca21054..b830b313 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -19,7 +19,6 @@ pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; pub use hugr::{Direction, Node, Port, Wire}; /// An operation applied to specific wires. -#[derive(Eq, PartialOrd, Ord, Hash)] pub struct Command<'circ, Circ> { /// The circuit. circ: &'circ Circ, @@ -63,6 +62,24 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { Units::new(self.circ, self.node, direction, self).filter_units::() } + /// Returns the linear units of this command in a given direction. + #[inline] + pub fn qubits(&self, direction: Direction) -> FilteredUnits { + Units::new(self.circ, self.node, direction, self).filter_units::() + } + + /// Returns the linear units of this command in a given direction. + #[inline] + pub fn input_qubits(&self) -> FilteredUnits { + self.qubits(Direction::Incoming) + } + + /// Returns the linear units of this command in a given direction. + #[inline] + pub fn output_qubits(&self) -> FilteredUnits { + self.qubits(Direction::Outgoing) + } + /// Returns the units and wires of this command in a given direction. #[inline] pub fn unit_wires( @@ -121,6 +138,22 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { pub fn output_count(&self) -> usize { self.optype().signature().output_count() } + + /// Returns the port in the command given a linear unit. + #[inline] + pub fn linear_unit_port(&self, unit: LinearUnit, direction: Direction) -> Option { + self.linear_units(direction) + .find(|(cu, _, _)| *cu == unit) + .map(|(_, port, _)| port) + } + + /// Returns whether the port is a linear port. + #[inline] + pub fn is_linear_port(&self, port: Port) -> bool { + self.optype() + .port_kind(port) + .map_or(false, |kind| kind.is_linear()) + } } impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { @@ -162,6 +195,8 @@ impl<'circ, Circ> PartialEq for Command<'circ, Circ> { } } +impl<'circ, Circ> Eq for Command<'circ, Circ> {} + impl<'circ, Circ> Clone for Command<'circ, Circ> { fn clone(&self) -> Self { Self { @@ -172,6 +207,27 @@ impl<'circ, Circ> Clone for Command<'circ, Circ> { } } +impl<'circ, Circ> std::hash::Hash for Command<'circ, Circ> { + fn hash(&self, state: &mut H) { + self.node.hash(state); + self.linear_units.hash(state); + } +} + +impl<'circ, Circ> PartialOrd for Command<'circ, Circ> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl<'circ, Circ> Ord for Command<'circ, Circ> { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.node + .cmp(&other.node) + .then(self.linear_units.cmp(&other.linear_units)) + } +} + /// A non-borrowing topological walker over the nodes of a circuit. type NodeWalker = pv::Topo>; diff --git a/src/circuit/units/filter.rs b/src/circuit/units/filter.rs index fe485704..04d6a6a6 100644 --- a/src/circuit/units/filter.rs +++ b/src/circuit/units/filter.rs @@ -16,6 +16,7 @@ pub type FilteredUnits = std::iter::FilterMap< /// A filter over a [`Units`] iterator. pub trait UnitFilter { + /// The item yielded by the filtered iterator. type Item; /// Filter a [`Units`] iterator item, and unwrap it into a `Self::Item` if From f7e92372bc5f1be9c2c3b548666762e22b604dd3 Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Thu, 21 Sep 2023 17:11:48 +0100 Subject: [PATCH 3/6] feat: Support copies in pattern matching (#127) --- Cargo.toml | 2 +- src/circuit/units.rs | 8 +- src/portmatching.rs | 130 ++++++++++++++++++++++++++++++- src/portmatching/matcher.rs | 61 ++++++++++----- src/portmatching/pattern.rs | 150 ++++++++++++++++++++++++++++++------ src/utils.rs | 5 ++ 6 files changed, 306 insertions(+), 50 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fa4abbfc..f582ded7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ itertools = { workspace = true } petgraph = { version = "0.6.3", default-features = false } serde_yaml = "0.9.22" # portmatching = { version = "0.2.0", optional = true, features = ["serde"]} -portmatching = { optional = true, git = "https://github.com/lmondada/portmatching", rev = "c4ad0ec", features = [ +portmatching = { optional = true, git = "https://github.com/lmondada/portmatching", rev = "738c91c", features = [ "serde", ] } derive_more = "0.99.17" diff --git a/src/circuit/units.rs b/src/circuit/units.rs index 222705a1..cd09ab55 100644 --- a/src/circuit/units.rs +++ b/src/circuit/units.rs @@ -19,9 +19,11 @@ use std::iter::FusedIterator; use hugr::hugr::CircuitUnit; use hugr::ops::OpTrait; -use hugr::types::{EdgeKind, Type, TypeBound, TypeRow}; +use hugr::types::{EdgeKind, Type, TypeRow}; use hugr::{Direction, Node, Port, Wire}; +use crate::utils::type_is_linear; + use self::filter::UnitFilter; use super::Circuit; @@ -215,7 +217,3 @@ impl UnitLabeller for DefaultUnitLabeller { } } } - -fn type_is_linear(typ: &Type) -> bool { - !TypeBound::Copyable.contains(typ.least_upper_bound()) -} diff --git a/src/portmatching.rs b/src/portmatching.rs index 6cf40f42..423c612c 100644 --- a/src/portmatching.rs +++ b/src/portmatching.rs @@ -5,11 +5,137 @@ pub mod pattern; #[cfg(feature = "pyo3")] pub mod pyo3; +use itertools::Itertools; pub use matcher::{PatternMatch, PatternMatcher}; pub use pattern::CircuitPattern; -use hugr::Port; +use hugr::{ + ops::{OpTag, OpTrait}, + Node, Port, +}; use matcher::MatchOp; +use thiserror::Error; + +use crate::{circuit::Circuit, utils::type_is_linear}; -type PEdge = (Port, Port); type PNode = MatchOp; + +/// An edge property in a circuit pattern. +/// +/// Edges are +/// Edges are reversible if the edge type is linear. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize, +)] +enum PEdge { + /// A "normal" edge between src and dst within a pattern. + InternalEdge { + src: Port, + dst: Port, + is_reversible: bool, + }, + /// An edge from a copied input to src. + /// + /// Edges from inputs are typically not matched as part of the pattern, + /// unless a single input is copied multiple times. In this case, an + /// InputEdge is used to link the source port to the (usually hidden) + /// copy node. + /// + /// Input edges are always irreversible. + InputEdge { src: Port }, +} + +#[derive(Debug, Clone, Error)] +enum InvalidEdgeProperty { + /// The port is linked to multiple edges. + #[error("port {0:?} is linked to multiple edges")] + AmbiguousEdge(Port), + /// The port is not linked to any edge. + #[error("port {0:?} is not linked to any edge")] + NoLinkedEdge(Port), + /// The port does not have a type. + #[error("port {0:?} does not have a type")] + UntypedPort(Port), +} + +impl PEdge { + fn try_from_port( + node: Node, + port: Port, + circ: &impl Circuit, + ) -> Result { + let src = port; + let (dst_node, dst) = circ + .linked_ports(node, src) + .exactly_one() + .map_err(|mut e| { + if e.next().is_some() { + InvalidEdgeProperty::AmbiguousEdge(src) + } else { + InvalidEdgeProperty::NoLinkedEdge(src) + } + })?; + if circ.get_optype(dst_node).tag() == OpTag::Input { + return Ok(Self::InputEdge { src }); + } + let port_type = circ + .get_optype(node) + .signature() + .get(src) + .cloned() + .ok_or(InvalidEdgeProperty::UntypedPort(src))?; + let is_reversible = type_is_linear(&port_type); + Ok(Self::InternalEdge { + src, + dst, + is_reversible, + }) + } +} + +impl portmatching::EdgeProperty for PEdge { + type OffsetID = Port; + + fn reverse(&self) -> Option { + match *self { + Self::InternalEdge { + src, + dst, + is_reversible, + } => is_reversible.then_some(Self::InternalEdge { + src: dst, + dst: src, + is_reversible, + }), + Self::InputEdge { .. } => None, + } + } + + fn offset_id(&self) -> Self::OffsetID { + match *self { + Self::InternalEdge { src, .. } => src, + Self::InputEdge { src, .. } => src, + } + } +} + +/// A node in a pattern. +/// +/// A node is either a real node in the HUGR graph or a hidden copy node +/// that is identified by its node and outgoing port. +/// +/// A NodeID::CopyNode can only be found as a target of a PEdge::InputEdge +/// property. Furthermore, a NodeID::CopyNode never has a node property. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize, +)] +pub(super) enum NodeID { + HugrNode(Node), + CopyNode(Node, Port), +} + +impl From for NodeID { + fn from(node: Node) -> Self { + Self::HugrNode(node) + } +} diff --git a/src/portmatching/matcher.rs b/src/portmatching/matcher.rs index 30075593..960ea346 100644 --- a/src/portmatching/matcher.rs +++ b/src/portmatching/matcher.rs @@ -7,13 +7,13 @@ use std::{ path::{Path, PathBuf}, }; -use super::{CircuitPattern, PEdge, PNode}; +use super::{CircuitPattern, NodeID, PEdge, PNode}; use hugr::hugr::views::sibling_subgraph::{ConvexChecker, InvalidReplacement, InvalidSubgraph}; use hugr::{hugr::views::SiblingSubgraph, ops::OpType, Hugr, Node, Port}; use itertools::Itertools; use portmatching::{ automaton::{LineBuilder, ScopeAutomaton}, - PatternID, + EdgeProperty, PatternID, }; use thiserror::Error; @@ -30,6 +30,7 @@ use crate::{ /// Matchable operations in a circuit. /// /// We currently support [`T2Op`] and a the HUGR load constant operation. +// TODO: Support OpType::Const, but blocked by use of F64 (Eq support required) #[derive( Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize, )] @@ -257,11 +258,11 @@ impl PatternMatcher { ) -> Vec { self.automaton .run( - root, + root.into(), // Node weights (none) - validate_weighted_node(circ), + validate_circuit_node(circ), // Check edge exist - validate_unweighted_edge(circ), + validate_circuit_edge(circ), ) .filter_map(|pattern_id| { handle_match_error( @@ -380,28 +381,48 @@ impl From for InvalidPatternMatch { } } -fn compatible_offsets((_, pout): &(Port, Port), (pin, _): &(Port, Port)) -> bool { - pout.direction() != pin.direction() && pout.index() == pin.index() +fn compatible_offsets(e1: &PEdge, e2: &PEdge) -> bool { + let PEdge::InternalEdge { dst: dst1, .. } = e1 else { + return false; + }; + let src2 = e2.offset_id(); + dst1.direction() != src2.direction() && dst1.index() == src2.index() } -/// Check if an edge `e` is valid in a portgraph `g` without weights. -pub(crate) fn validate_unweighted_edge( +/// Returns a predicate checking that an edge at `src` satisfies `prop` in `circ`. +pub(super) fn validate_circuit_edge( circ: &impl Circuit, -) -> impl for<'a> Fn(Node, &'a PEdge) -> Option + '_ { - move |src, &(src_port, tgt_port)| { - let (next_node, _) = circ - .linked_ports(src, src_port) - .find(|&(_, tgt)| tgt == tgt_port)?; - Some(next_node) +) -> impl for<'a> Fn(NodeID, &'a PEdge) -> Option + '_ { + move |src, &prop| { + let NodeID::HugrNode(src) = src else { + return None; + }; + match prop { + PEdge::InternalEdge { + src: src_port, + dst: dst_port, + .. + } => { + let (next_node, next_port) = circ.linked_ports(src, src_port).exactly_one().ok()?; + (dst_port == next_port).then_some(NodeID::HugrNode(next_node)) + } + PEdge::InputEdge { src: src_port } => { + let (next_node, next_port) = circ.linked_ports(src, src_port).exactly_one().ok()?; + Some(NodeID::CopyNode(next_node, next_port)) + } + } } } -/// Check if a node `n` is valid in a weighted portgraph `g`. -pub(crate) fn validate_weighted_node( +/// Returns a predicate checking that `node` satisfies `prop` in `circ`. +pub(crate) fn validate_circuit_node( circ: &impl Circuit, -) -> impl for<'a> Fn(Node, &PNode) -> bool + '_ { - move |v, prop| { - let v_weight = MatchOp::try_from(circ.get_optype(v).clone()); +) -> impl for<'a> Fn(NodeID, &PNode) -> bool + '_ { + move |node, prop| { + let NodeID::HugrNode(node) = node else { + return false; + }; + let v_weight = MatchOp::try_from(circ.get_optype(node).clone()); v_weight.is_ok_and(|w| &w == prop) } } diff --git a/src/portmatching/pattern.rs b/src/portmatching/pattern.rs index cf97d739..60b0ab1c 100644 --- a/src/portmatching/pattern.rs +++ b/src/portmatching/pattern.rs @@ -7,10 +7,10 @@ use std::fmt::Debug; use thiserror::Error; use super::{ - matcher::{validate_unweighted_edge, validate_weighted_node}, + matcher::{validate_circuit_edge, validate_circuit_node}, PEdge, PNode, }; -use crate::circuit::Circuit; +use crate::{circuit::Circuit, portmatching::NodeID}; #[cfg(feature = "pyo3")] use pyo3::{create_exception, exceptions::PyException, pyclass, PyErr}; @@ -19,7 +19,7 @@ use pyo3::{create_exception, exceptions::PyException, pyclass, PyErr}; #[cfg_attr(feature = "pyo3", pyclass)] #[derive(Clone, serde::Serialize, serde::Deserialize)] pub struct CircuitPattern { - pub(super) pattern: Pattern, + pub(super) pattern: Pattern, /// The input ports pub(super) inputs: Vec>, /// The output ports @@ -40,14 +40,21 @@ impl CircuitPattern { let mut pattern = Pattern::new(); for cmd in circuit.commands() { let op = cmd.optype().clone(); - pattern.require(cmd.node(), op.try_into().unwrap()); - for out_offset in 0..cmd.output_count() { - let out_offset = Port::new_outgoing(out_offset); - for (next_node, in_offset) in circuit.linked_ports(cmd.node(), out_offset) { - if circuit.get_optype(next_node).tag() != hugr::ops::OpTag::Output { - pattern.add_edge(cmd.node(), next_node, (out_offset, in_offset)); - } - } + pattern.require(cmd.node().into(), op.try_into().unwrap()); + for in_offset in 0..cmd.input_count() { + let in_offset = Port::new_incoming(in_offset); + let edge_prop = + PEdge::try_from_port(cmd.node(), in_offset, circuit).expect("Invalid HUGR"); + let (prev_node, prev_port) = circuit + .linked_ports(cmd.node(), in_offset) + .exactly_one() + .ok() + .expect("invalid HUGR"); + let prev_node = match edge_prop { + PEdge::InternalEdge { .. } => NodeID::HugrNode(prev_node), + PEdge::InputEdge { .. } => NodeID::CopyNode(prev_node, prev_port), + }; + pattern.add_edge(cmd.node().into(), prev_node, edge_prop); } } pattern.set_any_root()?; @@ -83,15 +90,25 @@ impl CircuitPattern { } /// Compute the map from pattern nodes to circuit nodes in `circ`. - pub fn get_match_map(&self, root: Node, circ: &C) -> Option> { + pub fn get_match_map(&self, root: Node, circ: &impl Circuit) -> Option> { let single_matcher = SinglePatternMatcher::from_pattern(self.pattern.clone()); single_matcher .get_match_map( - root, - validate_weighted_node(circ), - validate_unweighted_edge(circ), + root.into(), + validate_circuit_node(circ), + validate_circuit_edge(circ), ) - .map(|m| m.into_iter().collect()) + .map(|m| { + m.into_iter() + .filter_map(|(node_p, node_c)| match (node_p, node_c) { + (NodeID::HugrNode(node_p), NodeID::HugrNode(node_c)) => { + Some((node_p, node_c)) + } + (NodeID::CopyNode(..), NodeID::CopyNode(..)) => None, + _ => panic!("Invalid match map"), + }) + .collect() + }) } } @@ -136,9 +153,17 @@ impl From for PyErr { #[cfg(test)] mod tests { + + use std::collections::HashSet; + + use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; + use hugr::extension::prelude::QB_T; + use hugr::ops::LeafOp; + use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; + use hugr::types::FunctionType; use hugr::Hugr; - use itertools::Itertools; + use crate::extension::REGISTRY; use crate::utils::build_simple_circuit; use crate::T2Op; @@ -153,23 +178,68 @@ mod tests { .unwrap() } + /// A circuit with two rotation gates in sequence, sharing a param + fn circ_with_copy() -> Hugr { + let input_t = vec![QB_T, FLOAT64_TYPE]; + let output_t = vec![QB_T]; + let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); + + let mut inps = h.input_wires(); + let qb = inps.next().unwrap(); + let f = inps.next().unwrap(); + + let res = h.add_dataflow_op(T2Op::RxF64, [qb, f]).unwrap(); + let qb = res.outputs().next().unwrap(); + let res = h.add_dataflow_op(T2Op::RxF64, [qb, f]).unwrap(); + let qb = res.outputs().next().unwrap(); + + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + } + + /// A circuit with two rotation gates in parallel, sharing a param + fn circ_with_copy_disconnected() -> Hugr { + let input_t = vec![QB_T, QB_T, FLOAT64_TYPE]; + let output_t = vec![QB_T, QB_T]; + let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); + + let mut inps = h.input_wires(); + let qb1 = inps.next().unwrap(); + let qb2 = inps.next().unwrap(); + let f = inps.next().unwrap(); + + let res = h.add_dataflow_op(T2Op::RxF64, [qb1, f]).unwrap(); + let qb1 = res.outputs().next().unwrap(); + let res = h.add_dataflow_op(T2Op::RxF64, [qb2, f]).unwrap(); + let qb2 = res.outputs().next().unwrap(); + + h.finish_hugr_with_outputs([qb1, qb2], ®ISTRY).unwrap() + } + #[test] fn construct_pattern() { let hugr = h_cx(); let p = CircuitPattern::try_from_circuit(&hugr).unwrap(); - let edges = p + let edges: HashSet<_> = p .pattern .edges() .unwrap() .iter() .map(|e| (e.source.unwrap(), e.target.unwrap())) - .collect_vec(); + .collect(); + let inp = hugr.input(); + let cx_gate = NodeID::HugrNode(get_nodes_by_t2op(&hugr, T2Op::CX)[0]); + let h_gate = NodeID::HugrNode(get_nodes_by_t2op(&hugr, T2Op::H)[0]); assert_eq!( - // How would I construct hugr::Nodes for testing here? - edges.len(), - 1 + edges, + [ + (cx_gate, h_gate), + (cx_gate, NodeID::CopyNode(inp, Port::new_outgoing(0))), + (cx_gate, NodeID::CopyNode(inp, Port::new_outgoing(1))), + ] + .into_iter() + .collect() ) } @@ -199,4 +269,40 @@ mod tests { InvalidPattern::NotConnected ); } + + fn get_nodes_by_t2op(circ: &impl Circuit, t2_op: T2Op) -> Vec { + circ.nodes() + .filter(|n| { + let Ok(op): Result = circ.get_optype(*n).clone().try_into() else { + return false; + }; + op == t2_op.into() + }) + .collect() + } + + #[test] + fn pattern_with_copy() { + let circ = circ_with_copy(); + let pattern = CircuitPattern::try_from_circuit(&circ).unwrap(); + let edges = pattern.pattern.edges().unwrap(); + let rx_ns = get_nodes_by_t2op(&circ, T2Op::RxF64); + let inp = circ.input(); + for rx_n in rx_ns { + assert!(edges.iter().any(|e| { + e.reverse().is_none() + && e.source.unwrap() == rx_n.into() + && e.target.unwrap() == NodeID::CopyNode(inp, Port::new_outgoing(1)) + })); + } + } + + #[test] + fn pattern_with_copy_disconnected() { + let circ = circ_with_copy_disconnected(); + assert_eq!( + CircuitPattern::try_from_circuit(&circ).unwrap_err(), + InvalidPattern::NotConnected + ); + } } diff --git a/src/utils.rs b/src/utils.rs index c560d63e..124cda37 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,6 +1,7 @@ //! Utility functions for the library. use hugr::extension::PRELUDE_REGISTRY; +use hugr::types::{Type, TypeBound}; use hugr::{ builder::{BuildError, CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::QB_T, @@ -8,6 +9,10 @@ use hugr::{ Hugr, }; +pub(crate) fn type_is_linear(typ: &Type) -> bool { + !TypeBound::Copyable.contains(typ.least_upper_bound()) +} + // utility for building simple qubit-only circuits. #[allow(unused)] pub(crate) fn build_simple_circuit( From 6378d33eab756f4dd0bb4210247cfb7e14fb2478 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Fri, 22 Sep 2023 16:12:29 +0200 Subject: [PATCH 4/6] feat: Tracing instrumentation and structured logging (#115) - Adds some minimal `tracing` instrumentation. The idea is to do some multithreaded perf debugging in the future based on that. - Leverages the same infrastructure to refactor the TASO logger, now events are logged into the global logger, and the executable subscribes to either general events (for stdout), or more granular data (for the logfile). --- Cargo.toml | 2 + src/optimiser/taso.rs | 238 ++++++---------------------- src/optimiser/taso/hugr_pchannel.rs | 1 + src/optimiser/taso/log.rs | 111 +++++++++++++ src/optimiser/taso/worker.rs | 64 ++++++-- src/rewrite/strategy.rs | 2 + taso-optimiser/Cargo.toml | 5 +- taso-optimiser/src/main.rs | 51 ++++-- taso-optimiser/src/tracing.rs | 73 +++++++++ 9 files changed, 331 insertions(+), 216 deletions(-) create mode 100644 src/optimiser/taso/log.rs create mode 100644 taso-optimiser/src/tracing.rs diff --git a/Cargo.toml b/Cargo.toml index f582ded7..6161abcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ chrono = { version ="0.4.30" } bytemuck = "1.14.0" stringreader = "0.1.1" crossbeam-channel = "0.5.8" +tracing = { workspace = true } [features] pyo3 = [ @@ -79,3 +80,4 @@ itertools = { version = "0.11.0" } tket-json-rs = { git = "https://github.com/CQCL/tket-json-rs", rev = "619db15d3", features = [ "tket2ops", ] } +tracing = "0.1.37" diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 6a48d942..66c91ab6 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -14,54 +14,28 @@ mod eq_circ_class; mod hugr_pchannel; mod hugr_pqueue; +pub mod log; mod qtz_circuit; mod worker; use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; +use std::io; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; -use std::{fs, io}; use fxhash::FxHashSet; use hugr::Hugr; use crate::circuit::CircuitHash; -use crate::json::save_tk1_json_writer; +use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel; +use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ}; +use crate::optimiser::taso::worker::TasoWorker; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; -use hugr_pqueue::{Entry, HugrPQ}; -use self::hugr_pchannel::HugrPriorityChannel; - -/// Logging configuration for the TASO optimiser. -#[derive(Default)] -pub struct LogConfig<'w> { - final_circ_json: Option>, - circ_candidates_csv: Option>, - progress_log: Option>, -} - -impl<'w> LogConfig<'w> { - /// Create a new logging configuration. - /// - /// Three writer objects must be provided: - /// - best_circ_json: for the final optimised circuit, in TK1 JSON format, - /// - circ_candidates_csv: for a log of the successive best candidate circuits, - /// - progress_log: for a log of the progress of the optimisation. - pub fn new( - best_circ_json: impl io::Write + 'w, - circ_candidates_csv: impl io::Write + 'w, - progress_log: impl io::Write + 'w, - ) -> Self { - Self { - final_circ_json: Some(Box::new(best_circ_json)), - circ_candidates_csv: Some(Box::new(circ_candidates_csv)), - progress_log: Some(Box::new(progress_log)), - } - } -} +use self::log::TasoLogger; /// The TASO optimiser. /// @@ -116,7 +90,7 @@ where pub fn optimise_with_log( &self, circ: &Hugr, - log_config: LogConfig, + log_config: TasoLogger, timeout: Option, n_threads: NonZeroUsize, ) -> Hugr { @@ -126,37 +100,13 @@ where } } - /// Run the TASO optimiser on a circuit with default logging. - /// - /// The following files will be created: - /// - `final_circ.json`: the final optimised circuit, in TK1 JSON format, - /// - `best_circs.csv`: a log of the successive best candidate circuits, - /// - `taso-optimisation.log`: a log of the progress of the optimisation. - /// - /// If the creation of any of these files fails, an error is returned. - /// - /// A timeout (in seconds) can be provided. - pub fn optimise_with_default_log( - &self, - circ: &Hugr, - timeout: Option, - n_threads: NonZeroUsize, - ) -> io::Result { - let final_circ_json = fs::File::create("final_circ.json")?; - let circ_candidates_csv = fs::File::create("best_circs.csv")?; - let progress_log = fs::File::create("taso-optimisation.log")?; - let log_config = LogConfig::new(final_circ_json, circ_candidates_csv, progress_log); - Ok(self.optimise_with_log(circ, log_config, timeout, n_threads)) - } - - fn taso(&self, circ: &Hugr, mut log_config: LogConfig, timeout: Option) -> Hugr { + #[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))] + fn taso(&self, circ: &Hugr, mut logger: TasoLogger, timeout: Option) -> Hugr { let start_time = Instant::now(); - let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - let mut best_circ = circ.clone(); let mut best_circ_cost = (self.cost)(circ); - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + logger.log_best(best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); @@ -167,26 +117,19 @@ where pq.push(circ.clone()); let mut circ_cnt = 1; + let mut timeout_flag = false; while let Some(Entry { circ, cost, .. }) = pq.pop() { if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost; - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + logger.log_best(best_circ_cost); } let rewrites = self.rewriter.get_rewrites(&circ); for new_circ in self.strategy.apply_rewrites(rewrites, &circ) { let new_circ_hash = new_circ.circuit_hash(); circ_cnt += 1; - if circ_cnt % 1000 == 0 { - log_progress( - log_config.progress_log.as_mut(), - circ_cnt, - Some(&pq), - &seen_hashes, - ) - .expect("Failed to write to progress log"); - } + logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len()); if seen_hashes.contains(&new_circ_hash) { continue; } @@ -201,22 +144,13 @@ where if let Some(timeout) = timeout { if start_time.elapsed().as_secs() > timeout { - println!("Timeout"); + timeout_flag = true; break; } } } - log_processing_end(circ_cnt, false); - - log_final( - &best_circ, - log_config.progress_log.as_mut(), - log_config.final_circ_json.as_mut(), - &self.cost, - ) - .expect("Failed to write to progress log and/or final circuit JSON"); - + logger.log_processing_end(circ_cnt, best_circ_cost, false, timeout_flag); best_circ } @@ -224,18 +158,17 @@ where /// /// This is the multi-threaded version of [`taso`]. See [`TasoOptimiser`] for /// more details. + #[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))] fn taso_multithreaded( &self, circ: &Hugr, - mut log_config: LogConfig, + mut logger: TasoLogger, timeout: Option, n_threads: NonZeroUsize, ) -> Hugr { let n_threads: usize = n_threads.get(); const PRIORITY_QUEUE_CAPACITY: usize = 10_000; - let mut log_candidates = log_config.circ_candidates_csv.map(csv::Writer::from_writer); - // multi-consumer priority channel for queuing circuits to be processed by the workers let (tx_work, rx_work) = HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads); @@ -245,7 +178,7 @@ where let initial_circ_hash = circ.circuit_hash(); let mut best_circ = circ.clone(); let mut best_circ_cost = (self.cost)(&best_circ); - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); + logger.log_best(best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]); @@ -253,12 +186,13 @@ where // Each worker waits for circuits to scan for rewrites using all the // patterns and sends the results back to main. let joins: Vec<_> = (0..n_threads) - .map(|_| { - worker::spawn_pattern_matching_thread( + .map(|i| { + TasoWorker::spawn( rx_work.clone(), tx_result.clone(), self.rewriter.clone(), self.strategy.clone(), + Some(format!("taso-worker-{i}")), ) }) .collect(); @@ -283,7 +217,7 @@ where let mut jobs_completed = 0usize; // TODO: Report dropped jobs in the queue, so we can check for termination. - // Deadline for the optimization timeout + // Deadline for the optimisation timeout let timeout_event = match timeout { None => crossbeam_channel::never(), Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)), @@ -291,38 +225,39 @@ where // Process worker results until we have seen all the circuits, or we run // out of time. + let mut timeout_flag = false; loop { select! { recv(rx_result) -> msg => { match msg { Ok(hashed_circs) => { - jobs_completed += 1; - for (circ_hash, circ) in &hashed_circs { - circ_cnt += 1; - if circ_cnt % 1000 == 0 { - // TODO: Add a minimum time between logs - log_progress::<_,u64,usize>(log_config.progress_log.as_mut(), circ_cnt, None, &seen_hashes) - .expect("Failed to write to progress log"); - } - if !seen_hashes.insert(*circ_hash) { - continue; + let send_result = tracing::trace_span!(target: "taso::metrics", "recv_result").in_scope(|| { + jobs_completed += 1; + for (circ_hash, circ) in &hashed_circs { + circ_cnt += 1; + logger.log_progress(circ_cnt, None, seen_hashes.len()); + if seen_hashes.contains(circ_hash) { + continue; + } + seen_hashes.insert(*circ_hash); + + let cost = (self.cost)(circ); + + // Check if we got a new best circuit + if cost < best_circ_cost { + best_circ = circ.clone(); + best_circ_cost = cost; + logger.log_best(best_circ_cost); + } + jobs_sent += 1; } - - let cost = (self.cost)(circ); - - // Check if we got a new best circuit - if cost < best_circ_cost { - best_circ = circ.clone(); - best_circ_cost = cost; - log_best(best_circ_cost, log_candidates.as_mut()).unwrap(); - } - jobs_sent += 1; - } - // Fill the workqueue with data from pq - if tx_work.send(hashed_circs).is_err() { + // Fill the workqueue with data from pq + tx_work.send(hashed_circs) + }); + if send_result.is_err() { eprintln!("All our workers panicked. Stopping optimisation."); break; - }; + } // If there is no more data to process, we are done. // @@ -338,25 +273,17 @@ where } } recv(timeout_event) -> _ => { - println!("Timeout"); + timeout_flag = true; break; } } } - log_processing_end(circ_cnt, true); + logger.log_processing_end(circ_cnt, best_circ_cost, true, timeout_flag); // Drop the channel so the threads know to stop. drop(tx_work); - let _ = joins; // joins.into_iter().for_each(|j| j.join().unwrap()); - - log_final( - &best_circ, - log_config.progress_log.as_mut(), - log_config.final_circ_json.as_mut(), - &self.cost, - ) - .expect("Failed to write to progress log and/or final circuit JSON"); + joins.into_iter().for_each(|j| j.join().unwrap()); best_circ } @@ -381,68 +308,3 @@ mod taso_default { } } } - -/// A helper struct for logging improvements in circuit size seen during the -/// TASO execution. -// -// TODO: Replace this fixed logging. Report back intermediate results. -#[derive(serde::Serialize, Clone, Debug)] -struct BestCircSer { - circ_len: usize, - time: String, -} - -impl BestCircSer { - fn new(circ_len: usize) -> Self { - let time = chrono::Local::now().to_rfc3339(); - Self { circ_len, time } - } -} - -fn log_best(cbest: usize, wtr: Option<&mut csv::Writer>) -> io::Result<()> { - let Some(wtr) = wtr else { - return Ok(()); - }; - println!("new best of size {}", cbest); - wtr.serialize(BestCircSer::new(cbest)).unwrap(); - wtr.flush() -} - -fn log_processing_end(circuit_count: usize, needs_joining: bool) { - println!("END"); - println!("Tried {circuit_count} circuits"); - if needs_joining { - println!("Joining"); - } -} - -fn log_progress( - wr: Option<&mut W>, - circ_cnt: usize, - pq: Option<&HugrPQ>, - seen_hashes: &FxHashSet, -) -> io::Result<()> { - if let Some(wr) = wr { - writeln!(wr, "{circ_cnt} circuits...")?; - if let Some(pq) = pq { - writeln!(wr, "Queue size: {} circuits", pq.len())?; - } - writeln!(wr, "Total seen: {} circuits", seen_hashes.len())?; - } - Ok(()) -} - -fn log_final( - best_circ: &Hugr, - log: Option<&mut W1>, - final_circ: Option<&mut W2>, - cost: impl Fn(&Hugr) -> usize, -) -> io::Result<()> { - if let Some(log) = log { - writeln!(log, "END RESULT: {}", cost(best_circ))?; - } - if let Some(circ_writer) = final_circ { - save_tk1_json_writer(best_circ, circ_writer).unwrap(); - } - Ok(()) -} diff --git a/src/optimiser/taso/hugr_pchannel.rs b/src/optimiser/taso/hugr_pchannel.rs index c3e157cf..1ec0d2e4 100644 --- a/src/optimiser/taso/hugr_pchannel.rs +++ b/src/optimiser/taso/hugr_pchannel.rs @@ -45,6 +45,7 @@ where let in_channel = in_channel_orig.clone(); let out_channel = out_channel_orig.clone(); let _ = builder + .name("priority-channel".into()) .spawn(move || { // The priority queue, local to this thread. let mut pq: HugrPQ = diff --git a/src/optimiser/taso/log.rs b/src/optimiser/taso/log.rs new file mode 100644 index 00000000..39969ee6 --- /dev/null +++ b/src/optimiser/taso/log.rs @@ -0,0 +1,111 @@ +//! Logging utilities for the TASO optimiser. + +use std::io; + +/// Logging configuration for the TASO optimiser. +#[derive(Default)] +pub struct TasoLogger<'w> { + circ_candidates_csv: Option>>, +} + +/// The logging target for general events. +pub const LOG_TARGET: &str = "taso::log"; +/// The logging target for progress events. More verbose than the general log. +pub const PROGRESS_TARGET: &str = "taso::progress"; +/// The logging target for function spans. +pub const METRICS_TARGET: &str = "taso::metrics"; + +impl<'w> TasoLogger<'w> { + /// Create a new logging configuration. + /// + /// Two writer objects must be provided: + /// - best_progress_csv_writer: for a log of the successive best candidate + /// circuits, + /// + /// Regular events are logged with [`tracing`], with targets [`LOG_TARGET`] + /// or [`PROGRESS_TARGET`]. + /// + /// [`log`]: + pub fn new(best_progress_csv_writer: impl io::Write + 'w) -> Self { + let boxed_candidates_writer: Box = Box::new(best_progress_csv_writer); + Self { + circ_candidates_csv: Some(csv::Writer::from_writer(boxed_candidates_writer)), + } + } + + /// Log a new best candidate + #[inline] + pub fn log_best(&mut self, best_cost: usize) { + self.log(format!("new best of size {}", best_cost)); + if let Some(csv_writer) = self.circ_candidates_csv.as_mut() { + csv_writer.serialize(BestCircSer::new(best_cost)).unwrap(); + csv_writer.flush().unwrap(); + }; + } + + /// Log the final optimised circuit + #[inline] + pub fn log_processing_end( + &self, + circuit_count: usize, + best_cost: usize, + needs_joining: bool, + timeout: bool, + ) { + if timeout { + self.log("Timeout"); + } + self.log("Optimisation finished"); + self.log(format!("Tried {circuit_count} circuits")); + self.log(format!("END RESULT: {}", best_cost)); + if needs_joining { + self.log("Joining worker threads"); + } + } + + /// Log the progress of the optimisation. + #[inline(always)] + pub fn log_progress( + &mut self, + circ_cnt: usize, + workqueue_len: Option, + seen_hashes: usize, + ) { + if circ_cnt % 1000 == 0 { + self.progress(format!("{circ_cnt} circuits...")); + if let Some(workqueue_len) = workqueue_len { + self.progress(format!("Queue size: {workqueue_len} circuits")); + } + self.progress(format!("Total seen: {} circuits", seen_hashes)); + } + } + + /// Internal function to log general events, normally printed to stdout. + #[inline] + fn log(&self, msg: impl AsRef) { + tracing::info!(target: LOG_TARGET, "{}", msg.as_ref()); + } + + /// Internal function to log information on the progress of the optimization. + #[inline] + fn progress(&self, msg: impl AsRef) { + tracing::info!(target: PROGRESS_TARGET, "{}", msg.as_ref()); + } +} + +/// A helper struct for logging improvements in circuit size seen during the +/// TASO execution. +// +// TODO: Replace this fixed logging. Report back intermediate results. +#[derive(serde::Serialize, Clone, Debug)] +struct BestCircSer { + circ_len: usize, + time: String, +} + +impl BestCircSer { + fn new(circ_len: usize) -> Self { + let time = chrono::Local::now().to_rfc3339(); + Self { circ_len, time } + } +} diff --git a/src/optimiser/taso/worker.rs b/src/optimiser/taso/worker.rs index 4472031e..0d8667c8 100644 --- a/src/optimiser/taso/worker.rs +++ b/src/optimiser/taso/worker.rs @@ -8,24 +8,60 @@ use crate::circuit::CircuitHash; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; -pub fn spawn_pattern_matching_thread( - rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, - tx_result: crossbeam_channel::Sender>, - rewriter: impl Rewriter + Send + 'static, - strategy: impl RewriteStrategy + Send + 'static, -) -> JoinHandle<()> { - thread::spawn(move || { - // Process work until the main thread closes the channel send or receive - // channel. +/// A worker that processes circuits for the TASO optimiser. +pub struct TasoWorker { + _phantom: std::marker::PhantomData<(R, S)>, +} + +impl TasoWorker +where + R: Rewriter + Send + 'static, + S: RewriteStrategy + Send + 'static, +{ + /// Spawn a new worker thread. + pub fn spawn( + rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, + tx_result: crossbeam_channel::Sender>, + rewriter: R, + strategy: S, + worker_name: Option, + ) -> JoinHandle<()> { + let mut builder = thread::Builder::new(); + if let Some(name) = worker_name { + builder = builder.name(name); + }; + builder + .spawn(move || Self::worker_loop(rx_work, tx_result, rewriter, strategy)) + .unwrap() + } + + /// Main loop of the worker. + /// + /// Processes work until the main thread closes the channel send or receive + /// channel. + #[tracing::instrument(target = "taso::metrics", skip_all)] + fn worker_loop( + rx_work: crossbeam_channel::Receiver<(u64, Hugr)>, + tx_result: crossbeam_channel::Sender>, + rewriter: R, + strategy: S, + ) { while let Ok((_hash, circ)) = rx_work.recv() { - let rewrites = rewriter.get_rewrites(&circ); - let circs = strategy.apply_rewrites(rewrites, &circ); - let hashed_circs = circs.into_iter().map(|c| (c.circuit_hash(), c)).collect(); - let send = tx_result.send(hashed_circs); + let hashed_circs = Self::process_circ(circ, &rewriter, &strategy); + let send = tracing::trace_span!(target: "taso::metrics", "TasoWorker::send_result") + .in_scope(|| tx_result.send(hashed_circs)); if send.is_err() { // The main thread closed the send channel, we can stop. break; } } - }) + } + + /// Process a circuit. + #[tracing::instrument(target = "taso::metrics", skip_all)] + fn process_circ(circ: Hugr, rewriter: &R, strategy: &S) -> Vec<(u64, Hugr)> { + let rewrites = rewriter.get_rewrites(&circ); + let circs = strategy.apply_rewrites(rewrites, &circ); + circs.into_iter().map(|c| (c.circuit_hash(), c)).collect() + } } diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 645b5c16..29d5a934 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -45,6 +45,7 @@ pub trait RewriteStrategy { pub struct GreedyRewriteStrategy; impl RewriteStrategy for GreedyRewriteStrategy { + #[tracing::instrument(skip_all)] fn apply_rewrites( &self, rewrites: impl IntoIterator, @@ -99,6 +100,7 @@ impl Default for ExhaustiveRewriteStrategy { } impl RewriteStrategy for ExhaustiveRewriteStrategy { + #[tracing::instrument(skip_all)] fn apply_rewrites( &self, rewrites: impl IntoIterator, diff --git a/taso-optimiser/Cargo.toml b/taso-optimiser/Cargo.toml index 787f5654..d3106737 100644 --- a/taso-optimiser/Cargo.toml +++ b/taso-optimiser/Cargo.toml @@ -12,8 +12,11 @@ tket2 = { path = "../", features = ["portmatching"] } quantinuum-hugr = { workspace = true } itertools = { workspace = true } tket-json-rs = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = "0.3.17" +tracing-appender = "0.2.2" peak_alloc = { version = "0.2.0", optional = true } [features] default = ["peak_alloc"] -peak_alloc = ["dep:peak_alloc"] \ No newline at end of file +peak_alloc = ["dep:peak_alloc"] diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 4c277f8e..95368318 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -1,9 +1,16 @@ +mod tracing; + +use crate::tracing::Tracer; + +use std::io::BufWriter; use std::num::NonZeroUsize; +use std::path::PathBuf; use std::process::exit; -use std::{fs, io, path::Path}; +use std::{fs, path::Path}; use clap::Parser; use hugr::Hugr; +use tket2::optimiser::taso::log::TasoLogger; use tket2::{ json::{load_tk1_json_file, TKETDecode}, optimiser::TasoOptimiser, @@ -31,7 +38,7 @@ struct CmdLineArgs { value_name = "FILE", help = "Input. A quantum circuit in TK1 JSON format." )] - input: String, + input: PathBuf, /// Output circuit file #[arg( short, @@ -40,7 +47,7 @@ struct CmdLineArgs { value_name = "FILE", help = "Output. A quantum circuit in TK1 JSON format." )] - output: String, + output: PathBuf, /// ECC file #[arg( short, @@ -48,7 +55,16 @@ struct CmdLineArgs { value_name = "ECC_FILE", help = "Sets the ECC file to use. It is a JSON file of Quartz-generated ECCs." )] - eccs: String, + eccs: PathBuf, + /// Log output file + #[arg( + short, + long, + default_value = "taso-optimisation.log", + value_name = "LOGFILE", + help = "Logfile to to output the progress of the optimisation." + )] + logfile: Option, /// Timeout in seconds (default=no timeout) #[arg( short, @@ -62,27 +78,37 @@ struct CmdLineArgs { short = 'j', long, value_name = "N_THREADS", - help = "The number of threads to use. By default, the number of threads is equal to the number of logical cores." + help = "The number of threads to use. By default, use a single thread." )] n_threads: Option, } fn save_tk1_json_file(path: impl AsRef, circ: &Hugr) -> Result<(), std::io::Error> { let file = fs::File::create(path)?; - let writer = io::BufWriter::new(file); + let writer = BufWriter::new(file); let serial_circ = SerialCircuit::encode(circ).unwrap(); serde_json::to_writer_pretty(writer, &serial_circ)?; Ok(()) } -fn main() { +fn main() -> Result<(), Box> { let opts = CmdLineArgs::parse(); + // Setup tracing subscribers for stdout and file logging. + // + // We need to keep the object around to keep the logging active. + let _tracer = Tracer::setup_tracing(opts.logfile); + let input_path = Path::new(&opts.input); let output_path = Path::new(&opts.output); let ecc_path = Path::new(&opts.eccs); - let circ = load_tk1_json_file(input_path).unwrap(); + // TODO: Remove this from the Logger, and use tracing events instead. + let circ_candidates_csv = fs::File::create("best_circs.csv")?; + + let taso_logger = TasoLogger::new(circ_candidates_csv); + + let circ = load_tk1_json_file(input_path)?; println!("Compiling rewriter..."); let Ok(optimiser) = TasoOptimiser::default_with_eccs_json_file(ecc_path) else { @@ -101,15 +127,14 @@ fn main() { println!("Using {n_threads} threads"); println!("Optimising..."); - let opt_circ = optimiser - .optimise_with_default_log(&circ, opts.timeout, n_threads) - .unwrap(); + let opt_circ = optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads); println!("Saving result"); - save_tk1_json_file(output_path, &opt_circ).unwrap(); + save_tk1_json_file(output_path, &opt_circ)?; #[cfg(feature = "peak_alloc")] println!("Peak memory usage: {} GB", PEAK_ALLOC.peak_usage_as_gb()); - println!("Done.") + println!("Done."); + Ok(()) } diff --git a/taso-optimiser/src/tracing.rs b/taso-optimiser/src/tracing.rs new file mode 100644 index 00000000..0e90f0a5 --- /dev/null +++ b/taso-optimiser/src/tracing.rs @@ -0,0 +1,73 @@ +//! Setup routines for tracing and logging of the optimisation process. +use std::fs::File; +use std::io::BufWriter; +use std::path::PathBuf; + +use tket2::optimiser::taso::log::{LOG_TARGET, METRICS_TARGET, PROGRESS_TARGET}; + +use tracing::{Metadata, Subscriber}; +use tracing_appender::non_blocking; +use tracing_subscriber::filter::filter_fn; +use tracing_subscriber::prelude::__tracing_subscriber_SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::Layer; + +fn log_filter(metadata: &Metadata<'_>) -> bool { + metadata.target().starts_with(LOG_TARGET) +} + +fn verbose_filter(metadata: &Metadata<'_>) -> bool { + metadata.target().starts_with(LOG_TARGET) || metadata.target().starts_with(PROGRESS_TARGET) +} + +#[allow(unused)] +fn metrics_filter(metadata: &Metadata<'_>) -> bool { + metadata.target().starts_with(METRICS_TARGET) +} + +#[derive(Debug, Default)] +pub struct Tracer { + pub logfile: Option, +} + +impl Tracer { + /// Setup tracing subscribers for stdout and file logging. + pub fn setup_tracing(logfile: Option) -> Self { + let mut tracer = Self::default(); + tracing_subscriber::registry() + .with(tracer.stdout_layer()) + .with(logfile.map(|f| tracer.logfile_layer(f))) + .init(); + tracer + } + + /// Initialize a file logger handle and non-blocking worker. + fn init_writer(&self, file: PathBuf) -> (non_blocking::NonBlocking, non_blocking::WorkerGuard) { + let writer = BufWriter::new(File::create(file).unwrap()); + non_blocking(writer) + } + + /// Clean log with the most important events. + fn stdout_layer(&mut self) -> impl Layer + where + S: Subscriber + for<'span> tracing_subscriber::registry::LookupSpan<'span>, + { + tracing_subscriber::fmt::layer() + .without_time() + .with_target(false) + .with_level(false) + .with_filter(filter_fn(log_filter)) + } + + fn logfile_layer(&mut self, logfile: PathBuf) -> impl Layer + where + S: Subscriber + for<'span> tracing_subscriber::registry::LookupSpan<'span>, + { + let (non_blocking, guard) = self.init_writer(logfile); + self.logfile = Some(guard); + tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_writer(non_blocking) + .with_filter(filter_fn(verbose_filter)) + } +} From 0a000e0c693ccfebcd97425bb1560eb3012c3afc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Fri, 22 Sep 2023 17:15:16 +0200 Subject: [PATCH 5/6] chore: Update hugr (#131) `Port::index()` now requires importing a trait --- Cargo.toml | 2 +- src/circuit/command.rs | 2 +- src/optimiser/taso.rs | 4 +++- src/portmatching/matcher.rs | 1 + 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6161abcf..f7a52f64 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,7 +73,7 @@ members = ["pyrs", "compile-matcher", "taso-optimiser"] [workspace.dependencies] -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "981f4f9" } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "19ed0fc" } portgraph = { version = "0.9", features = ["serde"] } pyo3 = { version = "0.19" } itertools = { version = "0.11.0" } diff --git a/src/circuit/command.rs b/src/circuit/command.rs index b830b313..094255e1 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -5,7 +5,7 @@ use std::collections::{HashMap, HashSet}; use std::iter::FusedIterator; -use hugr::hugr::NodeType; +use hugr::hugr::{NodeType, PortIndex}; use hugr::ops::{OpTag, OpTrait}; use petgraph::visit as pv; diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 66c91ab6..ad88bc51 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -21,7 +21,6 @@ mod worker; use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; -use std::io; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; @@ -37,6 +36,9 @@ use crate::rewrite::Rewriter; use self::log::TasoLogger; +#[cfg(feature = "portmatching")] +use std::io; + /// The TASO optimiser. /// /// Adapted from [Quartz][], and originally [TASO][]. diff --git a/src/portmatching/matcher.rs b/src/portmatching/matcher.rs index 960ea346..2da3fa45 100644 --- a/src/portmatching/matcher.rs +++ b/src/portmatching/matcher.rs @@ -9,6 +9,7 @@ use std::{ use super::{CircuitPattern, NodeID, PEdge, PNode}; use hugr::hugr::views::sibling_subgraph::{ConvexChecker, InvalidReplacement, InvalidSubgraph}; +use hugr::hugr::PortIndex; use hugr::{hugr::views::SiblingSubgraph, ops::OpType, Hugr, Node, Port}; use itertools::Itertools; use portmatching::{ From 33cbc71d30d7c7e475108bd58771ffa8cba2cfc0 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 25 Sep 2023 11:41:50 +0100 Subject: [PATCH 6/6] chore: add just to devenv (#132) --- devenv.nix | 1 + justfile | 8 ++++++++ 2 files changed, 9 insertions(+) create mode 100644 justfile diff --git a/devenv.nix b/devenv.nix index 2fb98128..35be7a3e 100644 --- a/devenv.nix +++ b/devenv.nix @@ -8,6 +8,7 @@ packages = lib.optionals pkgs.stdenv.isDarwin (with pkgs.darwin.apple_sdk; [ frameworks.CoreServices frameworks.CoreFoundation + pkgs.just ]); # Certain Rust tools won't work without this diff --git a/justfile b/justfile new file mode 100644 index 00000000..bbee89e0 --- /dev/null +++ b/justfile @@ -0,0 +1,8 @@ +test: + cargo test + +fix: + cargo clippy --fix --allow-staged + +ptest: + (cd pyrs && maturin develop && pytest) \ No newline at end of file