Skip to content

Commit

Permalink
feat: Circuit cost module and methods
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Oct 3, 2023
1 parent 8e6692c commit d2a7304
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 111 deletions.
24 changes: 24 additions & 0 deletions src/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Quantum circuit representation and operations.
pub mod command;
pub mod cost;
mod hash;
pub mod units;

Expand All @@ -22,6 +23,7 @@ use itertools::Itertools;
use portgraph::Direction;
use thiserror::Error;

use self::cost::CircuitCost;
use self::units::{filter, FilteredUnits, Units};

/// An object behaving like a quantum circuit.
Expand Down Expand Up @@ -126,6 +128,28 @@ pub trait Circuit: HugrView {
// Traverse the circuit in topological order.
CommandIterator::new(self)
}

/// Compute the cost of the circuit based on a per-operation cost function.
#[inline]
fn circuit_cost<F, C>(&self, op_cost: F) -> C
where
Self: Sized,
C: CircuitCost,
F: Fn(&OpType) -> C,
{
self.commands().map(|cmd| op_cost(cmd.optype())).sum()
}

/// Compute the cost of a group of nodes in a circuit based on a
/// per-operation cost function.
#[inline]
fn nodes_cost<F, C>(&self, nodes: impl IntoIterator<Item = Node>, op_cost: F) -> C
where
C: CircuitCost,
F: Fn(&OpType) -> C,
{
nodes.into_iter().map(|n| op_cost(self.get_optype(n))).sum()
}
}

/// Remove an empty wire in a dataflow HUGR.
Expand Down
102 changes: 102 additions & 0 deletions src/circuit/cost.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//! Cost definitions for a circuit.
use derive_more::From;
use hugr::ops::OpType;
use std::fmt::Debug;
use std::iter::Sum;
use std::num::NonZeroUsize;
use std::ops::Add;

use crate::ops::op_matches;
use crate::T2Op;

/// The cost for a group of operations in a circuit, each with cost `OpCost`.
pub trait CircuitCost: Add<Output = Self> + Sum<Self> + Debug + Default + Clone + Ord {
/// Returns true if the cost is above the threshold.
fn check_threshold(self, threshold: Self) -> bool;

/// Divide the cost, rounded up.
fn div_cost(self, n: NonZeroUsize) -> Self;
}

/// A pair of major and minor cost.
///
/// This is used to order circuits based on major cost first, then minor cost.
/// A typical example would be CX count as major cost and total gate count as
/// minor cost.
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From)]
pub struct MajorMinorCost {
major: usize,
minor: usize,
}

// Serialise as string so that it is easy to write to CSV
impl serde::Serialize for MajorMinorCost {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&format!("{:?}", self))
}
}

impl Debug for MajorMinorCost {
// TODO: A nicer print for the logs
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "(major={}, minor={})", self.major, self.minor)
}
}

impl Add<MajorMinorCost> for MajorMinorCost {
type Output = MajorMinorCost;

fn add(self, rhs: MajorMinorCost) -> Self::Output {
(self.major + rhs.major, self.minor + rhs.minor).into()
}
}

impl Sum for MajorMinorCost {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|a, b| (a.major + b.major, a.minor + b.minor).into())
.unwrap_or_default()
}
}

impl CircuitCost for MajorMinorCost {
#[inline]
fn check_threshold(self, threshold: Self) -> bool {
self.major > threshold.major
}

#[inline]
fn div_cost(mut self, n: NonZeroUsize) -> Self {
self.major = (self.major.saturating_sub(1)) / n.get() + 1;
self.minor = (self.minor.saturating_sub(1)) / n.get() + 1;
self
}
}

impl CircuitCost for usize {
#[inline]
fn check_threshold(self, threshold: Self) -> bool {
self > threshold
}

#[inline]
fn div_cost(self, n: NonZeroUsize) -> Self {
(self.saturating_sub(1)) / n.get() + 1
}
}

/// Returns true if the operation is a controlled X operation.
pub fn is_cx(op: &OpType) -> bool {
op_matches(op, T2Op::CX)
}

/// Returns true if the operation is a quantum operation.
pub fn is_quantum(op: &OpType) -> bool {
let Ok(op): Result<T2Op, _> = op.try_into() else {
return false;
};
op.is_quantum()
}
3 changes: 1 addition & 2 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
pub use log::TasoLogger;

use std::fmt;
use std::num::NonZeroUsize;
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -76,7 +75,7 @@ impl<R, S> TasoOptimiser<R, S>
where
R: Rewriter + Send + Clone + 'static,
S: RewriteStrategy + Send + Sync + Clone + 'static,
S::Cost: fmt::Debug + serde::Serialize,
S::Cost: serde::Serialize,
{
/// Run the TASO optimiser on a circuit.
///
Expand Down
89 changes: 75 additions & 14 deletions src/passes/chunks.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
//! Utility
//! This module provides a utility to split a circuit into chunks, and reassemble them afterwards.
//!
//! See [`CircuitChunks`] for more information.
use std::collections::HashMap;
use std::mem;
use std::ops::{Index, IndexMut};

use hugr::builder::{Dataflow, DataflowHugr, FunctionBuilder};
use hugr::builder::{Container, FunctionBuilder};
use hugr::extension::ExtensionSet;
use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::views::sibling_subgraph::ConvexChecker;
use hugr::hugr::views::{HierarchyView, SiblingGraph, SiblingSubgraph};
use hugr::hugr::{HugrError, NodeMetadata};
use hugr::ops::handle::DataflowParentID;
use hugr::ops::OpType;
use hugr::types::{FunctionType, Signature};
use hugr::{Hugr, HugrView, Node, Port, Wire};
use itertools::Itertools;

use crate::extension::REGISTRY;
use crate::Circuit;

use crate::circuit::cost::CircuitCost;
#[cfg(feature = "pyo3")]
use crate::json::TKETDecode;
#[cfg(feature = "pyo3")]
Expand All @@ -38,9 +43,9 @@ pub struct Chunk {
/// The extracted circuit.
pub circ: Hugr,
/// The original wires connected to the input.
pub inputs: Vec<ChunkConnection>,
inputs: Vec<ChunkConnection>,
/// The original wires connected to the output.
pub outputs: Vec<ChunkConnection>,
outputs: Vec<ChunkConnection>,
}

impl Chunk {
Expand Down Expand Up @@ -145,7 +150,12 @@ struct ChunkInsertResult {
pub outgoing_connections: HashMap<ChunkConnection, (Node, Port)>,
}

/// An utility for splitting a circuit into chunks, and reassembling them afterwards.
/// An utility for splitting a circuit into chunks, and reassembling them
/// afterwards.
///
/// Circuits can be split into [`CircuitChunks`] with [`CircuitChunks::split`]
/// or [`CircuitChunks::split_with_cost`], and reassembled with
/// [`CircuitChunks::reassemble`].
#[derive(Debug, Clone)]
#[cfg_attr(feature = "pyo3", pyclass)]
pub struct CircuitChunks {
Expand All @@ -170,6 +180,17 @@ impl CircuitChunks {
///
/// The circuit is split into chunks of at most `max_size` gates.
pub fn split(circ: &impl Circuit, max_size: usize) -> Self {
Self::split_with_cost(circ, max_size, |_| 1)
}

/// Split a circuit into chunks.
///
/// The circuit is split into chunks of at most `max_cost`, using the provided cost function.
pub fn split_with_cost<H: Circuit, C: CircuitCost>(
circ: &H,
max_cost: C,
op_cost: impl Fn(&OpType) -> C,
) -> Self {
let root_meta = circ.get_metadata(circ.root()).clone();
let signature = circ.circuit_signature().clone();

Expand All @@ -186,7 +207,18 @@ impl CircuitChunks {

let mut chunks = Vec::new();
let mut convex_checker = ConvexChecker::new(circ);
for commands in &circ.commands().map(|cmd| cmd.node()).chunks(max_size) {
let mut running_cost = C::default();
let mut current_group = 0;
for (_, commands) in &circ.commands().map(|cmd| cmd.node()).group_by(|&node| {
let new_cost = running_cost.clone() + op_cost(circ.get_optype(node));
if new_cost.clone().check_threshold(max_cost.clone()) {
running_cost = C::default();
current_group += 1;
} else {
running_cost = new_cost;
}
current_group
}) {
chunks.push(Chunk::extract(circ, commands, &mut convex_checker));
}
Self {
Expand All @@ -211,10 +243,10 @@ impl CircuitChunks {
input_extensions: ExtensionSet::new(),
};

let builder = FunctionBuilder::new(name, signature).unwrap();
let inputs = builder.input_wires();
// TODO: Use the correct REGISTRY if the method accepts custom input resources.
let mut reassembled = builder.finish_hugr_with_outputs(inputs, &REGISTRY).unwrap();
let mut builder = FunctionBuilder::new(name, signature).unwrap();
// Take the unfinished Hugr from the builder, to avoid unnecessary
// validation checks that require connecting the inputs an outputs.
let mut reassembled = mem::take(builder.hugr_mut());
let root = reassembled.root();
let [reassembled_input, reassembled_output] = reassembled.get_io(root).unwrap();

Expand All @@ -229,7 +261,6 @@ impl CircuitChunks {
.iter()
.zip(reassembled.node_outputs(reassembled_input))
{
reassembled.disconnect(reassembled_input, port)?;
sources.insert(connection, (reassembled_input, port));
}
for (&connection, port) in self
Expand Down Expand Up @@ -269,9 +300,24 @@ impl CircuitChunks {
}

/// Returns a list of references to the split circuits.
pub fn circuits(&self) -> impl Iterator<Item = &Hugr> {
pub fn iter(&self) -> impl Iterator<Item = &Hugr> {
self.chunks.iter().map(|chunk| &chunk.circ)
}

/// Returns a list of references to the split circuits.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Hugr> {
self.chunks.iter_mut().map(|chunk| &mut chunk.circ)
}

/// Returns the number of chunks.
pub fn len(&self) -> usize {
self.chunks.len()
}

/// Returns `true` if there are no chunks.
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}
}

#[cfg(feature = "pyo3")]
Expand All @@ -287,7 +333,7 @@ impl CircuitChunks {
/// Returns clones of the split circuits.
#[pyo3(name = "circuits")]
fn py_circuits(&self) -> PyResult<Vec<Py<PyAny>>> {
self.circuits()
self.iter()
.map(|hugr| SerialCircuit::encode(hugr)?.to_tket1())
.collect()
}
Expand All @@ -306,9 +352,24 @@ impl CircuitChunks {
}
}

impl Index<usize> for CircuitChunks {
type Output = Hugr;

fn index(&self, index: usize) -> &Self::Output {
&self.chunks[index].circ
}
}

impl IndexMut<usize> for CircuitChunks {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.chunks[index].circ
}
}

#[cfg(test)]
mod test {
use crate::circuit::CircuitHash;
use crate::extension::REGISTRY;
use crate::utils::build_simple_circuit;
use crate::T2Op;

Expand Down
Loading

0 comments on commit d2a7304

Please sign in to comment.