diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 7581cd0cf..6ea962473 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -437,6 +437,27 @@ pub trait Dataflow: Container { TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop) } + /// Return a builder for a [`crate::ops::Conditional`] node. + /// `sum_input` is a tuple of the type of the Sum + /// variants and the corresponding wire. + /// + /// The `other_inputs` must be an iterable over pairs of the type of the input and + /// the corresponding wire. + /// The `outputs` are the types of the outputs. Extension delta will be inferred. + /// + /// # Errors + /// + /// This function will return an error if there is an error when building + /// the Conditional node. + fn conditional_builder( + &mut self, + sum_input: (impl IntoIterator, Wire), + other_inputs: impl IntoIterator, + output_types: TypeRow, + ) -> Result, BuildError> { + self.conditional_builder_exts(sum_input, other_inputs, TO_BE_INFERRED, output_types) + } + /// Return a builder for a [`crate::ops::Conditional`] node. /// `sum_rows` and `sum_wire` define the type of the Sum /// variants and the wire carrying the Sum respectively. @@ -449,12 +470,12 @@ pub trait Dataflow: Container { /// /// This function will return an error if there is an error when building /// the Conditional node. - fn conditional_builder( + fn conditional_builder_exts( &mut self, (sum_rows, sum_wire): (impl IntoIterator, Wire), other_inputs: impl IntoIterator, + extension_delta: impl Into, output_types: TypeRow, - extension_delta: ExtensionSet, ) -> Result, BuildError> { let mut input_wires = vec![sum_wire]; let (input_types, rest_input_wires): (Vec, Vec) = @@ -471,7 +492,7 @@ pub trait Dataflow: Container { sum_rows, other_inputs: inputs, outputs: output_types, - extension_delta, + extension_delta: extension_delta.into(), }, input_wires, )?; diff --git a/hugr-core/src/builder/conditional.rs b/hugr-core/src/builder/conditional.rs index ab4716d66..bfdd859a5 100644 --- a/hugr-core/src/builder/conditional.rs +++ b/hugr-core/src/builder/conditional.rs @@ -245,7 +245,6 @@ mod test { ([type_row![], type_row![]], const_wire), other_inputs, outputs, - ExtensionSet::new(), )?; n_identity(conditional_b.case_builder(0)?)?; diff --git a/hugr-core/src/builder/tail_loop.rs b/hugr-core/src/builder/tail_loop.rs index 1d0983e24..9f0ad44cd 100644 --- a/hugr-core/src/builder/tail_loop.rs +++ b/hugr-core/src/builder/tail_loop.rs @@ -165,7 +165,6 @@ mod test { ([type_row![], type_row![]], const_wire), vec![(BIT, b1)], output_row, - PRELUDE_ID.into(), )?; let mut branch_0 = conditional_b.case_builder(0)?; diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index 35971dc2e..fa6c15560 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -352,11 +352,11 @@ mod test { .add_dataflow_op(test_quantum_extension::measure(), r.outputs())? .outputs_arr(); // Node using the boolean. Here we just select between two empty computations. - let mut if_n = inner.conditional_builder( + let mut if_n = inner.conditional_builder_exts( ([type_row![], type_row![]], b), [], - type_row![], ExtensionSet::new(), + type_row![], )?; if_n.case_builder(0)?.finish_with_outputs([])?; if_n.case_builder(1)?.finish_with_outputs([])?; diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index fe40bd0cb..10808bf14 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -665,7 +665,6 @@ mod test { (vec![type_row![]; 2], b), [(USIZE_T, i)], type_row![USIZE_T], - ExtensionSet::new(), )?; let mut case1 = cond.case_builder(0)?; let foo = case1.add_dataflow_op(mk_op("foo"), case1.input_wires())?; diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 3e65454b3..6fc6e3a83 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -792,7 +792,6 @@ fn test_polymorphic_call() -> Result<(), Box> { (vec![type_row![USIZE_T; 2]], tup), vec![], type_row![USIZE_T;2], - TO_BE_INFERRED.into(), //TODO make default //es.clone().union(EXT_ID.into()), )?; let mut cc = c.case_builder(0)?; let [i1, i2] = cc.input_wires_arr();