diff --git a/pyrs/src/lib.rs b/pyrs/src/lib.rs index 146c36f7..3fde8675 100644 --- a/pyrs/src/lib.rs +++ b/pyrs/src/lib.rs @@ -2,7 +2,7 @@ #![warn(missing_docs)] use circuit::try_with_hugr; use pyo3::prelude::*; -use tket2::{json::TKETDecode, passes::apply_greedy_commutation}; +use tket2::{json::TKETDecode, optimiser::taso, passes::apply_greedy_commutation}; use tket_json_rs::circuit_json::SerialCircuit; mod circuit; @@ -22,6 +22,7 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> { add_circuit_module(py, m)?; add_pattern_module(py, m)?; add_pass_module(py, m)?; + add_optimiser_module(py, m)?; Ok(()) } @@ -78,3 +79,11 @@ fn add_pass_module(py: Python, parent: &PyModule) -> PyResult<()> { parent.add_submodule(m)?; Ok(()) } + +/// circuit optimisation module +fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> { + let m = PyModule::new(py, "optimiser")?; + m.add_class::()?; + + parent.add_submodule(m) +} diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 2c7f4982..99f9e08b 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -15,6 +15,9 @@ mod eq_circ_class; mod hugr_pchannel; mod hugr_pqueue; pub mod log; +#[cfg(feature = "pyo3")] +#[cfg(feature = "portmatching")] +pub mod pyo3; mod qtz_circuit; mod worker; @@ -36,9 +39,6 @@ use crate::rewrite::Rewriter; use self::log::TasoLogger; -#[cfg(feature = "portmatching")] -use std::io; - /// The TASO optimiser. /// /// Adapted from [Quartz][], and originally [TASO][]. @@ -295,8 +295,11 @@ where mod taso_default { use hugr::ops::OpType; use hugr::HugrView; + use std::io; + use std::path::Path; use crate::ops::op_matches; + use crate::rewrite::ecc_rewriter::RewriterSerialisationError; use crate::rewrite::strategy::ExhaustiveRewriteStrategy; use crate::rewrite::ECCRewriter; use crate::T2Op; @@ -312,13 +315,20 @@ mod taso_default { impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. - pub fn default_with_eccs_json_file( - eccs_path: impl AsRef, - ) -> io::Result { + pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) } + + /// A sane default optimiser using a precompiled binary rewriter. + pub fn default_with_rewriter_binary( + rewriter_path: impl AsRef, + ) -> Result { + let rewriter = ECCRewriter::load_binary(rewriter_path)?; + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) + } } fn num_cx_gates(circ: &Hugr) -> usize { diff --git a/src/optimiser/taso/pyo3.rs b/src/optimiser/taso/pyo3.rs new file mode 100644 index 00000000..e4a3566d --- /dev/null +++ b/src/optimiser/taso/pyo3.rs @@ -0,0 +1,62 @@ +//! PyO3 wrapper for the TASO optimiser. + +use std::{fs, num::NonZeroUsize}; + +use pyo3::{exceptions::PyTypeError, prelude::*}; +use tket_json_rs::circuit_json::SerialCircuit; + +use crate::{json::TKETDecode, utils::pyobj_as_hugr}; + +use super::{log::TasoLogger, DefaultTasoOptimiser}; + +/// Wrapped [`DefaultTasoOptimiser`]. +/// +/// Currently only exposes loading from an ECC file using the constructor +/// and optimising using default logging settings. +#[pyclass(name = "TasoOptimiser")] +pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser); + +#[pymethods] +impl PyDefaultTasoOptimiser { + /// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter. + #[staticmethod] + pub fn load_precompiled(path: &str) -> Self { + Self(DefaultTasoOptimiser::default_with_rewriter_binary(path).unwrap()) + } + + /// Create a new [`PyDefaultTasoOptimiser`] from ECC sets. + /// + /// This will compile the rewriter from the provided ECC JSON file. + #[staticmethod] + pub fn compile_eccs(path: &str) -> Self { + Self(DefaultTasoOptimiser::default_with_eccs_json_file(path).unwrap()) + } + + /// Run the optimiser on a circuit. + /// + /// Returns an optimised circuit and log the progress to a CSV + /// file called "best_circs.csv". + pub fn optimise( + &mut self, + circ: PyObject, + timeout: Option, + n_threads: Option, + ) -> PyResult { + let circ = pyobj_as_hugr(circ)?; + let circ_candidates_csv = fs::File::create("best_circs.csv").unwrap(); + + let taso_logger = TasoLogger::new(circ_candidates_csv); + self.0.optimise_with_log( + &circ, + taso_logger, + timeout, + n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), + ); + let ser_circ = + SerialCircuit::encode(&circ).map_err(|e| PyTypeError::new_err(e.to_string()))?; + let tk1_circ = ser_circ + .to_tket1() + .map_err(|e| PyTypeError::new_err(e.to_string()))?; + Ok(tk1_circ) + } +} diff --git a/src/portmatching/pyo3.rs b/src/portmatching/pyo3.rs index 1ff2a7b6..b2180768 100644 --- a/src/portmatching/pyo3.rs +++ b/src/portmatching/pyo3.rs @@ -4,16 +4,15 @@ use std::fmt; use derive_more::{From, Into}; use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError; -use hugr::{Hugr, Port}; +use hugr::Port; use itertools::Itertools; use portmatching::{HashMap, PatternID}; use pyo3::{prelude::*, types::PyIterator}; -use tket_json_rs::circuit_json::SerialCircuit; use super::{CircuitPattern, PatternMatch, PatternMatcher}; use crate::circuit::Circuit; -use crate::json::TKETDecode; use crate::rewrite::CircuitRewrite; +use crate::utils::pyobj_as_hugr; #[pymethods] impl CircuitPattern { @@ -196,9 +195,3 @@ impl Node { format!("{:?}", self) } } - -fn pyobj_as_hugr(circ: PyObject) -> PyResult { - let ser_c = SerialCircuit::_from_tket1(circ); - let hugr: Hugr = ser_c.decode()?; - Ok(hugr) -} diff --git a/src/utils.rs b/src/utils.rs index 124cda37..f734fa0d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -13,6 +13,24 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool { !TypeBound::Copyable.contains(typ.least_upper_bound()) } +// Convert a pytket object to HUGR +#[cfg(feature = "pyo3")] +mod pyo3 { + use hugr::Hugr; + use pyo3::prelude::*; + use tket_json_rs::circuit_json::SerialCircuit; + + use crate::json::TKETDecode; + + pub(crate) fn pyobj_as_hugr(circ: PyObject) -> PyResult { + let ser_c = SerialCircuit::_from_tket1(circ); + let hugr: Hugr = ser_c.decode()?; + Ok(hugr) + } +} +#[cfg(feature = "pyo3")] +pub(crate) use pyo3::pyobj_as_hugr; + // utility for building simple qubit-only circuits. #[allow(unused)] pub(crate) fn build_simple_circuit(