Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Embarrassingly parallel TASO #149

Merged
merged 15 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ fxhash = "0.2.1"
rmp-serde = { version = "1.1.2", optional = true }
delegate = "0.10.0"
csv = { version = "1.2.2" }
chrono = { version ="0.4.30" }
chrono = { version = "0.4.30" }
bytemuck = "1.14.0"
stringreader = "0.1.1"
crossbeam-channel = "0.5.8"
Expand Down
17 changes: 17 additions & 0 deletions pyrs/src/optimiser.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! PyO3 wrapper for the TASO circuit optimiser.

use std::io::BufWriter;
use std::{fs, num::NonZeroUsize, path::PathBuf};

use pyo3::prelude::*;
Expand Down Expand Up @@ -42,16 +43,31 @@ impl PyDefaultTasoOptimiser {
///
/// Returns an optimised circuit and optionally log the progress to a CSV
/// file.
///
/// # Parameters
///
/// * `circ`: The circuit to optimise.
/// * `timeout`: The timeout in seconds.
/// * `n_threads`: The number of threads to use.
/// * `split_circ`: Whether to split the circuit into chunks before
/// processing.
///
/// If this option is set, the optimise will divide the circuit into
/// `n_threads` chunks and optimise each on a separate thread.
/// * `log_progress`: The path to a CSV file to log progress to.
///
pub fn optimise(
&self,
circ: PyObject,
timeout: Option<u64>,
n_threads: Option<NonZeroUsize>,
split_circ: Option<bool>,
log_progress: Option<PathBuf>,
) -> PyResult<PyObject> {
let taso_logger = log_progress
.map(|file_name| {
let log_file = fs::File::create(file_name).unwrap();
let log_file = BufWriter::new(log_file);
TasoLogger::new(log_file)
})
.unwrap_or_default();
Expand All @@ -61,6 +77,7 @@ impl PyDefaultTasoOptimiser {
taso_logger,
timeout,
n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
split_circ.unwrap_or(false),
)
})
}
Expand Down
98 changes: 83 additions & 15 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,24 @@ mod worker;
use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
use fxhash::FxHashSet;
use hugr::hugr::HugrError;
pub use log::TasoLogger;

use std::num::NonZeroUsize;
use std::time::{Duration, Instant};
use std::{mem, thread};

use hugr::Hugr;

use crate::circuit::cost::CircuitCost;
use crate::circuit::CircuitHash;
use crate::optimiser::taso::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog};
use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::taso::worker::TasoWorker;
use crate::passes::CircuitChunks;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::Rewriter;
use crate::Circuit;

/// The TASO optimiser.
///
Expand Down Expand Up @@ -75,16 +80,19 @@ impl<R, S> TasoOptimiser<R, S>
where
R: Rewriter + Send + Clone + 'static,
S: RewriteStrategy + Send + Sync + Clone + 'static,
S::Cost: serde::Serialize,
S::Cost: serde::Serialize + Send + Sync,
{
/// Run the TASO optimiser on a circuit.
///
/// A timeout (in seconds) can be provided.
pub fn optimise(&self, circ: &Hugr, timeout: Option<u64>, n_threads: NonZeroUsize) -> Hugr
where
S::Cost: Send + Sync + Clone,
{
self.optimise_with_log(circ, Default::default(), timeout, n_threads)
pub fn optimise(
&self,
circ: &Hugr,
timeout: Option<u64>,
n_threads: NonZeroUsize,
split_circuit: bool,
) -> Hugr {
self.optimise_with_log(circ, Default::default(), timeout, n_threads, split_circuit)
}

/// Run the TASO optimiser on a circuit with logging activated.
Expand All @@ -96,10 +104,13 @@ where
log_config: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Hugr
where
S::Cost: Send + Sync + Clone,
{
split_circuit: bool,
) -> Hugr {
if split_circuit && n_threads.get() > 1 {
return self
.split_run(circ, log_config, timeout, n_threads)
.unwrap();
}
match n_threads.get() {
1 => self.taso(circ, log_config, timeout),
_ => self.taso_multithreaded(circ, log_config, timeout, n_threads),
Expand Down Expand Up @@ -177,10 +188,7 @@ where
mut logger: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Hugr
where
S::Cost: Send + Sync + Clone,
{
) -> Hugr {
let n_threads: usize = n_threads.get();
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;

Expand Down Expand Up @@ -276,6 +284,66 @@ where

best_circ
}

/// Split the circuit into chunks and process each in a separate thread.
#[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))]
fn split_run(
&self,
circ: &Hugr,
mut logger: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Result<Hugr, HugrError> {
let circ_cost = self.cost(circ);
let max_chunk_cost = circ_cost.clone().div_cost(n_threads);
logger.log(format!(
"Splitting circuit with cost {:?} into chunks of at most {max_chunk_cost:?}.",
circ_cost.clone()
));
let mut chunks =
CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op));

logger.log_best(circ_cost.clone());

let (joins, rx_work): (Vec<_>, Vec<_>) = chunks
.iter_mut()
.enumerate()
.map(|(i, chunk)| {
let (tx, rx) = crossbeam_channel::unbounded();
let taso = self.clone();
let chunk = mem::take(chunk);
let chunk_cx_cost = chunk.circuit_cost(|op| self.strategy.op_cost(op));
logger.log(format!("Chunk {i} has {chunk_cx_cost:?} CX gates",));
let join = thread::Builder::new()
.name(format!("chunk-{}", i))
.spawn(move || {
let res =
taso.optimise(&chunk, timeout, NonZeroUsize::new(1).unwrap(), false);
tx.send(res).unwrap();
})
.unwrap();
(join, rx)
})
.unzip();

for i in 0..chunks.len() {
let res = rx_work[i]
.recv()
.unwrap_or_else(|_| panic!("Worker thread panicked"));
chunks[i] = res;
}

let best_circ = chunks.reassemble()?;
let best_circ_cost = self.cost(&best_circ);
if best_circ_cost.clone() < circ_cost {
logger.log_best(best_circ_cost.clone());
}

logger.log_processing_end(n_threads.get(), best_circ_cost, true, false);
joins.into_iter().for_each(|j| j.join().unwrap());

Ok(best_circ)
}
}

#[cfg(feature = "portmatching")]
Expand Down Expand Up @@ -360,7 +428,7 @@ mod tests {

#[rstest]
fn rz_rz_cancellation(rz_rz: Hugr, taso_opt: DefaultTasoOptimiser) {
let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap());
let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false);
let cmds = opt_rz
.commands()
.map(|cmd| {
Expand Down
8 changes: 4 additions & 4 deletions src/optimiser/taso/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ impl<'w> TasoLogger<'w> {
}
}

/// Internal function to log general events, normally printed to stdout.
/// Log general events, normally printed to stdout.
#[inline]
fn log(&self, msg: impl AsRef<str>) {
pub fn log(&self, msg: impl AsRef<str>) {
tracing::info!(target: LOG_TARGET, "{}", msg.as_ref());
}

/// Internal function to log information on the progress of the optimization.
/// Log verbose information on the progress of the optimization.
#[inline]
fn progress(&self, msg: impl AsRef<str>) {
pub fn progress(&self, msg: impl AsRef<str>) {
tracing::info!(target: PROGRESS_TARGET, "{}", msg.as_ref());
}
}
Expand Down
15 changes: 13 additions & 2 deletions taso-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ struct CmdLineArgs {
help = "The number of threads to use. By default, use a single thread."
)]
n_threads: Option<NonZeroUsize>,
/// Number of threads (default=1)
#[arg(
long = "split-circ",
help = "Split the circuit into chunks and optimize each one in a separate thread. Use `-j` to specify the number of threads to use."
)]
split_circ: bool,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -86,7 +92,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Setup tracing subscribers for stdout and file logging.
//
// We need to keep the object around to keep the logging active.
let _tracer = Tracer::setup_tracing(opts.logfile);
let _tracer = Tracer::setup_tracing(opts.logfile, opts.split_circ);

let input_path = Path::new(&opts.input);
let output_path = Path::new(&opts.output);
Expand Down Expand Up @@ -115,8 +121,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.unwrap_or(NonZeroUsize::new(1).unwrap());
println!("Using {n_threads} threads");

if opts.split_circ && n_threads.get() > 1 {
println!("Splitting circuit into {n_threads} chunks.");
}

println!("Optimising...");
let opt_circ = optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads);
let opt_circ =
optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads, opts.split_circ);

println!("Saving result");
save_tk1_json_file(&opt_circ, output_path)?;
Expand Down
12 changes: 7 additions & 5 deletions taso-optimiser/src/tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ pub struct Tracer {

impl Tracer {
/// Setup tracing subscribers for stdout and file logging.
pub fn setup_tracing(logfile: Option<PathBuf>) -> Self {
pub fn setup_tracing(logfile: Option<PathBuf>, show_threads: bool) -> Self {
let mut tracer = Self::default();
tracing_subscriber::registry()
.with(tracer.stdout_layer())
.with(logfile.map(|f| tracer.logfile_layer(f)))
.with(tracer.stdout_layer(show_threads))
.with(logfile.map(|f| tracer.logfile_layer(f, show_threads)))
.init();
tracer
}
Expand All @@ -48,25 +48,27 @@ impl Tracer {
}

/// Clean log with the most important events.
fn stdout_layer<S>(&mut self) -> impl Layer<S>
fn stdout_layer<S>(&mut self, show_threads: bool) -> impl Layer<S>
where
S: Subscriber + for<'span> tracing_subscriber::registry::LookupSpan<'span>,
{
tracing_subscriber::fmt::layer()
.without_time()
.with_target(false)
.with_level(false)
.with_thread_names(show_threads)
.with_filter(filter_fn(log_filter))
}

fn logfile_layer<S>(&mut self, logfile: PathBuf) -> impl Layer<S>
fn logfile_layer<S>(&mut self, logfile: PathBuf, show_threads: bool) -> impl Layer<S>
where
S: Subscriber + for<'span> tracing_subscriber::registry::LookupSpan<'span>,
{
let (non_blocking, guard) = self.init_writer(logfile);
self.logfile = Some(guard);
tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_thread_names(show_threads)
.with_writer(non_blocking)
.with_filter(filter_fn(verbose_filter))
}
Expand Down