Skip to content

Commit

Permalink
feat!: Update remaining builder methods to "infer by default" (#1386)
Browse files Browse the repository at this point in the history
closes #1318 and also deals with BlockBuilder not mentioned there.
I think this is now all of the nested-structures covered:
| | XBuilder::new | fn x |
|----|----|----|
|Conditional|**here**|in #1226|
|\->Case|takes `Signature`|inherits exts|
|TailLoop|**here**|**here**|
|DFG|takes `Signature`|takes `Signature` (`dfg_builder_endo` in #1219)|
|CFG|takes `Signature`|**here**|
|\->Block|**here**|in #1226|

(FuncDefn takes `Signature` and is *not supported by inference yet*
anyway)

BREAKING CHANGE: `cfg_builder`, `tail_loop_builder`,
`ConditionalBuilder::new`, `BlockBuilder::new` and
`TailLoopBuilder::new` no longer take an ExtensionSet parameter; either
remove the argument (to use extension inference) or use the `_exts`
variant
  • Loading branch information
acl-cqc authored Aug 2, 2024
1 parent 68cfac5 commit b75dd09
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 47 deletions.
51 changes: 47 additions & 4 deletions hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ pub trait Dataflow: Container {
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `output_types` are the types of the outputs.
/// The Extension delta will be inferred.
///
/// # Errors
///
Expand All @@ -314,7 +315,27 @@ pub trait Dataflow: Container {
&mut self,
inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
extension_delta: ExtensionSet,
) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
self.cfg_builder_exts(inputs, output_types, TO_BE_INFERRED)
}

/// Return a builder for a [`crate::ops::CFG`] node,
/// i.e. a nested controlflow subgraph.
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `output_types` are the types of the outputs.
/// `extension_delta` is explicitly specified. Alternatively
/// [cfg_builder](Self::cfg_builder) may be used to infer it.
///
/// # Errors
///
/// This function will return an error if there is an error when building
/// the CFG node.
fn cfg_builder_exts(
&mut self,
inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
extension_delta: impl Into<ExtensionSet>,
) -> Result<CFGBuilder<&mut Hugr>, BuildError> {
let (input_types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();

Expand Down Expand Up @@ -405,17 +426,39 @@ pub trait Dataflow: Container {
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `output_types` are the types of the outputs.
/// The extension delta will be inferred.
///
/// # Errors
///
/// This function will return an error if there is an error when building
/// the [`ops::TailLoop`] node.
///
fn tail_loop_builder(
&mut self,
just_inputs: impl IntoIterator<Item = (Type, Wire)>,
inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
just_out_types: TypeRow,
extension_delta: ExtensionSet,
) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
self.tail_loop_builder_exts(just_inputs, inputs_outputs, just_out_types, TO_BE_INFERRED)
}

/// Return a builder for a [`crate::ops::TailLoop`] node.
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `output_types` are the types of the outputs.
/// `extension_delta` explicitly specified. Alternatively
/// [tail_loop_builder](Self::tail_loop_builder) may be used to infer it.
///
/// # Errors
///
/// This function will return an error if there is an error when building
/// the [`ops::TailLoop`] node.
fn tail_loop_builder_exts(
&mut self,
just_inputs: impl IntoIterator<Item = (Type, Wire)>,
inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
just_out_types: TypeRow,
extension_delta: impl Into<ExtensionSet>,
) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
let (input_types, mut input_wires): (Vec<Type>, Vec<Wire>) =
just_inputs.into_iter().unzip();
Expand All @@ -427,7 +470,7 @@ pub trait Dataflow: Container {
just_inputs: input_types.into(),
just_outputs: just_out_types,
rest: rest_types.into(),
extension_delta,
extension_delta: extension_delta.into(),
};
// TODO: Make input extensions a parameter
let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?;
Expand Down Expand Up @@ -463,7 +506,7 @@ pub trait Dataflow: Container {
/// The `other_inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `output_types` are the types of the outputs.
/// `exts` explicitly specifies the extension delta. Alternatively
/// `extension_delta` is explicitly specified. Alternatively
/// [conditional_builder](Self::conditional_builder) may be used to infer it.
///
/// # Errors
Expand Down
25 changes: 17 additions & 8 deletions hugr-core/src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,24 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
}

impl BlockBuilder<Hugr> {
/// Initialize a [`DataflowBlock`] rooted HUGR builder
/// Initialize a [`DataflowBlock`] rooted HUGR builder.
/// Extension delta will be inferred.
pub fn new(
inputs: impl Into<TypeRow>,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: impl Into<TypeRow>,
extension_delta: ExtensionSet,
) -> Result<Self, BuildError> {
Self::new_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED)
}

/// Initialize a [`DataflowBlock`] rooted HUGR builder.
/// `extension_delta` is explicitly specified; alternatively, [new](Self::new)
/// may be used to infer it.
pub fn new_exts(
inputs: impl Into<TypeRow>,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: impl Into<TypeRow>,
extension_delta: impl Into<ExtensionSet>,
) -> Result<Self, BuildError> {
let inputs = inputs.into();
let sum_rows: Vec<_> = sum_rows.into_iter().collect();
Expand All @@ -427,7 +439,7 @@ impl BlockBuilder<Hugr> {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
sum_rows,
extension_delta,
extension_delta: extension_delta.into(),
};

let base = Hugr::new(op);
Expand Down Expand Up @@ -468,11 +480,8 @@ pub(crate) mod test {
let [int] = func_builder.input_wires_arr();

let cfg_id = {
let mut cfg_builder = func_builder.cfg_builder(
vec![(NAT, int)],
type_row![NAT],
ExtensionSet::new(),
)?;
let mut cfg_builder =
func_builder.cfg_builder(vec![(NAT, int)], type_row![NAT])?;
build_basic_cfg(&mut cfg_builder)?;

cfg_builder.finish_sub_container()?
Expand Down
37 changes: 20 additions & 17 deletions hugr-core/src/builder/conditional.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::extension::ExtensionRegistry;
use crate::extension::{ExtensionRegistry, TO_BE_INFERRED};
use crate::hugr::views::HugrView;
use crate::ops::dataflow::DataflowOpTrait;
use crate::types::{Signature, TypeRow};
Expand Down Expand Up @@ -152,12 +152,23 @@ impl HugrBuilder for ConditionalBuilder<Hugr> {
}

impl ConditionalBuilder<Hugr> {
/// Initialize a Conditional rooted HUGR builder
/// Initialize a Conditional rooted HUGR builder, extension delta will be inferred.
pub fn new(
sum_rows: impl IntoIterator<Item = TypeRow>,
other_inputs: impl Into<TypeRow>,
outputs: impl Into<TypeRow>,
extension_delta: ExtensionSet,
) -> Result<Self, BuildError> {
Self::new_exts(sum_rows, other_inputs, outputs, TO_BE_INFERRED)
}

/// Initialize a Conditional rooted HUGR builder,
/// `extension_delta` explicitly specified. Alternatively,
/// [new](Self::new) may be used to infer it.
pub fn new_exts(
sum_rows: impl IntoIterator<Item = TypeRow>,
other_inputs: impl Into<TypeRow>,
outputs: impl Into<TypeRow>,
extension_delta: impl Into<ExtensionSet>,
) -> Result<Self, BuildError> {
let sum_rows: Vec<_> = sum_rows.into_iter().collect();
let other_inputs = other_inputs.into();
Expand All @@ -170,7 +181,7 @@ impl ConditionalBuilder<Hugr> {
sum_rows,
other_inputs,
outputs,
extension_delta,
extension_delta: extension_delta.into(),
};
let base = Hugr::new(op);
let conditional_node = base.root();
Expand Down Expand Up @@ -216,7 +227,7 @@ mod test {

#[test]
fn basic_conditional() -> Result<(), BuildError> {
let mut conditional_b = ConditionalBuilder::new(
let mut conditional_b = ConditionalBuilder::new_exts(
[type_row![], type_row![]],
type_row![NAT],
type_row![NAT],
Expand Down Expand Up @@ -265,12 +276,8 @@ mod test {

#[test]
fn test_not_all_cases() -> Result<(), BuildError> {
let mut builder = ConditionalBuilder::new(
[type_row![], type_row![]],
type_row![],
type_row![],
ExtensionSet::new(),
)?;
let mut builder =
ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
n_identity(builder.case_builder(0)?)?;
assert_matches!(
builder.finish_sub_container().map(|_| ()),
Expand All @@ -283,12 +290,8 @@ mod test {

#[test]
fn test_case_already_built() -> Result<(), BuildError> {
let mut builder = ConditionalBuilder::new(
[type_row![], type_row![]],
type_row![],
type_row![],
ExtensionSet::new(),
)?;
let mut builder =
ConditionalBuilder::new([type_row![], type_row![]], type_row![], type_row![])?;
n_identity(builder.case_builder(0)?)?;
assert_matches!(
builder.case_builder(0).map(|_| ()),
Expand Down
34 changes: 20 additions & 14 deletions hugr-core/src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::extension::ExtensionSet;
use crate::extension::{ExtensionSet, TO_BE_INFERRED};
use crate::ops::{self, DataflowOpTrait};

use crate::hugr::views::HugrView;
Expand Down Expand Up @@ -71,18 +71,30 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> TailLoopBuilder<B> {
}

impl TailLoopBuilder<Hugr> {
/// Initialize new builder for a [`ops::TailLoop`] rooted HUGR
/// Initialize new builder for a [`ops::TailLoop`] rooted HUGR.
/// Extension delta will be inferred.
pub fn new(
just_inputs: impl Into<TypeRow>,
inputs_outputs: impl Into<TypeRow>,
just_outputs: impl Into<TypeRow>,
extension_delta: ExtensionSet,
) -> Result<Self, BuildError> {
Self::new_exts(just_inputs, inputs_outputs, just_outputs, TO_BE_INFERRED)
}

/// Initialize new builder for a [`ops::TailLoop`] rooted HUGR.
/// `extension_delta` is explicitly specified; alternatively, [new](Self::new)
/// may be used to infer it.
pub fn new_exts(
just_inputs: impl Into<TypeRow>,
inputs_outputs: impl Into<TypeRow>,
just_outputs: impl Into<TypeRow>,
extension_delta: impl Into<ExtensionSet>,
) -> Result<Self, BuildError> {
let tail_loop = ops::TailLoop {
just_inputs: just_inputs.into(),
just_outputs: just_outputs.into(),
rest: inputs_outputs.into(),
extension_delta,
extension_delta: extension_delta.into(),
};
let base = Hugr::new(tail_loop.clone());
let root = base.root();
Expand Down Expand Up @@ -111,7 +123,7 @@ mod test {
fn basic_loop() -> Result<(), BuildError> {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b =
TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T], PRELUDE_ID.into())?;
TailLoopBuilder::new_exts(vec![], vec![BIT], vec![USIZE_T], PRELUDE_ID)?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_value(ConstUsize::new(1));

Expand Down Expand Up @@ -143,12 +155,8 @@ mod test {
)?
.outputs_arr();
let loop_id = {
let mut loop_b = fbuild.tail_loop_builder(
vec![(BIT, b1)],
vec![],
type_row![NAT],
PRELUDE_ID.into(),
)?;
let mut loop_b =
fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?;
let signature = loop_b.loop_signature()?.clone();
let const_val = Value::true_val();
let const_wire = loop_b.add_load_const(Value::true_val());
Expand Down Expand Up @@ -199,9 +207,7 @@ mod test {
#[test]
// fixed: issue 1257: When building a TailLoop, calling outputs_arr, you are given an OrderEdge "output wire"
fn tailloop_output_arr() {
let mut builder =
TailLoopBuilder::new(type_row![], type_row![], type_row![], ExtensionSet::new())
.unwrap();
let mut builder = TailLoopBuilder::new(type_row![], type_row![], type_row![]).unwrap();
let control = builder.add_load_value(Value::false_val());
let tailloop = builder.finish_with_outputs(control, []).unwrap();
let [] = tailloop.outputs_arr();
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl Rewrite for OutlineCfg {

// 2. new_block contains input node, sub-cfg, exit node all connected
let (new_block, cfg_node) = {
let mut new_block_bldr = BlockBuilder::new(
let mut new_block_bldr = BlockBuilder::new_exts(
inputs.clone(),
vec![type_row![]],
outputs.clone(),
Expand All @@ -134,7 +134,7 @@ impl Rewrite for OutlineCfg {
.unwrap();
let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
let cfg = new_block_bldr
.cfg_builder(wires_in, outputs, extension_delta)
.cfg_builder_exts(wires_in, outputs, extension_delta)
.unwrap();
let cfg = cfg.finish_sub_container().unwrap();
let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum());
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1166,11 +1166,11 @@ mod extension_tests {
}

fn make_bb(t: Type, es: ExtensionSet) -> DFGWrapper<Hugr, BasicBlockID> {
BlockBuilder::new(t.clone(), vec![t.into()], type_row![], es).unwrap()
BlockBuilder::new_exts(t.clone(), vec![t.into()], type_row![], es).unwrap()
}

fn make_tailloop(t: Type, es: ExtensionSet) -> DFGWrapper<Hugr, BuildHandle<TailLoopID>> {
let row = TypeRow::from(t);
TailLoopBuilder::new(row.clone(), type_row![], row, es).unwrap()
TailLoopBuilder::new_exts(row.clone(), type_row![], row, es).unwrap()
}
}

0 comments on commit b75dd09

Please sign in to comment.