Skip to content

Commit

Permalink
feat: Python bindings for TASO
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada committed Sep 27, 2023
1 parent d068c3b commit b158bcf
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 16 deletions.
11 changes: 10 additions & 1 deletion pyrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#![warn(missing_docs)]
use circuit::try_with_hugr;
use pyo3::prelude::*;
use tket2::{json::TKETDecode, passes::apply_greedy_commutation};
use tket2::{json::TKETDecode, optimiser::taso, passes::apply_greedy_commutation};
use tket_json_rs::circuit_json::SerialCircuit;

mod circuit;
Expand All @@ -22,6 +22,7 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
add_circuit_module(py, m)?;
add_pattern_module(py, m)?;
add_pass_module(py, m)?;
add_optimiser_module(py, m)?;
Ok(())
}

Expand Down Expand Up @@ -78,3 +79,11 @@ fn add_pass_module(py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_submodule(m)?;
Ok(())
}

/// circuit optimisation module
fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "optimiser")?;
m.add_class::<taso::pyo3::PyDefaultTasoOptimiser>()?;

parent.add_submodule(m)
}
22 changes: 16 additions & 6 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ mod eq_circ_class;
mod hugr_pchannel;
mod hugr_pqueue;
pub mod log;
#[cfg(feature = "pyo3")]
#[cfg(feature = "portmatching")]
pub mod pyo3;
mod qtz_circuit;
mod worker;

Expand All @@ -36,9 +39,6 @@ use crate::rewrite::Rewriter;

use self::log::TasoLogger;

#[cfg(feature = "portmatching")]
use std::io;

/// The TASO optimiser.
///
/// Adapted from [Quartz][], and originally [TASO][].
Expand Down Expand Up @@ -295,8 +295,11 @@ where
mod taso_default {
use hugr::ops::OpType;
use hugr::HugrView;
use std::io;
use std::path::Path;

use crate::ops::op_matches;
use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use crate::rewrite::ECCRewriter;
use crate::T2Op;
Expand All @@ -312,13 +315,20 @@ mod taso_default {

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(
eccs_path: impl AsRef<std::path::Path>,
) -> io::Result<Self> {
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}

/// A sane default optimiser using a precompiled binary rewriter.
pub fn default_with_rewriter_binary(
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}
}

fn num_cx_gates(circ: &Hugr) -> usize {
Expand Down
62 changes: 62 additions & 0 deletions src/optimiser/taso/pyo3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//! PyO3 wrapper for the TASO optimiser.
use std::{fs, num::NonZeroUsize};

use pyo3::{exceptions::PyTypeError, prelude::*};
use tket_json_rs::circuit_json::SerialCircuit;

use crate::{json::TKETDecode, utils::pyobj_as_hugr};

use super::{log::TasoLogger, DefaultTasoOptimiser};

/// Wrapped [`DefaultTasoOptimiser`].
///
/// Currently only exposes loading from an ECC file using the constructor
/// and optimising using default logging settings.
#[pyclass(name = "TasoOptimiser")]
pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser);

#[pymethods]
impl PyDefaultTasoOptimiser {
/// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter.
#[staticmethod]
pub fn load_precompiled(path: &str) -> Self {
Self(DefaultTasoOptimiser::default_with_rewriter_binary(path).unwrap())
}

/// Create a new [`PyDefaultTasoOptimiser`] from ECC sets.
///
/// This will compile the rewriter from the provided ECC JSON file.
#[staticmethod]
pub fn compile_eccs(path: &str) -> Self {
Self(DefaultTasoOptimiser::default_with_eccs_json_file(path).unwrap())
}

/// Run the optimiser on a circuit.
///
/// Returns an optimised circuit and log the progress to a CSV
/// file called "best_circs.csv".
pub fn optimise(
&mut self,
circ: PyObject,
timeout: Option<u64>,
n_threads: Option<NonZeroUsize>,
) -> PyResult<PyObject> {
let circ = pyobj_as_hugr(circ)?;
let circ_candidates_csv = fs::File::create("best_circs.csv").unwrap();

let taso_logger = TasoLogger::new(circ_candidates_csv);
self.0.optimise_with_log(
&circ,
taso_logger,
timeout,
n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
);
let ser_circ =
SerialCircuit::encode(&circ).map_err(|e| PyTypeError::new_err(e.to_string()))?;
let tk1_circ = ser_circ
.to_tket1()
.map_err(|e| PyTypeError::new_err(e.to_string()))?;
Ok(tk1_circ)
}
}
11 changes: 2 additions & 9 deletions src/portmatching/pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ use std::fmt;

use derive_more::{From, Into};
use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError;
use hugr::{Hugr, Port};
use hugr::Port;
use itertools::Itertools;
use portmatching::{HashMap, PatternID};
use pyo3::{prelude::*, types::PyIterator};
use tket_json_rs::circuit_json::SerialCircuit;

use super::{CircuitPattern, PatternMatch, PatternMatcher};
use crate::circuit::Circuit;
use crate::json::TKETDecode;
use crate::rewrite::CircuitRewrite;
use crate::utils::pyobj_as_hugr;

#[pymethods]
impl CircuitPattern {
Expand Down Expand Up @@ -196,9 +195,3 @@ impl Node {
format!("{:?}", self)
}
}

fn pyobj_as_hugr(circ: PyObject) -> PyResult<Hugr> {
let ser_c = SerialCircuit::_from_tket1(circ);
let hugr: Hugr = ser_c.decode()?;
Ok(hugr)
}
18 changes: 18 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool {
!TypeBound::Copyable.contains(typ.least_upper_bound())
}

// Convert a pytket object to HUGR
#[cfg(feature = "pyo3")]
mod pyo3 {
use hugr::Hugr;
use pyo3::prelude::*;
use tket_json_rs::circuit_json::SerialCircuit;

use crate::json::TKETDecode;

pub(crate) fn pyobj_as_hugr(circ: PyObject) -> PyResult<Hugr> {
let ser_c = SerialCircuit::_from_tket1(circ);
let hugr: Hugr = ser_c.decode()?;
Ok(hugr)
}
}
#[cfg(feature = "pyo3")]
pub(crate) use pyo3::pyobj_as_hugr;

// utility for building simple qubit-only circuits.
#[allow(unused)]
pub(crate) fn build_simple_circuit(
Expand Down

0 comments on commit b158bcf

Please sign in to comment.