Skip to content

Commit

Permalink
feat!: Add lexicographic cost (#270)
Browse files Browse the repository at this point in the history
Co-authored-by: Agustín Borgna <[email protected]>
  • Loading branch information
lmondada and aborgna-q authored Dec 14, 2023
1 parent c2875d8 commit 8b50e90
Showing 1 changed file with 101 additions and 54 deletions.
155 changes: 101 additions & 54 deletions tket2/src/circuit/cost.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Cost definitions for a circuit.
use derive_more::From;
use hugr::ops::OpType;
use std::fmt::{Debug, Display};
use itertools::izip;
use std::fmt::Debug;
use std::iter::Sum;
use std::num::NonZeroUsize;
use std::ops::{Add, AddAssign};
Expand Down Expand Up @@ -41,14 +41,32 @@ pub trait CostDelta:
/// This is used to order circuits based on major cost first, then minor cost.
/// A typical example would be CX count as major cost and total gate count as
/// minor cost.
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From)]
pub struct MajorMinorCost<T = usize> {
major: T,
minor: T,
pub type MajorMinorCost<T = usize> = LexicographicCost<T, 2>;

impl<const N: usize, V, T> From<V> for LexicographicCost<T, N>
where
V: Into<[T; N]>,
{
fn from(v: V) -> Self {
Self(v.into())
}
}

/// A cost that is ordered lexicographically.
///
/// An array of cost functions, where the first one is infinitely more important
/// than the second, which is infinitely more important than the third, etc.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct LexicographicCost<T, const N: usize>([T; N]);

impl<const N: usize, T: Default + Copy> Default for LexicographicCost<T, N> {
fn default() -> Self {
Self([Default::default(); N])
}
}

// Serialise as string so that it is easy to write to CSV
impl serde::Serialize for MajorMinorCost {
impl<const N: usize> serde::Serialize for LexicographicCost<usize, N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
Expand All @@ -57,70 +75,86 @@ impl serde::Serialize for MajorMinorCost {
}
}

impl<T: Display> Debug for MajorMinorCost<T> {
impl<T: Debug, const N: usize> Debug for LexicographicCost<T, N> {
// TODO: A nicer print for the logs
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "(major={}, minor={})", self.major, self.minor)
write!(f, "{:?}", self.0)
}
}

impl<T: Add<Output = T>> Add for MajorMinorCost<T> {
type Output = MajorMinorCost<T>;
impl<T: Add<Output = T> + Copy, const N: usize> Add for LexicographicCost<T, N> {
type Output = Self;

fn add(self, rhs: MajorMinorCost<T>) -> Self::Output {
(self.major + rhs.major, self.minor + rhs.minor).into()
fn add(mut self, rhs: Self) -> Self::Output {
for i in 0..N {
self.0[i] = self.0[i] + rhs.0[i];
}
self
}
}

impl<T: AddAssign> AddAssign for MajorMinorCost<T> {
impl<T: AddAssign + Copy, const N: usize> AddAssign for LexicographicCost<T, N> {
fn add_assign(&mut self, rhs: Self) {
self.major += rhs.major;
self.minor += rhs.minor;
for i in 0..N {
self.0[i] += rhs.0[i];
}
}
}

impl<T: Add<Output = T> + Default> Sum for MajorMinorCost<T> {
impl<T: Add<Output = T> + Default + Copy, const N: usize> Sum for LexicographicCost<T, N> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|a, b| (a.major + b.major, a.minor + b.minor).into())
.unwrap_or_default()
iter.reduce(|a, b| a + b).unwrap_or_default()
}
}

impl CostDelta for MajorMinorCost<isize> {
impl<const N: usize> CostDelta for LexicographicCost<isize, N> {
#[inline]
fn as_isize(&self) -> isize {
self.major
if N > 0 {
self.0[0]
} else {
0
}
}
}

impl CircuitCost for MajorMinorCost<usize> {
type CostDelta = MajorMinorCost<isize>;
impl<const N: usize> CircuitCost for LexicographicCost<usize, N> {
type CostDelta = LexicographicCost<isize, N>;

#[inline]
fn as_usize(&self) -> usize {
self.major
if N > 0 {
self.0[0]
} else {
0
}
}

#[inline]
fn sub_cost(&self, other: &Self) -> Self::CostDelta {
let major = (self.major as isize) - (other.major as isize);
let minor = (self.minor as isize) - (other.minor as isize);
MajorMinorCost { major, minor }
let mut costdelta = [0; N];
for (delta, &a, &b) in izip!(costdelta.iter_mut(), &self.0, &other.0) {
*delta = (a as isize) - (b as isize);
}
LexicographicCost(costdelta)
}

#[inline]
fn add_delta(&self, delta: &Self::CostDelta) -> Self {
MajorMinorCost {
major: self.major.saturating_add_signed(delta.major),
minor: self.minor.saturating_add_signed(delta.minor),
let mut ret = [0; N];
for (add, &a, &b) in izip!(ret.iter_mut(), &self.0, &delta.0) {
*add = a.saturating_add_signed(b);
}
Self(ret)
}

#[inline]
fn div_cost(&self, n: NonZeroUsize) -> Self {
let major = (self.major.saturating_sub(1)) / n.get() + 1;
let minor = (self.minor.saturating_sub(1)) / n.get() + 1;
Self { major, minor }
let mut ret = [0; N];
for (div, &a) in ret.iter_mut().zip(&self.0) {
*div = (a.saturating_sub(1)) / n.get() + 1;
}
Self(ret)
}
}

Expand Down Expand Up @@ -174,38 +208,51 @@ mod tests {

#[test]
fn major_minor() {
let a = MajorMinorCost {
major: 10,
minor: 2,
};
let b = MajorMinorCost {
major: 20,
minor: 1,
};
let a = LexicographicCost([10, 2]);
let b = LexicographicCost([20, 1]);
assert!(a < b);
assert_eq!(
a + b,
MajorMinorCost {
major: 30,
minor: 3
}
);
assert_eq!(a + b, LexicographicCost([30, 3]));
assert_eq!(a.sub_cost(&b).as_isize(), -10);
assert_eq!(b.sub_cost(&a).as_isize(), 10);
assert_eq!(
a.div_cost(NonZeroUsize::new(2).unwrap()),
MajorMinorCost { major: 5, minor: 1 }
LexicographicCost([5, 1])
);
assert_eq!(
a.div_cost(NonZeroUsize::new(3).unwrap()),
MajorMinorCost { major: 4, minor: 1 }
LexicographicCost([4, 1])
);
assert_eq!(
a.div_cost(NonZeroUsize::new(1).unwrap()),
MajorMinorCost {
major: 10,
minor: 2
}
LexicographicCost([10, 2])
);
}

#[test]
fn zero_dim_cost() {
let a = LexicographicCost::<usize, 0>([]);
let b = LexicographicCost::<usize, 0>([]);
assert_eq!(a, b);
assert_eq!(a + b, LexicographicCost::<usize, 0>([]));
assert_eq!(a.sub_cost(&b).as_isize(), 0);
assert_eq!(b.sub_cost(&a).as_isize(), 0);
assert_eq!(a.div_cost(NonZeroUsize::new(2).unwrap()), a);
assert_eq!(a.div_cost(NonZeroUsize::new(3).unwrap()), a);
assert_eq!(a.div_cost(NonZeroUsize::new(1).unwrap()), a);
}

#[test]
fn as_usize() {
let a = LexicographicCost([10, 2]);
assert_eq!(a.as_usize(), 10);
let a = LexicographicCost::<usize, 0>([]);
assert_eq!(a.as_usize(), 0);
}

#[test]
fn serde_serialize() {
let a = LexicographicCost([10, 2]);
let s = serde_json::to_string(&a).unwrap();
assert_eq!(s, "\"[10, 2]\"");
}
}

0 comments on commit 8b50e90

Please sign in to comment.