From ca2d3df9821ea3168c5b22b2682889bf63475352 Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Fri, 29 Sep 2023 15:29:09 +0200 Subject: [PATCH] feat: Python bindings for TASO (#142) --- pyrs/src/lib.rs | 3 ++ pyrs/src/optimiser.rs | 67 +++++++++++++++++++++++++++++++++++++++++++ src/optimiser.rs | 2 +- src/optimiser/taso.rs | 22 ++++++++------ 4 files changed, 85 insertions(+), 9 deletions(-) create mode 100644 pyrs/src/optimiser.rs diff --git a/pyrs/src/lib.rs b/pyrs/src/lib.rs index 062b4ba4..d845f2c9 100644 --- a/pyrs/src/lib.rs +++ b/pyrs/src/lib.rs @@ -1,11 +1,13 @@ //! Python bindings for TKET2. #![warn(missing_docs)] use circuit::{add_circuit_module, try_with_hugr}; +use optimiser::add_optimiser_module; use pyo3::prelude::*; use tket2::{json::TKETDecode, passes::apply_greedy_commutation}; use tket_json_rs::circuit_json::SerialCircuit; mod circuit; +mod optimiser; #[pyfunction] fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> { @@ -22,6 +24,7 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> { add_circuit_module(py, m)?; add_pattern_module(py, m)?; add_pass_module(py, m)?; + add_optimiser_module(py, m)?; Ok(()) } diff --git a/pyrs/src/optimiser.rs b/pyrs/src/optimiser.rs new file mode 100644 index 00000000..bd7b6a4c --- /dev/null +++ b/pyrs/src/optimiser.rs @@ -0,0 +1,67 @@ +//! PyO3 wrapper for the TASO circuit optimiser. + +use std::{fs, num::NonZeroUsize, path::PathBuf}; + +use pyo3::prelude::*; +use tket2::optimiser::{DefaultTasoOptimiser, TasoLogger}; + +use crate::circuit::update_hugr; + +/// The circuit optimisation module. +pub fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> { + let m = PyModule::new(py, "optimiser")?; + m.add_class::()?; + + parent.add_submodule(m) +} + +/// Wrapped [`DefaultTasoOptimiser`]. +/// +/// Currently only exposes loading from an ECC file using the constructor +/// and optimising using default logging settings. +#[pyclass(name = "TasoOptimiser")] +pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser); + +#[pymethods] +impl PyDefaultTasoOptimiser { + /// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter. + #[staticmethod] + pub fn load_precompiled(path: PathBuf) -> Self { + Self(DefaultTasoOptimiser::default_with_rewriter_binary(path).unwrap()) + } + + /// Create a new [`PyDefaultTasoOptimiser`] from ECC sets. + /// + /// This will compile the rewriter from the provided ECC JSON file. + #[staticmethod] + pub fn compile_eccs(path: &str) -> Self { + Self(DefaultTasoOptimiser::default_with_eccs_json_file(path).unwrap()) + } + + /// Run the optimiser on a circuit. + /// + /// Returns an optimised circuit and optionally log the progress to a CSV + /// file. + pub fn optimise( + &self, + circ: PyObject, + timeout: Option, + n_threads: Option, + log_progress: Option, + ) -> PyResult { + let taso_logger = log_progress + .map(|file_name| { + let log_file = fs::File::create(file_name).unwrap(); + TasoLogger::new(log_file) + }) + .unwrap_or_default(); + update_hugr(circ, |circ| { + self.0.optimise_with_log( + &circ, + taso_logger, + timeout, + n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), + ) + }) + } +} diff --git a/src/optimiser.rs b/src/optimiser.rs index 0760d0d3..4656d0be 100644 --- a/src/optimiser.rs +++ b/src/optimiser.rs @@ -6,4 +6,4 @@ pub mod taso; #[cfg(feature = "portmatching")] pub use taso::DefaultTasoOptimiser; -pub use taso::TasoOptimiser; +pub use taso::{TasoLogger, TasoOptimiser}; diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 2c7f4982..a231aaf8 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -20,6 +20,7 @@ mod worker; use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; +pub use log::TasoLogger; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; @@ -34,11 +35,6 @@ use crate::optimiser::taso::worker::TasoWorker; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; -use self::log::TasoLogger; - -#[cfg(feature = "portmatching")] -use std::io; - /// The TASO optimiser. /// /// Adapted from [Quartz][], and originally [TASO][]. @@ -295,8 +291,11 @@ where mod taso_default { use hugr::ops::OpType; use hugr::HugrView; + use std::io; + use std::path::Path; use crate::ops::op_matches; + use crate::rewrite::ecc_rewriter::RewriterSerialisationError; use crate::rewrite::strategy::ExhaustiveRewriteStrategy; use crate::rewrite::ECCRewriter; use crate::T2Op; @@ -312,13 +311,20 @@ mod taso_default { impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. - pub fn default_with_eccs_json_file( - eccs_path: impl AsRef, - ) -> io::Result { + pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) } + + /// A sane default optimiser using a precompiled binary rewriter. + pub fn default_with_rewriter_binary( + rewriter_path: impl AsRef, + ) -> Result { + let rewriter = ECCRewriter::load_binary(rewriter_path)?; + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) + } } fn num_cx_gates(circ: &Hugr) -> usize {