Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Return rewrite strategies as a generator #275

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,20 @@ where
circ_cnt += 1;

let rewrites = self.rewriter.get_rewrites(&circ);
for (new_circ, cost_delta) in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_cost = cost.add_delta(&cost_delta);

// Get combinations of rewrites that can be applied to the circuit,
// and filter them to keep only the ones that
//
// - Don't have a worse cost than the last candidate in the priority queue.
// - Do not invalidate the circuit by creating a loop.
// - We haven't seen yet.
for r in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_cost = cost.add_delta(&r.cost_delta);
if !pq.check_accepted(&new_circ_cost) {
continue;
}

let Ok(new_circ_hash) = new_circ.circuit_hash() else {
let Ok(new_circ_hash) = r.circ.circuit_hash() else {
// The composed rewrites produced a loop.
//
// See [https://github.com/CQCL/tket2/discussions/242]
Expand All @@ -207,7 +214,8 @@ where
// Ignore this circuit: we've already seen it
continue;
}
pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost);

pq.push_unchecked(r.circ, new_circ_hash, new_circ_cost);
logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len());
}

Expand Down
14 changes: 7 additions & 7 deletions tket2/src/optimiser/badger/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ where
};

let rewrites = self.rewriter.get_rewrites(&circ);
let rewrite_result = self.strategy.apply_rewrites(rewrites, &circ);
let max_cost = self.priority_channel.max_cost();
let new_circs = rewrite_result
.into_iter()
.filter_map(|(c, cost_delta)| {
let new_cost = cost.add_delta(&cost_delta);
let new_circs = self
.strategy
.apply_rewrites(rewrites, &circ)
.filter_map(|r| {
let new_cost = cost.add_delta(&r.cost_delta);
if max_cost.is_some() && &new_cost >= max_cost.as_ref().unwrap() {
return None;
}

let Ok(hash) = c.circuit_hash() else {
let Ok(hash) = r.circ.circuit_hash() else {
// The composed rewrites were not valid.
//
// See [https://github.com/CQCL/tket2/discussions/242]
Expand All @@ -83,7 +83,7 @@ where
Some(Work {
cost: new_cost,
hash,
circ: c,
circ: r.circ,
})
})
.collect();
Expand Down
115 changes: 39 additions & 76 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
//! - [`GammaStrategyCost`] ignores rewrites that increase the cost
//! function beyond a percentage given by a f64 parameter gamma.

use std::iter;
use std::{collections::HashSet, fmt::Debug};

use derive_more::From;
Expand Down Expand Up @@ -51,7 +52,7 @@ pub trait RewriteStrategy {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<Self::Cost>;
) -> impl Iterator<Item = RewriteResult<Self::Cost>>;

/// The cost of a single operation for this strategy's cost function.
fn op_cost(&self, op: &OpType) -> Self::Cost;
Expand All @@ -76,47 +77,19 @@ pub trait RewriteStrategy {
}
}

/// The result of a rewrite strategy.
///
/// Returned by [`RewriteStrategy::apply_rewrites`].
pub struct RewriteResult<Cost: CircuitCost> {
/// The rewritten circuits.
pub circs: Vec<Hugr>,
/// The cost delta of each rewritten circuit.
pub cost_deltas: Vec<Cost::CostDelta>,
}

impl<Cost: CircuitCost> RewriteResult<Cost> {
/// Init a new rewrite result.
pub fn with_capacity(capacity: usize) -> Self {
Self {
circs: Vec::with_capacity(capacity),
cost_deltas: Vec::with_capacity(capacity),
}
}

/// Returns the number of rewritten circuits.
pub fn len(&self) -> usize {
self.circs.len()
}

/// Returns true if there are no rewritten circuits.
pub fn is_empty(&self) -> bool {
self.circs.is_empty()
}

/// Returns an iterator over the rewritten circuits and their cost deltas.
pub fn iter(&self) -> impl Iterator<Item = (&Hugr, &Cost::CostDelta)> {
self.circs.iter().zip(self.cost_deltas.iter())
}
/// A possible rewrite result returned by a rewrite strategy.
#[derive(Debug, Clone)]
pub struct RewriteResult<C: CircuitCost> {
/// The rewritten circuit.
pub circ: Hugr,
/// The cost delta of the rewrite.
pub cost_delta: C::CostDelta,
}

impl<Cost: CircuitCost> IntoIterator for RewriteResult<Cost> {
type Item = (Hugr, Cost::CostDelta);
type IntoIter = std::iter::Zip<std::vec::IntoIter<Hugr>, std::vec::IntoIter<Cost::CostDelta>>;

fn into_iter(self) -> Self::IntoIter {
self.circs.into_iter().zip(self.cost_deltas)
impl<C: CircuitCost> From<(Hugr, C::CostDelta)> for RewriteResult<C> {
#[inline]
fn from((circ, cost_delta): (Hugr, C::CostDelta)) -> Self {
Self { circ, cost_delta }
}
}

Expand All @@ -141,7 +114,7 @@ impl RewriteStrategy for GreedyRewriteStrategy {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<usize> {
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
let rewrites = rewrites
.into_iter()
.sorted_by_key(|rw| rw.node_count_delta())
Expand All @@ -164,10 +137,7 @@ impl RewriteStrategy for GreedyRewriteStrategy {
.apply(&mut circ)
.expect("Could not perform rewrite in greedy strategy");
}
RewriteResult {
circs: vec![circ],
cost_deltas: vec![cost_delta],
}
iter::once((circ, cost_delta).into())
}

fn circuit_cost(&self, circ: &Hugr) -> Self::Cost {
Expand Down Expand Up @@ -215,7 +185,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<T::OpCost> {
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
// Check only the rewrites that reduce the size of the circuit.
let rewrites = rewrites
.into_iter()
Expand All @@ -230,8 +200,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
.sorted_by_key(|(_, delta)| delta.clone())
.collect_vec();

let mut rewrite_sets = RewriteResult::with_capacity(rewrites.len());
for i in 0..rewrites.len() {
(0..rewrites.len()).map(move |i| {
let mut curr_circ = circ.clone();
let mut changed_nodes = HashSet::new();
let mut cost_delta = Default::default();
Expand All @@ -256,10 +225,8 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
}

curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count));
rewrite_sets.circs.push(curr_circ);
rewrite_sets.cost_deltas.push(cost_delta);
}
rewrite_sets
(curr_circ, cost_delta).into()
})
}

#[inline]
Expand Down Expand Up @@ -296,21 +263,17 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveThresholdStrategy<T> {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<T::OpCost> {
let (circs, cost_deltas) = rewrites
.into_iter()
.filter_map(|rw| {
let pattern_cost = self.pre_rewrite_cost(&rw, circ);
let target_cost = self.post_rewrite_cost(&rw);
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
let mut circ = circ.clone();
rw.apply(&mut circ).expect("invalid pattern match");
Some((circ, target_cost.sub_cost(&pattern_cost)))
})
.unzip();
RewriteResult { circs, cost_deltas }
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
rewrites.into_iter().filter_map(|rw| {
let pattern_cost = self.pre_rewrite_cost(&rw, circ);
let target_cost = self.post_rewrite_cost(&rw);
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
let mut circ = circ.clone();
rw.apply(&mut circ).expect("invalid pattern match");
Some((circ, target_cost.sub_cost(&pattern_cost)).into())
})
}

#[inline]
Expand Down Expand Up @@ -520,12 +483,12 @@ mod tests {
];

let strategy = GreedyRewriteStrategy;
let rewritten = strategy.apply_rewrites(rws, &circ);
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
assert_eq!(rewritten.len(), 1);
assert_eq!(rewritten.circs[0].num_gates(), 5);
assert_eq!(rewritten[0].circ.num_gates(), 5);

if REWRITE_TRACING_ENABLED {
assert_eq!(rewritten.circs[0].rewrite_trace().unwrap().len(), 3);
assert_eq!(rewritten[0].circ.rewrite_trace().unwrap().len(), 3);
}
}

Expand All @@ -543,24 +506,24 @@ mod tests {
];

let strategy = LexicographicCostFunction::default_cx();
let rewritten = strategy.apply_rewrites(rws, &circ);
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_gates()).collect();
assert_eq!(circ_lens, exp_circ_lens);

if REWRITE_TRACING_ENABLED {
// Each strategy branch applies a single rewrite, composed of
// multiple individual elements from `rws`.
assert_eq!(
rewritten.circs[0].rewrite_trace().unwrap(),
rewritten[0].circ.rewrite_trace().unwrap(),
vec![RewriteTrace::new(3)]
);
assert_eq!(
rewritten.circs[1].rewrite_trace().unwrap(),
rewritten[1].circ.rewrite_trace().unwrap(),
vec![RewriteTrace::new(2)]
);
assert_eq!(
rewritten.circs[2].rewrite_trace().unwrap(),
rewritten[2].circ.rewrite_trace().unwrap(),
vec![RewriteTrace::new(1)]
);
}
Expand All @@ -581,7 +544,7 @@ mod tests {
let strategy = GammaStrategyCost::exhaustive_cx_with_gamma(10.);
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]);
let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_gates()).collect();
assert_eq!(circ_lens, exp_circ_lens);
}

Expand Down