From 0adbbb980168e94693f4af2177ee48e4fce1d099 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Fri, 29 Sep 2023 09:49:22 +0100 Subject: [PATCH 1/4] wip: exhaustive greedy strategy --- src/optimiser/taso.rs | 8 ++-- src/rewrite/strategy.rs | 94 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 4 deletions(-) diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index a231aaf8..1b5bc1ac 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -296,7 +296,7 @@ mod taso_default { use crate::ops::op_matches; use crate::rewrite::ecc_rewriter::RewriterSerialisationError; - use crate::rewrite::strategy::ExhaustiveRewriteStrategy; + use crate::rewrite::strategy::ExhaustiveGreedyRewriteStrategy; use crate::rewrite::ECCRewriter; use crate::T2Op; @@ -305,7 +305,7 @@ mod taso_default { /// The default TASO optimiser using ECC sets. pub type DefaultTasoOptimiser = TasoOptimiser< ECCRewriter, - ExhaustiveRewriteStrategy bool>, + ExhaustiveGreedyRewriteStrategy bool>, fn(&Hugr) -> usize, >; @@ -313,7 +313,7 @@ mod taso_default { /// A sane default optimiser using the given ECC sets. pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + let strategy = ExhaustiveGreedyRewriteStrategy::greedy_exhaustive_cx(); Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) } @@ -322,7 +322,7 @@ mod taso_default { rewriter_path: impl AsRef, ) -> Result { let rewriter = ECCRewriter::load_binary(rewriter_path)?; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + let strategy = ExhaustiveGreedyRewriteStrategy::greedy_exhaustive_cx(); Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) } } diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 3df7ceeb..3b296c76 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -42,6 +42,7 @@ pub trait RewriteStrategy { /// with as many rewrites applied as possible. /// /// Rewrites are only applied if they strictly decrease gate count. +#[derive(Debug, Copy, Clone)] pub struct GreedyRewriteStrategy; impl RewriteStrategy for GreedyRewriteStrategy { @@ -157,6 +158,99 @@ impl bool> RewriteStrategy for ExhaustiveRewriteStrategy

{ } } +/// A rewrite strategy that explores applying each rewrite that reduces the size +/// of the circuit to copies of the circuit. +/// +/// Tries to apply as many rewrites as possible at each step, using a greedy +/// strategy. +/// +/// The cost function is given by the number of operations in the circuit that +/// satisfy a given Op predicate. This allows for instance to use the total +/// number of gates (true predicate) or the number of CX gates as cost function. +#[derive(Debug, Clone)] +pub struct ExhaustiveGreedyRewriteStrategy

{ + /// The gamma parameter. + pub gamma: f64, + /// Ops to count for cost function. + pub op_predicate: P, +} + +impl

ExhaustiveGreedyRewriteStrategy

{ + /// New greedy exhaustive rewrite strategy with provided predicate. + /// + /// The gamma parameter is set to the default 1.0001. + pub fn with_predicate(op_predicate: P) -> Self { + Self { + gamma: 1.0001, + op_predicate, + } + } + + /// New greedy exhaustive rewrite strategy with provided gamma and predicate. + pub fn new(gamma: f64, op_predicate: P) -> Self { + Self { + gamma, + op_predicate, + } + } +} + +impl ExhaustiveGreedyRewriteStrategy bool> { + /// Exhaustive rewrite strategy with CX count cost function. + /// + /// The gamma parameter is set to the default 1.0001. This is a good default + /// choice for NISQ-y circuits, where CX gates are the most expensive. + pub fn greedy_exhaustive_cx() -> Self { + Self::with_predicate(is_cx) + } +} + +impl bool> RewriteStrategy for ExhaustiveGreedyRewriteStrategy

{ + #[tracing::instrument(skip_all)] + fn apply_rewrites( + &self, + rewrites: impl IntoIterator, + circ: &Hugr, + ) -> Vec { + // Check only the rewrites that reduce the size of the circuit. + let rewrites = rewrites + .into_iter() + .map(|rw| { + let old_count = pre_rewrite_cost(&rw, circ, &self.op_predicate) as f64; + let new_count = post_rewrite_cost(&rw, circ, &self.op_predicate) as f64; + (new_count - old_count, rw) + }) + .sorted_by_key(|(delta, _)| *delta as isize) + .take_while(|(delta, _)| *delta < 0.001) + .map(|(_, rw)| rw) + .collect_vec(); + + let mut rewrite_sets = Vec::with_capacity(rewrites.len()); + for i in 0..rewrites.len() { + let mut curr_circ = circ.clone(); + let mut changed_nodes = HashSet::new(); + for rewrite in &rewrites[i..] { + if !changed_nodes.is_empty() + && rewrite + .subcircuit() + .nodes() + .iter() + .any(|n| changed_nodes.contains(n)) + { + continue; + } + changed_nodes.extend(rewrite.subcircuit().nodes().iter().copied()); + rewrite + .clone() + .apply(&mut curr_circ) + .expect("Could not perform rewrite in exhaustive greedy strategy"); + } + rewrite_sets.push(curr_circ); + } + rewrite_sets + } +} + fn is_cx(op: &OpType) -> bool { op_matches(op, T2Op::CX) } From 7f3ee82fa4b1831d61b21ca9364aae5910058ef5 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 11 Oct 2023 12:37:08 +0100 Subject: [PATCH 2/4] Update default cx strategy to be semi-greedy --- Cargo.toml | 2 +- src/circuit/cost.rs | 19 +++- src/optimiser/taso.rs | 15 ++- src/rewrite.rs | 9 ++ src/rewrite/strategy.rs | 222 +++++++++++++++++++++++++++------------- 5 files changed, 183 insertions(+), 84 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 37f63f63..9244e2bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,7 +73,7 @@ members = ["pyrs", "compile-rewriter", "taso-optimiser"] [workspace.dependencies] -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "09494f1" } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "a34662f0" } portgraph = { version = "0.9", features = ["serde"] } pyo3 = { version = "0.19" } itertools = { version = "0.11.0" } diff --git a/src/circuit/cost.rs b/src/circuit/cost.rs index 64ec9f40..3c51dda4 100644 --- a/src/circuit/cost.rs +++ b/src/circuit/cost.rs @@ -5,7 +5,7 @@ use hugr::ops::OpType; use std::fmt::{Debug, Display}; use std::iter::Sum; use std::num::NonZeroUsize; -use std::ops::Add; +use std::ops::{Add, AddAssign}; use crate::ops::op_matches; use crate::T2Op; @@ -29,7 +29,9 @@ pub trait CircuitCost: Add + Sum + Debug + Default + Clone } /// The cost for a group of operations in a circuit, each with cost `OpCost`. -pub trait CostDelta: Sum + Debug + Default + Clone + Ord { +pub trait CostDelta: + AddAssign + Add + Sum + Debug + Default + Clone + Ord +{ /// Return the delta as a `isize`. This may discard some of the cost delta information. fn as_isize(&self) -> isize; } @@ -62,14 +64,21 @@ impl Debug for MajorMinorCost { } } -impl Add for MajorMinorCost { - type Output = MajorMinorCost; +impl> Add for MajorMinorCost { + type Output = MajorMinorCost; - fn add(self, rhs: MajorMinorCost) -> Self::Output { + fn add(self, rhs: MajorMinorCost) -> Self::Output { (self.major + rhs.major, self.minor + rhs.minor).into() } } +impl AddAssign for MajorMinorCost { + fn add_assign(&mut self, rhs: Self) { + self.major += rhs.major; + self.minor += rhs.minor; + } +} + impl + Default> Sum for MajorMinorCost { fn sum>(iter: I) -> Self { iter.reduce(|a, b| (a.major + b.major, a.minor + b.minor).into()) diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index d0ec9816..2281dd5a 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -169,7 +169,6 @@ where // Ignore this circuit: we've already seen it continue; } - circ_cnt += 1; logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len()); let new_circ_cost = cost.add_delta(&cost_delta); pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost); @@ -390,22 +389,22 @@ mod taso_default { use hugr::ops::OpType; use crate::rewrite::ecc_rewriter::RewriterSerialisationError; - use crate::rewrite::strategy::ExhaustiveGreedyRewriteStrategy; + use crate::rewrite::strategy::{ExhaustiveGreedyStrategy, NonIncreasingGateCountCost}; use crate::rewrite::ECCRewriter; use super::*; + pub type StrategyCost = NonIncreasingGateCountCost usize, fn(&OpType) -> usize>; + /// The default TASO optimiser using ECC sets. - pub type DefaultTasoOptimiser = TasoOptimiser< - ECCRewriter, - ExhaustiveGreedyRewriteStrategy bool, fn(&Hugr) -> usize>, - >; + pub type DefaultTasoOptimiser = + TasoOptimiser>; impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = ExhaustiveGreedyRewriteStrategy::greedy_exhaustive_cx(); + let strategy = NonIncreasingGateCountCost::default_cx(); Ok(TasoOptimiser::new(rewriter, strategy)) } @@ -414,7 +413,7 @@ mod taso_default { rewriter_path: impl AsRef, ) -> Result { let rewriter = ECCRewriter::load_binary(rewriter_path)?; - let strategy = ExhaustiveGreedyRewriteStrategy::greedy_exhaustive_cx(); + let strategy = NonIncreasingGateCountCost::default_cx(); Ok(TasoOptimiser::new(rewriter, strategy)) } } diff --git a/src/rewrite.rs b/src/rewrite.rs index ebd55d6f..83f7e422 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -101,6 +101,15 @@ impl CircuitRewrite { self.0.replacement() } + /// Returns the nodes affected by a rewrite. + /// + /// This includes the nodes in the subcircuit and it's neighbours (contained in + /// the rewrite's boundary). + #[inline] + pub fn affected_nodes(&self) -> impl Iterator + '_ { + self.0.affected_nodes() + } + delegate! { to self.0 { /// Apply the rewrite rule to a circuit. diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index b5c6fd71..90e7a176 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -14,6 +14,7 @@ use std::{collections::HashSet, fmt::Debug}; +use derive_more::From; use hugr::ops::OpType; use hugr::Hugr; use itertools::Itertools; @@ -158,9 +159,13 @@ impl RewriteStrategy for GreedyRewriteStrategy { /// Exhaustive strategies based on cost functions and thresholds. /// -/// Every possible rewrite is applied to a copy of the input circuit. Thus for -/// one circuit, up to `n` rewritten circuits will be returned, each obtained -/// by applying one of the `n` rewrites to the original circuit. +/// Every possible rewrite is applied to a copy of the input circuit. In +/// addition, other non-overlapping rewrites are applied greedily in ascending +/// cost delta. +/// +/// Thus for one circuit, up to `n` rewritten circuits will be returned, each +/// obtained by applying at least one of the `n` rewrites to the original +/// circuit. /// /// Whether a rewrite is allowed or not is determined by a cost function and a /// threshold function: if the cost of the target of the rewrite is below the @@ -170,21 +175,91 @@ impl RewriteStrategy for GreedyRewriteStrategy { /// The cost function must return a value of type `Self::OpCost`. All op costs /// are summed up to obtain a total cost that is then compared using the /// threshold function. -pub trait ExhaustiveThresholdStrategy { - /// The cost of a single operation. - type OpCost: CircuitCost; +/// +/// This kind of strategy is not recommended for thresholds that allow positive +/// cost deltas, as these will always be greedily applied even if they increase +/// the final cost. +#[derive(Debug, Copy, Clone, From)] +pub struct ExhaustiveGreedyStrategy { + /// The cost function. + pub strat_cost: T, +} + +impl RewriteStrategy for ExhaustiveGreedyStrategy { + type Cost = T::OpCost; + + #[tracing::instrument(skip_all)] + fn apply_rewrites( + &self, + rewrites: impl IntoIterator, + circ: &Hugr, + ) -> RewriteResult { + // Check only the rewrites that reduce the size of the circuit. + let rewrites = rewrites + .into_iter() + .filter_map(|rw| { + let pattern_cost = pre_rewrite_cost(&rw, circ, |op| self.op_cost(op)); + let target_cost = post_rewrite_cost(&rw, |op| self.op_cost(op)); + if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) { + return None; + } + Some((rw, target_cost.sub_cost(&pattern_cost))) + }) + .sorted_by_key(|(_, delta)| delta.clone()) + .collect_vec(); + + let mut rewrite_sets = RewriteResult::with_capacity(rewrites.len()); + for i in 0..rewrites.len() { + let mut curr_circ = circ.clone(); + let mut changed_nodes = HashSet::new(); + let mut cost_delta = Default::default(); + for (rewrite, delta) in &rewrites[i..] { + if !changed_nodes.is_empty() + && rewrite.affected_nodes().any(|n| changed_nodes.contains(&n)) + { + continue; + } + changed_nodes.extend(rewrite.affected_nodes()); + cost_delta += delta.clone(); + + rewrite + .clone() + .apply(&mut curr_circ) + .expect("Could not perform rewrite in exhaustive greedy strategy"); + } + rewrite_sets.circs.push(curr_circ); + rewrite_sets.cost_deltas.push(cost_delta); + } + rewrite_sets + } - /// Returns true if the rewrite is allowed, based on the cost of the pattern and target. #[inline] - fn under_threshold(&self, pattern_cost: &Self::OpCost, target_cost: &Self::OpCost) -> bool { - target_cost.sub_cost(pattern_cost).as_isize() <= 0 + fn op_cost(&self, op: &OpType) -> Self::Cost { + self.strat_cost.op_cost(op) } +} - /// The cost of a single operation. - fn op_cost(&self, op: &OpType) -> Self::OpCost; +/// Exhaustive strategies based on cost functions and thresholds. +/// +/// Every possible rewrite is applied to a copy of the input circuit. Thus for +/// one circuit, up to `n` rewritten circuits will be returned, each obtained +/// by applying one of the `n` rewrites to the original circuit. +/// +/// Whether a rewrite is allowed or not is determined by a cost function and a +/// threshold function: if the cost of the target of the rewrite is below the +/// threshold given by the cost of the original circuit, the rewrite is +/// performed. +/// +/// The cost function must return a value of type `Self::OpCost`. All op costs +/// are summed up to obtain a total cost that is then compared using the +/// threshold function. +#[derive(Debug, Copy, Clone, From)] +pub struct ExhaustiveThresholdStrategy { + /// The cost function. + pub strat_cost: T, } -impl RewriteStrategy for T { +impl RewriteStrategy for ExhaustiveThresholdStrategy { type Cost = T::OpCost; #[tracing::instrument(skip_all)] @@ -198,7 +273,7 @@ impl RewriteStrategy for T { .filter_map(|rw| { let pattern_cost = pre_rewrite_cost(&rw, circ, |op| self.op_cost(op)); let target_cost = post_rewrite_cost(&rw, |op| self.op_cost(op)); - if !self.under_threshold(&pattern_cost, &target_cost) { + if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) { return None; } let mut circ = circ.clone(); @@ -211,11 +286,28 @@ impl RewriteStrategy for T { #[inline] fn op_cost(&self, op: &OpType) -> Self::Cost { - ::op_cost(self, op) + self.strat_cost.op_cost(op) + } +} + +/// Cost function definitions required in exhaustive strategies. +/// +/// See [`ExhaustiveThresholdStrategy`], [`ExhaustiveGreedyStrategy`]. +pub trait StrategyCost { + /// The cost of a single operation. + type OpCost: CircuitCost; + + /// Returns true if the rewrite is allowed, based on the cost of the pattern and target. + #[inline] + fn under_threshold(&self, pattern_cost: &Self::OpCost, target_cost: &Self::OpCost) -> bool { + target_cost.sub_cost(pattern_cost).as_isize() <= 0 } + + /// The cost of a single operation. + fn op_cost(&self, op: &OpType) -> Self::OpCost; } -/// Exhaustive rewrite strategy allowing smaller or equal cost rewrites. +/// Rewrite strategy cost allowing smaller or equal cost rewrites. /// /// Rewrites are permitted based on a cost function called the major cost: if /// the major cost of the target of the rewrite is smaller or equal to the major @@ -226,16 +318,29 @@ impl RewriteStrategy for T { /// with a smaller minor cost. /// /// An example would be to use the number of CX gates as major cost and the -/// total number of gates as minor cost. Compared to a [`ExhaustiveGammaStrategy`], -/// that would only order circuits based on the number of CX gates, this creates -/// a less flat optimisation landscape. +/// total number of gates as minor cost. Compared to a +/// [`ExhaustiveGammaStrategy`], that would only order circuits based on the +/// number of CX gates, this creates a less flat optimisation landscape. #[derive(Debug, Clone)] -pub struct NonIncreasingGateCountStrategy { +pub struct NonIncreasingGateCountCost { major_cost: C1, minor_cost: C2, } -impl NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize> { +impl StrategyCost for NonIncreasingGateCountCost +where + C1: Fn(&OpType) -> usize, + C2: Fn(&OpType) -> usize, +{ + type OpCost = MajorMinorCost; + + #[inline] + fn op_cost(&self, op: &OpType) -> Self::OpCost { + ((self.major_cost)(op), (self.minor_cost)(op)).into() + } +} + +impl NonIncreasingGateCountCost usize, fn(&OpType) -> usize> { /// Non-increasing rewrite strategy based on CX count. /// /// The minor cost to break ties between equal CX counts is the number of @@ -243,31 +348,19 @@ impl NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize> /// /// This is probably a good default for NISQ-y circuit optimisation. #[inline] - pub fn default_cx() -> Self { + pub fn default_cx() -> ExhaustiveGreedyStrategy { Self { major_cost: |op| is_cx(op) as usize, minor_cost: |op| is_quantum(op) as usize, } + .into() } } -impl ExhaustiveThresholdStrategy for NonIncreasingGateCountStrategy -where - C1: Fn(&OpType) -> usize, - C2: Fn(&OpType) -> usize, -{ - type OpCost = MajorMinorCost; - - #[inline] - fn op_cost(&self, op: &OpType) -> Self::OpCost { - ((self.major_cost)(op), (self.minor_cost)(op)).into() - } -} - -/// Exhaustive rewrite strategy allowing rewrites with bounded cost increase. +/// Rewrite strategy cost allowing rewrites with bounded cost increase. /// -/// The parameter gamma controls how greedy the algorithm should be. It allows -/// a rewrite C1 -> C2 if C2 has at most gamma times the cost of C1: +/// The parameter gamma controls how greedy the algorithm should be. It allows a +/// rewrite C1 -> C2 if C2 has at most gamma times the cost of C1: /// /// $cost(C2) < gamma * cost(C1)$ /// @@ -281,14 +374,14 @@ where /// the Quartz paper) and the number of CX gates. This essentially allows /// rewrites that improve or leave the number of CX unchanged. #[derive(Debug, Clone)] -pub struct ExhaustiveGammaStrategy { +pub struct GammaStrategyCost { /// The gamma parameter. pub gamma: f64, /// A cost function for each operation. pub op_cost: C, } -impl usize> ExhaustiveThresholdStrategy for ExhaustiveGammaStrategy { +impl usize> StrategyCost for GammaStrategyCost { type OpCost = usize; #[inline] @@ -302,39 +395,40 @@ impl usize> ExhaustiveThresholdStrategy for ExhaustiveGammaStr } } -impl ExhaustiveGammaStrategy { +impl GammaStrategyCost { /// New exhaustive rewrite strategy with provided predicate. /// /// The gamma parameter is set to the default 1.0001. #[inline] - pub fn with_cost(op_cost: C) -> Self { + pub fn with_cost(op_cost: C) -> ExhaustiveThresholdStrategy { Self { gamma: 1.0001, op_cost, } + .into() } /// New exhaustive rewrite strategy with provided gamma and predicate. #[inline] - pub fn new(gamma: f64, op_cost: C) -> Self { - Self { gamma, op_cost } + pub fn new(gamma: f64, op_cost: C) -> ExhaustiveThresholdStrategy { + Self { gamma, op_cost }.into() } } -impl ExhaustiveGammaStrategy usize> { +impl GammaStrategyCost usize> { /// Exhaustive rewrite strategy with CX count cost function. /// /// The gamma parameter is set to the default 1.0001. This is a good default /// choice for NISQ-y circuits, where CX gates are the most expensive. #[inline] - pub fn exhaustive_cx() -> Self { - ExhaustiveGammaStrategy::with_cost(|op| is_cx(op) as usize) + pub fn exhaustive_cx() -> ExhaustiveThresholdStrategy { + GammaStrategyCost::with_cost(|op| is_cx(op) as usize) } /// Exhaustive rewrite strategy with CX count cost function and provided gamma. #[inline] - pub fn exhaustive_cx_with_gamma(gamma: f64) -> Self { - ExhaustiveGammaStrategy::new(gamma, |op| is_cx(op) as usize) + pub fn exhaustive_cx_with_gamma(gamma: f64) -> ExhaustiveThresholdStrategy { + GammaStrategyCost::new(gamma, |op| is_cx(op) as usize) } } @@ -357,10 +451,7 @@ where #[cfg(test)] mod tests { use super::*; - use hugr::{ - ops::{OpTag, OpTrait}, - Hugr, HugrView, Node, - }; + use hugr::{Hugr, Node}; use itertools::Itertools; use crate::{ @@ -396,10 +487,7 @@ mod tests { #[test] fn test_greedy_strategy() { let circ = n_cx(10); - let cx_gates = circ - .nodes() - .filter(|&n| OpTag::Leaf.is_superset(circ.get_optype(n).tag())) - .collect_vec(); + let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec(); let rws = [ rw_to_empty(&circ, cx_gates[0..2].to_vec()), @@ -417,10 +505,7 @@ mod tests { #[test] fn test_exhaustive_default_strategy() { let circ = n_cx(10); - let cx_gates = circ - .nodes() - .filter(|&n| OpTag::Leaf.is_superset(circ.get_optype(n).tag())) - .collect_vec(); + let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec(); let rws = [ rw_to_empty(&circ, cx_gates[0..2].to_vec()), @@ -429,20 +514,17 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveGammaStrategy::exhaustive_cx(); + let strategy = NonIncreasingGateCountCost::default_cx(); let rewritten = strategy.apply_rewrites(rws, &circ); - let exp_circ_lens = HashSet::from_iter([8, 6, 9]); + let exp_circ_lens = HashSet::from_iter([3, 7, 9]); let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect(); assert_eq!(circ_lens, exp_circ_lens); } #[test] - fn test_exhaustive_generous_strategy() { + fn test_exhaustive_gamma_strategy() { let circ = n_cx(10); - let cx_gates = circ - .nodes() - .filter(|&n| OpTag::Leaf.is_superset(circ.get_optype(n).tag())) - .collect_vec(); + let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec(); let rws = [ rw_to_empty(&circ, cx_gates[0..2].to_vec()), @@ -451,7 +533,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveGammaStrategy::exhaustive_cx_with_gamma(10.); + let strategy = GammaStrategyCost::exhaustive_cx_with_gamma(10.); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]); let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect(); @@ -460,7 +542,7 @@ mod tests { #[test] fn test_exhaustive_default_cx_cost() { - let strat = NonIncreasingGateCountStrategy::default_cx(); + let strat = NonIncreasingGateCountCost::default_cx(); let circ = n_cx(3); assert_eq!(strat.circuit_cost(&circ), (3, 3).into()); let circ = build_simple_circuit(2, |circ| { @@ -475,7 +557,7 @@ mod tests { #[test] fn test_exhaustive_default_cx_threshold() { - let strat = NonIncreasingGateCountStrategy::default_cx(); + let strat = NonIncreasingGateCountCost::default_cx().strat_cost; assert!(strat.under_threshold(&(3, 0).into(), &(3, 0).into())); assert!(strat.under_threshold(&(3, 0).into(), &(3, 5).into())); assert!(!strat.under_threshold(&(3, 10).into(), &(4, 0).into())); From e688bc8b5cdb3f5d14a8cab8ed479ab64460fc4f Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 11 Oct 2023 12:54:38 +0100 Subject: [PATCH 3/4] fix docs --- src/rewrite/strategy.rs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 90e7a176..90439ca2 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -2,15 +2,21 @@ //! //! This module contains the [`RewriteStrategy`] trait, which is currently //! implemented by -//! - [`GreedyRewriteStrategy`], which applies as many rewrites as possible -//! on one circuit, and -//! - exhaustive strategies, which clone the original circuit and explore every +//! - [`GreedyRewriteStrategy`], which applies as many rewrites as possible on +//! one circuit, and +//! - Exhaustive strategies, which clone the original circuit and explore every //! possible rewrite (with some pruning strategy): -//! - [`NonIncreasingGateCountStrategy`], which only considers rewrites that -//! do not increase some cost function (e.g. cx gate count, implemented as -//! [`NonIncreasingGateCountStrategy::default_cx`]), and -//! - [`ExhaustiveGammaStrategy`], which ignores rewrites that increase the -//! cost function beyond a threshold given by a f64 parameter gamma. +//! - [`ExhaustiveGreedyStrategy`], which applies multiple combinations of +//! non-overlapping rewrites. +//! - [`ExhaustiveThresholdStrategy`], which tries every rewrite below +//! threshold function. +//! +//! The exhaustive strategies are parametrised by a strategy cost function: +//! - [`NonIncreasingGateCountCost`], which only considers rewrites that do +//! not increase some cost function (e.g. cx gate count, implemented as +//! [`NonIncreasingGateCountCost::default_cx`]), and +//! - [`GammaStrategyCost`], which ignores rewrites that increase the cost +//! function beyond a percentage given by a f64 parameter gamma. use std::{collections::HashSet, fmt::Debug}; @@ -319,7 +325,7 @@ pub trait StrategyCost { /// /// An example would be to use the number of CX gates as major cost and the /// total number of gates as minor cost. Compared to a -/// [`ExhaustiveGammaStrategy`], that would only order circuits based on the +/// [`GammaStrategyCost`], that would only order circuits based on the /// number of CX gates, this creates a less flat optimisation landscape. #[derive(Debug, Clone)] pub struct NonIncreasingGateCountCost { From 862408eadcaedc31126ffac3d7e8172803565dbe Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Wed, 11 Oct 2023 14:50:21 +0100 Subject: [PATCH 4/4] Update hugr dep --- Cargo.toml | 2 +- src/rewrite.rs | 11 ++++++----- src/rewrite/strategy.rs | 6 ++++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9244e2bd..fc6dfa39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,7 +73,7 @@ members = ["pyrs", "compile-rewriter", "taso-optimiser"] [workspace.dependencies] -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "a34662f0" } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "9254ac7" } portgraph = { version = "0.9", features = ["serde"] } pyo3 = { version = "0.19" } itertools = { version = "0.11.0" } diff --git a/src/rewrite.rs b/src/rewrite.rs index 83f7e422..6d9be652 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -101,13 +101,14 @@ impl CircuitRewrite { self.0.replacement() } - /// Returns the nodes affected by a rewrite. + /// Returns a set of nodes referenced by the rewrite. Modifying any these + /// nodes will invalidate it. /// - /// This includes the nodes in the subcircuit and it's neighbours (contained in - /// the rewrite's boundary). + /// Two `CircuitRewrite`s can be composed if their invalidation sets are + /// disjoint. #[inline] - pub fn affected_nodes(&self) -> impl Iterator + '_ { - self.0.affected_nodes() + pub fn invalidation_set(&self) -> impl Iterator + '_ { + self.0.invalidation_set() } delegate! { diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 90439ca2..ad5df340 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -221,11 +221,13 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { let mut cost_delta = Default::default(); for (rewrite, delta) in &rewrites[i..] { if !changed_nodes.is_empty() - && rewrite.affected_nodes().any(|n| changed_nodes.contains(&n)) + && rewrite + .invalidation_set() + .any(|n| changed_nodes.contains(&n)) { continue; } - changed_nodes.extend(rewrite.affected_nodes()); + changed_nodes.extend(rewrite.invalidation_set()); cost_delta += delta.clone(); rewrite