From 1bbb94b4a2c4e2c0d526dbda84d57cb266fdc28d Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 3 Oct 2023 09:38:16 +0200 Subject: [PATCH 1/5] feat: Improve TASO cost function and rewrite strategies --- src/ops.rs | 10 ++ src/optimiser/taso.rs | 139 +++++++++++++------ src/optimiser/taso/log.rs | 22 +-- src/rewrite/ecc_rewriter.rs | 6 +- src/rewrite/strategy.rs | 266 ++++++++++++++++++++++++++++++------ test_files/small_eccs.json | 5 + 6 files changed, 354 insertions(+), 94 deletions(-) diff --git a/src/ops.rs b/src/ops.rs index 1c1cb69a..1ac5c230 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -200,6 +200,16 @@ impl T2Op { _ => vec![], } } + + /// Check if this op is a quantum op. + pub fn is_quantum(&self) -> bool { + use T2Op::*; + match self { + H | CX | T | S | X | Y | Z | Tdg | Sdg | ZZMax | RzF64 | RxF64 | PhasedX | ZZPhase + | CZ | TK1 => true, + AngleAdd | Measure => false, + } + } } /// Initialize a new custom symbolic expression constant op from a string. diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index a231aaf8..7208d841 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -22,11 +22,12 @@ use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; pub use log::TasoLogger; +use std::fmt; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; use fxhash::FxHashSet; -use hugr::Hugr; +use hugr::{Hugr, HugrView}; use crate::circuit::CircuitHash; use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel; @@ -52,28 +53,30 @@ use crate::rewrite::Rewriter; /// [Quartz]: https://arxiv.org/abs/2204.09033 /// [TASO]: https://dl.acm.org/doi/10.1145/3341301.3359630 #[derive(Clone, Debug)] -pub struct TasoOptimiser { +pub struct TasoOptimiser { rewriter: R, strategy: S, - cost: C, } -impl TasoOptimiser { +impl TasoOptimiser { /// Create a new TASO optimiser. - pub fn new(rewriter: R, strategy: S, cost: C) -> Self { - Self { - rewriter, - strategy, - cost, - } + pub fn new(rewriter: R, strategy: S) -> Self { + Self { rewriter, strategy } + } + + fn cost(&self, circ: &Hugr) -> S::Cost + where + S: RewriteStrategy, + { + self.strategy.circuit_cost(circ) } } -impl TasoOptimiser +impl TasoOptimiser where R: Rewriter + Send + Clone + 'static, - S: RewriteStrategy + Send + Clone + 'static, - C: Fn(&Hugr) -> usize + Send + Sync + Clone + 'static, + S: RewriteStrategy + Send + Sync + Clone + 'static, + S::Cost: fmt::Debug + serde::Serialize, { /// Run the TASO optimiser on a circuit. /// @@ -103,15 +106,19 @@ where let start_time = Instant::now(); let mut best_circ = circ.clone(); - let mut best_circ_cost = (self.cost)(circ); - logger.log_best(best_circ_cost); + let mut best_circ_cost = self.cost(circ); + logger.log_best(&best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); // The priority queue of circuits to be processed (this should not get big) const PRIORITY_QUEUE_CAPACITY: usize = 10_000; - let mut pq = HugrPQ::with_capacity(&self.cost, PRIORITY_QUEUE_CAPACITY); + let cost_fn = { + let strategy = self.strategy.clone(); + move |circ: &'_ Hugr| strategy.circuit_cost(circ) + }; + let mut pq = HugrPQ::with_capacity(cost_fn, PRIORITY_QUEUE_CAPACITY); pq.push(circ.clone()); let mut circ_cnt = 1; @@ -120,7 +127,7 @@ where if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost; - logger.log_best(best_circ_cost); + logger.log_best(&best_circ_cost); } let rewrites = self.rewriter.get_rewrites(&circ); @@ -168,15 +175,19 @@ where const PRIORITY_QUEUE_CAPACITY: usize = 10_000; // multi-consumer priority channel for queuing circuits to be processed by the workers + let cost_fn = { + let strategy = self.strategy.clone(); + move |circ: &'_ Hugr| strategy.circuit_cost(circ) + }; let (tx_work, rx_work) = - HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads); + HugrPriorityChannel::init(cost_fn, PRIORITY_QUEUE_CAPACITY * n_threads); // channel for sending circuits from threads back to main let (tx_result, rx_result) = crossbeam_channel::unbounded(); let initial_circ_hash = circ.circuit_hash(); let mut best_circ = circ.clone(); - let mut best_circ_cost = (self.cost)(&best_circ); - logger.log_best(best_circ_cost); + let mut best_circ_cost = self.cost(&best_circ); + logger.log_best(&best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]); @@ -239,13 +250,13 @@ where } seen_hashes.insert(*circ_hash); - let cost = (self.cost)(circ); + let cost = self.cost(circ); // Check if we got a new best circuit if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost; - logger.log_best(best_circ_cost); + logger.log_best(&best_circ_cost); } jobs_sent += 1; } @@ -289,32 +300,24 @@ where #[cfg(feature = "portmatching")] mod taso_default { - use hugr::ops::OpType; - use hugr::HugrView; use std::io; use std::path::Path; - use crate::ops::op_matches; use crate::rewrite::ecc_rewriter::RewriterSerialisationError; - use crate::rewrite::strategy::ExhaustiveRewriteStrategy; + use crate::rewrite::strategy::NonIncreasingCXCountStrategy; use crate::rewrite::ECCRewriter; - use crate::T2Op; use super::*; /// The default TASO optimiser using ECC sets. - pub type DefaultTasoOptimiser = TasoOptimiser< - ECCRewriter, - ExhaustiveRewriteStrategy bool>, - fn(&Hugr) -> usize, - >; + pub type DefaultTasoOptimiser = TasoOptimiser; impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); - Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) + let strategy = NonIncreasingCXCountStrategy::default_cx(); + Ok(TasoOptimiser::new(rewriter, strategy)) } /// A sane default optimiser using a precompiled binary rewriter. @@ -322,16 +325,68 @@ mod taso_default { rewriter_path: impl AsRef, ) -> Result { let rewriter = ECCRewriter::load_binary(rewriter_path)?; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); - Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) + let strategy = NonIncreasingCXCountStrategy::default_cx(); + Ok(TasoOptimiser::new(rewriter, strategy)) } } - - fn num_cx_gates(circ: &Hugr) -> usize { - circ.nodes() - .filter(|&n| op_matches(circ.get_optype(n), T2Op::CX)) - .count() - } } #[cfg(feature = "portmatching")] pub use taso_default::DefaultTasoOptimiser; + +#[cfg(test)] +#[cfg(feature = "portmatching")] +mod tests { + use hugr::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::QB_T, + std_extensions::arithmetic::float_types::FLOAT64_TYPE, + types::FunctionType, + Hugr, + }; + use rstest::{fixture, rstest}; + + use crate::{extension::REGISTRY, Circuit, T2Op}; + + use super::{DefaultTasoOptimiser, TasoOptimiser}; + + #[fixture] + fn rz_rz() -> Hugr { + let input_t = vec![QB_T, FLOAT64_TYPE, FLOAT64_TYPE]; + let output_t = vec![QB_T]; + let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); + + let mut inps = h.input_wires(); + let qb = inps.next().unwrap(); + let f1 = inps.next().unwrap(); + let f2 = inps.next().unwrap(); + + let res = h.add_dataflow_op(T2Op::RzF64, [qb, f1]).unwrap(); + let qb = res.outputs().next().unwrap(); + let res = h.add_dataflow_op(T2Op::RzF64, [qb, f2]).unwrap(); + let qb = res.outputs().next().unwrap(); + + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + } + + #[fixture] + fn taso_opt() -> DefaultTasoOptimiser { + TasoOptimiser::default_with_eccs_json_file("test_files/small_eccs.json").unwrap() + } + + #[rstest] + fn rz_rz_cancellation(rz_rz: Hugr, taso_opt: DefaultTasoOptimiser) { + let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap()); + let cmds = opt_rz + .commands() + .map(|cmd| { + ( + cmd.optype().try_into().unwrap(), + cmd.inputs().count(), + cmd.outputs().count(), + ) + }) + .collect::>(); + let exp_cmds = vec![(T2Op::AngleAdd, 2, 1), (T2Op::RzF64, 2, 1)]; + assert_eq!(cmds, exp_cmds); + } +} diff --git a/src/optimiser/taso/log.rs b/src/optimiser/taso/log.rs index 39969ee6..9a44f717 100644 --- a/src/optimiser/taso/log.rs +++ b/src/optimiser/taso/log.rs @@ -1,6 +1,6 @@ //! Logging utilities for the TASO optimiser. -use std::io; +use std::{fmt::Debug, io}; /// Logging configuration for the TASO optimiser. #[derive(Default)] @@ -35,8 +35,8 @@ impl<'w> TasoLogger<'w> { /// Log a new best candidate #[inline] - pub fn log_best(&mut self, best_cost: usize) { - self.log(format!("new best of size {}", best_cost)); + pub fn log_best(&mut self, best_cost: C) { + self.log(format!("new best of size {:?}", best_cost)); if let Some(csv_writer) = self.circ_candidates_csv.as_mut() { csv_writer.serialize(BestCircSer::new(best_cost)).unwrap(); csv_writer.flush().unwrap(); @@ -45,10 +45,10 @@ impl<'w> TasoLogger<'w> { /// Log the final optimised circuit #[inline] - pub fn log_processing_end( + pub fn log_processing_end( &self, circuit_count: usize, - best_cost: usize, + best_cost: C, needs_joining: bool, timeout: bool, ) { @@ -57,7 +57,7 @@ impl<'w> TasoLogger<'w> { } self.log("Optimisation finished"); self.log(format!("Tried {circuit_count} circuits")); - self.log(format!("END RESULT: {}", best_cost)); + self.log(format!("END RESULT: {:?}", best_cost)); if needs_joining { self.log("Joining worker threads"); } @@ -98,14 +98,14 @@ impl<'w> TasoLogger<'w> { // // TODO: Replace this fixed logging. Report back intermediate results. #[derive(serde::Serialize, Clone, Debug)] -struct BestCircSer { - circ_len: usize, +struct BestCircSer { + circ_cost: C, time: String, } -impl BestCircSer { - fn new(circ_len: usize) -> Self { +impl BestCircSer { + fn new(circ_cost: C) -> Self { let time = chrono::Local::now().to_rfc3339(); - Self { circ_len, time } + Self { circ_cost, time } } } diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index 8f0a5a8d..b5fd7a47 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -278,7 +278,7 @@ mod tests { let test_file = "test_files/small_eccs.json"; let rewriter = ECCRewriter::try_from_eccs_json_file(test_file).unwrap(); assert_eq!(rewriter.rewrite_rules.len(), rewriter.matcher.n_patterns()); - assert_eq!(rewriter.targets.len(), 5 * 4 + 4 * 3); + assert_eq!(rewriter.targets.len(), 5 * 4 + 5 * 3); // Assert that the rewrite rules are correct, i.e that the rewrite // rules in the slice (k..=k+t) is given by [[k+1, ..., k+t], [k], ..., [k]] @@ -301,8 +301,8 @@ mod tests { curr_repr = TargetID(i); } } - // There should be 4x ECCs of size 3 and 5x ECCs of size 4 - let exp_n_eccs_of_len = [0, 4 * 2 + 5 * 3, 4, 5]; + // There should be 5x ECCs of size 3 and 5x ECCs of size 4 + let exp_n_eccs_of_len = [0, 5 * 2 + 5 * 3, 5, 5]; assert_eq!(n_eccs_of_len, exp_n_eccs_of_len); } } diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 3df7ceeb..fe5c51cb 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -8,12 +8,13 @@ //! times as there are possible rewrites and applies a different rewrite //! to every circuit. -use std::collections::HashSet; +use std::{collections::HashSet, fmt::Debug, iter::Sum}; +use derive_more::From; use hugr::{ops::OpType, Hugr, HugrView, Node}; use itertools::Itertools; -use crate::{ops::op_matches, T2Op}; +use crate::{ops::op_matches, Circuit, T2Op}; use super::CircuitRewrite; @@ -23,13 +24,22 @@ use super::CircuitRewrite; /// to a circuit according to a strategy. It returns a list of new circuits, /// each obtained by applying one or several non-overlapping rewrites to the /// original circuit. +/// +/// It also assign every circuit a totally ordered cost that can be used when +/// using rewrites for circuit optimisation. pub trait RewriteStrategy { + /// The circuit cost to be minised. + type Cost: Ord; + /// Apply a set of rewrites to a circuit. fn apply_rewrites( &self, rewrites: impl IntoIterator, circ: &Hugr, ) -> Vec; + + /// The cost of a circuit. + fn circuit_cost(&self, circ: &Hugr) -> Self::Cost; } /// A rewrite strategy applying as many non-overlapping rewrites as possible. @@ -45,6 +55,8 @@ pub trait RewriteStrategy { pub struct GreedyRewriteStrategy; impl RewriteStrategy for GreedyRewriteStrategy { + type Cost = usize; + #[tracing::instrument(skip_all)] fn apply_rewrites( &self, @@ -73,68 +85,168 @@ impl RewriteStrategy for GreedyRewriteStrategy { } vec![circ] } + + fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + circ.num_gates() + } } -/// A rewrite strategy that explores applying each rewrite to copies of the -/// circuit. +/// Exhaustive rewrite strategy allowing smaller or equal cost rewrites. +/// +/// Rewrites are permitted based on a cost function called the major cost: if +/// the major cost of the target of the rewrite is smaller or equal to the major +/// cost of the pattern, the rewrite is allowed. +/// +/// A second cost function, the minor cost, is used as a tie breaker: within +/// circuits with the same major cost, the circuit ordering prioritises circuits +/// with a smaller minor cost. +/// +/// An example would be to use the number of CX gates as major cost and the +/// total number of gates as minor cost. Compared to a [`ExhaustiveGammaStrategy`], +/// that would only order circuits based on the number of CX gates, this creates +/// a less flat optimisation landscape. +#[derive(Debug, Clone)] +pub struct NonIncreasingGateCountStrategy { + major_cost: C1, + minor_cost: C2, +} + +impl ExhaustiveThresholdStrategy for NonIncreasingGateCountStrategy +where + C1: Fn(&OpType) -> usize, + C2: Fn(&OpType) -> usize, +{ + type OpCost = MajorMinorCost; + type SumOpCost = MajorMinorCost; + + fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool { + target_cost.major <= pattern_cost.major + } + + fn op_cost(&self, op: &OpType) -> Self::OpCost { + ((self.major_cost)(op), (self.minor_cost)(op)).into() + } +} + +/// Non-increasing rewrite strategy based on CX count. +/// +/// The minor cost to break ties between equal CX counts is the number of +/// quantum gates. +pub type NonIncreasingCXCountStrategy = + NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize>; + +impl NonIncreasingCXCountStrategy { + /// Create rewrite strategy based on non-increasing CX count. + pub fn default_cx() -> Self { + Self { + major_cost: |op| is_cx(op) as usize, + minor_cost: |op| is_quantum(op) as usize, + } + } +} + +/// Exhaustive rewrite strategy allowing rewrites with bounded cost increase. /// /// The parameter gamma controls how greedy the algorithm should be. It allows /// a rewrite C1 -> C2 if C2 has at most gamma times the cost of C1: /// /// $cost(C2) < gamma * cost(C1)$ /// -/// The cost function is given by the number of operations in the circuit that -/// satisfy a given Op predicate. This allows for instance to use the total -/// number of gates (true predicate) or the number of CX gates as cost function. +/// The cost function is given by the sum of the cost of each operation in the +/// circuit. This allows for instance to use of the total number of gates (true +/// predicate), the number of CX gates or a weighted sum of gate types as cost +/// functions. /// /// gamma = 1 is the greedy strategy where a rewrite is only allowed if it /// strictly reduces the gate count. The default is gamma = 1.0001 (as set in /// the Quartz paper) and the number of CX gates. This essentially allows /// rewrites that improve or leave the number of CX unchanged. #[derive(Debug, Clone)] -pub struct ExhaustiveRewriteStrategy

{ +pub struct ExhaustiveGammaStrategy { /// The gamma parameter. pub gamma: f64, - /// Ops to count for cost function. - pub op_predicate: P, + /// A cost function for each operation. + pub op_cost: C, } -impl

ExhaustiveRewriteStrategy

{ +impl usize> ExhaustiveThresholdStrategy for ExhaustiveGammaStrategy { + type OpCost = usize; + type SumOpCost = usize; + + fn threshold(&self, &pattern_cost: &Self::SumOpCost, &target_cost: &Self::SumOpCost) -> bool { + (target_cost as f64) < self.gamma * (pattern_cost as f64) + } + + fn op_cost(&self, op: &OpType) -> Self::OpCost { + (self.op_cost)(op) + } +} + +impl ExhaustiveGammaStrategy { /// New exhaustive rewrite strategy with provided predicate. /// /// The gamma parameter is set to the default 1.0001. - pub fn with_predicate(op_predicate: P) -> Self { + pub fn with_cost(op_cost: C) -> Self { Self { gamma: 1.0001, - op_predicate, + op_cost, } } /// New exhaustive rewrite strategy with provided gamma and predicate. - pub fn new(gamma: f64, op_predicate: P) -> Self { - Self { - gamma, - op_predicate, - } + pub fn new(gamma: f64, op_cost: C) -> Self { + Self { gamma, op_cost } } } -impl ExhaustiveRewriteStrategy bool> { +impl ExhaustiveGammaStrategy usize> { /// Exhaustive rewrite strategy with CX count cost function. /// /// The gamma parameter is set to the default 1.0001. This is a good default /// choice for NISQ-y circuits, where CX gates are the most expensive. pub fn exhaustive_cx() -> Self { - ExhaustiveRewriteStrategy::with_predicate(is_cx) + ExhaustiveGammaStrategy::with_cost(|op| is_cx(op) as usize) } /// Exhaustive rewrite strategy with CX count cost function and provided gamma. pub fn exhaustive_cx_with_gamma(gamma: f64) -> Self { - ExhaustiveRewriteStrategy::new(gamma, is_cx) + ExhaustiveGammaStrategy::new(gamma, |op| is_cx(op) as usize) } } -impl bool> RewriteStrategy for ExhaustiveRewriteStrategy

{ +/// Exhaustive strategies based on cost functions and thresholds. +/// +/// Every possible rewrite is applied to a copy of the input circuit. Thus for +/// one circuit, up to `n` rewritten circuits will be returned, each obtained +/// by applying one of the `n` rewrites to the original circuit. +/// +/// Whether a rewrite is allowed or not is determined by a cost function and a +/// threshold function: if the cost of the target of the rewrite is below the +/// threshold given by the cost of the original circuit, the rewrite is +/// performed. +/// +/// The cost function must return a value of type `Self::OpCost`. All op costs +/// are summed up to obtain a total cost that is then compared using the +/// threshold function. +pub trait ExhaustiveThresholdStrategy { + /// The cost of a single operation. + type OpCost; + /// The sum of the cost of all operations in a circuit. + type SumOpCost; + + /// Whether the rewrite is allowed or not, based on the cost of the pattern and target. + fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool; + + /// The cost of a single operation. + fn op_cost(&self, op: &OpType) -> Self::OpCost; +} + +impl RewriteStrategy for T +where + T::SumOpCost: Sum + Ord, +{ + type Cost = T::SumOpCost; + #[tracing::instrument(skip_all)] fn apply_rewrites( &self, @@ -144,9 +256,9 @@ impl bool> RewriteStrategy for ExhaustiveRewriteStrategy

{ rewrites .into_iter() .filter(|rw| { - let old_count = pre_rewrite_cost(rw, circ, &self.op_predicate) as f64; - let new_count = post_rewrite_cost(rw, circ, &self.op_predicate) as f64; - new_count < old_count * self.gamma + let pattern_cost = pre_rewrite_cost(rw, circ, |op| self.op_cost(op)); + let target_cost = post_rewrite_cost(rw, circ, |op| self.op_cost(op)); + self.threshold(&pattern_cost, &target_cost) }) .map(|rw| { let mut circ = circ.clone(); @@ -155,31 +267,85 @@ impl bool> RewriteStrategy for ExhaustiveRewriteStrategy

{ }) .collect() } + + fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + cost(circ.nodes(), circ, |op| self.op_cost(op)) + } +} + +/// A pair of major and minor cost. +/// +/// This is used to order circuits based on major cost first, then minor cost. +/// A typical example would be CX count as major cost and total gate count as +/// minor cost. +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From)] +pub struct MajorMinorCost { + major: usize, + minor: usize, +} + +// Serialise as string so that it is easy to write to CSV +impl serde::Serialize for MajorMinorCost { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&format!("{:?}", self)) + } +} + +impl Debug for MajorMinorCost { + // TODO: A nicer print for the logs + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "(major={}, minor={})", self.major, self.minor) + } +} + +impl Sum for MajorMinorCost { + fn sum>(iter: I) -> Self { + iter.reduce(|a, b| (a.major + b.major, a.minor + b.minor).into()) + .unwrap_or_default() + } } fn is_cx(op: &OpType) -> bool { op_matches(op, T2Op::CX) } -fn cost( - nodes: impl IntoIterator, - circ: &Hugr, - pred: impl Fn(&OpType) -> bool, -) -> usize { +fn is_quantum(op: &OpType) -> bool { + let Ok(op): Result = op.try_into() else { + return false; + }; + op.is_quantum() +} + +fn cost(nodes: impl IntoIterator, circ: &Hugr, op_cost: C) -> S +where + C: Fn(&OpType) -> T, + S: Sum, +{ nodes .into_iter() - .filter(|n| { - let op = circ.get_optype(*n); - pred(op) + .map(|n| { + let op = circ.get_optype(n); + op_cost(op) }) - .count() + .sum() } -fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { +fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: C) -> S +where + C: Fn(&OpType) -> T, + S: Sum, +{ cost(rw.subcircuit().nodes().iter().copied(), circ, pred) } -fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { +fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: C) -> S +where + C: Fn(&OpType) -> T, + S: Sum, +{ cost(rw.replacement().nodes(), circ, pred) } @@ -258,7 +424,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + let strategy = ExhaustiveGammaStrategy::exhaustive_cx(); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); @@ -280,10 +446,34 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx_with_gamma(10.); + let strategy = ExhaustiveGammaStrategy::exhaustive_cx_with_gamma(10.); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); assert_eq!(circ_lens, exp_circ_lens); } + + #[test] + fn test_exhaustive_default_cx_cost() { + let strat = NonIncreasingCXCountStrategy::default_cx(); + let circ = n_cx(3); + assert_eq!(strat.circuit_cost(&circ), (3, 3).into()); + let circ = build_simple_circuit(2, |circ| { + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::X, [0])?; + circ.append(T2Op::X, [1])?; + Ok(()) + }) + .unwrap(); + assert_eq!(strat.circuit_cost(&circ), (1, 3).into()); + } + + #[test] + fn test_exhaustive_default_cx_threshold() { + let strat = NonIncreasingCXCountStrategy::default_cx(); + assert!(strat.threshold(&(3, 0).into(), &(3, 0).into())); + assert!(strat.threshold(&(3, 0).into(), &(3, 5).into())); + assert!(!strat.threshold(&(3, 10).into(), &(4, 0).into())); + assert!(strat.threshold(&(3, 0).into(), &(1, 5).into())); + } } diff --git a/test_files/small_eccs.json b/test_files/small_eccs.json index 787d10a6..06e29e54 100644 --- a/test_files/small_eccs.json +++ b/test_files/small_eccs.json @@ -50,5 +50,10 @@ ,[[2,0,0,3,["116baec0e00cd"],[5.53564910528592469e-01,2.63161770391809990e-01]],[["cx", ["Q1", "Q0"],["Q1", "Q0"]],["x", ["Q1"],["Q1"]],["cx", ["Q1", "Q0"],["Q1", "Q0"]]]] ,[[2,0,0,3,["116baec0e00cd"],[5.53564910528592469e-01,2.63161770391809990e-01]],[["cx", ["Q1", "Q0"],["Q1", "Q0"]],["x", ["Q1"],["Q1"]],["cx", ["Q1", "Q0"],["Q1", "Q0"]]]] ] +,"6701_3": [ +[[1,2,3,2,["a720832fadf2"],[2.22710267824423158e-01,2.92349563045663841e-01]],[["add", ["P2"],["P0", "P1"]],["rz", ["Q0"],["Q0", "P2"]]]] +,[[1,2,2,2,["a720832fadf2"],[2.22710267824423158e-01,2.92349563045663841e-01]],[["rz", ["Q0"],["Q0", "P0"]],["rz", ["Q0"],["Q0", "P1"]]]] +,[[1,2,2,2,["a720832fadf2"],[2.22710267824423103e-01,2.92349563045663619e-01]],[["rz", ["Q0"],["Q0", "P1"]],["rz", ["Q0"],["Q0", "P0"]]]] +] } ] \ No newline at end of file From 990d78f76098a32e403d2cfe9db0ae7326168a53 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 3 Oct 2023 09:43:47 +0200 Subject: [PATCH 2/5] Remove unused import --- src/optimiser/taso.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 7208d841..67a3f249 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -27,7 +27,7 @@ use std::num::NonZeroUsize; use std::time::{Duration, Instant}; use fxhash::FxHashSet; -use hugr::{Hugr, HugrView}; +use hugr::Hugr; use crate::circuit::CircuitHash; use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel; From 4c10e519169b280e930f098b9b628c26b51378cf Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 3 Oct 2023 10:03:06 +0200 Subject: [PATCH 3/5] fix docs --- src/rewrite/strategy.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index fe5c51cb..aae2e810 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -2,11 +2,15 @@ //! //! This module contains the [`RewriteStrategy`] trait, which is currently //! implemented by -//! - [`GreedyRewriteStrategy`], which applies as many rewrites as possible +//! - [`GreedyRewriteStrategy`], which applies as many rewrites as possible //! on one circuit, and -//! - [`ExhaustiveRewriteStrategy`], which clones the original circuit as many -//! times as there are possible rewrites and applies a different rewrite -//! to every circuit. +//! - exhaustive strategies, which clone the original circuit and explore every +//! possible rewrite (with some pruning strategy): +//! - [`NonIncreasingGateCountStrategy`], which only considers rewrites that +//! do not increase some cost function (e.g. cx gate count, implemented as +//! [`NonIncreasingCXCountStrategy`]), and +//! - [`ExhaustiveGammaStrategy`], which ignores rewrites that increase the +//! cost function beyond a threshold given by a f64 parameter gamma. use std::{collections::HashSet, fmt::Debug, iter::Sum}; From 6e52e83ff4efde24ff3bfe9fc0d81d0cbb257ddf Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 3 Oct 2023 14:56:41 +0200 Subject: [PATCH 4/5] Address comments --- src/optimiser/taso.rs | 13 ++-- src/rewrite/strategy.rs | 138 ++++++++++++++++++++-------------------- 2 files changed, 77 insertions(+), 74 deletions(-) diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 67a3f249..7fc4c9b0 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -303,20 +303,25 @@ mod taso_default { use std::io; use std::path::Path; + use hugr::ops::OpType; + use crate::rewrite::ecc_rewriter::RewriterSerialisationError; - use crate::rewrite::strategy::NonIncreasingCXCountStrategy; + use crate::rewrite::strategy::NonIncreasingGateCountStrategy; use crate::rewrite::ECCRewriter; use super::*; /// The default TASO optimiser using ECC sets. - pub type DefaultTasoOptimiser = TasoOptimiser; + pub type DefaultTasoOptimiser = TasoOptimiser< + ECCRewriter, + NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize>, + >; impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = NonIncreasingCXCountStrategy::default_cx(); + let strategy = NonIncreasingGateCountStrategy::default_cx(); Ok(TasoOptimiser::new(rewriter, strategy)) } @@ -325,7 +330,7 @@ mod taso_default { rewriter_path: impl AsRef, ) -> Result { let rewriter = ECCRewriter::load_binary(rewriter_path)?; - let strategy = NonIncreasingCXCountStrategy::default_cx(); + let strategy = NonIncreasingGateCountStrategy::default_cx(); Ok(TasoOptimiser::new(rewriter, strategy)) } } diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index aae2e810..b05424db 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -95,6 +95,65 @@ impl RewriteStrategy for GreedyRewriteStrategy { } } +/// Exhaustive strategies based on cost functions and thresholds. +/// +/// Every possible rewrite is applied to a copy of the input circuit. Thus for +/// one circuit, up to `n` rewritten circuits will be returned, each obtained +/// by applying one of the `n` rewrites to the original circuit. +/// +/// Whether a rewrite is allowed or not is determined by a cost function and a +/// threshold function: if the cost of the target of the rewrite is below the +/// threshold given by the cost of the original circuit, the rewrite is +/// performed. +/// +/// The cost function must return a value of type `Self::OpCost`. All op costs +/// are summed up to obtain a total cost that is then compared using the +/// threshold function. +pub trait ExhaustiveThresholdStrategy { + /// The cost of a single operation. + type OpCost; + /// The sum of the cost of all operations in a circuit. + type SumOpCost; + + /// Whether the rewrite is allowed or not, based on the cost of the pattern and target. + fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool; + + /// The cost of a single operation. + fn op_cost(&self, op: &OpType) -> Self::OpCost; +} + +impl RewriteStrategy for T +where + T::SumOpCost: Sum + Ord, +{ + type Cost = T::SumOpCost; + + #[tracing::instrument(skip_all)] + fn apply_rewrites( + &self, + rewrites: impl IntoIterator, + circ: &Hugr, + ) -> Vec { + rewrites + .into_iter() + .filter(|rw| { + let pattern_cost = pre_rewrite_cost(rw, circ, |op| self.op_cost(op)); + let target_cost = post_rewrite_cost(rw, circ, |op| self.op_cost(op)); + self.threshold(&pattern_cost, &target_cost) + }) + .map(|rw| { + let mut circ = circ.clone(); + rw.apply(&mut circ).expect("invalid pattern match"); + circ + }) + .collect() + } + + fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + cost(circ.nodes(), circ, |op| self.op_cost(op)) + } +} + /// Exhaustive rewrite strategy allowing smaller or equal cost rewrites. /// /// Rewrites are permitted based on a cost function called the major cost: if @@ -132,15 +191,13 @@ where } } -/// Non-increasing rewrite strategy based on CX count. -/// -/// The minor cost to break ties between equal CX counts is the number of -/// quantum gates. -pub type NonIncreasingCXCountStrategy = - NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize>; - -impl NonIncreasingCXCountStrategy { - /// Create rewrite strategy based on non-increasing CX count. +impl NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize> { + /// Non-increasing rewrite strategy based on CX count. + /// + /// The minor cost to break ties between equal CX counts is the number of + /// quantum gates. + /// + /// This is probably a good default for NISQ-y circuit optimisation. pub fn default_cx() -> Self { Self { major_cost: |op| is_cx(op) as usize, @@ -218,65 +275,6 @@ impl ExhaustiveGammaStrategy usize> { } } -/// Exhaustive strategies based on cost functions and thresholds. -/// -/// Every possible rewrite is applied to a copy of the input circuit. Thus for -/// one circuit, up to `n` rewritten circuits will be returned, each obtained -/// by applying one of the `n` rewrites to the original circuit. -/// -/// Whether a rewrite is allowed or not is determined by a cost function and a -/// threshold function: if the cost of the target of the rewrite is below the -/// threshold given by the cost of the original circuit, the rewrite is -/// performed. -/// -/// The cost function must return a value of type `Self::OpCost`. All op costs -/// are summed up to obtain a total cost that is then compared using the -/// threshold function. -pub trait ExhaustiveThresholdStrategy { - /// The cost of a single operation. - type OpCost; - /// The sum of the cost of all operations in a circuit. - type SumOpCost; - - /// Whether the rewrite is allowed or not, based on the cost of the pattern and target. - fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool; - - /// The cost of a single operation. - fn op_cost(&self, op: &OpType) -> Self::OpCost; -} - -impl RewriteStrategy for T -where - T::SumOpCost: Sum + Ord, -{ - type Cost = T::SumOpCost; - - #[tracing::instrument(skip_all)] - fn apply_rewrites( - &self, - rewrites: impl IntoIterator, - circ: &Hugr, - ) -> Vec { - rewrites - .into_iter() - .filter(|rw| { - let pattern_cost = pre_rewrite_cost(rw, circ, |op| self.op_cost(op)); - let target_cost = post_rewrite_cost(rw, circ, |op| self.op_cost(op)); - self.threshold(&pattern_cost, &target_cost) - }) - .map(|rw| { - let mut circ = circ.clone(); - rw.apply(&mut circ).expect("invalid pattern match"); - circ - }) - .collect() - } - - fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { - cost(circ.nodes(), circ, |op| self.op_cost(op)) - } -} - /// A pair of major and minor cost. /// /// This is used to order circuits based on major cost first, then minor cost. @@ -459,7 +457,7 @@ mod tests { #[test] fn test_exhaustive_default_cx_cost() { - let strat = NonIncreasingCXCountStrategy::default_cx(); + let strat = NonIncreasingGateCountStrategy::default_cx(); let circ = n_cx(3); assert_eq!(strat.circuit_cost(&circ), (3, 3).into()); let circ = build_simple_circuit(2, |circ| { @@ -474,7 +472,7 @@ mod tests { #[test] fn test_exhaustive_default_cx_threshold() { - let strat = NonIncreasingCXCountStrategy::default_cx(); + let strat = NonIncreasingGateCountStrategy::default_cx(); assert!(strat.threshold(&(3, 0).into(), &(3, 0).into())); assert!(strat.threshold(&(3, 0).into(), &(3, 5).into())); assert!(!strat.threshold(&(3, 10).into(), &(4, 0).into())); From 16cb3466560bf2da789c5b64e43cedc706efc77d Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Tue, 3 Oct 2023 15:14:29 +0200 Subject: [PATCH 5/5] Remove ref to defunct CXCountStrat --- src/rewrite/strategy.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index b05424db..38a63e3f 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -8,7 +8,7 @@ //! possible rewrite (with some pruning strategy): //! - [`NonIncreasingGateCountStrategy`], which only considers rewrites that //! do not increase some cost function (e.g. cx gate count, implemented as -//! [`NonIncreasingCXCountStrategy`]), and +//! [`NonIncreasingGateCountStrategy::default_cx`]), and //! - [`ExhaustiveGammaStrategy`], which ignores rewrites that increase the //! cost function beyond a threshold given by a f64 parameter gamma.