Skip to content

Commit

Permalink
feat: Python bindings for TASO (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada authored Sep 29, 2023
1 parent ae8b281 commit ca2d3df
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 9 deletions.
3 changes: 3 additions & 0 deletions pyrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
//! Python bindings for TKET2.
#![warn(missing_docs)]
use circuit::{add_circuit_module, try_with_hugr};
use optimiser::add_optimiser_module;
use pyo3::prelude::*;
use tket2::{json::TKETDecode, passes::apply_greedy_commutation};
use tket_json_rs::circuit_json::SerialCircuit;

mod circuit;
mod optimiser;

#[pyfunction]
fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> {
Expand All @@ -22,6 +24,7 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
add_circuit_module(py, m)?;
add_pattern_module(py, m)?;
add_pass_module(py, m)?;
add_optimiser_module(py, m)?;
Ok(())
}

Expand Down
67 changes: 67 additions & 0 deletions pyrs/src/optimiser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//! PyO3 wrapper for the TASO circuit optimiser.
use std::{fs, num::NonZeroUsize, path::PathBuf};

use pyo3::prelude::*;
use tket2::optimiser::{DefaultTasoOptimiser, TasoLogger};

use crate::circuit::update_hugr;

/// The circuit optimisation module.
pub fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "optimiser")?;
m.add_class::<PyDefaultTasoOptimiser>()?;

parent.add_submodule(m)
}

/// Wrapped [`DefaultTasoOptimiser`].
///
/// Currently only exposes loading from an ECC file using the constructor
/// and optimising using default logging settings.
#[pyclass(name = "TasoOptimiser")]
pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser);

#[pymethods]
impl PyDefaultTasoOptimiser {
/// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter.
#[staticmethod]
pub fn load_precompiled(path: PathBuf) -> Self {
Self(DefaultTasoOptimiser::default_with_rewriter_binary(path).unwrap())
}

/// Create a new [`PyDefaultTasoOptimiser`] from ECC sets.
///
/// This will compile the rewriter from the provided ECC JSON file.
#[staticmethod]
pub fn compile_eccs(path: &str) -> Self {
Self(DefaultTasoOptimiser::default_with_eccs_json_file(path).unwrap())
}

/// Run the optimiser on a circuit.
///
/// Returns an optimised circuit and optionally log the progress to a CSV
/// file.
pub fn optimise(
&self,
circ: PyObject,
timeout: Option<u64>,
n_threads: Option<NonZeroUsize>,
log_progress: Option<PathBuf>,
) -> PyResult<PyObject> {
let taso_logger = log_progress
.map(|file_name| {
let log_file = fs::File::create(file_name).unwrap();
TasoLogger::new(log_file)
})
.unwrap_or_default();
update_hugr(circ, |circ| {
self.0.optimise_with_log(
&circ,
taso_logger,
timeout,
n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
)
})
}
}
2 changes: 1 addition & 1 deletion src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pub mod taso;

#[cfg(feature = "portmatching")]
pub use taso::DefaultTasoOptimiser;
pub use taso::TasoOptimiser;
pub use taso::{TasoLogger, TasoOptimiser};
22 changes: 14 additions & 8 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod worker;

use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
pub use log::TasoLogger;

use std::num::NonZeroUsize;
use std::time::{Duration, Instant};
Expand All @@ -34,11 +35,6 @@ use crate::optimiser::taso::worker::TasoWorker;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::Rewriter;

use self::log::TasoLogger;

#[cfg(feature = "portmatching")]
use std::io;

/// The TASO optimiser.
///
/// Adapted from [Quartz][], and originally [TASO][].
Expand Down Expand Up @@ -295,8 +291,11 @@ where
mod taso_default {
use hugr::ops::OpType;
use hugr::HugrView;
use std::io;
use std::path::Path;

use crate::ops::op_matches;
use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use crate::rewrite::ECCRewriter;
use crate::T2Op;
Expand All @@ -312,13 +311,20 @@ mod taso_default {

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(
eccs_path: impl AsRef<std::path::Path>,
) -> io::Result<Self> {
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}

/// A sane default optimiser using a precompiled binary rewriter.
pub fn default_with_rewriter_binary(
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}
}

fn num_cx_gates(circ: &Hugr) -> usize {
Expand Down

0 comments on commit ca2d3df

Please sign in to comment.