diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index e91c36782..1cb8e242e 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -305,6 +305,7 @@ pub trait Dataflow: Container { /// The `inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. /// The `output_types` are the types of the outputs. + /// The Extension delta will be inferred. /// /// # Errors /// @@ -314,7 +315,27 @@ pub trait Dataflow: Container { &mut self, inputs: impl IntoIterator, output_types: TypeRow, - extension_delta: ExtensionSet, + ) -> Result, BuildError> { + self.cfg_builder_exts(inputs, output_types, TO_BE_INFERRED) + } + + /// Return a builder for a [`crate::ops::CFG`] node, + /// i.e. a nested controlflow subgraph. + /// The `inputs` must be an iterable over pairs of the type of the input and + /// the corresponding wire. + /// The `output_types` are the types of the outputs. + /// `extension_delta` is explicitly specified. Alternatively + /// [cfg_builder](Self::cfg_builder) may be used to infer it. + /// + /// # Errors + /// + /// This function will return an error if there is an error when building + /// the CFG node. + fn cfg_builder_exts( + &mut self, + inputs: impl IntoIterator, + output_types: TypeRow, + extension_delta: impl Into, ) -> Result, BuildError> { let (input_types, input_wires): (Vec, Vec) = inputs.into_iter().unzip(); @@ -405,17 +426,39 @@ pub trait Dataflow: Container { /// The `inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. /// The `output_types` are the types of the outputs. + /// The extension delta will be inferred. /// /// # Errors /// /// This function will return an error if there is an error when building /// the [`ops::TailLoop`] node. + /// fn tail_loop_builder( &mut self, just_inputs: impl IntoIterator, inputs_outputs: impl IntoIterator, just_out_types: TypeRow, - extension_delta: ExtensionSet, + ) -> Result, BuildError> { + self.tail_loop_builder_exts(just_inputs, inputs_outputs, just_out_types, TO_BE_INFERRED) + } + + /// Return a builder for a [`crate::ops::TailLoop`] node. + /// The `inputs` must be an iterable over pairs of the type of the input and + /// the corresponding wire. + /// The `output_types` are the types of the outputs. + /// `extension_delta` explicitly specified. Alternatively + /// [tail_loop_builder](Self::tail_loop_builder) may be used to infer it. + /// + /// # Errors + /// + /// This function will return an error if there is an error when building + /// the [`ops::TailLoop`] node. + fn tail_loop_builder_exts( + &mut self, + just_inputs: impl IntoIterator, + inputs_outputs: impl IntoIterator, + just_out_types: TypeRow, + extension_delta: impl Into, ) -> Result, BuildError> { let (input_types, mut input_wires): (Vec, Vec) = just_inputs.into_iter().unzip(); @@ -427,7 +470,7 @@ pub trait Dataflow: Container { just_inputs: input_types.into(), just_outputs: just_out_types, rest: rest_types.into(), - extension_delta, + extension_delta: extension_delta.into(), }; // TODO: Make input extensions a parameter let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?; @@ -463,7 +506,7 @@ pub trait Dataflow: Container { /// The `other_inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// `exts` explicitly specifies the extension delta. Alternatively + /// `extension_delta` is explicitly specified. Alternatively /// [conditional_builder](Self::conditional_builder) may be used to infer it. /// /// # Errors diff --git a/hugr-core/src/builder/cfg.rs b/hugr-core/src/builder/cfg.rs index 8f974d8c5..bcb6aff3b 100644 --- a/hugr-core/src/builder/cfg.rs +++ b/hugr-core/src/builder/cfg.rs @@ -413,12 +413,24 @@ impl + AsRef> BlockBuilder { } impl BlockBuilder { - /// Initialize a [`DataflowBlock`] rooted HUGR builder + /// Initialize a [`DataflowBlock`] rooted HUGR builder. + /// Extension delta will be inferred. pub fn new( inputs: impl Into, sum_rows: impl IntoIterator, other_outputs: impl Into, - extension_delta: ExtensionSet, + ) -> Result { + Self::new_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED) + } + + /// Initialize a [`DataflowBlock`] rooted HUGR builder. + /// `extension_delta` is explicitly specified; alternatively, [new](Self::new) + /// may be used to infer it. + pub fn new_exts( + inputs: impl Into, + sum_rows: impl IntoIterator, + other_outputs: impl Into, + extension_delta: impl Into, ) -> Result { let inputs = inputs.into(); let sum_rows: Vec<_> = sum_rows.into_iter().collect(); @@ -427,7 +439,7 @@ impl BlockBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), sum_rows, - extension_delta, + extension_delta: extension_delta.into(), }; let base = Hugr::new(op); @@ -468,11 +480,8 @@ pub(crate) mod test { let [int] = func_builder.input_wires_arr(); let cfg_id = { - let mut cfg_builder = func_builder.cfg_builder( - vec![(NAT, int)], - type_row![NAT], - ExtensionSet::new(), - )?; + let mut cfg_builder = + func_builder.cfg_builder(vec![(NAT, int)], type_row![NAT])?; build_basic_cfg(&mut cfg_builder)?; cfg_builder.finish_sub_container()? diff --git a/hugr-core/src/builder/conditional.rs b/hugr-core/src/builder/conditional.rs index e4e56cdc1..4d7fd6bd3 100644 --- a/hugr-core/src/builder/conditional.rs +++ b/hugr-core/src/builder/conditional.rs @@ -1,4 +1,4 @@ -use crate::extension::ExtensionRegistry; +use crate::extension::{ExtensionRegistry, TO_BE_INFERRED}; use crate::hugr::views::HugrView; use crate::ops::dataflow::DataflowOpTrait; use crate::types::{Signature, TypeRow}; @@ -152,12 +152,23 @@ impl HugrBuilder for ConditionalBuilder { } impl ConditionalBuilder { - /// Initialize a Conditional rooted HUGR builder + /// Initialize a Conditional rooted HUGR builder, extension delta will be inferred. pub fn new( sum_rows: impl IntoIterator, other_inputs: impl Into, outputs: impl Into, - extension_delta: ExtensionSet, + ) -> Result { + Self::new_exts(sum_rows, other_inputs, outputs, TO_BE_INFERRED) + } + + /// Initialize a Conditional rooted HUGR builder, + /// `extension_delta` explicitly specified. Alternatively, + /// [new](Self::new) may be used to infer it. + pub fn new_exts( + sum_rows: impl IntoIterator, + other_inputs: impl Into, + outputs: impl Into, + extension_delta: impl Into, ) -> Result { let sum_rows: Vec<_> = sum_rows.into_iter().collect(); let other_inputs = other_inputs.into(); @@ -170,7 +181,7 @@ impl ConditionalBuilder { sum_rows, other_inputs, outputs, - extension_delta, + extension_delta: extension_delta.into(), }; let base = Hugr::new(op); let conditional_node = base.root(); @@ -216,7 +227,7 @@ mod test { #[test] fn basic_conditional() -> Result<(), BuildError> { - let mut conditional_b = ConditionalBuilder::new( + let mut conditional_b = ConditionalBuilder::new_exts( [type_row![], type_row![]], type_row![NAT], type_row![NAT], @@ -265,12 +276,8 @@ mod test { #[test] fn test_not_all_cases() -> Result<(), BuildError> { - let mut builder = ConditionalBuilder::new( - [type_row![], type_row![]], - type_row![], - type_row![], - ExtensionSet::new(), - )?; + let mut builder = + ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?; n_identity(builder.case_builder(0)?)?; assert_matches!( builder.finish_sub_container().map(|_| ()), @@ -283,12 +290,8 @@ mod test { #[test] fn test_case_already_built() -> Result<(), BuildError> { - let mut builder = ConditionalBuilder::new( - [type_row![], type_row![]], - type_row![], - type_row![], - ExtensionSet::new(), - )?; + let mut builder = + ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?; n_identity(builder.case_builder(0)?)?; assert_matches!( builder.case_builder(0).map(|_| ()), diff --git a/hugr-core/src/builder/tail_loop.rs b/hugr-core/src/builder/tail_loop.rs index 7d713cc2c..577960354 100644 --- a/hugr-core/src/builder/tail_loop.rs +++ b/hugr-core/src/builder/tail_loop.rs @@ -1,4 +1,4 @@ -use crate::extension::ExtensionSet; +use crate::extension::{ExtensionSet, TO_BE_INFERRED}; use crate::ops::{self, DataflowOpTrait}; use crate::hugr::views::HugrView; @@ -71,18 +71,30 @@ impl + AsRef> TailLoopBuilder { } impl TailLoopBuilder { - /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR + /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR. + /// Extension delta will be inferred. pub fn new( just_inputs: impl Into, inputs_outputs: impl Into, just_outputs: impl Into, - extension_delta: ExtensionSet, + ) -> Result { + Self::new_exts(just_inputs, inputs_outputs, just_outputs, TO_BE_INFERRED) + } + + /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR. + /// `extension_delta` is explicitly specified; alternatively, [new](Self::new) + /// may be used to infer it. + pub fn new_exts( + just_inputs: impl Into, + inputs_outputs: impl Into, + just_outputs: impl Into, + extension_delta: impl Into, ) -> Result { let tail_loop = ops::TailLoop { just_inputs: just_inputs.into(), just_outputs: just_outputs.into(), rest: inputs_outputs.into(), - extension_delta, + extension_delta: extension_delta.into(), }; let base = Hugr::new(tail_loop.clone()); let root = base.root(); @@ -111,7 +123,7 @@ mod test { fn basic_loop() -> Result<(), BuildError> { let build_result: Result = { let mut loop_b = - TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T], PRELUDE_ID.into())?; + TailLoopBuilder::new_exts(vec![], vec![BIT], vec![USIZE_T], PRELUDE_ID)?; let [i1] = loop_b.input_wires_arr(); let const_wire = loop_b.add_load_value(ConstUsize::new(1)); @@ -143,12 +155,8 @@ mod test { )? .outputs_arr(); let loop_id = { - let mut loop_b = fbuild.tail_loop_builder( - vec![(BIT, b1)], - vec![], - type_row![NAT], - PRELUDE_ID.into(), - )?; + let mut loop_b = + fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?; let signature = loop_b.loop_signature()?.clone(); let const_val = Value::true_val(); let const_wire = loop_b.add_load_const(Value::true_val()); @@ -199,9 +207,7 @@ mod test { #[test] // fixed: issue 1257: When building a TailLoop, calling outputs_arr, you are given an OrderEdge "output wire" fn tailloop_output_arr() { - let mut builder = - TailLoopBuilder::new(type_row![], type_row![], type_row![], ExtensionSet::new()) - .unwrap(); + let mut builder = TailLoopBuilder::new(type_row![], type_row![], type_row![]).unwrap(); let control = builder.add_load_value(Value::false_val()); let tailloop = builder.finish_with_outputs(control, []).unwrap(); let [] = tailloop.outputs_arr(); diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/rewrite/outline_cfg.rs index 1543be973..84f9ed730 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/rewrite/outline_cfg.rs @@ -125,7 +125,7 @@ impl Rewrite for OutlineCfg { // 2. new_block contains input node, sub-cfg, exit node all connected let (new_block, cfg_node) = { - let mut new_block_bldr = BlockBuilder::new( + let mut new_block_bldr = BlockBuilder::new_exts( inputs.clone(), vec![type_row![]], outputs.clone(), @@ -134,7 +134,7 @@ impl Rewrite for OutlineCfg { .unwrap(); let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires()); let cfg = new_block_bldr - .cfg_builder(wires_in, outputs, extension_delta) + .cfg_builder_exts(wires_in, outputs, extension_delta) .unwrap(); let cfg = cfg.finish_sub_container().unwrap(); let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum()); diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 4252b8cd9..c1d8785ff 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -1166,11 +1166,11 @@ mod extension_tests { } fn make_bb(t: Type, es: ExtensionSet) -> DFGWrapper { - BlockBuilder::new(t.clone(), vec![t.into()], type_row![], es).unwrap() + BlockBuilder::new_exts(t.clone(), vec![t.into()], type_row![], es).unwrap() } fn make_tailloop(t: Type, es: ExtensionSet) -> DFGWrapper> { let row = TypeRow::from(t); - TailLoopBuilder::new(row.clone(), type_row![], row, es).unwrap() + TailLoopBuilder::new_exts(row.clone(), type_row![], row, es).unwrap() } }