diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index 8f0a5a8d..c8e95f0f 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -13,17 +13,22 @@ //! of the Quartz repository. use derive_more::{From, Into}; +use hugr::hugr::PortIndex; +use hugr::ops::OpTrait; use itertools::Itertools; use portmatching::PatternID; -use std::fs::File; -use std::path::Path; -use std::{io, path::PathBuf}; +use std::{ + collections::HashSet, + fs::File, + io, + path::{Path, PathBuf}, +}; use thiserror::Error; use hugr::Hugr; use crate::{ - circuit::Circuit, + circuit::{remove_empty_wire, Circuit}, optimiser::taso::{load_eccs_json_file, EqCircClass}, portmatching::{CircuitPattern, PatternMatcher}, }; @@ -49,6 +54,9 @@ pub struct ECCRewriter { /// target TargetIDs. The usize index of PatternID is used to index into /// the outer vector. rewrite_rules: Vec>, + /// Wires that have been removed in the pattern circuit -- to be removed + /// in the target circuit as well when generating a rewrite. + empty_wires: Vec>, } impl ECCRewriter { @@ -72,18 +80,34 @@ impl ECCRewriter { let eccs = eccs.into(); let rewrite_rules = get_rewrite_rules(&eccs); let patterns = get_patterns(&eccs); + let targets = into_targets(eccs); // Remove failed patterns - let (patterns, rewrite_rules): (Vec<_>, Vec<_>) = patterns + let (patterns, empty_wires, rewrite_rules): (Vec<_>, Vec<_>, Vec<_>) = patterns .into_iter() .zip(rewrite_rules) - .filter_map(|(p, r)| Some((p?, r))) - .unzip(); - let targets = into_targets(eccs); + .filter_map(|(p, r)| { + // Filter out target IDs where empty wires are not empty + let (pattern, pattern_empty_wires) = p?; + let targets = r + .into_iter() + .filter(|&id| { + let circ = &targets[id.0]; + let target_empty_wires: HashSet<_> = + empty_wires(&circ).into_iter().collect(); + pattern_empty_wires + .iter() + .all(|&w| target_empty_wires.contains(&w)) + }) + .collect(); + Some((pattern, pattern_empty_wires, targets)) + }) + .multiunzip(); let matcher = PatternMatcher::from_patterns(patterns); Self { matcher, targets, rewrite_rules, + empty_wires, } } @@ -151,7 +175,11 @@ impl Rewriter for ECCRewriter { .flat_map(|m| { let pattern_id = m.pattern_id(); self.get_targets(pattern_id).map(move |repl| { - m.to_rewrite(circ.base_hugr(), repl.clone()) + let mut repl = repl.clone(); + for &empty_qb in self.empty_wires[pattern_id.0].iter().rev() { + remove_empty_wire(&mut repl, empty_qb).unwrap(); + } + m.to_rewrite(circ.base_hugr(), repl) .expect("invalid replacement") }) }) @@ -198,11 +226,43 @@ fn get_rewrite_rules(rep_sets: &[EqCircClass]) -> Vec> { rewrite_rules } -fn get_patterns(rep_sets: &[EqCircClass]) -> Vec> { +/// For an equivalence class, return all valid patterns together with the +/// indices of the wires that have been removed in the pattern circuit. +fn get_patterns(rep_sets: &[EqCircClass]) -> Vec)>> { rep_sets .iter() .flat_map(|rs| rs.circuits()) - .map(|circ| CircuitPattern::try_from_circuit(circ).ok()) + .map(|circ| { + let empty_qbs = empty_wires(circ); + let mut circ = circ.clone(); + for &qb in empty_qbs.iter().rev() { + remove_empty_wire(&mut circ, qb).unwrap(); + } + CircuitPattern::try_from_circuit(&circ) + .ok() + .map(|circ| (circ, empty_qbs)) + }) + .collect() +} + +/// The port offsets of wires that are empty. +fn empty_wires(circ: &impl Circuit) -> Vec { + let inp = circ.input(); + circ.node_outputs(inp) + // Only consider dataflow edges + .filter(|&p| circ.get_optype(inp).signature().get(p).is_some()) + // Only consider ports linked to at most one other port + .filter_map(|p| Some((p, circ.linked_ports(inp, p).at_most_one().ok()?))) + // Ports are either connected to output or nothing + .filter_map(|(from, to)| { + if let Some((n, _)) = to { + // Wires connected to output + (n == circ.output()).then_some(from.index()) + } else { + // Wires connected to nothing + Some(from.index()) + } + }) .collect() } @@ -305,4 +365,15 @@ mod tests { let exp_n_eccs_of_len = [0, 4 * 2 + 5 * 3, 4, 5]; assert_eq!(n_eccs_of_len, exp_n_eccs_of_len); } + + /// Some inputs are left untouched: these parameters should be removed to + /// obtain convex patterns + #[test] + fn ecc_rewriter_empty_params() { + let test_file = "test_files/cx_cx_eccs.json"; + let rewriter = ECCRewriter::try_from_eccs_json_file(test_file).unwrap(); + + let cx_cx = cx_cx(); + assert_eq!(rewriter.get_rewrites(&cx_cx).len(), 1); + } } diff --git a/test_files/cx_cx_eccs.json b/test_files/cx_cx_eccs.json new file mode 100644 index 00000000..8cf933f3 --- /dev/null +++ b/test_files/cx_cx_eccs.json @@ -0,0 +1,8 @@ +[[], +{ +"7779_2": [ +[[3,0,0,0,["2acba3946000"],[7.23975983031379111e-02,-6.01244148396314765e-02]],[]] +,[[3,0,0,2,["2acba3946000"],[7.23975983031379111e-02,-6.01244148396314765e-02]],[["cx", ["Q0", "Q1"],["Q0", "Q1"]],["cx", ["Q0", "Q1"],["Q0", "Q1"]]]] +] +} +] \ No newline at end of file