Skip to content

Commit

Permalink
feat: Serialisation for ECCRewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada committed Sep 27, 2023
1 parent be8b9a9 commit d068c3b
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 97 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ harness = false

[workspace]

members = ["pyrs", "compile-matcher", "taso-optimiser"]
members = ["pyrs", "compile-rewriter", "taso-optimiser"]

[workspace.dependencies]

Expand Down
93 changes: 0 additions & 93 deletions compile-matcher/src/main.rs

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
75 changes: 75 additions & 0 deletions compile-rewriter/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use std::fs;
use std::path::Path;
use std::process::exit;
use std::time::Instant;

use clap::Parser;

use tket2::rewrite::ECCRewriter;

/// Program to precompile patterns from files into a PatternMatcher stored as binary file.
#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
#[clap(
about = "Precompiles ECC sets into a TKET2 Rewriter. The resulting binary files can be loaded into TKET2 for circuit optimisation."
)]
struct CmdLineArgs {
// TODO: Differentiate between TK1 input and ECC input
/// Name of input file/folder
#[arg(
short,
long,
value_name = "FILE",
help = "Sets the input file to use. It must be a JSON file of ECC sets in the Quartz format."
)]
input: String,
/// Name of output file/folder
#[arg(
short,
long,
value_name = "FILE",
default_value = ".",
help = "Sets the output file or folder. Defaults to \"matcher.rwr\" if no file name is provided. The extension of the file name will always be set or amended to be `.rwr`."
)]
output: String,
}

fn main() {
let opts = CmdLineArgs::parse();

let input_path = Path::new(&opts.input);
let output_path = Path::new(&opts.output);

if !input_path.is_file() || input_path.extension().unwrap() != "json" {
panic!("Input must be a JSON file");
};
let start_time = Instant::now();
println!("Compiling rewriter...");
let Ok(rewriter) = ECCRewriter::try_from_eccs_json_file(input_path) else {
eprintln!(
"Unable to load ECC file {:?}. Is it a JSON file of Quartz-generated ECCs?",
input_path
);
exit(1);
};
println!("Saving to file...");
let output_file = if output_path.is_dir() {
output_path.join("matcher.rwr")
} else {
output_path.to_path_buf()
};
let output_file = rewriter.save_binary(output_file.to_str().unwrap()).unwrap();
println!("Written rewriter to {:?}", output_file);

// Print the file size of output_file in megabytes
if let Ok(metadata) = fs::metadata(&output_file) {
let file_size = metadata.len() as f64 / (1024.0 * 1024.0);
println!("File size: {:.2} MB", file_size);
}
let elapsed = start_time.elapsed();
println!(
"Done in {}.{:03} seconds",
elapsed.as_secs(),
elapsed.subsec_millis()
);
}
69 changes: 66 additions & 3 deletions src/rewrite/ecc_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
use derive_more::{From, Into};
use itertools::Itertools;
use portmatching::PatternID;
use std::io;
use std::fs::File;
use std::path::Path;
use std::{io, path::PathBuf};
use thiserror::Error;

use hugr::Hugr;

Expand All @@ -28,7 +30,7 @@ use crate::{

use super::{CircuitRewrite, Rewriter};

#[derive(Debug, Clone, Copy, PartialEq, Eq, From, Into)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, From, Into, serde::Serialize, serde::Deserialize)]
struct TargetID(usize);

/// A rewriter based on circuit equivalence classes.
Expand All @@ -37,7 +39,7 @@ struct TargetID(usize);
/// Valid rewrites turn a non-representative circuit into its representative,
/// or a representative circuit into any of the equivalent non-representative
/// circuits.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ECCRewriter {
/// Matcher for finding patterns.
matcher: PatternMatcher,
Expand Down Expand Up @@ -91,6 +93,53 @@ impl ECCRewriter {
.iter()
.map(|id| &self.targets[id.0])
}

/// Serialise a rewriter to an IO stream.
///
/// Precomputed rewriters can be serialised as binary and then loaded
/// later using [`ECCRewriter::load_binary_io`].
pub fn save_binary_io<W: io::Write>(
&self,
writer: &mut W,
) -> Result<(), RewriterSerialisationError> {
rmp_serde::encode::write(writer, &self)?;
Ok(())
}

/// Load a rewriter from an IO stream.
///
/// Loads streams as created by [`ECCRewriter::save_binary_io`].
pub fn load_binary_io<R: io::Read>(reader: &mut R) -> Result<Self, RewriterSerialisationError> {
let matcher: Self = rmp_serde::decode::from_read(reader)?;
Ok(matcher)
}

/// Save a rewriter as a binary file.
///
/// Precomputed rewriters can be saved as binary files and then loaded
/// later using [`ECCRewriter::load_binary`].
///
/// The extension of the file name will always be set or amended to be
/// `.rwr`.
///
/// If successful, returns the path to the newly created file.
pub fn save_binary(
&self,
name: impl AsRef<Path>,
) -> Result<PathBuf, RewriterSerialisationError> {
let mut file_name = PathBuf::from(name.as_ref());
file_name.set_extension("rwr");
let mut file = File::create(&file_name)?;
self.save_binary_io(&mut file)?;
Ok(file_name)
}

/// Loads a rewriter saved using [`ECCRewriter::save_binary`].
pub fn load_binary(name: impl AsRef<Path>) -> Result<Self, RewriterSerialisationError> {
let file = File::open(name)?;
let mut reader = std::io::BufReader::new(file);
Self::load_binary_io(&mut reader)
}
}

impl Rewriter for ECCRewriter {
Expand All @@ -109,6 +158,20 @@ impl Rewriter for ECCRewriter {
}
}

/// Errors that can occur when (de)serialising an [`ECCRewriter`].
#[derive(Debug, Error)]
pub enum RewriterSerialisationError {
/// An IO error occured
#[error("IO error: {0}")]
Io(#[from] io::Error),
/// An error occured during deserialisation
#[error("Deserialisation error: {0}")]
Deserialisation(#[from] rmp_serde::decode::Error),
/// An error occured during serialisation
#[error("Serialisation error: {0}")]
Serialisation(#[from] rmp_serde::encode::Error),
}

fn into_targets(rep_sets: Vec<EqCircClass>) -> Vec<Hugr> {
rep_sets
.into_iter()
Expand Down

0 comments on commit d068c3b

Please sign in to comment.