Skip to content

Commit

Permalink
Merge branch 'main' into feat/commutation
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Sep 7, 2023
2 parents 65659d0 + f4ed814 commit 46fd4f8
Show file tree
Hide file tree
Showing 12 changed files with 614 additions and 122 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ strum_macros = "0.25.2"
strum = "0.25.0"
fxhash = "0.2.1"
rmp-serde = { version = "1.1.2", optional = true }
delegate = "0.10.0"

[features]
pyo3 = [
Expand Down
14 changes: 7 additions & 7 deletions compile-matcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ use hugr::HugrView;
use itertools::Itertools;

use tket2::json::load_tk1_json_file;
// Import the CircuitMatcher struct and its methods
use tket2::passes::taso::rep_sets_from_path;
use tket2::portmatching::{CircuitMatcher, CircuitPattern};
// Import the PatternMatcher struct and its methods
use tket2::passes::taso::load_eccs_json_file;
use tket2::portmatching::{CircuitPattern, PatternMatcher};

/// Program to precompile patterns from files into a CircuitMatcher stored as binary file.
/// Program to precompile patterns from files into a PatternMatcher stored as binary file.
#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
#[clap(about = "Precompiles patterns from files into a CircuitMatcher stored as binary file.")]
#[clap(about = "Precompiles patterns from files into a PatternMatcher stored as binary file.")]
struct CmdLineArgs {
// TODO: Differentiate between TK1 input and ECC input
/// Name of input file/folder
Expand Down Expand Up @@ -45,7 +45,7 @@ fn main() {

let all_circs = if input_path.is_file() {
// Input is an ECC file in JSON format
let eccs = rep_sets_from_path(input_path);
let eccs = load_eccs_json_file(input_path);
eccs.into_iter()
.flat_map(|ecc| ecc.into_circuits())
.collect_vec()
Expand Down Expand Up @@ -78,7 +78,7 @@ fn main() {
} else {
output_path.to_path_buf()
};
let matcher = CircuitMatcher::from_patterns(patterns);
let matcher = PatternMatcher::from_patterns(patterns);
matcher.save_binary(output_file.to_str().unwrap()).unwrap();
println!("Written matcher to {:?}", output_file);

Expand Down
4 changes: 2 additions & 2 deletions pyrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use tket2::{
circuit::HierarchyView,
json::TKETDecode,
passes::apply_greedy_commutation,
portmatching::{CircuitMatcher, CircuitPattern},
portmatching::{CircuitPattern, PatternMatcher},
};
use tket_json_rs::circuit_json::SerialCircuit;

Expand Down Expand Up @@ -63,7 +63,7 @@ fn pyrs(py: Python, m: &PyModule) -> PyResult<()> {
fn add_patterns_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "patterns")?;
m.add_class::<CircuitPattern>()?;
m.add_class::<CircuitMatcher>()?;
m.add_class::<PatternMatcher>()?;
parent.add_submodule(m)?;
Ok(())
}
Expand Down
6 changes: 3 additions & 3 deletions pyrs/test/test_portmatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ def test_simple_matching():
p1 = patterns.CircuitPattern(Circuit(2).CX(0, 1).H(1))
p2 = patterns.CircuitPattern(Circuit(2).H(0).CX(1, 0))

matcher = patterns.CircuitMatcher(iter([p1, p2]))
matcher = patterns.PatternMatcher(iter([p1, p2]))

assert len(matcher.find_matches(c)) == 2


def test_non_convex_pattern():
""" two-qubit circuits can't match three-qb ones """
p1 = patterns.CircuitPattern(Circuit(3).CX(0, 1).CX(1, 2))
matcher = patterns.CircuitMatcher(iter([p1]))
matcher = patterns.PatternMatcher(iter([p1]))

c = Circuit(2).CX(0, 1).CX(1, 0)
assert len(matcher.find_matches(c)) == 0
Expand All @@ -39,6 +39,6 @@ def test_larger_matching():
p3 = patterns.CircuitPattern(Circuit(2).CX(0, 1).CX(1, 0))
p4 = patterns.CircuitPattern(Circuit(3).CX(0, 1).CX(1, 2))

matcher = patterns.CircuitMatcher(iter([p1, p2, p3, p4]))
matcher = patterns.PatternMatcher(iter([p1, p2, p3, p4]))

assert len(matcher.find_matches(c)) == 6
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod extension;
pub mod json;
pub(crate) mod ops;
pub mod passes;
pub mod rewrite;
pub use ops::{symbolic_constant_op, Pauli, T2Op};

#[cfg(feature = "portmatching")]
Expand Down
22 changes: 17 additions & 5 deletions src/passes/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@ use hugr::Hugr;

mod qtz_circuit;

/// An equivalent circuit class (ECC), with a canonical representative.
#[derive(Clone)]
#[allow(unused)] // TODO
pub struct RepCircSet {
pub struct EqCircClass {
rep_circ: Hugr,
others: Vec<Hugr>,
}

impl RepCircSet {
impl EqCircClass {
pub fn new(rep_circ: Hugr, others: Vec<Hugr>) -> Self {
Self { rep_circ, others }
}

/// The representative circuit of the equivalence class.
pub fn rep_circ(&self) -> &Hugr {
&self.rep_circ
Expand All @@ -39,11 +43,19 @@ impl RepCircSet {
pub fn into_circuits(self) -> impl Iterator<Item = Hugr> {
std::iter::once(self.rep_circ).chain(self.others)
}

/// The number of circuits in the equivalence class.
///
/// An ECC always has a representative circuit, so this method will always
/// return an integer strictly greater than 0.
pub fn n_circuits(&self) -> usize {
self.others.len() + 1
}
}

// TODO refactor so both implementations share more code

pub fn rep_sets_from_path(path: impl AsRef<Path>) -> Vec<RepCircSet> {
pub fn load_eccs_json_file(path: impl AsRef<Path>) -> Vec<EqCircClass> {
let all_circs = qtz_circuit::load_ecc_set(path);

all_circs
Expand All @@ -52,7 +64,7 @@ pub fn rep_sets_from_path(path: impl AsRef<Path>) -> Vec<RepCircSet> {
// TODO is the rep circ always the first??
let rep_circ = all.remove(0);

RepCircSet {
EqCircClass {
rep_circ,
others: all,
}
Expand Down
2 changes: 1 addition & 1 deletion src/portmatching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod pattern;
#[allow(missing_docs)]
pub mod pyo3;

pub use matcher::{CircuitMatch, CircuitMatcher, CircuitRewrite};
pub use matcher::{PatternMatch, PatternMatcher};
pub use pattern::CircuitPattern;

use hugr::Port;
Expand Down
Loading

0 comments on commit 46fd4f8

Please sign in to comment.