From 34f05db865510f5fea69a15de29a3bf6bd50fef0 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 26 Sep 2023 17:22:38 +0200 Subject: [PATCH 1/4] feat[taso]: Use CX count as default cost function --- src/ops.rs | 34 ++++++++++++-- src/optimiser/taso.rs | 38 +++++++++++----- src/rewrite/strategy.rs | 91 +++++++++++++++++++++++++++++++------- taso-optimiser/src/main.rs | 8 ++-- 4 files changed, 135 insertions(+), 36 deletions(-) diff --git a/src/ops.rs b/src/ops.rs index 3d5bd9e8..1c1cb69a 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -71,6 +71,14 @@ pub enum T2Op { TK1, } +/// Whether an op is a given T2Op. +pub(crate) fn op_matches(op: &OpType, t2op: T2Op) -> bool { + let Ok(op) = T2Op::try_from(op) else { + return false; + }; + op == t2op +} + #[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)] #[cfg_attr(feature = "pyo3", pyclass)] #[allow(missing_docs)] @@ -294,17 +302,27 @@ impl TryFrom for T2Op { type Error = NotT2Op; fn try_from(op: OpType) -> Result { - let leaf: LeafOp = op.try_into().map_err(|_| NotT2Op)?; + Self::try_from(&op) + } +} + +impl TryFrom<&OpType> for T2Op { + type Error = NotT2Op; + + fn try_from(op: &OpType) -> Result { + let OpType::LeafOp(leaf) = op else { + return Err(NotT2Op); + }; leaf.try_into() } } -impl TryFrom for T2Op { +impl TryFrom<&LeafOp> for T2Op { type Error = NotT2Op; - fn try_from(op: LeafOp) -> Result { + fn try_from(op: &LeafOp) -> Result { match op { - LeafOp::CustomOp(b) => match *b { + LeafOp::CustomOp(b) => match b.as_ref() { ExternalOp::Extension(e) => Self::try_from_op_def(e.def()), ExternalOp::Opaque(o) => from_extension_name(o.extension(), o.name()), }, @@ -313,6 +331,14 @@ impl TryFrom for T2Op { } } +impl TryFrom for T2Op { + type Error = NotT2Op; + + fn try_from(op: LeafOp) -> Result { + Self::try_from(&op) + } +} + /// load all variants of a `SimpleOpEnum` in to an extension as op defs. fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> { for op in T::all_variants() { diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index ad88bc51..75b2dffc 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -293,20 +293,36 @@ where #[cfg(feature = "portmatching")] mod taso_default { - use crate::circuit::Circuit; - use crate::rewrite::strategy::ExhaustiveRewriteStrategy; + use hugr::ops::OpType; + use hugr::HugrView; + + use crate::ops::op_matches; + use crate::rewrite::strategy::{exhaustive_cx, ExhaustiveRewriteStrategy}; use crate::rewrite::ECCRewriter; + use crate::T2Op; use super::*; - impl TasoOptimiser usize> { - /// A sane default optimiser using the given ECC sets. - pub fn default_with_eccs_json_file( - eccs_path: impl AsRef, - ) -> io::Result { - let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = ExhaustiveRewriteStrategy::default(); - Ok(Self::new(rewriter, strategy, |c| c.num_gates())) - } + pub type DefaultTasoOptimiser = TasoOptimiser< + ECCRewriter, + ExhaustiveRewriteStrategy bool>, + fn(&Hugr) -> usize, + >; + + /// A sane default optimiser using the given ECC sets. + pub fn default_with_eccs_json_file( + eccs_path: impl AsRef, + ) -> io::Result { + let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; + let strategy = exhaustive_cx(); + Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) + } + + fn num_cx_gates(circ: &Hugr) -> usize { + circ.nodes() + .filter(|&n| op_matches(circ.get_optype(n), T2Op::CX)) + .count() } } +#[cfg(feature = "portmatching")] +pub use taso_default::default_with_eccs_json_file; diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 29d5a934..7e963dd3 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -10,10 +10,10 @@ use std::collections::HashSet; -use hugr::Hugr; +use hugr::{ops::OpType, Hugr, HugrView, Node}; use itertools::Itertools; -use crate::circuit::Circuit; +use crate::{ops::op_matches, T2Op}; use super::CircuitRewrite; @@ -79,27 +79,60 @@ impl RewriteStrategy for GreedyRewriteStrategy { /// circuit. /// /// The parameter gamma controls how greedy the algorithm should be. It allows -/// a rewrite C1 -> C2 if C2 has at most gamma times as many gates as C1: +/// a rewrite C1 -> C2 if C2 has at most gamma times the cost of C1: /// -/// $|C2| < gamma * |C1|$ +/// $cost(C2) < gamma * cost(C1)$ +/// +/// The cost function is given by the number of operations in the circuit that +/// satisfy a given Op predicate. This allows for instance to use the total +/// number of gates (true predicate) or the number of CX gates as cost function. /// /// gamma = 1 is the greedy strategy where a rewrite is only allowed if it -/// strictly reduces the gate count. The default is gamma = 1.0001, as set -/// in the Quartz paper. This essentially allows rewrites that improve or leave -/// the number of nodes unchanged. +/// strictly reduces the gate count. The default is gamma = 1.0001 (as set in +/// the Quartz paper) and the number of CX gates. This essentially allows +/// rewrites that improve or leave the number of CX unchanged. #[derive(Debug, Clone)] -pub struct ExhaustiveRewriteStrategy { +pub struct ExhaustiveRewriteStrategy

{ /// The gamma parameter. pub gamma: f64, + /// Ops to count for cost function. + pub op_predicate: P, } -impl Default for ExhaustiveRewriteStrategy { - fn default() -> Self { - Self { gamma: 1.0001 } +impl

ExhaustiveRewriteStrategy

{ + /// New exhaustive rewrite strategy with provided predicate. + /// + /// The gamma parameter is set to the default 1.0001. + pub fn with_predicate(op_predicate: P) -> Self { + Self { + gamma: 1.0001, + op_predicate, + } + } + + /// New exhaustive rewrite strategy with provided gamma and predicate. + pub fn new(gamma: f64, op_predicate: P) -> Self { + Self { + gamma, + op_predicate, + } } } -impl RewriteStrategy for ExhaustiveRewriteStrategy { +/// Exhaustive rewrite strategy with CX count cost function. +/// +/// The gamma parameter is set to the default 1.0001. This is a good default +/// choice for NISQ-y circuits, where CX gates are the most expensive. +pub fn exhaustive_cx() -> ExhaustiveRewriteStrategy bool> { + ExhaustiveRewriteStrategy::with_predicate(is_cx) +} + +/// Exhaustive rewrite strategy with CX count cost function and provided gamma. +pub fn exhaustive_cx_with_gamma(gamma: f64) -> ExhaustiveRewriteStrategy bool> { + ExhaustiveRewriteStrategy::new(gamma, is_cx) +} + +impl bool> RewriteStrategy for ExhaustiveRewriteStrategy

{ #[tracing::instrument(skip_all)] fn apply_rewrites( &self, @@ -109,8 +142,8 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy { rewrites .into_iter() .filter(|rw| { - let old_count = rw.subcircuit().node_count() as f64; - let new_count = rw.replacement().num_gates() as f64; + let old_count = pre_rewrite_cost(rw, circ, &self.op_predicate) as f64; + let new_count = post_rewrite_cost(rw, circ, &self.op_predicate) as f64; new_count < old_count * self.gamma }) .map(|rw| { @@ -122,6 +155,32 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy { } } +fn is_cx(op: &OpType) -> bool { + op_matches(op, T2Op::CX) +} + +fn cost( + nodes: impl IntoIterator, + circ: &Hugr, + pred: impl Fn(&OpType) -> bool, +) -> usize { + nodes + .into_iter() + .filter(|n| { + let op = circ.get_optype(*n); + pred(op) + }) + .count() +} + +fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { + cost(rw.subcircuit().nodes().iter().copied(), circ, pred) +} + +fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { + cost(rw.replacement().nodes(), circ, pred) +} + #[cfg(test)] mod tests { use super::*; @@ -197,7 +256,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy::default(); + let strategy = exhaustive_cx(); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); @@ -219,7 +278,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy { gamma: 10. }; + let strategy = exhaustive_cx_with_gamma(10.); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 95368318..7dc737c2 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -10,11 +10,9 @@ use std::{fs, path::Path}; use clap::Parser; use hugr::Hugr; +use tket2::json::{load_tk1_json_file, TKETDecode}; +use tket2::optimiser::taso; use tket2::optimiser::taso::log::TasoLogger; -use tket2::{ - json::{load_tk1_json_file, TKETDecode}, - optimiser::TasoOptimiser, -}; use tket_json_rs::circuit_json::SerialCircuit; #[cfg(feature = "peak_alloc")] @@ -111,7 +109,7 @@ fn main() -> Result<(), Box> { let circ = load_tk1_json_file(input_path)?; println!("Compiling rewriter..."); - let Ok(optimiser) = TasoOptimiser::default_with_eccs_json_file(ecc_path) else { + let Ok(optimiser) = taso::default_with_eccs_json_file(ecc_path) else { eprintln!( "Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?", ecc_path From 1eebccdd3fd79c967e967c16d9081d3ff049b5df Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 26 Sep 2023 19:11:50 +0200 Subject: [PATCH 2/4] Make defaults associated functions --- src/optimiser/taso.rs | 20 ++++++++++---------- src/rewrite/strategy.rs | 26 ++++++++++++++------------ taso-optimiser/src/main.rs | 4 ++-- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 75b2dffc..899783c9 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -297,7 +297,7 @@ mod taso_default { use hugr::HugrView; use crate::ops::op_matches; - use crate::rewrite::strategy::{exhaustive_cx, ExhaustiveRewriteStrategy}; + use crate::rewrite::strategy::ExhaustiveRewriteStrategy; use crate::rewrite::ECCRewriter; use crate::T2Op; @@ -309,13 +309,15 @@ mod taso_default { fn(&Hugr) -> usize, >; - /// A sane default optimiser using the given ECC sets. - pub fn default_with_eccs_json_file( - eccs_path: impl AsRef, - ) -> io::Result { - let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = exhaustive_cx(); - Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) + impl DefaultTasoOptimiser { + /// A sane default optimiser using the given ECC sets. + pub fn default_with_eccs_json_file( + eccs_path: impl AsRef, + ) -> io::Result { + let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) + } } fn num_cx_gates(circ: &Hugr) -> usize { @@ -324,5 +326,3 @@ mod taso_default { .count() } } -#[cfg(feature = "portmatching")] -pub use taso_default::default_with_eccs_json_file; diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 7e963dd3..3df7ceeb 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -119,17 +119,19 @@ impl

ExhaustiveRewriteStrategy

{ } } -/// Exhaustive rewrite strategy with CX count cost function. -/// -/// The gamma parameter is set to the default 1.0001. This is a good default -/// choice for NISQ-y circuits, where CX gates are the most expensive. -pub fn exhaustive_cx() -> ExhaustiveRewriteStrategy bool> { - ExhaustiveRewriteStrategy::with_predicate(is_cx) -} +impl ExhaustiveRewriteStrategy bool> { + /// Exhaustive rewrite strategy with CX count cost function. + /// + /// The gamma parameter is set to the default 1.0001. This is a good default + /// choice for NISQ-y circuits, where CX gates are the most expensive. + pub fn exhaustive_cx() -> Self { + ExhaustiveRewriteStrategy::with_predicate(is_cx) + } -/// Exhaustive rewrite strategy with CX count cost function and provided gamma. -pub fn exhaustive_cx_with_gamma(gamma: f64) -> ExhaustiveRewriteStrategy bool> { - ExhaustiveRewriteStrategy::new(gamma, is_cx) + /// Exhaustive rewrite strategy with CX count cost function and provided gamma. + pub fn exhaustive_cx_with_gamma(gamma: f64) -> Self { + ExhaustiveRewriteStrategy::new(gamma, is_cx) + } } impl bool> RewriteStrategy for ExhaustiveRewriteStrategy

{ @@ -256,7 +258,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = exhaustive_cx(); + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); @@ -278,7 +280,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = exhaustive_cx_with_gamma(10.); + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx_with_gamma(10.); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 7dc737c2..af8b1a5d 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -11,8 +11,8 @@ use std::{fs, path::Path}; use clap::Parser; use hugr::Hugr; use tket2::json::{load_tk1_json_file, TKETDecode}; -use tket2::optimiser::taso; use tket2::optimiser::taso::log::TasoLogger; +use tket2::optimiser::TasoOptimiser; use tket_json_rs::circuit_json::SerialCircuit; #[cfg(feature = "peak_alloc")] @@ -109,7 +109,7 @@ fn main() -> Result<(), Box> { let circ = load_tk1_json_file(input_path)?; println!("Compiling rewriter..."); - let Ok(optimiser) = taso::default_with_eccs_json_file(ecc_path) else { + let Ok(optimiser) = TasoOptimiser::default_with_eccs_json_file(ecc_path) else { eprintln!( "Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?", ecc_path From 27c48d494317f25fe3e06b555035ab6fe9dae1e9 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 26 Sep 2023 19:19:23 +0200 Subject: [PATCH 3/4] Expose DefaultTasoOptimiser --- src/optimiser.rs | 2 +- src/optimiser/taso.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/optimiser.rs b/src/optimiser.rs index 711caab6..2c86ee0f 100644 --- a/src/optimiser.rs +++ b/src/optimiser.rs @@ -3,4 +3,4 @@ //! Currently, the only optimiser is TASO pub mod taso; -pub use taso::TasoOptimiser; +pub use taso::{DefaultTasoOptimiser, TasoOptimiser}; diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 899783c9..2c7f4982 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -303,6 +303,7 @@ mod taso_default { use super::*; + /// The default TASO optimiser using ECC sets. pub type DefaultTasoOptimiser = TasoOptimiser< ECCRewriter, ExhaustiveRewriteStrategy bool>, @@ -326,3 +327,5 @@ mod taso_default { .count() } } +#[cfg(feature = "portmatching")] +pub use taso_default::DefaultTasoOptimiser; From dfcf787cf24d061845556177fd52da53cfca9fec Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 26 Sep 2023 19:24:29 +0200 Subject: [PATCH 4/4] Expose DefaultTasoOptimiser when portmatching is enabled --- src/optimiser.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/optimiser.rs b/src/optimiser.rs index 2c86ee0f..0760d0d3 100644 --- a/src/optimiser.rs +++ b/src/optimiser.rs @@ -3,4 +3,7 @@ //! Currently, the only optimiser is TASO pub mod taso; -pub use taso::{DefaultTasoOptimiser, TasoOptimiser}; + +#[cfg(feature = "portmatching")] +pub use taso::DefaultTasoOptimiser; +pub use taso::TasoOptimiser;