Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Use CX count as default cost function #134

Merged
merged 4 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ pub enum T2Op {
TK1,
}

/// Whether an op is a given T2Op.
pub(crate) fn op_matches(op: &OpType, t2op: T2Op) -> bool {
let Ok(op) = T2Op::try_from(op) else {
return false;
};
op == t2op
}

#[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)]
#[cfg_attr(feature = "pyo3", pyclass)]
#[allow(missing_docs)]
Expand Down Expand Up @@ -294,17 +302,27 @@ impl TryFrom<OpType> for T2Op {
type Error = NotT2Op;

fn try_from(op: OpType) -> Result<Self, Self::Error> {
let leaf: LeafOp = op.try_into().map_err(|_| NotT2Op)?;
Self::try_from(&op)
}
}

impl TryFrom<&OpType> for T2Op {
type Error = NotT2Op;

fn try_from(op: &OpType) -> Result<Self, Self::Error> {
let OpType::LeafOp(leaf) = op else {
return Err(NotT2Op);
};
leaf.try_into()
}
}

impl TryFrom<LeafOp> for T2Op {
impl TryFrom<&LeafOp> for T2Op {
type Error = NotT2Op;

fn try_from(op: LeafOp) -> Result<Self, Self::Error> {
fn try_from(op: &LeafOp) -> Result<Self, Self::Error> {
match op {
LeafOp::CustomOp(b) => match *b {
LeafOp::CustomOp(b) => match b.as_ref() {
ExternalOp::Extension(e) => Self::try_from_op_def(e.def()),
ExternalOp::Opaque(o) => from_extension_name(o.extension(), o.name()),
},
Expand All @@ -313,6 +331,14 @@ impl TryFrom<LeafOp> for T2Op {
}
}

impl TryFrom<LeafOp> for T2Op {
type Error = NotT2Op;

fn try_from(op: LeafOp) -> Result<Self, Self::Error> {
Self::try_from(&op)
}
}

/// load all variants of a `SimpleOpEnum` in to an extension as op defs.
fn load_all_ops<T: SimpleOpEnum>(extension: &mut Extension) -> Result<(), ExtensionBuildError> {
for op in T::all_variants() {
Expand Down
38 changes: 27 additions & 11 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,20 +293,36 @@ where

#[cfg(feature = "portmatching")]
mod taso_default {
use crate::circuit::Circuit;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use hugr::ops::OpType;
use hugr::HugrView;

use crate::ops::op_matches;
use crate::rewrite::strategy::{exhaustive_cx, ExhaustiveRewriteStrategy};
use crate::rewrite::ECCRewriter;
use crate::T2Op;

use super::*;

impl TasoOptimiser<ECCRewriter, ExhaustiveRewriteStrategy, fn(&Hugr) -> usize> {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(
eccs_path: impl AsRef<std::path::Path>,
) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::default();
Ok(Self::new(rewriter, strategy, |c| c.num_gates()))
}
pub type DefaultTasoOptimiser = TasoOptimiser<
ECCRewriter,
ExhaustiveRewriteStrategy<fn(&OpType) -> bool>,
fn(&Hugr) -> usize,
>;

/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(
eccs_path: impl AsRef<std::path::Path>,
) -> io::Result<DefaultTasoOptimiser> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}

fn num_cx_gates(circ: &Hugr) -> usize {
circ.nodes()
.filter(|&n| op_matches(circ.get_optype(n), T2Op::CX))
.count()
}
}
#[cfg(feature = "portmatching")]
pub use taso_default::default_with_eccs_json_file;
91 changes: 75 additions & 16 deletions src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

use std::collections::HashSet;

use hugr::Hugr;
use hugr::{ops::OpType, Hugr, HugrView, Node};
use itertools::Itertools;

use crate::circuit::Circuit;
use crate::{ops::op_matches, T2Op};

use super::CircuitRewrite;

Expand Down Expand Up @@ -79,27 +79,60 @@ impl RewriteStrategy for GreedyRewriteStrategy {
/// circuit.
///
/// The parameter gamma controls how greedy the algorithm should be. It allows
/// a rewrite C1 -> C2 if C2 has at most gamma times as many gates as C1:
/// a rewrite C1 -> C2 if C2 has at most gamma times the cost of C1:
///
/// $|C2| < gamma * |C1|$
/// $cost(C2) < gamma * cost(C1)$
///
/// The cost function is given by the number of operations in the circuit that
/// satisfy a given Op predicate. This allows for instance to use the total
/// number of gates (true predicate) or the number of CX gates as cost function.
///
/// gamma = 1 is the greedy strategy where a rewrite is only allowed if it
/// strictly reduces the gate count. The default is gamma = 1.0001, as set
/// in the Quartz paper. This essentially allows rewrites that improve or leave
/// the number of nodes unchanged.
/// strictly reduces the gate count. The default is gamma = 1.0001 (as set in
/// the Quartz paper) and the number of CX gates. This essentially allows
/// rewrites that improve or leave the number of CX unchanged.
#[derive(Debug, Clone)]
pub struct ExhaustiveRewriteStrategy {
pub struct ExhaustiveRewriteStrategy<P> {
/// The gamma parameter.
pub gamma: f64,
/// Ops to count for cost function.
pub op_predicate: P,
}

impl Default for ExhaustiveRewriteStrategy {
fn default() -> Self {
Self { gamma: 1.0001 }
impl<P> ExhaustiveRewriteStrategy<P> {
/// New exhaustive rewrite strategy with provided predicate.
///
/// The gamma parameter is set to the default 1.0001.
pub fn with_predicate(op_predicate: P) -> Self {
Self {
gamma: 1.0001,
op_predicate,
}
}

/// New exhaustive rewrite strategy with provided gamma and predicate.
pub fn new(gamma: f64, op_predicate: P) -> Self {
Self {
gamma,
op_predicate,
}
}
}

impl RewriteStrategy for ExhaustiveRewriteStrategy {
/// Exhaustive rewrite strategy with CX count cost function.
///
/// The gamma parameter is set to the default 1.0001. This is a good default
/// choice for NISQ-y circuits, where CX gates are the most expensive.
pub fn exhaustive_cx() -> ExhaustiveRewriteStrategy<fn(&OpType) -> bool> {
ExhaustiveRewriteStrategy::with_predicate(is_cx)
}

/// Exhaustive rewrite strategy with CX count cost function and provided gamma.
pub fn exhaustive_cx_with_gamma(gamma: f64) -> ExhaustiveRewriteStrategy<fn(&OpType) -> bool> {
ExhaustiveRewriteStrategy::new(gamma, is_cx)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these would be nicer as associated functions of ExhaustiveRewriteStrategy

impl ExhaustiveRewriteStrategy<fn(&OpType) -> bool> {
    pub fn exhaustive_cx() -> Self {
        Self::with_predicate(is_cx)
    }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right, I did the same with default_with_eccs_json_file in TasoOptimiser.


impl<P: Fn(&OpType) -> bool> RewriteStrategy for ExhaustiveRewriteStrategy<P> {
#[tracing::instrument(skip_all)]
fn apply_rewrites(
&self,
Expand All @@ -109,8 +142,8 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy {
rewrites
.into_iter()
.filter(|rw| {
let old_count = rw.subcircuit().node_count() as f64;
let new_count = rw.replacement().num_gates() as f64;
let old_count = pre_rewrite_cost(rw, circ, &self.op_predicate) as f64;
let new_count = post_rewrite_cost(rw, circ, &self.op_predicate) as f64;
new_count < old_count * self.gamma
})
.map(|rw| {
Expand All @@ -122,6 +155,32 @@ impl RewriteStrategy for ExhaustiveRewriteStrategy {
}
}

fn is_cx(op: &OpType) -> bool {
op_matches(op, T2Op::CX)
}

fn cost(
nodes: impl IntoIterator<Item = Node>,
circ: &Hugr,
pred: impl Fn(&OpType) -> bool,
) -> usize {
nodes
.into_iter()
.filter(|n| {
let op = circ.get_optype(*n);
pred(op)
})
.count()
}

fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize {
cost(rw.subcircuit().nodes().iter().copied(), circ, pred)
}

fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize {
cost(rw.replacement().nodes(), circ, pred)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -197,7 +256,7 @@ mod tests {
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
];

let strategy = ExhaustiveRewriteStrategy::default();
let strategy = exhaustive_cx();
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 6, 9]);
let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect();
Expand All @@ -219,7 +278,7 @@ mod tests {
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
];

let strategy = ExhaustiveRewriteStrategy { gamma: 10. };
let strategy = exhaustive_cx_with_gamma(10.);
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]);
let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect();
Expand Down
8 changes: 3 additions & 5 deletions taso-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ use std::{fs, path::Path};

use clap::Parser;
use hugr::Hugr;
use tket2::json::{load_tk1_json_file, TKETDecode};
use tket2::optimiser::taso;
use tket2::optimiser::taso::log::TasoLogger;
use tket2::{
json::{load_tk1_json_file, TKETDecode},
optimiser::TasoOptimiser,
};
use tket_json_rs::circuit_json::SerialCircuit;

#[cfg(feature = "peak_alloc")]
Expand Down Expand Up @@ -111,7 +109,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let circ = load_tk1_json_file(input_path)?;

println!("Compiling rewriter...");
let Ok(optimiser) = TasoOptimiser::default_with_eccs_json_file(ecc_path) else {
let Ok(optimiser) = taso::default_with_eccs_json_file(ecc_path) else {
eprintln!(
"Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?",
ecc_path
Expand Down
Loading