From 8b50e90b9aa87d32392ef9c0e7ac0ed1fcbab142 Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:30:08 +0000 Subject: [PATCH] feat!: Add lexicographic cost (#270) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com> --- tket2/src/circuit/cost.rs | 155 +++++++++++++++++++++++++------------- 1 file changed, 101 insertions(+), 54 deletions(-) diff --git a/tket2/src/circuit/cost.rs b/tket2/src/circuit/cost.rs index 4346d168..c94ddef5 100644 --- a/tket2/src/circuit/cost.rs +++ b/tket2/src/circuit/cost.rs @@ -1,8 +1,8 @@ //! Cost definitions for a circuit. -use derive_more::From; use hugr::ops::OpType; -use std::fmt::{Debug, Display}; +use itertools::izip; +use std::fmt::Debug; use std::iter::Sum; use std::num::NonZeroUsize; use std::ops::{Add, AddAssign}; @@ -41,14 +41,32 @@ pub trait CostDelta: /// This is used to order circuits based on major cost first, then minor cost. /// A typical example would be CX count as major cost and total gate count as /// minor cost. -#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, From)] -pub struct MajorMinorCost { - major: T, - minor: T, +pub type MajorMinorCost = LexicographicCost; + +impl From for LexicographicCost +where + V: Into<[T; N]>, +{ + fn from(v: V) -> Self { + Self(v.into()) + } +} + +/// A cost that is ordered lexicographically. +/// +/// An array of cost functions, where the first one is infinitely more important +/// than the second, which is infinitely more important than the third, etc. +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct LexicographicCost([T; N]); + +impl Default for LexicographicCost { + fn default() -> Self { + Self([Default::default(); N]) + } } // Serialise as string so that it is easy to write to CSV -impl serde::Serialize for MajorMinorCost { +impl serde::Serialize for LexicographicCost { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, @@ -57,70 +75,86 @@ impl serde::Serialize for MajorMinorCost { } } -impl Debug for MajorMinorCost { +impl Debug for LexicographicCost { // TODO: A nicer print for the logs fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "(major={}, minor={})", self.major, self.minor) + write!(f, "{:?}", self.0) } } -impl> Add for MajorMinorCost { - type Output = MajorMinorCost; +impl + Copy, const N: usize> Add for LexicographicCost { + type Output = Self; - fn add(self, rhs: MajorMinorCost) -> Self::Output { - (self.major + rhs.major, self.minor + rhs.minor).into() + fn add(mut self, rhs: Self) -> Self::Output { + for i in 0..N { + self.0[i] = self.0[i] + rhs.0[i]; + } + self } } -impl AddAssign for MajorMinorCost { +impl AddAssign for LexicographicCost { fn add_assign(&mut self, rhs: Self) { - self.major += rhs.major; - self.minor += rhs.minor; + for i in 0..N { + self.0[i] += rhs.0[i]; + } } } -impl + Default> Sum for MajorMinorCost { +impl + Default + Copy, const N: usize> Sum for LexicographicCost { fn sum>(iter: I) -> Self { - iter.reduce(|a, b| (a.major + b.major, a.minor + b.minor).into()) - .unwrap_or_default() + iter.reduce(|a, b| a + b).unwrap_or_default() } } -impl CostDelta for MajorMinorCost { +impl CostDelta for LexicographicCost { #[inline] fn as_isize(&self) -> isize { - self.major + if N > 0 { + self.0[0] + } else { + 0 + } } } -impl CircuitCost for MajorMinorCost { - type CostDelta = MajorMinorCost; +impl CircuitCost for LexicographicCost { + type CostDelta = LexicographicCost; #[inline] fn as_usize(&self) -> usize { - self.major + if N > 0 { + self.0[0] + } else { + 0 + } } #[inline] fn sub_cost(&self, other: &Self) -> Self::CostDelta { - let major = (self.major as isize) - (other.major as isize); - let minor = (self.minor as isize) - (other.minor as isize); - MajorMinorCost { major, minor } + let mut costdelta = [0; N]; + for (delta, &a, &b) in izip!(costdelta.iter_mut(), &self.0, &other.0) { + *delta = (a as isize) - (b as isize); + } + LexicographicCost(costdelta) } #[inline] fn add_delta(&self, delta: &Self::CostDelta) -> Self { - MajorMinorCost { - major: self.major.saturating_add_signed(delta.major), - minor: self.minor.saturating_add_signed(delta.minor), + let mut ret = [0; N]; + for (add, &a, &b) in izip!(ret.iter_mut(), &self.0, &delta.0) { + *add = a.saturating_add_signed(b); } + Self(ret) } #[inline] fn div_cost(&self, n: NonZeroUsize) -> Self { - let major = (self.major.saturating_sub(1)) / n.get() + 1; - let minor = (self.minor.saturating_sub(1)) / n.get() + 1; - Self { major, minor } + let mut ret = [0; N]; + for (div, &a) in ret.iter_mut().zip(&self.0) { + *div = (a.saturating_sub(1)) / n.get() + 1; + } + Self(ret) } } @@ -174,38 +208,51 @@ mod tests { #[test] fn major_minor() { - let a = MajorMinorCost { - major: 10, - minor: 2, - }; - let b = MajorMinorCost { - major: 20, - minor: 1, - }; + let a = LexicographicCost([10, 2]); + let b = LexicographicCost([20, 1]); assert!(a < b); - assert_eq!( - a + b, - MajorMinorCost { - major: 30, - minor: 3 - } - ); + assert_eq!(a + b, LexicographicCost([30, 3])); assert_eq!(a.sub_cost(&b).as_isize(), -10); assert_eq!(b.sub_cost(&a).as_isize(), 10); assert_eq!( a.div_cost(NonZeroUsize::new(2).unwrap()), - MajorMinorCost { major: 5, minor: 1 } + LexicographicCost([5, 1]) ); assert_eq!( a.div_cost(NonZeroUsize::new(3).unwrap()), - MajorMinorCost { major: 4, minor: 1 } + LexicographicCost([4, 1]) ); assert_eq!( a.div_cost(NonZeroUsize::new(1).unwrap()), - MajorMinorCost { - major: 10, - minor: 2 - } + LexicographicCost([10, 2]) ); } + + #[test] + fn zero_dim_cost() { + let a = LexicographicCost::([]); + let b = LexicographicCost::([]); + assert_eq!(a, b); + assert_eq!(a + b, LexicographicCost::([])); + assert_eq!(a.sub_cost(&b).as_isize(), 0); + assert_eq!(b.sub_cost(&a).as_isize(), 0); + assert_eq!(a.div_cost(NonZeroUsize::new(2).unwrap()), a); + assert_eq!(a.div_cost(NonZeroUsize::new(3).unwrap()), a); + assert_eq!(a.div_cost(NonZeroUsize::new(1).unwrap()), a); + } + + #[test] + fn as_usize() { + let a = LexicographicCost([10, 2]); + assert_eq!(a.as_usize(), 10); + let a = LexicographicCost::([]); + assert_eq!(a.as_usize(), 0); + } + + #[test] + fn serde_serialize() { + let a = LexicographicCost([10, 2]); + let s = serde_json::to_string(&a).unwrap(); + assert_eq!(s, "\"[10, 2]\""); + } }