diff --git a/tket2/src/circuit/command.rs b/tket2/src/circuit/command.rs index 09af12a1..50aa5815 100644 --- a/tket2/src/circuit/command.rs +++ b/tket2/src/circuit/command.rs @@ -6,6 +6,7 @@ use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::iter::FusedIterator; +use hugr::hugr::views::{HierarchyView, SiblingGraph}; use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; use hugr::{HugrView, IncomingPort, OutgoingPort}; @@ -237,12 +238,14 @@ type NodeWalker = pv::Topo>; pub struct CommandIterator<'circ, T> { /// The circuit. circ: &'circ Circuit, + /// A view of the top-level region of the circuit. + region: SiblingGraph<'circ>, /// Toposorted nodes. nodes: NodeWalker, /// Last wire for each [`LinearUnit`] in the circuit. wire_unit: HashMap, - /// Remaining commands, not counting I/O nodes. - remaining: usize, + /// Maximum number of remaining commands, not counting I/O nodes nor root nodes. + max_remaining: usize, /// Delayed output of constant and load const nodes. Contains nodes that /// haven't been yielded yet. /// @@ -275,13 +278,16 @@ impl<'circ, T: HugrView> CommandIterator<'circ, T> { .map(|(linear_unit, port, _)| (Wire::new(circ.input_node(), port), linear_unit.index())) .collect(); - let nodes = pv::Topo::new(&circ.hugr().as_petgraph()); + let region: SiblingGraph = SiblingGraph::try_new(circ.hugr(), circ.parent()).unwrap(); + let node_count = region.node_count(); + let nodes = pv::Topo::new(®ion.as_petgraph()); Self { circ, + region, nodes, wire_unit, // Ignore the input and output nodes, and the root. - remaining: circ.hugr().node_count() - 3, + max_remaining: node_count - 3, delayed_consts: HashSet::new(), delayed_consumers: HashMap::new(), delayed_node: None, @@ -296,7 +302,12 @@ impl<'circ, T: HugrView> CommandIterator<'circ, T> { let node = self .delayed_node .take() - .or_else(|| self.nodes.next(&self.circ.hugr().as_petgraph()))?; + .or_else(|| self.nodes.next(&self.region.as_petgraph()))?; + if node == self.circ.parent() { + // Ignore the root of the circuit. + // This will only happen once. + return self.next_node(); + } // If this node is a constant or load const node, delay it. let tag = self.circ.hugr().get_optype(node).tag(); @@ -432,7 +443,7 @@ impl<'circ, T: HugrView> Iterator for CommandIterator<'circ, T> { let node = self.next_node()?; // Process the node, returning a command if it's not an input or output. if let Some((input_linear_units, output_linear_units)) = self.process_node(node) { - self.remaining -= 1; + self.max_remaining -= 1; return Some(Command { circ: self.circ, node, @@ -445,7 +456,7 @@ impl<'circ, T: HugrView> Iterator for CommandIterator<'circ, T> { #[inline] fn size_hint(&self) -> (usize, Option) { - (self.remaining, Some(self.remaining)) + (0, Some(self.max_remaining)) } } @@ -456,7 +467,7 @@ impl<'circ, T: HugrView> std::fmt::Debug for CommandIterator<'circ, T> { f.debug_struct("CommandIterator") .field("circuit name", &self.circ.name()) .field("wire_unit", &self.wire_unit) - .field("remaining", &self.remaining) + .field("max_remaining", &self.max_remaining) .finish() } } @@ -465,17 +476,19 @@ impl<'circ, T: HugrView> std::fmt::Debug for CommandIterator<'circ, T> { mod test { use hugr::builder::{Container, DFGBuilder, Dataflow, DataflowHugr}; use hugr::extension::prelude::QB_T; + use hugr::hugr::hugrmut::HugrMut; use hugr::ops::handle::NodeHandle; use hugr::ops::{NamedOp, Value}; use hugr::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; use hugr::std_extensions::arithmetic::float_types::ConstF64; use hugr::types::FunctionType; use itertools::Itertools; + use rstest::{fixture, rstest}; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use crate::extension::REGISTRY; - use crate::utils::build_simple_circuit; + use crate::utils::{build_module_with_circuit, build_simple_circuit}; use crate::Tk2Op; use super::*; @@ -487,22 +500,53 @@ mod test { }; } - #[test] - fn iterate_commands() { - let circ = build_simple_circuit(2, |circ| { + /// 2-qubit circuit with a Hadamard, a CNOT, and a T gate. + #[fixture] + fn simple_circuit() -> Circuit { + build_simple_circuit(2, |circ| { + circ.append(Tk2Op::H, [0])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::T, [1])?; + Ok(()) + }) + .unwrap() + } + + /// 2-qubit circuit with a Hadamard, a CNOT, and a T gate, + /// defined inside a module. + #[fixture] + fn simple_module() -> Circuit { + build_module_with_circuit(2, |circ| { circ.append(Tk2Op::H, [0])?; circ.append(Tk2Op::CX, [0, 1])?; circ.append(Tk2Op::T, [1])?; Ok(()) }) - .unwrap(); + .unwrap() + } + + /// 2-qubit circuit with a Hadamard, a CNOT, and a T gate, + /// defined inside a module containing other circuits. + #[fixture] + fn module_with_circuits() -> Circuit { + let mut module = simple_module(); + let other_circ = simple_circuit(); + let hugr = module.hugr_mut(); + hugr.insert_hugr(hugr.root(), other_circ.into_hugr()); + return module; + } + #[rstest] + #[case::dfg_rooted(simple_circuit())] + #[case::module_rooted(simple_module())] + #[case::complex_module_rooted(module_with_circuits())] + fn iterate_commands_simple(#[case] circ: Circuit) { assert_eq!(CommandIterator::new(&circ).count(), 3); let tk2op_name = |op: Tk2Op| op.exposed_name(); let mut commands = CommandIterator::new(&circ); - assert_eq!(commands.size_hint(), (3, Some(3))); + assert_eq!(commands.size_hint(), (0, Some(3))); let hadamard = commands.next().unwrap(); assert_eq!(hadamard.optype().name().as_str(), tk2op_name(Tk2Op::H)); diff --git a/tket2/src/utils.rs b/tket2/src/utils.rs index ec4daa87..727a2f09 100644 --- a/tket2/src/utils.rs +++ b/tket2/src/utils.rs @@ -1,12 +1,14 @@ //! Utility functions for the library. +use hugr::builder::{Container, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder}; use hugr::extension::PRELUDE_REGISTRY; +use hugr::ops::handle::NodeHandle; use hugr::types::{Type, TypeBound}; +use hugr::Hugr; use hugr::{ builder::{BuildError, CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::QB_T, types::FunctionType, - Hugr, }; use crate::circuit::Circuit; @@ -15,12 +17,12 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool { !TypeBound::Copyable.contains(typ.least_upper_bound()) } -// utility for building simple qubit-only circuits. +/// Utility for building simple qubit-only circuits. #[allow(unused)] -pub(crate) fn build_simple_circuit( - num_qubits: usize, - f: impl FnOnce(&mut CircuitBuilder>) -> Result<(), BuildError>, -) -> Result { +pub(crate) fn build_simple_circuit(num_qubits: usize, f: F) -> Result +where + F: FnOnce(&mut CircuitBuilder>) -> Result<(), BuildError>, +{ let qb_row = vec![QB_T; num_qubits]; let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?; @@ -35,6 +37,26 @@ pub(crate) fn build_simple_circuit( Ok(hugr.into()) } +/// Utility for building a module with a single circuit definition. +#[allow(unused)] +pub(crate) fn build_module_with_circuit(num_qubits: usize, f: F) -> Result +where + F: FnOnce(&mut CircuitBuilder>) -> Result<(), BuildError>, +{ + let mut builder = ModuleBuilder::new(); + let circ = { + let qb_row = vec![QB_T; num_qubits]; + let circ_signature = FunctionType::new(qb_row.clone(), qb_row); + let mut dfg = builder.define_function("main", circ_signature.into())?; + let mut circ = dfg.as_circuit(dfg.input_wires()); + f(&mut circ)?; + let qbs = circ.finish(); + dfg.finish_with_outputs(qbs)? + }; + let hugr = builder.finish_hugr(&PRELUDE_REGISTRY)?; + Ok(Circuit::new(hugr, circ.node())) +} + // Test only utils #[allow(dead_code)] #[allow(unused_imports)]