Skip to content

Commit

Permalink
Merge branch 'main' into feat/more-strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
lmondada authored Oct 3, 2023
2 parents 4c10e51 + d4f1cc6 commit e5d3fe7
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 11 deletions.
93 changes: 82 additions & 11 deletions src/rewrite/ecc_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
//! of the Quartz repository.
use derive_more::{From, Into};
use hugr::hugr::PortIndex;
use hugr::ops::OpTrait;
use itertools::Itertools;
use portmatching::PatternID;
use std::fs::File;
use std::path::Path;
use std::{io, path::PathBuf};
use std::{
collections::HashSet,
fs::File,
io,
path::{Path, PathBuf},
};
use thiserror::Error;

use hugr::Hugr;

use crate::{
circuit::Circuit,
circuit::{remove_empty_wire, Circuit},
optimiser::taso::{load_eccs_json_file, EqCircClass},
portmatching::{CircuitPattern, PatternMatcher},
};
Expand All @@ -49,6 +54,9 @@ pub struct ECCRewriter {
/// target TargetIDs. The usize index of PatternID is used to index into
/// the outer vector.
rewrite_rules: Vec<Vec<TargetID>>,
/// Wires that have been removed in the pattern circuit -- to be removed
/// in the target circuit as well when generating a rewrite.
empty_wires: Vec<Vec<usize>>,
}

impl ECCRewriter {
Expand All @@ -72,18 +80,34 @@ impl ECCRewriter {
let eccs = eccs.into();
let rewrite_rules = get_rewrite_rules(&eccs);
let patterns = get_patterns(&eccs);
let targets = into_targets(eccs);
// Remove failed patterns
let (patterns, rewrite_rules): (Vec<_>, Vec<_>) = patterns
let (patterns, empty_wires, rewrite_rules): (Vec<_>, Vec<_>, Vec<_>) = patterns
.into_iter()
.zip(rewrite_rules)
.filter_map(|(p, r)| Some((p?, r)))
.unzip();
let targets = into_targets(eccs);
.filter_map(|(p, r)| {
// Filter out target IDs where empty wires are not empty
let (pattern, pattern_empty_wires) = p?;
let targets = r
.into_iter()
.filter(|&id| {
let circ = &targets[id.0];
let target_empty_wires: HashSet<_> =
empty_wires(&circ).into_iter().collect();
pattern_empty_wires
.iter()
.all(|&w| target_empty_wires.contains(&w))
})
.collect();
Some((pattern, pattern_empty_wires, targets))
})
.multiunzip();
let matcher = PatternMatcher::from_patterns(patterns);
Self {
matcher,
targets,
rewrite_rules,
empty_wires,
}
}

Expand Down Expand Up @@ -151,7 +175,11 @@ impl Rewriter for ECCRewriter {
.flat_map(|m| {
let pattern_id = m.pattern_id();
self.get_targets(pattern_id).map(move |repl| {
m.to_rewrite(circ.base_hugr(), repl.clone())
let mut repl = repl.clone();
for &empty_qb in self.empty_wires[pattern_id.0].iter().rev() {
remove_empty_wire(&mut repl, empty_qb).unwrap();
}
m.to_rewrite(circ.base_hugr(), repl)
.expect("invalid replacement")
})
})
Expand Down Expand Up @@ -198,11 +226,43 @@ fn get_rewrite_rules(rep_sets: &[EqCircClass]) -> Vec<Vec<TargetID>> {
rewrite_rules
}

fn get_patterns(rep_sets: &[EqCircClass]) -> Vec<Option<CircuitPattern>> {
/// For an equivalence class, return all valid patterns together with the
/// indices of the wires that have been removed in the pattern circuit.
fn get_patterns(rep_sets: &[EqCircClass]) -> Vec<Option<(CircuitPattern, Vec<usize>)>> {
rep_sets
.iter()
.flat_map(|rs| rs.circuits())
.map(|circ| CircuitPattern::try_from_circuit(circ).ok())
.map(|circ| {
let empty_qbs = empty_wires(circ);
let mut circ = circ.clone();
for &qb in empty_qbs.iter().rev() {
remove_empty_wire(&mut circ, qb).unwrap();
}
CircuitPattern::try_from_circuit(&circ)
.ok()
.map(|circ| (circ, empty_qbs))
})
.collect()
}

/// The port offsets of wires that are empty.
fn empty_wires(circ: &impl Circuit) -> Vec<usize> {
let inp = circ.input();
circ.node_outputs(inp)
// Only consider dataflow edges
.filter(|&p| circ.get_optype(inp).signature().get(p).is_some())
// Only consider ports linked to at most one other port
.filter_map(|p| Some((p, circ.linked_ports(inp, p).at_most_one().ok()?)))
// Ports are either connected to output or nothing
.filter_map(|(from, to)| {
if let Some((n, _)) = to {
// Wires connected to output
(n == circ.output()).then_some(from.index())
} else {
// Wires connected to nothing
Some(from.index())
}
})
.collect()
}

Expand Down Expand Up @@ -305,4 +365,15 @@ mod tests {
let exp_n_eccs_of_len = [0, 5 * 2 + 5 * 3, 5, 5];
assert_eq!(n_eccs_of_len, exp_n_eccs_of_len);
}

/// Some inputs are left untouched: these parameters should be removed to
/// obtain convex patterns
#[test]
fn ecc_rewriter_empty_params() {
let test_file = "test_files/cx_cx_eccs.json";
let rewriter = ECCRewriter::try_from_eccs_json_file(test_file).unwrap();

let cx_cx = cx_cx();
assert_eq!(rewriter.get_rewrites(&cx_cx).len(), 1);
}
}
8 changes: 8 additions & 0 deletions test_files/cx_cx_eccs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[[],
{
"7779_2": [
[[3,0,0,0,["2acba3946000"],[7.23975983031379111e-02,-6.01244148396314765e-02]],[]]
,[[3,0,0,2,["2acba3946000"],[7.23975983031379111e-02,-6.01244148396314765e-02]],[["cx", ["Q0", "Q1"],["Q0", "Q1"]],["cx", ["Q0", "Q1"],["Q0", "Q1"]]]]
]
}
]

0 comments on commit e5d3fe7

Please sign in to comment.