diff --git a/src/circuit.rs b/src/circuit.rs index 5bf0d1ff..3ae2b8a1 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -1,6 +1,7 @@ //! Quantum circuit representation and operations. pub mod command; +pub mod cost; mod hash; pub mod units; @@ -22,6 +23,7 @@ use itertools::Itertools; use portgraph::Direction; use thiserror::Error; +use self::cost::CircuitCost; use self::units::{filter, FilteredUnits, Units}; /// An object behaving like a quantum circuit. @@ -126,6 +128,28 @@ pub trait Circuit: HugrView { // Traverse the circuit in topological order. CommandIterator::new(self) } + + /// Compute the cost of the circuit based on a per-operation cost function. + #[inline] + fn circuit_cost(&self, op_cost: F) -> C + where + Self: Sized, + C: CircuitCost, + F: Fn(&OpType) -> C, + { + self.commands().map(|cmd| op_cost(cmd.optype())).sum() + } + + /// Compute the cost of a group of nodes in a circuit based on a + /// per-operation cost function. + #[inline] + fn nodes_cost(&self, nodes: impl IntoIterator, op_cost: F) -> C + where + C: CircuitCost, + F: Fn(&OpType) -> C, + { + nodes.into_iter().map(|n| op_cost(self.get_optype(n))).sum() + } } /// Remove an empty wire in a dataflow HUGR. diff --git a/src/circuit/cost.rs b/src/circuit/cost.rs new file mode 100644 index 00000000..e2ad3905 --- /dev/null +++ b/src/circuit/cost.rs @@ -0,0 +1,102 @@ +//! Cost definitions for a circuit. + +use derive_more::From; +use hugr::ops::OpType; +use std::fmt::Debug; +use std::iter::Sum; +use std::num::NonZeroUsize; +use std::ops::Add; + +use crate::ops::op_matches; +use crate::T2Op; + +/// The cost for a group of operations in a circuit, each with cost `OpCost`. +pub trait CircuitCost: Add + Sum + Debug + Default + Clone + Ord { + /// Returns true if the cost is above the threshold. + fn check_threshold(self, threshold: Self) -> bool; + + /// Divide the cost, rounded up. + fn div_cost(self, n: NonZeroUsize) -> Self; +} + +/// A pair of major and minor cost. +/// +/// This is used to order circuits based on major cost first, then minor cost. +/// A typical example would be CX count as major cost and total gate count as +/// minor cost. +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From)] +pub struct MajorMinorCost { + major: usize, + minor: usize, +} + +// Serialise as string so that it is easy to write to CSV +impl serde::Serialize for MajorMinorCost { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&format!("{:?}", self)) + } +} + +impl Debug for MajorMinorCost { + // TODO: A nicer print for the logs + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "(major={}, minor={})", self.major, self.minor) + } +} + +impl Add for MajorMinorCost { + type Output = MajorMinorCost; + + fn add(self, rhs: MajorMinorCost) -> Self::Output { + (self.major + rhs.major, self.minor + rhs.minor).into() + } +} + +impl Sum for MajorMinorCost { + fn sum>(iter: I) -> Self { + iter.reduce(|a, b| (a.major + b.major, a.minor + b.minor).into()) + .unwrap_or_default() + } +} + +impl CircuitCost for MajorMinorCost { + #[inline] + fn check_threshold(self, threshold: Self) -> bool { + self.major > threshold.major + } + + #[inline] + fn div_cost(mut self, n: NonZeroUsize) -> Self { + self.major = (self.major.saturating_sub(1)) / n.get() + 1; + self.minor = (self.minor.saturating_sub(1)) / n.get() + 1; + self + } +} + +impl CircuitCost for usize { + #[inline] + fn check_threshold(self, threshold: Self) -> bool { + self > threshold + } + + #[inline] + fn div_cost(self, n: NonZeroUsize) -> Self { + (self.saturating_sub(1)) / n.get() + 1 + } +} + +/// Returns true if the operation is a controlled X operation. +pub fn is_cx(op: &OpType) -> bool { + op_matches(op, T2Op::CX) +} + +/// Returns true if the operation is a quantum operation. +pub fn is_quantum(op: &OpType) -> bool { + let Ok(op): Result = op.try_into() else { + return false; + }; + op.is_quantum() +} diff --git a/src/ops.rs b/src/ops.rs index 1c1cb69a..1ac5c230 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -200,6 +200,16 @@ impl T2Op { _ => vec![], } } + + /// Check if this op is a quantum op. + pub fn is_quantum(&self) -> bool { + use T2Op::*; + match self { + H | CX | T | S | X | Y | Z | Tdg | Sdg | ZZMax | RzF64 | RxF64 | PhasedX | ZZPhase + | CZ | TK1 => true, + AngleAdd | Measure => false, + } + } } /// Initialize a new custom symbolic expression constant op from a string. diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 4e1bd326..5f15fbe3 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -30,6 +30,7 @@ use std::{mem, thread}; use fxhash::FxHashSet; use hugr::Hugr; +use crate::circuit::cost::CircuitCost; use crate::circuit::CircuitHash; use crate::optimiser::taso::hugr_pchannel::HugrPriorityChannel; use crate::optimiser::taso::hugr_pqueue::{Entry, HugrPQ}; @@ -37,6 +38,7 @@ use crate::optimiser::taso::worker::TasoWorker; use crate::passes::CircuitChunks; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; +use crate::Circuit; /// The TASO optimiser. /// @@ -55,28 +57,30 @@ use crate::rewrite::Rewriter; /// [Quartz]: https://arxiv.org/abs/2204.09033 /// [TASO]: https://dl.acm.org/doi/10.1145/3341301.3359630 #[derive(Clone, Debug)] -pub struct TasoOptimiser { +pub struct TasoOptimiser { rewriter: R, strategy: S, - cost: C, } -impl TasoOptimiser { +impl TasoOptimiser { /// Create a new TASO optimiser. - pub fn new(rewriter: R, strategy: S, cost: C) -> Self { - Self { - rewriter, - strategy, - cost, - } + pub fn new(rewriter: R, strategy: S) -> Self { + Self { rewriter, strategy } + } + + fn cost(&self, circ: &Hugr) -> S::Cost + where + S: RewriteStrategy, + { + self.strategy.circuit_cost(circ) } } -impl TasoOptimiser +impl TasoOptimiser where R: Rewriter + Send + Clone + 'static, - S: RewriteStrategy + Send + Clone + 'static, - C: Fn(&Hugr) -> usize + Send + Sync + Clone + 'static, + S: RewriteStrategy + Send + Sync + Clone + 'static, + S::Cost: serde::Serialize, { /// Run the TASO optimiser on a circuit. /// @@ -118,15 +122,19 @@ where let start_time = Instant::now(); let mut best_circ = circ.clone(); - let mut best_circ_cost = (self.cost)(circ); - logger.log_best(best_circ_cost); + let mut best_circ_cost = self.cost(circ); + logger.log_best(&best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(circ.circuit_hash())]); // The priority queue of circuits to be processed (this should not get big) const PRIORITY_QUEUE_CAPACITY: usize = 10_000; - let mut pq = HugrPQ::with_capacity(&self.cost, PRIORITY_QUEUE_CAPACITY); + let cost_fn = { + let strategy = self.strategy.clone(); + move |circ: &'_ Hugr| strategy.circuit_cost(circ) + }; + let mut pq = HugrPQ::with_capacity(cost_fn, PRIORITY_QUEUE_CAPACITY); pq.push(circ.clone()); let mut circ_cnt = 1; @@ -135,7 +143,7 @@ where if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost; - logger.log_best(best_circ_cost); + logger.log_best(&best_circ_cost); } let rewrites = self.rewriter.get_rewrites(&circ); @@ -183,15 +191,19 @@ where const PRIORITY_QUEUE_CAPACITY: usize = 10_000; // multi-consumer priority channel for queuing circuits to be processed by the workers + let cost_fn = { + let strategy = self.strategy.clone(); + move |circ: &'_ Hugr| strategy.circuit_cost(circ) + }; let (tx_work, rx_work) = - HugrPriorityChannel::init((self.cost).clone(), PRIORITY_QUEUE_CAPACITY * n_threads); + HugrPriorityChannel::init(cost_fn, PRIORITY_QUEUE_CAPACITY * n_threads); // channel for sending circuits from threads back to main let (tx_result, rx_result) = crossbeam_channel::unbounded(); let initial_circ_hash = circ.circuit_hash(); let mut best_circ = circ.clone(); - let mut best_circ_cost = (self.cost)(&best_circ); - logger.log_best(best_circ_cost); + let mut best_circ_cost = self.cost(&best_circ); + logger.log_best(&best_circ_cost); // Hash of seen circuits. Dot not store circuits as this map gets huge let mut seen_hashes: FxHashSet<_> = FromIterator::from_iter([(initial_circ_hash)]); @@ -254,13 +266,13 @@ where } seen_hashes.insert(*circ_hash); - let cost = (self.cost)(circ); + let cost = self.cost(circ); // Check if we got a new best circuit if cost < best_circ_cost { best_circ = circ.clone(); best_circ_cost = cost; - logger.log_best(best_circ_cost); + logger.log_best(&best_circ_cost); } jobs_sent += 1; } @@ -310,17 +322,16 @@ where timeout: Option, n_threads: NonZeroUsize, ) -> Result { - // TODO: Add a parameter to set other split cost functions? - // (In contrast to `self.cost`, this is a counts the per-node cost) - let circ_cx_cost = cost_functions::num_cx_gates(circ); - let max_cx_cost = (circ_cx_cost.saturating_sub(1)) / n_threads.get() + 1; + let circ_cost = self.cost(circ); + let max_chunk_cost = circ_cost.clone().div_cost(n_threads); logger.log(format!( - "Splitting circuit with cost {circ_cx_cost} into chunks of at most {max_cx_cost} CX gates" + "Splitting circuit with cost {:?} into chunks of at most {max_chunk_cost:?}.", + circ_cost.clone() )); - let mut chunks = CircuitChunks::split_with_cost(circ, max_cx_cost, cost_functions::cx_cost); + let mut chunks = + CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op)); - let circ_cost = (self.cost)(circ); - logger.log_best(circ_cost); + logger.log_best(circ_cost.clone()); let (joins, rx_work): (Vec<_>, Vec<_>) = chunks .iter_mut() @@ -329,8 +340,8 @@ where let (tx, rx) = crossbeam_channel::unbounded(); let taso = self.clone(); let chunk = mem::take(chunk); - let chunk_cx_cost = cost_functions::num_cx_gates(&chunk); - logger.log(format!("Chunk {i} has {chunk_cx_cost} CX gates",)); + let chunk_cx_cost = chunk.circuit_cost(|op| self.strategy.op_cost(op)); + logger.log(format!("Chunk {i} has {chunk_cx_cost:?} CX gates",)); let join = thread::Builder::new() .name(format!("chunk-{}", i)) .spawn(move || { @@ -351,9 +362,9 @@ where } let best_circ = chunks.reassemble()?; - let best_circ_cost = (self.cost)(&best_circ); - if best_circ_cost < circ_cost { - logger.log_best(best_circ_cost); + let best_circ_cost = self.cost(&best_circ); + if best_circ_cost.clone() < circ_cost { + logger.log_best(best_circ_cost.clone()); } logger.log_processing_end(n_threads.get(), best_circ_cost, true, false); @@ -365,12 +376,13 @@ where #[cfg(feature = "portmatching")] mod taso_default { - use hugr::ops::OpType; use std::io; use std::path::Path; + use hugr::ops::OpType; + use crate::rewrite::ecc_rewriter::RewriterSerialisationError; - use crate::rewrite::strategy::ExhaustiveRewriteStrategy; + use crate::rewrite::strategy::NonIncreasingGateCountStrategy; use crate::rewrite::ECCRewriter; use super::*; @@ -378,20 +390,15 @@ mod taso_default { /// The default TASO optimiser using ECC sets. pub type DefaultTasoOptimiser = TasoOptimiser< ECCRewriter, - ExhaustiveRewriteStrategy bool>, - fn(&Hugr) -> usize, + NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize>, >; impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); - Ok(TasoOptimiser::new( - rewriter, - strategy, - cost_functions::num_cx_gates, - )) + let strategy = NonIncreasingGateCountStrategy::default_cx(); + Ok(TasoOptimiser::new(rewriter, strategy)) } /// A sane default optimiser using a precompiled binary rewriter. @@ -399,29 +406,68 @@ mod taso_default { rewriter_path: impl AsRef, ) -> Result { let rewriter = ECCRewriter::load_binary(rewriter_path)?; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); - Ok(TasoOptimiser::new( - rewriter, - strategy, - cost_functions::num_cx_gates, - )) + let strategy = NonIncreasingGateCountStrategy::default_cx(); + Ok(TasoOptimiser::new(rewriter, strategy)) } } } #[cfg(feature = "portmatching")] pub use taso_default::DefaultTasoOptimiser; -mod cost_functions { - use super::*; - use crate::ops::op_matches; - use crate::T2Op; - use hugr::{HugrView, Node}; +#[cfg(test)] +#[cfg(feature = "portmatching")] +mod tests { + use hugr::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::QB_T, + std_extensions::arithmetic::float_types::FLOAT64_TYPE, + types::FunctionType, + Hugr, + }; + use rstest::{fixture, rstest}; + + use crate::{extension::REGISTRY, Circuit, T2Op}; + + use super::{DefaultTasoOptimiser, TasoOptimiser}; + + #[fixture] + fn rz_rz() -> Hugr { + let input_t = vec![QB_T, FLOAT64_TYPE, FLOAT64_TYPE]; + let output_t = vec![QB_T]; + let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); + + let mut inps = h.input_wires(); + let qb = inps.next().unwrap(); + let f1 = inps.next().unwrap(); + let f2 = inps.next().unwrap(); + + let res = h.add_dataflow_op(T2Op::RzF64, [qb, f1]).unwrap(); + let qb = res.outputs().next().unwrap(); + let res = h.add_dataflow_op(T2Op::RzF64, [qb, f2]).unwrap(); + let qb = res.outputs().next().unwrap(); + + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + } - pub fn num_cx_gates(circ: &Hugr) -> usize { - circ.nodes().map(|n| cx_cost(circ, n)).sum() + #[fixture] + fn taso_opt() -> DefaultTasoOptimiser { + TasoOptimiser::default_with_eccs_json_file("test_files/small_eccs.json").unwrap() } - pub fn cx_cost(circ: &Hugr, node: Node) -> usize { - op_matches(circ.get_optype(node), T2Op::CX) as usize + #[rstest] + fn rz_rz_cancellation(rz_rz: Hugr, taso_opt: DefaultTasoOptimiser) { + let opt_rz = taso_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false); + let cmds = opt_rz + .commands() + .map(|cmd| { + ( + cmd.optype().try_into().unwrap(), + cmd.inputs().count(), + cmd.outputs().count(), + ) + }) + .collect::>(); + let exp_cmds = vec![(T2Op::AngleAdd, 2, 1), (T2Op::RzF64, 2, 1)]; + assert_eq!(cmds, exp_cmds); } } diff --git a/src/optimiser/taso/log.rs b/src/optimiser/taso/log.rs index 1bbd9abd..e88cbe41 100644 --- a/src/optimiser/taso/log.rs +++ b/src/optimiser/taso/log.rs @@ -1,6 +1,6 @@ //! Logging utilities for the TASO optimiser. -use std::io; +use std::{fmt::Debug, io}; /// Logging configuration for the TASO optimiser. #[derive(Default)] @@ -35,8 +35,8 @@ impl<'w> TasoLogger<'w> { /// Log a new best candidate #[inline] - pub fn log_best(&mut self, best_cost: usize) { - self.log(format!("new best of size {}", best_cost)); + pub fn log_best(&mut self, best_cost: C) { + self.log(format!("new best of size {:?}", best_cost)); if let Some(csv_writer) = self.circ_candidates_csv.as_mut() { csv_writer.serialize(BestCircSer::new(best_cost)).unwrap(); csv_writer.flush().unwrap(); @@ -45,10 +45,10 @@ impl<'w> TasoLogger<'w> { /// Log the final optimised circuit #[inline] - pub fn log_processing_end( + pub fn log_processing_end( &self, circuit_count: usize, - best_cost: usize, + best_cost: C, needs_joining: bool, timeout: bool, ) { @@ -57,7 +57,7 @@ impl<'w> TasoLogger<'w> { } self.log("Optimisation finished"); self.log(format!("Tried {circuit_count} circuits")); - self.log(format!("END RESULT: {}", best_cost)); + self.log(format!("END RESULT: {:?}", best_cost)); if needs_joining { self.log("Joining worker threads"); } @@ -98,14 +98,14 @@ impl<'w> TasoLogger<'w> { // // TODO: Replace this fixed logging. Report back intermediate results. #[derive(serde::Serialize, Clone, Debug)] -struct BestCircSer { - circ_len: usize, +struct BestCircSer { + circ_cost: C, time: String, } -impl BestCircSer { - fn new(circ_len: usize) -> Self { +impl BestCircSer { + fn new(circ_cost: C) -> Self { let time = chrono::Local::now().to_rfc3339(); - Self { circ_len, time } + Self { circ_cost, time } } } diff --git a/src/passes/chunks.rs b/src/passes/chunks.rs index ff2d1361..38109966 100644 --- a/src/passes/chunks.rs +++ b/src/passes/chunks.rs @@ -13,12 +13,14 @@ use hugr::hugr::views::sibling_subgraph::ConvexChecker; use hugr::hugr::views::{HierarchyView, SiblingGraph, SiblingSubgraph}; use hugr::hugr::{HugrError, NodeMetadata}; use hugr::ops::handle::DataflowParentID; +use hugr::ops::OpType; use hugr::types::{FunctionType, Signature}; use hugr::{Hugr, HugrView, Node, Port, Wire}; use itertools::Itertools; use crate::Circuit; +use crate::circuit::cost::CircuitCost; #[cfg(feature = "pyo3")] use crate::json::TKETDecode; #[cfg(feature = "pyo3")] @@ -178,16 +180,16 @@ impl CircuitChunks { /// /// The circuit is split into chunks of at most `max_size` gates. pub fn split(circ: &impl Circuit, max_size: usize) -> Self { - Self::split_with_cost(circ, max_size, |_, _| 1) + Self::split_with_cost(circ, max_size, |_| 1) } /// Split a circuit into chunks. /// /// The circuit is split into chunks of at most `max_cost`, using the provided cost function. - pub fn split_with_cost( - circ: &C, - max_cost: usize, - node_cost: impl Fn(&C, Node) -> usize, + pub fn split_with_cost( + circ: &H, + max_cost: C, + op_cost: impl Fn(&OpType) -> C, ) -> Self { let root_meta = circ.get_metadata(circ.root()).clone(); let signature = circ.circuit_signature().clone(); @@ -205,11 +207,17 @@ impl CircuitChunks { let mut chunks = Vec::new(); let mut convex_checker = ConvexChecker::new(circ); - let mut running_cost = 0; + let mut running_cost = C::default(); + let mut current_group = 0; for (_, commands) in &circ.commands().map(|cmd| cmd.node()).group_by(|&node| { - let group = running_cost / max_cost; - running_cost += node_cost(circ, node); - group + let new_cost = running_cost.clone() + op_cost(circ.get_optype(node)); + if new_cost.clone().check_threshold(max_cost.clone()) { + running_cost = C::default(); + current_group += 1; + } else { + running_cost = new_cost; + } + current_group }) { chunks.push(Chunk::extract(circ, commands, &mut convex_checker)); } diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index c8e95f0f..6343837c 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -338,7 +338,7 @@ mod tests { let test_file = "test_files/small_eccs.json"; let rewriter = ECCRewriter::try_from_eccs_json_file(test_file).unwrap(); assert_eq!(rewriter.rewrite_rules.len(), rewriter.matcher.n_patterns()); - assert_eq!(rewriter.targets.len(), 5 * 4 + 4 * 3); + assert_eq!(rewriter.targets.len(), 5 * 4 + 5 * 3); // Assert that the rewrite rules are correct, i.e that the rewrite // rules in the slice (k..=k+t) is given by [[k+1, ..., k+t], [k], ..., [k]] @@ -361,8 +361,8 @@ mod tests { curr_repr = TargetID(i); } } - // There should be 4x ECCs of size 3 and 5x ECCs of size 4 - let exp_n_eccs_of_len = [0, 4 * 2 + 5 * 3, 4, 5]; + // There should be 5x ECCs of size 3 and 5x ECCs of size 4 + let exp_n_eccs_of_len = [0, 5 * 2 + 5 * 3, 5, 5]; assert_eq!(n_eccs_of_len, exp_n_eccs_of_len); } diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 3df7ceeb..c5acdb90 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -2,18 +2,24 @@ //! //! This module contains the [`RewriteStrategy`] trait, which is currently //! implemented by -//! - [`GreedyRewriteStrategy`], which applies as many rewrites as possible +//! - [`GreedyRewriteStrategy`], which applies as many rewrites as possible //! on one circuit, and -//! - [`ExhaustiveRewriteStrategy`], which clones the original circuit as many -//! times as there are possible rewrites and applies a different rewrite -//! to every circuit. +//! - exhaustive strategies, which clone the original circuit and explore every +//! possible rewrite (with some pruning strategy): +//! - [`NonIncreasingGateCountStrategy`], which only considers rewrites that +//! do not increase some cost function (e.g. cx gate count, implemented as +//! [`NonIncreasingGateCountStrategy::default_cx`]), and +//! - [`ExhaustiveGammaStrategy`], which ignores rewrites that increase the +//! cost function beyond a threshold given by a f64 parameter gamma. -use std::collections::HashSet; +use std::{collections::HashSet, fmt::Debug}; -use hugr::{ops::OpType, Hugr, HugrView, Node}; +use hugr::ops::OpType; +use hugr::{Hugr, HugrView}; use itertools::Itertools; -use crate::{ops::op_matches, T2Op}; +use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, MajorMinorCost}; +use crate::Circuit; use super::CircuitRewrite; @@ -23,13 +29,25 @@ use super::CircuitRewrite; /// to a circuit according to a strategy. It returns a list of new circuits, /// each obtained by applying one or several non-overlapping rewrites to the /// original circuit. +/// +/// It also assign every circuit a totally ordered cost that can be used when +/// using rewrites for circuit optimisation. pub trait RewriteStrategy { + /// The circuit cost to be minimised. + type Cost: CircuitCost; + /// Apply a set of rewrites to a circuit. fn apply_rewrites( &self, rewrites: impl IntoIterator, circ: &Hugr, ) -> Vec; + + /// The cost of a circuit. + fn circuit_cost(&self, circ: &Hugr) -> Self::Cost; + + /// The cost of a single operation. + fn op_cost(&self, op: &OpType) -> Self::Cost; } /// A rewrite strategy applying as many non-overlapping rewrites as possible. @@ -45,6 +63,8 @@ pub trait RewriteStrategy { pub struct GreedyRewriteStrategy; impl RewriteStrategy for GreedyRewriteStrategy { + type Cost = usize; + #[tracing::instrument(skip_all)] fn apply_rewrites( &self, @@ -73,114 +93,207 @@ impl RewriteStrategy for GreedyRewriteStrategy { } vec![circ] } + + fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + circ.num_gates() + } + + fn op_cost(&self, _op: &OpType) -> Self::Cost { + 1 + } +} + +/// Exhaustive strategies based on cost functions and thresholds. +/// +/// Every possible rewrite is applied to a copy of the input circuit. Thus for +/// one circuit, up to `n` rewritten circuits will be returned, each obtained +/// by applying one of the `n` rewrites to the original circuit. +/// +/// Whether a rewrite is allowed or not is determined by a cost function and a +/// threshold function: if the cost of the target of the rewrite is below the +/// threshold given by the cost of the original circuit, the rewrite is +/// performed. +/// +/// The cost function must return a value of type `Self::OpCost`. All op costs +/// are summed up to obtain a total cost that is then compared using the +/// threshold function. +pub trait ExhaustiveThresholdStrategy { + /// The cost of a single operation. + type OpCost: CircuitCost; + + /// Returns true if the rewrite is allowed, based on the cost of the pattern and target. + fn under_threshold(&self, pattern_cost: &Self::OpCost, target_cost: &Self::OpCost) -> bool; + + /// The cost of a single operation. + fn op_cost(&self, op: &OpType) -> Self::OpCost; +} + +impl RewriteStrategy for T { + type Cost = T::OpCost; + + #[tracing::instrument(skip_all)] + fn apply_rewrites( + &self, + rewrites: impl IntoIterator, + circ: &Hugr, + ) -> Vec { + rewrites + .into_iter() + .filter(|rw| { + let pattern_cost = pre_rewrite_cost(rw, circ, |op| self.op_cost(op)); + let target_cost = post_rewrite_cost(rw, circ, |op| self.op_cost(op)); + self.under_threshold(&pattern_cost, &target_cost) + }) + .map(|rw| { + let mut circ = circ.clone(); + rw.apply(&mut circ).expect("invalid pattern match"); + circ + }) + .collect() + } + + fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + circ.nodes_cost(circ.nodes(), |op| self.op_cost(op)) + } + + fn op_cost(&self, op: &OpType) -> Self::Cost { + ::op_cost(self, op) + } } -/// A rewrite strategy that explores applying each rewrite to copies of the -/// circuit. +/// Exhaustive rewrite strategy allowing smaller or equal cost rewrites. +/// +/// Rewrites are permitted based on a cost function called the major cost: if +/// the major cost of the target of the rewrite is smaller or equal to the major +/// cost of the pattern, the rewrite is allowed. +/// +/// A second cost function, the minor cost, is used as a tie breaker: within +/// circuits with the same major cost, the circuit ordering prioritises circuits +/// with a smaller minor cost. +/// +/// An example would be to use the number of CX gates as major cost and the +/// total number of gates as minor cost. Compared to a [`ExhaustiveGammaStrategy`], +/// that would only order circuits based on the number of CX gates, this creates +/// a less flat optimisation landscape. +#[derive(Debug, Clone)] +pub struct NonIncreasingGateCountStrategy { + major_cost: C1, + minor_cost: C2, +} + +impl ExhaustiveThresholdStrategy for NonIncreasingGateCountStrategy +where + C1: Fn(&OpType) -> usize, + C2: Fn(&OpType) -> usize, +{ + type OpCost = MajorMinorCost; + + fn under_threshold(&self, pattern_cost: &Self::OpCost, target_cost: &Self::OpCost) -> bool { + !target_cost.check_threshold(*pattern_cost) + } + + fn op_cost(&self, op: &OpType) -> Self::OpCost { + ((self.major_cost)(op), (self.minor_cost)(op)).into() + } +} + +impl NonIncreasingGateCountStrategy usize, fn(&OpType) -> usize> { + /// Non-increasing rewrite strategy based on CX count. + /// + /// The minor cost to break ties between equal CX counts is the number of + /// quantum gates. + /// + /// This is probably a good default for NISQ-y circuit optimisation. + pub fn default_cx() -> Self { + Self { + major_cost: |op| is_cx(op) as usize, + minor_cost: |op| is_quantum(op) as usize, + } + } +} + +/// Exhaustive rewrite strategy allowing rewrites with bounded cost increase. /// /// The parameter gamma controls how greedy the algorithm should be. It allows /// a rewrite C1 -> C2 if C2 has at most gamma times the cost of C1: /// /// $cost(C2) < gamma * cost(C1)$ /// -/// The cost function is given by the number of operations in the circuit that -/// satisfy a given Op predicate. This allows for instance to use the total -/// number of gates (true predicate) or the number of CX gates as cost function. +/// The cost function is given by the sum of the cost of each operation in the +/// circuit. This allows for instance to use of the total number of gates (true +/// predicate), the number of CX gates or a weighted sum of gate types as cost +/// functions. /// /// gamma = 1 is the greedy strategy where a rewrite is only allowed if it /// strictly reduces the gate count. The default is gamma = 1.0001 (as set in /// the Quartz paper) and the number of CX gates. This essentially allows /// rewrites that improve or leave the number of CX unchanged. #[derive(Debug, Clone)] -pub struct ExhaustiveRewriteStrategy

{ +pub struct ExhaustiveGammaStrategy { /// The gamma parameter. pub gamma: f64, - /// Ops to count for cost function. - pub op_predicate: P, + /// A cost function for each operation. + pub op_cost: C, } -impl

ExhaustiveRewriteStrategy

{ +impl usize> ExhaustiveThresholdStrategy for ExhaustiveGammaStrategy { + type OpCost = usize; + + fn under_threshold(&self, &pattern_cost: &Self::OpCost, &target_cost: &Self::OpCost) -> bool { + (target_cost as f64) < self.gamma * (pattern_cost as f64) + } + + fn op_cost(&self, op: &OpType) -> Self::OpCost { + (self.op_cost)(op) + } +} + +impl ExhaustiveGammaStrategy { /// New exhaustive rewrite strategy with provided predicate. /// /// The gamma parameter is set to the default 1.0001. - pub fn with_predicate(op_predicate: P) -> Self { + pub fn with_cost(op_cost: C) -> Self { Self { gamma: 1.0001, - op_predicate, + op_cost, } } /// New exhaustive rewrite strategy with provided gamma and predicate. - pub fn new(gamma: f64, op_predicate: P) -> Self { - Self { - gamma, - op_predicate, - } + pub fn new(gamma: f64, op_cost: C) -> Self { + Self { gamma, op_cost } } } -impl ExhaustiveRewriteStrategy bool> { +impl ExhaustiveGammaStrategy usize> { /// Exhaustive rewrite strategy with CX count cost function. /// /// The gamma parameter is set to the default 1.0001. This is a good default /// choice for NISQ-y circuits, where CX gates are the most expensive. pub fn exhaustive_cx() -> Self { - ExhaustiveRewriteStrategy::with_predicate(is_cx) + ExhaustiveGammaStrategy::with_cost(|op| is_cx(op) as usize) } /// Exhaustive rewrite strategy with CX count cost function and provided gamma. pub fn exhaustive_cx_with_gamma(gamma: f64) -> Self { - ExhaustiveRewriteStrategy::new(gamma, is_cx) + ExhaustiveGammaStrategy::new(gamma, |op| is_cx(op) as usize) } } -impl bool> RewriteStrategy for ExhaustiveRewriteStrategy

{ - #[tracing::instrument(skip_all)] - fn apply_rewrites( - &self, - rewrites: impl IntoIterator, - circ: &Hugr, - ) -> Vec { - rewrites - .into_iter() - .filter(|rw| { - let old_count = pre_rewrite_cost(rw, circ, &self.op_predicate) as f64; - let new_count = post_rewrite_cost(rw, circ, &self.op_predicate) as f64; - new_count < old_count * self.gamma - }) - .map(|rw| { - let mut circ = circ.clone(); - rw.apply(&mut circ).expect("invalid pattern match"); - circ - }) - .collect() - } -} - -fn is_cx(op: &OpType) -> bool { - op_matches(op, T2Op::CX) -} - -fn cost( - nodes: impl IntoIterator, - circ: &Hugr, - pred: impl Fn(&OpType) -> bool, -) -> usize { - nodes - .into_iter() - .filter(|n| { - let op = circ.get_optype(*n); - pred(op) - }) - .count() -} - -fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { - cost(rw.subcircuit().nodes().iter().copied(), circ, pred) +fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: F) -> C +where + C: CircuitCost, + F: Fn(&OpType) -> C, +{ + circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), pred) } -fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: impl Fn(&OpType) -> bool) -> usize { - cost(rw.replacement().nodes(), circ, pred) +fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: F) -> C +where + C: CircuitCost, + F: Fn(&OpType) -> C, +{ + circ.nodes_cost(rw.replacement().nodes(), pred) } #[cfg(test)] @@ -258,7 +371,7 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + let strategy = ExhaustiveGammaStrategy::exhaustive_cx(); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); @@ -280,10 +393,34 @@ mod tests { rw_to_empty(&circ, cx_gates[9..10].to_vec()), ]; - let strategy = ExhaustiveRewriteStrategy::exhaustive_cx_with_gamma(10.); + let strategy = ExhaustiveGammaStrategy::exhaustive_cx_with_gamma(10.); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]); let circ_lens: HashSet<_> = rewritten.iter().map(|c| c.num_gates()).collect(); assert_eq!(circ_lens, exp_circ_lens); } + + #[test] + fn test_exhaustive_default_cx_cost() { + let strat = NonIncreasingGateCountStrategy::default_cx(); + let circ = n_cx(3); + assert_eq!(strat.circuit_cost(&circ), (3, 3).into()); + let circ = build_simple_circuit(2, |circ| { + circ.append(T2Op::CX, [0, 1])?; + circ.append(T2Op::X, [0])?; + circ.append(T2Op::X, [1])?; + Ok(()) + }) + .unwrap(); + assert_eq!(strat.circuit_cost(&circ), (1, 3).into()); + } + + #[test] + fn test_exhaustive_default_cx_threshold() { + let strat = NonIncreasingGateCountStrategy::default_cx(); + assert!(strat.under_threshold(&(3, 0).into(), &(3, 0).into())); + assert!(strat.under_threshold(&(3, 0).into(), &(3, 5).into())); + assert!(!strat.under_threshold(&(3, 10).into(), &(4, 0).into())); + assert!(strat.under_threshold(&(3, 0).into(), &(1, 5).into())); + } } diff --git a/test_files/small_eccs.json b/test_files/small_eccs.json index 787d10a6..06e29e54 100644 --- a/test_files/small_eccs.json +++ b/test_files/small_eccs.json @@ -50,5 +50,10 @@ ,[[2,0,0,3,["116baec0e00cd"],[5.53564910528592469e-01,2.63161770391809990e-01]],[["cx", ["Q1", "Q0"],["Q1", "Q0"]],["x", ["Q1"],["Q1"]],["cx", ["Q1", "Q0"],["Q1", "Q0"]]]] ,[[2,0,0,3,["116baec0e00cd"],[5.53564910528592469e-01,2.63161770391809990e-01]],[["cx", ["Q1", "Q0"],["Q1", "Q0"]],["x", ["Q1"],["Q1"]],["cx", ["Q1", "Q0"],["Q1", "Q0"]]]] ] +,"6701_3": [ +[[1,2,3,2,["a720832fadf2"],[2.22710267824423158e-01,2.92349563045663841e-01]],[["add", ["P2"],["P0", "P1"]],["rz", ["Q0"],["Q0", "P2"]]]] +,[[1,2,2,2,["a720832fadf2"],[2.22710267824423158e-01,2.92349563045663841e-01]],[["rz", ["Q0"],["Q0", "P0"]],["rz", ["Q0"],["Q0", "P1"]]]] +,[[1,2,2,2,["a720832fadf2"],[2.22710267824423103e-01,2.92349563045663619e-01]],[["rz", ["Q0"],["Q0", "P1"]],["rz", ["Q0"],["Q0", "P0"]]]] +] } ] \ No newline at end of file