diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 53c92ed56..cc8caa01f 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -304,11 +304,9 @@ impl SiblingSubgraph { }) .collect_vec(); Signature::new(input, output).with_extension_delta(ExtensionSet::union_over( - self.nodes.iter().map(|n| { - hugr.signature(*n) - .expect("all nodes must have dataflow signature") - .extension_reqs - }), + self.nodes + .iter() + .map(|n| hugr.get_optype(*n).extension_delta()), )) } @@ -729,14 +727,14 @@ pub enum InvalidSubgraphBoundary { #[cfg(test)] mod tests { - use std::error::Error; - use cool_asserts::assert_matches; use crate::builder::inout_sig; - use crate::extension::PRELUDE_REGISTRY; + use crate::extension::{prelude, ExtensionRegistry}; + use crate::ops::Const; + use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; use crate::std_extensions::logic; - use crate::utils::test_quantum_extension::{self, cx_gate}; + use crate::utils::test_quantum_extension::{self, cx_gate, rz_f64}; use crate::{ builder::{ BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, @@ -778,22 +776,36 @@ mod tests { } } + /// A Module with a single function from three qubits to three qubits. + /// The function applies a CX gate to the first two qubits and a Rz gate (with a constant angle) to the last qubit. fn build_hugr() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", Signature::new_endo(type_row![QB_T, QB_T, QB_T]) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) + .with_extension_delta(ExtensionSet::from_iter([ + test_quantum_extension::EXTENSION_ID, + float_types::EXTENSION_ID, + ])) .into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let [w0, w1, w2] = dfg.input_wires_arr(); let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr(); + let c = dfg.add_load_const(Const::new(ConstF64::new(0.5).into())); + let [w2] = dfg.add_dataflow_op(rz_f64(), [w2, c])?.outputs_arr(); dfg.finish_with_outputs([w0, w1, w2])? }; let hugr = mod_builder - .finish_prelude_hugr() + .finish_hugr( + &ExtensionRegistry::try_new([ + prelude::PRELUDE.to_owned(), + test_quantum_extension::EXTENSION.to_owned(), + float_types::EXTENSION.to_owned(), + ]) + .unwrap(), + ) .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -888,16 +900,17 @@ mod tests { let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; let empty_dfg = { - let builder = DFGBuilder::new(Signature::new_endo(type_row![QB_T, QB_T])).unwrap(); + let builder = + DFGBuilder::new(Signature::new_endo(type_row![QB_T, QB_T, QB_T])).unwrap(); let inputs = builder.input_wires(); builder.finish_prelude_hugr_with_outputs(inputs).unwrap() }; let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap(); - assert_eq!(rep.subgraph().nodes().len(), 1); + assert_eq!(rep.subgraph().nodes().len(), 4); - assert_eq!(hugr.node_count(), 5); // Module + Def + In + CX + Out + assert_eq!(hugr.node_count(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out hugr.apply_rewrite(rep).unwrap(); assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out @@ -909,12 +922,14 @@ mod tests { let (hugr, dfg) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, dfg).unwrap(); let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; - // The identity wire on the third qubit is ignored, so the subgraph's signature only contains - // the first two qubits. assert_eq!( sub.signature(&func), - Signature::new_endo(type_row![QB_T, QB_T]) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) + Signature::new_endo(type_row![QB_T, QB_T, QB_T]).with_extension_delta( + ExtensionSet::from_iter([ + test_quantum_extension::EXTENSION_ID, + float_types::EXTENSION_ID, + ]) + ) ); Ok(()) } @@ -947,7 +962,7 @@ mod tests { .unwrap() .nodes() .len(), - 1 + 4 ) } @@ -1064,15 +1079,23 @@ mod tests { } #[test] - fn extract_subgraph() -> Result<(), Box> { - let (hugr, func_root) = build_hugr()?; - let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root)?; - let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph)?; + fn extract_subgraph() { + let (hugr, func_root) = build_hugr().unwrap(); + let func_graph: SiblingGraph<'_, FuncID> = + SiblingGraph::try_new(&hugr, func_root).unwrap(); + let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); let extracted = subgraph.extract_subgraph(&hugr, "region"); - extracted.validate(&PRELUDE_REGISTRY)?; - - Ok(()) + extracted + .validate( + &ExtensionRegistry::try_new([ + prelude::PRELUDE.to_owned(), + test_quantum_extension::EXTENSION.to_owned(), + float_types::EXTENSION.to_owned(), + ]) + .unwrap(), + ) + .unwrap(); } #[test]