From d2a73042a8e5f2a49beaaefb3f34ee7adb20c4bc Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Tue, 3 Oct 2023 02:44:49 +0100 Subject: [PATCH] feat: Circuit cost module and methods --- src/circuit.rs | 24 +++++++ src/circuit/cost.rs | 102 ++++++++++++++++++++++++++++++ src/optimiser/taso.rs | 3 +- src/passes/chunks.rs | 89 +++++++++++++++++++++----- src/rewrite/strategy.rs | 135 ++++++++++++---------------------------- 5 files changed, 242 insertions(+), 111 deletions(-) create mode 100644 src/circuit/cost.rs diff --git a/src/circuit.rs b/src/circuit.rs index 5bf0d1ff..3ae2b8a1 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -1,6 +1,7 @@ //! Quantum circuit representation and operations. pub mod command; +pub mod cost; mod hash; pub mod units; @@ -22,6 +23,7 @@ use itertools::Itertools; use portgraph::Direction; use thiserror::Error; +use self::cost::CircuitCost; use self::units::{filter, FilteredUnits, Units}; /// An object behaving like a quantum circuit. @@ -126,6 +128,28 @@ pub trait Circuit: HugrView { // Traverse the circuit in topological order. CommandIterator::new(self) } + + /// Compute the cost of the circuit based on a per-operation cost function. + #[inline] + fn circuit_cost(&self, op_cost: F) -> C + where + Self: Sized, + C: CircuitCost, + F: Fn(&OpType) -> C, + { + self.commands().map(|cmd| op_cost(cmd.optype())).sum() + } + + /// Compute the cost of a group of nodes in a circuit based on a + /// per-operation cost function. + #[inline] + fn nodes_cost(&self, nodes: impl IntoIterator, op_cost: F) -> C + where + C: CircuitCost, + F: Fn(&OpType) -> C, + { + nodes.into_iter().map(|n| op_cost(self.get_optype(n))).sum() + } } /// Remove an empty wire in a dataflow HUGR. diff --git a/src/circuit/cost.rs b/src/circuit/cost.rs new file mode 100644 index 00000000..e2ad3905 --- /dev/null +++ b/src/circuit/cost.rs @@ -0,0 +1,102 @@ +//! Cost definitions for a circuit. + +use derive_more::From; +use hugr::ops::OpType; +use std::fmt::Debug; +use std::iter::Sum; +use std::num::NonZeroUsize; +use std::ops::Add; + +use crate::ops::op_matches; +use crate::T2Op; + +/// The cost for a group of operations in a circuit, each with cost `OpCost`. +pub trait CircuitCost: Add + Sum + Debug + Default + Clone + Ord { + /// Returns true if the cost is above the threshold. + fn check_threshold(self, threshold: Self) -> bool; + + /// Divide the cost, rounded up. + fn div_cost(self, n: NonZeroUsize) -> Self; +} + +/// A pair of major and minor cost. +/// +/// This is used to order circuits based on major cost first, then minor cost. +/// A typical example would be CX count as major cost and total gate count as +/// minor cost. +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From)] +pub struct MajorMinorCost { + major: usize, + minor: usize, +} + +// Serialise as string so that it is easy to write to CSV +impl serde::Serialize for MajorMinorCost { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&format!("{:?}", self)) + } +} + +impl Debug for MajorMinorCost { + // TODO: A nicer print for the logs + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "(major={}, minor={})", self.major, self.minor) + } +} + +impl Add for MajorMinorCost { + type Output = MajorMinorCost; + + fn add(self, rhs: MajorMinorCost) -> Self::Output { + (self.major + rhs.major, self.minor + rhs.minor).into() + } +} + +impl Sum for MajorMinorCost { + fn sum>(iter: I) -> Self { + iter.reduce(|a, b| (a.major + b.major, a.minor + b.minor).into()) + .unwrap_or_default() + } +} + +impl CircuitCost for MajorMinorCost { + #[inline] + fn check_threshold(self, threshold: Self) -> bool { + self.major > threshold.major + } + + #[inline] + fn div_cost(mut self, n: NonZeroUsize) -> Self { + self.major = (self.major.saturating_sub(1)) / n.get() + 1; + self.minor = (self.minor.saturating_sub(1)) / n.get() + 1; + self + } +} + +impl CircuitCost for usize { + #[inline] + fn check_threshold(self, threshold: Self) -> bool { + self > threshold + } + + #[inline] + fn div_cost(self, n: NonZeroUsize) -> Self { + (self.saturating_sub(1)) / n.get() + 1 + } +} + +/// Returns true if the operation is a controlled X operation. +pub fn is_cx(op: &OpType) -> bool { + op_matches(op, T2Op::CX) +} + +/// Returns true if the operation is a quantum operation. +pub fn is_quantum(op: &OpType) -> bool { + let Ok(op): Result = op.try_into() else { + return false; + }; + op.is_quantum() +} diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 7fc4c9b0..a4f0991f 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -22,7 +22,6 @@ use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; pub use log::TasoLogger; -use std::fmt; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; @@ -76,7 +75,7 @@ impl TasoOptimiser where R: Rewriter + Send + Clone + 'static, S: RewriteStrategy + Send + Sync + Clone + 'static, - S::Cost: fmt::Debug + serde::Serialize, + S::Cost: serde::Serialize, { /// Run the TASO optimiser on a circuit. /// diff --git a/src/passes/chunks.rs b/src/passes/chunks.rs index 2f436cdf..38109966 100644 --- a/src/passes/chunks.rs +++ b/src/passes/chunks.rs @@ -1,21 +1,26 @@ -//! Utility +//! This module provides a utility to split a circuit into chunks, and reassemble them afterwards. +//! +//! See [`CircuitChunks`] for more information. use std::collections::HashMap; +use std::mem; +use std::ops::{Index, IndexMut}; -use hugr::builder::{Dataflow, DataflowHugr, FunctionBuilder}; +use hugr::builder::{Container, FunctionBuilder}; use hugr::extension::ExtensionSet; use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::views::sibling_subgraph::ConvexChecker; use hugr::hugr::views::{HierarchyView, SiblingGraph, SiblingSubgraph}; use hugr::hugr::{HugrError, NodeMetadata}; use hugr::ops::handle::DataflowParentID; +use hugr::ops::OpType; use hugr::types::{FunctionType, Signature}; use hugr::{Hugr, HugrView, Node, Port, Wire}; use itertools::Itertools; -use crate::extension::REGISTRY; use crate::Circuit; +use crate::circuit::cost::CircuitCost; #[cfg(feature = "pyo3")] use crate::json::TKETDecode; #[cfg(feature = "pyo3")] @@ -38,9 +43,9 @@ pub struct Chunk { /// The extracted circuit. pub circ: Hugr, /// The original wires connected to the input. - pub inputs: Vec, + inputs: Vec, /// The original wires connected to the output. - pub outputs: Vec, + outputs: Vec, } impl Chunk { @@ -145,7 +150,12 @@ struct ChunkInsertResult { pub outgoing_connections: HashMap, } -/// An utility for splitting a circuit into chunks, and reassembling them afterwards. +/// An utility for splitting a circuit into chunks, and reassembling them +/// afterwards. +/// +/// Circuits can be split into [`CircuitChunks`] with [`CircuitChunks::split`] +/// or [`CircuitChunks::split_with_cost`], and reassembled with +/// [`CircuitChunks::reassemble`]. #[derive(Debug, Clone)] #[cfg_attr(feature = "pyo3", pyclass)] pub struct CircuitChunks { @@ -170,6 +180,17 @@ impl CircuitChunks { /// /// The circuit is split into chunks of at most `max_size` gates. pub fn split(circ: &impl Circuit, max_size: usize) -> Self { + Self::split_with_cost(circ, max_size, |_| 1) + } + + /// Split a circuit into chunks. + /// + /// The circuit is split into chunks of at most `max_cost`, using the provided cost function. + pub fn split_with_cost( + circ: &H, + max_cost: C, + op_cost: impl Fn(&OpType) -> C, + ) -> Self { let root_meta = circ.get_metadata(circ.root()).clone(); let signature = circ.circuit_signature().clone(); @@ -186,7 +207,18 @@ impl CircuitChunks { let mut chunks = Vec::new(); let mut convex_checker = ConvexChecker::new(circ); - for commands in &circ.commands().map(|cmd| cmd.node()).chunks(max_size) { + let mut running_cost = C::default(); + let mut current_group = 0; + for (_, commands) in &circ.commands().map(|cmd| cmd.node()).group_by(|&node| { + let new_cost = running_cost.clone() + op_cost(circ.get_optype(node)); + if new_cost.clone().check_threshold(max_cost.clone()) { + running_cost = C::default(); + current_group += 1; + } else { + running_cost = new_cost; + } + current_group + }) { chunks.push(Chunk::extract(circ, commands, &mut convex_checker)); } Self { @@ -211,10 +243,10 @@ impl CircuitChunks { input_extensions: ExtensionSet::new(), }; - let builder = FunctionBuilder::new(name, signature).unwrap(); - let inputs = builder.input_wires(); - // TODO: Use the correct REGISTRY if the method accepts custom input resources. - let mut reassembled = builder.finish_hugr_with_outputs(inputs, ®ISTRY).unwrap(); + let mut builder = FunctionBuilder::new(name, signature).unwrap(); + // Take the unfinished Hugr from the builder, to avoid unnecessary + // validation checks that require connecting the inputs an outputs. + let mut reassembled = mem::take(builder.hugr_mut()); let root = reassembled.root(); let [reassembled_input, reassembled_output] = reassembled.get_io(root).unwrap(); @@ -229,7 +261,6 @@ impl CircuitChunks { .iter() .zip(reassembled.node_outputs(reassembled_input)) { - reassembled.disconnect(reassembled_input, port)?; sources.insert(connection, (reassembled_input, port)); } for (&connection, port) in self @@ -269,9 +300,24 @@ impl CircuitChunks { } /// Returns a list of references to the split circuits. - pub fn circuits(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.chunks.iter().map(|chunk| &chunk.circ) } + + /// Returns a list of references to the split circuits. + pub fn iter_mut(&mut self) -> impl Iterator { + self.chunks.iter_mut().map(|chunk| &mut chunk.circ) + } + + /// Returns the number of chunks. + pub fn len(&self) -> usize { + self.chunks.len() + } + + /// Returns `true` if there are no chunks. + pub fn is_empty(&self) -> bool { + self.chunks.is_empty() + } } #[cfg(feature = "pyo3")] @@ -287,7 +333,7 @@ impl CircuitChunks { /// Returns clones of the split circuits. #[pyo3(name = "circuits")] fn py_circuits(&self) -> PyResult>> { - self.circuits() + self.iter() .map(|hugr| SerialCircuit::encode(hugr)?.to_tket1()) .collect() } @@ -306,9 +352,24 @@ impl CircuitChunks { } } +impl Index for CircuitChunks { + type Output = Hugr; + + fn index(&self, index: usize) -> &Self::Output { + &self.chunks[index].circ + } +} + +impl IndexMut for CircuitChunks { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.chunks[index].circ + } +} + #[cfg(test)] mod test { use crate::circuit::CircuitHash; + use crate::extension::REGISTRY; use crate::utils::build_simple_circuit; use crate::T2Op; diff --git a/src/rewrite/strategy.rs b/src/rewrite/strategy.rs index 38a63e3f..c5acdb90 100644 --- a/src/rewrite/strategy.rs +++ b/src/rewrite/strategy.rs @@ -12,13 +12,14 @@ //! - [`ExhaustiveGammaStrategy`], which ignores rewrites that increase the //! cost function beyond a threshold given by a f64 parameter gamma. -use std::{collections::HashSet, fmt::Debug, iter::Sum}; +use std::{collections::HashSet, fmt::Debug}; -use derive_more::From; -use hugr::{ops::OpType, Hugr, HugrView, Node}; +use hugr::ops::OpType; +use hugr::{Hugr, HugrView}; use itertools::Itertools; -use crate::{ops::op_matches, Circuit, T2Op}; +use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, MajorMinorCost}; +use crate::Circuit; use super::CircuitRewrite; @@ -32,8 +33,8 @@ use super::CircuitRewrite; /// It also assign every circuit a totally ordered cost that can be used when /// using rewrites for circuit optimisation. pub trait RewriteStrategy { - /// The circuit cost to be minised. - type Cost: Ord; + /// The circuit cost to be minimised. + type Cost: CircuitCost; /// Apply a set of rewrites to a circuit. fn apply_rewrites( @@ -44,6 +45,9 @@ pub trait RewriteStrategy { /// The cost of a circuit. fn circuit_cost(&self, circ: &Hugr) -> Self::Cost; + + /// The cost of a single operation. + fn op_cost(&self, op: &OpType) -> Self::Cost; } /// A rewrite strategy applying as many non-overlapping rewrites as possible. @@ -93,6 +97,10 @@ impl RewriteStrategy for GreedyRewriteStrategy { fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { circ.num_gates() } + + fn op_cost(&self, _op: &OpType) -> Self::Cost { + 1 + } } /// Exhaustive strategies based on cost functions and thresholds. @@ -111,22 +119,17 @@ impl RewriteStrategy for GreedyRewriteStrategy { /// threshold function. pub trait ExhaustiveThresholdStrategy { /// The cost of a single operation. - type OpCost; - /// The sum of the cost of all operations in a circuit. - type SumOpCost; + type OpCost: CircuitCost; - /// Whether the rewrite is allowed or not, based on the cost of the pattern and target. - fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool; + /// Returns true if the rewrite is allowed, based on the cost of the pattern and target. + fn under_threshold(&self, pattern_cost: &Self::OpCost, target_cost: &Self::OpCost) -> bool; /// The cost of a single operation. fn op_cost(&self, op: &OpType) -> Self::OpCost; } -impl RewriteStrategy for T -where - T::SumOpCost: Sum + Ord, -{ - type Cost = T::SumOpCost; +impl RewriteStrategy for T { + type Cost = T::OpCost; #[tracing::instrument(skip_all)] fn apply_rewrites( @@ -139,7 +142,7 @@ where .filter(|rw| { let pattern_cost = pre_rewrite_cost(rw, circ, |op| self.op_cost(op)); let target_cost = post_rewrite_cost(rw, circ, |op| self.op_cost(op)); - self.threshold(&pattern_cost, &target_cost) + self.under_threshold(&pattern_cost, &target_cost) }) .map(|rw| { let mut circ = circ.clone(); @@ -150,7 +153,11 @@ where } fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { - cost(circ.nodes(), circ, |op| self.op_cost(op)) + circ.nodes_cost(circ.nodes(), |op| self.op_cost(op)) + } + + fn op_cost(&self, op: &OpType) -> Self::Cost { + ::op_cost(self, op) } } @@ -180,10 +187,9 @@ where C2: Fn(&OpType) -> usize, { type OpCost = MajorMinorCost; - type SumOpCost = MajorMinorCost; - fn threshold(&self, pattern_cost: &Self::SumOpCost, target_cost: &Self::SumOpCost) -> bool { - target_cost.major <= pattern_cost.major + fn under_threshold(&self, pattern_cost: &Self::OpCost, target_cost: &Self::OpCost) -> bool { + !target_cost.check_threshold(*pattern_cost) } fn op_cost(&self, op: &OpType) -> Self::OpCost { @@ -232,9 +238,8 @@ pub struct ExhaustiveGammaStrategy { impl usize> ExhaustiveThresholdStrategy for ExhaustiveGammaStrategy { type OpCost = usize; - type SumOpCost = usize; - fn threshold(&self, &pattern_cost: &Self::SumOpCost, &target_cost: &Self::SumOpCost) -> bool { + fn under_threshold(&self, &pattern_cost: &Self::OpCost, &target_cost: &Self::OpCost) -> bool { (target_cost as f64) < self.gamma * (pattern_cost as f64) } @@ -275,80 +280,20 @@ impl ExhaustiveGammaStrategy usize> { } } -/// A pair of major and minor cost. -/// -/// This is used to order circuits based on major cost first, then minor cost. -/// A typical example would be CX count as major cost and total gate count as -/// minor cost. -#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From)] -pub struct MajorMinorCost { - major: usize, - minor: usize, -} - -// Serialise as string so that it is easy to write to CSV -impl serde::Serialize for MajorMinorCost { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(&format!("{:?}", self)) - } -} - -impl Debug for MajorMinorCost { - // TODO: A nicer print for the logs - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "(major={}, minor={})", self.major, self.minor) - } -} - -impl Sum for MajorMinorCost { - fn sum>(iter: I) -> Self { - iter.reduce(|a, b| (a.major + b.major, a.minor + b.minor).into()) - .unwrap_or_default() - } -} - -fn is_cx(op: &OpType) -> bool { - op_matches(op, T2Op::CX) -} - -fn is_quantum(op: &OpType) -> bool { - let Ok(op): Result = op.try_into() else { - return false; - }; - op.is_quantum() -} - -fn cost(nodes: impl IntoIterator, circ: &Hugr, op_cost: C) -> S -where - C: Fn(&OpType) -> T, - S: Sum, -{ - nodes - .into_iter() - .map(|n| { - let op = circ.get_optype(n); - op_cost(op) - }) - .sum() -} - -fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: C) -> S +fn pre_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: F) -> C where - C: Fn(&OpType) -> T, - S: Sum, + C: CircuitCost, + F: Fn(&OpType) -> C, { - cost(rw.subcircuit().nodes().iter().copied(), circ, pred) + circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), pred) } -fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: C) -> S +fn post_rewrite_cost(rw: &CircuitRewrite, circ: &Hugr, pred: F) -> C where - C: Fn(&OpType) -> T, - S: Sum, + C: CircuitCost, + F: Fn(&OpType) -> C, { - cost(rw.replacement().nodes(), circ, pred) + circ.nodes_cost(rw.replacement().nodes(), pred) } #[cfg(test)] @@ -473,9 +418,9 @@ mod tests { #[test] fn test_exhaustive_default_cx_threshold() { let strat = NonIncreasingGateCountStrategy::default_cx(); - assert!(strat.threshold(&(3, 0).into(), &(3, 0).into())); - assert!(strat.threshold(&(3, 0).into(), &(3, 5).into())); - assert!(!strat.threshold(&(3, 10).into(), &(4, 0).into())); - assert!(strat.threshold(&(3, 0).into(), &(1, 5).into())); + assert!(strat.under_threshold(&(3, 0).into(), &(3, 0).into())); + assert!(strat.under_threshold(&(3, 0).into(), &(3, 5).into())); + assert!(!strat.under_threshold(&(3, 10).into(), &(4, 0).into())); + assert!(strat.under_threshold(&(3, 0).into(), &(1, 5).into())); } }