From 969d46329a85980efceae77c9416a99ad2908d54 Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:27:42 +0200 Subject: [PATCH] feat!: Use CX count as default cost function (#134) --- src/ops.rs | 34 ++++++++++++-- src/optimiser.rs | 3 ++ src/optimiser/taso.rs | 27 +++++++++-- src/rewrite/strategy.rs | 93 +++++++++++++++++++++++++++++++------- taso-optimiser/src/main.rs | 6 +-- 5 files changed, 135 insertions(+), 28 deletions(-) diff --git a/src/ops.rs b/src/ops.rs index 3d5bd9e8..1c1cb69a 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -71,6 +71,14 @@ pub enum T2Op { TK1, } +/// Whether an op is a given T2Op. +pub(crate) fn op_matches(op: &OpType, t2op: T2Op) -> bool { + let Ok(op) = T2Op::try_from(op) else { + return false; + }; + op == t2op +} + #[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)] #[cfg_attr(feature = "pyo3", pyclass)] #[allow(missing_docs)] @@ -294,17 +302,27 @@ impl TryFrom for T2Op { type Error = NotT2Op; fn try_from(op: OpType) -> Result { - let leaf: LeafOp = op.try_into().map_err(|_| NotT2Op)?; + Self::try_from(&op) + } +} + +impl TryFrom<&OpType> for T2Op { + type Error = NotT2Op; + + fn try_from(op: &OpType) -> Result { + let OpType::LeafOp(leaf) = op else { + return Err(NotT2Op); + }; leaf.try_into() } } -impl TryFrom for T2Op { +impl TryFrom<&LeafOp> for T2Op { type Error = NotT2Op; - fn try_from(op: LeafOp) -> Result { + fn try_from(op: &LeafOp) -> Result { match op { - LeafOp::CustomOp(b) => match *b { + LeafOp::CustomOp(b) => match b.as_ref() { ExternalOp::Extension(e) => Self::try_from_op_def(e.def()), ExternalOp::Opaque(o) => from_extension_name(o.extension(), o.name()), }, @@ -313,6 +331,14 @@ impl TryFrom for T2Op { } } +impl TryFrom for T2Op { + type Error = NotT2Op; + + fn try_from(op: LeafOp) -> Result { + Self::try_from(&op) + } +} + /// load all variants of a `SimpleOpEnum` in to an extension as op defs. fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> { for op in T::all_variants() { diff --git a/src/optimiser.rs b/src/optimiser.rs index 711caab6..0760d0d3 100644 --- a/src/optimiser.rs +++ b/src/optimiser.rs @@ -3,4 +3,7 @@ //! Currently, the only optimiser is TASO pub mod taso; + +#[cfg(feature = "portmatching")] +pub use taso::DefaultTasoOptimiser; pub use taso::TasoOptimiser; diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index ad88bc51..2c7f4982 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -293,20 +293,39 @@ where #[cfg(feature = "portmatching")] mod taso_default { - use crate::circuit::Circuit; + use hugr::ops::OpType; + use hugr::HugrView; + + use crate::ops::op_matches; use crate::rewrite::strategy::ExhaustiveRewriteStrategy; use crate::rewrite::ECCRewriter; + use crate::T2Op; use super::*; - impl TasoOptimiser usize> { + /// The default TASO optimiser using ECC sets. + pub type DefaultTasoOptimiser = TasoOptimiser< + ECCRewriter, + ExhaustiveRewriteStrategy bool>, + fn(&Hugr) -> usize, + >; + + impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. pub fn default_with_eccs_json_file( eccs_path: impl AsRef, ) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = ExhaustiveRewriteStrategy::default(); - Ok(Self::new(rewriter, strategy, |c| c.num_gates())) + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) } } + + fn num_cx_gates(circ: &Hugr) -> usize { + circ.nodes() + .filter(|&n| op_matches(circ.get_optype(n), T2Op::CX)) + .count() + } } +#[cfg(feature = "portmatching")] +pub use taso_default::DefaultTasoOptimiser; diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 29d5a934..3df7ceeb 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -10,10 +10,10 @@ use std::collections::HashSet; -use hugr::Hugr; +use hugr::{ops::OpType, Hugr, HugrView, Node}; use itertools::Itertools; -use crate::circuit::Circuit; +use crate::{ops::op_matches, T2Op}; use super::CircuitRewrite; @@ -79,27 +79,62 @@ impl RewriteStrategy for GreedyRewriteStrategy { /// circuit. /// /// The parameter gamma controls how greedy the algorithm should be. It allows -/// a rewrite C1 -> C2 if C2 has at most gamma times as many gates as C1: +/// a rewrite C1 -> C2 if C2 has at most gamma times the cost of C1: /// -/// $|C2| < gamma * |C1|$ +/// $cost(C2) < gamma * cost(C1)$ +/// +/// The cost function is given by the number of operations in the circuit that +/// satisfy a given Op predicate. This allows for instance to use the total +/// number of gates (true predicate) or the number of CX gates as cost function. /// /// gamma = 1 is the greedy strategy where a rewrite is only allowed if it -/// strictly reduces the gate count. The default is gamma = 1.0001, as set -/// in the Quartz paper. This essentially allows rewrites that improve or leave -/// the number of nodes unchanged. +/// strictly reduces the gate count. The default is gamma = 1.0001 (as set in +/// the Quartz paper) and the number of CX gates. This essentially allows +/// rewrites that improve or leave the number of CX unchanged. #[derive(Debug, Clone)] -pub struct ExhaustiveRewriteStrategy { +pub struct ExhaustiveRewriteStrategy

{ /// The gamma parameter. pub gamma: f64, + /// Ops to count for cost function. + pub op_predicate: P, +} + +impl

ExhaustiveRewriteStrategy

{ + /// New exhaustive rewrite strategy with provided predicate. + /// + /// The gamma parameter is set to the default 1.0001. + pub fn with_predicate(op_predicate: P) -> Self { + Self { + gamma: 1.0001, + op_predicate, + } + } + + /// New exhaustive rewrite strategy with provided gamma and predicate. + pub fn new(gamma: f64, op_predicate: P) -> Self { + Self { + gamma, + op_predicate, + } + } } -impl Default for ExhaustiveRewriteStrategy { - fn default() -> Self { - Self { gamma: 1.0001 } +impl ExhaustiveRewriteStrategy bool> { + /// Exhaustive rewrite strategy with CX count cost function. + /// + /// The gamma parameter is set to the default 1.0001. This is a good default + /// choice for NISQ-y circuits, where CX gates are the most expensive. + pub fn exhaustive_cx() -> Self { + ExhaustiveRewriteStrategy::with_predicate(is_cx) + } + + /// Exhaustive rewrite strategy with CX count cost function and provided gamma. + pub fn exhaustive_cx_with_gamma(gamma: f64) -> Self { + ExhaustiveRewriteStrategy::new(gamma, is_cx) } } -impl RewriteStrategy for ExhaustiveRewriteStrategy { +impl bool> RewriteStrategy for ExhaustiveRewriteStrategy

{ #[tracing::instrument(skip_all)] fn apply_rewrites( &self, @@ -109,8 +144,8 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy { rewrites .into_iter() .filter(|rw| { - let old_count = rw.subcircuit().node_count() as f64; - let new_count = rw.replacement().num_gates() as f64; + let old_count = pre_rewrite_cost(rw, circ, &self.op_predicate) as f64; + let new_count = post_rewrite_cost(rw, circ, &self.op_predicate) as f64; new_count < old_count * self.gamma }) .map(|rw| { @@ -122,6 +157,32 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy { } } +fn is_cx(op: &OpType) -> bool { + op_matches(op, T2Op::CX) +} + +fn cost( + nodes: impl IntoIterator, + circ: &Hugr, + pred: impl Fn(&OpType) -> bool, +) -> usize { + nodes + .into_iter() + .filter(|n| { + let op = circ.get_optype(*n); + pred(op) + }) + .count() +} + +fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { + cost(rw.subcircuit().nodes().iter().copied(), circ, pred) +} + +fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { + cost(rw.replacement().nodes(), circ, pred) +} + #[cfg(test)] mod tests { use super::*; @@ -197,7 +258,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy::default(); + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); @@ -219,7 +280,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy { gamma: 10. }; + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx_with_gamma(10.); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); diff --git a/taso-optimiser/src/main.rs b/taso-optimiser/src/main.rs index 95368318..af8b1a5d 100644 --- a/taso-optimiser/src/main.rs +++ b/taso-optimiser/src/main.rs @@ -10,11 +10,9 @@ use std::{fs, path::Path}; use clap::Parser; use hugr::Hugr; +use tket2::json::{load_tk1_json_file, TKETDecode}; use tket2::optimiser::taso::log::TasoLogger; -use tket2::{ - json::{load_tk1_json_file, TKETDecode}, - optimiser::TasoOptimiser, -}; +use tket2::optimiser::TasoOptimiser; use tket_json_rs::circuit_json::SerialCircuit; #[cfg(feature = "peak_alloc")]