Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feat/taso-split
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Oct 2, 2023
2 parents e076c80 + 742e64c commit 9a40ac9
Show file tree
Hide file tree
Showing 5 changed files with 7,685 additions and 23 deletions.
10 changes: 3 additions & 7 deletions src/json/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,12 @@ impl JsonEncoder {
// Uses the names from the metadata if available, or initializes new sequentially-numbered registers.
let mut bit_to_reg = HashMap::new();
let mut qubit_to_reg = HashMap::new();
let get_register = |registers: &mut Vec<Register>, prefix: &str, index| match registers
.get(index)
.cloned()
{
Some(r) => r,
None => {
let get_register = |registers: &mut Vec<Register>, prefix: &str, index| {
registers.get(index).cloned().unwrap_or_else(|| {
let r = Register(prefix.to_string(), vec![index as i64]);
registers.push(r.clone());
r
}
})
};
for (unit, _, ty) in circ.units() {
if ty == QB_T {
Expand Down
21 changes: 16 additions & 5 deletions src/json/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! General tests.
use std::collections::HashSet;
use std::io::BufReader;

use hugr::Hugr;
use rstest::rstest;
Expand Down Expand Up @@ -64,16 +64,27 @@ fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num
compare_serial_circs(&ser, &reser);
}

#[rstest]
#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
#[case::barenco_tof_10("test_files/barenco_tof_10.json")]
fn json_file_roundtrip(#[case] circ: impl AsRef<std::path::Path>) {
let reader = BufReader::new(std::fs::File::open(circ).unwrap());
let ser: circuit_json::SerialCircuit = serde_json::from_reader(reader).unwrap();
let circ: Hugr = ser.clone().decode().unwrap();
let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap();
compare_serial_circs(&ser, &reser);
}

fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) {
assert_eq!(a.name, b.name);
assert_eq!(a.phase, b.phase);

let qubits_a: HashSet<_> = a.qubits.iter().collect();
let qubits_b: HashSet<_> = b.qubits.iter().collect();
let qubits_a: Vec<_> = a.qubits.iter().collect();
let qubits_b: Vec<_> = b.qubits.iter().collect();
assert_eq!(qubits_a, qubits_b);

let bits_a: HashSet<_> = a.bits.iter().collect();
let bits_b: HashSet<_> = b.bits.iter().collect();
let bits_a: Vec<_> = a.bits.iter().collect();
let bits_b: Vec<_> = b.bits.iter().collect();
assert_eq!(bits_a, bits_b);

assert_eq!(a.implicit_permutation, b.implicit_permutation);
Expand Down
93 changes: 82 additions & 11 deletions src/rewrite/ecc_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@
//! of the Quartz repository.
use derive_more::{From, Into};
use hugr::hugr::PortIndex;
use hugr::ops::OpTrait;
use itertools::Itertools;
use portmatching::PatternID;
use std::fs::File;
use std::path::Path;
use std::{io, path::PathBuf};
use std::{
collections::HashSet,
fs::File,
io,
path::{Path, PathBuf},
};
use thiserror::Error;

use hugr::Hugr;

use crate::{
circuit::Circuit,
circuit::{remove_empty_wire, Circuit},
optimiser::taso::{load_eccs_json_file, EqCircClass},
portmatching::{CircuitPattern, PatternMatcher},
};
Expand All @@ -49,6 +54,9 @@ pub struct ECCRewriter {
/// target TargetIDs. The usize index of PatternID is used to index into
/// the outer vector.
rewrite_rules: Vec<Vec<TargetID>>,
/// Wires that have been removed in the pattern circuit -- to be removed
/// in the target circuit as well when generating a rewrite.
empty_wires: Vec<Vec<usize>>,
}

impl ECCRewriter {
Expand All @@ -72,18 +80,34 @@ impl ECCRewriter {
let eccs = eccs.into();
let rewrite_rules = get_rewrite_rules(&eccs);
let patterns = get_patterns(&eccs);
let targets = into_targets(eccs);
// Remove failed patterns
let (patterns, rewrite_rules): (Vec<_>, Vec<_>) = patterns
let (patterns, empty_wires, rewrite_rules): (Vec<_>, Vec<_>, Vec<_>) = patterns
.into_iter()
.zip(rewrite_rules)
.filter_map(|(p, r)| Some((p?, r)))
.unzip();
let targets = into_targets(eccs);
.filter_map(|(p, r)| {
// Filter out target IDs where empty wires are not empty
let (pattern, pattern_empty_wires) = p?;
let targets = r
.into_iter()
.filter(|&id| {
let circ = &targets[id.0];
let target_empty_wires: HashSet<_> =
empty_wires(&circ).into_iter().collect();
pattern_empty_wires
.iter()
.all(|&w| target_empty_wires.contains(&w))
})
.collect();
Some((pattern, pattern_empty_wires, targets))
})
.multiunzip();
let matcher = PatternMatcher::from_patterns(patterns);
Self {
matcher,
targets,
rewrite_rules,
empty_wires,
}
}

Expand Down Expand Up @@ -151,7 +175,11 @@ impl Rewriter for ECCRewriter {
.flat_map(|m| {
let pattern_id = m.pattern_id();
self.get_targets(pattern_id).map(move |repl| {
m.to_rewrite(circ.base_hugr(), repl.clone())
let mut repl = repl.clone();
for &empty_qb in self.empty_wires[pattern_id.0].iter().rev() {
remove_empty_wire(&mut repl, empty_qb).unwrap();
}
m.to_rewrite(circ.base_hugr(), repl)
.expect("invalid replacement")
})
})
Expand Down Expand Up @@ -198,11 +226,43 @@ fn get_rewrite_rules(rep_sets: &[EqCircClass]) -> Vec<Vec<TargetID>> {
rewrite_rules
}

fn get_patterns(rep_sets: &[EqCircClass]) -> Vec<Option<CircuitPattern>> {
/// For an equivalence class, return all valid patterns together with the
/// indices of the wires that have been removed in the pattern circuit.
fn get_patterns(rep_sets: &[EqCircClass]) -> Vec<Option<(CircuitPattern, Vec<usize>)>> {
rep_sets
.iter()
.flat_map(|rs| rs.circuits())
.map(|circ| CircuitPattern::try_from_circuit(circ).ok())
.map(|circ| {
let empty_qbs = empty_wires(circ);
let mut circ = circ.clone();
for &qb in empty_qbs.iter().rev() {
remove_empty_wire(&mut circ, qb).unwrap();
}
CircuitPattern::try_from_circuit(&circ)
.ok()
.map(|circ| (circ, empty_qbs))
})
.collect()
}

/// The port offsets of wires that are empty.
fn empty_wires(circ: &impl Circuit) -> Vec<usize> {
let inp = circ.input();
circ.node_outputs(inp)
// Only consider dataflow edges
.filter(|&p| circ.get_optype(inp).signature().get(p).is_some())
// Only consider ports linked to at most one other port
.filter_map(|p| Some((p, circ.linked_ports(inp, p).at_most_one().ok()?)))
// Ports are either connected to output or nothing
.filter_map(|(from, to)| {
if let Some((n, _)) = to {
// Wires connected to output
(n == circ.output()).then_some(from.index())
} else {
// Wires connected to nothing
Some(from.index())
}
})
.collect()
}

Expand Down Expand Up @@ -305,4 +365,15 @@ mod tests {
let exp_n_eccs_of_len = [0, 4 * 2 + 5 * 3, 4, 5];
assert_eq!(n_eccs_of_len, exp_n_eccs_of_len);
}

/// Some inputs are left untouched: these parameters should be removed to
/// obtain convex patterns
#[test]
fn ecc_rewriter_empty_params() {
let test_file = "test_files/cx_cx_eccs.json";
let rewriter = ECCRewriter::try_from_eccs_json_file(test_file).unwrap();

let cx_cx = cx_cx();
assert_eq!(rewriter.get_rewrites(&cx_cx).len(), 1);
}
}
Loading

0 comments on commit 9a40ac9

Please sign in to comment.