Skip to content

Commit

Permalink
feat: taso + circuit splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Sep 29, 2023
1 parent ca2d3df commit 2b64a63
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ members = ["pyrs", "compile-rewriter", "taso-optimiser"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "af664e3" }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "0ce711b" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
Expand Down
1 change: 1 addition & 0 deletions pyrs/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl PyDefaultTasoOptimiser {
taso_logger,
timeout,
n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
false,
)
})
}
Expand Down
116 changes: 103 additions & 13 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ mod worker;

use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
use hugr::hugr::HugrError;
pub use log::TasoLogger;

use std::num::NonZeroUsize;
use std::time::{Duration, Instant};
use std::{mem, thread};

use fxhash::FxHashSet;
use hugr::Hugr;
Expand All @@ -32,6 +34,7 @@ use crate::circuit::CircuitHash;
use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel;
use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ};
use crate::optimiser::taso::worker::TasoWorker;
use crate::passes::CircuitChunks;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::Rewriter;

Expand Down Expand Up @@ -78,8 +81,14 @@ where
/// Run the TASO optimiser on a circuit.
///
/// A timeout (in seconds) can be provided.
pub fn optimise(&self, circ: &Hugr, timeout: Option<u64>, n_threads: NonZeroUsize) -> Hugr {
self.optimise_with_log(circ, Default::default(), timeout, n_threads)
pub fn optimise(
&self,
circ: &Hugr,
timeout: Option<u64>,
n_threads: NonZeroUsize,
split_circuit: bool,
) -> Hugr {
self.optimise_with_log(circ, Default::default(), timeout, n_threads, split_circuit)
}

/// Run the TASO optimiser on a circuit with logging activated.
Expand All @@ -91,7 +100,13 @@ where
log_config: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
split_circuit: bool,
) -> Hugr {
if split_circuit && n_threads.get() > 1 {
return self
.split_run(circ, log_config, timeout, n_threads)
.unwrap();
}
match n_threads.get() {
1 => self.taso(circ, log_config, timeout),
_ => self.taso_multithreaded(circ, log_config, timeout, n_threads),
Expand Down Expand Up @@ -285,20 +300,78 @@ where

best_circ
}

/// Split the circuit into chunks and process each in a separate thread.
#[tracing::instrument(target = "taso::metrics", skip(self, circ, logger))]
fn split_run(
&self,
circ: &Hugr,
mut logger: TasoLogger,
timeout: Option<u64>,
n_threads: NonZeroUsize,
) -> Result<Hugr, HugrError> {
// TODO: Add a parameter to set other split cost functions?
// (In contrast to `self.cost`, this is a counts the per-node cost)
let circ_cx_cost = cost_functions::num_cx_gates(circ);
let max_cx_cost = (circ_cx_cost.saturating_sub(1)) / n_threads.get() + 1;
logger.log(format!(
"Splitting circuit with cost {circ_cx_cost} into chunks of at most {max_cx_cost} CX gates"
));
let mut chunks = CircuitChunks::split_with_cost(circ, max_cx_cost, cost_functions::cx_cost);

let circ_cost = (self.cost)(circ);
logger.log_best(circ_cost);

let (joins, rx_work): (Vec<_>, Vec<_>) = chunks
.iter_mut()
.enumerate()
.map(|(i, chunk)| {
let (tx, rx) = crossbeam_channel::unbounded();
let taso = self.clone();
let chunk = mem::take(chunk);
let chunk_cx_cost = cost_functions::num_cx_gates(&chunk);
logger.log(format!("Chunk {i} has {chunk_cx_cost} CX gates",));
let join = thread::Builder::new()
.name(format!("chunk-{}", i))
.spawn(move || {
let res =
taso.optimise(&chunk, timeout, NonZeroUsize::new(1).unwrap(), false);
tx.send(res).unwrap();
})
.unwrap();
(join, rx)
})
.unzip();

for i in 0..chunks.len() {
let res = rx_work[i]
.recv()
.unwrap_or_else(|_| panic!("Worker thread panicked"));
chunks[i] = res;
}

let best_circ = chunks.reassemble()?;
let best_circ_cost = (self.cost)(&best_circ);
if best_circ_cost < circ_cost {
logger.log_best(best_circ_cost);
}

logger.log_processing_end(n_threads.get(), best_circ_cost, true, false);
joins.into_iter().for_each(|j| j.join().unwrap());

Ok(best_circ)
}
}

#[cfg(feature = "portmatching")]
mod taso_default {
use hugr::ops::OpType;
use hugr::HugrView;
use std::io;
use std::path::Path;

use crate::ops::op_matches;
use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use crate::rewrite::ECCRewriter;
use crate::T2Op;

use super::*;

Expand All @@ -314,7 +387,11 @@ mod taso_default {
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
Ok(TasoOptimiser::new(
rewriter,
strategy,
cost_functions::num_cx_gates,
))
}

/// A sane default optimiser using a precompiled binary rewriter.
Expand All @@ -323,15 +400,28 @@ mod taso_default {
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
Ok(TasoOptimiser::new(
rewriter,
strategy,
cost_functions::num_cx_gates,
))
}
}

fn num_cx_gates(circ: &Hugr) -> usize {
circ.nodes()
.filter(|&n| op_matches(circ.get_optype(n), T2Op::CX))
.count()
}
}
#[cfg(feature = "portmatching")]
pub use taso_default::DefaultTasoOptimiser;

mod cost_functions {
use super::*;
use crate::ops::op_matches;
use crate::T2Op;
use hugr::{HugrView, Node};

pub fn num_cx_gates(circ: &Hugr) -> usize {
circ.nodes().map(|n| cx_cost(circ, n)).sum()
}

pub fn cx_cost(circ: &Hugr, node: Node) -> usize {
op_matches(circ.get_optype(node), T2Op::CX) as usize
}
}
4 changes: 2 additions & 2 deletions src/optimiser/taso/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ impl<'w> TasoLogger<'w> {

/// Internal function to log general events, normally printed to stdout.
#[inline]
fn log(&self, msg: impl AsRef<str>) {
pub fn log(&self, msg: impl AsRef<str>) {
tracing::info!(target: LOG_TARGET, "{}", msg.as_ref());
}

/// Internal function to log information on the progress of the optimization.
#[inline]
fn progress(&self, msg: impl AsRef<str>) {
pub fn progress(&self, msg: impl AsRef<str>) {
tracing::info!(target: PROGRESS_TARGET, "{}", msg.as_ref());
}
}
Expand Down
70 changes: 58 additions & 12 deletions src/passes/chunks.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! Utility
use std::collections::HashMap;
use std::mem;
use std::ops::{Index, IndexMut};

use hugr::builder::{Dataflow, DataflowHugr, FunctionBuilder};
use hugr::builder::{Container, FunctionBuilder};
use hugr::extension::ExtensionSet;
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::views::sibling_subgraph::ConvexChecker;
Expand All @@ -13,7 +15,6 @@ use hugr::types::{FunctionType, Signature};
use hugr::{Hugr, HugrView, Node, Port, Wire};
use itertools::Itertools;

use crate::extension::REGISTRY;
use crate::Circuit;

#[cfg(feature = "pyo3")]
Expand All @@ -38,9 +39,9 @@ pub struct Chunk {
/// The extracted circuit.
pub circ: Hugr,
/// The original wires connected to the input.
pub inputs: Vec<ChunkConnection>,
inputs: Vec<ChunkConnection>,
/// The original wires connected to the output.
pub outputs: Vec<ChunkConnection>,
outputs: Vec<ChunkConnection>,
}

impl Chunk {
Expand Down Expand Up @@ -170,6 +171,17 @@ impl CircuitChunks {
///
/// The circuit is split into chunks of at most `max_size` gates.
pub fn split(circ: &impl Circuit, max_size: usize) -> Self {
Self::split_with_cost(circ, max_size, |_, _| 1)
}

/// Split a circuit into chunks.
///
/// The circuit is split into chunks of at most `max_cost`, using the provided cost function.
pub fn split_with_cost<C: Circuit>(
circ: &C,
max_cost: usize,
node_cost: impl Fn(&C, Node) -> usize,
) -> Self {
let root_meta = circ.get_metadata(circ.root()).clone();
let signature = circ.circuit_signature().clone();

Expand All @@ -186,7 +198,12 @@ impl CircuitChunks {

let mut chunks = Vec::new();
let mut convex_checker = ConvexChecker::new(circ);
for commands in &circ.commands().map(|cmd| cmd.node()).chunks(max_size) {
let mut running_cost = 0;
for (_, commands) in &circ.commands().map(|cmd| cmd.node()).group_by(|&node| {
let group = running_cost / max_cost;
running_cost += node_cost(circ, node);
group
}) {
chunks.push(Chunk::extract(circ, commands, &mut convex_checker));
}
Self {
Expand All @@ -211,10 +228,10 @@ impl CircuitChunks {
input_extensions: ExtensionSet::new(),
};

let builder = FunctionBuilder::new(name, signature).unwrap();
let inputs = builder.input_wires();
// TODO: Use the correct REGISTRY if the method accepts custom input resources.
let mut reassembled = builder.finish_hugr_with_outputs(inputs, &REGISTRY).unwrap();
let mut builder = FunctionBuilder::new(name, signature).unwrap();
// Take the unfinished Hugr from the builder, to avoid unnecessary
// validation checks that require connecting the inputs an outputs.
let mut reassembled = mem::take(builder.hugr_mut());
let root = reassembled.root();
let [reassembled_input, reassembled_output] = reassembled.get_io(root).unwrap();

Expand All @@ -229,7 +246,6 @@ impl CircuitChunks {
.iter()
.zip(reassembled.node_outputs(reassembled_input))
{
reassembled.disconnect(reassembled_input, port)?;
sources.insert(connection, (reassembled_input, port));
}
for (&connection, port) in self
Expand Down Expand Up @@ -267,9 +283,24 @@ impl CircuitChunks {
}

/// Returns a list of references to the split circuits.
pub fn circuits(&self) -> impl Iterator<Item = &Hugr> {
pub fn iter(&self) -> impl Iterator<Item = &Hugr> {
self.chunks.iter().map(|chunk| &chunk.circ)
}

/// Returns a list of references to the split circuits.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Hugr> {
self.chunks.iter_mut().map(|chunk| &mut chunk.circ)
}

/// Returns the number of chunks.
pub fn len(&self) -> usize {
self.chunks.len()
}

/// Returns `true` if there are no chunks.
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}
}

#[cfg(feature = "pyo3")]
Expand All @@ -285,7 +316,7 @@ impl CircuitChunks {
/// Returns clones of the split circuits.
#[pyo3(name = "circuits")]
fn py_circuits(&self) -> PyResult<Vec<Py<PyAny>>> {
self.circuits()
self.iter()
.map(|hugr| SerialCircuit::encode(hugr)?.to_tket1())
.collect()
}
Expand All @@ -304,9 +335,24 @@ impl CircuitChunks {
}
}

impl Index<usize> for CircuitChunks {
type Output = Hugr;

fn index(&self, index: usize) -> &Self::Output {
&self.chunks[index].circ
}
}

impl IndexMut<usize> for CircuitChunks {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.chunks[index].circ
}
}

#[cfg(test)]
mod test {
use crate::circuit::CircuitHash;
use crate::extension::REGISTRY;
use crate::utils::build_simple_circuit;
use crate::T2Op;

Expand Down
15 changes: 13 additions & 2 deletions taso-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ struct CmdLineArgs {
help = "The number of threads to use. By default, use a single thread."
)]
n_threads: Option<NonZeroUsize>,
/// Number of threads (default=1)
#[arg(
long = "split-circ",
help = "Split the circuit into chunks and optimize each one in a separate thread. Use `-j` to specify the number of threads to use."
)]
split_circ: bool,
}

fn save_tk1_json_file(path: impl AsRef<Path>, circ: &Hugr) -> Result<(), std::io::Error> {
Expand All @@ -96,7 +102,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Setup tracing subscribers for stdout and file logging.
//
// We need to keep the object around to keep the logging active.
let _tracer = Tracer::setup_tracing(opts.logfile);
let _tracer = Tracer::setup_tracing(opts.logfile, opts.split_circ);

let input_path = Path::new(&opts.input);
let output_path = Path::new(&opts.output);
Expand Down Expand Up @@ -125,8 +131,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.unwrap_or(NonZeroUsize::new(1).unwrap());
println!("Using {n_threads} threads");

if opts.split_circ && n_threads.get() > 1 {
println!("Splitting circuit into {n_threads} chunks.");
}

println!("Optimising...");
let opt_circ = optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads);
let opt_circ =
optimiser.optimise_with_log(&circ, taso_logger, opts.timeout, n_threads, opts.split_circ);

println!("Saving result");
save_tk1_json_file(output_path, &opt_circ)?;
Expand Down
Loading

0 comments on commit 2b64a63

Please sign in to comment.