Skip to content

Commit

Permalink
Merge branch 'main' into perf/taso-table-bound
Browse files Browse the repository at this point in the history
# Conflicts:
#	src/optimiser/taso.rs
#	src/rewrite/strategy.rs
  • Loading branch information
lmondada committed Oct 3, 2023
2 parents e628060 + 0dcf5e0 commit 283c0c0
Show file tree
Hide file tree
Showing 10 changed files with 7,834 additions and 131 deletions.
6 changes: 6 additions & 0 deletions src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ use self::encoder::JsonEncoder;

/// Prefix used for storing metadata in the hugr nodes.
pub const METADATA_PREFIX: &str = "TKET1_JSON";
/// The global phase specified as metadata.
const METADATA_PHASE: &str = "TKET1_JSON.phase";
/// The implicit permutation of qubits.
const METADATA_IMPLICIT_PERM: &str = "TKET1_JSON.implicit_permutation";
/// Explicit names for the input qubit registers.
const METADATA_Q_REGISTERS: &str = "TKET1_JSON.qubit_registers";
/// Explicit names for the input bit registers.
const METADATA_B_REGISTERS: &str = "TKET1_JSON.bit_registers";

/// A JSON-serialized circuit that can be converted to a [`Hugr`].
pub trait TKETDecode: Sized {
Expand Down
3 changes: 3 additions & 0 deletions src/json/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use tket_json_rs::circuit_json::SerialCircuit;
use super::op::JsonOp;
use super::{try_param_to_constant, METADATA_IMPLICIT_PERM, METADATA_PHASE};
use crate::extension::{LINEAR_BIT, REGISTRY};
use crate::json::{METADATA_B_REGISTERS, METADATA_Q_REGISTERS};
use crate::symbolic_constant_op;

/// The state of an in-progress [`DFGBuilder`] being built from a [`SerialCircuit`].
Expand Down Expand Up @@ -76,6 +77,8 @@ impl JsonDecoder {
"name": serialcirc.name,
METADATA_PHASE: serialcirc.phase,
METADATA_IMPLICIT_PERM: serialcirc.implicit_permutation,
METADATA_Q_REGISTERS: serialcirc.qubits,
METADATA_B_REGISTERS: serialcirc.bits,
});
dfg.set_metadata(metadata);

Expand Down
103 changes: 63 additions & 40 deletions src/json/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ use crate::extension::LINEAR_BIT;
use crate::ops::match_symb_const_op;

use super::op::JsonOp;
use super::{OpConvertError, METADATA_IMPLICIT_PERM, METADATA_PHASE};
use super::{
OpConvertError, METADATA_B_REGISTERS, METADATA_IMPLICIT_PERM, METADATA_PHASE,
METADATA_Q_REGISTERS,
};

/// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`].
#[derive(Debug, PartialEq)]
Expand All @@ -30,9 +33,13 @@ pub(super) struct JsonEncoder {
/// The current commands
commands: Vec<circuit_json::Command>,
/// The TKET1 qubit registers associated to each qubit unit of the circuit.
qubit_regs: HashMap<CircuitUnit, circuit_json::Register>,
qubit_to_reg: HashMap<CircuitUnit, Register>,
/// The TKET1 bit registers associated to each linear bit unit of the circuit.
bit_regs: HashMap<CircuitUnit, circuit_json::Register>,
bit_to_reg: HashMap<CircuitUnit, Register>,
/// The ordered TKET1 names for the input qubit registers.
qubit_registers: Vec<Register>,
/// The ordered TKET1 names for the input bit registers.
bit_registers: Vec<Register>,
/// A register of wires with constant values, used to recover TKET1
/// parameters.
parameters: HashMap<Wire, String>,
Expand All @@ -43,48 +50,63 @@ impl JsonEncoder {
pub fn new(circ: &impl Circuit) -> Self {
let name = circ.name().map(str::to_string);

// Compute the linear qubit and bit registers. Each one have independent
// indices starting from zero.
//
// TODO Throw an error on non-recognized unit types, or just ignore?
let mut bit_units = HashMap::new();
let mut qubit_units = HashMap::new();
let mut qubit_registers = vec![];
let mut bit_registers = vec![];
let mut phase = "0".to_string();
let mut implicit_permutation = vec![];

// Recover other parameters stored in the metadata
if let Some(meta) = circ.get_metadata(circ.root()).as_object() {
if let Some(p) = meta.get(METADATA_PHASE) {
// TODO: Check for invalid encoded metadata
phase = p.as_str().unwrap().to_string();
}
if let Some(perm) = meta.get(METADATA_IMPLICIT_PERM) {
// TODO: Check for invalid encoded metadata
implicit_permutation = serde_json::from_value(perm.clone()).unwrap();
}
if let Some(q_regs) = meta.get(METADATA_Q_REGISTERS) {
qubit_registers = serde_json::from_value(q_regs.clone()).unwrap();
}
if let Some(b_regs) = meta.get(METADATA_B_REGISTERS) {
bit_registers = serde_json::from_value(b_regs.clone()).unwrap();
}
}

// Map the Hugr units to tket1 register names.
// Uses the names from the metadata if available, or initializes new sequentially-numbered registers.
let mut bit_to_reg = HashMap::new();
let mut qubit_to_reg = HashMap::new();
let get_register = |registers: &mut Vec<Register>, prefix: &str, index| {
registers.get(index).cloned().unwrap_or_else(|| {
let r = Register(prefix.to_string(), vec![index as i64]);
registers.push(r.clone());
r
})
};
for (unit, _, ty) in circ.units() {
if ty == QB_T {
let index = vec![qubit_units.len() as i64];
let reg = circuit_json::Register("q".to_string(), index);
qubit_units.insert(unit, reg);
let index = qubit_to_reg.len();
let reg = get_register(&mut qubit_registers, "q", index);
qubit_to_reg.insert(unit, reg);
} else if ty == *LINEAR_BIT {
let index = vec![bit_units.len() as i64];
let reg = circuit_json::Register("c".to_string(), index);
bit_units.insert(unit, reg);
let index = bit_to_reg.len();
let reg = get_register(&mut bit_registers, "b", index);
bit_to_reg.insert(unit, reg.clone());
}
}

let mut encoder = Self {
Self {
name,
phase: "0".to_string(),
implicit_permutation: vec![],
phase,
implicit_permutation,
commands: vec![],
qubit_regs: qubit_units,
bit_regs: bit_units,
qubit_to_reg,
bit_to_reg,
qubit_registers,
bit_registers,
parameters: HashMap::new(),
};

// Encode other parameters stored in the metadata
if let Some(meta) = circ.get_metadata(circ.root()).as_object() {
if let Some(phase) = meta.get(METADATA_PHASE) {
// TODO: Check for invalid encoded metadata
encoder.phase = phase.as_str().unwrap().to_string();
}
if let Some(implicit_perm) = meta.get(METADATA_IMPLICIT_PERM) {
// TODO: Check for invalid encoded metadata
encoder.implicit_permutation =
serde_json::from_value(implicit_perm.clone()).unwrap();
}
}

encoder
}

/// Add a circuit command to the serialization.
Expand Down Expand Up @@ -139,8 +161,8 @@ impl JsonEncoder {
name: self.name,
phase: self.phase,
commands: self.commands,
qubits: self.qubit_regs.into_values().collect_vec(),
bits: self.bit_regs.into_values().collect_vec(),
qubits: self.qubit_registers,
bits: self.bit_registers,
implicit_permutation: self.implicit_permutation,
}
}
Expand Down Expand Up @@ -208,10 +230,11 @@ impl JsonEncoder {
true
}

fn unit_to_register(&self, unit: CircuitUnit) -> Option<circuit_json::Register> {
self.qubit_regs
/// Translate a linear [`CircuitUnit`] into a [`Register`], if possible.
fn unit_to_register(&self, unit: CircuitUnit) -> Option<Register> {
self.qubit_to_reg
.get(&unit)
.or_else(|| self.bit_regs.get(&unit))
.or_else(|| self.bit_to_reg.get(&unit))
.cloned()
}
}
21 changes: 16 additions & 5 deletions src/json/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! General tests.
use std::collections::HashSet;
use std::io::BufReader;

use hugr::Hugr;
use rstest::rstest;
Expand Down Expand Up @@ -64,16 +64,27 @@ fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num
compare_serial_circs(&ser, &reser);
}

#[rstest]
#[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
#[case::barenco_tof_10("test_files/barenco_tof_10.json")]
fn json_file_roundtrip(#[case] circ: impl AsRef<std::path::Path>) {
let reader = BufReader::new(std::fs::File::open(circ).unwrap());
let ser: circuit_json::SerialCircuit = serde_json::from_reader(reader).unwrap();
let circ: Hugr = ser.clone().decode().unwrap();
let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap();
compare_serial_circs(&ser, &reser);
}

fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) {
assert_eq!(a.name, b.name);
assert_eq!(a.phase, b.phase);

let qubits_a: HashSet<_> = a.qubits.iter().collect();
let qubits_b: HashSet<_> = b.qubits.iter().collect();
let qubits_a: Vec<_> = a.qubits.iter().collect();
let qubits_b: Vec<_> = b.qubits.iter().collect();
assert_eq!(qubits_a, qubits_b);

let bits_a: HashSet<_> = a.bits.iter().collect();
let bits_b: HashSet<_> = b.bits.iter().collect();
let bits_a: Vec<_> = a.bits.iter().collect();
let bits_b: Vec<_> = b.bits.iter().collect();
assert_eq!(bits_a, bits_b);

assert_eq!(a.implicit_permutation, b.implicit_permutation);
Expand Down
13 changes: 9 additions & 4 deletions src/optimiser/taso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,20 +284,25 @@ mod taso_default {
use std::io;
use std::path::Path;

use hugr::ops::OpType;

use crate::rewrite::ecc_rewriter::RewriterSerialisationError;
use crate::rewrite::strategy::NonIncreasingCXCountStrategy;
use crate::rewrite::strategy::NonIncreasingGateCountStrategy;
use crate::rewrite::ECCRewriter;

use super::*;

/// The default TASO optimiser using ECC sets.
pub type DefaultTasoOptimiser = TasoOptimiser<ECCRewriter, NonIncreasingCXCountStrategy>;
pub type DefaultTasoOptimiser = TasoOptimiser<
ECCRewriter,
NonIncreasingGateCountStrategy<fn(&OpType) -> usize, fn(&OpType) -> usize>,
>;

impl DefaultTasoOptimiser {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = NonIncreasingCXCountStrategy::default_cx();
let strategy = NonIncreasingGateCountStrategy::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}

Expand All @@ -306,7 +311,7 @@ mod taso_default {
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = NonIncreasingCXCountStrategy::default_cx();
let strategy = NonIncreasingGateCountStrategy::default_cx();
Ok(TasoOptimiser::new(rewriter, strategy))
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/passes/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ impl CircuitChunks {
}
}

reassembled.set_metadata(root, self.root_meta)?;

Ok(reassembled)
}

Expand Down
Loading

0 comments on commit 283c0c0

Please sign in to comment.