Skip to content

Commit

Permalink
Merge branch 'main' into feat/chunk-id-wires
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q authored Oct 5, 2023
2 parents 395f0d7 + 4174e14 commit ea64397
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fxhash = "0.2.1"
rmp-serde = { version = "1.1.2", optional = true }
delegate = "0.10.0"
csv = { version = "1.2.2" }
chrono = { version ="0.4.30" }
chrono = { version = "0.4.30" }
bytemuck = "1.14.0"
stringreader = "0.1.1"
crossbeam-channel = "0.5.8"
Expand Down
17 changes: 17 additions & 0 deletions pyrs/src/optimiser.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! PyO3 wrapper for the TASO circuit optimiser.
use std::io::BufWriter;
use std::{fs, num::NonZeroUsize, path::PathBuf};

use pyo3::prelude::*;
Expand Down Expand Up @@ -42,16 +43,31 @@ impl PyDefaultTasoOptimiser {
///
/// Returns an optimised circuit and optionally log the progress to a CSV
/// file.
///
/// # Parameters
///
/// * `circ`: The circuit to optimise.
/// * `timeout`: The timeout in seconds.
/// * `n_threads`: The number of threads to use.
/// * `split_circ`: Whether to split the circuit into chunks before
/// processing.
///
/// If this option is set, the optimise will divide the circuit into
/// `n_threads` chunks and optimise each on a separate thread.
/// * `log_progress`: The path to a CSV file to log progress to.
///
pub fn optimise(
&self,
circ: PyObject,
timeout: Option<u64>,
n_threads: Option<NonZeroUsize>,
split_circ: Option<bool>,
log_progress: Option<PathBuf>,
) -> PyResult<PyObject> {
let taso_logger = log_progress
.map(|file_name| {
let log_file = fs::File::create(file_name).unwrap();
let log_file = BufWriter::new(log_file);
TasoLogger::new(log_file)
})
.unwrap_or_default();
Expand All @@ -61,6 +77,7 @@ impl PyDefaultTasoOptimiser {
taso_logger,
timeout,
n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
split_circ.unwrap_or(false),
)
})
}
Expand Down
98 changes: 83 additions & 15 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,24 @@ mod worker;
use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
use fxhash::FxHashSet;
use hugr::hugr::HugrError;
pub use log::TasoLogger;

use std::num::NonZeroUsize;
use std::time::{Duration, Instant};
use std::{mem, thread};

use hugr::Hugr;

use crate::circuit::cost::CircuitCost;
use crate::circuit::CircuitHash;
use crate::optimiser::taso::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog};
use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::taso::worker::TasoWorker;
use crate::passes::CircuitChunks;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::Rewriter;
use crate::Circuit;

/// The TASO optimiser.
///
Expand Down Expand Up @@ -75,16 +80,19 @@ impl<R, S> TasoOptimiser<R, S>
where
R: Rewriter + Send + Clone + 'static,
S: RewriteStrategy + Send + Sync + Clone + 'static,
S::Cost: serde::Serialize,
S::Cost: serde::Serialize + Send + Sync,
{
/// Run the TASO optimiser on a circuit.
///
/// A timeout (in seconds) can be provided.
pub fn optimise(&self, circ: &Hugr, timeout: Option<u64>, n_threads: NonZeroUsize) -> Hugr
where
S::Cost: Send + Sync + Clone,
{
self.optimise_with_log(circ, Default::default(), timeout, n_threads)
pub fn optimise(
&self,
circ: &Hugr,
timeout: Option<u64>,
n_threads: NonZeroUsize,
split_circuit: bool,
) -> Hugr {
self.optimise_with_log(circ, Default::default(), timeout, n_threads, split_circuit)
}

/// Run the TASO optimiser on a circuit with logging activated.
Expand All @@ -96,10 +104,13 @@ where
log_config: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Hugr
where
S::Cost: Send + Sync + Clone,
{
split_circuit: bool,
) -> Hugr {
if split_circuit && n_threads.get() > 1 {
return self
.split_run(circ, log_config, timeout, n_threads)
.unwrap();
}
match n_threads.get() {
1 => self.taso(circ, log_config, timeout),
_ => self.taso_multithreaded(circ, log_config, timeout, n_threads),
Expand Down Expand Up @@ -177,10 +188,7 @@ where
mut logger: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Hugr
where
S::Cost: Send + Sync + Clone,
{
) -> Hugr {
let n_threads: usize = n_threads.get();
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;

Expand Down Expand Up @@ -276,6 +284,66 @@ where

best_circ
}

/// Split the circuit into chunks and process each in a separate thread.
#[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))]
fn split_run(
&self,
circ: &Hugr,
mut logger: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Result<Hugr, HugrError> {
let circ_cost = self.cost(circ);
let max_chunk_cost = circ_cost.clone().div_cost(n_threads);
logger.log(format!(
"Splitting circuit with cost {:?} into chunks of at most {max_chunk_cost:?}.",
circ_cost.clone()
));
let mut chunks =
CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op));

logger.log_best(circ_cost.clone());

let (joins, rx_work): (Vec<_>, Vec<_>) = chunks
.iter_mut()
.enumerate()
.map(|(i, chunk)| {
let (tx, rx) = crossbeam_channel::unbounded();
let taso = self.clone();
let chunk = mem::take(chunk);
let chunk_cx_cost = chunk.circuit_cost(|op| self.strategy.op_cost(op));
logger.log(format!("Chunk {i} has {chunk_cx_cost:?} CX gates",));
let join = thread::Builder::new()
.name(format!("chunk-{}", i))
.spawn(move || {
let res =
taso.optimise(&chunk, timeout, NonZeroUsize::new(1).unwrap(), false);
tx.send(res).unwrap();
})
.unwrap();
(join, rx)
})
.unzip();

for i in 0..chunks.len() {
let res = rx_work[i]
.recv()
.unwrap_or_else(|_| panic!("Worker thread panicked"));
chunks[i] = res;
}

let best_circ = chunks.reassemble()?;
let best_circ_cost = self.cost(&best_circ);
if best_circ_cost.clone() < circ_cost {
logger.log_best(best_circ_cost.clone());
}

logger.log_processing_end(n_threads.get(), best_circ_cost, true, false);
joins.into_iter().for_each(|j| j.join().unwrap());

Ok(best_circ)
}
}

#[cfg(feature = "portmatching")]
Expand Down Expand Up @@ -360,7 +428,7 @@ mod tests {

#[rstest]
fn rz_rz_cancellation(rz_rz: Hugr, taso_opt: DefaultTasoOptimiser) {
let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap());
let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false);
let cmds = opt_rz
.commands()
.map(|cmd| {
Expand Down
8 changes: 4 additions & 4 deletions src/optimiser/taso/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ impl<'w> TasoLogger<'w> {
}
}

/// Internal function to log general events, normally printed to stdout.
/// Log general events, normally printed to stdout.
#[inline]
fn log(&self, msg: impl AsRef<str>) {
pub fn log(&self, msg: impl AsRef<str>) {
tracing::info!(target: LOG_TARGET, "{}", msg.as_ref());
}

/// Internal function to log information on the progress of the optimization.
/// Log verbose information on the progress of the optimization.
#[inline]
fn progress(&self, msg: impl AsRef<str>) {
pub fn progress(&self, msg: impl AsRef<str>) {
tracing::info!(target: PROGRESS_TARGET, "{}", msg.as_ref());
}
}
Expand Down
15 changes: 13 additions & 2 deletions taso-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ struct CmdLineArgs {
help = "The number of threads to use. By default, use a single thread."
)]
n_threads: Option<NonZeroUsize>,
/// Number of threads (default=1)
#[arg(
long = "split-circ",
help = "Split the circuit into chunks and optimize each one in a separate thread. Use `-j` to specify the number of threads to use."
)]
split_circ: bool,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -86,7 +92,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Setup tracing subscribers for stdout and file logging.
//
// We need to keep the object around to keep the logging active.
let _tracer = Tracer::setup_tracing(opts.logfile);
let _tracer = Tracer::setup_tracing(opts.logfile, opts.split_circ);

let input_path = Path::new(&opts.input);
let output_path = Path::new(&opts.output);
Expand Down Expand Up @@ -115,8 +121,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.unwrap_or(NonZeroUsize::new(1).unwrap());
println!("Using {n_threads} threads");

if opts.split_circ && n_threads.get() > 1 {
println!("Splitting circuit into {n_threads} chunks.");
}

println!("Optimising...");
let opt_circ = optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads);
let opt_circ =
optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads, opts.split_circ);

println!("Saving result");
save_tk1_json_file(&opt_circ, output_path)?;
Expand Down
12 changes: 7 additions & 5 deletions taso-optimiser/src/tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ pub struct Tracer {

impl Tracer {
/// Setup tracing subscribers for stdout and file logging.
pub fn setup_tracing(logfile: Option<PathBuf>) -> Self {
pub fn setup_tracing(logfile: Option<PathBuf>, show_threads: bool) -> Self {
let mut tracer = Self::default();
tracing_subscriber::registry()
.with(tracer.stdout_layer())
.with(logfile.map(|f| tracer.logfile_layer(f)))
.with(tracer.stdout_layer(show_threads))
.with(logfile.map(|f| tracer.logfile_layer(f, show_threads)))
.init();
tracer
}
Expand All @@ -48,25 +48,27 @@ impl Tracer {
}

/// Clean log with the most important events.
fn stdout_layer<S>(&mut self) -> impl Layer<S>
fn stdout_layer<S>(&mut self, show_threads: bool) -> impl Layer<S>
where
S: Subscriber + for<'span> tracing_subscriber::registry::LookupSpan<'span>,
{
tracing_subscriber::fmt::layer()
.without_time()
.with_target(false)
.with_level(false)
.with_thread_names(show_threads)
.with_filter(filter_fn(log_filter))
}

fn logfile_layer<S>(&mut self, logfile: PathBuf) -> impl Layer<S>
fn logfile_layer<S>(&mut self, logfile: PathBuf, show_threads: bool) -> impl Layer<S>
where
S: Subscriber + for<'span> tracing_subscriber::registry::LookupSpan<'span>,
{
let (non_blocking, guard) = self.init_writer(logfile);
self.logfile = Some(guard);
tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_thread_names(show_threads)
.with_writer(non_blocking)
.with_filter(filter_fn(verbose_filter))
}
Expand Down

0 comments on commit ea64397

Please sign in to comment.