diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 67a3f249..7fc4c9b0 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -303,20 +303,25 @@ mod taso_default { use std::io; use std::path::Path; + use hugr::ops::OpType; + use crate::rewrite::ecc_rewriter::RewriterSerialisationError; - use crate::rewrite::strategy::NonIncreasingCXCountStrategy; + use crate::rewrite::strategy::NonIncreasingGateCountStrategy; use crate::rewrite::ECCRewriter; use super::*; /// The default TASO optimiser using ECC sets. - pub type DefaultTasoOptimiser = TasoOptimiser; + pub type DefaultTasoOptimiser = TasoOptimiser< + ECCRewriter, + NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize>, + >; impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = NonIncreasingCXCountStrategy::default_cx(); + let strategy = NonIncreasingGateCountStrategy::default_cx(); Ok(TasoOptimiser::new(rewriter, strategy)) } @@ -325,7 +330,7 @@ mod taso_default { rewriter_path: impl AsRef, ) -> Result { let rewriter = ECCRewriter::load_binary(rewriter_path)?; - let strategy = NonIncreasingCXCountStrategy::default_cx(); + let strategy = NonIncreasingGateCountStrategy::default_cx(); Ok(TasoOptimiser::new(rewriter, strategy)) } } diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index aae2e810..b05424db 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -95,6 +95,65 @@ impl RewriteStrategy for GreedyRewriteStrategy { } } +/// Exhaustive strategies based on cost functions and thresholds. +/// +/// Every possible rewrite is applied to a copy of the input circuit. Thus for +/// one circuit, up to `n` rewritten circuits will be returned, each obtained +/// by applying one of the `n` rewrites to the original circuit. +/// +/// Whether a rewrite is allowed or not is determined by a cost function and a +/// threshold function: if the cost of the target of the rewrite is below the +/// threshold given by the cost of the original circuit, the rewrite is +/// performed. +/// +/// The cost function must return a value of type `Self::OpCost`. All op costs +/// are summed up to obtain a total cost that is then compared using the +/// threshold function. +pub trait ExhaustiveThresholdStrategy { + /// The cost of a single operation. + type OpCost; + /// The sum of the cost of all operations in a circuit. + type SumOpCost; + + /// Whether the rewrite is allowed or not, based on the cost of the pattern and target. + fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool; + + /// The cost of a single operation. + fn op_cost(&self, op: &OpType) -> Self::OpCost; +} + +impl RewriteStrategy for T +where + T::SumOpCost: Sum + Ord, +{ + type Cost = T::SumOpCost; + + #[tracing::instrument(skip_all)] + fn apply_rewrites( + &self, + rewrites: impl IntoIterator, + circ: &Hugr, + ) -> Vec { + rewrites + .into_iter() + .filter(|rw| { + let pattern_cost = pre_rewrite_cost(rw, circ, |op| self.op_cost(op)); + let target_cost = post_rewrite_cost(rw, circ, |op| self.op_cost(op)); + self.threshold(&pattern_cost, &target_cost) + }) + .map(|rw| { + let mut circ = circ.clone(); + rw.apply(&mut circ).expect("invalid pattern match"); + circ + }) + .collect() + } + + fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + cost(circ.nodes(), circ, |op| self.op_cost(op)) + } +} + /// Exhaustive rewrite strategy allowing smaller or equal cost rewrites. /// /// Rewrites are permitted based on a cost function called the major cost: if @@ -132,15 +191,13 @@ where } } -/// Non-increasing rewrite strategy based on CX count. -/// -/// The minor cost to break ties between equal CX counts is the number of -/// quantum gates. -pub type NonIncreasingCXCountStrategy = - NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize>; - -impl NonIncreasingCXCountStrategy { - /// Create rewrite strategy based on non-increasing CX count. +impl NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize> { + /// Non-increasing rewrite strategy based on CX count. + /// + /// The minor cost to break ties between equal CX counts is the number of + /// quantum gates. + /// + /// This is probably a good default for NISQ-y circuit optimisation. pub fn default_cx() -> Self { Self { major_cost: |op| is_cx(op) as usize, @@ -218,65 +275,6 @@ impl ExhaustiveGammaStrategy usize> { } } -/// Exhaustive strategies based on cost functions and thresholds. -/// -/// Every possible rewrite is applied to a copy of the input circuit. Thus for -/// one circuit, up to `n` rewritten circuits will be returned, each obtained -/// by applying one of the `n` rewrites to the original circuit. -/// -/// Whether a rewrite is allowed or not is determined by a cost function and a -/// threshold function: if the cost of the target of the rewrite is below the -/// threshold given by the cost of the original circuit, the rewrite is -/// performed. -/// -/// The cost function must return a value of type `Self::OpCost`. All op costs -/// are summed up to obtain a total cost that is then compared using the -/// threshold function. -pub trait ExhaustiveThresholdStrategy { - /// The cost of a single operation. - type OpCost; - /// The sum of the cost of all operations in a circuit. - type SumOpCost; - - /// Whether the rewrite is allowed or not, based on the cost of the pattern and target. - fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool; - - /// The cost of a single operation. - fn op_cost(&self, op: &OpType) -> Self::OpCost; -} - -impl RewriteStrategy for T -where - T::SumOpCost: Sum + Ord, -{ - type Cost = T::SumOpCost; - - #[tracing::instrument(skip_all)] - fn apply_rewrites( - &self, - rewrites: impl IntoIterator, - circ: &Hugr, - ) -> Vec { - rewrites - .into_iter() - .filter(|rw| { - let pattern_cost = pre_rewrite_cost(rw, circ, |op| self.op_cost(op)); - let target_cost = post_rewrite_cost(rw, circ, |op| self.op_cost(op)); - self.threshold(&pattern_cost, &target_cost) - }) - .map(|rw| { - let mut circ = circ.clone(); - rw.apply(&mut circ).expect("invalid pattern match"); - circ - }) - .collect() - } - - fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { - cost(circ.nodes(), circ, |op| self.op_cost(op)) - } -} - /// A pair of major and minor cost. /// /// This is used to order circuits based on major cost first, then minor cost. @@ -459,7 +457,7 @@ mod tests { #[test] fn test_exhaustive_default_cx_cost() { - let strat = NonIncreasingCXCountStrategy::default_cx(); + let strat = NonIncreasingGateCountStrategy::default_cx(); let circ = n_cx(3); assert_eq!(strat.circuit_cost(&circ), (3, 3).into()); let circ = build_simple_circuit(2, |circ| { @@ -474,7 +472,7 @@ mod tests { #[test] fn test_exhaustive_default_cx_threshold() { - let strat = NonIncreasingCXCountStrategy::default_cx(); + let strat = NonIncreasingGateCountStrategy::default_cx(); assert!(strat.threshold(&(3, 0).into(), &(3, 0).into())); assert!(strat.threshold(&(3, 0).into(), &(3, 5).into())); assert!(!strat.threshold(&(3, 10).into(), &(4, 0).into()));