Skip to content

Commit

Permalink
feat: Return rewrite strategies as a generator (#275)
Browse files Browse the repository at this point in the history
...instead of collecting a vector and returning it.

This lets us generate rewrites on-demand, so we avoid unnecessary work.

Closes #269
  • Loading branch information
aborgna-q authored Jan 2, 2024
1 parent 3dae175 commit aead286
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 87 deletions.
16 changes: 12 additions & 4 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,20 @@ where
circ_cnt += 1;

let rewrites = self.rewriter.get_rewrites(&circ);
for (new_circ, cost_delta) in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_cost = cost.add_delta(&cost_delta);

// Get combinations of rewrites that can be applied to the circuit,
// and filter them to keep only the ones that
//
// - Don't have a worse cost than the last candidate in the priority queue.
// - Do not invalidate the circuit by creating a loop.
// - We haven't seen yet.
for r in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_cost = cost.add_delta(&r.cost_delta);
if !pq.check_accepted(&new_circ_cost) {
continue;
}

let Ok(new_circ_hash) = new_circ.circuit_hash() else {
let Ok(new_circ_hash) = r.circ.circuit_hash() else {
// The composed rewrites produced a loop.
//
// See [https://github.com/CQCL/tket2/discussions/242]
Expand All @@ -207,7 +214,8 @@ where
// Ignore this circuit: we've already seen it
continue;
}
pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost);

pq.push_unchecked(r.circ, new_circ_hash, new_circ_cost);
logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len());
}

Expand Down
14 changes: 7 additions & 7 deletions tket2/src/optimiser/badger/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ where
};

let rewrites = self.rewriter.get_rewrites(&circ);
let rewrite_result = self.strategy.apply_rewrites(rewrites, &circ);
let max_cost = self.priority_channel.max_cost();
let new_circs = rewrite_result
.into_iter()
.filter_map(|(c, cost_delta)| {
let new_cost = cost.add_delta(&cost_delta);
let new_circs = self
.strategy
.apply_rewrites(rewrites, &circ)
.filter_map(|r| {
let new_cost = cost.add_delta(&r.cost_delta);
if max_cost.is_some() && &new_cost >= max_cost.as_ref().unwrap() {
return None;
}

let Ok(hash) = c.circuit_hash() else {
let Ok(hash) = r.circ.circuit_hash() else {
// The composed rewrites were not valid.
//
// See [https://github.com/CQCL/tket2/discussions/242]
Expand All @@ -83,7 +83,7 @@ where
Some(Work {
cost: new_cost,
hash,
circ: c,
circ: r.circ,
})
})
.collect();
Expand Down
115 changes: 39 additions & 76 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
//! - [`GammaStrategyCost`] ignores rewrites that increase the cost
//! function beyond a percentage given by a f64 parameter gamma.
use std::iter;
use std::{collections::HashSet, fmt::Debug};

use derive_more::From;
Expand Down Expand Up @@ -51,7 +52,7 @@ pub trait RewriteStrategy {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<Self::Cost>;
) -> impl Iterator<Item = RewriteResult<Self::Cost>>;

/// The cost of a single operation for this strategy's cost function.
fn op_cost(&self, op: &OpType) -> Self::Cost;
Expand All @@ -76,47 +77,19 @@ pub trait RewriteStrategy {
}
}

/// The result of a rewrite strategy.
///
/// Returned by [`RewriteStrategy::apply_rewrites`].
pub struct RewriteResult<Cost: CircuitCost> {
/// The rewritten circuits.
pub circs: Vec<Hugr>,
/// The cost delta of each rewritten circuit.
pub cost_deltas: Vec<Cost::CostDelta>,
}

impl<Cost: CircuitCost> RewriteResult<Cost> {
/// Init a new rewrite result.
pub fn with_capacity(capacity: usize) -> Self {
Self {
circs: Vec::with_capacity(capacity),
cost_deltas: Vec::with_capacity(capacity),
}
}

/// Returns the number of rewritten circuits.
pub fn len(&self) -> usize {
self.circs.len()
}

/// Returns true if there are no rewritten circuits.
pub fn is_empty(&self) -> bool {
self.circs.is_empty()
}

/// Returns an iterator over the rewritten circuits and their cost deltas.
pub fn iter(&self) -> impl Iterator<Item = (&Hugr, &Cost::CostDelta)> {
self.circs.iter().zip(self.cost_deltas.iter())
}
/// A possible rewrite result returned by a rewrite strategy.
#[derive(Debug, Clone)]
pub struct RewriteResult<C: CircuitCost> {
/// The rewritten circuit.
pub circ: Hugr,
/// The cost delta of the rewrite.
pub cost_delta: C::CostDelta,
}

impl<Cost: CircuitCost> IntoIterator for RewriteResult<Cost> {
type Item = (Hugr, Cost::CostDelta);
type IntoIter = std::iter::Zip<std::vec::IntoIter<Hugr>, std::vec::IntoIter<Cost::CostDelta>>;

fn into_iter(self) -> Self::IntoIter {
self.circs.into_iter().zip(self.cost_deltas)
impl<C: CircuitCost> From<(Hugr, C::CostDelta)> for RewriteResult<C> {
#[inline]
fn from((circ, cost_delta): (Hugr, C::CostDelta)) -> Self {
Self { circ, cost_delta }
}
}

Expand All @@ -141,7 +114,7 @@ impl RewriteStrategy for GreedyRewriteStrategy {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<usize> {
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
let rewrites = rewrites
.into_iter()
.sorted_by_key(|rw| rw.node_count_delta())
Expand All @@ -164,10 +137,7 @@ impl RewriteStrategy for GreedyRewriteStrategy {
.apply(&mut circ)
.expect("Could not perform rewrite in greedy strategy");
}
RewriteResult {
circs: vec![circ],
cost_deltas: vec![cost_delta],
}
iter::once((circ, cost_delta).into())
}

fn circuit_cost(&self, circ: &Hugr) -> Self::Cost {
Expand Down Expand Up @@ -215,7 +185,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<T::OpCost> {
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
// Check only the rewrites that reduce the size of the circuit.
let rewrites = rewrites
.into_iter()
Expand All @@ -230,8 +200,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
.sorted_by_key(|(_, delta)| delta.clone())
.collect_vec();

let mut rewrite_sets = RewriteResult::with_capacity(rewrites.len());
for i in 0..rewrites.len() {
(0..rewrites.len()).map(move |i| {
let mut curr_circ = circ.clone();
let mut changed_nodes = HashSet::new();
let mut cost_delta = Default::default();
Expand All @@ -256,10 +225,8 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
}

curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count));
rewrite_sets.circs.push(curr_circ);
rewrite_sets.cost_deltas.push(cost_delta);
}
rewrite_sets
(curr_circ, cost_delta).into()
})
}

#[inline]
Expand Down Expand Up @@ -296,21 +263,17 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveThresholdStrategy<T> {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<T::OpCost> {
let (circs, cost_deltas) = rewrites
.into_iter()
.filter_map(|rw| {
let pattern_cost = self.pre_rewrite_cost(&rw, circ);
let target_cost = self.post_rewrite_cost(&rw);
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
let mut circ = circ.clone();
rw.apply(&mut circ).expect("invalid pattern match");
Some((circ, target_cost.sub_cost(&pattern_cost)))
})
.unzip();
RewriteResult { circs, cost_deltas }
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
rewrites.into_iter().filter_map(|rw| {
let pattern_cost = self.pre_rewrite_cost(&rw, circ);
let target_cost = self.post_rewrite_cost(&rw);
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
let mut circ = circ.clone();
rw.apply(&mut circ).expect("invalid pattern match");
Some((circ, target_cost.sub_cost(&pattern_cost)).into())
})
}

#[inline]
Expand Down Expand Up @@ -520,12 +483,12 @@ mod tests {
];

let strategy = GreedyRewriteStrategy;
let rewritten = strategy.apply_rewrites(rws, &circ);
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
assert_eq!(rewritten.len(), 1);
assert_eq!(rewritten.circs[0].num_gates(), 5);
assert_eq!(rewritten[0].circ.num_gates(), 5);

if REWRITE_TRACING_ENABLED {
assert_eq!(rewritten.circs[0].rewrite_trace().unwrap().len(), 3);
assert_eq!(rewritten[0].circ.rewrite_trace().unwrap().len(), 3);
}
}

Expand All @@ -543,24 +506,24 @@ mod tests {
];

let strategy = LexicographicCostFunction::default_cx();
let rewritten = strategy.apply_rewrites(rws, &circ);
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_gates()).collect();
assert_eq!(circ_lens, exp_circ_lens);

if REWRITE_TRACING_ENABLED {
// Each strategy branch applies a single rewrite, composed of
// multiple individual elements from `rws`.
assert_eq!(
rewritten.circs[0].rewrite_trace().unwrap(),
rewritten[0].circ.rewrite_trace().unwrap(),
vec![RewriteTrace::new(3)]
);
assert_eq!(
rewritten.circs[1].rewrite_trace().unwrap(),
rewritten[1].circ.rewrite_trace().unwrap(),
vec![RewriteTrace::new(2)]
);
assert_eq!(
rewritten.circs[2].rewrite_trace().unwrap(),
rewritten[2].circ.rewrite_trace().unwrap(),
vec![RewriteTrace::new(1)]
);
}
Expand All @@ -581,7 +544,7 @@ mod tests {
let strategy = GammaStrategyCost::exhaustive_cx_with_gamma(10.);
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]);
let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_gates()).collect();
assert_eq!(circ_lens, exp_circ_lens);
}

Expand Down

0 comments on commit aead286

Please sign in to comment.