From d068c3b53f9ea145e97962eba70243558b1cddac Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Wed, 27 Sep 2023 19:44:09 +0200 Subject: [PATCH 1/9] feat: Serialisation for ECCRewriter --- Cargo.toml | 2 +- compile-matcher/src/main.rs | 93 ------------------ .../Cargo.toml | 0 .../README.md | 0 .../matcher.bin | Bin compile-rewriter/src/main.rs | 75 ++++++++++++++ src/rewrite/ecc_rewriter.rs | 69 ++++++++++++- 7 files changed, 142 insertions(+), 97 deletions(-) delete mode 100644 compile-matcher/src/main.rs rename {compile-matcher => compile-rewriter}/Cargo.toml (100%) rename {compile-matcher => compile-rewriter}/README.md (100%) rename {compile-matcher => compile-rewriter}/matcher.bin (100%) create mode 100644 compile-rewriter/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index f7a52f64..e80033e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,7 +69,7 @@ harness = false [workspace] -members = ["pyrs", "compile-matcher", "taso-optimiser"] +members = ["pyrs", "compile-rewriter", "taso-optimiser"] [workspace.dependencies] diff --git a/compile-matcher/src/main.rs b/compile-matcher/src/main.rs deleted file mode 100644 index 64f3c6d5..00000000 --- a/compile-matcher/src/main.rs +++ /dev/null @@ -1,93 +0,0 @@ -use std::fs; -use std::path::Path; -use std::process::exit; - -use clap::Parser; -use itertools::Itertools; - -use tket2::json::load_tk1_json_file; -// Import the PatternMatcher struct and its methods -use tket2::optimiser::taso::load_eccs_json_file; -use tket2::portmatching::{CircuitPattern, PatternMatcher}; - -/// Program to precompile patterns from files into a PatternMatcher stored as binary file. -#[derive(Parser, Debug)] -#[clap(version = "1.0", long_about = None)] -#[clap(about = "Precompiles patterns from files into a PatternMatcher stored as binary file.")] -struct CmdLineArgs { - // TODO: Differentiate between TK1 input and ECC input - /// Name of input file/folder - #[arg( - short, - long, - value_name = "FILE", - help = "Sets the input file or folder to use. It is either a JSON file of Quartz-generated ECCs or a folder with TK1 circuits in JSON format." - )] - input: String, - /// Name of output file/folder - #[arg( - short, - long, - value_name = "FILE", - default_value = ".", - help = "Sets the output file or folder to use. Defaults to \"matcher.bin\" if no file name is provided." - )] - output: String, -} - -fn main() { - let opts = CmdLineArgs::parse(); - - let input_path = Path::new(&opts.input); - let output_path = Path::new(&opts.output); - - let all_circs = if input_path.is_file() { - // Input is an ECC file in JSON format - let Ok(eccs) = load_eccs_json_file(input_path) else { - eprintln!( - "Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?", - input_path - ); - exit(1); - }; - eccs.into_iter() - .flat_map(|ecc| ecc.into_circuits()) - .collect_vec() - } else if input_path.is_dir() { - // Input is a folder with TK1 circuits in JSON format - fs::read_dir(input_path) - .unwrap() - .map(|file| { - let path = file.unwrap().path(); - load_tk1_json_file(path).unwrap() - }) - .collect_vec() - } else { - panic!("Input must be a file or a directory"); - }; - - let patterns = all_circs - .iter() - .filter_map(|circ| { - // Fail silently on empty or disconnected patterns - CircuitPattern::try_from_circuit(circ).ok() - }) - .collect_vec(); - println!("Loaded {} patterns.", patterns.len()); - - println!("Building matcher..."); - let output_file = if output_path.is_dir() { - output_path.join("matcher.bin") - } else { - output_path.to_path_buf() - }; - let matcher = PatternMatcher::from_patterns(patterns); - matcher.save_binary(output_file.to_str().unwrap()).unwrap(); - println!("Written matcher to {:?}", output_file); - - // Print the file size of output_file in megabytes - if let Ok(metadata) = fs::metadata(&output_file) { - let file_size = metadata.len() as f64 / (1024.0 * 1024.0); - println!("File size: {:.2} MB", file_size); - } -} diff --git a/compile-matcher/Cargo.toml b/compile-rewriter/Cargo.toml similarity index 100% rename from compile-matcher/Cargo.toml rename to compile-rewriter/Cargo.toml diff --git a/compile-matcher/README.md b/compile-rewriter/README.md similarity index 100% rename from compile-matcher/README.md rename to compile-rewriter/README.md diff --git a/compile-matcher/matcher.bin b/compile-rewriter/matcher.bin similarity index 100% rename from compile-matcher/matcher.bin rename to compile-rewriter/matcher.bin diff --git a/compile-rewriter/src/main.rs b/compile-rewriter/src/main.rs new file mode 100644 index 00000000..6960426e --- /dev/null +++ b/compile-rewriter/src/main.rs @@ -0,0 +1,75 @@ +use std::fs; +use std::path::Path; +use std::process::exit; +use std::time::Instant; + +use clap::Parser; + +use tket2::rewrite::ECCRewriter; + +/// Program to precompile patterns from files into a PatternMatcher stored as binary file. +#[derive(Parser, Debug)] +#[clap(version = "1.0", long_about = None)] +#[clap( + about = "Precompiles ECC sets into a TKET2 Rewriter. The resulting binary files can be loaded into TKET2 for circuit optimisation." +)] +struct CmdLineArgs { + // TODO: Differentiate between TK1 input and ECC input + /// Name of input file/folder + #[arg( + short, + long, + value_name = "FILE", + help = "Sets the input file to use. It must be a JSON file of ECC sets in the Quartz format." + )] + input: String, + /// Name of output file/folder + #[arg( + short, + long, + value_name = "FILE", + default_value = ".", + help = "Sets the output file or folder. Defaults to \"matcher.rwr\" if no file name is provided. The extension of the file name will always be set or amended to be `.rwr`." + )] + output: String, +} + +fn main() { + let opts = CmdLineArgs::parse(); + + let input_path = Path::new(&opts.input); + let output_path = Path::new(&opts.output); + + if !input_path.is_file() || input_path.extension().unwrap() != "json" { + panic!("Input must be a JSON file"); + }; + let start_time = Instant::now(); + println!("Compiling rewriter..."); + let Ok(rewriter) = ECCRewriter::try_from_eccs_json_file(input_path) else { + eprintln!( + "Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?", + input_path + ); + exit(1); + }; + println!("Saving to file..."); + let output_file = if output_path.is_dir() { + output_path.join("matcher.rwr") + } else { + output_path.to_path_buf() + }; + let output_file = rewriter.save_binary(output_file.to_str().unwrap()).unwrap(); + println!("Written rewriter to {:?}", output_file); + + // Print the file size of output_file in megabytes + if let Ok(metadata) = fs::metadata(&output_file) { + let file_size = metadata.len() as f64 / (1024.0 * 1024.0); + println!("File size: {:.2} MB", file_size); + } + let elapsed = start_time.elapsed(); + println!( + "Done in {}.{:03} seconds", + elapsed.as_secs(), + elapsed.subsec_millis() + ); +} diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index afd4ff10..dba5d3cc 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -15,8 +15,10 @@ use derive_more::{From, Into}; use itertools::Itertools; use portmatching::PatternID; -use std::io; +use std::fs::File; use std::path::Path; +use std::{io, path::PathBuf}; +use thiserror::Error; use hugr::Hugr; @@ -28,7 +30,7 @@ use crate::{ use super::{CircuitRewrite, Rewriter}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, From, Into)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, From, Into, serde::Serialize, serde::Deserialize)] struct TargetID(usize); /// A rewriter based on circuit equivalence classes. @@ -37,7 +39,7 @@ struct TargetID(usize); /// Valid rewrites turn a non-representative circuit into its representative, /// or a representative circuit into any of the equivalent non-representative /// circuits. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct ECCRewriter { /// Matcher for finding patterns. matcher: PatternMatcher, @@ -91,6 +93,53 @@ impl ECCRewriter { .iter() .map(|id| &self.targets[id.0]) } + + /// Serialise a rewriter to an IO stream. + /// + /// Precomputed rewriters can be serialised as binary and then loaded + /// later using [`ECCRewriter::load_binary_io`]. + pub fn save_binary_io( + &self, + writer: &mut W, + ) -> Result<(), RewriterSerialisationError> { + rmp_serde::encode::write(writer, &self)?; + Ok(()) + } + + /// Load a rewriter from an IO stream. + /// + /// Loads streams as created by [`ECCRewriter::save_binary_io`]. + pub fn load_binary_io(reader: &mut R) -> Result { + let matcher: Self = rmp_serde::decode::from_read(reader)?; + Ok(matcher) + } + + /// Save a rewriter as a binary file. + /// + /// Precomputed rewriters can be saved as binary files and then loaded + /// later using [`ECCRewriter::load_binary`]. + /// + /// The extension of the file name will always be set or amended to be + /// `.rwr`. + /// + /// If successful, returns the path to the newly created file. + pub fn save_binary( + &self, + name: impl AsRef, + ) -> Result { + let mut file_name = PathBuf::from(name.as_ref()); + file_name.set_extension("rwr"); + let mut file = File::create(&file_name)?; + self.save_binary_io(&mut file)?; + Ok(file_name) + } + + /// Loads a rewriter saved using [`ECCRewriter::save_binary`]. + pub fn load_binary(name: impl AsRef) -> Result { + let file = File::open(name)?; + let mut reader = std::io::BufReader::new(file); + Self::load_binary_io(&mut reader) + } } impl Rewriter for ECCRewriter { @@ -109,6 +158,20 @@ impl Rewriter for ECCRewriter { } } +/// Errors that can occur when (de)serialising an [`ECCRewriter`]. +#[derive(Debug, Error)] +pub enum RewriterSerialisationError { + /// An IO error occured + #[error("IO error: {0}")] + Io(#[from] io::Error), + /// An error occured during deserialisation + #[error("Deserialisation error: {0}")] + Deserialisation(#[from] rmp_serde::decode::Error), + /// An error occured during serialisation + #[error("Serialisation error: {0}")] + Serialisation(#[from] rmp_serde::encode::Error), +} + fn into_targets(rep_sets: Vec) -> Vec { rep_sets .into_iter() From b158bcfac04741822b58ec20363ab781dfc9b763 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Wed, 27 Sep 2023 11:08:42 +0200 Subject: [PATCH 2/9] feat: Python bindings for TASO --- pyrs/src/lib.rs | 11 ++++++- src/optimiser/taso.rs | 22 ++++++++++---- src/optimiser/taso/pyo3.rs | 62 ++++++++++++++++++++++++++++++++++++++ src/portmatching/pyo3.rs | 11 ++----- src/utils.rs | 18 +++++++++++ 5 files changed, 108 insertions(+), 16 deletions(-) create mode 100644 src/optimiser/taso/pyo3.rs diff --git a/pyrs/src/lib.rs b/pyrs/src/lib.rs index 146c36f7..3fde8675 100644 --- a/pyrs/src/lib.rs +++ b/pyrs/src/lib.rs @@ -2,7 +2,7 @@ #![warn(missing_docs)] use circuit::try_with_hugr; use pyo3::prelude::*; -use tket2::{json::TKETDecode, passes::apply_greedy_commutation}; +use tket2::{json::TKETDecode, optimiser::taso, passes::apply_greedy_commutation}; use tket_json_rs::circuit_json::SerialCircuit; mod circuit; @@ -22,6 +22,7 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> { add_circuit_module(py, m)?; add_pattern_module(py, m)?; add_pass_module(py, m)?; + add_optimiser_module(py, m)?; Ok(()) } @@ -78,3 +79,11 @@ fn add_pass_module(py: Python, parent: &PyModule) -> PyResult<()> { parent.add_submodule(m)?; Ok(()) } + +/// circuit optimisation module +fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> { + let m = PyModule::new(py, "optimiser")?; + m.add_class::()?; + + parent.add_submodule(m) +} diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 2c7f4982..99f9e08b 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -15,6 +15,9 @@ mod eq_circ_class; mod hugr_pchannel; mod hugr_pqueue; pub mod log; +#[cfg(feature = "pyo3")] +#[cfg(feature = "portmatching")] +pub mod pyo3; mod qtz_circuit; mod worker; @@ -36,9 +39,6 @@ use crate::rewrite::Rewriter; use self::log::TasoLogger; -#[cfg(feature = "portmatching")] -use std::io; - /// The TASO optimiser. /// /// Adapted from [Quartz][], and originally [TASO][]. @@ -295,8 +295,11 @@ where mod taso_default { use hugr::ops::OpType; use hugr::HugrView; + use std::io; + use std::path::Path; use crate::ops::op_matches; + use crate::rewrite::ecc_rewriter::RewriterSerialisationError; use crate::rewrite::strategy::ExhaustiveRewriteStrategy; use crate::rewrite::ECCRewriter; use crate::T2Op; @@ -312,13 +315,20 @@ mod taso_default { impl DefaultTasoOptimiser { /// A sane default optimiser using the given ECC sets. - pub fn default_with_eccs_json_file( - eccs_path: impl AsRef, - ) -> io::Result { + pub fn default_with_eccs_json_file(eccs_path: impl AsRef) -> io::Result { let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?; let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) } + + /// A sane default optimiser using a precompiled binary rewriter. + pub fn default_with_rewriter_binary( + rewriter_path: impl AsRef, + ) -> Result { + let rewriter = ECCRewriter::load_binary(rewriter_path)?; + let strategy = ExhaustiveRewriteStrategy::exhaustive_cx(); + Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates)) + } } fn num_cx_gates(circ: &Hugr) -> usize { diff --git a/src/optimiser/taso/pyo3.rs b/src/optimiser/taso/pyo3.rs new file mode 100644 index 00000000..e4a3566d --- /dev/null +++ b/src/optimiser/taso/pyo3.rs @@ -0,0 +1,62 @@ +//! PyO3 wrapper for the TASO optimiser. + +use std::{fs, num::NonZeroUsize}; + +use pyo3::{exceptions::PyTypeError, prelude::*}; +use tket_json_rs::circuit_json::SerialCircuit; + +use crate::{json::TKETDecode, utils::pyobj_as_hugr}; + +use super::{log::TasoLogger, DefaultTasoOptimiser}; + +/// Wrapped [`DefaultTasoOptimiser`]. +/// +/// Currently only exposes loading from an ECC file using the constructor +/// and optimising using default logging settings. +#[pyclass(name = "TasoOptimiser")] +pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser); + +#[pymethods] +impl PyDefaultTasoOptimiser { + /// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter. + #[staticmethod] + pub fn load_precompiled(path: &str) -> Self { + Self(DefaultTasoOptimiser::default_with_rewriter_binary(path).unwrap()) + } + + /// Create a new [`PyDefaultTasoOptimiser`] from ECC sets. + /// + /// This will compile the rewriter from the provided ECC JSON file. + #[staticmethod] + pub fn compile_eccs(path: &str) -> Self { + Self(DefaultTasoOptimiser::default_with_eccs_json_file(path).unwrap()) + } + + /// Run the optimiser on a circuit. + /// + /// Returns an optimised circuit and log the progress to a CSV + /// file called "best_circs.csv". + pub fn optimise( + &mut self, + circ: PyObject, + timeout: Option, + n_threads: Option, + ) -> PyResult { + let circ = pyobj_as_hugr(circ)?; + let circ_candidates_csv = fs::File::create("best_circs.csv").unwrap(); + + let taso_logger = TasoLogger::new(circ_candidates_csv); + self.0.optimise_with_log( + &circ, + taso_logger, + timeout, + n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), + ); + let ser_circ = + SerialCircuit::encode(&circ).map_err(|e| PyTypeError::new_err(e.to_string()))?; + let tk1_circ = ser_circ + .to_tket1() + .map_err(|e| PyTypeError::new_err(e.to_string()))?; + Ok(tk1_circ) + } +} diff --git a/src/portmatching/pyo3.rs b/src/portmatching/pyo3.rs index 1ff2a7b6..b2180768 100644 --- a/src/portmatching/pyo3.rs +++ b/src/portmatching/pyo3.rs @@ -4,16 +4,15 @@ use std::fmt; use derive_more::{From, Into}; use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError; -use hugr::{Hugr, Port}; +use hugr::Port; use itertools::Itertools; use portmatching::{HashMap, PatternID}; use pyo3::{prelude::*, types::PyIterator}; -use tket_json_rs::circuit_json::SerialCircuit; use super::{CircuitPattern, PatternMatch, PatternMatcher}; use crate::circuit::Circuit; -use crate::json::TKETDecode; use crate::rewrite::CircuitRewrite; +use crate::utils::pyobj_as_hugr; #[pymethods] impl CircuitPattern { @@ -196,9 +195,3 @@ impl Node { format!("{:?}", self) } } - -fn pyobj_as_hugr(circ: PyObject) -> PyResult { - let ser_c = SerialCircuit::_from_tket1(circ); - let hugr: Hugr = ser_c.decode()?; - Ok(hugr) -} diff --git a/src/utils.rs b/src/utils.rs index 124cda37..f734fa0d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -13,6 +13,24 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool { !TypeBound::Copyable.contains(typ.least_upper_bound()) } +// Convert a pytket object to HUGR +#[cfg(feature = "pyo3")] +mod pyo3 { + use hugr::Hugr; + use pyo3::prelude::*; + use tket_json_rs::circuit_json::SerialCircuit; + + use crate::json::TKETDecode; + + pub(crate) fn pyobj_as_hugr(circ: PyObject) -> PyResult { + let ser_c = SerialCircuit::_from_tket1(circ); + let hugr: Hugr = ser_c.decode()?; + Ok(hugr) + } +} +#[cfg(feature = "pyo3")] +pub(crate) use pyo3::pyobj_as_hugr; + // utility for building simple qubit-only circuits. #[allow(unused)] pub(crate) fn build_simple_circuit( From b7dd79c5704ae34f80e42bca14b2e758c66b6da3 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Thu, 28 Sep 2023 08:32:45 +0200 Subject: [PATCH 3/9] Do not require mut for fn optimise --- src/optimiser/taso/pyo3.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimiser/taso/pyo3.rs b/src/optimiser/taso/pyo3.rs index e4a3566d..73f5bc01 100644 --- a/src/optimiser/taso/pyo3.rs +++ b/src/optimiser/taso/pyo3.rs @@ -37,7 +37,7 @@ impl PyDefaultTasoOptimiser { /// Returns an optimised circuit and log the progress to a CSV /// file called "best_circs.csv". pub fn optimise( - &mut self, + &self, circ: PyObject, timeout: Option, n_threads: Option, From 8c22a9fd2d8c224abad313de5378b3da1e7f9646 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Thu, 28 Sep 2023 08:44:15 +0200 Subject: [PATCH 4/9] Actually return optimised circuit --- src/optimiser/taso/pyo3.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optimiser/taso/pyo3.rs b/src/optimiser/taso/pyo3.rs index 73f5bc01..47dd600c 100644 --- a/src/optimiser/taso/pyo3.rs +++ b/src/optimiser/taso/pyo3.rs @@ -46,14 +46,14 @@ impl PyDefaultTasoOptimiser { let circ_candidates_csv = fs::File::create("best_circs.csv").unwrap(); let taso_logger = TasoLogger::new(circ_candidates_csv); - self.0.optimise_with_log( + let opt_circ = self.0.optimise_with_log( &circ, taso_logger, timeout, n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), ); let ser_circ = - SerialCircuit::encode(&circ).map_err(|e| PyTypeError::new_err(e.to_string()))?; + SerialCircuit::encode(&opt_circ).map_err(|e| PyTypeError::new_err(e.to_string()))?; let tk1_circ = ser_circ .to_tket1() .map_err(|e| PyTypeError::new_err(e.to_string()))?; From 5681cef69a4b02490e46e40e47010b212b8f174a Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Thu, 28 Sep 2023 08:46:41 +0200 Subject: [PATCH 5/9] Fix ambiguous import --- src/optimiser/taso.rs | 3 +-- src/utils.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 99f9e08b..e35b29f7 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -15,8 +15,7 @@ mod eq_circ_class; mod hugr_pchannel; mod hugr_pqueue; pub mod log; -#[cfg(feature = "pyo3")] -#[cfg(feature = "portmatching")] +#[cfg(all(feature = "pyo3", feature = "portmatching"))] pub mod pyo3; mod qtz_circuit; mod worker; diff --git a/src/utils.rs b/src/utils.rs index f734fa0d..367e5d1f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -29,7 +29,7 @@ mod pyo3 { } } #[cfg(feature = "pyo3")] -pub(crate) use pyo3::pyobj_as_hugr; +pub(crate) use self::pyo3::pyobj_as_hugr; // utility for building simple qubit-only circuits. #[allow(unused)] From ee4fb54a1b964da1ef0d14d1847216602a114d34 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Fri, 29 Sep 2023 12:59:30 +0200 Subject: [PATCH 6/9] Address comments --- pyrs/Cargo.toml | 1 + pyrs/src/lib.rs | 12 ++------ src/optimiser.rs | 2 +- src/optimiser/taso.rs | 5 +-- src/optimiser/taso/pyo3.rs | 62 -------------------------------------- src/portmatching/pyo3.rs | 11 +++++-- src/utils.rs | 18 ----------- 7 files changed, 15 insertions(+), 96 deletions(-) delete mode 100644 src/optimiser/taso/pyo3.rs diff --git a/pyrs/Cargo.toml b/pyrs/Cargo.toml index 0008e1ce..7f113b24 100644 --- a/pyrs/Cargo.toml +++ b/pyrs/Cargo.toml @@ -18,3 +18,4 @@ tket-json-rs = { git = "https://github.com/CQCL/tket-json-rs", rev="619db15d3", quantinuum-hugr = { workspace = true } portgraph = { workspace = true, features = ["pyo3"] } pyo3 = { workspace = true, features = ["extension-module"] } +itertools.workspace = true diff --git a/pyrs/src/lib.rs b/pyrs/src/lib.rs index d070e07c..d845f2c9 100644 --- a/pyrs/src/lib.rs +++ b/pyrs/src/lib.rs @@ -1,11 +1,13 @@ //! Python bindings for TKET2. #![warn(missing_docs)] use circuit::{add_circuit_module, try_with_hugr}; +use optimiser::add_optimiser_module; use pyo3::prelude::*; -use tket2::{json::TKETDecode, optimiser::taso, passes::apply_greedy_commutation}; +use tket2::{json::TKETDecode, passes::apply_greedy_commutation}; use tket_json_rs::circuit_json::SerialCircuit; mod circuit; +mod optimiser; #[pyfunction] fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> { @@ -55,11 +57,3 @@ fn add_pass_module(py: Python, parent: &PyModule) -> PyResult<()> { parent.add_submodule(m)?; Ok(()) } - -/// circuit optimisation module -fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> { - let m = PyModule::new(py, "optimiser")?; - m.add_class::()?; - - parent.add_submodule(m) -} diff --git a/src/optimiser.rs b/src/optimiser.rs index 0760d0d3..4656d0be 100644 --- a/src/optimiser.rs +++ b/src/optimiser.rs @@ -6,4 +6,4 @@ pub mod taso; #[cfg(feature = "portmatching")] pub use taso::DefaultTasoOptimiser; -pub use taso::TasoOptimiser; +pub use taso::{TasoLogger, TasoOptimiser}; diff --git a/src/optimiser/taso.rs b/src/optimiser/taso.rs index 6d73e68c..9dde7db6 100644 --- a/src/optimiser/taso.rs +++ b/src/optimiser/taso.rs @@ -16,13 +16,12 @@ mod hugr_hash_set; mod hugr_pchannel; mod hugr_pqueue; pub mod log; -#[cfg(all(feature = "pyo3", feature = "portmatching"))] -pub mod pyo3; mod qtz_circuit; mod worker; use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; +pub use log::TasoLogger; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; @@ -37,8 +36,6 @@ use crate::optimiser::taso::worker::TasoWorker; use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; -use self::log::TasoLogger; - /// The TASO optimiser. /// /// Adapted from [Quartz][], and originally [TASO][]. diff --git a/src/optimiser/taso/pyo3.rs b/src/optimiser/taso/pyo3.rs deleted file mode 100644 index 47dd600c..00000000 --- a/src/optimiser/taso/pyo3.rs +++ /dev/null @@ -1,62 +0,0 @@ -//! PyO3 wrapper for the TASO optimiser. - -use std::{fs, num::NonZeroUsize}; - -use pyo3::{exceptions::PyTypeError, prelude::*}; -use tket_json_rs::circuit_json::SerialCircuit; - -use crate::{json::TKETDecode, utils::pyobj_as_hugr}; - -use super::{log::TasoLogger, DefaultTasoOptimiser}; - -/// Wrapped [`DefaultTasoOptimiser`]. -/// -/// Currently only exposes loading from an ECC file using the constructor -/// and optimising using default logging settings. -#[pyclass(name = "TasoOptimiser")] -pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser); - -#[pymethods] -impl PyDefaultTasoOptimiser { - /// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter. - #[staticmethod] - pub fn load_precompiled(path: &str) -> Self { - Self(DefaultTasoOptimiser::default_with_rewriter_binary(path).unwrap()) - } - - /// Create a new [`PyDefaultTasoOptimiser`] from ECC sets. - /// - /// This will compile the rewriter from the provided ECC JSON file. - #[staticmethod] - pub fn compile_eccs(path: &str) -> Self { - Self(DefaultTasoOptimiser::default_with_eccs_json_file(path).unwrap()) - } - - /// Run the optimiser on a circuit. - /// - /// Returns an optimised circuit and log the progress to a CSV - /// file called "best_circs.csv". - pub fn optimise( - &self, - circ: PyObject, - timeout: Option, - n_threads: Option, - ) -> PyResult { - let circ = pyobj_as_hugr(circ)?; - let circ_candidates_csv = fs::File::create("best_circs.csv").unwrap(); - - let taso_logger = TasoLogger::new(circ_candidates_csv); - let opt_circ = self.0.optimise_with_log( - &circ, - taso_logger, - timeout, - n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), - ); - let ser_circ = - SerialCircuit::encode(&opt_circ).map_err(|e| PyTypeError::new_err(e.to_string()))?; - let tk1_circ = ser_circ - .to_tket1() - .map_err(|e| PyTypeError::new_err(e.to_string()))?; - Ok(tk1_circ) - } -} diff --git a/src/portmatching/pyo3.rs b/src/portmatching/pyo3.rs index b2180768..1ff2a7b6 100644 --- a/src/portmatching/pyo3.rs +++ b/src/portmatching/pyo3.rs @@ -4,15 +4,16 @@ use std::fmt; use derive_more::{From, Into}; use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError; -use hugr::Port; +use hugr::{Hugr, Port}; use itertools::Itertools; use portmatching::{HashMap, PatternID}; use pyo3::{prelude::*, types::PyIterator}; +use tket_json_rs::circuit_json::SerialCircuit; use super::{CircuitPattern, PatternMatch, PatternMatcher}; use crate::circuit::Circuit; +use crate::json::TKETDecode; use crate::rewrite::CircuitRewrite; -use crate::utils::pyobj_as_hugr; #[pymethods] impl CircuitPattern { @@ -195,3 +196,9 @@ impl Node { format!("{:?}", self) } } + +fn pyobj_as_hugr(circ: PyObject) -> PyResult { + let ser_c = SerialCircuit::_from_tket1(circ); + let hugr: Hugr = ser_c.decode()?; + Ok(hugr) +} diff --git a/src/utils.rs b/src/utils.rs index 367e5d1f..124cda37 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -13,24 +13,6 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool { !TypeBound::Copyable.contains(typ.least_upper_bound()) } -// Convert a pytket object to HUGR -#[cfg(feature = "pyo3")] -mod pyo3 { - use hugr::Hugr; - use pyo3::prelude::*; - use tket_json_rs::circuit_json::SerialCircuit; - - use crate::json::TKETDecode; - - pub(crate) fn pyobj_as_hugr(circ: PyObject) -> PyResult { - let ser_c = SerialCircuit::_from_tket1(circ); - let hugr: Hugr = ser_c.decode()?; - Ok(hugr) - } -} -#[cfg(feature = "pyo3")] -pub(crate) use self::pyo3::pyobj_as_hugr; - // utility for building simple qubit-only circuits. #[allow(unused)] pub(crate) fn build_simple_circuit( From 9b3f9bbc30ea34d93bb84c2a267b16fc82f01626 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Fri, 29 Sep 2023 13:04:22 +0200 Subject: [PATCH 7/9] Actually add file --- pyrs/src/optimiser.rs | 67 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 pyrs/src/optimiser.rs diff --git a/pyrs/src/optimiser.rs b/pyrs/src/optimiser.rs new file mode 100644 index 00000000..bd7b6a4c --- /dev/null +++ b/pyrs/src/optimiser.rs @@ -0,0 +1,67 @@ +//! PyO3 wrapper for the TASO circuit optimiser. + +use std::{fs, num::NonZeroUsize, path::PathBuf}; + +use pyo3::prelude::*; +use tket2::optimiser::{DefaultTasoOptimiser, TasoLogger}; + +use crate::circuit::update_hugr; + +/// The circuit optimisation module. +pub fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> { + let m = PyModule::new(py, "optimiser")?; + m.add_class::()?; + + parent.add_submodule(m) +} + +/// Wrapped [`DefaultTasoOptimiser`]. +/// +/// Currently only exposes loading from an ECC file using the constructor +/// and optimising using default logging settings. +#[pyclass(name = "TasoOptimiser")] +pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser); + +#[pymethods] +impl PyDefaultTasoOptimiser { + /// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter. + #[staticmethod] + pub fn load_precompiled(path: PathBuf) -> Self { + Self(DefaultTasoOptimiser::default_with_rewriter_binary(path).unwrap()) + } + + /// Create a new [`PyDefaultTasoOptimiser`] from ECC sets. + /// + /// This will compile the rewriter from the provided ECC JSON file. + #[staticmethod] + pub fn compile_eccs(path: &str) -> Self { + Self(DefaultTasoOptimiser::default_with_eccs_json_file(path).unwrap()) + } + + /// Run the optimiser on a circuit. + /// + /// Returns an optimised circuit and optionally log the progress to a CSV + /// file. + pub fn optimise( + &self, + circ: PyObject, + timeout: Option, + n_threads: Option, + log_progress: Option, + ) -> PyResult { + let taso_logger = log_progress + .map(|file_name| { + let log_file = fs::File::create(file_name).unwrap(); + TasoLogger::new(log_file) + }) + .unwrap_or_default(); + update_hugr(circ, |circ| { + self.0.optimise_with_log( + &circ, + taso_logger, + timeout, + n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), + ) + }) + } +} From ba838a0fd37db3c2654559169cabbd5b27710733 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Fri, 29 Sep 2023 13:11:21 +0200 Subject: [PATCH 8/9] feat: Serialisation for ECCRewriter (#141) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com> # Conflicts: # compile-rewriter/src/main.rs # src/rewrite/ecc_rewriter.rs --- compile-rewriter/src/main.rs | 2 +- src/rewrite/ecc_rewriter.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/compile-rewriter/src/main.rs b/compile-rewriter/src/main.rs index 6960426e..77384f03 100644 --- a/compile-rewriter/src/main.rs +++ b/compile-rewriter/src/main.rs @@ -58,7 +58,7 @@ fn main() { } else { output_path.to_path_buf() }; - let output_file = rewriter.save_binary(output_file.to_str().unwrap()).unwrap(); + let output_file = rewriter.save_binary(output_file).unwrap(); println!("Written rewriter to {:?}", output_file); // Print the file size of output_file in megabytes diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index dba5d3cc..8f0a5a8d 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -129,7 +129,8 @@ impl ECCRewriter { ) -> Result { let mut file_name = PathBuf::from(name.as_ref()); file_name.set_extension("rwr"); - let mut file = File::create(&file_name)?; + let file = File::create(&file_name)?; + let mut file = io::BufWriter::new(file); self.save_binary_io(&mut file)?; Ok(file_name) } From 45f26cdc8cd8875e96a2da3e6ee15b5c1fde7b95 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Fri, 29 Sep 2023 13:14:00 +0200 Subject: [PATCH 9/9] Remove itertools dep --- pyrs/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyrs/Cargo.toml b/pyrs/Cargo.toml index 7f113b24..0008e1ce 100644 --- a/pyrs/Cargo.toml +++ b/pyrs/Cargo.toml @@ -18,4 +18,3 @@ tket-json-rs = { git = "https://github.com/CQCL/tket-json-rs", rev="619db15d3", quantinuum-hugr = { workspace = true } portgraph = { workspace = true, features = ["pyo3"] } pyo3 = { workspace = true, features = ["extension-module"] } -itertools.workspace = true