Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Python bindings for TASO #142

Merged
merged 13 commits into from
Sep 29, 2023
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ harness = false

[workspace]

members = ["pyrs", "compile-matcher", "taso-optimiser"]
members = ["pyrs", "compile-rewriter", "taso-optimiser"]

[workspace.dependencies]

Expand Down
93 changes: 0 additions & 93 deletions compile-matcher/src/main.rs

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
75 changes: 75 additions & 0 deletions compile-rewriter/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::fs;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

git doesn't seem to have picked up this rename

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's because there were too many changes, so it views it as a new file. Can I do anything about it? Pretty sure I git mved it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no don't worry

use std::path::Path;
use std::process::exit;
use std::time::Instant;

use clap::Parser;

use tket2::rewrite::ECCRewriter;

/// Program to precompile patterns from files into a PatternMatcher stored as binary file.
#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
#[clap(
about = "Precompiles ECC sets into a TKET2 Rewriter. The resulting binary files can be loaded into TKET2 for circuit optimisation."
)]
struct CmdLineArgs {
// TODO: Differentiate between TK1 input and ECC input
/// Name of input file/folder
#[arg(
short,
long,
value_name = "FILE",
help = "Sets the input file to use. It must be a JSON file of ECC sets in the Quartz format."
)]
input: String,
/// Name of output file/folder
#[arg(
short,
long,
value_name = "FILE",
default_value = ".",
help = "Sets the output file or folder. Defaults to \"matcher.rwr\" if no file name is provided. The extension of the file name will always be set or amended to be `.rwr`."
)]
output: String,
}

fn main() {
let opts = CmdLineArgs::parse();

let input_path = Path::new(&opts.input);
let output_path = Path::new(&opts.output);

if !input_path.is_file() || input_path.extension().unwrap() != "json" {
panic!("Input must be a JSON file");
};
let start_time = Instant::now();
println!("Compiling rewriter...");
let Ok(rewriter) = ECCRewriter::try_from_eccs_json_file(input_path) else {
eprintln!(
"Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?",
input_path
);
exit(1);
};
println!("Saving to file...");
let output_file = if output_path.is_dir() {
output_path.join("matcher.rwr")
} else {
output_path.to_path_buf()
};
let output_file = rewriter.save_binary(output_file.to_str().unwrap()).unwrap();
println!("Written rewriter to {:?}", output_file);

// Print the file size of output_file in megabytes
if let Ok(metadata) = fs::metadata(&output_file) {
let file_size = metadata.len() as f64 / (1024.0 * 1024.0);
println!("File size: {:.2} MB", file_size);
}
let elapsed = start_time.elapsed();
println!(
"Done in {}.{:03} seconds",
elapsed.as_secs(),
elapsed.subsec_millis()
);
}
11 changes: 10 additions & 1 deletion pyrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#![warn(missing_docs)]
use circuit::{add_circuit_module, try_with_hugr};
use pyo3::prelude::*;
use tket2::{json::TKETDecode, passes::apply_greedy_commutation};
use tket2::{json::TKETDecode, optimiser::taso, passes::apply_greedy_commutation};
use tket_json_rs::circuit_json::SerialCircuit;

mod circuit;
Expand All @@ -22,6 +22,7 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
add_circuit_module(py, m)?;
add_pattern_module(py, m)?;
add_pass_module(py, m)?;
add_optimiser_module(py, m)?;
Ok(())
}

Expand Down Expand Up @@ -54,3 +55,11 @@ fn add_pass_module(py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_submodule(m)?;
Ok(())
}

/// circuit optimisation module
fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "optimiser")?;
m.add_class::<taso::pyo3::PyDefaultTasoOptimiser>()?;

parent.add_submodule(m)
}
21 changes: 15 additions & 6 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ mod hugr_hash_set;
mod hugr_pchannel;
mod hugr_pqueue;
pub mod log;
#[cfg(all(feature = "pyo3", feature = "portmatching"))]
pub mod pyo3;
mod qtz_circuit;
mod worker;

Expand All @@ -37,9 +39,6 @@ use crate::rewrite::Rewriter;

use self::log::TasoLogger;

#[cfg(feature = "portmatching")]
use std::io;

/// The TASO optimiser.
///
/// Adapted from [Quartz][], and originally [TASO][].
Expand Down Expand Up @@ -274,8 +273,11 @@ where
mod taso_default {
use hugr::ops::OpType;
use hugr::HugrView;
use std::io;
use std::path::Path;

use crate::ops::op_matches;
use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::ExhaustiveRewriteStrategy;
use crate::rewrite::ECCRewriter;
use crate::T2Op;
Expand All @@ -291,13 +293,20 @@ mod taso_default {

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(
eccs_path: impl AsRef<std::path::Path>,
) -> io::Result<Self> {
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}

/// A sane default optimiser using a precompiled binary rewriter.
pub fn default_with_rewriter_binary(
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = ExhaustiveRewriteStrategy::exhaustive_cx();
Ok(TasoOptimiser::new(rewriter, strategy, num_cx_gates))
}
}

fn num_cx_gates(circ: &Hugr) -> usize {
Expand Down
62 changes: 62 additions & 0 deletions src/optimiser/taso/pyo3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//! PyO3 wrapper for the TASO optimiser.

use std::{fs, num::NonZeroUsize};

use pyo3::{exceptions::PyTypeError, prelude::*};
use tket_json_rs::circuit_json::SerialCircuit;

use crate::{json::TKETDecode, utils::pyobj_as_hugr};

use super::{log::TasoLogger, DefaultTasoOptimiser};

/// Wrapped [`DefaultTasoOptimiser`].
///
/// Currently only exposes loading from an ECC file using the constructor
/// and optimising using default logging settings.
#[pyclass(name = "TasoOptimiser")]
pub struct PyDefaultTasoOptimiser(DefaultTasoOptimiser);

#[pymethods]
impl PyDefaultTasoOptimiser {
/// Create a new [`PyDefaultTasoOptimiser`] from a precompiled rewriter.
#[staticmethod]
pub fn load_precompiled(path: &str) -> Self {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can use Path or PathBuf objects and they will be mapped to python pathlib.Path

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! It seems I need to pass a PathBuf by value, but I assume given this is just for Python bindings it doesn't matter anyways.

Self(DefaultTasoOptimiser::default_with_rewriter_binary(path).unwrap())
}

/// Create a new [`PyDefaultTasoOptimiser`] from ECC sets.
///
/// This will compile the rewriter from the provided ECC JSON file.
#[staticmethod]
pub fn compile_eccs(path: &str) -> Self {
Self(DefaultTasoOptimiser::default_with_eccs_json_file(path).unwrap())
}

/// Run the optimiser on a circuit.
///
/// Returns an optimised circuit and log the progress to a CSV
/// file called "best_circs.csv".
pub fn optimise(
&self,
circ: PyObject,
timeout: Option<u64>,
n_threads: Option<NonZeroUsize>,
) -> PyResult<PyObject> {
let circ = pyobj_as_hugr(circ)?;
let circ_candidates_csv = fs::File::create("best_circs.csv").unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this path a (optional?) function parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure, done.


let taso_logger = TasoLogger::new(circ_candidates_csv);
let opt_circ = self.0.optimise_with_log(
&circ,
taso_logger,
timeout,
n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
);
let ser_circ =
SerialCircuit::encode(&opt_circ).map_err(|e| PyTypeError::new_err(e.to_string()))?;
let tk1_circ = ser_circ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to use try_update_hugr from circuit.rs for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out: had never known about the goodness in this file! I had to move the pyo3.rs file out of tket2 and into pyrs to use it. That was probably the right thing to do in the first place.

.to_tket1()
.map_err(|e| PyTypeError::new_err(e.to_string()))?;
Ok(tk1_circ)
}
}
11 changes: 2 additions & 9 deletions src/portmatching/pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ use std::fmt;

use derive_more::{From, Into};
use hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError;
use hugr::{Hugr, Port};
use hugr::Port;
use itertools::Itertools;
use portmatching::{HashMap, PatternID};
use pyo3::{prelude::*, types::PyIterator};
use tket_json_rs::circuit_json::SerialCircuit;

use super::{CircuitPattern, PatternMatch, PatternMatcher};
use crate::circuit::Circuit;
use crate::json::TKETDecode;
use crate::rewrite::CircuitRewrite;
use crate::utils::pyobj_as_hugr;

#[pymethods]
impl CircuitPattern {
Expand Down Expand Up @@ -196,9 +195,3 @@ impl Node {
format!("{:?}", self)
}
}

fn pyobj_as_hugr(circ: PyObject) -> PyResult<Hugr> {
let ser_c = SerialCircuit::_from_tket1(circ);
let hugr: Hugr = ser_c.decode()?;
Ok(hugr)
}
Loading