Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: Ops require their own extension #1226

Merged
merged 23 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7832cf0
Insert OpDef extension in compute_signature
acl-cqc Jun 24, 2024
d6802c3
fix tests w/out extension_inference (only 3). 42 failures with --all-…
acl-cqc Jun 24, 2024
2c32b1a
simple_replace/sibling_subgraph - 23 failures remaining
acl-cqc Jun 25, 2024
5cb15ca
validate, note about closed_dfg_root_hugr
acl-cqc Jun 25, 2024
5b05df3
extension (op_def.rs, prelude.rs), serialize
acl-cqc Jun 25, 2024
78f7aed
replace.rs - need better CFG support here
acl-cqc Jun 25, 2024
856bfd4
builder copy_insertion+nested_identity (change from module+function t…
acl-cqc Jun 25, 2024
b2c09d5
circuit (explicit Function)
acl-cqc Jun 25, 2024
10bab5d
views (hugr-core now ok, and so, onto hugr-passes)
acl-cqc Jun 25, 2024
9e47b98
const_fold (WIN) - int_fn+float_fn+others => noargfn calling ft2
acl-cqc Jun 25, 2024
290bb6a
merge_bbs (TODO simple_block_builder, block_builder)
acl-cqc Jun 25, 2024
eca016d
doctests
acl-cqc Jun 25, 2024
abbfa2a
merge_bbs use simple_entry_builder
acl-cqc Jun 25, 2024
3994ee9
(simple_)block_builder variants, fix tests
acl-cqc Jun 25, 2024
0da41d8
Conditional builder (_exts)
acl-cqc Jun 25, 2024
a238f21
benches
acl-cqc Jun 25, 2024
79da288
clippy
acl-cqc Jun 28, 2024
ab4ef40
Merge remote-tracking branch 'origin/main' into ops_require_ext
acl-cqc Jul 15, 2024
1b5eebd
doc improvements
acl-cqc Jul 15, 2024
eb6d1d0
fix sibling_subgraph
acl-cqc Jul 15, 2024
a574c6d
force_order
acl-cqc Jul 15, 2024
aaa1cf3
docs: correct output_types parameter name
acl-cqc Jul 15, 2024
2c56bf6
extension_delta last param everytime
acl-cqc Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
//! # use hugr::Hugr;
//! # use hugr::builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, ModuleBuilder, DataflowSubContainer, HugrBuilder};
//! use hugr::extension::prelude::BOOL_T;
//! use hugr::std_extensions::logic::{NotOp, LOGIC_REG};
//! use hugr::std_extensions::logic::{EXTENSION_ID, LOGIC_REG, NotOp};
//! use hugr::types::FunctionType;
//!
//! # fn doctest() -> Result<(), BuildError> {
Expand All @@ -42,7 +42,7 @@
//! let _dfg_handle = {
//! let mut dfg = module_builder.define_function(
//! "main",
//! FunctionType::new(vec![BOOL_T], vec![BOOL_T]),
//! FunctionType::new_endo(BOOL_T).with_extension_delta(EXTENSION_ID),
//! )?;
//!
//! // Get the wires from the function inputs.
Expand All @@ -59,7 +59,8 @@
//! let _circuit_handle = {
//! let mut dfg = module_builder.define_function(
//! "circuit",
//! FunctionType::new_endo(vec![BOOL_T, BOOL_T]),
//! FunctionType::new_endo(vec![BOOL_T, BOOL_T])
//! .with_extension_delta(EXTENSION_ID),
//! )?;
//! let mut circuit = dfg.as_circuit(dfg.input_wires());
//!
Expand Down Expand Up @@ -285,10 +286,9 @@ pub(crate) mod test {
cfg_builder.finish_prelude_hugr().unwrap()
}

/// A helper method which creates a DFG rooted hugr with closed resources,
/// for tests which want to avoid having open extension variables after
/// inference. Using DFGBuilder will default to a root node with an open
/// extension variable
/// A helper method which creates a DFG rooted hugr with Input and Output node
/// only (no wires), given a function type with extension delta.
// TODO consider taking two type rows and using TO_BE_INFERRED
pub(crate) fn closed_dfg_root_hugr(signature: FunctionType) -> Hugr {
let mut hugr = Hugr::new(ops::DFG {
signature: signature.clone(),
Expand Down
31 changes: 27 additions & 4 deletions hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,24 +437,47 @@ pub trait Dataflow: Container {
TailLoopBuilder::create_with_io(self.hugr_mut(), loop_node, &tail_loop)
}

/// Return a builder for a [`crate::ops::Conditional`] node.
/// `sum_input` is a tuple of the type of the Sum
/// variants and the corresponding wire.
///
/// The `other_inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `output_types` are the types of the outputs. Extension delta will be inferred.
///
/// # Errors
///
/// This function will return an error if there is an error when building
/// the Conditional node.
fn conditional_builder(
&mut self,
sum_input: (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
self.conditional_builder_exts(sum_input, other_inputs, output_types, TO_BE_INFERRED)
}

/// Return a builder for a [`crate::ops::Conditional`] node.
/// `sum_rows` and `sum_wire` define the type of the Sum
/// variants and the wire carrying the Sum respectively.
///
/// The `other_inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `outputs` are the types of the outputs.
/// The `output_types` are the types of the outputs.
/// `exts` explicitly specifies the extension delta. Alternatively
/// [conditional_builder](Self::conditional_builder) may be used to infer it.
///
/// # Errors
///
/// This function will return an error if there is an error when building
/// the Conditional node.
fn conditional_builder(
fn conditional_builder_exts(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention extension_delta in the fn docs.

&mut self,
(sum_rows, sum_wire): (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
extension_delta: ExtensionSet,
extension_delta: impl Into<ExtensionSet>,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
let mut input_wires = vec![sum_wire];
let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
Expand All @@ -471,7 +494,7 @@ pub trait Dataflow: Container {
sum_rows,
other_inputs: inputs,
outputs: output_types,
extension_delta,
extension_delta: extension_delta.into(),
},
input_wires,
)?;
Expand Down
124 changes: 95 additions & 29 deletions hugr-core/src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use super::{
BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
};

use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType};
use crate::{
extension::TO_BE_INFERRED,
ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType},
};
use crate::{
extension::{ExtensionRegistry, ExtensionSet},
types::FunctionType,
Expand Down Expand Up @@ -43,7 +46,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// +------------+
/// */
/// use hugr::{
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder},
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder, endo_ft, inout_ft},
/// extension::{prelude, ExtensionSet},
/// ops, type_row,
/// types::{FunctionType, SumType, Type},
Expand All @@ -62,8 +65,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
///
/// // The second argument says what types will be passed through to every
/// // successor, in addition to the appropriate `sum_variants` type.
/// let mut entry_b =
/// cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?;
/// let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?;
///
/// let [inw] = entry_b.input_wires_arr();
/// let entry = {
Expand All @@ -82,7 +84,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// // `NAT` arguments: one from the `sum_variants` type, and another from the
/// // entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder.simple_block_builder(
/// FunctionType::new(type_row![NAT, NAT], type_row![NAT]),
/// inout_ft(type_row![NAT, NAT], NAT),
/// 1, // only one successor to this block
/// )?;
/// let successor_a = {
Expand All @@ -96,8 +98,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// };
///
/// // The only argument to this block is the entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder
/// .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
/// let mut successor_builder = cfg_builder.simple_block_builder(endo_ft(NAT), 1)?;
/// let successor_b = {
/// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
/// let [in_wire] = successor_builder.input_wires_arr();
Expand Down Expand Up @@ -197,7 +198,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {

/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`.
/// specified by `sum_rows`. Extension delta will be inferred.
///
/// # Errors
///
Expand All @@ -206,18 +207,40 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
inputs: TypeRow,
sum_rows: impl IntoIterator<Item = TypeRow>,
extension_delta: ExtensionSet,
other_outputs: TypeRow,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, false)
self.block_builder_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED)
}

fn any_block_builder(
/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`. Extension delta will be inferred.
///
/// # Errors
///
/// This function will return an error if there is an error adding the node.
pub fn block_builder_exts(
&mut self,
inputs: TypeRow,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
extension_delta: impl Into<ExtensionSet>,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.any_block_builder(
inputs,
extension_delta.into(),
sum_rows,
other_outputs,
false,
)
}

fn any_block_builder(
&mut self,
inputs: TypeRow,
extension_delta: ExtensionSet,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
entry: bool,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let sum_rows: Vec<_> = sum_rows.into_iter().collect();
Expand All @@ -241,7 +264,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
}

/// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
/// and `outputs` and `extension_delta` explicitly specified, plus a UnitSum type
/// (a Sum of `n_cases` unit types) to select the successor.
///
/// # Errors
///
Expand All @@ -251,17 +275,17 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
signature: FunctionType,
n_cases: usize,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.block_builder(
self.block_builder_exts(
signature.input,
vec![type_row![]; n_cases],
signature.extension_reqs,
signature.output,
signature.extension_reqs,
)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and the variants of the branching Sum value
/// specified by `sum_rows`.
/// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`
/// and the variants of the branching Sum value specified by `sum_rows`.
/// Extension delta will be inferred.
///
/// # Errors
///
Expand All @@ -270,17 +294,39 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder_exts(sum_rows, other_outputs, TO_BE_INFERRED)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`,
/// the variants of the branching Sum value specified by `sum_rows`, and
/// `extension_delta` explicitly specified. ([entry_builder](Self::entry_builder)
/// may be used to infer.)
///
/// # Errors
///
/// This function will return an error if an entry block has already been built.
pub fn entry_builder_exts(
&mut self,
sum_rows: impl IntoIterator<Item = TypeRow>,
other_outputs: TypeRow,
extension_delta: impl Into<ExtensionSet>,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
let inputs = self
.inputs
.take()
.ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
self.any_block_builder(inputs, sum_rows, other_outputs, extension_delta, true)
self.any_block_builder(
inputs,
extension_delta.into(),
sum_rows,
other_outputs,
true,
)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with `inputs`
/// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
/// Return a builder for the entry [`DataflowBlock`] child graph with
/// `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
///
/// # Errors
///
Expand All @@ -289,9 +335,24 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
&mut self,
outputs: TypeRow,
n_cases: usize,
extension_delta: ExtensionSet,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder(vec![type_row![]; n_cases], outputs, extension_delta)
self.entry_builder(vec![type_row![]; n_cases], outputs)
}

/// Return a builder for the entry [`DataflowBlock`] child graph with
/// `outputs` and a Sum of `n_cases` unit types, and explicit `extension_delta`.
/// ([simple_entry_builder](Self::simple_entry_builder) may be used to infer.)
///
/// # Errors
///
/// This function will return an error if there is an error adding the node.
pub fn simple_entry_builder_exts(
&mut self,
outputs: TypeRow,
n_cases: usize,
extension_delta: impl Into<ExtensionSet>,
) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
self.entry_builder_exts(vec![type_row![]; n_cases], outputs, extension_delta)
}

/// Returns the exit block of this [`CFGBuilder`].
Expand Down Expand Up @@ -439,8 +500,11 @@ pub(crate) mod test {
cfg_builder: &mut CFGBuilder<T>,
) -> Result<(), BuildError> {
let sum2_variants = vec![type_row![NAT], type_row![NAT]];
let mut entry_b =
cfg_builder.entry_builder(sum2_variants.clone(), type_row![], ExtensionSet::new())?;
let mut entry_b = cfg_builder.entry_builder_exts(
sum2_variants.clone(),
type_row![],
ExtensionSet::new(),
)?;
let entry = {
let [inw] = entry_b.input_wires_arr();

Expand All @@ -466,8 +530,11 @@ pub(crate) mod test {
let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
let sum_variants = vec![type_row![]];

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![], ExtensionSet::new())?;
let mut entry_b = cfg_builder.entry_builder_exts(
sum_variants.clone(),
type_row![],
ExtensionSet::new(),
)?;
let [inw] = entry_b.input_wires_arr();
let entry = {
let sum = entry_b.load_const(&sum_tuple_const);
Expand Down Expand Up @@ -501,8 +568,7 @@ pub(crate) mod test {
middle_b.finish_with_outputs(c, [inw])?
};

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?;
let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?;
let entry = {
let sum = entry_b.load_const(&sum_tuple_const);
// entry block uses wire from middle block even though middle block
Expand Down
11 changes: 8 additions & 3 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ mod test {

use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
use crate::utils::test_quantum_extension::{
cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
};
use crate::{
builder::{
Expand All @@ -261,6 +261,7 @@ mod test {
fn simple_linear() {
let build_res = build_main(
FunctionType::new(type_row![QB, QB], type_row![QB, QB])
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.with_extension_delta(float_types::EXTENSION_ID)
.into(),
|mut f_build| {
Expand Down Expand Up @@ -302,7 +303,9 @@ mod test {
FunctionType::new(vec![QB, NAT], vec![QB]),
));
let build_res = build_main(
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(),
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.into(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();

Expand All @@ -327,7 +330,9 @@ mod test {
#[test]
fn ancillae() {
let build_res = build_main(
FunctionType::new(type_row![QB], type_row![QB]).into(),
FunctionType::new_endo(QB)
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.into(),
|mut f_build| {
let mut circ = f_build.as_circuit(f_build.input_wires());
assert_eq!(circ.n_wires(), 1);
Expand Down
1 change: 0 additions & 1 deletion hugr-core/src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ mod test {
([type_row![], type_row![]], const_wire),
other_inputs,
outputs,
ExtensionSet::new(),
)?;

n_identity(conditional_b.case_builder(0)?)?;
Expand Down
Loading
Loading