Skip to content

Commit

Permalink
feat: bindings for circuit cost and hash (#252)
Browse files Browse the repository at this point in the history
- Adds `Tk2Circuit::hash` to the python bindings.
- Adds bindings for computing the cost of a circuit.
- Adds tests
  - I split `test_bindings.py` into `test_circuit` and `test_pass`
  • Loading branch information
aborgna-q authored Nov 23, 2023
1 parent 60c6608 commit 85ce5f9
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 102 deletions.
8 changes: 6 additions & 2 deletions tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#![allow(unused)]

pub mod convert;
pub mod cost;

use derive_more::{From, Into};
use pyo3::prelude::*;
Expand All @@ -14,14 +15,17 @@ use tket2::rewrite::CircuitRewrite;
use tket_json_rs::circuit_json::SerialCircuit;

pub use self::convert::{try_update_hugr, try_with_hugr, update_hugr, with_hugr, Tk2Circuit};
pub use self::cost::PyCircuitCost;
pub use tket2::{Pauli, Tk2Op};

/// The module definition
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_circuit")?;
m.add_class::<Tk2Circuit>()?;
m.add_class::<PyNode>()?;
m.add_class::<tket2::Tk2Op>()?;
m.add_class::<tket2::Pauli>()?;
m.add_class::<PyCircuitCost>()?;
m.add_class::<Tk2Op>()?;
m.add_class::<Pauli>()?;

m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;
Expand Down
74 changes: 71 additions & 3 deletions tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,43 @@
//! Utilities for calling Hugr functions on generic python objects.
use pyo3::exceptions::PyAttributeError;
use hugr::ops::OpType;
use pyo3::exceptions::{PyAttributeError, PyValueError};
use pyo3::{prelude::*, PyTypeInfo};

use derive_more::From;
use hugr::{Hugr, HugrView};
use serde::Serialize;
use tket2::circuit::CircuitHash;
use tket2::extension::REGISTRY;
use tket2::json::TKETDecode;
use tket2::passes::CircuitChunks;
use tket2::{Circuit, Tk2Op};
use tket_json_rs::circuit_json::SerialCircuit;

use crate::rewrite::PyCircuitRewrite;

/// A manager for tket 2 operations on a tket 1 Circuit.
use super::{cost, PyCircuitCost};

/// A circuit in tket2 format.
///
/// This can be freely converted to and from a `pytket.Circuit`. Prefer using
/// this class when applying multiple tket2 operations on a circuit, as it
/// avoids the overhead of converting to and from a `pytket.Circuit` each time.
///
/// Node indices returned by this class are not stable across conversion to and
/// from a `pytket.Circuit`.
///
/// # Examples
///
/// Convert between `pytket.Circuit`s and `Tk2Circuit`s:
/// ```python
/// from pytket import Circuit
/// c = Circuit(2).H(0).CX(0, 1)
/// # Convert to a Tk2Circuit
/// t2c = Tk2Circuit(c)
/// # Convert back to a pytket.Circuit
/// c2 = t2c.to_tket1()
/// ```
#[pyclass]
#[derive(Clone, Debug, PartialEq, From)]
pub struct Tk2Circuit {
Expand All @@ -37,7 +61,7 @@ impl Tk2Circuit {
}

/// Apply a rewrite on the circuit.
pub fn apply_match(&mut self, rw: PyCircuitRewrite) {
pub fn apply_rewrite(&mut self, rw: PyCircuitRewrite) {
rw.rewrite.apply(&mut self.hugr).expect("Apply error.");
}

Expand Down Expand Up @@ -73,6 +97,50 @@ impl Tk2Circuit {
hugr: tk1.decode()?,
})
}

/// Compute the cost of the circuit based on a per-operation cost function.
///
/// :param cost_fn: A function that takes a `Tk2Op` and returns an arbitrary cost.
/// The cost must implement `__add__`, `__sub__`, `__lt__`,
/// `__eq__`, `__int__`, and integer `__div__`.
///
/// :returns: The sum of all operation costs.
pub fn circuit_cost<'py>(&self, cost_fn: &'py PyAny) -> PyResult<&'py PyAny> {
let py = cost_fn.py();
let cost_fn = |op: &OpType| -> PyResult<PyCircuitCost> {
let tk2_op: Tk2Op = op.try_into().map_err(|e| {
PyErr::new::<PyValueError, _>(format!(
"Could not convert circuit operation to a `Tk2Op`: {e}"
))
})?;
let cost = cost_fn.call1((tk2_op,))?;
Ok(PyCircuitCost {
cost: cost.to_object(py),
})
};
let circ_cost = self.hugr.circuit_cost(cost_fn)?;
Ok(circ_cost.cost.into_ref(py))
}

/// Returns a hash of the circuit.
pub fn hash(&self) -> u64 {
self.hugr.circuit_hash().unwrap()
}

/// Hash the circuit
pub fn __hash__(&self) -> isize {
self.hash() as isize
}

/// Copy the circuit.
pub fn __copy__(&self) -> PyResult<Self> {
Ok(self.clone())
}

/// Copy the circuit.
pub fn __deepcopy__(&self, _memo: Py<PyAny>) -> PyResult<Self> {
Ok(self.clone())
}
}
impl Tk2Circuit {
/// Tries to extract a Tk2Circuit from a python object.
Expand Down
167 changes: 167 additions & 0 deletions tket2-py/src/circuit/cost.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
//!
use std::cmp::Ordering;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Sub};

use pyo3::{prelude::*, PyTypeInfo};
use tket2::circuit::cost::{CircuitCost, CostDelta};

/// A generic circuit cost, backed by an arbitrary python object.
#[pyclass]
#[derive(Clone, Debug)]
#[pyo3(name = "CircuitCost")]
pub struct PyCircuitCost {
/// Generic python cost object.
pub cost: PyObject,
}

#[pymethods]
impl PyCircuitCost {
/// Create a new circuit cost.
#[new]
pub fn new(cost: PyObject) -> Self {
Self { cost }
}
}

impl Default for PyCircuitCost {
fn default() -> Self {
Python::with_gil(|py| PyCircuitCost { cost: py.None() })
}
}

impl Add for PyCircuitCost {
type Output = PyCircuitCost;

fn add(self, rhs: PyCircuitCost) -> Self::Output {
Python::with_gil(|py| {
let cost = self
.cost
.call_method1(py, "__add__", (rhs.cost,))
.expect("Could not add circuit cost objects.");
PyCircuitCost { cost }
})
}
}

impl AddAssign for PyCircuitCost {
fn add_assign(&mut self, rhs: Self) {
Python::with_gil(|py| {
let cost = self
.cost
.call_method1(py, "__add__", (rhs.cost,))
.expect("Could not add circuit cost objects.");
self.cost = cost;
})
}
}

impl Sub for PyCircuitCost {
type Output = PyCircuitCost;

fn sub(self, rhs: PyCircuitCost) -> Self::Output {
Python::with_gil(|py| {
let cost = self
.cost
.call_method1(py, "__sub__", (rhs.cost,))
.expect("Could not subtract circuit cost objects.");
PyCircuitCost { cost }
})
}
}

impl Sum for PyCircuitCost {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
Python::with_gil(|py| {
let cost = iter
.fold(None, |acc: Option<PyObject>, c| {
Some(match acc {
None => c.cost,
Some(cost) => cost
.call_method1(py, "__add__", (c.cost,))
.expect("Could not add circuit cost objects."),
})
})
.unwrap_or_else(|| py.None());
PyCircuitCost { cost }
})
}
}

impl PartialEq for PyCircuitCost {
fn eq(&self, other: &Self) -> bool {
Python::with_gil(|py| {
let res = self
.cost
.call_method1(py, "__eq__", (&other.cost,))
.expect("Could not compare circuit cost objects.");
res.is_true(py)
.expect("Could not compare circuit cost objects.")
})
}
}

impl Eq for PyCircuitCost {}

impl PartialOrd for PyCircuitCost {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for PyCircuitCost {
fn cmp(&self, other: &Self) -> Ordering {
Python::with_gil(|py| -> PyResult<Ordering> {
let res = self.cost.call_method1(py, "__lt__", (&other.cost,))?;
if res.is_true(py)? {
return Ok(Ordering::Less);
}
let res = self.cost.call_method1(py, "__eq__", (&other.cost,))?;
if res.is_true(py)? {
return Ok(Ordering::Equal);
}
Ok(Ordering::Greater)
})
.expect("Could not compare circuit cost objects.")
}
}

impl CostDelta for PyCircuitCost {
fn as_isize(&self) -> isize {
Python::with_gil(|py| {
let res = self
.cost
.call_method0(py, "__int__")
.expect("Could not convert the circuit cost object to an integer.");
res.extract(py)
.expect("Could not convert the circuit cost object to an integer.")
})
}
}

impl CircuitCost for PyCircuitCost {
type CostDelta = PyCircuitCost;

fn as_usize(&self) -> usize {
self.as_isize() as usize
}

fn sub_cost(&self, other: &Self) -> Self::CostDelta {
self.clone() - other.clone()
}

fn add_delta(&self, delta: &Self::CostDelta) -> Self {
self.clone() + delta.clone()
}

fn div_cost(&self, n: std::num::NonZeroUsize) -> Self {
Python::with_gil(|py| {
let res = self
.cost
.call_method0(py, "__div__")
.expect("Could not divide the circuit cost object.");
Self { cost: res }
})
}
}
93 changes: 0 additions & 93 deletions tket2-py/test/test_bindings.py

This file was deleted.

Loading

0 comments on commit 85ce5f9

Please sign in to comment.