-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix!: Ops require their own extension #1226
Changes from 17 commits
7832cf0
d6802c3
2c32b1a
5cb15ca
5b05df3
78f7aed
856bfd4
b2c09d5
10bab5d
9e47b98
290bb6a
eca016d
abbfa2a
3994ee9
0da41d8
a238f21
79da288
ab4ef40
1b5eebd
eb6d1d0
a574c6d
aaa1cf3
2c56bf6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -437,6 +437,27 @@ pub trait Dataflow: Container { | |
TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop) | ||
} | ||
|
||
/// Return a builder for a [`crate::ops::Conditional`] node. | ||
/// `sum_input` is a tuple of the type of the Sum | ||
/// variants and the corresponding wire. | ||
/// | ||
/// The `other_inputs` must be an iterable over pairs of the type of the input and | ||
/// the corresponding wire. | ||
/// The `outputs` are the types of the outputs. Extension delta will be inferred. | ||
/// | ||
/// # Errors | ||
/// | ||
/// This function will return an error if there is an error when building | ||
/// the Conditional node. | ||
fn conditional_builder( | ||
&mut self, | ||
sum_input: (impl IntoIterator<Item = TypeRow>, Wire), | ||
other_inputs: impl IntoIterator<Item = (Type, Wire)>, | ||
output_types: TypeRow, | ||
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> { | ||
self.conditional_builder_exts(sum_input, other_inputs, TO_BE_INFERRED, output_types) | ||
} | ||
|
||
/// Return a builder for a [`crate::ops::Conditional`] node. | ||
/// `sum_rows` and `sum_wire` define the type of the Sum | ||
/// variants and the wire carrying the Sum respectively. | ||
|
@@ -449,12 +470,12 @@ pub trait Dataflow: Container { | |
/// | ||
/// This function will return an error if there is an error when building | ||
/// the Conditional node. | ||
fn conditional_builder( | ||
fn conditional_builder_exts( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mention |
||
&mut self, | ||
(sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire), | ||
other_inputs: impl IntoIterator<Item = (Type, Wire)>, | ||
extension_delta: impl Into<ExtensionSet>, | ||
output_types: TypeRow, | ||
extension_delta: ExtensionSet, | ||
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> { | ||
let mut input_wires = vec![sum_wire]; | ||
let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) = | ||
|
@@ -471,7 +492,7 @@ pub trait Dataflow: Container { | |
sum_rows, | ||
other_inputs: inputs, | ||
outputs: output_types, | ||
extension_delta, | ||
extension_delta: extension_delta.into(), | ||
}, | ||
input_wires, | ||
)?; | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,7 +5,10 @@ use super::{ | |||||||||||||
BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire, | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType}; | ||||||||||||||
use crate::{ | ||||||||||||||
extension::TO_BE_INFERRED, | ||||||||||||||
ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType}, | ||||||||||||||
}; | ||||||||||||||
use crate::{ | ||||||||||||||
extension::{ExtensionRegistry, ExtensionSet}, | ||||||||||||||
types::FunctionType, | ||||||||||||||
|
@@ -43,7 +46,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; | |||||||||||||
/// +------------+ | ||||||||||||||
/// */ | ||||||||||||||
/// use hugr::{ | ||||||||||||||
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder}, | ||||||||||||||
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder, ft1, ft2}, | ||||||||||||||
/// extension::{prelude, ExtensionSet}, | ||||||||||||||
/// ops, type_row, | ||||||||||||||
/// types::{FunctionType, SumType, Type}, | ||||||||||||||
|
@@ -62,8 +65,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; | |||||||||||||
/// | ||||||||||||||
/// // The second argument says what types will be passed through to every | ||||||||||||||
/// // successor, in addition to the appropriate `sum_variants` type. | ||||||||||||||
/// let mut entry_b = | ||||||||||||||
/// cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?; | ||||||||||||||
/// let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?; | ||||||||||||||
/// | ||||||||||||||
/// let [inw] = entry_b.input_wires_arr(); | ||||||||||||||
/// let entry = { | ||||||||||||||
|
@@ -82,7 +84,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; | |||||||||||||
/// // `NAT` arguments: one from the `sum_variants` type, and another from the | ||||||||||||||
/// // entry node's `other_outputs`. | ||||||||||||||
/// let mut successor_builder = cfg_builder.simple_block_builder( | ||||||||||||||
/// FunctionType::new(type_row![NAT, NAT], type_row![NAT]), | ||||||||||||||
/// ft2(type_row![NAT, NAT], NAT), | ||||||||||||||
/// 1, // only one successor to this block | ||||||||||||||
/// )?; | ||||||||||||||
/// let successor_a = { | ||||||||||||||
|
@@ -96,8 +98,7 @@ use crate::{hugr::HugrMut, type_row, Hugr}; | |||||||||||||
/// }; | ||||||||||||||
/// | ||||||||||||||
/// // The only argument to this block is the entry node's `other_outputs`. | ||||||||||||||
/// let mut successor_builder = cfg_builder | ||||||||||||||
/// .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?; | ||||||||||||||
/// let mut successor_builder = cfg_builder.simple_block_builder(ft1(NAT), 1)?; | ||||||||||||||
/// let successor_b = { | ||||||||||||||
/// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum()); | ||||||||||||||
/// let [in_wire] = successor_builder.input_wires_arr(); | ||||||||||||||
|
@@ -197,7 +198,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> { | |||||||||||||
|
||||||||||||||
/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` | ||||||||||||||
/// and `outputs` and the variants of the branching Sum value | ||||||||||||||
/// specified by `sum_rows`. | ||||||||||||||
/// specified by `sum_rows`. Extension delta will be inferred. | ||||||||||||||
/// | ||||||||||||||
/// # Errors | ||||||||||||||
/// | ||||||||||||||
|
@@ -206,18 +207,40 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> { | |||||||||||||
&mut self, | ||||||||||||||
inputs: TypeRow, | ||||||||||||||
sum_rows: impl IntoIterator<Item = TypeRow>, | ||||||||||||||
extension_delta: ExtensionSet, | ||||||||||||||
other_outputs: TypeRow, | ||||||||||||||
) -> Result<BlockBuilder<&mut Hugr>, BuildError> { | ||||||||||||||
self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, false) | ||||||||||||||
self.block_builder_exts(inputs, sum_rows, TO_BE_INFERRED, other_outputs) | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
fn any_block_builder( | ||||||||||||||
/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` | ||||||||||||||
/// and `outputs` and the variants of the branching Sum value | ||||||||||||||
/// specified by `sum_rows`. Extension delta will be inferred. | ||||||||||||||
/// | ||||||||||||||
/// # Errors | ||||||||||||||
/// | ||||||||||||||
/// This function will return an error if there is an error adding the node. | ||||||||||||||
pub fn block_builder_exts( | ||||||||||||||
&mut self, | ||||||||||||||
inputs: TypeRow, | ||||||||||||||
sum_rows: impl IntoIterator<Item = TypeRow>, | ||||||||||||||
extension_delta: impl Into<ExtensionSet>, | ||||||||||||||
other_outputs: TypeRow, | ||||||||||||||
) -> Result<BlockBuilder<&mut Hugr>, BuildError> { | ||||||||||||||
self.any_block_builder( | ||||||||||||||
inputs, | ||||||||||||||
extension_delta.into(), | ||||||||||||||
sum_rows, | ||||||||||||||
other_outputs, | ||||||||||||||
false, | ||||||||||||||
) | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
fn any_block_builder( | ||||||||||||||
&mut self, | ||||||||||||||
inputs: TypeRow, | ||||||||||||||
extension_delta: ExtensionSet, | ||||||||||||||
sum_rows: impl IntoIterator<Item = TypeRow>, | ||||||||||||||
other_outputs: TypeRow, | ||||||||||||||
entry: bool, | ||||||||||||||
) -> Result<BlockBuilder<&mut Hugr>, BuildError> { | ||||||||||||||
let sum_rows: Vec<_> = sum_rows.into_iter().collect(); | ||||||||||||||
|
@@ -241,7 +264,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> { | |||||||||||||
} | ||||||||||||||
|
||||||||||||||
/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` | ||||||||||||||
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. | ||||||||||||||
/// and `outputs` and `extension_delta` explicitly specified, plus a UnitSum type | ||||||||||||||
/// (a Sum of `n_cases` unit types) to select the successor. | ||||||||||||||
/// | ||||||||||||||
/// # Errors | ||||||||||||||
/// | ||||||||||||||
|
@@ -251,32 +275,53 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> { | |||||||||||||
signature: FunctionType, | ||||||||||||||
n_cases: usize, | ||||||||||||||
) -> Result<BlockBuilder<&mut Hugr>, BuildError> { | ||||||||||||||
self.block_builder( | ||||||||||||||
self.block_builder_exts( | ||||||||||||||
signature.input, | ||||||||||||||
vec![type_row![]; n_cases], | ||||||||||||||
signature.extension_reqs, | ||||||||||||||
signature.output, | ||||||||||||||
) | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
/// Return a builder for the entry [`DataflowBlock`] child graph with | ||||||||||||||
/// `outputs` and the variants of the branching Sum value | ||||||||||||||
/// specified by `sum_rows`. | ||||||||||||||
/// | ||||||||||||||
/// # Errors | ||||||||||||||
/// | ||||||||||||||
/// This function will return an error if an entry block has already been built. | ||||||||||||||
pub fn entry_builder( | ||||||||||||||
&mut self, | ||||||||||||||
sum_rows: impl IntoIterator<Item = TypeRow>, | ||||||||||||||
other_outputs: TypeRow, | ||||||||||||||
) -> Result<BlockBuilder<&mut Hugr>, BuildError> { | ||||||||||||||
self.entry_builder_exts(TO_BE_INFERRED, sum_rows, other_outputs) | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs` | ||||||||||||||
/// and `outputs` and the variants of the branching Sum value | ||||||||||||||
/// specified by `sum_rows`. | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, pretty much did this, also noted that delta will be inferred on the implicit version |
||||||||||||||
/// | ||||||||||||||
/// # Errors | ||||||||||||||
/// | ||||||||||||||
/// This function will return an error if an entry block has already been built. | ||||||||||||||
pub fn entry_builder( | ||||||||||||||
pub fn entry_builder_exts( | ||||||||||||||
&mut self, | ||||||||||||||
extension_delta: impl Into<ExtensionSet>, | ||||||||||||||
sum_rows: impl IntoIterator<Item = TypeRow>, | ||||||||||||||
other_outputs: TypeRow, | ||||||||||||||
extension_delta: ExtensionSet, | ||||||||||||||
) -> Result<BlockBuilder<&mut Hugr>, BuildError> { | ||||||||||||||
let inputs = self | ||||||||||||||
.inputs | ||||||||||||||
.take() | ||||||||||||||
.ok_or(BuildError::EntryBuiltError(self.cfg_node))?; | ||||||||||||||
self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, true) | ||||||||||||||
self.any_block_builder( | ||||||||||||||
inputs, | ||||||||||||||
extension_delta.into(), | ||||||||||||||
sum_rows, | ||||||||||||||
other_outputs, | ||||||||||||||
true, | ||||||||||||||
) | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs` | ||||||||||||||
|
@@ -289,9 +334,23 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> { | |||||||||||||
&mut self, | ||||||||||||||
outputs: TypeRow, | ||||||||||||||
n_cases: usize, | ||||||||||||||
extension_delta: ExtensionSet, | ||||||||||||||
) -> Result<BlockBuilder<&mut Hugr>, BuildError> { | ||||||||||||||
self.entry_builder(vec![type_row![]; n_cases], outputs, extension_delta) | ||||||||||||||
self.entry_builder(vec![type_row![]; n_cases], outputs) | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs` | ||||||||||||||
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact I removed mention of non-existent parameter |
||||||||||||||
/// | ||||||||||||||
/// # Errors | ||||||||||||||
/// | ||||||||||||||
/// This function will return an error if there is an error adding the node. | ||||||||||||||
pub fn simple_entry_builder_exts( | ||||||||||||||
&mut self, | ||||||||||||||
outputs: TypeRow, | ||||||||||||||
n_cases: usize, | ||||||||||||||
extension_delta: impl Into<ExtensionSet>, | ||||||||||||||
) -> Result<BlockBuilder<&mut Hugr>, BuildError> { | ||||||||||||||
self.entry_builder_exts(extension_delta, vec![type_row![]; n_cases], outputs) | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
/// Returns the exit block of this [`CFGBuilder`]. | ||||||||||||||
|
@@ -439,8 +498,11 @@ pub(crate) mod test { | |||||||||||||
cfg_builder: &mut CFGBuilder<T>, | ||||||||||||||
) -> Result<(), BuildError> { | ||||||||||||||
let sum2_variants = vec![type_row![NAT], type_row![NAT]]; | ||||||||||||||
let mut entry_b = | ||||||||||||||
cfg_builder.entry_builder(sum2_variants.clone(), type_row![], ExtensionSet::new())?; | ||||||||||||||
let mut entry_b = cfg_builder.entry_builder_exts( | ||||||||||||||
ExtensionSet::new(), | ||||||||||||||
sum2_variants.clone(), | ||||||||||||||
type_row![], | ||||||||||||||
)?; | ||||||||||||||
let entry = { | ||||||||||||||
let [inw] = entry_b.input_wires_arr(); | ||||||||||||||
|
||||||||||||||
|
@@ -466,8 +528,11 @@ pub(crate) mod test { | |||||||||||||
let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum()); | ||||||||||||||
let sum_variants = vec![type_row![]]; | ||||||||||||||
|
||||||||||||||
let mut entry_b = | ||||||||||||||
cfg_builder.entry_builder(sum_variants.clone(), type_row![], ExtensionSet::new())?; | ||||||||||||||
let mut entry_b = cfg_builder.entry_builder_exts( | ||||||||||||||
ExtensionSet::new(), | ||||||||||||||
sum_variants.clone(), | ||||||||||||||
type_row![], | ||||||||||||||
)?; | ||||||||||||||
let [inw] = entry_b.input_wires_arr(); | ||||||||||||||
let entry = { | ||||||||||||||
let sum = entry_b.load_const(&sum_tuple_const); | ||||||||||||||
|
@@ -501,8 +566,7 @@ pub(crate) mod test { | |||||||||||||
middle_b.finish_with_outputs(c, [inw])? | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
let mut entry_b = | ||||||||||||||
cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?; | ||||||||||||||
let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?; | ||||||||||||||
let entry = { | ||||||||||||||
let sum = entry_b.load_const(&sum_tuple_const); | ||||||||||||||
// entry block uses wire from middle block even though middle block | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.