Skip to content

Commit

Permalink
refactor!: rename predicate to TupleSum/UnitSum (#557)
Browse files Browse the repository at this point in the history
Closes #448

BREAKING CHANGE: previous uses of type/value constructor methods for
"predicates" will need updating
  • Loading branch information
ss2165 authored Oct 23, 2023
1 parent 527cce5 commit bb446e2
Show file tree
Hide file tree
Showing 18 changed files with 171 additions and 192 deletions.
27 changes: 13 additions & 14 deletions specification/hugr.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,18 @@ express control flow, i.e. conditional or repeated evaluation.
##### `Conditional` nodes

These are parents to multiple `Case` nodes; the children have no edges.
The first input to the Conditional-node is of Predicate type (see below), whose
The first input to the Conditional-node is of TupleSum type (see below), whose
arity matches the number of children of the Conditional-node. At runtime
the constructor (tag) selects which child to execute; the unpacked
contents of the Predicate with all remaining inputs to Conditional
contents of the TupleSum with all remaining inputs to Conditional
appended are sent to this child, and all outputs of the child are the
outputs of the Conditional; that child is evaluated, but the others are
not. That is, Conditional-nodes act as "if-then-else" followed by a
control-flow merge.

A **Predicate(T0, T1…TN)** type is an algebraic “sum of products” type,
defined as `Sum(Tuple(#t0), Tuple(#t1), ...Tuple(#tn))` (see [type
system](#type-system)), where `#ti` is the *i*th Row defining it.
A **TupleSum(T0, T1…TN)** type is an algebraic “sum of products” type,
defined as `Sum(Tuple(#T0), Tuple(#T1), ...Tuple(#Tn))` (see [type
system](#type-system)), where `#Ti` is the *i*th Row defining it.

```mermaid
flowchart
Expand All @@ -362,7 +362,7 @@ flowchart
end
Case0 ~~~ Case1
end
Pred["case 0 inputs | case 1 inputs"] --> Conditional
TupleSum["case 0 inputs | case 1 inputs"] --> Conditional
OI["other inputs"] --> Conditional
Conditional --> outputs
```
Expand All @@ -371,13 +371,13 @@ flowchart

These provide tail-controlled loops. The dataflow sibling graph within the
TailLoop-node defines the loop body: this computes a row of outputs, whose
first element has type `Predicate(#I, #O)` and the remainder is a row `#X`
first element has type `TupleSum(#I, #O)` and the remainder is a row `#X`
(perhaps empty). Inputs to the contained graph and to the TailLoop node itself
are the row `#I:#X`, where `:` indicates row concatenation (with the tuple
inside the `Predicate` unpacked).
inside the `TupleSum` unpacked).

Evaluation of the node begins by feeding the node inputs into the child graph
and evaluating it. The `Predicate` produced controls iteration of the loop:
and evaluating it. The `TupleSum` produced controls iteration of the loop:
* The first variant (`#I`) means that these values, along with the other
sibling-graph outputs `#X`, are fed back into the top of the loop,
and the body is evaluated again (thus perhaps many times)
Expand Down Expand Up @@ -405,7 +405,7 @@ The first child is the entry block and must be a `DFB`, with inputs the same as
The remaining children are either `DFB`s or [scoped definitions](#scoped-definitions).

The first output of the DSG contained in a `BasicBlock` has type
`Predicate(#t0,...#t(n-1))`, where the node has `n` successors, and the
`TupleSum(#t0,...#t(n-1))`, where the node has `n` successors, and the
remaining outputs are a row `#x`. `#ti` with `#x` appended matches the
inputs of successor `i`.

Expand All @@ -431,7 +431,7 @@ output of each of these is a sum type, whose arity is the number of outgoing
control edges; the remaining outputs are those that are passed to all
succeeding nodes.

The three nodes labelled "Const" are simply generating a predicate with one empty
The three nodes labelled "Const" are simply generating a TupleSum with one empty
value to pass to the Output node.

```mermaid
Expand Down Expand Up @@ -1125,7 +1125,6 @@ run, which removes the `HigherOrder` extension requirement:
```
precompute :: Function[](Function[Quantum,HigherOrder](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit))),
Function[Quantum](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit))))
>>>>>>> c6abd39 ([doc] Tidy hugr specification)
```
Before we can run the circuit.
Expand Down Expand Up @@ -1391,8 +1390,8 @@ use an empty node in the replacement and have B map this node to the old
one.
We can, for example, implement “turning a Conditional-node with known
predicate into a DFG-node” by a `Replace` where the Conditional (and its
preceding predicate) is replaced by an empty DFG and the map B specifies
TupleSum into a DFG-node” by a `Replace` where the Conditional (and its
preceding TupleSum) is replaced by an empty DFG and the map B specifies
the “good” child of the Conditional as the surrogate parent of the new
DFG’s children. (If the good child was just an Op, we could either
remove it and include it in the replacement, or – to avoid this overhead
Expand Down
18 changes: 6 additions & 12 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,10 +605,8 @@ pub(crate) mod test {
// \-> right -/ \-<--<-/
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;

let pred_const =
cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch
let const_unit =
cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down Expand Up @@ -889,10 +887,8 @@ pub(crate) mod test {
separate: bool,
) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let pred_const =
cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch
let const_unit =
cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 2, ExtensionSet::new())?,
Expand Down Expand Up @@ -933,10 +929,8 @@ pub(crate) mod test {
cfg_builder: &mut CFGBuilder<T>,
separate_headers: bool,
) -> Result<(BasicBlockID, BasicBlockID), BuildError> {
let pred_const =
cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch
let const_unit =
cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?;
let pred_const = cfg_builder.add_constant(Const::unit_sum(0, 2), ExtensionSet::new())?; // Nothing here cares which
let const_unit = cfg_builder.add_constant(Const::unary_unit_sum(), ExtensionSet::new())?;

let entry = n_identity(
cfg_builder.simple_entry_builder(type_row![NAT], 1, ExtensionSet::new())?,
Expand Down
26 changes: 13 additions & 13 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,8 @@ pub trait Dataflow: Container {
}

/// Return a builder for a [`crate::ops::Conditional`] node.
/// `predicate_inputs` and `predicate_wire` define the type of the predicate
/// variants and the wire carrying the predicate respectively.
/// `tuple_sum_rows` and `tuple_sum_wire` define the type of the TupleSum
/// variants and the wire carrying the TupleSum respectively.
///
/// The `other_inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
Expand All @@ -434,24 +434,24 @@ pub trait Dataflow: Container {
/// the Conditional node.
fn conditional_builder(
&mut self,
(predicate_inputs, predicate_wire): (impl IntoIterator<Item = TypeRow>, Wire),
(tuple_sum_rows, tuple_sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
extension_delta: ExtensionSet,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
let mut input_wires = vec![predicate_wire];
let mut input_wires = vec![tuple_sum_wire];
let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
other_inputs.into_iter().unzip();

input_wires.extend(rest_input_wires);
let inputs: TypeRow = input_types.into();
let predicate_inputs: Vec<_> = predicate_inputs.into_iter().collect();
let n_cases = predicate_inputs.len();
let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect();
let n_cases = tuple_sum_rows.len();
let n_out_wires = output_types.len();

let conditional_id = self.add_dataflow_op(
ops::Conditional {
predicate_inputs,
tuple_sum_rows,
other_inputs: inputs,
outputs: output_types,
extension_delta,
Expand Down Expand Up @@ -534,15 +534,15 @@ pub trait Dataflow: Container {
}

/// Add [`LeafOp::MakeTuple`] and [`LeafOp::Tag`] nodes to construct the
/// `tag` variant of a predicate (sum-of-tuples) type.
fn make_predicate(
/// `tag` variant of a TupleSum type.
fn make_tuple_sum(
&mut self,
tag: usize,
predicate_variants: impl IntoIterator<Item = TypeRow>,
tuple_sum_rows: impl IntoIterator<Item = TypeRow>,
values: impl IntoIterator<Item = Wire>,
) -> Result<Wire, BuildError> {
let tuple = self.make_tuple(values)?;
let variants = crate::types::predicate_variants_row(predicate_variants);
let variants = crate::types::tuple_sum_row(tuple_sum_rows);
let make_op = self.add_dataflow_op(LeafOp::Tag { tag, variants }, vec![tuple])?;
Ok(make_op.out_wire(0))
}
Expand All @@ -561,7 +561,7 @@ pub trait Dataflow: Container {
tail_loop: ops::TailLoop,
values: impl IntoIterator<Item = Wire>,
) -> Result<Wire, BuildError> {
self.make_predicate(0, [tail_loop.just_inputs, tail_loop.just_outputs], values)
self.make_tuple_sum(0, [tail_loop.just_inputs, tail_loop.just_outputs], values)
}

/// Use the wires in `values` to return a wire corresponding to the
Expand All @@ -578,7 +578,7 @@ pub trait Dataflow: Container {
loop_op: ops::TailLoop,
values: impl IntoIterator<Item = Wire>,
) -> Result<Wire, BuildError> {
self.make_predicate(1, [loop_op.just_inputs, loop_op.just_outputs], values)
self.make_tuple_sum(1, [loop_op.just_inputs, loop_op.just_outputs], values)
}

/// Add a [`ops::Call`] node, calling `function`, with inputs
Expand Down
58 changes: 26 additions & 32 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,22 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
}

/// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs`
/// and `outputs` and the variants of the branching predicate Sum value
/// specified by `predicate_variants`.
/// and `outputs` and the variants of the branching TupleSum value
/// specified by `tuple_sum_rows`.
///
/// # Errors
///
/// This function will return an error if there is an error adding the node.
pub fn block_builder(
&mut self,
inputs: TypeRow,
predicate_variants: Vec<TypeRow>,
tuple_sum_rows: impl IntoIterator<Item = TypeRow>,
extension_delta: ExtensionSet,
other_outputs: TypeRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(
inputs,
predicate_variants,
tuple_sum_rows,
other_outputs,
extension_delta,
false,
Expand All @@ -128,15 +128,16 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
fn any_block_builder(
&mut self,
inputs: TypeRow,
predicate_variants: Vec<TypeRow>,
tuple_sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
extension_delta: ExtensionSet,
entry: bool,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect();
let op = OpType::BasicBlock(BasicBlock::DFB {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
tuple_sum_rows: tuple_sum_rows.clone(),
extension_delta,
});
let parent = self.container_node();
Expand All @@ -152,14 +153,14 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
BlockBuilder::create(
self.hugr_mut(),
block_n,
predicate_variants,
tuple_sum_rows,
other_outputs,
inputs,
)
}

/// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs`
/// and `outputs` and a simple predicate type: a Sum of `n_cases` unit types.
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
///
/// # Errors
///
Expand All @@ -178,33 +179,27 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
/// and `outputs` and the variants of the branching predicate Sum value
/// specified by `predicate_variants`.
/// and `outputs` and the variants of the branching TupleSum value
/// specified by `tuple_sum_rows`.
///
/// # Errors
///
/// This function will return an error if an entry block has already been built.
pub fn entry_builder(
&mut self,
predicate_variants: Vec<TypeRow>,
tuple_sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let inputs = self
.inputs
.take()
.ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
self.any_block_builder(
inputs,
predicate_variants,
other_outputs,
extension_delta,
true,
)
self.any_block_builder(inputs, tuple_sum_rows, other_outputs, extension_delta, true)
}

/// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs`
/// and `outputs` and a simple predicate type: a Sum of `n_cases` unit types.
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
///
/// # Errors
///
Expand Down Expand Up @@ -244,8 +239,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
pub type BlockBuilder<B> = DFGWrapper<B, BasicBlockID>;

impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
/// Set the outputs of the block, with `branch_wire` being the value of the
/// predicate. `outputs` are the remaining outputs.
/// Set the outputs of the block, with `branch_wire` carrying the value of the
/// branch controlling TupleSum value. `outputs` are the remaining outputs.
pub fn set_outputs(
&mut self,
branch_wire: Wire,
Expand All @@ -256,13 +251,13 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
fn create(
base: B,
block_n: Node,
predicate_variants: Vec<TypeRow>,
tuple_sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
inputs: TypeRow,
) -> Result<Self, BuildError> {
// The node outputs a predicate before the data outputs of the block node
let predicate_type = Type::new_predicate(predicate_variants);
let mut node_outputs = vec![predicate_type];
// The node outputs a TupleSum before the data outputs of the block node
let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows);
let mut node_outputs = vec![tuple_sum_type];
node_outputs.extend_from_slice(&other_outputs);
let signature = FunctionType::new(inputs, TypeRow::from(node_outputs));
let inp_ex = base
Expand Down Expand Up @@ -293,23 +288,23 @@ impl BlockBuilder<Hugr> {
pub fn new(
inputs: impl Into<TypeRow>,
input_extensions: impl Into<Option<ExtensionSet>>,
predicate_variants: impl IntoIterator<Item = TypeRow>,
tuple_sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: impl Into<TypeRow>,
extension_delta: ExtensionSet,
) -> Result<Self, BuildError> {
let inputs = inputs.into();
let predicate_variants: Vec<_> = predicate_variants.into_iter().collect();
let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect();
let other_outputs = other_outputs.into();
let op = BasicBlock::DFB {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
tuple_sum_rows: tuple_sum_rows.clone(),
extension_delta,
};

let base = Hugr::new(NodeType::new(op, input_extensions));
let root = base.root();
Self::create(base, root, predicate_variants, other_outputs, inputs)
Self::create(base, root, tuple_sum_rows, other_outputs, inputs)
}

/// [Set outputs](BlockBuilder::set_outputs) and [finish_hugr](`BlockBuilder::finish_hugr`).
Expand Down Expand Up @@ -382,14 +377,13 @@ mod test {
let entry = {
let [inw] = entry_b.input_wires_arr();

let sum = entry_b.make_predicate(1, sum2_variants, [inw])?;
let sum = entry_b.make_tuple_sum(1, sum2_variants, [inw])?;
entry_b.finish_with_outputs(sum, [])?
};
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let middle = {
let c = middle_b
.add_load_const(ops::Const::simple_unary_predicate(), ExtensionSet::new())?;
let c = middle_b.add_load_const(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let [inw] = middle_b.input_wires_arr();
middle_b.finish_with_outputs(c, [inw])?
};
Expand Down
Loading

0 comments on commit bb446e2

Please sign in to comment.