Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Embarrassingly parallel TASO #149

Merged
merged 15 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ members = ["pyrs", "compile-rewriter", "taso-optimiser"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "af664e3" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "0ce711b" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
Expand Down
1 change: 1 addition & 0 deletions pyrs/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl PyDefaultTasoOptimiser {
taso_logger,
timeout,
n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
false,
)
})
}
Expand Down
116 changes: 103 additions & 13 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ mod worker;

use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
use hugr::hugr::HugrError;
pub use log::TasoLogger;

use std::num::NonZeroUsize;
use std::time::{Duration, Instant};
use std::{mem, thread};

use fxhash::FxHashSet;
use hugr::Hugr;
Expand All @@ -32,6 +34,7 @@ use crate::circuit::CircuitHash;
use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel;
use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::taso::worker::TasoWorker;
use crate::passes::CircuitChunks;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::Rewriter;

Expand Down Expand Up @@ -78,8 +81,14 @@ where
/// Run the TASO optimiser on a circuit.
///
/// A timeout (in seconds) can be provided.
pub fn optimise(&self, circ: &Hugr, timeout: Option<u64>, n_threads: NonZeroUsize) -> Hugr {
self.optimise_with_log(circ, Default::default(), timeout, n_threads)
pub fn optimise(
&self,
circ: &Hugr,
timeout: Option<u64>,
n_threads: NonZeroUsize,
split_circuit: bool,
) -> Hugr {
self.optimise_with_log(circ, Default::default(), timeout, n_threads, split_circuit)
}

/// Run the TASO optimiser on a circuit with logging activated.
Expand All @@ -91,7 +100,13 @@ where
log_config: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
split_circuit: bool,
) -> Hugr {
if split_circuit && n_threads.get() > 1 {
return self
.split_run(circ, log_config, timeout, n_threads)
.unwrap();
}
match n_threads.get() {
1 => self.taso(circ, log_config, timeout),
_ => self.taso_multithreaded(circ, log_config, timeout, n_threads),
Expand Down Expand Up @@ -285,20 +300,78 @@ where

best_circ
}

/// Split the circuit into chunks and process each in a separate thread.
#[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))]
fn split_run(
&self,
circ: &Hugr,
mut logger: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Result<Hugr, HugrError> {
// TODO: Add a parameter to set other split cost functions?
// (In contrast to `self.cost`, this is a counts the per-node cost)
let circ_cx_cost = cost_functions::num_cx_gates(circ);
let max_cx_cost = (circ_cx_cost.saturating_sub(1)) / n_threads.get() + 1;
logger.log(format!(
"Splitting circuit with cost {circ_cx_cost} into chunks of at most {max_cx_cost} CX gates"
));
let mut chunks = CircuitChunks::split_with_cost(circ, max_cx_cost, cost_functions::cx_cost);

let circ_cost = (self.cost)(circ);
logger.log_best(circ_cost);

let (joins, rx_work): (Vec<_>, Vec<_>) = chunks
.iter_mut()
.enumerate()
.map(|(i, chunk)| {
let (tx, rx) = crossbeam_channel::unbounded();
let taso = self.clone();
let chunk = mem::take(chunk);
let chunk_cx_cost = cost_functions::num_cx_gates(&chunk);
logger.log(format!("Chunk {i} has {chunk_cx_cost} CX gates",));
let join = thread::Builder::new()
.name(format!("chunk-{}", i))
.spawn(move || {
let res =
taso.optimise(&chunk, timeout, NonZeroUsize::new(1).unwrap(), false);
tx.send(res).unwrap();
})
.unwrap();
(join, rx)
})
.unzip();

for i in 0..chunks.len() {
let res = rx_work[i]
.recv()
.unwrap_or_else(|_| panic!("Worker thread panicked"));
chunks[i] = res;
}

let best_circ = chunks.reassemble()?;
let best_circ_cost = (self.cost)(&best_circ);
if best_circ_cost < circ_cost {
logger.log_best(best_circ_cost);
}

logger.log_processing_end(n_threads.get(), best_circ_cost, true, false);
joins.into_iter().for_each(|j| j.join().unwrap());

Ok(best_circ)
}
}

#[cfg(feature = "portmatching")]
mod taso_default {
use hugr::ops::OpType;
use hugr::HugrView;
use std::io;
use std::path::Path;

use crate::ops::op_matches;
use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use crate::rewrite::ECCRewriter;
use crate::T2Op;

use super::*;

Expand All @@ -314,7 +387,11 @@ mod taso_default {
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
Ok(TasoOptimiser::new(
rewriter,
strategy,
cost_functions::num_cx_gates,
))
}

/// A sane default optimiser using a precompiled binary rewriter.
Expand All @@ -323,15 +400,28 @@ mod taso_default {
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
Ok(TasoOptimiser::new(
rewriter,
strategy,
cost_functions::num_cx_gates,
))
}
}

fn num_cx_gates(circ: &Hugr) -> usize {
circ.nodes()
.filter(|&n| op_matches(circ.get_optype(n), T2Op::CX))
.count()
}
}
#[cfg(feature = "portmatching")]
pub use taso_default::DefaultTasoOptimiser;

mod cost_functions {
use super::*;
use crate::ops::op_matches;
use crate::T2Op;
use hugr::{HugrView, Node};

pub fn num_cx_gates(circ: &Hugr) -> usize {
circ.nodes().map(|n| cx_cost(circ, n)).sum()
}

pub fn cx_cost(circ: &Hugr, node: Node) -> usize {
op_matches(circ.get_optype(node), T2Op::CX) as usize
}
}
4 changes: 2 additions & 2 deletions src/optimiser/taso/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ impl<'w> TasoLogger<'w> {

/// Internal function to log general events, normally printed to stdout.
#[inline]
fn log(&self, msg: impl AsRef<str>) {
pub fn log(&self, msg: impl AsRef<str>) {
tracing::info!(target: LOG_TARGET, "{}", msg.as_ref());
}

/// Internal function to log information on the progress of the optimization.
#[inline]
fn progress(&self, msg: impl AsRef<str>) {
pub fn progress(&self, msg: impl AsRef<str>) {
tracing::info!(target: PROGRESS_TARGET, "{}", msg.as_ref());
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
}
}
Expand Down
70 changes: 58 additions & 12 deletions src/passes/chunks.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! Utility

use std::collections::HashMap;
use std::mem;
use std::ops::{Index, IndexMut};

use hugr::builder::{Dataflow, DataflowHugr, FunctionBuilder};
use hugr::builder::{Container, FunctionBuilder};
use hugr::extension::ExtensionSet;
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::views::sibling_subgraph::ConvexChecker;
Expand All @@ -13,7 +15,6 @@ use hugr::types::{FunctionType, Signature};
use hugr::{Hugr, HugrView, Node, Port, Wire};
use itertools::Itertools;

use crate::extension::REGISTRY;
use crate::Circuit;

#[cfg(feature = "pyo3")]
Expand All @@ -38,9 +39,9 @@ pub struct Chunk {
/// The extracted circuit.
pub circ: Hugr,
/// The original wires connected to the input.
pub inputs: Vec<ChunkConnection>,
inputs: Vec<ChunkConnection>,
/// The original wires connected to the output.
pub outputs: Vec<ChunkConnection>,
outputs: Vec<ChunkConnection>,
}

impl Chunk {
Expand Down Expand Up @@ -170,6 +171,17 @@ impl CircuitChunks {
///
/// The circuit is split into chunks of at most `max_size` gates.
pub fn split(circ: &impl Circuit, max_size: usize) -> Self {
Self::split_with_cost(circ, max_size, |_, _| 1)
}

/// Split a circuit into chunks.
///
/// The circuit is split into chunks of at most `max_cost`, using the provided cost function.
pub fn split_with_cost<C: Circuit>(
circ: &C,
max_cost: usize,
node_cost: impl Fn(&C, Node) -> usize,
) -> Self {
let root_meta = circ.get_metadata(circ.root()).clone();
let signature = circ.circuit_signature().clone();

Expand All @@ -186,7 +198,12 @@ impl CircuitChunks {

let mut chunks = Vec::new();
let mut convex_checker = ConvexChecker::new(circ);
for commands in &circ.commands().map(|cmd| cmd.node()).chunks(max_size) {
let mut running_cost = 0;
for (_, commands) in &circ.commands().map(|cmd| cmd.node()).group_by(|&node| {
let group = running_cost / max_cost;
running_cost += node_cost(circ, node);
group
}) {
chunks.push(Chunk::extract(circ, commands, &mut convex_checker));
}
Self {
Expand All @@ -211,10 +228,10 @@ impl CircuitChunks {
input_extensions: ExtensionSet::new(),
};

let builder = FunctionBuilder::new(name, signature).unwrap();
let inputs = builder.input_wires();
// TODO: Use the correct REGISTRY if the method accepts custom input resources.
let mut reassembled = builder.finish_hugr_with_outputs(inputs, &REGISTRY).unwrap();
let mut builder = FunctionBuilder::new(name, signature).unwrap();
// Take the unfinished Hugr from the builder, to avoid unnecessary
// validation checks that require connecting the inputs an outputs.
let mut reassembled = mem::take(builder.hugr_mut());
let root = reassembled.root();
let [reassembled_input, reassembled_output] = reassembled.get_io(root).unwrap();

Expand All @@ -229,7 +246,6 @@ impl CircuitChunks {
.iter()
.zip(reassembled.node_outputs(reassembled_input))
{
reassembled.disconnect(reassembled_input, port)?;
sources.insert(connection, (reassembled_input, port));
}
for (&connection, port) in self
Expand Down Expand Up @@ -267,9 +283,24 @@ impl CircuitChunks {
}

/// Returns a list of references to the split circuits.
pub fn circuits(&self) -> impl Iterator<Item = &Hugr> {
pub fn iter(&self) -> impl Iterator<Item = &Hugr> {
self.chunks.iter().map(|chunk| &chunk.circ)
}

/// Returns a list of references to the split circuits.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Hugr> {
self.chunks.iter_mut().map(|chunk| &mut chunk.circ)
}

/// Returns the number of chunks.
pub fn len(&self) -> usize {
self.chunks.len()
}

/// Returns `true` if there are no chunks.
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}
}

#[cfg(feature = "pyo3")]
Expand All @@ -285,7 +316,7 @@ impl CircuitChunks {
/// Returns clones of the split circuits.
#[pyo3(name = "circuits")]
fn py_circuits(&self) -> PyResult<Vec<Py<PyAny>>> {
self.circuits()
self.iter()
.map(|hugr| SerialCircuit::encode(hugr)?.to_tket1())
.collect()
}
Expand All @@ -304,9 +335,24 @@ impl CircuitChunks {
}
}

impl Index<usize> for CircuitChunks {
type Output = Hugr;

fn index(&self, index: usize) -> &Self::Output {
&self.chunks[index].circ
}
}

impl IndexMut<usize> for CircuitChunks {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.chunks[index].circ
}
}

#[cfg(test)]
mod test {
use crate::circuit::CircuitHash;
use crate::extension::REGISTRY;
use crate::utils::build_simple_circuit;
use crate::T2Op;

Expand Down
15 changes: 13 additions & 2 deletions taso-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ struct CmdLineArgs {
help = "The number of threads to use. By default, use a single thread."
)]
n_threads: Option<NonZeroUsize>,
/// Number of threads (default=1)
#[arg(
long = "split-circ",
help = "Split the circuit into chunks and optimize each one in a separate thread. Use `-j` to specify the number of threads to use."
)]
split_circ: bool,
}

fn save_tk1_json_file(path: impl AsRef<Path>, circ: &Hugr) -> Result<(), std::io::Error> {
Expand All @@ -96,7 +102,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Setup tracing subscribers for stdout and file logging.
//
// We need to keep the object around to keep the logging active.
let _tracer = Tracer::setup_tracing(opts.logfile);
let _tracer = Tracer::setup_tracing(opts.logfile, opts.split_circ);

let input_path = Path::new(&opts.input);
let output_path = Path::new(&opts.output);
Expand Down Expand Up @@ -125,8 +131,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.unwrap_or(NonZeroUsize::new(1).unwrap());
println!("Using {n_threads} threads");

if opts.split_circ && n_threads.get() > 1 {
println!("Splitting circuit into {n_threads} chunks.");
}

println!("Optimising...");
let opt_circ = optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads);
let opt_circ =
optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads, opts.split_circ);

println!("Saving result");
save_tk1_json_file(output_path, &opt_circ)?;
Expand Down
Loading