Skip to content

Commit

Permalink
feat: Move pre/post rewrite cost to the RewriteStrategy API (#276)
Browse files Browse the repository at this point in the history
Expose the rewrite cost delta functions
  • Loading branch information
aborgna-q authored Dec 18, 2023
1 parent 9332b6e commit e470fa0
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ pub trait RewriteStrategy {
fn circuit_cost(&self, circ: &Hugr) -> Self::Cost {
circ.circuit_cost(|op| self.op_cost(op))
}

/// Returns the cost of a rewrite's matched subcircuit before replacing it.
#[inline]
fn pre_rewrite_cost(&self, rw: &CircuitRewrite, circ: &Hugr) -> Self::Cost {
circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), |op| {
self.op_cost(op)
})
}

/// Returns the expected cost of a rewrite's matched subcircuit after replacing it.
fn post_rewrite_cost(&self, rw: &CircuitRewrite) -> Self::Cost {
rw.replacement().circuit_cost(|op| self.op_cost(op))
}
}

/// The result of a rewrite strategy.
Expand Down Expand Up @@ -205,8 +218,8 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
let rewrites = rewrites
.into_iter()
.filter_map(|rw| {
let pattern_cost = pre_rewrite_cost(&rw, circ, |op| self.op_cost(op));
let target_cost = post_rewrite_cost(&rw, |op| self.op_cost(op));
let pattern_cost = self.pre_rewrite_cost(&rw, circ);
let target_cost = self.post_rewrite_cost(&rw);
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
Expand Down Expand Up @@ -285,8 +298,8 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveThresholdStrategy<T> {
let (circs, cost_deltas) = rewrites
.into_iter()
.filter_map(|rw| {
let pattern_cost = pre_rewrite_cost(&rw, circ, |op| self.op_cost(op));
let target_cost = post_rewrite_cost(&rw, |op| self.op_cost(op));
let pattern_cost = self.pre_rewrite_cost(&rw, circ);
let target_cost = self.post_rewrite_cost(&rw);
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
Expand Down Expand Up @@ -446,22 +459,6 @@ impl GammaStrategyCost<fn(&OpType) -> usize> {
}
}

fn pre_rewrite_cost<F, C>(rw: &CircuitRewrite, circ: &Hugr, pred: F) -> C
where
C: CircuitCost,
F: Fn(&OpType) -> C,
{
circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), pred)
}

fn post_rewrite_cost<F, C>(rw: &CircuitRewrite, pred: F) -> C
where
C: CircuitCost,
F: Fn(&OpType) -> C,
{
rw.replacement().circuit_cost(pred)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit e470fa0

Please sign in to comment.