diff --git a/tket2/src/optimiser/badger.rs b/tket2/src/optimiser/badger.rs index e7da32d9..5487095a 100644 --- a/tket2/src/optimiser/badger.rs +++ b/tket2/src/optimiser/badger.rs @@ -190,13 +190,20 @@ where circ_cnt += 1; let rewrites = self.rewriter.get_rewrites(&circ); - for (new_circ, cost_delta) in self.strategy.apply_rewrites(rewrites, &circ) { - let new_circ_cost = cost.add_delta(&cost_delta); + + // Get combinations of rewrites that can be applied to the circuit, + // and filter them to keep only the ones that + // + // - Don't have a worse cost than the last candidate in the priority queue. + // - Do not invalidate the circuit by creating a loop. + // - We haven't seen yet. + for r in self.strategy.apply_rewrites(rewrites, &circ) { + let new_circ_cost = cost.add_delta(&r.cost_delta); if !pq.check_accepted(&new_circ_cost) { continue; } - let Ok(new_circ_hash) = new_circ.circuit_hash() else { + let Ok(new_circ_hash) = r.circ.circuit_hash() else { // The composed rewrites produced a loop. // // See [https://github.com/CQCL/tket2/discussions/242] @@ -207,7 +214,8 @@ where // Ignore this circuit: we've already seen it continue; } - pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost); + + pq.push_unchecked(r.circ, new_circ_hash, new_circ_cost); logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len()); } diff --git a/tket2/src/optimiser/badger/worker.rs b/tket2/src/optimiser/badger/worker.rs index b0924066..6f4b6608 100644 --- a/tket2/src/optimiser/badger/worker.rs +++ b/tket2/src/optimiser/badger/worker.rs @@ -63,17 +63,17 @@ where }; let rewrites = self.rewriter.get_rewrites(&circ); - let rewrite_result = self.strategy.apply_rewrites(rewrites, &circ); let max_cost = self.priority_channel.max_cost(); - let new_circs = rewrite_result - .into_iter() - .filter_map(|(c, cost_delta)| { - let new_cost = cost.add_delta(&cost_delta); + let new_circs = self + .strategy + .apply_rewrites(rewrites, &circ) + .filter_map(|r| { + let new_cost = cost.add_delta(&r.cost_delta); if max_cost.is_some() && &new_cost >= max_cost.as_ref().unwrap() { return None; } - let Ok(hash) = c.circuit_hash() else { + let Ok(hash) = r.circ.circuit_hash() else { // The composed rewrites were not valid. // // See [https://github.com/CQCL/tket2/discussions/242] @@ -83,7 +83,7 @@ where Some(Work { cost: new_cost, hash, - circ: c, + circ: r.circ, }) }) .collect(); diff --git a/tket2/src/rewrite/strategy.rs b/tket2/src/rewrite/strategy.rs index 1d639f15..e6576942 100644 --- a/tket2/src/rewrite/strategy.rs +++ b/tket2/src/rewrite/strategy.rs @@ -20,6 +20,7 @@ //! - [`GammaStrategyCost`] ignores rewrites that increase the cost //! function beyond a percentage given by a f64 parameter gamma. +use std::iter; use std::{collections::HashSet, fmt::Debug}; use derive_more::From; @@ -51,7 +52,7 @@ pub trait RewriteStrategy { &self, rewrites: impl IntoIterator, circ: &Hugr, - ) -> RewriteResult; + ) -> impl Iterator>; /// The cost of a single operation for this strategy's cost function. fn op_cost(&self, op: &OpType) -> Self::Cost; @@ -76,47 +77,19 @@ pub trait RewriteStrategy { } } -/// The result of a rewrite strategy. -/// -/// Returned by [`RewriteStrategy::apply_rewrites`]. -pub struct RewriteResult { - /// The rewritten circuits. - pub circs: Vec, - /// The cost delta of each rewritten circuit. - pub cost_deltas: Vec, -} - -impl RewriteResult { - /// Init a new rewrite result. - pub fn with_capacity(capacity: usize) -> Self { - Self { - circs: Vec::with_capacity(capacity), - cost_deltas: Vec::with_capacity(capacity), - } - } - - /// Returns the number of rewritten circuits. - pub fn len(&self) -> usize { - self.circs.len() - } - - /// Returns true if there are no rewritten circuits. - pub fn is_empty(&self) -> bool { - self.circs.is_empty() - } - - /// Returns an iterator over the rewritten circuits and their cost deltas. - pub fn iter(&self) -> impl Iterator { - self.circs.iter().zip(self.cost_deltas.iter()) - } +/// A possible rewrite result returned by a rewrite strategy. +#[derive(Debug, Clone)] +pub struct RewriteResult { + /// The rewritten circuit. + pub circ: Hugr, + /// The cost delta of the rewrite. + pub cost_delta: C::CostDelta, } -impl IntoIterator for RewriteResult { - type Item = (Hugr, Cost::CostDelta); - type IntoIter = std::iter::Zip, std::vec::IntoIter>; - - fn into_iter(self) -> Self::IntoIter { - self.circs.into_iter().zip(self.cost_deltas) +impl From<(Hugr, C::CostDelta)> for RewriteResult { + #[inline] + fn from((circ, cost_delta): (Hugr, C::CostDelta)) -> Self { + Self { circ, cost_delta } } } @@ -141,7 +114,7 @@ impl RewriteStrategy for GreedyRewriteStrategy { &self, rewrites: impl IntoIterator, circ: &Hugr, - ) -> RewriteResult { + ) -> impl Iterator> { let rewrites = rewrites .into_iter() .sorted_by_key(|rw| rw.node_count_delta()) @@ -164,10 +137,7 @@ impl RewriteStrategy for GreedyRewriteStrategy { .apply(&mut circ) .expect("Could not perform rewrite in greedy strategy"); } - RewriteResult { - circs: vec![circ], - cost_deltas: vec![cost_delta], - } + iter::once((circ, cost_delta).into()) } fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { @@ -215,7 +185,7 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { &self, rewrites: impl IntoIterator, circ: &Hugr, - ) -> RewriteResult { + ) -> impl Iterator> { // Check only the rewrites that reduce the size of the circuit. let rewrites = rewrites .into_iter() @@ -230,8 +200,7 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { .sorted_by_key(|(_, delta)| delta.clone()) .collect_vec(); - let mut rewrite_sets = RewriteResult::with_capacity(rewrites.len()); - for i in 0..rewrites.len() { + (0..rewrites.len()).map(move |i| { let mut curr_circ = circ.clone(); let mut changed_nodes = HashSet::new(); let mut cost_delta = Default::default(); @@ -256,10 +225,8 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { } curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count)); - rewrite_sets.circs.push(curr_circ); - rewrite_sets.cost_deltas.push(cost_delta); - } - rewrite_sets + (curr_circ, cost_delta).into() + }) } #[inline] @@ -296,21 +263,17 @@ impl RewriteStrategy for ExhaustiveThresholdStrategy { &self, rewrites: impl IntoIterator, circ: &Hugr, - ) -> RewriteResult { - let (circs, cost_deltas) = rewrites - .into_iter() - .filter_map(|rw| { - let pattern_cost = self.pre_rewrite_cost(&rw, circ); - let target_cost = self.post_rewrite_cost(&rw); - if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) { - return None; - } - let mut circ = circ.clone(); - rw.apply(&mut circ).expect("invalid pattern match"); - Some((circ, target_cost.sub_cost(&pattern_cost))) - }) - .unzip(); - RewriteResult { circs, cost_deltas } + ) -> impl Iterator> { + rewrites.into_iter().filter_map(|rw| { + let pattern_cost = self.pre_rewrite_cost(&rw, circ); + let target_cost = self.post_rewrite_cost(&rw); + if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) { + return None; + } + let mut circ = circ.clone(); + rw.apply(&mut circ).expect("invalid pattern match"); + Some((circ, target_cost.sub_cost(&pattern_cost)).into()) + }) } #[inline] @@ -520,12 +483,12 @@ mod tests { ]; let strategy = GreedyRewriteStrategy; - let rewritten = strategy.apply_rewrites(rws, &circ); + let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec(); assert_eq!(rewritten.len(), 1); - assert_eq!(rewritten.circs[0].num_gates(), 5); + assert_eq!(rewritten[0].circ.num_gates(), 5); if REWRITE_TRACING_ENABLED { - assert_eq!(rewritten.circs[0].rewrite_trace().unwrap().len(), 3); + assert_eq!(rewritten[0].circ.rewrite_trace().unwrap().len(), 3); } } @@ -543,24 +506,24 @@ mod tests { ]; let strategy = LexicographicCostFunction::default_cx(); - let rewritten = strategy.apply_rewrites(rws, &circ); + let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec(); let exp_circ_lens = HashSet::from_iter([3, 7, 9]); - let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect(); + let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_gates()).collect(); assert_eq!(circ_lens, exp_circ_lens); if REWRITE_TRACING_ENABLED { // Each strategy branch applies a single rewrite, composed of // multiple individual elements from `rws`. assert_eq!( - rewritten.circs[0].rewrite_trace().unwrap(), + rewritten[0].circ.rewrite_trace().unwrap(), vec![RewriteTrace::new(3)] ); assert_eq!( - rewritten.circs[1].rewrite_trace().unwrap(), + rewritten[1].circ.rewrite_trace().unwrap(), vec![RewriteTrace::new(2)] ); assert_eq!( - rewritten.circs[2].rewrite_trace().unwrap(), + rewritten[2].circ.rewrite_trace().unwrap(), vec![RewriteTrace::new(1)] ); } @@ -581,7 +544,7 @@ mod tests { let strategy = GammaStrategyCost::exhaustive_cx_with_gamma(10.); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]); - let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect(); + let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_gates()).collect(); assert_eq!(circ_lens, exp_circ_lens); }