Skip to content

Commit

Permalink
fix!: Ops require their own extension (#1226)
Browse files Browse the repository at this point in the history
Includes adding new variants of `block_builder`, `entry_builder`,
`simple_entry_builder` and `conditional_builder`: the default version
omits the extension set parameter, the `_exts` variant takes an extra
parameter (being an ExtensionSet). `simple_block_builder` is untouched
(as it takes a FunctionType, so can use `endo_ft`/`inout_ft`)

We'll need similar updates to `cfg_builder`, `tail_loop_builder`,
`ConditionalBuilder::new` and `TailLoopBuilder::new` but I'll leave
those for another PR, there's quite enough here ;)

closes #388

BREAKING CHANGE: (1) container-node extension-deltas will need to be
enlarged to include ops therein; for FuncDefn this will have to be
manually specified but for other containers TO_BE_INFERRED, `endo_ft` or
`inout_ft` all work. (2) `block_builder`, `entry_builder`,
`simple_entry_builder` and `conditional_builder` no longer take an
ExtensionSet; either drop the argument or use the `..._exts` variant.
  • Loading branch information
acl-cqc authored Jul 15, 2024
1 parent 9ed379f commit cfb0674
Show file tree
Hide file tree
Showing 25 changed files with 346 additions and 293 deletions.
14 changes: 7 additions & 7 deletions hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
//! # use hugr::Hugr;
//! # use hugr::builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, ModuleBuilder, DataflowSubContainer, HugrBuilder};
//! use hugr::extension::prelude::BOOL_T;
//! use hugr::std_extensions::logic::{NotOp, LOGIC_REG};
//! use hugr::std_extensions::logic::{EXTENSION_ID, LOGIC_REG, NotOp};
//! use hugr::types::FunctionType;
//!
//! # fn doctest() -> Result<(), BuildError> {
Expand All @@ -42,7 +42,7 @@
//! let _dfg_handle = {
//! let mut dfg = module_builder.define_function(
//! "main",
//! FunctionType::new(vec![BOOL_T], vec![BOOL_T]),
//! FunctionType::new_endo(BOOL_T).with_extension_delta(EXTENSION_ID),
//! )?;
//!
//! // Get the wires from the function inputs.
Expand All @@ -59,7 +59,8 @@
//! let _circuit_handle = {
//! let mut dfg = module_builder.define_function(
//! "circuit",
//! FunctionType::new_endo(vec![BOOL_T, BOOL_T]),
//! FunctionType::new_endo(vec![BOOL_T, BOOL_T])
//! .with_extension_delta(EXTENSION_ID),
//! )?;
//! let mut circuit = dfg.as_circuit(dfg.input_wires());
//!
Expand Down Expand Up @@ -285,10 +286,9 @@ pub(crate) mod test {
cfg_builder.finish_prelude_hugr().unwrap()
}

/// A helper method which creates a DFG rooted hugr with closed resources,
/// for tests which want to avoid having open extension variables after
/// inference. Using DFGBuilder will default to a root node with an open
/// extension variable
/// A helper method which creates a DFG rooted hugr with Input and Output node
/// only (no wires), given a function type with extension delta.
// TODO consider taking two type rows and using TO_BE_INFERRED
pub(crate) fn closed_dfg_root_hugr(signature: FunctionType) -> Hugr {
let mut hugr = Hugr::new(ops::DFG {
signature: signature.clone(),
Expand Down
31 changes: 27 additions & 4 deletions hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,24 +437,47 @@ pub trait Dataflow: Container {
TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
}

/// Return a builder for a [`crate::ops::Conditional`] node.
/// `sum_input` is a tuple of the type of the Sum
/// variants and the corresponding wire.
///
/// The `other_inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `output_types` are the types of the outputs. Extension delta will be inferred.
///
/// # Errors
///
/// This function will return an error if there is an error when building
/// the Conditional node.
fn conditional_builder(
&mut self,
sum_input: (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
self.conditional_builder_exts(sum_input, other_inputs, output_types, TO_BE_INFERRED)
}

/// Return a builder for a [`crate::ops::Conditional`] node.
/// `sum_rows` and `sum_wire` define the type of the Sum
/// variants and the wire carrying the Sum respectively.
///
/// The `other_inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `outputs` are the types of the outputs.
/// The `output_types` are the types of the outputs.
/// `exts` explicitly specifies the extension delta. Alternatively
/// [conditional_builder](Self::conditional_builder) may be used to infer it.
///
/// # Errors
///
/// This function will return an error if there is an error when building
/// the Conditional node.
fn conditional_builder(
fn conditional_builder_exts(
&mut self,
(sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
extension_delta: ExtensionSet,
extension_delta: impl Into<ExtensionSet>,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
let mut input_wires = vec![sum_wire];
let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
Expand All @@ -471,7 +494,7 @@ pub trait Dataflow: Container {
sum_rows,
other_inputs: inputs,
outputs: output_types,
extension_delta,
extension_delta: extension_delta.into(),
},
input_wires,
)?;
Expand Down
124 changes: 95 additions & 29 deletions hugr-core/src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use super::{
BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
};

use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType};
use crate::{
extension::TO_BE_INFERRED,
ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType},
};
use crate::{
extension::{ExtensionRegistry, ExtensionSet},
types::FunctionType,
Expand Down Expand Up @@ -43,7 +46,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// +------------+
/// */
/// use hugr::{
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder},
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder, endo_ft, inout_ft},
/// extension::{prelude, ExtensionSet},
/// ops, type_row,
/// types::{FunctionType, SumType, Type},
Expand All @@ -62,8 +65,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
///
/// // The second argument says what types will be passed through to every
/// // successor, in addition to the appropriate `sum_variants` type.
/// let mut entry_b =
/// cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?;
/// let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?;
///
/// let [inw] = entry_b.input_wires_arr();
/// let entry = {
Expand All @@ -82,7 +84,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// // `NAT` arguments: one from the `sum_variants` type, and another from the
/// // entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder.simple_block_builder(
/// FunctionType::new(type_row![NAT, NAT], type_row![NAT]),
/// inout_ft(type_row![NAT, NAT], NAT),
/// 1, // only one successor to this block
/// )?;
/// let successor_a = {
Expand All @@ -96,8 +98,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// };
///
/// // The only argument to this block is the entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder
/// .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
/// let mut successor_builder = cfg_builder.simple_block_builder(endo_ft(NAT), 1)?;
/// let successor_b = {
/// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
/// let [in_wire] = successor_builder.input_wires_arr();
Expand Down Expand Up @@ -197,7 +198,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {

/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`.
/// specified by `sum_rows`. Extension delta will be inferred.
///
/// # Errors
///
Expand All @@ -206,18 +207,40 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
inputs: TypeRow,
sum_rows: impl IntoIterator<Item = TypeRow>,
extension_delta: ExtensionSet,
other_outputs: TypeRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, false)
self.block_builder_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED)
}

fn any_block_builder(
/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`. Extension delta will be inferred.
///
/// # Errors
///
/// This function will return an error if there is an error adding the node.
pub fn block_builder_exts(
&mut self,
inputs: TypeRow,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
extension_delta: impl Into<ExtensionSet>,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(
inputs,
extension_delta.into(),
sum_rows,
other_outputs,
false,
)
}

fn any_block_builder(
&mut self,
inputs: TypeRow,
extension_delta: ExtensionSet,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
entry: bool,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let sum_rows: Vec<_> = sum_rows.into_iter().collect();
Expand All @@ -241,7 +264,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
}

/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
/// and `outputs` and `extension_delta` explicitly specified, plus a UnitSum type
/// (a Sum of `n_cases` unit types) to select the successor.
///
/// # Errors
///
Expand All @@ -251,17 +275,17 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
signature: FunctionType,
n_cases: usize,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.block_builder(
self.block_builder_exts(
signature.input,
vec![type_row![]; n_cases],
signature.extension_reqs,
signature.output,
signature.extension_reqs,
)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`.
/// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`
/// and the variants of the branching Sum value specified by `sum_rows`.
/// Extension delta will be inferred.
///
/// # Errors
///
Expand All @@ -270,17 +294,39 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder_exts(sum_rows, other_outputs, TO_BE_INFERRED)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`,
/// the variants of the branching Sum value specified by `sum_rows`, and
/// `extension_delta` explicitly specified. ([entry_builder](Self::entry_builder)
/// may be used to infer.)
///
/// # Errors
///
/// This function will return an error if an entry block has already been built.
pub fn entry_builder_exts(
&mut self,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
extension_delta: impl Into<ExtensionSet>,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let inputs = self
.inputs
.take()
.ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, true)
self.any_block_builder(
inputs,
extension_delta.into(),
sum_rows,
other_outputs,
true,
)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
/// Return a builder for the entry [`DataflowBlock`] child graph with
/// `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
///
/// # Errors
///
Expand All @@ -289,9 +335,24 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
outputs: TypeRow,
n_cases: usize,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder(vec![type_row![]; n_cases], outputs, extension_delta)
self.entry_builder(vec![type_row![]; n_cases], outputs)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with
/// `outputs` and a Sum of `n_cases` unit types, and explicit `extension_delta`.
/// ([simple_entry_builder](Self::simple_entry_builder) may be used to infer.)
///
/// # Errors
///
/// This function will return an error if there is an error adding the node.
pub fn simple_entry_builder_exts(
&mut self,
outputs: TypeRow,
n_cases: usize,
extension_delta: impl Into<ExtensionSet>,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder_exts(vec![type_row![]; n_cases], outputs, extension_delta)
}

/// Returns the exit block of this [`CFGBuilder`].
Expand Down Expand Up @@ -439,8 +500,11 @@ pub(crate) mod test {
cfg_builder: &mut CFGBuilder<T>,
) -> Result<(), BuildError> {
let sum2_variants = vec![type_row![NAT], type_row![NAT]];
let mut entry_b =
cfg_builder.entry_builder(sum2_variants.clone(), type_row![], ExtensionSet::new())?;
let mut entry_b = cfg_builder.entry_builder_exts(
sum2_variants.clone(),
type_row![],
ExtensionSet::new(),
)?;
let entry = {
let [inw] = entry_b.input_wires_arr();

Expand All @@ -466,8 +530,11 @@ pub(crate) mod test {
let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
let sum_variants = vec![type_row![]];

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![], ExtensionSet::new())?;
let mut entry_b = cfg_builder.entry_builder_exts(
sum_variants.clone(),
type_row![],
ExtensionSet::new(),
)?;
let [inw] = entry_b.input_wires_arr();
let entry = {
let sum = entry_b.load_const(&sum_tuple_const);
Expand Down Expand Up @@ -501,8 +568,7 @@ pub(crate) mod test {
middle_b.finish_with_outputs(c, [inw])?
};

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?;
let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?;
let entry = {
let sum = entry_b.load_const(&sum_tuple_const);
// entry block uses wire from middle block even though middle block
Expand Down
11 changes: 8 additions & 3 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ mod test {

use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
use crate::utils::test_quantum_extension::{
cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
};
use crate::{
builder::{
Expand All @@ -261,6 +261,7 @@ mod test {
fn simple_linear() {
let build_res = build_main(
FunctionType::new(type_row![QB, QB], type_row![QB, QB])
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.with_extension_delta(float_types::EXTENSION_ID)
.into(),
|mut f_build| {
Expand Down Expand Up @@ -302,7 +303,9 @@ mod test {
FunctionType::new(vec![QB, NAT], vec![QB]),
));
let build_res = build_main(
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(),
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.into(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();

Expand All @@ -327,7 +330,9 @@ mod test {
#[test]
fn ancillae() {
let build_res = build_main(
FunctionType::new(type_row![QB], type_row![QB]).into(),
FunctionType::new_endo(QB)
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.into(),
|mut f_build| {
let mut circ = f_build.as_circuit(f_build.input_wires());
assert_eq!(circ.n_wires(), 1);
Expand Down
1 change: 0 additions & 1 deletion hugr-core/src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ mod test {
([type_row![], type_row![]], const_wire),
other_inputs,
outputs,
ExtensionSet::new(),
)?;

n_identity(conditional_b.case_builder(0)?)?;
Expand Down
Loading

0 comments on commit cfb0674

Please sign in to comment.