From 4e5ac313a81e51a90a7b67fdc2a509950dd58ce5 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 5 Oct 2023 19:39:13 +0100 Subject: [PATCH] feat: queue size argument for TASO --- pyrs/src/optimiser.rs | 2 ++ src/optimiser/taso.rs | 48 +++++++++++++++++++++++++++----------- src/optimiser/taso/log.rs | 10 ++++---- taso-optimiser/src/main.rs | 43 +++++++++++++++++++++++----------- 4 files changed, 70 insertions(+), 33 deletions(-) diff --git a/pyrs/src/optimiser.rs b/pyrs/src/optimiser.rs index a6efab45..2415a613 100644 --- a/pyrs/src/optimiser.rs +++ b/pyrs/src/optimiser.rs @@ -63,6 +63,7 @@ impl PyDefaultTasoOptimiser { n_threads: Option, split_circ: Option, log_progress: Option, + queue_size: Option, ) -> PyResult { let taso_logger = log_progress .map(|file_name| { @@ -78,6 +79,7 @@ impl PyDefaultTasoOptimiser { timeout, n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), split_circ.unwrap_or(false), + queue_size.unwrap_or(10_000), ) }) } diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index a15e8c5d..e0efd669 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -91,8 +91,16 @@ where timeout: Option, n_threads: NonZeroUsize, split_circuit: bool, + queue_size: usize, ) -> Hugr { - self.optimise_with_log(circ, Default::default(), timeout, n_threads, split_circuit) + self.optimise_with_log( + circ, + Default::default(), + timeout, + n_threads, + split_circuit, + queue_size, + ) } /// Run the TASO optimiser on a circuit with logging activated. @@ -105,20 +113,27 @@ where timeout: Option, n_threads: NonZeroUsize, split_circuit: bool, + queue_size: usize, ) -> Hugr { if split_circuit && n_threads.get() > 1 { return self - .split_run(circ, log_config, timeout, n_threads) + .split_run(circ, log_config, timeout, n_threads, queue_size) .unwrap(); } match n_threads.get() { - 1 => self.taso(circ, log_config, timeout), - _ => self.taso_multithreaded(circ, log_config, timeout, n_threads), + 1 => self.taso(circ, log_config, timeout, queue_size), + _ => self.taso_multithreaded(circ, log_config, timeout, n_threads, queue_size), } } #[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))] - fn taso(&self, circ: &Hugr, mut logger: TasoLogger, timeout: Option) -> Hugr { + fn taso( + &self, + circ: &Hugr, + mut logger: TasoLogger, + timeout: Option, + queue_size: usize, + ) -> Hugr { let start_time = Instant::now(); let mut best_circ = circ.clone(); @@ -130,12 +145,11 @@ where seen_hashes.insert(circ.circuit_hash()); // The priority queue of circuits to be processed (this should not get big) - const PRIORITY_QUEUE_CAPACITY: usize = 10_000; let cost_fn = { let strategy = self.strategy.clone(); move |circ: &'_ Hugr| strategy.circuit_cost(circ) }; - let mut pq = HugrPQ::with_capacity(cost_fn, PRIORITY_QUEUE_CAPACITY); + let mut pq = HugrPQ::with_capacity(cost_fn, queue_size); pq.push(circ.clone()); let mut circ_cnt = 1; @@ -160,9 +174,9 @@ where pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost); } - if pq.len() >= PRIORITY_QUEUE_CAPACITY { + if pq.len() >= queue_size { // Haircut to keep the queue size manageable - pq.truncate(PRIORITY_QUEUE_CAPACITY / 2); + pq.truncate(queue_size / 2); } if let Some(timeout) = timeout { @@ -188,16 +202,16 @@ where mut logger: TasoLogger, timeout: Option, n_threads: NonZeroUsize, + queue_size: usize, ) -> Hugr { let n_threads: usize = n_threads.get(); - const PRIORITY_QUEUE_CAPACITY: usize = 10_000; // multi-consumer priority channel for queuing circuits to be processed by the workers let cost_fn = { let strategy = self.strategy.clone(); move |circ: &'_ Hugr| strategy.circuit_cost(circ) }; - let mut pq = HugrPriorityChannel::init(cost_fn, PRIORITY_QUEUE_CAPACITY * n_threads); + let mut pq = HugrPriorityChannel::init(cost_fn, queue_size); let initial_circ_hash = circ.circuit_hash(); let mut best_circ = circ.clone(); @@ -293,6 +307,7 @@ where mut logger: TasoLogger, timeout: Option, n_threads: NonZeroUsize, + queue_size: usize, ) -> Result { let circ_cost = self.cost(circ); let max_chunk_cost = circ_cost.clone().div_cost(n_threads); @@ -317,8 +332,13 @@ where let join = thread::Builder::new() .name(format!("chunk-{}", i)) .spawn(move || { - let res = - taso.optimise(&chunk, timeout, NonZeroUsize::new(1).unwrap(), false); + let res = taso.optimise( + &chunk, + timeout, + NonZeroUsize::new(1).unwrap(), + false, + queue_size, + ); tx.send(res).unwrap(); }) .unwrap(); @@ -428,7 +448,7 @@ mod tests { #[rstest] fn rz_rz_cancellation(rz_rz: Hugr, taso_opt: DefaultTasoOptimiser) { - let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false); + let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false, 10_000); let cmds = opt_rz .commands() .map(|cmd| { diff --git a/src/optimiser/taso/log.rs b/src/optimiser/taso/log.rs index e88cbe41..402ed69c 100644 --- a/src/optimiser/taso/log.rs +++ b/src/optimiser/taso/log.rs @@ -52,12 +52,12 @@ impl<'w> TasoLogger<'w> { needs_joining: bool, timeout: bool, ) { - if timeout { - self.log("Timeout"); - } - self.log("Optimisation finished"); + match timeout { + true => self.log("Optimisation finished (timeout)"), + false => self.log("Optimisation finished"), + }; self.log(format!("Tried {circuit_count} circuits")); - self.log(format!("END RESULT: {:?}", best_cost)); + self.log(format!("---- END RESULT: {:?} ----", best_cost)); if needs_joining { self.log("Joining worker threads"); } diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index abb887c6..cda9a546 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -78,26 +78,41 @@ struct CmdLineArgs { help = "The number of threads to use. By default, use a single thread." )] n_threads: Option, - /// Number of threads (default=1) + /// Split the circuit into chunks, and process them separately. #[arg( long = "split-circ", help = "Split the circuit into chunks and optimize each one in a separate thread. Use `-j` to specify the number of threads to use." )] split_circ: bool, + /// Max queue size. + #[arg( + short = 'q', + long = "queue-size", + default_value = "10000", + value_name = "QUEUE_SIZE", + help = "The priority queue size. Defaults to 10_000." + )] + queue_size: usize, } fn main() -> Result<(), Box> { let opts = CmdLineArgs::parse(); - // Setup tracing subscribers for stdout and file logging. - // - // We need to keep the object around to keep the logging active. - let _tracer = Tracer::setup_tracing(opts.logfile, opts.split_circ); - let input_path = Path::new(&opts.input); let output_path = Path::new(&opts.output); let ecc_path = Path::new(&opts.eccs); + let n_threads = opts + .n_threads + // TODO: Default to multithreading once that produces better results. + //.or_else(|| std::thread::available_parallelism().ok()) + .unwrap_or(NonZeroUsize::new(1).unwrap()); + + // Setup tracing subscribers for stdout and file logging. + // + // We need to keep the object around to keep the logging active. + let _tracer = Tracer::setup_tracing(opts.logfile, n_threads.get() > 1); + // TODO: Remove this from the Logger, and use tracing events instead. let circ_candidates_csv = BufWriter::new(File::create("best_circs.csv")?); @@ -113,12 +128,6 @@ fn main() -> Result<(), Box> { ); exit(1); }; - - let n_threads = opts - .n_threads - // TODO: Default to multithreading once that produces better results. - //.or_else(|| std::thread::available_parallelism().ok()) - .unwrap_or(NonZeroUsize::new(1).unwrap()); println!("Using {n_threads} threads"); if opts.split_circ && n_threads.get() > 1 { @@ -126,8 +135,14 @@ fn main() -> Result<(), Box> { } println!("Optimising..."); - let opt_circ = - optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads, opts.split_circ); + let opt_circ = optimiser.optimise_with_log( + &circ, + taso_logger, + opts.timeout, + n_threads, + opts.split_circ, + opts.queue_size, + ); println!("Saving result"); save_tk1_json_file(&opt_circ, output_path)?;