Skip to content

Commit

Permalink
Conditional builder (_exts)
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Jun 25, 2024
1 parent 4195736 commit cf99323
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 9 deletions.
27 changes: 24 additions & 3 deletions hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,27 @@ pub trait Dataflow: Container {
TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
}

/// Return a builder for a [`crate::ops::Conditional`] node.
/// `sum_input` is a tuple of the type of the Sum
/// variants and the corresponding wire.
///
/// The `other_inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `outputs` are the types of the outputs. Extension delta will be inferred.
///
/// # Errors
///
/// This function will return an error if there is an error when building
/// the Conditional node.
fn conditional_builder(
&mut self,
sum_input: (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
self.conditional_builder_exts(sum_input, other_inputs, TO_BE_INFERRED, output_types)
}

/// Return a builder for a [`crate::ops::Conditional`] node.
/// `sum_rows` and `sum_wire` define the type of the Sum
/// variants and the wire carrying the Sum respectively.
Expand All @@ -449,12 +470,12 @@ pub trait Dataflow: Container {
///
/// This function will return an error if there is an error when building
/// the Conditional node.
fn conditional_builder(
fn conditional_builder_exts(
&mut self,
(sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (Type, Wire)>,
extension_delta: impl Into<ExtensionSet>,
output_types: TypeRow,
extension_delta: ExtensionSet,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
let mut input_wires = vec![sum_wire];
let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
Expand All @@ -471,7 +492,7 @@ pub trait Dataflow: Container {
sum_rows,
other_inputs: inputs,
outputs: output_types,
extension_delta,
extension_delta: extension_delta.into(),
},
input_wires,
)?;
Expand Down
1 change: 0 additions & 1 deletion hugr-core/src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ mod test {
([type_row![], type_row![]], const_wire),
other_inputs,
outputs,
ExtensionSet::new(),
)?;

n_identity(conditional_b.case_builder(0)?)?;
Expand Down
1 change: 0 additions & 1 deletion hugr-core/src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ mod test {
([type_row![], type_row![]], const_wire),
vec![(BIT, b1)],
output_row,
PRELUDE_ID.into(),
)?;

let mut branch_0 = conditional_b.case_builder(0)?;
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@ mod test {
.add_dataflow_op(test_quantum_extension::measure(), r.outputs())?
.outputs_arr();
// Node using the boolean. Here we just select between two empty computations.
let mut if_n = inner.conditional_builder(
let mut if_n = inner.conditional_builder_exts(
([type_row![], type_row![]], b),
[],
type_row![],
ExtensionSet::new(),
type_row![],
)?;
if_n.case_builder(0)?.finish_with_outputs([])?;
if_n.case_builder(1)?.finish_with_outputs([])?;
Expand Down
1 change: 0 additions & 1 deletion hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,6 @@ mod test {
(vec![type_row![]; 2], b),
[(USIZE_T, i)],
type_row![USIZE_T],
ExtensionSet::new(),
)?;
let mut case1 = cond.case_builder(0)?;
let foo = case1.add_dataflow_op(mk_op("foo"), case1.input_wires())?;
Expand Down
1 change: 0 additions & 1 deletion hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,6 @@ fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
(vec![type_row![USIZE_T; 2]], tup),
vec![],
type_row![USIZE_T;2],
TO_BE_INFERRED.into(), //TODO make default //es.clone().union(EXT_ID.into()),
)?;
let mut cc = c.case_builder(0)?;
let [i1, i2] = cc.input_wires_arr();
Expand Down

0 comments on commit cf99323

Please sign in to comment.