Skip to content

Commit

Permalink
feat!: Use CX count as default cost function (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada authored Sep 26, 2023
1 parent 33cbc71 commit 969d463
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 28 deletions.
34 changes: 30 additions & 4 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ pub enum T2Op {
TK1,
}

/// Whether an op is a given T2Op.
pub(crate) fn op_matches(op: &OpType, t2op: T2Op) -> bool {
let Ok(op) = T2Op::try_from(op) else {
return false;
};
op == t2op
}

#[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)]
#[cfg_attr(feature = "pyo3", pyclass)]
#[allow(missing_docs)]
Expand Down Expand Up @@ -294,17 +302,27 @@ impl TryFrom<OpType> for T2Op {
type Error = NotT2Op;

fn try_from(op: OpType) -> Result<Self, Self::Error> {
let leaf: LeafOp = op.try_into().map_err(|_| NotT2Op)?;
Self::try_from(&op)
}
}

impl TryFrom<&OpType> for T2Op {
type Error = NotT2Op;

fn try_from(op: &OpType) -> Result<Self, Self::Error> {
let OpType::LeafOp(leaf) = op else {
return Err(NotT2Op);
};
leaf.try_into()
}
}

impl TryFrom<LeafOp> for T2Op {
impl TryFrom<&LeafOp> for T2Op {
type Error = NotT2Op;

fn try_from(op: LeafOp) -> Result<Self, Self::Error> {
fn try_from(op: &LeafOp) -> Result<Self, Self::Error> {
match op {
LeafOp::CustomOp(b) => match *b {
LeafOp::CustomOp(b) => match b.as_ref() {
ExternalOp::Extension(e) => Self::try_from_op_def(e.def()),
ExternalOp::Opaque(o) => from_extension_name(o.extension(), o.name()),
},
Expand All @@ -313,6 +331,14 @@ impl TryFrom<LeafOp> for T2Op {
}
}

impl TryFrom<LeafOp> for T2Op {
type Error = NotT2Op;

fn try_from(op: LeafOp) -> Result<Self, Self::Error> {
Self::try_from(&op)
}
}

/// load all variants of a `SimpleOpEnum` in to an extension as op defs.
fn load_all_ops<T: SimpleOpEnum>(extension: &mut Extension) -> Result<(), ExtensionBuildError> {
for op in T::all_variants() {
Expand Down
3 changes: 3 additions & 0 deletions src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
//! Currently, the only optimiser is TASO
pub mod taso;

#[cfg(feature = "portmatching")]
pub use taso::DefaultTasoOptimiser;
pub use taso::TasoOptimiser;
27 changes: 23 additions & 4 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,20 +293,39 @@ where

#[cfg(feature = "portmatching")]
mod taso_default {
use crate::circuit::Circuit;
use hugr::ops::OpType;
use hugr::HugrView;

use crate::ops::op_matches;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use crate::rewrite::ECCRewriter;
use crate::T2Op;

use super::*;

impl TasoOptimiser<ECCRewriter, ExhaustiveRewriteStrategy, fn(&Hugr) -> usize> {
/// The default TASO optimiser using ECC sets.
pub type DefaultTasoOptimiser = TasoOptimiser<
ECCRewriter,
ExhaustiveRewriteStrategy<fn(&OpType) -> bool>,
fn(&Hugr) -> usize,
>;

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(
eccs_path: impl AsRef<std::path::Path>,
) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::default();
Ok(Self::new(rewriter, strategy, |c| c.num_gates()))
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}
}

fn num_cx_gates(circ: &Hugr) -> usize {
circ.nodes()
.filter(|&n| op_matches(circ.get_optype(n), T2Op::CX))
.count()
}
}
#[cfg(feature = "portmatching")]
pub use taso_default::DefaultTasoOptimiser;
93 changes: 77 additions & 16 deletions src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
use std::collections::HashSet;

use hugr::Hugr;
use hugr::{ops::OpType, Hugr, HugrView, Node};
use itertools::Itertools;

use crate::circuit::Circuit;
use crate::{ops::op_matches, T2Op};

use super::CircuitRewrite;

Expand Down Expand Up @@ -79,27 +79,62 @@ impl RewriteStrategy for GreedyRewriteStrategy {
/// circuit.
///
/// The parameter gamma controls how greedy the algorithm should be. It allows
/// a rewrite C1 -> C2 if C2 has at most gamma times as many gates as C1:
/// a rewrite C1 -> C2 if C2 has at most gamma times the cost of C1:
///
/// $|C2| < gamma * |C1|$
/// $cost(C2) < gamma * cost(C1)$
///
/// The cost function is given by the number of operations in the circuit that
/// satisfy a given Op predicate. This allows for instance to use the total
/// number of gates (true predicate) or the number of CX gates as cost function.
///
/// gamma = 1 is the greedy strategy where a rewrite is only allowed if it
/// strictly reduces the gate count. The default is gamma = 1.0001, as set
/// in the Quartz paper. This essentially allows rewrites that improve or leave
/// the number of nodes unchanged.
/// strictly reduces the gate count. The default is gamma = 1.0001 (as set in
/// the Quartz paper) and the number of CX gates. This essentially allows
/// rewrites that improve or leave the number of CX unchanged.
#[derive(Debug, Clone)]
pub struct ExhaustiveRewriteStrategy {
pub struct ExhaustiveRewriteStrategy<P> {
/// The gamma parameter.
pub gamma: f64,
/// Ops to count for cost function.
pub op_predicate: P,
}

impl<P> ExhaustiveRewriteStrategy<P> {
/// New exhaustive rewrite strategy with provided predicate.
///
/// The gamma parameter is set to the default 1.0001.
pub fn with_predicate(op_predicate: P) -> Self {
Self {
gamma: 1.0001,
op_predicate,
}
}

/// New exhaustive rewrite strategy with provided gamma and predicate.
pub fn new(gamma: f64, op_predicate: P) -> Self {
Self {
gamma,
op_predicate,
}
}
}

impl Default for ExhaustiveRewriteStrategy {
fn default() -> Self {
Self { gamma: 1.0001 }
impl ExhaustiveRewriteStrategy<fn(&OpType) -> bool> {
/// Exhaustive rewrite strategy with CX count cost function.
///
/// The gamma parameter is set to the default 1.0001. This is a good default
/// choice for NISQ-y circuits, where CX gates are the most expensive.
pub fn exhaustive_cx() -> Self {
ExhaustiveRewriteStrategy::with_predicate(is_cx)
}

/// Exhaustive rewrite strategy with CX count cost function and provided gamma.
pub fn exhaustive_cx_with_gamma(gamma: f64) -> Self {
ExhaustiveRewriteStrategy::new(gamma, is_cx)
}
}

impl RewriteStrategy for ExhaustiveRewriteStrategy {
impl<P: Fn(&OpType) -> bool> RewriteStrategy for ExhaustiveRewriteStrategy<P> {
#[tracing::instrument(skip_all)]
fn apply_rewrites(
&self,
Expand All @@ -109,8 +144,8 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy {
rewrites
.into_iter()
.filter(|rw| {
let old_count = rw.subcircuit().node_count() as f64;
let new_count = rw.replacement().num_gates() as f64;
let old_count = pre_rewrite_cost(rw, circ, &self.op_predicate) as f64;
let new_count = post_rewrite_cost(rw, circ, &self.op_predicate) as f64;
new_count < old_count * self.gamma
})
.map(|rw| {
Expand All @@ -122,6 +157,32 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy {
}
}

fn is_cx(op: &OpType) -> bool {
op_matches(op, T2Op::CX)
}

fn cost(
nodes: impl IntoIterator<Item = Node>,
circ: &Hugr,
pred: impl Fn(&OpType) -> bool,
) -> usize {
nodes
.into_iter()
.filter(|n| {
let op = circ.get_optype(*n);
pred(op)
})
.count()
}

fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize {
cost(rw.subcircuit().nodes().iter().copied(), circ, pred)
}

fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize {
cost(rw.replacement().nodes(), circ, pred)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -197,7 +258,7 @@ mod tests {
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
];

let strategy = ExhaustiveRewriteStrategy::default();
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 6, 9]);
let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect();
Expand All @@ -219,7 +280,7 @@ mod tests {
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
];

let strategy = ExhaustiveRewriteStrategy { gamma: 10. };
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx_with_gamma(10.);
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]);
let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect();
Expand Down
6 changes: 2 additions & 4 deletions taso-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ use std::{fs, path::Path};

use clap::Parser;
use hugr::Hugr;
use tket2::json::{load_tk1_json_file, TKETDecode};
use tket2::optimiser::taso::log::TasoLogger;
use tket2::{
json::{load_tk1_json_file, TKETDecode},
optimiser::TasoOptimiser,
};
use tket2::optimiser::TasoOptimiser;
use tket_json_rs::circuit_json::SerialCircuit;

#[cfg(feature = "peak_alloc")]
Expand Down

0 comments on commit 969d463

Please sign in to comment.