From 2f952aba3d7064b91544195bf1160b6520f1e885 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 21 Nov 2023 17:05:11 +0000 Subject: [PATCH] feat: make FuncDecl/FuncDefn polymorphic (#692) * Change FuncDecl and FuncDefn from storing FunctionType to PolyFuncType * Change `declare_function` and other builder methods to take PolyFuncType rather than Signature. The input-extensions of the Signature were only used for the FuncDefn node itself and it seems saner, and tests pass ok, for the builder to set the input-extensions of the FuncDefn node itself to pure==empty in all cases. Note this makes #695 slightly even more severe, but I feel we should just take this into account when fixing that bug. * Much of the bulk of the PR is thus changing (someFunctionType)`.pure()` to `.into()`... * Restructure validation. Break out the op-type specific checks that type args were declared, into a hierarchical traversal (rather than iterating over the Hugr's nodes), so they can gather binders from the nearest enclosing FuncDefn (potentially several levels above in the hierarchy). Move `validate_port` (and hence `validate_edge`) here because ports/edges inside a FuncDefn may have types using its parameters. * Also do not allow Static edges to refer to TVs declared outside. (The Static edge from a polymorphic FuncDefn contains a PolyFuncType i.e. declaring its own type vars, so they are not free.) * Update OpDef to properly check_args etc. reflecting that we may need to pass some args to the binary function and then more args to the *returned* PolyFuncType. * Also drop `impl TypeParametrised for OpDef` * Add TypeArgs to Call, and builder `fn call()`; also give the latter an ExtensionRegistry argument * Finally add `impl From for TypeParam`, which we probably should have done much earlier. --- src/builder.rs | 9 +- src/builder/build_traits.rs | 23 ++--- src/builder/cfg.rs | 2 +- src/builder/circuit.rs | 4 +- src/builder/conditional.rs | 2 +- src/builder/dataflow.rs | 28 +++--- src/builder/module.rs | 42 ++++----- src/builder/tail_loop.rs | 13 ++- src/extension/infer/test.rs | 6 +- src/extension/op_def.rs | 50 ++++++----- src/hugr/hugrmut.rs | 2 +- src/hugr/rewrite/outline_cfg.rs | 2 +- src/hugr/rewrite/simple_replace.rs | 2 +- src/hugr/serialize.rs | 4 +- src/hugr/validate.rs | 137 +++++++++++++++++++---------- src/hugr/validate/test.rs | 133 ++++++++++++++++++++++++++-- src/hugr/views.rs | 10 +-- src/hugr/views/descendants.rs | 6 +- src/hugr/views/sibling.rs | 8 +- src/hugr/views/sibling_subgraph.rs | 23 ++--- src/hugr/views/tests.rs | 15 +++- src/ops/module.rs | 6 +- src/ops/validate.rs | 12 ++- src/std_extensions/collections.rs | 8 ++ src/types/check.rs | 2 +- src/types/poly_func.rs | 12 +-- src/types/type_param.rs | 6 ++ 27 files changed, 372 insertions(+), 195 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index 5e437bcbc..c24913303 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -5,6 +5,7 @@ use thiserror::Error; #[cfg(feature = "pyo3")] use pyo3::{create_exception, exceptions::PyException, PyErr}; +use crate::extension::SignatureError; use crate::hugr::{HugrError, ValidationError}; use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID}; use crate::types::ConstTypeError; @@ -43,6 +44,10 @@ pub enum BuildError { /// The constructed HUGR is invalid. #[error("The constructed HUGR is invalid: {0}.")] InvalidHUGR(#[from] ValidationError), + /// SignatureError in trying to construct a node (differs from + /// [ValidationError::SignatureError] in that we could not construct a node to report about) + #[error(transparent)] + SignatureError(#[from] SignatureError), /// Tried to add a malformed [Const] /// /// [Const]: crate::ops::constant::Const @@ -100,7 +105,7 @@ pub(crate) mod test { use crate::hugr::{views::HugrView, HugrMut, NodeType}; use crate::ops; - use crate::types::{FunctionType, Signature, Type}; + use crate::types::{FunctionType, PolyFuncType, Type}; use crate::{type_row, Hugr}; use super::handle::BuildHandle; @@ -123,7 +128,7 @@ pub(crate) mod test { } pub(super) fn build_main( - signature: Signature, + signature: PolyFuncType, f: impl FnOnce(FunctionBuilder<&mut Hugr>) -> Result>, BuildError>, ) -> Result { let mut module_builder = ModuleBuilder::new(); diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index e1435fffe..9f467426a 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -20,7 +20,7 @@ use crate::{ }; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE_REGISTRY}; -use crate::types::{FunctionType, Signature, Type, TypeRow}; +use crate::types::{FunctionType, PolyFuncType, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -90,22 +90,19 @@ pub trait Container { fn define_function( &mut self, name: impl Into, - signature: Signature, + signature: PolyFuncType, ) -> Result, BuildError> { + let body = signature.body.clone(); let f_node = self.add_child_node(NodeType::new( ops::FuncDefn { name: name.into(), - signature: signature.signature.clone(), + signature, }, - signature.input_extensions.clone(), + ExtensionSet::new(), ))?; - let db = DFGBuilder::create_with_io( - self.hugr_mut(), - f_node, - signature.signature, - Some(signature.input_extensions), - )?; + let db = + DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, Some(ExtensionSet::new()))?; Ok(FunctionBuilder::from_dfg_builder(db)) } @@ -598,11 +595,14 @@ pub trait Dataflow: Container { fn call( &mut self, function: &FuncID, + type_args: &[TypeArg], input_wires: impl IntoIterator, + // Sadly required as we substituting in type_args may result in recomputing bounds of types: + exts: &ExtensionRegistry, ) -> Result, BuildError> { let hugr = self.hugr(); let def_op = hugr.get_optype(function.node()); - let signature = match def_op { + let type_scheme = match def_op { OpType::FuncDefn(ops::FuncDefn { signature, .. }) | OpType::FuncDecl(ops::FuncDecl { signature, .. }) => signature.clone(), _ => { @@ -612,6 +612,7 @@ pub trait Dataflow: Container { }) } }; + let signature = type_scheme.instantiate(type_args, exts)?; let op: OpType = ops::Call { signature }.into(); let const_in_port = op.static_input_port().unwrap(); let op_id = self.add_dataflow_op(op, input_wires)?; diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index fe2954bd5..99781ea2c 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -336,7 +336,7 @@ mod test { let build_result = { let mut module_builder = ModuleBuilder::new(); let mut func_builder = module_builder - .define_function("main", FunctionType::new(vec![NAT], type_row![NAT]).pure())?; + .define_function("main", FunctionType::new(vec![NAT], type_row![NAT]).into())?; let _f_id = { let [int] = func_builder.input_wires_arr(); diff --git a/src/builder/circuit.rs b/src/builder/circuit.rs index 043945ff9..054f839fa 100644 --- a/src/builder/circuit.rs +++ b/src/builder/circuit.rs @@ -147,7 +147,7 @@ mod test { #[test] fn simple_linear() { let build_res = build_main( - FunctionType::new(type_row![QB, QB], type_row![QB, QB]).pure(), + FunctionType::new(type_row![QB, QB], type_row![QB, QB]).into(), |mut f_build| { let wires = f_build.input_wires().collect(); @@ -184,7 +184,7 @@ mod test { .into(), ); let build_res = build_main( - FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).pure(), + FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).into(), |mut f_build| { let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr(); diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 895bdf552..0238d14a4 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -240,7 +240,7 @@ mod test { let mut module_builder = ModuleBuilder::new(); let mut fbuild = module_builder.define_function( "main", - FunctionType::new(type_row![NAT], type_row![NAT]).pure(), + FunctionType::new(type_row![NAT], type_row![NAT]).into(), )?; let tru_const = fbuild.add_constant(Const::true_val(), ExtensionSet::new())?; let _fdef = { diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 3ef28ac8d..1db768b52 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -7,7 +7,7 @@ use std::marker::PhantomData; use crate::hugr::{HugrView, NodeType, ValidationError}; use crate::ops; -use crate::types::{FunctionType, Signature}; +use crate::types::{FunctionType, PolyFuncType}; use crate::extension::{ExtensionRegistry, ExtensionSet}; use crate::Node; @@ -146,21 +146,17 @@ impl FunctionBuilder { /// # Errors /// /// Error in adding DFG child nodes. - pub fn new(name: impl Into, signature: Signature) -> Result { + pub fn new(name: impl Into, signature: PolyFuncType) -> Result { + let body = signature.body.clone(); let op = ops::FuncDefn { - signature: signature.clone().into(), + signature, name: name.into(), }; - let base = Hugr::new(NodeType::new(op, signature.input_extensions.clone())); + let base = Hugr::new(NodeType::new_pure(op)); let root = base.root(); - let db = DFGBuilder::create_with_io( - base, - root, - signature.signature, - Some(signature.input_extensions), - )?; + let db = DFGBuilder::create_with_io(base, root, body, Some(ExtensionSet::new()))?; Ok(Self::from_dfg_builder(db)) } } @@ -239,7 +235,7 @@ pub(crate) mod test { let _f_id = { let mut func_builder = module_builder.define_function( "main", - FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).pure(), + FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into(), )?; let [int, qb] = func_builder.input_wires_arr(); @@ -273,7 +269,7 @@ pub(crate) mod test { let f_build = module_builder.define_function( "main", - FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).pure(), + FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]).into(), )?; f(f_build)?; @@ -323,7 +319,7 @@ pub(crate) mod test { let f_build = module_builder.define_function( "main", - FunctionType::new(type_row![QB], type_row![QB, QB]).pure(), + FunctionType::new(type_row![QB], type_row![QB, QB]).into(), )?; let [q1] = f_build.input_wires_arr(); @@ -340,7 +336,7 @@ pub(crate) mod test { let builder = || -> Result { let mut f_build = FunctionBuilder::new( "main", - FunctionType::new(type_row![BIT], type_row![BIT]).pure(), + FunctionType::new(type_row![BIT], type_row![BIT]).into(), )?; let [i1] = f_build.input_wires_arr(); @@ -364,7 +360,7 @@ pub(crate) mod test { fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> { let mut f_build = FunctionBuilder::new( "main", - FunctionType::new(type_row![QB], type_row![QB]).pure(), + FunctionType::new(type_row![QB], type_row![QB]).into(), )?; let [i1] = f_build.input_wires_arr(); @@ -408,7 +404,7 @@ pub(crate) mod test { let (dfg_node, f_node) = { let mut f_build = module_builder.define_function( "main", - FunctionType::new(type_row![BIT], type_row![BIT]).pure(), + FunctionType::new(type_row![BIT], type_row![BIT]).into(), )?; let [i1] = f_build.input_wires_arr(); diff --git a/src/builder/module.rs b/src/builder/module.rs index 3fbce55c4..09ac47da3 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -8,13 +8,11 @@ use crate::{ extension::ExtensionRegistry, hugr::{hugrmut::sealed::HugrMutInternals, views::HugrView, ValidationError}, ops, - types::{Type, TypeBound}, + types::{PolyFuncType, Type, TypeBound}, }; use crate::ops::handle::{AliasID, FuncID, NodeHandle}; -use crate::types::Signature; - use crate::Node; use smol_str::SmolStr; @@ -86,16 +84,13 @@ impl + AsRef> ModuleBuilder { op_desc: "crate::ops::OpType::FuncDecl", })? .clone(); - + let body = signature.body.clone(); self.hugr_mut().replace_op( f_node, - NodeType::new_pure(ops::FuncDefn { - name, - signature: signature.clone(), - }), + NodeType::new_pure(ops::FuncDefn { name, signature }), )?; - let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, signature, None)?; + let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body, None)?; Ok(FunctionBuilder::from_dfg_builder(db)) } @@ -108,17 +103,13 @@ impl + AsRef> ModuleBuilder { pub fn declare( &mut self, name: impl Into, - signature: Signature, + signature: PolyFuncType, ) -> Result, BuildError> { // TODO add param names to metadata - let rs = signature.input_extensions.clone(); - let declare_n = self.add_child_node(NodeType::new( - ops::FuncDecl { - signature: signature.into(), - name: name.into(), - }, - rs, - ))?; + let declare_n = self.add_child_node(NodeType::new_pure(ops::FuncDecl { + signature, + name: name.into(), + }))?; Ok(declare_n.into()) } @@ -176,7 +167,7 @@ mod test { test::{n_identity, NAT}, Dataflow, DataflowSubContainer, }, - extension::EMPTY_REG, + extension::{EMPTY_REG, PRELUDE_REGISTRY}, type_row, types::FunctionType, }; @@ -189,11 +180,11 @@ mod test { let f_id = module_builder.declare( "main", - FunctionType::new(type_row![NAT], type_row![NAT]).pure(), + FunctionType::new(type_row![NAT], type_row![NAT]).into(), )?; let mut f_build = module_builder.define_declaration(&f_id)?; - let call = f_build.call(&f_id, f_build.input_wires())?; + let call = f_build.call(&f_id, &[], f_build.input_wires(), &PRELUDE_REGISTRY)?; f_build.finish_with_outputs(call.outputs())?; module_builder.finish_prelude_hugr() @@ -216,7 +207,7 @@ mod test { vec![qubit_state_type.get_alias_type()], vec![qubit_state_type.get_alias_type()], ) - .pure(), + .into(), )?; n_identity(f_build)?; module_builder.finish_hugr(&EMPTY_REG) @@ -232,16 +223,17 @@ mod test { let mut f_build = module_builder.define_function( "main", - FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(), + FunctionType::new(type_row![NAT], type_row![NAT, NAT]).into(), )?; let local_build = f_build.define_function( "local", - FunctionType::new(type_row![NAT], type_row![NAT, NAT]).pure(), + FunctionType::new(type_row![NAT], type_row![NAT, NAT]).into(), )?; let [wire] = local_build.input_wires_arr(); let f_id = local_build.finish_with_outputs([wire, wire])?; - let call = f_build.call(f_id.handle(), f_build.input_wires())?; + let call = + f_build.call(f_id.handle(), &[], f_build.input_wires(), &PRELUDE_REGISTRY)?; f_build.finish_with_outputs(call.outputs())?; module_builder.finish_prelude_hugr() diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index 8ea5113ae..9ab71182b 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -130,10 +130,19 @@ mod test { let mut fbuild = module_builder.define_function( "main", FunctionType::new(type_row![BIT], type_row![NAT]) - .with_input_extensions(ExtensionSet::singleton(&PRELUDE_ID)), + .with_extension_delta(&ExtensionSet::singleton(&PRELUDE_ID)) + .into(), )?; let _fdef = { - let [b1] = fbuild.input_wires_arr(); + let [b1] = fbuild + .add_dataflow_op( + ops::LeafOp::Lift { + type_row: type_row![BIT], + new_extension: PRELUDE_ID, + }, + fbuild.input_wires(), + )? + .outputs_arr(); let loop_id = { let mut loop_b = fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?; diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index c54f38206..d8fb8f16b 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -955,7 +955,7 @@ fn simple_funcdefn() -> Result<(), Box> { "F", FunctionType::new(vec![NAT], vec![NAT]) .with_extension_delta(&ExtensionSet::singleton(&A)) - .pure(), + .into(), )?; let [w] = func_builder.input_wires_arr(); @@ -979,7 +979,7 @@ fn funcdefn_signature_mismatch() -> Result<(), Box> { "F", FunctionType::new(vec![NAT], vec![NAT]) .with_extension_delta(&ExtensionSet::singleton(&A)) - .pure(), + .into(), )?; let [w] = func_builder.input_wires_arr(); @@ -1013,7 +1013,7 @@ fn funcdefn_signature_mismatch2() -> Result<(), Box> { "F", FunctionType::new(vec![NAT], vec![NAT]) .with_extension_delta(&ExtensionSet::singleton(&A)) - .pure(), + .into(), )?; let [w] = func_builder.input_wires_arr(); diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index b9107acba..9dc2fcd18 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -8,10 +8,8 @@ use smol_str::SmolStr; use super::{ Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, - TypeParametrised, }; -use crate::ops::custom::OpaqueOp; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; use crate::types::{FunctionType, PolyFuncType}; use crate::Hugr; @@ -143,26 +141,36 @@ pub struct OpDef { lower_funcs: Vec, } -impl TypeParametrised for OpDef { - type Concrete = OpaqueOp; - - fn params(&self) -> &[TypeParam] { - self.params() - } - - fn name(&self) -> &SmolStr { - self.name() - } - - fn extension(&self) -> &ExtensionId { - self.extension() - } -} - impl OpDef { - /// Check provided type arguments are valid against parameters. - pub fn check_args(&self, args: &[TypeArg]) -> Result<(), SignatureError> { - self.check_args_impl(args) + /// Check provided type arguments are valid against [ExtensionRegistry], + /// against parameters, and that no type variables are used as static arguments + /// (to [compute_signature][CustomSignatureFunc::compute_signature]) + pub fn validate_args( + &self, + args: &[TypeArg], + exts: &ExtensionRegistry, + var_decls: &[TypeParam], + ) -> Result<(), SignatureError> { + let temp: PolyFuncType; // to keep alive + let (pf, args) = match &self.signature_func { + SignatureFunc::TypeScheme(ts) => (ts, args), + SignatureFunc::CustomFunc { + static_params, + func, + } => { + let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); + static_args + .iter() + .try_for_each(|ta| ta.validate(exts, &[]))?; + check_type_args(static_args, static_params)?; + temp = func.compute_signature(&self.name, static_args, &self.misc, exts)?; + (&temp, other_args) + } + }; + args.iter() + .try_for_each(|ta| ta.validate(exts, var_decls))?; + check_type_args(args, pf.params())?; + Ok(()) } /// Computes the signature of a node, i.e. an instantiation of this diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index f549929d0..1f5dd6d30 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -636,7 +636,7 @@ mod test { module, ops::FuncDefn { name: "main".into(), - signature: FunctionType::new(type_row![NAT], type_row![NAT, NAT]), + signature: FunctionType::new(type_row![NAT], type_row![NAT, NAT]).into(), }, ) .expect("Failed to add function definition node"); diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index 3cd5ad2f5..d0640048a 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -362,7 +362,7 @@ mod test { let mut fbuild = module_builder .define_function( "main", - FunctionType::new(type_row![USIZE_T], type_row![USIZE_T]).pure(), + FunctionType::new(type_row![USIZE_T], type_row![USIZE_T]).into(), ) .unwrap(); let [i1] = fbuild.input_wires_arr(); diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index b02dacafd..2f14dee3e 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -251,7 +251,7 @@ pub(in crate::hugr::rewrite) mod test { let _f_id = { let mut func_builder = module_builder.define_function( "main", - FunctionType::new(type_row![QB, QB, QB], type_row![QB, QB, QB]).pure(), + FunctionType::new(type_row![QB, QB, QB], type_row![QB, QB, QB]).into(), )?; let [qb0, qb1, qb2] = func_builder.input_wires_arr(); diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index b1f99a4a2..4b7da1a3b 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -363,7 +363,7 @@ pub mod test { let t_row = vec![Type::new_sum(vec![NAT, QB])]; let mut f_build = module_builder - .define_function("main", FunctionType::new(t_row.clone(), t_row).pure()) + .define_function("main", FunctionType::new(t_row.clone(), t_row).into()) .unwrap(); let outputs = f_build @@ -398,7 +398,7 @@ pub mod test { let mut module_builder = ModuleBuilder::new(); let t_row = vec![Type::new_sum(vec![NAT, QB])]; let mut f_build = module_builder - .define_function("main", FunctionType::new(t_row.clone(), t_row).pure()) + .define_function("main", FunctionType::new(t_row.clone(), t_row).into()) .unwrap(); let outputs = f_build diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 2725c5e25..bf0021ad8 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -19,9 +19,10 @@ use crate::extension::{ }; use crate::ops::custom::CustomOpError; -use crate::ops::custom::{resolve_opaque_op, ExtensionOp, ExternalOp}; +use crate::ops::custom::{resolve_opaque_op, ExternalOp}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; -use crate::ops::{OpTag, OpTrait, OpType, ValidateOp}; +use crate::ops::{FuncDefn, OpTag, OpTrait, OpType, ValidateOp}; +use crate::types::type_param::TypeParam; use crate::types::{EdgeKind, Type}; use crate::{Direction, Hugr, Node, Port}; @@ -93,6 +94,9 @@ impl<'a, 'b> ValidationContext<'a, 'b> { self.validate_node(node)?; } + // Hierarchy and children. No type variables declared outside the root. + self.validate_subtree(self.hugr.root(), &[])?; + Ok(()) } @@ -112,7 +116,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { /// This includes: /// - Matching the number of ports with the signature /// - Dataflow ports are correct. See `validate_df_port` - fn validate_node(&mut self, node: Node) -> Result<(), ValidationError> { + fn validate_node(&self, node: Node) -> Result<(), ValidationError> { let node_type = self.hugr.get_nodetype(node); let op_type = &node_type.op; @@ -154,49 +158,9 @@ impl<'a, 'b> ValidationContext<'a, 'b> { dir, }); } - - // Check port connections - for (i, port_index) in self.hugr.graph.ports(node.pg_index(), dir).enumerate() { - let port = Port::new(dir, i); - self.validate_port(node, port, port_index, op_type)?; - } } } - // Check operation-specific constraints. - // TODO make a separate method for this (perhaps producing Result<(), SignatureError>) - match op_type { - OpType::LeafOp(crate::ops::LeafOp::CustomOp(b)) => { - // Check TypeArgs are valid (in themselves, not necessarily wrt the TypeParams) - for arg in b.args() { - // Hugrs are monomorphic, so no type variables in scope - arg.validate(self.extension_registry, &[]) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; - } - // Try to resolve serialized names to actual OpDefs in Extensions. - let e: Option; - let ext_op = match &**b { - ExternalOp::Opaque(op) => { - // If resolve_extension_ops has been called first, this would always return Ok(None) - e = resolve_opaque_op(node, op, self.extension_registry)?; - e.as_ref() - } - ExternalOp::Extension(ext) => Some(ext), - }; - // If successful, check TypeArgs are valid for the declared TypeParams - if let Some(ext_op) = ext_op { - ext_op - .def() - .check_args(ext_op.args()) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; - } - } - OpType::LeafOp(crate::ops::LeafOp::TypeApply { ta }) => { - ta.validate(self.extension_registry) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; - } - _ => (), - } // Secondly that the node has correct children self.validate_children(node, node_type)?; @@ -223,6 +187,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { port: Port, port_index: portgraph::PortIndex, op_type: &OpType, + var_decls: &[TypeParam], ) -> Result<(), ValidationError> { let port_kind = op_type.port_kind(port).unwrap(); let dir = port.direction(); @@ -253,8 +218,13 @@ impl<'a, 'b> ValidationContext<'a, 'b> { } match &port_kind { - EdgeKind::Value(ty) | EdgeKind::Static(ty) => ty - .validate(self.extension_registry, &[]) // no type vars inside the Hugr + EdgeKind::Value(ty) => ty + .validate(self.extension_registry, var_decls) + .map_err(|cause| ValidationError::SignatureError { node, cause })?, + // Static edges must *not* refer to type variables declared by enclosing FuncDefns + // as these are only types at runtime. + EdgeKind::Static(ty) => ty + .validate(self.extension_registry, &[]) .map_err(|cause| ValidationError::SignatureError { node, cause })?, _ => (), } @@ -278,9 +248,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { let other_op = self.hugr.get_optype(other_node); let Some(other_kind) = other_op.port_kind(other_offset) else { - // The number of ports in `other_node` does not match the operation definition. - // This should be caught by `validate_node`. - return Err(self.validate_node(other_node).unwrap_err()); + panic!("The number of ports in {other_node} does not match the operation definition. This should have been caught by `validate_node`."); }; // TODO: We will require some "unifiable" comparison instead of strict equality, to allow for pre-type inference hugrs. if other_kind != port_kind { @@ -558,6 +526,79 @@ impl<'a, 'b> ValidationContext<'a, 'b> { } .into()) } + + /// Validates that TypeArgs are valid wrt the [ExtensionRegistry] and that nodes + /// only refer to type variables declared by the closest enclosing FuncDefn. + fn validate_subtree( + &mut self, + node: Node, + var_decls: &[TypeParam], + ) -> Result<(), ValidationError> { + let op_type = self.hugr.get_optype(node); + // The op_type must be defined only in terms of type variables defined outside the node + // TODO consider turning this match into a trait method? + match op_type { + OpType::LeafOp(crate::ops::LeafOp::CustomOp(b)) => { + // Try to resolve serialized names to actual OpDefs in Extensions. + let temp: ExternalOp; + let resolved = match &**b { + ExternalOp::Opaque(op) => { + // If resolve_extension_ops has been called first, this would always return Ok(None) + match resolve_opaque_op(node, op, self.extension_registry)? { + Some(exten) => { + temp = ExternalOp::Extension(exten); + &temp + } + None => &**b, + } + } + ExternalOp::Extension(_) => &**b, + }; + // Check TypeArgs are valid, and if we can, fit the declared TypeParams + match resolved { + ExternalOp::Extension(exten) => exten + .def() + .validate_args(exten.args(), self.extension_registry, var_decls) + .map_err(|cause| ValidationError::SignatureError { node, cause })?, + ExternalOp::Opaque(opaq) => { + // Best effort. Just check TypeArgs are valid in themselves, allowing any of them + // to contain type vars (we don't know how many are binary params, so accept if in doubt) + for arg in opaq.args() { + arg.validate(self.extension_registry, var_decls) + .map_err(|cause| ValidationError::SignatureError { node, cause })?; + } + } + } + } + OpType::LeafOp(crate::ops::LeafOp::TypeApply { ta }) => { + ta.validate(self.extension_registry) + .map_err(|cause| ValidationError::SignatureError { node, cause })?; + } + _ => (), + } + + // Check port connections. + for dir in Direction::BOTH { + for (i, port_index) in self.hugr.graph.ports(node.pg_index(), dir).enumerate() { + let port = Port::new(dir, i); + self.validate_port(node, port, port_index, op_type, var_decls)?; + } + } + + // For FuncDefn's, only the type variables declared by the FuncDefn can be referred to by nodes + // inside the function. (The same would be true for FuncDecl's, but they have no child nodes.) + let var_decls = if let OpType::FuncDefn(FuncDefn { signature, .. }) = op_type { + signature.params() + } else { + var_decls + }; + + for child in self.hugr.children(node) { + self.validate_subtree(child, var_decls)?; + } + + Ok(()) + } } /// Errors that can occur while validating a Hugr. diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index c72c90615..8a59d8a45 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -2,7 +2,10 @@ use cool_asserts::assert_matches; use super::*; use crate::builder::test::closed_dfg_root_hugr; -use crate::builder::{BuildError, Container, Dataflow, DataflowSubContainer, ModuleBuilder}; +use crate::builder::{ + BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, + ModuleBuilder, +}; use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; use crate::extension::{ Extension, ExtensionId, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY, @@ -11,11 +14,12 @@ use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrError, HugrMut, NodeType}; use crate::macros::const_extension_ids; use crate::ops::dataflow::IOTrait; -use crate::ops::{self, LeafOp, OpType}; +use crate::ops::{self, Const, LeafOp, OpType}; use crate::std_extensions::logic; use crate::std_extensions::logic::test::{and_op, not_op, or_op}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; -use crate::types::{CustomType, FunctionType, Type, TypeBound, TypeRow}; +use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow}; +use crate::values::Value; use crate::{type_row, Direction, IncomingPort, Node}; const NAT: Type = crate::extension::prelude::USIZE_T; @@ -27,7 +31,7 @@ const Q: Type = crate::extension::prelude::QB_T; fn make_simple_hugr(copies: usize) -> (Hugr, Node) { let def_op: OpType = ops::FuncDefn { name: "main".into(), - signature: FunctionType::new(type_row![BOOL_T], vec![BOOL_T; copies]), + signature: FunctionType::new(type_row![BOOL_T], vec![BOOL_T; copies]).into(), } .into(); @@ -173,7 +177,7 @@ fn children_restrictions() { .add_node_with_parent( root, ops::FuncDefn { - signature: def_sig, + signature: def_sig.into(), name: "main".into(), }, ) @@ -445,7 +449,7 @@ fn missing_lift_node() -> Result<(), BuildError> { let mut module_builder = ModuleBuilder::new(); let mut main = module_builder.define_function( "main", - FunctionType::new(type_row![NAT], type_row![NAT]).pure(), + FunctionType::new(type_row![NAT], type_row![NAT]).into(), )?; let [main_input] = main.input_wires_arr(); @@ -481,7 +485,7 @@ fn missing_lift_node() -> Result<(), BuildError> { fn too_many_extension() -> Result<(), BuildError> { let mut module_builder = ModuleBuilder::new(); - let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]).pure(); + let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]).into(); let mut main = module_builder.define_function("main", main_sig)?; let [main_input] = main.input_wires_arr(); @@ -521,7 +525,7 @@ fn extensions_mismatch() -> Result<(), BuildError> { let main_sig = FunctionType::new(type_row![], type_row![NAT]) .with_extension_delta(&all_rs) - .with_input_extensions(ExtensionSet::new()); + .into(); let mut main = module_builder.define_function("main", main_sig)?; @@ -636,7 +640,7 @@ fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { b.root(), ops::FuncDefn { name: "main".into(), - signature: FunctionType::new(row.clone(), row.clone()), + signature: FunctionType::new(row.clone(), row.clone()).into(), }, ) .unwrap(); @@ -802,3 +806,114 @@ fn parent_io_mismatch() { )) ); } + +#[test] +fn typevars_declared() -> Result<(), Box> { + // Base case + let f = FunctionBuilder::new( + "myfunc", + PolyFuncType::new( + [TypeParam::Type(TypeBound::Any)], + FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + ), + )?; + let [w] = f.input_wires_arr(); + f.finish_prelude_hugr_with_outputs([w])?; + // Type refers to undeclared variable + let f = FunctionBuilder::new( + "myfunc", + PolyFuncType::new( + [TypeParam::Type(TypeBound::Any)], + FunctionType::new_endo(vec![Type::new_var_use(1, TypeBound::Any)]), + ), + )?; + let [w] = f.input_wires_arr(); + assert!(f.finish_prelude_hugr_with_outputs([w]).is_err()); + // Variable declaration incorrectly copied to use site + let f = FunctionBuilder::new( + "myfunc", + PolyFuncType::new( + [TypeParam::Type(TypeBound::Any)], + FunctionType::new_endo(vec![Type::new_var_use(1, TypeBound::Copyable)]), + ), + )?; + let [w] = f.input_wires_arr(); + assert!(f.finish_prelude_hugr_with_outputs([w]).is_err()); + Ok(()) +} + +/// Test that nested FuncDefns cannot use Type Variables declared by enclosing FuncDefns +#[test] +fn nested_typevars() -> Result<(), Box> { + const OUTER_BOUND: TypeBound = TypeBound::Any; + const INNER_BOUND: TypeBound = TypeBound::Copyable; + fn build(t: Type) -> Result { + let mut outer = FunctionBuilder::new( + "outer", + PolyFuncType::new( + [OUTER_BOUND.into()], + FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + ), + )?; + let inner = outer.define_function( + "inner", + PolyFuncType::new([INNER_BOUND.into()], FunctionType::new_endo(vec![t])), + )?; + let [w] = inner.input_wires_arr(); + inner.finish_with_outputs([w])?; + let [w] = outer.input_wires_arr(); + outer.finish_prelude_hugr_with_outputs([w]) + } + assert!(build(Type::new_var_use(0, INNER_BOUND)).is_ok()); + assert_matches!( + build(Type::new_var_use(1, OUTER_BOUND)).unwrap_err(), + BuildError::InvalidHUGR(ValidationError::SignatureError { + cause: SignatureError::FreeTypeVar { + idx: 1, + num_decls: 1 + }, + .. + }) + ); + assert_matches!(build(Type::new_var_use(0, OUTER_BOUND)).unwrap_err(), + BuildError::InvalidHUGR(ValidationError::SignatureError { cause: SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached }, .. }) => + actual == INNER_BOUND.into() && cached == OUTER_BOUND.into()); + Ok(()) +} + +#[test] +fn no_polymorphic_consts() -> Result<(), Box> { + use crate::std_extensions::collections; + const BOUND: TypeParam = TypeParam::Type(TypeBound::Copyable); + let list_of_var = Type::new_extension( + collections::EXTENSION + .get_type(&collections::LIST_TYPENAME) + .unwrap() + .instantiate(vec![TypeArg::new_var_use(0, BOUND)])?, + ); + let reg = ExtensionRegistry::try_new([collections::EXTENSION.to_owned()]).unwrap(); + let just_colns = ExtensionSet::singleton(&collections::EXTENSION_NAME); + let mut def = FunctionBuilder::new( + "myfunc", + PolyFuncType::new( + [BOUND], + FunctionType::new(vec![], vec![list_of_var.clone()]).with_extension_delta(&just_colns), + ), + )?; + let empty_list = Value::Extension { + c: (Box::new(collections::ListValue::new(vec![])),), + }; + let cst = def.add_load_const(Const::new(empty_list, list_of_var)?, just_colns)?; + let res = def.finish_hugr_with_outputs([cst], ®); + assert_matches!( + res.unwrap_err(), + BuildError::InvalidHUGR(ValidationError::SignatureError { + cause: SignatureError::FreeTypeVar { + idx: 0, + num_decls: 0 + }, + .. + }) + ); + Ok(()) +} diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 63fb03ab4..b9f798370 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -27,7 +27,7 @@ use crate::ops::handle::NodeHandle; use crate::ops::{FuncDecl, FuncDefn, OpName, OpTag, OpTrait, OpType, DFG}; #[rustversion::since(1.75)] // uses impl in return position use crate::types::Type; -use crate::types::{EdgeKind, FunctionType}; +use crate::types::{EdgeKind, FunctionType, PolyFuncType}; use crate::{Direction, IncomingPort, Node, OutgoingPort, Port}; #[rustversion::since(1.75)] // uses impl in return position use itertools::Either; @@ -336,12 +336,12 @@ pub trait HugrView: sealed::HugrInternals { /// For function-like HUGRs (DFG, FuncDefn, FuncDecl), report the function /// type. Otherwise return None. - fn get_function_type(&self) -> Option<&FunctionType> { + fn get_function_type(&self) -> Option { let op = self.get_nodetype(self.root()); match &op.op { - OpType::DFG(DFG { signature }) - | OpType::FuncDecl(FuncDecl { signature, .. }) - | OpType::FuncDefn(FuncDefn { signature, .. }) => Some(signature), + OpType::DFG(DFG { signature }) => Some(signature.clone().into()), + OpType::FuncDecl(FuncDecl { signature, .. }) + | OpType::FuncDefn(FuncDefn { signature, .. }) => Some(signature.clone()), _ => None, } } diff --git a/src/hugr/views/descendants.rs b/src/hugr/views/descendants.rs index bd31107a2..f0c515457 100644 --- a/src/hugr/views/descendants.rs +++ b/src/hugr/views/descendants.rs @@ -222,7 +222,7 @@ pub(super) mod test { let (f_id, inner_id) = { let mut func_builder = module_builder.define_function( "main", - FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).pure(), + FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into(), )?; let [int, qb] = func_builder.input_wires_arr(); @@ -261,12 +261,12 @@ pub(super) mod test { assert_eq!( region.get_function_type(), - Some(&FunctionType::new(type_row![NAT, QB], type_row![NAT, QB])) + Some(FunctionType::new(type_row![NAT, QB], type_row![NAT, QB]).into()) ); let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; assert_eq!( inner_region.get_function_type(), - Some(&FunctionType::new(type_row![NAT], type_row![NAT])) + Some(FunctionType::new(type_row![NAT], type_row![NAT]).into()) ); Ok(()) diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index 7d54469f4..bcc122361 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -401,7 +401,7 @@ mod test { fn nested_flat() -> Result<(), Box> { let mut module_builder = ModuleBuilder::new(); let fty = FunctionType::new(type_row![NAT], type_row![NAT]); - let mut fbuild = module_builder.define_function("main", fty.clone().pure())?; + let mut fbuild = module_builder.define_function("main", fty.clone().into())?; let dfg = fbuild.dfg_builder(fty, None, fbuild.input_wires())?; let ins = dfg.input_wires(); let sub_dfg = dfg.finish_with_outputs(ins)?; @@ -442,7 +442,11 @@ mod test { fn flat_mut(mut simple_dfg_hugr: Hugr) { simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap(); let root = simple_dfg_hugr.root(); - let signature = simple_dfg_hugr.get_function_type().unwrap().clone(); + let signature = simple_dfg_hugr + .get_function_type() + .unwrap() + .instantiate(&[], &PRELUDE_REGISTRY) + .unwrap(); let sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root); assert_eq!( diff --git a/src/hugr/views/sibling_subgraph.rs b/src/hugr/views/sibling_subgraph.rs index 592446a54..e8e8b4319 100644 --- a/src/hugr/views/sibling_subgraph.rs +++ b/src/hugr/views/sibling_subgraph.rs @@ -17,12 +17,10 @@ use portgraph::{view::Subgraph, Direction, PortView}; use thiserror::Error; use crate::builder::{Container, FunctionBuilder}; -use crate::extension::ExtensionSet; use crate::hugr::{HugrError, HugrMut, HugrView, RootTagged}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{ContainerHandle, DataflowOpID}; use crate::ops::{OpTag, OpTrait}; -use crate::types::Signature; use crate::types::{FunctionType, Type}; use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement}; @@ -407,19 +405,14 @@ impl SiblingSubgraph { /// Create a new Hugr containing only the subgraph. /// - /// The new Hugr will contain a function root wth the same signature as the - /// subgraph and the specified `input_extensions`. + /// The new Hugr will contain a [FuncDefn][crate::ops::FuncDefn] root + /// with the same signature as the subgraph and the specified `name` pub fn extract_subgraph( &self, hugr: &impl HugrView, name: impl Into, - input_extensions: ExtensionSet, ) -> Result { - let signature = Signature { - signature: self.signature(hugr), - input_extensions, - }; - let mut builder = FunctionBuilder::new(name, signature).unwrap(); + let mut builder = FunctionBuilder::new(name, self.signature(hugr).into()).unwrap(); // Take the unfinished Hugr from the builder, to avoid unnecessary // validation checks that require connecting the inputs and outputs. let mut extracted = mem::take(builder.hugr_mut()); @@ -748,7 +741,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]).pure(), + FunctionType::new_endo(type_row![QB_T, QB_T, QB_T]).into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -764,7 +757,7 @@ mod tests { fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare("test", FunctionType::new_endo(type_row![BOOL_T]).pure())?; + let func = mod_builder.declare("test", FunctionType::new_endo(type_row![BOOL_T]).into())?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let outs1 = dfg.add_dataflow_op(not_op(), dfg.input_wires())?; @@ -783,7 +776,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]).pure(), + FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]).into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -976,7 +969,7 @@ mod tests { SiblingGraph::try_new(&hugr, func_root).unwrap(); let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap(); - assert_eq!(func_defn.signature, func.signature(&func_graph)) + assert_eq!(func_defn.signature, func.signature(&func_graph).into()); } #[test] @@ -985,7 +978,7 @@ mod tests { let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); - let extracted = subgraph.extract_subgraph(&hugr, "region", ExtensionSet::new())?; + let extracted = subgraph.extract_subgraph(&hugr, "region")?; extracted.validate(&PRELUDE_REGISTRY).unwrap(); diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 44b4ee8df..6d81ff52a 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -156,7 +156,7 @@ fn static_targets() { #[test] fn test_dataflow_ports_only() { use crate::builder::DataflowSubContainer; - use crate::extension::prelude::BOOL_T; + use crate::extension::{prelude::BOOL_T, PRELUDE_REGISTRY}; use crate::hugr::views::PortIterator; use crate::std_extensions::logic::test::not_op; use itertools::Itertools; @@ -165,7 +165,7 @@ fn test_dataflow_ports_only() { let local_and = dfg .define_function( "and", - FunctionType::new(type_row![BOOL_T; 2], type_row![BOOL_T]).pure(), + FunctionType::new(type_row![BOOL_T; 2], type_row![BOOL_T]).into(), ) .unwrap(); let first_input = local_and.input().out_wire(0); @@ -174,10 +174,17 @@ fn test_dataflow_ports_only() { let [in_bool] = dfg.input_wires_arr(); let not = dfg.add_dataflow_op(not_op(), [in_bool]).unwrap(); - let call = dfg.call(local_and.handle(), [not.out_wire(0); 2]).unwrap(); + let call = dfg + .call( + local_and.handle(), + &[], + [not.out_wire(0); 2], + &PRELUDE_REGISTRY, + ) + .unwrap(); dfg.add_other_wire(not.node(), call.node()).unwrap(); let h = dfg - .finish_hugr_with_outputs(not.outputs(), &crate::extension::PRELUDE_REGISTRY) + .finish_hugr_with_outputs(not.outputs(), &PRELUDE_REGISTRY) .unwrap(); let filtered_ports = h .all_linked_outputs(call.node()) diff --git a/src/ops/module.rs b/src/ops/module.rs index 2bdb5d400..571789183 100644 --- a/src/ops/module.rs +++ b/src/ops/module.rs @@ -2,7 +2,7 @@ use smol_str::SmolStr; -use crate::types::{EdgeKind, FunctionType}; +use crate::types::{EdgeKind, PolyFuncType}; use crate::types::{Type, TypeBound}; use super::StaticTag; @@ -36,7 +36,7 @@ pub struct FuncDefn { /// Name of function pub name: String, /// Signature of the function - pub signature: FunctionType, + pub signature: PolyFuncType, } impl_op_name!(FuncDefn); @@ -63,7 +63,7 @@ pub struct FuncDecl { /// Name of function pub name: String, /// Signature of the function - pub signature: FunctionType, + pub signature: PolyFuncType, } impl_op_name!(FuncDecl); diff --git a/src/ops/validate.rs b/src/ops/validate.rs index ec8b75c24..ae31b8e28 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -10,7 +10,7 @@ use itertools::Itertools; use portgraph::{NodeIndex, PortOffset}; use thiserror::Error; -use crate::types::{Type, TypeRow}; +use crate::types::{FunctionType, Type, TypeRow}; use super::{impl_validate_op, BasicBlock, OpTag, OpTrait, OpType, ValidateOp}; @@ -77,12 +77,10 @@ impl ValidateOp for super::FuncDefn { &self, children: impl DoubleEndedIterator, ) -> Result<(), ChildrenValidationError> { - validate_io_nodes( - &self.signature.input, - &self.signature.output, - "function definition", - children, - ) + // We check type-variables are declared in `validate_subtree`, so here + // we can just assume all type variables are valid regardless of binders. + let FunctionType { input, output, .. } = &self.signature.body; + validate_io_nodes(input, output, "function definition", children) } } diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 99d208c63..41939d5eb 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -27,6 +27,14 @@ pub const EXTENSION_NAME: ExtensionId = ExtensionId::new_unchecked("Collections" /// Dynamically sized list of values, all of the same type. pub struct ListValue(Vec); +impl ListValue { + /// Create a new [CustomConst] for a list of values. + /// (The caller will need these to all be of the same type, but that is not checked here.) + pub fn new(contents: Vec) -> Self { + Self(contents) + } +} + #[typetag::serde] impl CustomConst for ListValue { fn name(&self) -> SmolStr { diff --git a/src/types/check.rs b/src/types/check.rs index 31c8d1962..0ac48e69d 100644 --- a/src/types/check.rs +++ b/src/types/check.rs @@ -58,7 +58,7 @@ impl Type { Ok(()) } (TypeEnum::Function(t), Value::Function { hugr: v }) - if v.get_function_type().is_some_and(|f| &**t == f) => + if v.get_function_type().is_some_and(|f| **t == f) => { // exact signature equality, in future this may need to be // relaxed to be compatibility checks between the signatures. diff --git a/src/types/poly_func.rs b/src/types/poly_func.rs index 759776fb1..dcbc0972e 100644 --- a/src/types/poly_func.rs +++ b/src/types/poly_func.rs @@ -15,7 +15,7 @@ use super::{FunctionType, Substitution}; /// [Graph]: crate::values::Value::Function /// [OpDef]: crate::extension::OpDef #[derive( - Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, + Clone, PartialEq, Debug, Default, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, )] #[display( fmt = "forall {}. {}", @@ -31,7 +31,7 @@ pub struct PolyFuncType { /// [TypeArg]: super::type_param::TypeArg params: Vec, /// Template for the function. May contain variables up to length of [Self::params] - pub(super) body: FunctionType, + pub(crate) body: FunctionType, } impl From for PolyFuncType { @@ -141,12 +141,6 @@ impl PolyFuncType { } } -impl PartialEq for PolyFuncType { - fn eq(&self, other: &FunctionType) -> bool { - self.params.is_empty() && &self.body == other - } -} - /// A [Substitution] with a finite list of known values. /// (Variables out of the range of the list will result in a panic) struct SubstValues<'a>(&'a [TypeArg], &'a ExtensionRegistry); @@ -491,7 +485,7 @@ pub(crate) mod test { let actual = array_max .instantiate_poly(&[USIZE_TA, TypeArg::BoundedNat { n: 3 }], &PRELUDE_REGISTRY)?; - assert_eq!(actual, concrete); + assert_eq!(actual, concrete.into()); // forall N.(Array -> usize) let partial = PolyFuncType::new_validated( diff --git a/src/types/type_param.rs b/src/types/type_param.rs index d8546b58f..b87e2ad82 100644 --- a/src/types/type_param.rs +++ b/src/types/type_param.rs @@ -91,6 +91,12 @@ impl TypeParam { } } +impl From for TypeParam { + fn from(bound: TypeBound) -> Self { + Self::Type(bound) + } +} + /// A statically-known argument value to an operation. #[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[non_exhaustive]