diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index dbfbae9be..f258fe024 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -40,6 +40,7 @@ struct TargetID(usize); /// Valid rewrites turn a non-representative circuit into its representative, /// or a representative circuit into any of the equivalent non-representative /// circuits. +#[derive(Debug, Clone)] pub struct ECCRewriter { /// Matcher for finding patterns. matcher: PatternMatcher, diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index cfea36f2b..645b5c161 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -86,6 +86,7 @@ impl RewriteStrategy for GreedyRewriteStrategy { /// strictly reduces the gate count. The default is gamma = 1.0001, as set /// in the Quartz paper. This essentially allows rewrites that improve or leave /// the number of nodes unchanged. +#[derive(Debug, Clone)] pub struct ExhaustiveRewriteStrategy { /// The gamma parameter. pub gamma: f64, diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 3f6919c31..0a42f599c 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -5,7 +5,7 @@ use hugr::Hugr; use tket2::{ circuit::Circuit, json::{load_tk1_json_file, TKETDecode}, - passes::taso::taso, + passes::taso::{taso, taso_mpsc}, rewrite::{strategy::ExhaustiveRewriteStrategy, ECCRewriter}, }; use tket_json_rs::circuit_json::SerialCircuit; @@ -42,6 +42,15 @@ struct CmdLineArgs { help = "Sets the ECC file to use. It is a JSON file of Quartz-generated ECCs." )] eccs: String, + /// Number of threads (default=1) + #[arg( + short, + long, + default_value = 1, + value_name = "N_THREADS", + help = "The number of threads to use." + )] + n_threads: usize, } fn save_tk1_json_file(path: impl AsRef, circ: &Hugr) -> Result<(), std::io::Error> { @@ -66,7 +75,20 @@ fn main() { let strategy = ExhaustiveRewriteStrategy::default(); println!("Optimising..."); - let opt_circ = taso(circ, rewriter, strategy, |c| c.num_gates(), Some(100)); + let opt_circ = if opts.n_threads == 1 { + println!("Using single-threaded TASO"); + taso(circ, rewriter, strategy, |c| c.num_gates(), Some(100)) + } else { + println!("Using multi-threaded TASO with {} threads", opts.n_threads); + taso_mpsc( + circ, + rewriter, + strategy, + |c| c.num_gates(), + Some(100), + opts.n_threads, + ) + }; println!("Saving result"); save_tk1_json_file(output_path, &opt_circ).unwrap();