Skip to content

Commit

Permalink
feat: make FuncDecl/FuncDefn polymorphic (#692)
Browse files Browse the repository at this point in the history
* Change FuncDecl and FuncDefn from storing FunctionType to PolyFuncType
* Change `declare_function` and other builder methods to take
PolyFuncType rather than Signature. The input-extensions of the
Signature were only used for the FuncDefn node itself and it seems
saner, and tests pass ok, for the builder to set the input-extensions of
the FuncDefn node itself to pure==empty in all cases. Note this makes
#695 slightly even more severe, but I feel we should just take this into
account when fixing that bug.
* Much of the bulk of the PR is thus changing
(someFunctionType)`.pure()` to `.into()`...
* Restructure validation. Break out the op-type specific checks that
type args were declared, into a hierarchical traversal (rather than
iterating over the Hugr's nodes), so they can gather binders from the
nearest enclosing FuncDefn (potentially several levels above in the
hierarchy). Move `validate_port` (and hence `validate_edge`) here
because ports/edges inside a FuncDefn may have types using its
parameters.
* Also do not allow Static edges to refer to TVs declared outside. (The
Static edge from a polymorphic FuncDefn contains a PolyFuncType i.e.
declaring its own type vars, so they are not free.)
* Update OpDef to properly check_args etc. reflecting that we may need
to pass some args to the binary function and then more args to the
*returned* PolyFuncType.
   * Also drop `impl TypeParametrised for OpDef`
* Add TypeArgs to Call, and builder `fn call()`; also give the latter an
ExtensionRegistry argument
* Finally add `impl From<TypeBound> for TypeParam`, which we probably
should have done much earlier.
  • Loading branch information
acl-cqc authored Nov 21, 2023
1 parent a63487a commit 2f952ab
Show file tree
Hide file tree
Showing 27 changed files with 372 additions and 195 deletions.
9 changes: 7 additions & 2 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use thiserror::Error;
#[cfg(feature = "pyo3")]
use pyo3::{create_exception, exceptions::PyException, PyErr};

use crate::extension::SignatureError;
use crate::hugr::{HugrError, ValidationError};
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::types::ConstTypeError;
Expand Down Expand Up @@ -43,6 +44,10 @@ pub enum BuildError {
/// The constructed HUGR is invalid.
#[error("The constructed HUGR is invalid: {0}.")]
InvalidHUGR(#[from] ValidationError),
/// SignatureError in trying to construct a node (differs from
/// [ValidationError::SignatureError] in that we could not construct a node to report about)
#[error(transparent)]
SignatureError(#[from] SignatureError),
/// Tried to add a malformed [Const]
///
/// [Const]: crate::ops::constant::Const
Expand Down Expand Up @@ -100,7 +105,7 @@ pub(crate) mod test {

use crate::hugr::{views::HugrView, HugrMut, NodeType};
use crate::ops;
use crate::types::{FunctionType, Signature, Type};
use crate::types::{FunctionType, PolyFuncType, Type};
use crate::{type_row, Hugr};

use super::handle::BuildHandle;
Expand All @@ -123,7 +128,7 @@ pub(crate) mod test {
}

pub(super) fn build_main(
signature: Signature,
signature: PolyFuncType,
f: impl FnOnce(FunctionBuilder<&mut Hugr>) -> Result<BuildHandle<FuncID<true>>, BuildError>,
) -> Result<Hugr, BuildError> {
let mut module_builder = ModuleBuilder::new();
Expand Down
23 changes: 12 additions & 11 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
};

use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE_REGISTRY};
use crate::types::{FunctionType, Signature, Type, TypeRow};
use crate::types::{FunctionType, PolyFuncType, Type, TypeArg, TypeRow};

use itertools::Itertools;

Expand Down Expand Up @@ -90,22 +90,19 @@ pub trait Container {
fn define_function(
&mut self,
name: impl Into<String>,
signature: Signature,
signature: PolyFuncType,
) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
let body = signature.body.clone();
let f_node = self.add_child_node(NodeType::new(
ops::FuncDefn {
name: name.into(),
signature: signature.signature.clone(),
signature,
},
signature.input_extensions.clone(),
ExtensionSet::new(),
))?;

let db = DFGBuilder::create_with_io(
self.hugr_mut(),
f_node,
signature.signature,
Some(signature.input_extensions),
)?;
let db =
DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, Some(ExtensionSet::new()))?;
Ok(FunctionBuilder::from_dfg_builder(db))
}

Expand Down Expand Up @@ -598,11 +595,14 @@ pub trait Dataflow: Container {
fn call<const DEFINED: bool>(
&mut self,
function: &FuncID<DEFINED>,
type_args: &[TypeArg],
input_wires: impl IntoIterator<Item = Wire>,
// Sadly required as we substituting in type_args may result in recomputing bounds of types:
exts: &ExtensionRegistry,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let hugr = self.hugr();
let def_op = hugr.get_optype(function.node());
let signature = match def_op {
let type_scheme = match def_op {
OpType::FuncDefn(ops::FuncDefn { signature, .. })
| OpType::FuncDecl(ops::FuncDecl { signature, .. }) => signature.clone(),
_ => {
Expand All @@ -612,6 +612,7 @@ pub trait Dataflow: Container {
})
}
};
let signature = type_scheme.instantiate(type_args, exts)?;
let op: OpType = ops::Call { signature }.into();
let const_in_port = op.static_input_port().unwrap();
let op_id = self.add_dataflow_op(op, input_wires)?;
Expand Down
2 changes: 1 addition & 1 deletion src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ mod test {
let build_result = {
let mut module_builder = ModuleBuilder::new();
let mut func_builder = module_builder
.define_function("main", FunctionType::new(vec![NAT], type_row![NAT]).pure())?;
.define_function("main", FunctionType::new(vec![NAT], type_row![NAT]).into())?;
let _f_id = {
let [int] = func_builder.input_wires_arr();

Expand Down
4 changes: 2 additions & 2 deletions src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ mod test {
#[test]
fn simple_linear() {
let build_res = build_main(
FunctionType::new(type_row![QB, QB], type_row![QB, QB]).pure(),
FunctionType::new(type_row![QB, QB], type_row![QB, QB]).into(),
|mut f_build| {
let wires = f_build.input_wires().collect();

Expand Down Expand Up @@ -184,7 +184,7 @@ mod test {
.into(),
);
let build_res = build_main(
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).pure(),
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();

Expand Down
2 changes: 1 addition & 1 deletion src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ mod test {
let mut module_builder = ModuleBuilder::new();
let mut fbuild = module_builder.define_function(
"main",
FunctionType::new(type_row![NAT], type_row![NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT]).into(),
)?;
let tru_const = fbuild.add_constant(Const::true_val(), ExtensionSet::new())?;
let _fdef = {
Expand Down
28 changes: 12 additions & 16 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::marker::PhantomData;
use crate::hugr::{HugrView, NodeType, ValidationError};
use crate::ops;

use crate::types::{FunctionType, Signature};
use crate::types::{FunctionType, PolyFuncType};

use crate::extension::{ExtensionRegistry, ExtensionSet};
use crate::Node;
Expand Down Expand Up @@ -146,21 +146,17 @@ impl FunctionBuilder<Hugr> {
/// # Errors
///
/// Error in adding DFG child nodes.
pub fn new(name: impl Into<String>, signature: Signature) -> Result<Self, BuildError> {
pub fn new(name: impl Into<String>, signature: PolyFuncType) -> Result<Self, BuildError> {
let body = signature.body.clone();
let op = ops::FuncDefn {
signature: signature.clone().into(),
signature,
name: name.into(),
};

let base = Hugr::new(NodeType::new(op, signature.input_extensions.clone()));
let base = Hugr::new(NodeType::new_pure(op));
let root = base.root();

let db = DFGBuilder::create_with_io(
base,
root,
signature.signature,
Some(signature.input_extensions),
)?;
let db = DFGBuilder::create_with_io(base, root, body, Some(ExtensionSet::new()))?;
Ok(Self::from_dfg_builder(db))
}
}
Expand Down Expand Up @@ -239,7 +235,7 @@ pub(crate) mod test {
let _f_id = {
let mut func_builder = module_builder.define_function(
"main",
FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).pure(),
FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into(),
)?;

let [int, qb] = func_builder.input_wires_arr();
Expand Down Expand Up @@ -273,7 +269,7 @@ pub(crate) mod test {

let f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).pure(),
FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).into(),
)?;

f(f_build)?;
Expand Down Expand Up @@ -323,7 +319,7 @@ pub(crate) mod test {

let f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![QB], type_row![QB, QB]).pure(),
FunctionType::new(type_row![QB], type_row![QB, QB]).into(),
)?;

let [q1] = f_build.input_wires_arr();
Expand All @@ -340,7 +336,7 @@ pub(crate) mod test {
let builder = || -> Result<Hugr, BuildError> {
let mut f_build = FunctionBuilder::new(
"main",
FunctionType::new(type_row![BIT], type_row![BIT]).pure(),
FunctionType::new(type_row![BIT], type_row![BIT]).into(),
)?;

let [i1] = f_build.input_wires_arr();
Expand All @@ -364,7 +360,7 @@ pub(crate) mod test {
fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> {
let mut f_build = FunctionBuilder::new(
"main",
FunctionType::new(type_row![QB], type_row![QB]).pure(),
FunctionType::new(type_row![QB], type_row![QB]).into(),
)?;

let [i1] = f_build.input_wires_arr();
Expand Down Expand Up @@ -408,7 +404,7 @@ pub(crate) mod test {
let (dfg_node, f_node) = {
let mut f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![BIT], type_row![BIT]).pure(),
FunctionType::new(type_row![BIT], type_row![BIT]).into(),
)?;

let [i1] = f_build.input_wires_arr();
Expand Down
42 changes: 17 additions & 25 deletions src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ use crate::{
extension::ExtensionRegistry,
hugr::{hugrmut::sealed::HugrMutInternals, views::HugrView, ValidationError},
ops,
types::{Type, TypeBound},
types::{PolyFuncType, Type, TypeBound},
};

use crate::ops::handle::{AliasID, FuncID, NodeHandle};

use crate::types::Signature;

use crate::Node;
use smol_str::SmolStr;

Expand Down Expand Up @@ -86,16 +84,13 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
op_desc: "crate::ops::OpType::FuncDecl",
})?
.clone();

let body = signature.body.clone();
self.hugr_mut().replace_op(
f_node,
NodeType::new_pure(ops::FuncDefn {
name,
signature: signature.clone(),
}),
NodeType::new_pure(ops::FuncDefn { name, signature }),
)?;

let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, signature, None)?;
let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, None)?;
Ok(FunctionBuilder::from_dfg_builder(db))
}

Expand All @@ -108,17 +103,13 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
pub fn declare(
&mut self,
name: impl Into<String>,
signature: Signature,
signature: PolyFuncType,
) -> Result<FuncID<false>, BuildError> {
// TODO add param names to metadata
let rs = signature.input_extensions.clone();
let declare_n = self.add_child_node(NodeType::new(
ops::FuncDecl {
signature: signature.into(),
name: name.into(),
},
rs,
))?;
let declare_n = self.add_child_node(NodeType::new_pure(ops::FuncDecl {
signature,
name: name.into(),
}))?;

Ok(declare_n.into())
}
Expand Down Expand Up @@ -176,7 +167,7 @@ mod test {
test::{n_identity, NAT},
Dataflow, DataflowSubContainer,
},
extension::EMPTY_REG,
extension::{EMPTY_REG, PRELUDE_REGISTRY},
type_row,
types::FunctionType,
};
Expand All @@ -189,11 +180,11 @@ mod test {

let f_id = module_builder.declare(
"main",
FunctionType::new(type_row![NAT], type_row![NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT]).into(),
)?;

let mut f_build = module_builder.define_declaration(&f_id)?;
let call = f_build.call(&f_id, f_build.input_wires())?;
let call = f_build.call(&f_id, &[], f_build.input_wires(), &PRELUDE_REGISTRY)?;

f_build.finish_with_outputs(call.outputs())?;
module_builder.finish_prelude_hugr()
Expand All @@ -216,7 +207,7 @@ mod test {
vec![qubit_state_type.get_alias_type()],
vec![qubit_state_type.get_alias_type()],
)
.pure(),
.into(),
)?;
n_identity(f_build)?;
module_builder.finish_hugr(&EMPTY_REG)
Expand All @@ -232,16 +223,17 @@ mod test {

let mut f_build = module_builder.define_function(
"main",
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).into(),
)?;
let local_build = f_build.define_function(
"local",
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(),
FunctionType::new(type_row![NAT], type_row![NAT, NAT]).into(),
)?;
let [wire] = local_build.input_wires_arr();
let f_id = local_build.finish_with_outputs([wire, wire])?;

let call = f_build.call(f_id.handle(), f_build.input_wires())?;
let call =
f_build.call(f_id.handle(), &[], f_build.input_wires(), &PRELUDE_REGISTRY)?;

f_build.finish_with_outputs(call.outputs())?;
module_builder.finish_prelude_hugr()
Expand Down
13 changes: 11 additions & 2 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,19 @@ mod test {
let mut fbuild = module_builder.define_function(
"main",
FunctionType::new(type_row![BIT], type_row![NAT])
.with_input_extensions(ExtensionSet::singleton(&PRELUDE_ID)),
.with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID))
.into(),
)?;
let _fdef = {
let [b1] = fbuild.input_wires_arr();
let [b1] = fbuild
.add_dataflow_op(
ops::LeafOp::Lift {
type_row: type_row![BIT],
new_extension: PRELUDE_ID,
},
fbuild.input_wires(),
)?
.outputs_arr();
let loop_id = {
let mut loop_b =
fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?;
Expand Down
6 changes: 3 additions & 3 deletions src/extension/infer/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ fn simple_funcdefn() -> Result<(), Box<dyn Error>> {
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
.into(),
)?;

let [w] = func_builder.input_wires_arr();
Expand All @@ -979,7 +979,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box<dyn Error>> {
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
.into(),
)?;

let [w] = func_builder.input_wires_arr();
Expand Down Expand Up @@ -1013,7 +1013,7 @@ fn funcdefn_signature_mismatch2() -> Result<(), Box<dyn Error>> {
"F",
FunctionType::new(vec![NAT], vec![NAT])
.with_extension_delta(&ExtensionSet::singleton(&A))
.pure(),
.into(),
)?;

let [w] = func_builder.input_wires_arr();
Expand Down
Loading

0 comments on commit 2f952ab

Please sign in to comment.