diff --git a/src/builder/circuit.rs b/src/builder/circuit.rs index d730566dd..127f62479 100644 --- a/src/builder/circuit.rs +++ b/src/builder/circuit.rs @@ -137,22 +137,33 @@ mod test { test::{build_main, NAT, QB}, Dataflow, DataflowSubContainer, Wire, }, - extension::prelude::BOOL_T, + extension::{prelude::BOOL_T, ExtensionSet}, ops::{custom::OpaqueOp, LeafOp}, type_row, types::FunctionType, - utils::test_quantum_extension::{cx_gate, h_gate, measure}, + utils::test_quantum_extension::{cx_gate, h_gate, measure, EXTENSION_ID}, }; #[test] fn simple_linear() { let build_res = build_main( - FunctionType::new(type_row![QB, QB], type_row![QB, QB]).into(), + FunctionType::new_endo(type_row![QB, QB]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into(), |mut f_build| { - let wires = f_build.input_wires().collect(); + let mut wires: [Wire; 2] = f_build.input_wires_arr(); + [wires[1]] = f_build + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![QB].into(), + new_extension: EXTENSION_ID, + }, + [wires[1]], + )? + .outputs_arr(); let mut linear = CircuitBuilder { - wires, + wires: Vec::from(wires), builder: &mut f_build, }; @@ -184,10 +195,20 @@ mod test { .into(), ); let build_res = build_main( - FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(), + FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into(), |mut f_build| { let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr(); - + let [angle] = f_build + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![NAT].into(), + new_extension: EXTENSION_ID, + }, + [angle], + )? + .outputs_arr(); let mut linear = f_build.as_circuit(vec![q0, q1]); let measure_out = linear diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 879b43f93..1e711737d 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -213,9 +213,9 @@ pub(crate) mod test { use crate::hugr::validate::InterGraphEdgeError; use crate::ops::{handle::NodeHandle, LeafOp, OpTag}; - use crate::std_extensions::logic::test::and_op; + use crate::std_extensions::logic::{self, test::and_op}; use crate::types::Type; - use crate::utils::test_quantum_extension::h_gate; + use crate::utils::test_quantum_extension::{self, h_gate}; use crate::{ builder::{ test::{n_identity, BIT, NAT, QB}, @@ -235,13 +235,25 @@ pub(crate) mod test { let _f_id = { let mut func_builder = module_builder.define_function( "main", - FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into(), + FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]) + .with_extension_delta(&ExtensionSet::singleton( + &test_quantum_extension::EXTENSION_ID, + )) + .into(), )?; let [int, qb] = func_builder.input_wires_arr(); let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?; - + let [int] = func_builder + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![NAT].into(), + new_extension: test_quantum_extension::EXTENSION_ID, + }, + [int], + )? + .outputs_arr(); let inner_builder = func_builder.dfg_builder( FunctionType::new(type_row![NAT], type_row![NAT]), None, @@ -260,7 +272,7 @@ pub(crate) mod test { } // Scaffolding for copy insertion tests - fn copy_scaffold(f: F, msg: &'static str) -> Result<(), BuildError> + fn copy_scaffold(f: F, delta: &ExtensionSet, msg: &'static str) -> Result<(), BuildError> where F: FnOnce(FunctionBuilder<&mut Hugr>) -> Result>, BuildError>, { @@ -269,7 +281,9 @@ pub(crate) mod test { let f_build = module_builder.define_function( "main", - FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).into(), + FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]) + .with_extension_delta(delta) + .into(), )?; f(f_build)?; @@ -287,15 +301,27 @@ pub(crate) mod test { let [b1] = f_build.input_wires_arr(); f_build.finish_with_outputs([b1, b1]) }, + &ExtensionSet::new(), "Copy input and output", )?; + let es = ExtensionSet::singleton(&logic::EXTENSION_ID); copy_scaffold( |mut f_build| { let [b1] = f_build.input_wires_arr(); let xor = f_build.add_dataflow_op(and_op(), [b1, b1])?; + let [b1] = f_build + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![BOOL_T].into(), + new_extension: logic::EXTENSION_ID, + }, + [b1], + )? + .outputs_arr(); f_build.finish_with_outputs([xor.out_wire(0), b1]) }, + &es, "Copy input and use with binary function", )?; @@ -303,9 +329,19 @@ pub(crate) mod test { |mut f_build| { let [b1] = f_build.input_wires_arr(); let xor1 = f_build.add_dataflow_op(and_op(), [b1, b1])?; + let [b1] = f_build + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![BOOL_T].into(), + new_extension: logic::EXTENSION_ID, + }, + [b1], + )? + .outputs_arr(); let xor2 = f_build.add_dataflow_op(and_op(), [b1, xor1.out_wire(0)])?; f_build.finish_with_outputs([xor2.out_wire(0), b1]) }, + &es, "Copy multiple times", )?; diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 84e22b65a..5187f32eb 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -648,11 +648,12 @@ impl UnificationContext { fn search_variable_deps(&self) -> HashSet { let mut seen = HashSet::new(); let mut new_variables: HashSet = self.variables.clone(); + let constraints_for_solved = HashSet::new(); while !new_variables.is_empty() { new_variables = new_variables .into_iter() .filter(|m| seen.insert(*m)) - .flat_map(|m| self.get_constraints(&m).unwrap()) + .flat_map(|m| self.get_constraints(&m).unwrap_or(&constraints_for_solved)) .map(|c| match c { Constraint::Plus(_, other) => self.resolve(*other), Constraint::Equal(other) => self.resolve(*other), diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 143426123..3e39f6523 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -223,7 +223,7 @@ impl SignatureFunc { /// /// This function will return an error if the type arguments are invalid or /// there is some error in type computation. - pub fn compute_signature( + fn compute_signature( &self, def: &OpDef, args: &[TypeArg], @@ -245,10 +245,10 @@ impl SignatureFunc { } }; - let res = pf.instantiate(args, exts)?; - // TODO bring this assert back once resource inference is done? - // https://github.com/CQCL/hugr/issues/388 - // debug_assert!(res.extension_reqs.contains(def.extension())); + let mut res = pf.instantiate(args, exts)?; + res.extension_reqs = res + .extension_reqs + .union(&ExtensionSet::singleton(def.extension())); Ok(res) } } @@ -541,10 +541,10 @@ mod test { let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()]; assert_eq!( def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok(FunctionType::new( - vec![USIZE_T; 3], - vec![Type::new_tuple(vec![USIZE_T; 3])] - )) + Ok( + FunctionType::new(vec![USIZE_T; 3], vec![Type::new_tuple(vec![USIZE_T; 3])]) + .with_extension_delta(&ExtensionSet::singleton(&EXT_ID)) + ) ); assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); @@ -554,10 +554,10 @@ mod test { let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; assert_eq!( def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok(FunctionType::new( - tyvars.clone(), - vec![Type::new_tuple(tyvars)] - )) + Ok( + FunctionType::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) + .with_extension_delta(&ExtensionSet::singleton(&EXT_ID)) + ) ); def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Eq.into()]) .unwrap(); @@ -610,7 +610,8 @@ mod test { def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); assert_eq!( def.compute_signature(&args, &EMPTY_REG), - Ok(FunctionType::new_endo(vec![tv])) + Ok(FunctionType::new_endo(vec![tv]) + .with_extension_delta(&ExtensionSet::singleton(&EXT_ID))) ); Ok(()) } diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 1ca02b4ac..17b41a153 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -477,12 +477,11 @@ mod test { .unwrap() .into(); let just_list = TypeRow::from(vec![listy.clone()]); + let exset = ExtensionSet::singleton(&collections::EXTENSION_NAME); let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]); let mut cfg = CFGBuilder::new( - // One might expect an extension_delta of "collections" here, but push/pop - // have an empty delta themselves, pending https://github.com/CQCL/hugr/issues/388 - FunctionType::new_endo(just_list.clone()), + FunctionType::new_endo(just_list.clone()).with_extension_delta(&exset), )?; let pred_const = cfg.add_constant(ops::Const::unary_unit_sum())?; @@ -628,13 +627,31 @@ mod test { }, op_sig.input() ); - h.simple_entry_builder(op_sig.output, 1, op_sig.extension_reqs.clone())? + h.simple_entry_builder(op_sig.output.clone(), 1, op_sig.extension_reqs.clone())? } else { - h.simple_block_builder(op_sig, 1)? + h.simple_block_builder(op_sig.clone(), 1)? }; let op: OpType = op.into(); let op = bb.add_dataflow_op(op, bb.input_wires())?; - let load_pred = bb.load_const(pred_const)?; + let mut load_pred = bb.load_const(pred_const)?; + let const_ty = bb + .hugr() + .get_optype(pred_const.node()) + .as_const() + .unwrap() + .const_type() + .clone(); + for e in op_sig.extension_reqs.iter() { + [load_pred] = bb + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![const_ty.clone()].into(), + new_extension: e.clone(), + }, + [load_pred], + )? + .outputs_arr(); + } bb.finish_with_outputs(load_pred, op.outputs()) } diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 2f14dee3e..469017d86 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -221,7 +221,7 @@ pub(in crate::hugr::rewrite) mod test { HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::BOOL_T; - use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; + use crate::extension::{ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Rewrite}; use crate::ops::dataflow::DataflowOpTrait; @@ -230,7 +230,7 @@ pub(in crate::hugr::rewrite) mod test { use crate::std_extensions::logic::test::and_op; use crate::type_row; use crate::types::{FunctionType, Type}; - use crate::utils::test_quantum_extension::{cx_gate, h_gate}; + use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID}; use crate::{IncomingPort, Node}; use super::SimpleReplacement; @@ -249,9 +249,12 @@ pub(in crate::hugr::rewrite) mod test { fn make_hugr() -> Result { let mut module_builder = ModuleBuilder::new(); let _f_id = { + let delta = ExtensionSet::singleton(&EXTENSION_ID); let mut func_builder = module_builder.define_function( "main", - FunctionType::new(type_row![QB, QB, QB], type_row![QB, QB, QB]).into(), + FunctionType::new_endo(type_row![QB, QB, QB]) + .with_extension_delta(&delta) + .into(), )?; let [qb0, qb1, qb2] = func_builder.input_wires_arr(); @@ -259,7 +262,7 @@ pub(in crate::hugr::rewrite) mod test { let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb2])?; let mut inner_builder = func_builder.dfg_builder( - FunctionType::new(type_row![QB, QB], type_row![QB, QB]), + FunctionType::new_endo(type_row![QB, QB]).with_extension_delta(&delta), None, [qb0, qb1], )?; diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index dc8e9add2..c20c8a767 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -16,7 +16,7 @@ use crate::macros::const_extension_ids; use crate::ops::dataflow::IOTrait; use crate::ops::{self, Const, LeafOp, OpType}; use crate::std_extensions::logic::test::{and_op, or_op}; -use crate::std_extensions::logic::{self, NotOp}; +use crate::std_extensions::logic::{self, NotOp, EXTENSION_ID}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow}; use crate::values::Value; @@ -360,17 +360,18 @@ fn cfg_children_restrictions() { #[test] fn test_ext_edge() -> Result<(), HugrError> { - let mut h = closed_dfg_root_hugr(FunctionType::new( - type_row![BOOL_T, BOOL_T], - type_row![BOOL_T], - )); + let delta = ExtensionSet::singleton(&EXTENSION_ID); + let mut h = closed_dfg_root_hugr( + FunctionType::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T]) + .with_extension_delta(&delta), + ); let [input, output] = h.get_io(h.root()).unwrap(); // Nested DFG BOOL_T -> BOOL_T let sub_dfg = h.add_node_with_parent( h.root(), ops::DFG { - signature: FunctionType::new_endo(type_row![BOOL_T]), + signature: FunctionType::new_endo(type_row![BOOL_T]).with_extension_delta(&delta), }, )?; // this Xor has its 2nd input unconnected @@ -411,7 +412,10 @@ const_extension_ids! { #[test] fn test_local_const() -> Result<(), HugrError> { - let mut h = closed_dfg_root_hugr(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])); + let mut h = closed_dfg_root_hugr( + FunctionType::new_endo(type_row![BOOL_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)), + ); let [input, output] = h.get_io(h.root()).unwrap(); let and = h.add_node_with_parent(h.root(), and_op())?; h.connect(input, 0, and, 0)?; diff --git a/src/hugr/views/descendants.rs b/src/hugr/views/descendants.rs index f0c515457..8b89e9836 100644 --- a/src/hugr/views/descendants.rs +++ b/src/hugr/views/descendants.rs @@ -201,10 +201,11 @@ where pub(super) mod test { use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, - ops::handle::NodeHandle, + extension::ExtensionSet, + ops::{handle::NodeHandle, LeafOp}, type_row, types::{FunctionType, Type}, - utils::test_quantum_extension::h_gate, + utils::test_quantum_extension::{h_gate, EXTENSION_ID}, }; use super::*; @@ -222,16 +223,27 @@ pub(super) mod test { let (f_id, inner_id) = { let mut func_builder = module_builder.define_function( "main", - FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into(), + FunctionType::new_endo(type_row![NAT, QB]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into(), )?; let [int, qb] = func_builder.input_wires_arr(); let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?; + let [int] = func_builder + .add_dataflow_op( + LeafOp::Lift { + type_row: type_row![NAT], + new_extension: EXTENSION_ID, + }, + [int], + )? + .outputs_arr(); let inner_id = { let inner_builder = func_builder.dfg_builder( - FunctionType::new(type_row![NAT], type_row![NAT]), + FunctionType::new_endo(type_row![NAT]), None, [int], )?; @@ -249,11 +261,11 @@ pub(super) mod test { #[test] fn full_region() -> Result<(), Box> { - let (hugr, def, inner) = make_module_hgr()?; + let (hugr, def, inner) = make_module_hgr().unwrap(); let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; - assert_eq!(region.node_count(), 7); + assert_eq!(region.node_count(), 8); assert!(region.nodes().all(|n| n == def || hugr.get_parent(n) == Some(def) || hugr.get_parent(n) == Some(inner))); @@ -261,7 +273,11 @@ pub(super) mod test { assert_eq!( region.get_function_type(), - Some(FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into()) + Some( + FunctionType::new_endo(type_row![NAT, QB]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into() + ) ); let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; assert_eq!( diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index bcc122361..4bf6c7ac2 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -387,7 +387,7 @@ mod test { let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?; - assert_eq!(region.node_count(), 5); + assert_eq!(region.node_count(), 6); assert!(region .nodes() .all(|n| n == def || hugr.get_parent(n) == Some(def))); diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 461083378..116b429bf 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -287,8 +287,9 @@ impl SiblingSubgraph { &self.outputs } - /// The signature of the subgraph. + /// The signature of the subgraph, excluding any extension delta pub fn signature(&self, hugr: &impl HugrView) -> FunctionType { + // We cannot calculate the delta from just the extensions at the input and output! let input = self .inputs .iter() @@ -344,7 +345,12 @@ impl SiblingSubgraph { let Some([rep_input, rep_output]) = replacement.get_io(rep_root) else { return Err(InvalidReplacement::InvalidDataflowParent); }; - if dfg_optype.dataflow_signature() != Some(self.signature(hugr)) { + if !dfg_optype.dataflow_signature().is_some_and(|rep_sig| { + rep_sig + == self + .signature(hugr) + .with_extension_delta(&rep_sig.extension_reqs) + }) { return Err(InvalidReplacement::InvalidSignature); } @@ -408,8 +414,15 @@ impl SiblingSubgraph { &self, hugr: &impl HugrView, name: impl Into, + extension_delta: &crate::extension::ExtensionSet, ) -> Result { - let mut builder = FunctionBuilder::new(name, self.signature(hugr).into()).unwrap(); + let mut builder = FunctionBuilder::new( + name, + self.signature(hugr) + .with_extension_delta(extension_delta) + .into(), + ) + .unwrap(); // Take the unfinished Hugr from the builder, to avoid unnecessary // validation checks that require connecting the inputs and outputs. let mut extracted = mem::take(builder.hugr_mut()); @@ -675,8 +688,9 @@ mod tests { use cool_asserts::assert_matches; - use crate::extension::PRELUDE_REGISTRY; - use crate::utils::test_quantum_extension::cx_gate; + use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; + use crate::ops::LeafOp; + use crate::utils::test_quantum_extension::{cx_gate, EXTENSION_ID}; use crate::{ builder::{ BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, @@ -723,17 +737,26 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]).into(), + FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)) + .into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let [w0, w1, w2] = dfg.input_wires_arr(); let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr(); + let [w2] = dfg + .add_dataflow_op( + LeafOp::Lift { + type_row: vec![QB_T].into(), + new_extension: EXTENSION_ID, + }, + [w2], + )? + .outputs_arr(); dfg.finish_with_outputs([w0, w1, w2])? }; - let hugr = mod_builder - .finish_prelude_hugr() - .map_err(|e| -> BuildError { e.into() })?; + let hugr = mod_builder.finish_prelude_hugr()?; Ok((hugr, func_id.node())) } @@ -797,18 +820,32 @@ mod tests { let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; let empty_dfg = { - let builder = DFGBuilder::new(FunctionType::new_endo(type_row![QB_T, QB_T])).unwrap(); + let mut builder = DFGBuilder::new( + FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)), + ) + .unwrap(); let inputs = builder.input_wires(); - builder.finish_prelude_hugr_with_outputs(inputs).unwrap() + let lifted = builder + .add_dataflow_op( + LeafOp::Lift { + type_row: type_row![QB_T, QB_T, QB_T], + new_extension: EXTENSION_ID, + }, + inputs, + ) + .unwrap(); + builder.set_outputs(lifted.outputs()).unwrap(); + builder.finish_prelude_hugr().unwrap() }; let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap(); - assert_eq!(rep.subgraph().nodes().len(), 1); + assert_eq!(rep.subgraph().nodes().len(), 2); - assert_eq!(hugr.node_count(), 5); // Module + Def + In + CX + Out + assert_eq!(hugr.node_count(), 6); // Module + Def + In + CX + Lift + Out hugr.apply_rewrite(rep).unwrap(); - assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out + assert_eq!(hugr.node_count(), 5); // Module + Def + In + Lift + Out Ok(()) } @@ -818,11 +855,10 @@ mod tests { let (hugr, dfg) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, dfg).unwrap(); let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; - // The identity wire on the third qubit is ignored, so the subgraph's signature only contains - // the first two qubits. + // The third wire is included because of the "Lift" node assert_eq!( sub.signature(&func), - FunctionType::new_endo(type_row![QB_T, QB_T]) + FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]) ); Ok(()) } @@ -855,7 +891,7 @@ mod tests { .unwrap() .nodes() .len(), - 1 + 2 // Include Lift node ) } @@ -960,7 +996,8 @@ mod tests { let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); - let extracted = subgraph.extract_subgraph(&hugr, "region")?; + let extracted = + subgraph.extract_subgraph(&hugr, "region", &ExtensionSet::singleton(&EXTENSION_ID))?; extracted.validate(&PRELUDE_REGISTRY).unwrap(); diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 97fb50861..43ccfbaee 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -161,12 +161,16 @@ fn static_targets() { #[test] fn test_dataflow_ports_only() { use crate::builder::DataflowSubContainer; - use crate::extension::{prelude::BOOL_T, PRELUDE_REGISTRY}; + use crate::extension::{prelude::BOOL_T, ExtensionSet, PRELUDE_REGISTRY}; use crate::hugr::views::PortIterator; - use crate::std_extensions::logic::NotOp; + use crate::std_extensions::logic::{NotOp, EXTENSION_ID}; use itertools::Itertools; - let mut dfg = DFGBuilder::new(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])).unwrap(); + let mut dfg = DFGBuilder::new( + FunctionType::new_endo(type_row![BOOL_T]) + .with_extension_delta(&ExtensionSet::singleton(&EXTENSION_ID)), + ) + .unwrap(); let local_and = { let local_and = dfg .define_function( @@ -189,6 +193,25 @@ fn test_dataflow_ports_only() { ) .unwrap(); dfg.add_other_wire(not.node(), call.node()).unwrap(); + + // As temporary workaround for https://github.com/CQCL/hugr/issues/695 + // We force the input-extensions of the FuncDefn node to include the logic + // extension, so the static edge from the FuncDefn to the call has the same + // extensions as the result of the "not". + { + let nt = dfg.hugr_mut().op_types.get_mut(local_and.node().pg_index()); + assert_eq!(nt.input_extensions, Some(ExtensionSet::new())); + nt.input_extensions = Some(ExtensionSet::singleton(&EXTENSION_ID)); + } + // Note that presently the builder sets too many input-exts that could be + // left to the inference (https://github.com/CQCL/hugr/issues/702) hence we + // must manually change these too, although we can let inference deal with them + for node in dfg.hugr().get_io(local_and.node()).unwrap() { + let nt = dfg.hugr_mut().op_types.get_mut(node.pg_index()); + assert_eq!(nt.input_extensions, Some(ExtensionSet::new())); + nt.input_extensions = None; + } + let h = dfg .finish_hugr_with_outputs(not.outputs(), &PRELUDE_REGISTRY) .unwrap(); diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 267fde902..2ebae3e53 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -349,13 +349,15 @@ mod test { } #[test] fn test_binary_signatures() { + let delta = ExtensionSet::singleton(&EXTENSION_ID); assert_eq!( IntOpDef::iwiden_s .with_two_widths(3, 4) .to_extension_op() .unwrap() .signature(), - FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],) + FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))]) + .with_extension_delta(&delta), ); assert_eq!( IntOpDef::iwiden_s @@ -363,7 +365,8 @@ mod test { .to_extension_op() .unwrap() .signature(), - FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(3))],) + FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(3))]) + .with_extension_delta(&delta), ); assert_eq!( IntOpDef::inarrow_s @@ -371,7 +374,8 @@ mod test { .to_extension_op() .unwrap() .signature(), - FunctionType::new(vec![int_type(ta(3))], vec![sum_with_error(int_type(ta(3)))],) + FunctionType::new(vec![int_type(ta(3))], vec![sum_with_error(int_type(ta(3)))]) + .with_extension_delta(&delta), ); assert!( IntOpDef::iwiden_u @@ -387,7 +391,8 @@ mod test { .to_extension_op() .unwrap() .signature(), - FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))],) + FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))]) + .with_extension_delta(&delta) ); assert!(IntOpDef::inarrow_u