Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Python bindings for TASO #142

Merged
merged 13 commits into from
Sep 29, 2023
3 changes: 3 additions & 0 deletions pyrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
//! Python bindings for TKET2.
#![warn(missing_docs)]
use circuit::{add_circuit_module, try_with_hugr};
use optimiser::add_optimiser_module;
use pyo3::prelude::*;
use tket2::{json::TKETDecode, passes::apply_greedy_commutation};
use tket_json_rs::circuit_json::SerialCircuit;

mod circuit;
mod optimiser;

#[pyfunction]
fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> {
Expand All @@ -22,6 +24,7 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
add_circuit_module(py, m)?;
add_pattern_module(py, m)?;
add_pass_module(py, m)?;
add_optimiser_module(py, m)?;
Ok(())
}

Expand Down
67 changes: 67 additions & 0 deletions pyrs/src/optimiser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//! PyO3 wrapper for the TASO circuit optimiser.

use std::{fs, num::NonZeroUsize, path::PathBuf};

use pyo3::prelude::*;
use tket2::optimiser::{DefaultTasoOptimiser, TasoLogger};

use crate::circuit::update_hugr;

/// The circuit optimisation module.
pub fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "optimiser")?;
m.add_class::<PyDefaultTasoOptimiser>()?;

parent.add_submodule(m)
}

/// Wrapped [`DefaultTasoOptimiser`].
///
/// Currently only exposes loading from an ECC file using the constructor
/// and optimising using default logging settings.
#[pyclass(name = "TasoOptimiser")]
pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser);

#[pymethods]
impl PyDefaultTasoOptimiser {
/// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter.
#[staticmethod]
pub fn load_precompiled(path: PathBuf) -> Self {
Self(DefaultTasoOptimiser::default_with_rewriter_binary(path).unwrap())
}

/// Create a new [`PyDefaultTasoOptimiser`] from ECC sets.
///
/// This will compile the rewriter from the provided ECC JSON file.
#[staticmethod]
pub fn compile_eccs(path: &str) -> Self {
Self(DefaultTasoOptimiser::default_with_eccs_json_file(path).unwrap())
}

/// Run the optimiser on a circuit.
///
/// Returns an optimised circuit and optionally log the progress to a CSV
/// file.
pub fn optimise(
&self,
circ: PyObject,
timeout: Option<u64>,
n_threads: Option<NonZeroUsize>,
log_progress: Option<PathBuf>,
) -> PyResult<PyObject> {
let taso_logger = log_progress
.map(|file_name| {
let log_file = fs::File::create(file_name).unwrap();
TasoLogger::new(log_file)
})
.unwrap_or_default();
update_hugr(circ, |circ| {
self.0.optimise_with_log(
&circ,
taso_logger,
timeout,
n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
)
})
}
}
2 changes: 1 addition & 1 deletion src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pub mod taso;

#[cfg(feature = "portmatching")]
pub use taso::DefaultTasoOptimiser;
pub use taso::TasoOptimiser;
pub use taso::{TasoLogger, TasoOptimiser};
22 changes: 14 additions & 8 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod worker;

use crossbeam_channel::select;
pub use eq_circ_class::{load_eccs_json_file, EqCircClass};
pub use log::TasoLogger;

use std::num::NonZeroUsize;
use std::time::{Duration, Instant};
Expand All @@ -34,11 +35,6 @@ use crate::optimiser::taso::worker::TasoWorker;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::Rewriter;

use self::log::TasoLogger;

#[cfg(feature = "portmatching")]
use std::io;

/// The TASO optimiser.
///
/// Adapted from [Quartz][], and originally [TASO][].
Expand Down Expand Up @@ -295,8 +291,11 @@ where
mod taso_default {
use hugr::ops::OpType;
use hugr::HugrView;
use std::io;
use std::path::Path;

use crate::ops::op_matches;
use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use crate::rewrite::ECCRewriter;
use crate::T2Op;
Expand All @@ -312,13 +311,20 @@ mod taso_default {

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(
eccs_path: impl AsRef<std::path::Path>,
) -> io::Result<Self> {
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}

/// A sane default optimiser using a precompiled binary rewriter.
pub fn default_with_rewriter_binary(
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}
}

fn num_cx_gates(circ: &Hugr) -> usize {
Expand Down