Skip to content

Commit

Permalink
feat: Improve TASO cost function and rewrite strategies (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada authored Oct 3, 2023
1 parent 742e64c commit 0dcf5e0
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 110 deletions.
10 changes: 10 additions & 0 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ impl T2Op {
_ => vec![],
}
}

/// Check if this op is a quantum op.
pub fn is_quantum(&self) -> bool {
use T2Op::*;
match self {
H | CX | T | S | X | Y | Z | Tdg | Sdg | ZZMax | RzF64 | RxF64 | PhasedX | ZZPhase
| CZ | TK1 => true,
AngleAdd | Measure => false,
}
}
}

/// Initialize a new custom symbolic expression constant op from a string.
Expand Down
136 changes: 98 additions & 38 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
pub use log::TasoLogger;

use std::fmt;
use std::num::NonZeroUsize;
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -52,28 +53,30 @@ use crate::rewrite::Rewriter;
/// [Quartz]: https://arxiv.org/abs/2204.09033
/// [TASO]: https://dl.acm.org/doi/10.1145/3341301.3359630
#[derive(Clone, Debug)]
pub struct TasoOptimiser<R, S, C> {
pub struct TasoOptimiser<R, S> {
rewriter: R,
strategy: S,
cost: C,
}

impl<R, S, C> TasoOptimiser<R, S, C> {
impl<R, S> TasoOptimiser<R, S> {
/// Create a new TASO optimiser.
pub fn new(rewriter: R, strategy: S, cost: C) -> Self {
Self {
rewriter,
strategy,
cost,
}
pub fn new(rewriter: R, strategy: S) -> Self {
Self { rewriter, strategy }
}

fn cost(&self, circ: &Hugr) -> S::Cost
where
S: RewriteStrategy,
{
self.strategy.circuit_cost(circ)
}
}

impl<R, S, C> TasoOptimiser<R, S, C>
impl<R, S> TasoOptimiser<R, S>
where
R: Rewriter + Send + Clone + 'static,
S: RewriteStrategy + Send + Clone + 'static,
C: Fn(&Hugr) -> usize + Send + Sync + Clone + 'static,
S: RewriteStrategy + Send + Sync + Clone + 'static,
S::Cost: fmt::Debug + serde::Serialize,
{
/// Run the TASO optimiser on a circuit.
///
Expand Down Expand Up @@ -103,15 +106,19 @@ where
let start_time = Instant::now();

let mut best_circ = circ.clone();
let mut best_circ_cost = (self.cost)(circ);
logger.log_best(best_circ_cost);
let mut best_circ_cost = self.cost(circ);
logger.log_best(&best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]);

// The priority queue of circuits to be processed (this should not get big)
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;
let mut pq = HugrPQ::with_capacity(&self.cost, PRIORITY_QUEUE_CAPACITY);
let cost_fn = {
let strategy = self.strategy.clone();
move |circ: &'_ Hugr| strategy.circuit_cost(circ)
};
let mut pq = HugrPQ::with_capacity(cost_fn, PRIORITY_QUEUE_CAPACITY);
pq.push(circ.clone());

let mut circ_cnt = 1;
Expand All @@ -120,7 +127,7 @@ where
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost;
logger.log_best(best_circ_cost);
logger.log_best(&best_circ_cost);
}

let rewrites = self.rewriter.get_rewrites(&circ);
Expand Down Expand Up @@ -168,15 +175,19 @@ where
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;

// multi-consumer priority channel for queuing circuits to be processed by the workers
let cost_fn = {
let strategy = self.strategy.clone();
move |circ: &'_ Hugr| strategy.circuit_cost(circ)
};
let (tx_work, rx_work) =
HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads);
HugrPriorityChannel::init(cost_fn, PRIORITY_QUEUE_CAPACITY * n_threads);
// channel for sending circuits from threads back to main
let (tx_result, rx_result) = crossbeam_channel::unbounded();

let initial_circ_hash = circ.circuit_hash();
let mut best_circ = circ.clone();
let mut best_circ_cost = (self.cost)(&best_circ);
logger.log_best(best_circ_cost);
let mut best_circ_cost = self.cost(&best_circ);
logger.log_best(&best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]);
Expand Down Expand Up @@ -239,13 +250,13 @@ where
}
seen_hashes.insert(*circ_hash);

let cost = (self.cost)(circ);
let cost = self.cost(circ);

// Check if we got a new best circuit
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost;
logger.log_best(best_circ_cost);
logger.log_best(&best_circ_cost);
}
jobs_sent += 1;
}
Expand Down Expand Up @@ -289,49 +300,98 @@ where

#[cfg(feature = "portmatching")]
mod taso_default {
use hugr::ops::OpType;
use hugr::HugrView;
use std::io;
use std::path::Path;

use crate::ops::op_matches;
use hugr::ops::OpType;

use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use crate::rewrite::strategy::NonIncreasingGateCountStrategy;
use crate::rewrite::ECCRewriter;
use crate::T2Op;

use super::*;

/// The default TASO optimiser using ECC sets.
pub type DefaultTasoOptimiser = TasoOptimiser<
ECCRewriter,
ExhaustiveRewriteStrategy<fn(&OpType) -> bool>,
fn(&Hugr) -> usize,
NonIncreasingGateCountStrategy<fn(&OpType) -> usize, fn(&OpType) -> usize>,
>;

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
let strategy = NonIncreasingGateCountStrategy::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}

/// A sane default optimiser using a precompiled binary rewriter.
pub fn default_with_rewriter_binary(
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
let strategy = NonIncreasingGateCountStrategy::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}
}

fn num_cx_gates(circ: &Hugr) -> usize {
circ.nodes()
.filter(|&n| op_matches(circ.get_optype(n), T2Op::CX))
.count()
}
}
#[cfg(feature = "portmatching")]
pub use taso_default::DefaultTasoOptimiser;

#[cfg(test)]
#[cfg(feature = "portmatching")]
mod tests {
use hugr::{
builder::{DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
types::FunctionType,
Hugr,
};
use rstest::{fixture, rstest};

use crate::{extension::REGISTRY, Circuit, T2Op};

use super::{DefaultTasoOptimiser, TasoOptimiser};

#[fixture]
fn rz_rz() -> Hugr {
let input_t = vec![QB_T, FLOAT64_TYPE, FLOAT64_TYPE];
let output_t = vec![QB_T];
let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap();

let mut inps = h.input_wires();
let qb = inps.next().unwrap();
let f1 = inps.next().unwrap();
let f2 = inps.next().unwrap();

let res = h.add_dataflow_op(T2Op::RzF64, [qb, f1]).unwrap();
let qb = res.outputs().next().unwrap();
let res = h.add_dataflow_op(T2Op::RzF64, [qb, f2]).unwrap();
let qb = res.outputs().next().unwrap();

h.finish_hugr_with_outputs([qb], &REGISTRY).unwrap()
}

#[fixture]
fn taso_opt() -> DefaultTasoOptimiser {
TasoOptimiser::default_with_eccs_json_file("test_files/small_eccs.json").unwrap()
}

#[rstest]
fn rz_rz_cancellation(rz_rz: Hugr, taso_opt: DefaultTasoOptimiser) {
let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap());
let cmds = opt_rz
.commands()
.map(|cmd| {
(
cmd.optype().try_into().unwrap(),
cmd.inputs().count(),
cmd.outputs().count(),
)
})
.collect::<Vec<(T2Op, _, _)>>();
let exp_cmds = vec![(T2Op::AngleAdd, 2, 1), (T2Op::RzF64, 2, 1)];
assert_eq!(cmds, exp_cmds);
}
}
22 changes: 11 additions & 11 deletions src/optimiser/taso/log.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Logging utilities for the TASO optimiser.
use std::io;
use std::{fmt::Debug, io};

/// Logging configuration for the TASO optimiser.
#[derive(Default)]
Expand Down Expand Up @@ -35,8 +35,8 @@ impl<'w> TasoLogger<'w> {

/// Log a new best candidate
#[inline]
pub fn log_best(&mut self, best_cost: usize) {
self.log(format!("new best of size {}", best_cost));
pub fn log_best<C: Debug + serde::Serialize>(&mut self, best_cost: C) {
self.log(format!("new best of size {:?}", best_cost));
if let Some(csv_writer) = self.circ_candidates_csv.as_mut() {
csv_writer.serialize(BestCircSer::new(best_cost)).unwrap();
csv_writer.flush().unwrap();
Expand All @@ -45,10 +45,10 @@ impl<'w> TasoLogger<'w> {

/// Log the final optimised circuit
#[inline]
pub fn log_processing_end(
pub fn log_processing_end<C: Debug>(
&self,
circuit_count: usize,
best_cost: usize,
best_cost: C,
needs_joining: bool,
timeout: bool,
) {
Expand All @@ -57,7 +57,7 @@ impl<'w> TasoLogger<'w> {
}
self.log("Optimisation finished");
self.log(format!("Tried {circuit_count} circuits"));
self.log(format!("END RESULT: {}", best_cost));
self.log(format!("END RESULT: {:?}", best_cost));
if needs_joining {
self.log("Joining worker threads");
}
Expand Down Expand Up @@ -98,14 +98,14 @@ impl<'w> TasoLogger<'w> {
//
// TODO: Replace this fixed logging. Report back intermediate results.
#[derive(serde::Serialize, Clone, Debug)]
struct BestCircSer {
circ_len: usize,
struct BestCircSer<C> {
circ_cost: C,
time: String,
}

impl BestCircSer {
fn new(circ_len: usize) -> Self {
impl<C> BestCircSer<C> {
fn new(circ_cost: C) -> Self {
let time = chrono::Local::now().to_rfc3339();
Self { circ_len, time }
Self { circ_cost, time }
}
}
6 changes: 3 additions & 3 deletions src/rewrite/ecc_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ mod tests {
let test_file = "test_files/small_eccs.json";
let rewriter = ECCRewriter::try_from_eccs_json_file(test_file).unwrap();
assert_eq!(rewriter.rewrite_rules.len(), rewriter.matcher.n_patterns());
assert_eq!(rewriter.targets.len(), 5 * 4 + 4 * 3);
assert_eq!(rewriter.targets.len(), 5 * 4 + 5 * 3);

// Assert that the rewrite rules are correct, i.e that the rewrite
// rules in the slice (k..=k+t) is given by [[k+1, ..., k+t], [k], ..., [k]]
Expand All @@ -361,8 +361,8 @@ mod tests {
curr_repr = TargetID(i);
}
}
// There should be 4x ECCs of size 3 and 5x ECCs of size 4
let exp_n_eccs_of_len = [0, 4 * 2 + 5 * 3, 4, 5];
// There should be 5x ECCs of size 3 and 5x ECCs of size 4
let exp_n_eccs_of_len = [0, 5 * 2 + 5 * 3, 5, 5];
assert_eq!(n_eccs_of_len, exp_n_eccs_of_len);
}

Expand Down
Loading

0 comments on commit 0dcf5e0

Please sign in to comment.