diff --git a/tket2/src/rewrite/strategy.rs b/tket2/src/rewrite/strategy.rs index 899b7909..890a8539 100644 --- a/tket2/src/rewrite/strategy.rs +++ b/tket2/src/rewrite/strategy.rs @@ -59,6 +59,19 @@ pub trait RewriteStrategy { fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { circ.circuit_cost(|op| self.op_cost(op)) } + + /// Returns the cost of a rewrite's matched subcircuit before replacing it. + #[inline] + fn pre_rewrite_cost(&self, rw: &CircuitRewrite, circ: &Hugr) -> Self::Cost { + circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), |op| { + self.op_cost(op) + }) + } + + /// Returns the expected cost of a rewrite's matched subcircuit after replacing it. + fn post_rewrite_cost(&self, rw: &CircuitRewrite) -> Self::Cost { + rw.replacement().circuit_cost(|op| self.op_cost(op)) + } } /// The result of a rewrite strategy. @@ -205,8 +218,8 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { let rewrites = rewrites .into_iter() .filter_map(|rw| { - let pattern_cost = pre_rewrite_cost(&rw, circ, |op| self.op_cost(op)); - let target_cost = post_rewrite_cost(&rw, |op| self.op_cost(op)); + let pattern_cost = self.pre_rewrite_cost(&rw, circ); + let target_cost = self.post_rewrite_cost(&rw); if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) { return None; } @@ -285,8 +298,8 @@ impl RewriteStrategy for ExhaustiveThresholdStrategy { let (circs, cost_deltas) = rewrites .into_iter() .filter_map(|rw| { - let pattern_cost = pre_rewrite_cost(&rw, circ, |op| self.op_cost(op)); - let target_cost = post_rewrite_cost(&rw, |op| self.op_cost(op)); + let pattern_cost = self.pre_rewrite_cost(&rw, circ); + let target_cost = self.post_rewrite_cost(&rw); if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) { return None; } @@ -446,22 +459,6 @@ impl GammaStrategyCost usize> { } } -fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: F) -> C -where - C: CircuitCost, - F: Fn(&OpType) -> C, -{ - circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), pred) -} - -fn post_rewrite_cost(rw: &CircuitRewrite, pred: F) -> C -where - C: CircuitCost, - F: Fn(&OpType) -> C, -{ - rw.replacement().circuit_cost(pred) -} - #[cfg(test)] mod tests { use super::*;