Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada committed Oct 3, 2023
1 parent e5d3fe7 commit 6e52e83
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 74 deletions.
13 changes: 9 additions & 4 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,20 +303,25 @@ mod taso_default {
use std::io;
use std::path::Path;

use hugr::ops::OpType;

use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::NonIncreasingCXCountStrategy;
use crate::rewrite::strategy::NonIncreasingGateCountStrategy;
use crate::rewrite::ECCRewriter;

use super::*;

/// The default TASO optimiser using ECC sets.
pub type DefaultTasoOptimiser = TasoOptimiser<ECCRewriter, NonIncreasingCXCountStrategy>;
pub type DefaultTasoOptimiser = TasoOptimiser<
ECCRewriter,
NonIncreasingGateCountStrategy<fn(&OpType) -> usize, fn(&OpType) -> usize>,
>;

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = NonIncreasingCXCountStrategy::default_cx();
let strategy = NonIncreasingGateCountStrategy::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}

Expand All @@ -325,7 +330,7 @@ mod taso_default {
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = NonIncreasingCXCountStrategy::default_cx();
let strategy = NonIncreasingGateCountStrategy::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}
}
Expand Down
138 changes: 68 additions & 70 deletions src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,65 @@ impl RewriteStrategy for GreedyRewriteStrategy {
}
}

/// Exhaustive strategies based on cost functions and thresholds.
///
/// Every possible rewrite is applied to a copy of the input circuit. Thus for
/// one circuit, up to `n` rewritten circuits will be returned, each obtained
/// by applying one of the `n` rewrites to the original circuit.
///
/// Whether a rewrite is allowed or not is determined by a cost function and a
/// threshold function: if the cost of the target of the rewrite is below the
/// threshold given by the cost of the original circuit, the rewrite is
/// performed.
///
/// The cost function must return a value of type `Self::OpCost`. All op costs
/// are summed up to obtain a total cost that is then compared using the
/// threshold function.
pub trait ExhaustiveThresholdStrategy {
/// The cost of a single operation.
type OpCost;
/// The sum of the cost of all operations in a circuit.
type SumOpCost;

/// Whether the rewrite is allowed or not, based on the cost of the pattern and target.
fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool;

/// The cost of a single operation.
fn op_cost(&self, op: &OpType) -> Self::OpCost;
}

impl<T: ExhaustiveThresholdStrategy> RewriteStrategy for T
where
T::SumOpCost: Sum<T::OpCost> + Ord,
{
type Cost = T::SumOpCost;

#[tracing::instrument(skip_all)]
fn apply_rewrites(
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> Vec<Hugr> {
rewrites
.into_iter()
.filter(|rw| {
let pattern_cost = pre_rewrite_cost(rw, circ, |op| self.op_cost(op));
let target_cost = post_rewrite_cost(rw, circ, |op| self.op_cost(op));
self.threshold(&pattern_cost, &target_cost)
})
.map(|rw| {
let mut circ = circ.clone();
rw.apply(&mut circ).expect("invalid pattern match");
circ
})
.collect()
}

fn circuit_cost(&self, circ: &Hugr) -> Self::Cost {
cost(circ.nodes(), circ, |op| self.op_cost(op))
}
}

/// Exhaustive rewrite strategy allowing smaller or equal cost rewrites.
///
/// Rewrites are permitted based on a cost function called the major cost: if
Expand Down Expand Up @@ -132,15 +191,13 @@ where
}
}

/// Non-increasing rewrite strategy based on CX count.
///
/// The minor cost to break ties between equal CX counts is the number of
/// quantum gates.
pub type NonIncreasingCXCountStrategy =
NonIncreasingGateCountStrategy<fn(&OpType) -> usize, fn(&OpType) -> usize>;

impl NonIncreasingCXCountStrategy {
/// Create rewrite strategy based on non-increasing CX count.
impl NonIncreasingGateCountStrategy<fn(&OpType) -> usize, fn(&OpType) -> usize> {
/// Non-increasing rewrite strategy based on CX count.
///
/// The minor cost to break ties between equal CX counts is the number of
/// quantum gates.
///
/// This is probably a good default for NISQ-y circuit optimisation.
pub fn default_cx() -> Self {
Self {
major_cost: |op| is_cx(op) as usize,
Expand Down Expand Up @@ -218,65 +275,6 @@ impl ExhaustiveGammaStrategy<fn(&OpType) -> usize> {
}
}

/// Exhaustive strategies based on cost functions and thresholds.
///
/// Every possible rewrite is applied to a copy of the input circuit. Thus for
/// one circuit, up to `n` rewritten circuits will be returned, each obtained
/// by applying one of the `n` rewrites to the original circuit.
///
/// Whether a rewrite is allowed or not is determined by a cost function and a
/// threshold function: if the cost of the target of the rewrite is below the
/// threshold given by the cost of the original circuit, the rewrite is
/// performed.
///
/// The cost function must return a value of type `Self::OpCost`. All op costs
/// are summed up to obtain a total cost that is then compared using the
/// threshold function.
pub trait ExhaustiveThresholdStrategy {
/// The cost of a single operation.
type OpCost;
/// The sum of the cost of all operations in a circuit.
type SumOpCost;

/// Whether the rewrite is allowed or not, based on the cost of the pattern and target.
fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool;

/// The cost of a single operation.
fn op_cost(&self, op: &OpType) -> Self::OpCost;
}

impl<T: ExhaustiveThresholdStrategy> RewriteStrategy for T
where
T::SumOpCost: Sum<T::OpCost> + Ord,
{
type Cost = T::SumOpCost;

#[tracing::instrument(skip_all)]
fn apply_rewrites(
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> Vec<Hugr> {
rewrites
.into_iter()
.filter(|rw| {
let pattern_cost = pre_rewrite_cost(rw, circ, |op| self.op_cost(op));
let target_cost = post_rewrite_cost(rw, circ, |op| self.op_cost(op));
self.threshold(&pattern_cost, &target_cost)
})
.map(|rw| {
let mut circ = circ.clone();
rw.apply(&mut circ).expect("invalid pattern match");
circ
})
.collect()
}

fn circuit_cost(&self, circ: &Hugr) -> Self::Cost {
cost(circ.nodes(), circ, |op| self.op_cost(op))
}
}

/// A pair of major and minor cost.
///
/// This is used to order circuits based on major cost first, then minor cost.
Expand Down Expand Up @@ -459,7 +457,7 @@ mod tests {

#[test]
fn test_exhaustive_default_cx_cost() {
let strat = NonIncreasingCXCountStrategy::default_cx();
let strat = NonIncreasingGateCountStrategy::default_cx();
let circ = n_cx(3);
assert_eq!(strat.circuit_cost(&circ), (3, 3).into());
let circ = build_simple_circuit(2, |circ| {
Expand All @@ -474,7 +472,7 @@ mod tests {

#[test]
fn test_exhaustive_default_cx_threshold() {
let strat = NonIncreasingCXCountStrategy::default_cx();
let strat = NonIncreasingGateCountStrategy::default_cx();
assert!(strat.threshold(&(3, 0).into(), &(3, 0).into()));
assert!(strat.threshold(&(3, 0).into(), &(3, 5).into()));
assert!(!strat.threshold(&(3, 10).into(), &(4, 0).into()));
Expand Down

0 comments on commit 6e52e83

Please sign in to comment.