Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve TASO cost function and rewrite strategies #154

Merged
merged 7 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ impl T2Op {
_ => vec![],
}
}

/// Check if this op is a quantum op.
pub fn is_quantum(&self) -> bool {
use T2Op::*;
match self {
H | CX | T | S | X | Y | Z | Tdg | Sdg | ZZMax | RzF64 | RxF64 | PhasedX | ZZPhase
| CZ | TK1 => true,
AngleAdd | Measure => false,
}
}
}

/// Initialize a new custom symbolic expression constant op from a string.
Expand Down
136 changes: 98 additions & 38 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
pub use log::TasoLogger;

use std::fmt;
use std::num::NonZeroUsize;
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -52,28 +53,30 @@ use crate::rewrite::Rewriter;
/// [Quartz]: https://arxiv.org/abs/2204.09033
/// [TASO]: https://dl.acm.org/doi/10.1145/3341301.3359630
#[derive(Clone, Debug)]
pub struct TasoOptimiser<R, S, C> {
pub struct TasoOptimiser<R, S> {
rewriter: R,
strategy: S,
cost: C,
}

impl<R, S, C> TasoOptimiser<R, S, C> {
impl<R, S> TasoOptimiser<R, S> {
/// Create a new TASO optimiser.
pub fn new(rewriter: R, strategy: S, cost: C) -> Self {
Self {
rewriter,
strategy,
cost,
}
pub fn new(rewriter: R, strategy: S) -> Self {
Self { rewriter, strategy }
}

fn cost(&self, circ: &Hugr) -> S::Cost
where
S: RewriteStrategy,
{
self.strategy.circuit_cost(circ)
}
}

impl<R, S, C> TasoOptimiser<R, S, C>
impl<R, S> TasoOptimiser<R, S>
where
R: Rewriter + Send + Clone + 'static,
S: RewriteStrategy + Send + Clone + 'static,
C: Fn(&Hugr) -> usize + Send + Sync + Clone + 'static,
S: RewriteStrategy + Send + Sync + Clone + 'static,
S::Cost: fmt::Debug + serde::Serialize,
{
/// Run the TASO optimiser on a circuit.
///
Expand Down Expand Up @@ -103,15 +106,19 @@ where
let start_time = Instant::now();

let mut best_circ = circ.clone();
let mut best_circ_cost = (self.cost)(circ);
logger.log_best(best_circ_cost);
let mut best_circ_cost = self.cost(circ);
logger.log_best(&best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]);

// The priority queue of circuits to be processed (this should not get big)
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;
let mut pq = HugrPQ::with_capacity(&self.cost, PRIORITY_QUEUE_CAPACITY);
let cost_fn = {
let strategy = self.strategy.clone();
move |circ: &'_ Hugr| strategy.circuit_cost(circ)
};
let mut pq = HugrPQ::with_capacity(cost_fn, PRIORITY_QUEUE_CAPACITY);
pq.push(circ.clone());

let mut circ_cnt = 1;
Expand All @@ -120,7 +127,7 @@ where
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost;
logger.log_best(best_circ_cost);
logger.log_best(&best_circ_cost);
}

let rewrites = self.rewriter.get_rewrites(&circ);
Expand Down Expand Up @@ -168,15 +175,19 @@ where
const PRIORITY_QUEUE_CAPACITY: usize = 10_000;

// multi-consumer priority channel for queuing circuits to be processed by the workers
let cost_fn = {
let strategy = self.strategy.clone();
move |circ: &'_ Hugr| strategy.circuit_cost(circ)
};
let (tx_work, rx_work) =
HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads);
HugrPriorityChannel::init(cost_fn, PRIORITY_QUEUE_CAPACITY * n_threads);
// channel for sending circuits from threads back to main
let (tx_result, rx_result) = crossbeam_channel::unbounded();

let initial_circ_hash = circ.circuit_hash();
let mut best_circ = circ.clone();
let mut best_circ_cost = (self.cost)(&best_circ);
logger.log_best(best_circ_cost);
let mut best_circ_cost = self.cost(&best_circ);
logger.log_best(&best_circ_cost);

// Hash of seen circuits. Dot not store circuits as this map gets huge
let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]);
Expand Down Expand Up @@ -239,13 +250,13 @@ where
}
seen_hashes.insert(*circ_hash);

let cost = (self.cost)(circ);
let cost = self.cost(circ);

// Check if we got a new best circuit
if cost < best_circ_cost {
best_circ = circ.clone();
best_circ_cost = cost;
logger.log_best(best_circ_cost);
logger.log_best(&best_circ_cost);
}
jobs_sent += 1;
}
Expand Down Expand Up @@ -289,49 +300,98 @@ where

#[cfg(feature = "portmatching")]
mod taso_default {
use hugr::ops::OpType;
use hugr::HugrView;
use std::io;
use std::path::Path;

use crate::ops::op_matches;
use hugr::ops::OpType;

use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use crate::rewrite::strategy::NonIncreasingGateCountStrategy;
use crate::rewrite::ECCRewriter;
use crate::T2Op;

use super::*;

/// The default TASO optimiser using ECC sets.
pub type DefaultTasoOptimiser = TasoOptimiser<
ECCRewriter,
ExhaustiveRewriteStrategy<fn(&OpType) -> bool>,
fn(&Hugr) -> usize,
NonIncreasingGateCountStrategy<fn(&OpType) -> usize, fn(&OpType) -> usize>,
>;

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
let strategy = NonIncreasingGateCountStrategy::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}

/// A sane default optimiser using a precompiled binary rewriter.
pub fn default_with_rewriter_binary(
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
let strategy = NonIncreasingGateCountStrategy::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}
}

fn num_cx_gates(circ: &Hugr) -> usize {
circ.nodes()
.filter(|&n| op_matches(circ.get_optype(n), T2Op::CX))
.count()
}
}
#[cfg(feature = "portmatching")]
pub use taso_default::DefaultTasoOptimiser;

#[cfg(test)]
#[cfg(feature = "portmatching")]
mod tests {
use hugr::{
builder::{DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
types::FunctionType,
Hugr,
};
use rstest::{fixture, rstest};

use crate::{extension::REGISTRY, Circuit, T2Op};

use super::{DefaultTasoOptimiser, TasoOptimiser};

#[fixture]
fn rz_rz() -> Hugr {
let input_t = vec![QB_T, FLOAT64_TYPE, FLOAT64_TYPE];
let output_t = vec![QB_T];
let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap();

let mut inps = h.input_wires();
let qb = inps.next().unwrap();
let f1 = inps.next().unwrap();
let f2 = inps.next().unwrap();

let res = h.add_dataflow_op(T2Op::RzF64, [qb, f1]).unwrap();
let qb = res.outputs().next().unwrap();
let res = h.add_dataflow_op(T2Op::RzF64, [qb, f2]).unwrap();
let qb = res.outputs().next().unwrap();

h.finish_hugr_with_outputs([qb], &REGISTRY).unwrap()
}

#[fixture]
fn taso_opt() -> DefaultTasoOptimiser {
TasoOptimiser::default_with_eccs_json_file("test_files/small_eccs.json").unwrap()
}

#[rstest]
fn rz_rz_cancellation(rz_rz: Hugr, taso_opt: DefaultTasoOptimiser) {
let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap());
let cmds = opt_rz
.commands()
.map(|cmd| {
(
cmd.optype().try_into().unwrap(),
cmd.inputs().count(),
cmd.outputs().count(),
)
})
.collect::<Vec<(T2Op, _, _)>>();
let exp_cmds = vec![(T2Op::AngleAdd, 2, 1), (T2Op::RzF64, 2, 1)];
assert_eq!(cmds, exp_cmds);
}
}
22 changes: 11 additions & 11 deletions src/optimiser/taso/log.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Logging utilities for the TASO optimiser.

use std::io;
use std::{fmt::Debug, io};

/// Logging configuration for the TASO optimiser.
#[derive(Default)]
Expand Down Expand Up @@ -35,8 +35,8 @@ impl<'w> TasoLogger<'w> {

/// Log a new best candidate
#[inline]
pub fn log_best(&mut self, best_cost: usize) {
self.log(format!("new best of size {}", best_cost));
pub fn log_best<C: Debug + serde::Serialize>(&mut self, best_cost: C) {
self.log(format!("new best of size {:?}", best_cost));
if let Some(csv_writer) = self.circ_candidates_csv.as_mut() {
csv_writer.serialize(BestCircSer::new(best_cost)).unwrap();
csv_writer.flush().unwrap();
Expand All @@ -45,10 +45,10 @@ impl<'w> TasoLogger<'w> {

/// Log the final optimised circuit
#[inline]
pub fn log_processing_end(
pub fn log_processing_end<C: Debug>(
&self,
circuit_count: usize,
best_cost: usize,
best_cost: C,
needs_joining: bool,
timeout: bool,
) {
Expand All @@ -57,7 +57,7 @@ impl<'w> TasoLogger<'w> {
}
self.log("Optimisation finished");
self.log(format!("Tried {circuit_count} circuits"));
self.log(format!("END RESULT: {}", best_cost));
self.log(format!("END RESULT: {:?}", best_cost));
if needs_joining {
self.log("Joining worker threads");
}
Expand Down Expand Up @@ -98,14 +98,14 @@ impl<'w> TasoLogger<'w> {
//
// TODO: Replace this fixed logging. Report back intermediate results.
#[derive(serde::Serialize, Clone, Debug)]
struct BestCircSer {
circ_len: usize,
struct BestCircSer<C> {
circ_cost: C,
time: String,
}

impl BestCircSer {
fn new(circ_len: usize) -> Self {
impl<C> BestCircSer<C> {
fn new(circ_cost: C) -> Self {
let time = chrono::Local::now().to_rfc3339();
Self { circ_len, time }
Self { circ_cost, time }
}
}
6 changes: 3 additions & 3 deletions src/rewrite/ecc_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ mod tests {
let test_file = "test_files/small_eccs.json";
let rewriter = ECCRewriter::try_from_eccs_json_file(test_file).unwrap();
assert_eq!(rewriter.rewrite_rules.len(), rewriter.matcher.n_patterns());
assert_eq!(rewriter.targets.len(), 5 * 4 + 4 * 3);
assert_eq!(rewriter.targets.len(), 5 * 4 + 5 * 3);

// Assert that the rewrite rules are correct, i.e that the rewrite
// rules in the slice (k..=k+t) is given by [[k+1, ..., k+t], [k], ..., [k]]
Expand All @@ -361,8 +361,8 @@ mod tests {
curr_repr = TargetID(i);
}
}
// There should be 4x ECCs of size 3 and 5x ECCs of size 4
let exp_n_eccs_of_len = [0, 4 * 2 + 5 * 3, 4, 5];
// There should be 5x ECCs of size 3 and 5x ECCs of size 4
let exp_n_eccs_of_len = [0, 5 * 2 + 5 * 3, 5, 5];
assert_eq!(n_eccs_of_len, exp_n_eccs_of_len);
}

Expand Down
Loading