From 2ad527c20056667b8a9940c6cad9c93adb72ea19 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 6 Jun 2024 10:48:47 +0100 Subject: [PATCH] feat: `Circuit::num_operations` only counting actual operations --- tket2/src/circuit.rs | 97 +++++++++++++++++---- tket2/src/lib.rs | 2 +- tket2/src/optimiser/badger/eq_circ_class.rs | 5 +- tket2/src/portmatching/pattern.rs | 2 +- tket2/src/rewrite.rs | 2 +- tket2/src/rewrite/strategy.rs | 8 +- 6 files changed, 91 insertions(+), 25 deletions(-) diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index d6d90576..1ef9695d 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -147,14 +147,31 @@ impl Circuit { .expect("Circuit has no I/O nodes") } - /// The number of quantum gates in the circuit. + /// The number of operations in the circuit. + /// + /// This includes [`Tk2Op`]s, pytket ops, and any other custom operations. + /// + /// Nested circuits are traversed to count their operations. + /// + /// [`Tk2Op`]: crate::Tk2Op #[inline] - pub fn num_gates(&self) -> usize + pub fn num_operations(&self) -> usize where Self: Sized, { - // TODO: Discern quantum gates in the commands iterator. - self.hugr().children(self.parent).count() - 2 + let mut count = 0; + let mut roots = vec![self.parent]; + while let Some(node) = roots.pop() { + for child in self.hugr().children(node) { + let optype = self.hugr().get_optype(child); + if optype.is_custom_op() { + count += 1; + } else if OpTag::DataflowParent.is_superset(optype.tag()) { + roots.push(child); + } + } + } + count } /// Count the number of qubits in the circuit. @@ -471,6 +488,7 @@ fn update_signature( #[cfg(test)] mod tests { use cool_asserts::assert_matches; + use rstest::{fixture, rstest}; use hugr::types::FunctionType; use hugr::{ @@ -479,9 +497,11 @@ mod tests { }; use super::*; + use crate::utils::build_module_with_circuit; use crate::{json::load_tk1_json_str, utils::build_simple_circuit, Tk2Op}; - fn test_circuit() -> Circuit { + #[fixture] + fn tk1_circuit() -> Circuit { load_tk1_json_str( r#"{ "phase": "0", "bits": [["c", [0]]], @@ -489,7 +509,7 @@ mod tests { "commands": [ {"args": [["q", [0]]], "op": {"type": "H"}}, {"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}}, - {"args": [["q", [1]]], "op": {"type": "X"}} + {"args": [["q", [1]]], "op": {"params": ["0.25"], "type": "Rz"}} ], "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] }"#, @@ -497,20 +517,63 @@ mod tests { .unwrap() } - #[test] - fn test_circuit_properties() { - let circ = test_circuit(); + /// 2-qubit circuit with a Hadamard, a CNOT, and a Rz gate. + #[fixture] + fn simple_circuit() -> Circuit { + build_simple_circuit(2, |circ| { + circ.append(Tk2Op::H, [0])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::X, [1])?; - assert_eq!(circ.name(), None); - assert_eq!(circ.circuit_signature().body().input_count(), 3); - assert_eq!(circ.circuit_signature().body().output_count(), 3); - assert_eq!(circ.qubit_count(), 2); - assert_eq!(circ.num_gates(), 3); + // TODO: Replace the `X` with the following once Hugr adds `CircuitBuilder::add_constant`. + // See https://github.com/CQCL/hugr/pull/1168 + + //let angle = circ.add_constant(ConstF64::new(0.5)); + //circ.append_and_consume( + // Tk2Op::RzF64, + // [CircuitUnit::Linear(1), CircuitUnit::Wire(angle)], + //)?; + Ok(()) + }) + .unwrap() + } + + /// 2-qubit circuit with a Hadamard, a CNOT, and a Rz gate, + /// defined inside a module. + #[fixture] + fn simple_module() -> Circuit { + build_module_with_circuit(2, |circ| { + circ.append(Tk2Op::H, [0])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::X, [1])?; + Ok(()) + }) + .unwrap() + } + + #[rstest] + #[case::simple(simple_circuit(), 2, 0, None)] + #[case::module(simple_module(), 2, 0, None)] + #[case::tk1(tk1_circuit(), 2, 1, None)] + fn test_circuit_properties( + #[case] circ: Circuit, + #[case] qubits: usize, + #[case] bits: usize, + #[case] name: Option<&str>, + ) { + assert_eq!(circ.name(), name); + assert_eq!(circ.circuit_signature().body().input_count(), qubits + bits); + assert_eq!( + circ.circuit_signature().body().output_count(), + qubits + bits + ); + assert_eq!(circ.qubit_count(), qubits); + assert_eq!(circ.num_operations(), 3); - assert_eq!(circ.units().count(), 3); + assert_eq!(circ.units().count(), qubits + bits); assert_eq!(circ.nonlinear_units().count(), 0); - assert_eq!(circ.linear_units().count(), 3); - assert_eq!(circ.qubits().count(), 2); + assert_eq!(circ.linear_units().count(), qubits + bits); + assert_eq!(circ.qubits().count(), qubits); } #[test] diff --git a/tket2/src/lib.rs b/tket2/src/lib.rs index 756b52d8..a11de7ea 100644 --- a/tket2/src/lib.rs +++ b/tket2/src/lib.rs @@ -27,7 +27,7 @@ //! let mut circ: Circuit = tket2::json::load_tk1_json_file("../test_files/barenco_tof_5.json").unwrap(); //! //! assert_eq!(circ.qubit_count(), 9); -//! assert_eq!(circ.num_gates(), 170); +//! assert_eq!(circ.num_operations(), 170); //! //! // Traverse the circuit and print the gates. //! for command in circ.commands() { diff --git a/tket2/src/optimiser/badger/eq_circ_class.rs b/tket2/src/optimiser/badger/eq_circ_class.rs index c2c7f118..732f42ee 100644 --- a/tket2/src/optimiser/badger/eq_circ_class.rs +++ b/tket2/src/optimiser/badger/eq_circ_class.rs @@ -74,7 +74,10 @@ impl EqCircClass { }; // Find the index for the smallest circuit - let min_index = circs.iter().position_min_by_key(|c| c.num_gates()).unwrap(); + let min_index = circs + .iter() + .position_min_by_key(|c| c.num_operations()) + .unwrap(); let representative = circs.swap_remove(min_index); Ok(Self::new(representative, circs)) } diff --git a/tket2/src/portmatching/pattern.rs b/tket2/src/portmatching/pattern.rs index 4a460ebd..7b6d2820 100644 --- a/tket2/src/portmatching/pattern.rs +++ b/tket2/src/portmatching/pattern.rs @@ -32,7 +32,7 @@ impl CircuitPattern { /// Construct a pattern from a circuit. pub fn try_from_circuit(circuit: &Circuit) -> Result { let hugr = circuit.hugr(); - if circuit.num_gates() == 0 { + if circuit.num_operations() == 0 { return Err(InvalidPattern::EmptyCircuit); } let mut pattern = Pattern::new(); diff --git a/tket2/src/rewrite.rs b/tket2/src/rewrite.rs index 85eeceee..a2fb220d 100644 --- a/tket2/src/rewrite.rs +++ b/tket2/src/rewrite.rs @@ -84,7 +84,7 @@ impl CircuitRewrite { /// The difference between the new number of nodes minus the old. A positive /// number is an increase in node count, a negative number is a decrease. pub fn node_count_delta(&self) -> isize { - let new_count = self.replacement().num_gates() as isize; + let new_count = self.replacement().num_operations() as isize; let old_count = self.subcircuit().node_count() as isize; new_count - old_count } diff --git a/tket2/src/rewrite/strategy.rs b/tket2/src/rewrite/strategy.rs index 3e44b638..652021c2 100644 --- a/tket2/src/rewrite/strategy.rs +++ b/tket2/src/rewrite/strategy.rs @@ -144,7 +144,7 @@ impl RewriteStrategy for GreedyRewriteStrategy { } fn circuit_cost(&self, circ: &Circuit) -> Self::Cost { - circ.num_gates() + circ.num_operations() } fn op_cost(&self, _op: &OpType) -> Self::Cost { @@ -488,7 +488,7 @@ mod tests { let strategy = GreedyRewriteStrategy; let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec(); assert_eq!(rewritten.len(), 1); - assert_eq!(rewritten[0].circ.num_gates(), 5); + assert_eq!(rewritten[0].circ.num_operations(), 5); if REWRITE_TRACING_ENABLED { assert_eq!(rewritten[0].circ.rewrite_trace().unwrap().len(), 3); @@ -511,7 +511,7 @@ mod tests { let strategy = LexicographicCostFunction::default_cx(); let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec(); let exp_circ_lens = HashSet::from_iter([3, 7, 9]); - let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_gates()).collect(); + let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_operations()).collect(); assert_eq!(circ_lens, exp_circ_lens); if REWRITE_TRACING_ENABLED { @@ -547,7 +547,7 @@ mod tests { let strategy = GammaStrategyCost::exhaustive_cx_with_gamma(10.); let rewritten = strategy.apply_rewrites(rws, &circ); let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]); - let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_gates()).collect(); + let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_operations()).collect(); assert_eq!(circ_lens, exp_circ_lens); }