diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index 92b72f6f..cfe3da4f 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -11,6 +11,7 @@ use std::iter::Sum; pub use command::{Command, CommandIterator}; pub use hash::CircuitHash; use hugr::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView}; +use hugr_core::hugr::internal::HugrMutInternals; use itertools::Either::{Left, Right}; use hugr::hugr::hugrmut::HugrMut; @@ -317,7 +318,11 @@ impl Circuit { } else { let view: DescendantsGraph = DescendantsGraph::try_new(&self.hugr, self.parent) .expect("Circuit parent was not a dataflow container."); - view.extract_hugr().into() + let mut hugr = view.extract_hugr(); + // TODO: Remove this once hugr 0.6.0 gets released. + // https://github.com/CQCL/hugr/pull/1239 + hugr.set_num_ports(hugr.root(), 0, 0); + hugr.into() }; extract_dfg::rewrite_into_dfg(&mut circ)?; Ok(circ) diff --git a/tket2/src/circuit/extract_dfg.rs b/tket2/src/circuit/extract_dfg.rs index d7deb6de..17c10ed2 100644 --- a/tket2/src/circuit/extract_dfg.rs +++ b/tket2/src/circuit/extract_dfg.rs @@ -47,7 +47,7 @@ fn remove_cfg_empty_output_tuple( signature: FunctionType, ) -> Result { let sig = signature; - let parent = circ.parent(); + let input_node = circ.input_node(); let output_node = circ.output_node(); let output_nodetype = circ.hugr.get_nodetype(output_node).clone(); @@ -89,8 +89,8 @@ fn remove_cfg_empty_output_tuple( let new_op = Output { types: new_types.clone().into(), }; - let new_node = hugr.add_node_with_parent( - parent, + let new_node = hugr.add_node_after( + input_node, NodeType::new( new_op, output_nodetype diff --git a/tket2/src/passes/pytket.rs b/tket2/src/passes/pytket.rs index c4dd3ada..9def01fa 100644 --- a/tket2/src/passes/pytket.rs +++ b/tket2/src/passes/pytket.rs @@ -38,3 +38,133 @@ pub enum PytketLoweringError { #[error("Non-local operations found. Function calls are not supported.")] NonLocalOperations, } + +#[cfg(test)] +mod test { + use crate::extension::REGISTRY; + use crate::Tk2Op; + + use super::*; + use hugr::builder::{ + Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer, + }; + use hugr::extension::prelude::QB_T; + use hugr::extension::{ExtensionSet, PRELUDE_REGISTRY}; + use hugr::ops::handle::NodeHandle; + use hugr::ops::{MakeTuple, OpType, Tag, UnpackTuple}; + use hugr::types::{FunctionType, TypeRow}; + use hugr::{type_row, HugrView}; + use rstest::{fixture, rstest}; + + /// Builds a circuit in the style of guppy's output. + /// + /// This is composed of a `Module`, containing a `FuncDefn`, containing a + /// `CFG`, containing an `Exit` and a `DataflowBlock` with the actual + /// circuit. + #[fixture] + fn guppy_like_circuit() -> Circuit { + fn build() -> Result { + let two_qbs = type_row![QB_T, QB_T]; + let circ_signature = FunctionType::new_endo(two_qbs.clone()); + let circ; + + let mut builder = ModuleBuilder::new(); + let _func = { + let mut func = builder.define_function("main", circ_signature.into())?; + let [q1, q2] = func.input_wires_arr(); + + let cfg = { + let mut cfg = func.cfg_builder( + [(QB_T, q1), (QB_T, q2)], + None, + two_qbs.clone(), + ExtensionSet::new(), + )?; + + circ = { + let mut dfg = + cfg.simple_entry_builder(two_qbs.clone(), 1, ExtensionSet::new())?; + let [q1, q2] = dfg.input_wires_arr(); + + let [q1] = dfg.add_dataflow_op(Tk2Op::H, [q1])?.outputs_arr(); + let [q1, q2] = dfg.add_dataflow_op(Tk2Op::CX, [q1, q2])?.outputs_arr(); + + let [tup] = dfg + .add_dataflow_op(MakeTuple::new(two_qbs.clone()), [q1, q2])? + .outputs_arr(); + let [q1, q2] = dfg + .add_dataflow_op(UnpackTuple::new(two_qbs), [tup])? + .outputs_arr(); + + // Adds an empty Unit branch. + let [branch] = dfg + .add_dataflow_op(Tag::new(0, vec![TypeRow::new()]), [])? + .outputs_arr(); + + dfg.finish_with_outputs(branch, [q1, q2])? + }; + cfg.branch(&circ, 0, &cfg.exit_block())?; + + cfg.finish_sub_container()? + }; + let [q1, q2] = cfg.outputs_arr(); + + func.finish_with_outputs([q1, q2])? + }; + + let hugr = builder.finish_hugr(&PRELUDE_REGISTRY)?; + Ok(Circuit::new(hugr, circ.node())) + } + build().unwrap() + } + + #[rstest] + #[case::guppy_like_circuit(guppy_like_circuit())] + fn test_pytket_lowering(#[case] circ: Circuit) { + use cool_asserts::assert_matches; + + let lowered_circ = lower_to_pytket(&circ).unwrap(); + lowered_circ.hugr().validate(®ISTRY).unwrap(); + + assert_eq!(lowered_circ.parent(), lowered_circ.hugr().root()); + assert_matches!( + lowered_circ.hugr().get_optype(lowered_circ.parent()), + OpType::DFG(_) + ); + assert_matches!( + lowered_circ.hugr().get_optype(lowered_circ.input_node()), + OpType::Input(_) + ); + assert_matches!( + lowered_circ.hugr().get_optype(lowered_circ.output_node()), + OpType::Output(_) + ); + assert_eq!(lowered_circ.num_operations(), circ.num_operations()); + + // Check that the circuit signature is preserved. + let original_sig = circ.circuit_signature(); + let lowered_sig = lowered_circ.circuit_signature(); + assert_eq!(lowered_sig.input(), original_sig.input()); + + // The output signature may have changed due CFG branch tag removal. + let output_count_diff = + original_sig.output().len() as isize - lowered_sig.output().len() as isize; + assert!( + output_count_diff == 0 || output_count_diff == 1, + "Output count mismatch. Original: {}, Lowered: {}", + original_sig, + lowered_sig + ); + assert_eq!( + lowered_sig.output()[..], + original_sig.output()[output_count_diff as usize..] + ); + + // Check that the output node was successfully updated + let output_sig = lowered_circ + .hugr() + .signature(lowered_circ.output_node()) + .unwrap(); + assert_eq!(lowered_sig.output(), output_sig.input()); + } +}