Skip to content

Commit

Permalink
fix: Commands iterator ignoring the hierarchy. (#381)
Browse files Browse the repository at this point in the history
The commands iterator did a toposort on the whole Hugr, ignoring the
hierarchy.
This generated problems with #370, since circuits may now be nested
somewhere in the hugr.
It would also have caused problems with circuit boxes, but we don't
support that yet here.

We use a `SiblingGraph` now, to ensure that we only explore the
top-level region of the circuit.
Subcircuits will be returned as a single command.

Adds a `build_module_with_circuit` helper to build circuits inside
modules.

Closes #42.
  • Loading branch information
aborgna-q authored Jun 4, 2024
1 parent 3c7684f commit 50ee0fa
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 20 deletions.
72 changes: 58 additions & 14 deletions tket2/src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::iter::FusedIterator;

use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::hugr::NodeType;
use hugr::ops::{OpTag, OpTrait};
use hugr::{HugrView, IncomingPort, OutgoingPort};
Expand Down Expand Up @@ -237,12 +238,14 @@ type NodeWalker = pv::Topo<Node, HashSet<Node>>;
pub struct CommandIterator<'circ, T> {
/// The circuit.
circ: &'circ Circuit<T>,
/// A view of the top-level region of the circuit.
region: SiblingGraph<'circ>,
/// Toposorted nodes.
nodes: NodeWalker,
/// Last wire for each [`LinearUnit`] in the circuit.
wire_unit: HashMap<Wire, usize>,
/// Remaining commands, not counting I/O nodes.
remaining: usize,
/// Maximum number of remaining commands, not counting I/O nodes nor root nodes.
max_remaining: usize,
/// Delayed output of constant and load const nodes. Contains nodes that
/// haven't been yielded yet.
///
Expand Down Expand Up @@ -275,13 +278,16 @@ impl<'circ, T: HugrView> CommandIterator<'circ, T> {
.map(|(linear_unit, port, _)| (Wire::new(circ.input_node(), port), linear_unit.index()))
.collect();

let nodes = pv::Topo::new(&circ.hugr().as_petgraph());
let region: SiblingGraph = SiblingGraph::try_new(circ.hugr(), circ.parent()).unwrap();
let node_count = region.node_count();
let nodes = pv::Topo::new(&region.as_petgraph());
Self {
circ,
region,
nodes,
wire_unit,
// Ignore the input and output nodes, and the root.
remaining: circ.hugr().node_count() - 3,
max_remaining: node_count - 3,
delayed_consts: HashSet::new(),
delayed_consumers: HashMap::new(),
delayed_node: None,
Expand All @@ -296,7 +302,12 @@ impl<'circ, T: HugrView> CommandIterator<'circ, T> {
let node = self
.delayed_node
.take()
.or_else(|| self.nodes.next(&self.circ.hugr().as_petgraph()))?;
.or_else(|| self.nodes.next(&self.region.as_petgraph()))?;
if node == self.circ.parent() {
// Ignore the root of the circuit.
// This will only happen once.
return self.next_node();
}

// If this node is a constant or load const node, delay it.
let tag = self.circ.hugr().get_optype(node).tag();
Expand Down Expand Up @@ -432,7 +443,7 @@ impl<'circ, T: HugrView> Iterator for CommandIterator<'circ, T> {
let node = self.next_node()?;
// Process the node, returning a command if it's not an input or output.
if let Some((input_linear_units, output_linear_units)) = self.process_node(node) {
self.remaining -= 1;
self.max_remaining -= 1;
return Some(Command {
circ: self.circ,
node,
Expand All @@ -445,7 +456,7 @@ impl<'circ, T: HugrView> Iterator for CommandIterator<'circ, T> {

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
(0, Some(self.max_remaining))
}
}

Expand All @@ -456,7 +467,7 @@ impl<'circ, T: HugrView> std::fmt::Debug for CommandIterator<'circ, T> {
f.debug_struct("CommandIterator")
.field("circuit name", &self.circ.name())
.field("wire_unit", &self.wire_unit)
.field("remaining", &self.remaining)
.field("max_remaining", &self.max_remaining)
.finish()
}
}
Expand All @@ -465,17 +476,19 @@ impl<'circ, T: HugrView> std::fmt::Debug for CommandIterator<'circ, T> {
mod test {
use hugr::builder::{Container, DFGBuilder, Dataflow, DataflowHugr};
use hugr::extension::prelude::QB_T;
use hugr::hugr::hugrmut::HugrMut;
use hugr::ops::handle::NodeHandle;
use hugr::ops::{NamedOp, Value};
use hugr::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use hugr::std_extensions::arithmetic::float_types::ConstF64;
use hugr::types::FunctionType;
use itertools::Itertools;
use rstest::{fixture, rstest};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

use crate::extension::REGISTRY;
use crate::utils::build_simple_circuit;
use crate::utils::{build_module_with_circuit, build_simple_circuit};
use crate::Tk2Op;

use super::*;
Expand All @@ -487,22 +500,53 @@ mod test {
};
}

#[test]
fn iterate_commands() {
let circ = build_simple_circuit(2, |circ| {
/// 2-qubit circuit with a Hadamard, a CNOT, and a T gate.
#[fixture]
fn simple_circuit() -> Circuit {
build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::T, [1])?;
Ok(())
})
.unwrap()
}

/// 2-qubit circuit with a Hadamard, a CNOT, and a T gate,
/// defined inside a module.
#[fixture]
fn simple_module() -> Circuit {
build_module_with_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::T, [1])?;
Ok(())
})
.unwrap();
.unwrap()
}

/// 2-qubit circuit with a Hadamard, a CNOT, and a T gate,
/// defined inside a module containing other circuits.
#[fixture]
fn module_with_circuits() -> Circuit {
let mut module = simple_module();
let other_circ = simple_circuit();
let hugr = module.hugr_mut();
hugr.insert_hugr(hugr.root(), other_circ.into_hugr());
return module;
}

#[rstest]
#[case::dfg_rooted(simple_circuit())]
#[case::module_rooted(simple_module())]
#[case::complex_module_rooted(module_with_circuits())]
fn iterate_commands_simple(#[case] circ: Circuit) {
assert_eq!(CommandIterator::new(&circ).count(), 3);

let tk2op_name = |op: Tk2Op| op.exposed_name();

let mut commands = CommandIterator::new(&circ);
assert_eq!(commands.size_hint(), (3, Some(3)));
assert_eq!(commands.size_hint(), (0, Some(3)));

let hadamard = commands.next().unwrap();
assert_eq!(hadamard.optype().name().as_str(), tk2op_name(Tk2Op::H));
Expand Down
34 changes: 28 additions & 6 deletions tket2/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
//! Utility functions for the library.
use hugr::builder::{Container, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder};
use hugr::extension::PRELUDE_REGISTRY;
use hugr::ops::handle::NodeHandle;
use hugr::types::{Type, TypeBound};
use hugr::Hugr;
use hugr::{
builder::{BuildError, CircuitBuilder, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
types::FunctionType,
Hugr,
};

use crate::circuit::Circuit;
Expand All @@ -15,12 +17,12 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool {
!TypeBound::Copyable.contains(typ.least_upper_bound())
}

// utility for building simple qubit-only circuits.
/// Utility for building simple qubit-only circuits.
#[allow(unused)]
pub(crate) fn build_simple_circuit(
num_qubits: usize,
f: impl FnOnce(&mut CircuitBuilder<DFGBuilder<Hugr>>) -> Result<(), BuildError>,
) -> Result<Circuit, BuildError> {
pub(crate) fn build_simple_circuit<F>(num_qubits: usize, f: F) -> Result<Circuit, BuildError>
where
F: FnOnce(&mut CircuitBuilder<DFGBuilder<Hugr>>) -> Result<(), BuildError>,
{
let qb_row = vec![QB_T; num_qubits];
let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?;

Expand All @@ -35,6 +37,26 @@ pub(crate) fn build_simple_circuit(
Ok(hugr.into())
}

/// Utility for building a module with a single circuit definition.
#[allow(unused)]
pub(crate) fn build_module_with_circuit<F>(num_qubits: usize, f: F) -> Result<Circuit, BuildError>
where
F: FnOnce(&mut CircuitBuilder<FunctionBuilder<&mut Hugr>>) -> Result<(), BuildError>,
{
let mut builder = ModuleBuilder::new();
let circ = {
let qb_row = vec![QB_T; num_qubits];
let circ_signature = FunctionType::new(qb_row.clone(), qb_row);
let mut dfg = builder.define_function("main", circ_signature.into())?;
let mut circ = dfg.as_circuit(dfg.input_wires());
f(&mut circ)?;
let qbs = circ.finish();
dfg.finish_with_outputs(qbs)?
};
let hugr = builder.finish_hugr(&PRELUDE_REGISTRY)?;
Ok(Circuit::new(hugr, circ.node()))
}

// Test only utils
#[allow(dead_code)]
#[allow(unused_imports)]
Expand Down

0 comments on commit 50ee0fa

Please sign in to comment.