From 3ea4834dd00466e3c106917c1e09c0c5b74c5826 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 22 May 2024 10:31:27 +0100 Subject: [PATCH] feat!: Allow "Row Variables" declared as List (#804) * Add a RowVariable variant of Type(Enum) that stands for potentially multiple types, created via Type::new_row_var_use. * This can appear in a TypeRow, i.e. the TypeRow is then of variable length; and can be instantiated with a list of types (including row vars) or a single row variable * Validation enforces that RowVariables are not used directly as wire/port types, but can appear *inside* other types * OpDef's may be polymorphic over RowVariables (allowing "varargs"-like operators equivalent to e.g. MakeTuple); these must be replaced by non-rowvar types when instantiating the OpDef to an OpType * FuncDefn's/FuncDecl's may also be polymorphic over RowVariables as long as these are not directly argument/result types * Also add TypeParam::new_sequence closes #787 BREAKING CHANGE: Type::validate takes extra bool (allow_rowvars); renamed {FunctionType, PolyFuncType}::(validate=>validate_var_len). --------- Co-authored-by: doug-q <141026920+doug-q@users.noreply.github.com> --- hugr-py/src/hugr/serialization/tys.py | 19 +- hugr/src/builder/build_traits.rs | 11 +- hugr/src/builder/dataflow.rs | 38 ++- hugr/src/extension.rs | 3 + hugr/src/extension/op_def.rs | 17 +- hugr/src/hugr/serialize/test.rs | 21 +- hugr/src/hugr/validate.rs | 30 +- hugr/src/hugr/validate/test.rs | 169 +++++++++++- hugr/src/ops/constant.rs | 2 +- hugr/src/types.rs | 122 +++++--- hugr/src/types/poly_func.rs | 115 +++++++- hugr/src/types/serialize.rs | 3 + hugr/src/types/signature.rs | 30 +- hugr/src/types/type_param.rs | 260 ++++++++++++++++-- hugr/src/types/type_row.rs | 29 +- .../schema/hugr_schema_strict_v1.json | 32 +++ specification/schema/hugr_schema_v1.json | 32 +++ .../schema/testing_hugr_schema_strict_v1.json | 32 +++ .../schema/testing_hugr_schema_v1.json | 32 +++ 19 files changed, 912 insertions(+), 85 deletions(-) diff --git a/hugr-py/src/hugr/serialization/tys.py b/hugr-py/src/hugr/serialization/tys.py index 92942fd5d..f19e27340 100644 --- a/hugr-py/src/hugr/serialization/tys.py +++ b/hugr-py/src/hugr/serialization/tys.py @@ -211,6 +211,15 @@ class Variable(ConfiguredBaseModel): b: "TypeBound" +class RowVar(ConfiguredBaseModel): + """A variable standing for a row of some (unknown) number of types. + May occur only within a row; not a node input/output.""" + + t: Literal["R"] = "R" + i: int + b: "TypeBound" + + class USize(ConfiguredBaseModel): """Unsigned integer size type.""" @@ -320,7 +329,15 @@ class Type(RootModel): """A HUGR type.""" root: Annotated[ - Qubit | Variable | USize | FunctionType | Array | SumType | Opaque | Alias, + Qubit + | Variable + | RowVar + | USize + | FunctionType + | Array + | SumType + | Opaque + | Alias, WrapValidator(_json_custom_error_validator), Field(discriminator="t"), ] diff --git a/hugr/src/builder/build_traits.rs b/hugr/src/builder/build_traits.rs index e3ef10de8..9703d65c3 100644 --- a/hugr/src/builder/build_traits.rs +++ b/hugr/src/builder/build_traits.rs @@ -19,7 +19,7 @@ use crate::{ types::EdgeKind, }; -use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE_REGISTRY}; +use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE_REGISTRY}; use crate::types::{FunctionType, PolyFuncType, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -645,6 +645,15 @@ fn add_node_with_wires( inputs: impl IntoIterator, ) -> Result<(Node, usize), BuildError> { let nodetype: NodeType = nodetype.into(); + // Check there are no row variables, as that would prevent us + // from indexing into the node's ports in order to wire up + nodetype + .op_signature() + .as_ref() + .and_then(FunctionType::find_rowvar) + .map_or(Ok(()), |(idx, _)| { + Err(SignatureError::RowVarWhereTypeExpected { idx }) + })?; let num_outputs = nodetype.op().value_output_count(); let op_node = data_builder.add_child_node(nodetype.clone()); diff --git a/hugr/src/builder/dataflow.rs b/hugr/src/builder/dataflow.rs index 473e79877..67865e13a 100644 --- a/hugr/src/builder/dataflow.rs +++ b/hugr/src/builder/dataflow.rs @@ -208,13 +208,14 @@ pub(crate) mod test { use crate::builder::build_traits::DataflowHugr; use crate::builder::{BuilderWiringError, DataflowSubContainer, ModuleBuilder}; - use crate::extension::prelude::BOOL_T; - use crate::extension::{ExtensionId, EMPTY_REG}; + use crate::extension::prelude::{BOOL_T, USIZE_T}; + use crate::extension::{ExtensionId, SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::validate::InterGraphEdgeError; use crate::ops::{handle::NodeHandle, Lift, Noop, OpTag}; use crate::std_extensions::logic::test::and_op; - use crate::types::Type; + use crate::types::type_param::TypeParam; + use crate::types::{Type, TypeBound}; use crate::utils::test_quantum_extension::h_gate; use crate::{ builder::test::{n_identity, BIT, NAT, QB}, @@ -550,4 +551,35 @@ pub(crate) mod test { ); Ok(()) } + + #[test] + fn no_outer_row_variables() -> Result<(), BuildError> { + let e = crate::hugr::validate::test::extension_with_eval_parallel(); + let tv = Type::new_row_var_use(0, TypeBound::Copyable); + let mut fb = FunctionBuilder::new( + "bad_eval", + PolyFuncType::new( + [TypeParam::new_list(TypeBound::Copyable)], + FunctionType::new( + Type::new_function(FunctionType::new(USIZE_T, tv.clone())), + vec![], + ), + ), + )?; + + let [func_arg] = fb.input_wires_arr(); + let i = fb.add_load_value(crate::extension::prelude::ConstUsize::new(5)); + let ev = e.instantiate_extension_op( + "eval", + [vec![USIZE_T.into()].into(), vec![tv.into()].into()], + &PRELUDE_REGISTRY, + )?; + let r = fb.add_dataflow_op(ev, [func_arg, i]); + // This error would be caught in validation, but the builder detects it much earlier + assert_eq!( + r.unwrap_err(), + BuildError::SignatureError(SignatureError::RowVarWhereTypeExpected { idx: 0 }) + ); + Ok(()) + } } diff --git a/hugr/src/extension.rs b/hugr/src/extension.rs index ebae773ec..c2e58b0a8 100644 --- a/hugr/src/extension.rs +++ b/hugr/src/extension.rs @@ -161,6 +161,9 @@ pub enum SignatureError { /// A type variable that was used has not been declared #[error("Type variable {idx} was not declared ({num_decls} in scope)")] FreeTypeVar { idx: usize, num_decls: usize }, + /// A row variable was found outside of a variable-length row + #[error("Expected a single type, but found row variable {idx}")] + RowVarWhereTypeExpected { idx: usize }, /// The result of the type application stored in a [Call] /// is not what we get by applying the type-args to the polymorphic function /// diff --git a/hugr/src/extension/op_def.rs b/hugr/src/extension/op_def.rs index 91a49cfa5..d91aa6890 100644 --- a/hugr/src/extension/op_def.rs +++ b/hugr/src/extension/op_def.rs @@ -405,7 +405,10 @@ impl OpDef { // TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams // for both type scheme and custom binary if let SignatureFunc::TypeScheme(ts) = &self.signature_func { - ts.poly_func.validate(exts)?; + // The type scheme may contain row variables so be of variable length; + // these will have to be substituted to fixed-length concrete types when + // the OpDef is instantiated into an actual OpType. + ts.poly_func.validate_var_len(exts)?; } Ok(()) } @@ -482,6 +485,7 @@ mod test { use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; use crate::ops::{CustomOp, OpName}; use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; + use crate::types::type_param::TypeArgError; use crate::types::Type; use crate::types::{type_param::TypeParam, FunctionType, PolyFuncType, TypeArg, TypeBound}; use crate::{const_extension_ids, Extension}; @@ -639,6 +643,17 @@ mod test { def.compute_signature(&args, &EMPTY_REG), Ok(FunctionType::new_endo(vec![tv])) ); + // But not with an external row variable + let arg: TypeArg = Type::new_row_var_use(0, TypeBound::Eq).into(); + assert_eq!( + def.compute_signature(&[arg.clone()], &EMPTY_REG), + Err(SignatureError::TypeArgMismatch( + TypeArgError::TypeMismatch { + param: TypeBound::Any.into(), + arg + } + )) + ); Ok(()) } diff --git a/hugr/src/hugr/serialize/test.rs b/hugr/src/hugr/serialize/test.rs index 049e3a096..ab86c1e33 100644 --- a/hugr/src/hugr/serialize/test.rs +++ b/hugr/src/hugr/serialize/test.rs @@ -420,13 +420,32 @@ fn polyfunctype1() -> PolyFuncType { PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type) } +fn polyfunctype2() -> PolyFuncType { + let tv0 = Type::new_row_var_use(0, TypeBound::Any); + let tv1 = Type::new_row_var_use(1, TypeBound::Eq); + let params = [TypeBound::Any, TypeBound::Eq].map(TypeParam::new_list); + let inputs = vec![ + Type::new_function(FunctionType::new(tv0.clone(), tv1.clone())), + tv0, + ]; + let res = PolyFuncType::new(params, FunctionType::new(inputs, tv1)); + // Just check we've got the arguments the right way round + // (not that it really matters for the serialization schema we have) + res.validate_var_len(&EMPTY_REG).unwrap(); + res +} + #[rstest] #[case(FunctionType::new_endo(type_row![]).into())] #[case(polyfunctype1())] #[case(PolyFuncType::new([TypeParam::Opaque { ty: int_custom_type(TypeArg::BoundedNat { n: 1 }) }], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] #[case(PolyFuncType::new([TypeBound::Eq.into()], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Eq)])))] -#[case(PolyFuncType::new([TypeParam::List { param: Box::new(TypeBound::Any.into()) }], FunctionType::new_endo(type_row![])))] +#[case(PolyFuncType::new([TypeParam::new_list(TypeBound::Any)], FunctionType::new_endo(type_row![])))] #[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FunctionType::new_endo(type_row![])))] +#[case(PolyFuncType::new( + [TypeParam::new_list(TypeBound::Any)], + FunctionType::new_endo(Type::new_tuple(Type::new_row_var_use(0, TypeBound::Any)))))] +#[case(polyfunctype2())] fn roundtrip_polyfunctype(#[case] poly_func_type: PolyFuncType) { check_testing_roundtrip(poly_func_type) } diff --git a/hugr/src/hugr/validate.rs b/hugr/src/hugr/validate.rs index 4f0c2d438..5a85ca8fe 100644 --- a/hugr/src/hugr/validate.rs +++ b/hugr/src/hugr/validate.rs @@ -20,7 +20,7 @@ use crate::ops::custom::{resolve_opaque_op, CustomOp}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; use crate::ops::{FuncDefn, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::type_param::TypeParam; -use crate::types::EdgeKind; +use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, Hugr, Node, Port}; use super::views::{HierarchyView, HugrView, SiblingGraph}; @@ -211,7 +211,20 @@ impl<'a, 'b> ValidationContext<'a, 'b> { } } - // Secondly that the node has correct children + // Secondly, check that the node signature does not contain any row variables. + // (We do this here so it's before we try indexing into the ports of any nodes). + op_type + .dataflow_signature() + .as_ref() + .and_then(FunctionType::find_rowvar) + .map_or(Ok(()), |(idx, _)| { + Err(ValidationError::SignatureError { + node, + cause: SignatureError::RowVarWhereTypeExpected { idx }, + }) + })?; + + // Thirdly that the node has correct children self.validate_children(node, node_type)?; Ok(()) @@ -301,11 +314,14 @@ impl<'a, 'b> ValidationContext<'a, 'b> { var_decls: &[TypeParam], ) -> Result<(), SignatureError> { match &port_kind { - EdgeKind::Value(ty) => ty.validate(self.extension_registry, var_decls), + EdgeKind::Value(ty) => ty.validate(false, self.extension_registry, var_decls), // Static edges must *not* refer to type variables declared by enclosing FuncDefns - // as these are only types at runtime. - EdgeKind::Const(ty) => ty.validate(self.extension_registry, &[]), - EdgeKind::Function(pf) => pf.validate(self.extension_registry), + // as these are only types at runtime. (Note the choice of `allow_row_vars` as `false` is arbitrary here.) + EdgeKind::Const(ty) => ty.validate(false, self.extension_registry, &[]), + // Allow function "value" to have unknown arity. A Call node will have to provide + // TypeArgs that produce a known arity, but a LoadFunction might pass the function + // value ("function pointer") around without knowing how to call it. + EdgeKind::Function(pf) => pf.validate_var_len(self.extension_registry), _ => Ok(()), } } @@ -794,4 +810,4 @@ pub enum InterGraphEdgeError { } #[cfg(test)] -mod test; +pub(crate) mod test; diff --git a/hugr/src/hugr/validate/test.rs b/hugr/src/hugr/validate/test.rs index a1518fcb2..b83cb751d 100644 --- a/hugr/src/hugr/validate/test.rs +++ b/hugr/src/hugr/validate/test.rs @@ -1,4 +1,5 @@ use cool_asserts::assert_matches; +use rstest::rstest; use super::*; use crate::builder::test::closed_dfg_root_hugr; @@ -10,7 +11,7 @@ use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, USIZE_T}; use crate::extension::{Extension, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::HugrMut; -use crate::ops::dataflow::IOTrait; +use crate::ops::dataflow::{IOTrait, LoadFunction}; use crate::ops::handle::NodeHandle; use crate::ops::leaf::MakeTuple; use crate::ops::{self, Noop, OpType, Value}; @@ -546,6 +547,170 @@ fn no_polymorphic_consts() -> Result<(), Box> { Ok(()) } +pub(crate) fn extension_with_eval_parallel() -> Extension { + let rowp = TypeParam::new_list(TypeBound::Any); + let mut e = Extension::new(EXT_ID); + + let inputs = Type::new_row_var_use(0, TypeBound::Any); + let outputs = Type::new_row_var_use(1, TypeBound::Any); + let evaled_fn = Type::new_function(FunctionType::new(inputs.clone(), outputs.clone())); + let pf = PolyFuncType::new( + [rowp.clone(), rowp.clone()], + FunctionType::new(vec![evaled_fn, inputs], outputs), + ); + e.add_op("eval".into(), "".into(), pf).unwrap(); + + let rv = |idx| Type::new_row_var_use(idx, TypeBound::Any); + let pf = PolyFuncType::new( + [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], + FunctionType::new( + vec![ + Type::new_function(FunctionType::new(rv(0), rv(2))), + Type::new_function(FunctionType::new(rv(1), rv(3))), + ], + Type::new_function(FunctionType::new(vec![rv(0), rv(1)], vec![rv(2), rv(3)])), + ), + ); + e.add_op("parallel".into(), "".into(), pf).unwrap(); + + e +} + +#[test] +fn instantiate_row_variables() -> Result<(), Box> { + fn uint_seq(i: usize) -> TypeArg { + vec![TypeArg::Type { ty: USIZE_T }; i].into() + } + let e = extension_with_eval_parallel(); + let mut dfb = DFGBuilder::new(FunctionType::new( + vec![ + Type::new_function(FunctionType::new(USIZE_T, vec![USIZE_T, USIZE_T])), + USIZE_T, + ], // inputs: function + its argument + vec![USIZE_T; 4], // outputs (*2^2, three calls) + ))?; + let [func, int] = dfb.input_wires_arr(); + let eval = e.instantiate_extension_op("eval", [uint_seq(1), uint_seq(2)], &PRELUDE_REGISTRY)?; + let [a, b] = dfb.add_dataflow_op(eval, [func, int])?.outputs_arr(); + let par = e.instantiate_extension_op( + "parallel", + [uint_seq(1), uint_seq(1), uint_seq(2), uint_seq(2)], + &PRELUDE_REGISTRY, + )?; + let [par_func] = dfb.add_dataflow_op(par, [func, func])?.outputs_arr(); + let eval2 = + e.instantiate_extension_op("eval", [uint_seq(2), uint_seq(4)], &PRELUDE_REGISTRY)?; + let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?; + dfb.finish_hugr_with_outputs( + eval2.outputs(), + &ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap(), + )?; + Ok(()) +} + +fn seq1ty(t: Type) -> TypeArg { + TypeArg::Sequence { + elems: vec![t.into()], + } +} + +#[test] +fn inner_row_variables() -> Result<(), Box> { + let e = extension_with_eval_parallel(); + let tv = Type::new_row_var_use(0, TypeBound::Any); + let inner_ft = Type::new_function(FunctionType::new_endo(tv.clone())); + let ft_usz = Type::new_function(FunctionType::new_endo(vec![tv.clone(), USIZE_T])); + let mut fb = FunctionBuilder::new( + "id", + PolyFuncType::new( + [TypeParam::new_list(TypeBound::Any)], + FunctionType::new(inner_ft.clone(), ft_usz), + ), + )?; + // All the wires here are carrying higher-order Function values + let [func_arg] = fb.input_wires_arr(); + let [id_usz] = { + let bldr = fb.define_function("id_usz", FunctionType::new_endo(USIZE_T).into())?; + let vals = bldr.input_wires(); + let [inner_def] = bldr.finish_with_outputs(vals)?.outputs_arr(); + let loadf = LoadFunction::try_new( + FunctionType::new_endo(USIZE_T).into(), + [], + &PRELUDE_REGISTRY, + ) + .unwrap(); + fb.add_dataflow_op(loadf, [inner_def])?.outputs_arr() + }; + let par = e.instantiate_extension_op( + "parallel", + [tv.clone(), USIZE_T, tv.clone(), USIZE_T].map(seq1ty), + &PRELUDE_REGISTRY, + )?; + let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; + fb.finish_hugr_with_outputs( + par_func.outputs(), + &ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap(), + )?; + Ok(()) +} + +#[rstest] +#[case(false)] +#[case(true)] +fn no_outer_row_variables(#[case] connect: bool) -> Result<(), Box> { + let e = extension_with_eval_parallel(); + let tv = Type::new_row_var_use(0, TypeBound::Copyable); + let fun_ty = Type::new_function(FunctionType::new(USIZE_T, tv.clone())); + let results = if connect { vec![tv.clone()] } else { vec![] }; + let mut fb = Hugr::new( + FuncDefn { + name: "bad_eval".to_string(), + signature: PolyFuncType::new( + [TypeParam::new_list(TypeBound::Copyable)], + FunctionType::new(fun_ty.clone(), results.clone()), + ), + } + .into(), + ); + let inp = fb.add_node_with_parent( + fb.root(), + ops::Input { + types: fun_ty.into(), + }, + ); + let out = fb.add_node_with_parent( + fb.root(), + ops::Output { + types: results.into(), + }, + ); + let cst = fb.add_node_with_parent( + fb.root(), + ops::Const::new(crate::extension::prelude::ConstUsize::new(5).into()), + ); + let i = fb.add_node_with_parent(fb.root(), ops::LoadConstant { datatype: USIZE_T }); + fb.connect(cst, 0, i, 0); + + let ev = fb.add_node_with_parent( + fb.root(), + e.instantiate_extension_op("eval", [seq1ty(USIZE_T), seq1ty(tv)], &PRELUDE_REGISTRY)?, + ); + fb.connect(inp, 0, ev, 0); + fb.connect(i, 0, ev, 1); + if connect { + fb.connect(ev, 0, out, 0); + } + let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap(); + assert_matches!( + fb.validate(®).unwrap_err(), + ValidationError::SignatureError { + node, + cause: SignatureError::RowVarWhereTypeExpected { idx: 0 } + } => assert!([ev, out].contains(&node)) + ); + Ok(()) +} + #[test] fn test_polymorphic_call() -> Result<(), Box> { let mut e = Extension::new(EXT_ID); @@ -562,7 +727,7 @@ fn test_polymorphic_call() -> Result<(), Box> { ) .with_extension_delta(ExtensionSet::type_var(1)), ); - // The higher-order "eval" operation - takes a function and its argument. + // Single-input/output version of the higher-order "eval" operation, with extension param. // Note the extension-delta of the eval node includes that of the input function. e.add_op( "eval".into(), diff --git a/hugr/src/ops/constant.rs b/hugr/src/ops/constant.rs index d0bd420d3..7e3a1ac37 100644 --- a/hugr/src/ops/constant.rs +++ b/hugr/src/ops/constant.rs @@ -669,7 +669,7 @@ mod test { 32, // Target around 32 total elements 3, // Each collection is up to 3 elements long |element| { - (any::(), vec(element.clone(), 0..3)).prop_map( + (Type::any_non_row_var(), vec(element.clone(), 0..3)).prop_map( |(typ, contents)| { OpaqueValue::new(ListValue::new( typ, diff --git a/hugr/src/types.rs b/hugr/src/types.rs index 147ad4b4c..cb8cd0706 100644 --- a/hugr/src/types.rs +++ b/hugr/src/types.rs @@ -222,6 +222,10 @@ pub enum TypeEnum { #[allow(missing_docs)] #[display(fmt = "Variable({})", _0)] Variable(usize, TypeBound), + /// Variable index, and cache of inner TypeBound - matches a [TypeParam::List] of [TypeParam::Type] + /// of this bound (checked in validation) + #[display(fmt = "RowVar({})", _0)] + RowVariable(usize, TypeBound), #[allow(missing_docs)] #[display(fmt = "{}", "_0")] Sum(#[cfg_attr(test, proptest(strategy = "any_with::(params)"))] SumType), @@ -233,7 +237,7 @@ impl TypeEnum { TypeEnum::Extension(c) => c.bound(), TypeEnum::Alias(a) => a.bound, TypeEnum::Function(_) => TypeBound::Copyable, - TypeEnum::Variable(_, b) => *b, + TypeEnum::Variable(_, b) | TypeEnum::RowVariable(_, b) => *b, TypeEnum::Sum(SumType::Unit { size: _ }) => TypeBound::Eq, TypeEnum::Sum(SumType::General { rows }) => least_upper_bound( rows.iter() @@ -271,14 +275,6 @@ impl TypeEnum { /// ``` pub struct Type(TypeEnum, TypeBound); -fn validate_each<'a>( - extension_registry: &ExtensionRegistry, - var_decls: &[TypeParam], - mut iter: impl Iterator, -) -> Result<(), SignatureError> { - iter.try_for_each(|t| t.validate(extension_registry, var_decls)) -} - impl Type { /// An empty `TypeRow`. Provided here for convenience pub const EMPTY_TYPEROW: TypeRow = type_row![]; @@ -332,13 +328,27 @@ impl Type { } /// New use (occurrence) of the type variable with specified index. - /// For use in type schemes only: `bound` must match that with which the - /// variable was declared (i.e. as a [TypeParam::Type]`(bound)`). + /// `bound` must be exactly that with which the variable was declared + /// (i.e. as a [TypeParam::Type]`(bound)`), which may be narrower + /// than required for the use. pub const fn new_var_use(idx: usize, bound: TypeBound) -> Self { Self(TypeEnum::Variable(idx, bound), bound) } - /// Report the least upper TypeBound, if there is one. + /// New use (occurrence) of the row variable with specified index. + /// `bound` must be exactly that with which the variable was declared + /// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound), + /// which may be narrower than required for the use. + /// For use in [OpDef] type schemes, or function types, only, + /// not [FuncDefn] type schemes or as a Hugr port type. + /// + /// [OpDef]: crate::extension::OpDef + /// [FuncDefn]: crate::ops::FuncDefn + pub const fn new_row_var_use(idx: usize, bound: TypeBound) -> Self { + Self(TypeEnum::RowVariable(idx, bound), bound) + } + + /// Report the least upper [TypeBound] #[inline(always)] pub const fn least_upper_bound(&self) -> TypeBound { self.1 @@ -356,47 +366,69 @@ impl Type { TypeBound::Copyable.contains(self.least_upper_bound()) } + /// Tells if this Type is a row variable, i.e. could stand for any number >=0 of Types + pub fn is_row_var(&self) -> bool { + matches!(self.0, TypeEnum::RowVariable(_, _)) + } + /// Checks all variables used in the type are in the provided list - /// of bound variables, and that for each [CustomType] the corresponding + /// of bound variables, rejecting any [RowVariable]s if `allow_row_vars` is False; + /// and that for each [CustomType] the corresponding /// [TypeDef] is in the [ExtensionRegistry] and the type arguments /// [validate] and fit into the def's declared parameters. /// + /// [RowVariable]: TypeEnum::RowVariable /// [validate]: crate::types::type_param::TypeArg::validate /// [TypeDef]: crate::extension::TypeDef pub(crate) fn validate( &self, + allow_row_vars: bool, extension_registry: &ExtensionRegistry, var_decls: &[TypeParam], ) -> Result<(), SignatureError> { // There is no need to check the components against the bound, // that is guaranteed by construction (even for deserialization) match &self.0 { - TypeEnum::Sum(SumType::General { rows }) => validate_each( - extension_registry, - var_decls, - rows.iter().flat_map(|x| x.iter()), - ), + TypeEnum::Sum(SumType::General { rows }) => rows + .iter() + .try_for_each(|row| row.validate_var_len(extension_registry, var_decls)), TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there TypeEnum::Alias(_) => Ok(()), TypeEnum::Extension(custy) => custy.validate(extension_registry, var_decls), - TypeEnum::Function(ft) => ft.validate(extension_registry, var_decls), + // Function values may be passed around without knowing their arity + // (i.e. with row vars) as long as they are not called: + TypeEnum::Function(ft) => ft.validate_var_len(extension_registry, var_decls), TypeEnum::Variable(idx, bound) => check_typevar_decl(var_decls, *idx, &(*bound).into()), + TypeEnum::RowVariable(idx, bound) => { + if allow_row_vars { + check_typevar_decl(var_decls, *idx, &TypeParam::new_list(*bound)) + } else { + Err(SignatureError::RowVarWhereTypeExpected { idx: *idx }) + } + } } } - pub(crate) fn substitute(&self, t: &Substitution) -> Self { + /// Applies a substitution to a type. + /// This may result in a row of types, if this [Type] is not really a single type but actually a row variable + /// Invariants may be confirmed by validation: + /// * If [Type::validate]`(false)` returns successfully, this method will return a Vec containing exactly one type + /// * If [Type::validate]`(false)` fails, but `(true)` succeeds, this method may (depending on structure of self) + /// return a Vec containing any number of [Type]s. These may (or not) pass [Type::validate] + fn substitute(&self, t: &Substitution) -> Vec { match &self.0 { - TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => self.clone(), + TypeEnum::RowVariable(idx, bound) => t.apply_rowvar(*idx, *bound), + TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()], TypeEnum::Variable(idx, bound) => { let TypeArg::Type { ty } = t.apply_var(*idx, &((*bound).into())) else { panic!("Variable was not a type - try validate() first") }; - ty + vec![ty] } - TypeEnum::Extension(cty) => Type::new_extension(cty.substitute(t)), - TypeEnum::Function(bf) => Type::new_function(bf.substitute(t)), + TypeEnum::Extension(cty) => vec![Type::new_extension(cty.substitute(t))], + TypeEnum::Function(bf) => vec![Type::new_function(bf.substitute(t))], TypeEnum::Sum(SumType::General { rows }) => { - Type::new_sum(rows.iter().map(|x| subst_row(x, t))) + vec![Type::new_sum(rows.iter().map(|r| r.substitute(t)))] } } } @@ -416,20 +448,34 @@ impl<'a> Substitution<'a> { arg.clone() } + fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec { + let arg = self + .0 + .get(idx) + .expect("Undeclared type variable - call validate() ?"); + debug_assert!(check_type_arg(arg, &TypeParam::new_list(bound)).is_ok()); + match arg { + // Row variables are represented as 'TypeArg::Type's (see TypeArg::new_var_use) + TypeArg::Sequence { elems } => elems + .iter() + .map(|ta| match ta { + TypeArg::Type { ty } => ty.clone(), + _ => panic!("Not a list of types - call validate() ?"), + }) + .collect(), + TypeArg::Type { ty } if matches!(ty.0, TypeEnum::RowVariable(_, _)) => { + // Standalone "Type" can be used iff its actually a Row Variable not an actual (single) Type + vec![ty.clone()] + } + _ => panic!("Not a type or list of types - call validate() ?"), + } + } + fn extension_registry(&self) -> &ExtensionRegistry { self.1 } } -fn subst_row(row: &TypeRow, tr: &Substitution) -> TypeRow { - let res = row - .iter() - .map(|ty| ty.substitute(tr)) - .collect::>() - .into(); - res -} - pub(crate) fn check_typevar_decl( decls: &[TypeParam], idx: usize, @@ -524,5 +570,13 @@ pub(crate) mod test { .boxed() } } + + impl super::Type { + pub fn any_non_row_var() -> BoxedStrategy { + any::() + .prop_filter("Cannot be a Row Variable", |t| !t.is_row_var()) + .boxed() + } + } } } diff --git a/hugr/src/types/poly_func.rs b/hugr/src/types/poly_func.rs index fc579e9d5..e152a3c60 100644 --- a/hugr/src/types/poly_func.rs +++ b/hugr/src/types/poly_func.rs @@ -82,9 +82,12 @@ impl PolyFuncType { /// Validates this instance, checking that the types in the body are /// wellformed with respect to the registry, and the type variables declared. - pub fn validate(&self, reg: &ExtensionRegistry) -> Result<(), SignatureError> { + /// Allows both inputs and outputs to contain [RowVariable]s + /// + /// [RowVariable]: [crate::types::TypeEnum::RowVariable] + pub fn validate_var_len(&self, reg: &ExtensionRegistry) -> Result<(), SignatureError> { // TODO https://github.com/CQCL/hugr/issues/624 validate TypeParams declared here, too - self.body.validate(reg, &self.params) + self.body.validate_var_len(reg, &self.params) } /// Instantiates an outer [PolyFuncType], i.e. with no free variables @@ -109,11 +112,13 @@ impl PolyFuncType { pub(crate) mod test { use std::num::NonZeroU64; + use cool_asserts::assert_matches; use lazy_static::lazy_static; - use crate::extension::prelude::{PRELUDE_ID, USIZE_CUSTOM_T, USIZE_T}; + use crate::extension::prelude::{BOOL_T, PRELUDE_ID, USIZE_CUSTOM_T, USIZE_T}; use crate::extension::{ - ExtensionId, ExtensionRegistry, SignatureError, TypeDefBound, PRELUDE, PRELUDE_REGISTRY, + ExtensionId, ExtensionRegistry, SignatureError, TypeDefBound, EMPTY_REG, PRELUDE, + PRELUDE_REGISTRY, }; use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; @@ -134,7 +139,7 @@ pub(crate) mod test { extension_registry: &ExtensionRegistry, ) -> Result { let res = Self::new(params, body); - res.validate(extension_registry)?; + res.validate_var_len(extension_registry)?; Ok(res) } } @@ -341,4 +346,104 @@ pub(crate) mod test { )?; Ok(()) } + + const TP_ANY: TypeParam = TypeParam::Type { b: TypeBound::Any }; + #[test] + fn row_variables_bad_schema() { + // Mismatched TypeBound (Copyable vs Any) + let decl = TypeParam::List { + param: Box::new(TP_ANY), + }; + let e = PolyFuncType::new_validated( + [decl.clone()], + FunctionType::new( + vec![USIZE_T], + vec![Type::new_row_var_use(0, TypeBound::Copyable)], + ), + &PRELUDE_REGISTRY, + ) + .unwrap_err(); + assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { + assert_eq!(actual, decl); + assert_eq!(cached, TypeParam::List {param: Box::new(TypeParam::Type {b: TypeBound::Copyable})}); + }); + // Declared as row variable, used as type variable + let e = PolyFuncType::new_validated( + [decl.clone()], + FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + &EMPTY_REG, + ) + .unwrap_err(); + assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { + assert_eq!(actual, decl); + assert_eq!(cached, TP_ANY); + }); + } + + #[test] + fn row_variables() { + let rty = Type::new_row_var_use(0, TypeBound::Any); + let pf = PolyFuncType::new_validated( + [TypeParam::new_list(TP_ANY)], + FunctionType::new(vec![USIZE_T, rty.clone()], vec![Type::new_tuple(rty)]), + &PRELUDE_REGISTRY, + ) + .unwrap(); + + fn seq2() -> Vec { + vec![USIZE_T.into(), BOOL_T.into()] + } + pf.instantiate(&[TypeArg::Type { ty: USIZE_T }], &PRELUDE_REGISTRY) + .unwrap_err(); + pf.instantiate( + &[TypeArg::Sequence { + elems: vec![USIZE_T.into(), TypeArg::Sequence { elems: seq2() }], + }], + &PRELUDE_REGISTRY, + ) + .unwrap_err(); + + let t2 = pf + .instantiate(&[TypeArg::Sequence { elems: seq2() }], &PRELUDE_REGISTRY) + .unwrap(); + assert_eq!( + t2, + FunctionType::new( + vec![USIZE_T, USIZE_T, BOOL_T], + vec![Type::new_tuple(vec![USIZE_T, BOOL_T])] + ) + ); + } + + #[test] + fn row_variables_inner() { + let inner_fty = Type::new_function(FunctionType::new_endo(vec![Type::new_row_var_use( + 0, + TypeBound::Copyable, + )])); + let pf = PolyFuncType::new_validated( + [TypeParam::List { + param: Box::new(TypeParam::Type { + b: TypeBound::Copyable, + }), + }], + FunctionType::new(vec![USIZE_T, inner_fty.clone()], vec![inner_fty]), + &PRELUDE_REGISTRY, + ) + .unwrap(); + + let inner3 = Type::new_function(FunctionType::new_endo(vec![USIZE_T, BOOL_T, USIZE_T])); + let t3 = pf + .instantiate( + &[TypeArg::Sequence { + elems: vec![USIZE_T.into(), BOOL_T.into(), USIZE_T.into()], + }], + &PRELUDE_REGISTRY, + ) + .unwrap(); + assert_eq!( + t3, + FunctionType::new(vec![USIZE_T, inner3.clone()], vec![inner3]) + ); + } } diff --git a/hugr/src/types/serialize.rs b/hugr/src/types/serialize.rs index 4a263af14..e4585d2f5 100644 --- a/hugr/src/types/serialize.rs +++ b/hugr/src/types/serialize.rs @@ -16,6 +16,7 @@ pub(super) enum SerSimpleType { Opaque(CustomType), Alias(AliasDecl), V { i: usize, b: TypeBound }, + R { i: usize, b: TypeBound }, } impl From for SerSimpleType { @@ -33,6 +34,7 @@ impl From for SerSimpleType { TypeEnum::Alias(a) => SerSimpleType::Alias(a), TypeEnum::Function(sig) => SerSimpleType::G(sig), TypeEnum::Variable(i, b) => SerSimpleType::V { i, b }, + TypeEnum::RowVariable(i, b) => SerSimpleType::R { i, b }, TypeEnum::Sum(st) => SerSimpleType::Sum(st), } } @@ -51,6 +53,7 @@ impl From for Type { SerSimpleType::Opaque(o) => Type::new_extension(o), SerSimpleType::Alias(a) => Type::new_alias(a), SerSimpleType::V { i, b } => Type::new_var_use(i, b), + SerSimpleType::R { i, b } => Type::new_row_var_use(i, b), } } } diff --git a/hugr/src/types/signature.rs b/hugr/src/types/signature.rs index 89c832fa1..1cf731eda 100644 --- a/hugr/src/types/signature.rs +++ b/hugr/src/types/signature.rs @@ -5,7 +5,7 @@ use itertools::Either; use std::fmt::{self, Display, Write}; use super::type_param::TypeParam; -use super::{subst_row, Substitution, Type, TypeRow}; +use super::{Substitution, Type, TypeBound, TypeEnum, TypeRow}; use crate::core::PortIndex; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; @@ -39,22 +39,21 @@ impl FunctionType { self } - pub(crate) fn validate( + pub(super) fn validate_var_len( &self, extension_registry: &ExtensionRegistry, var_decls: &[TypeParam], ) -> Result<(), SignatureError> { - self.input - .iter() - .chain(self.output.iter()) - .try_for_each(|t| t.validate(extension_registry, var_decls))?; + self.input.validate_var_len(extension_registry, var_decls)?; + self.output + .validate_var_len(extension_registry, var_decls)?; self.extension_reqs.validate(var_decls) } pub(crate) fn substitute(&self, tr: &Substitution) -> Self { FunctionType { - input: subst_row(&self.input, tr), - output: subst_row(&self.output, tr), + input: self.input.substitute(tr), + output: self.output.substitute(tr), extension_reqs: self.extension_reqs.substitute(tr), } } @@ -99,6 +98,7 @@ impl FunctionType { /// of bounds. #[inline] pub fn in_port_type(&self, port: impl Into) -> Option<&Type> { + debug_assert!(self.find_rowvar().is_none()); self.input.get(port.into().index()) } @@ -106,6 +106,7 @@ impl FunctionType { /// of bounds. #[inline] pub fn out_port_type(&self, port: impl Into) -> Option<&Type> { + debug_assert!(self.find_rowvar().is_none()); self.output.get(port.into().index()) } @@ -113,6 +114,7 @@ impl FunctionType { /// of bounds. #[inline] pub fn in_port_type_mut(&mut self, port: impl Into) -> Option<&mut Type> { + debug_assert!(self.find_rowvar().is_none()); self.input.get_mut(port.into().index()) } @@ -120,6 +122,7 @@ impl FunctionType { /// of bounds. #[inline] pub fn out_port_type_mut(&mut self, port: impl Into) -> Option<&mut Type> { + debug_assert!(self.find_rowvar().is_none()); self.output.get_mut(port.into().index()) } @@ -190,6 +193,17 @@ impl FunctionType { } impl FunctionType { + /// If this FunctionType contains any row variables, return one. + pub fn find_rowvar(&self) -> Option<(usize, TypeBound)> { + self.input + .iter() + .chain(self.output.iter()) + .find_map(|t| match t.0 { + TypeEnum::RowVariable(idx, bound) => Some((idx, bound)), + _ => None, + }) + } + /// Returns the `Port`s in the signature for a given direction. #[inline] pub fn ports(&self, dir: Direction) -> impl Iterator { diff --git a/hugr/src/types/type_param.rs b/hugr/src/types/type_param.rs index 1c4efc617..e7019fe7c 100644 --- a/hugr/src/types/type_param.rs +++ b/hugr/src/types/type_param.rs @@ -72,9 +72,10 @@ pub enum TypeParam { /// The [CustomType] defining the parameter. ty: CustomType, }, - /// Argument is a [TypeArg::Sequence]. A list of indeterminate size containing parameters. + /// Argument is a [TypeArg::Sequence]. A list of indeterminate size containing + /// parameters all of the (same) specified element type. List { - /// The [TypeParam] contained in the list. + /// The [TypeParam] describing each element of the list. param: Box, }, /// Argument is a [TypeArg::Sequence]. A tuple of parameters. @@ -104,6 +105,13 @@ impl TypeParam { } } + /// Make a new `TypeParam::List` (an arbitrary-length homogenous list) + pub fn new_list(elem: impl Into) -> Self { + Self::List { + param: Box::new(elem.into()), + } + } + fn contains(&self, other: &TypeParam) -> bool { match (self, other) { (TypeParam::Type { b: b1 }, TypeParam::Type { b: b2 }) => b1.contains(*b2), @@ -165,8 +173,9 @@ pub enum TypeArg { #[allow(missing_docs)] es: ExtensionSet, }, - /// Variable (used in type schemes only), that is not a [TypeArg::Type] - /// or [TypeArg::Extensions] - see [TypeArg::new_var_use] + /// Variable (used in type schemes or inside polymorphic functions), + /// but not a [TypeArg::Type] (not even a row variable i.e. [TypeParam::List] of type) + /// nor [TypeArg::Extensions] - see [TypeArg::new_var_use] Variable { #[allow(missing_docs)] #[serde(flatten)] @@ -213,13 +222,21 @@ pub struct TypeArgVariable { impl TypeArg { /// Makes a TypeArg representing a use (occurrence) of the type variable - /// with the specified index. For use within type schemes only: - /// `bound` must match that with which the variable was declared. + /// with the specified index. + /// `decl` must be exactly that with which the variable was declared. pub fn new_var_use(idx: usize, decl: TypeParam) -> Self { match decl { - TypeParam::Type { b } => TypeArg::Type { - ty: Type::new_var_use(idx, b), - }, + TypeParam::Type { b } => Type::new_var_use(idx, b).into(), + TypeParam::List { param: bx } if matches!(*bx, TypeParam::Type { .. }) => { + // There are two reasonable schemes for representing row variables: + // 1. TypeArg::Variable(idx, TypeParam::List(TypeParam::Type(typebound))) + // 2. TypeArg::Type(Type::new_row_var_use(idx, typebound)) + // Here we prefer the latter for canonicalization: TypeArgVariable's fields are non-pub + // so this pevents constructing malformed cases like the former. + let TypeParam::Type { b } = *bx else { panic!() }; + Type::new_row_var_use(idx, b).into() + } + // Similarly, prevent TypeArg::Variable(idx, TypeParam::Extensions) TypeParam::Extensions => TypeArg::Extensions { es: ExtensionSet::type_var(idx), }, @@ -240,7 +257,8 @@ impl TypeArg { var_decls: &[TypeParam], ) -> Result<(), SignatureError> { match self { - TypeArg::Type { ty } => ty.validate(extension_registry, var_decls), + // Row variables are represented as 'TypeArg::Type's (see TypeArg::new_var_use) + TypeArg::Type { ty } => ty.validate(true, extension_registry, var_decls), TypeArg::BoundedNat { .. } => Ok(()), TypeArg::Opaque { arg: custarg } => { // We could also add a facility to Extension to validate that the constant *value* @@ -255,15 +273,35 @@ impl TypeArg { TypeArg::Extensions { es: _ } => Ok(()), TypeArg::Variable { v: TypeArgVariable { idx, cached_decl }, - } => check_typevar_decl(var_decls, *idx, cached_decl), + } => { + assert!( + match cached_decl { + TypeParam::Type { .. } => false, + TypeParam::List { param } if matches!(**param, TypeParam::Type { .. }) => + false, + _ => true, + }, + "Malformed TypeArg::Variable {} - should be inconstructible", + cached_decl + ); + + check_typevar_decl(var_decls, *idx, cached_decl) + } } } pub(crate) fn substitute(&self, t: &Substitution) -> Self { match self { - TypeArg::Type { ty } => TypeArg::Type { - ty: ty.substitute(t), - }, + TypeArg::Type { ty } => { + let tys = ty.substitute(t).into_iter().map_into().collect::>(); + match as TryInto<[TypeArg; 1]>>::try_into(tys) { + Ok([ty]) => ty, + // Multiple elements can only have come from a row variable. + // So, we must be either in a TypeArg::Sequence; or a single Row Variable + // - fitting into a hole declared as a TypeParam::List (as per check_type_arg). + Err(elems) => TypeArg::Sequence { elems }, + } + } TypeArg::BoundedNat { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's TypeArg::Opaque { arg: CustomTypeArg { typ, .. }, @@ -273,9 +311,28 @@ impl TypeArg { debug_assert_eq!(&typ.substitute(t), typ); self.clone() } - TypeArg::Sequence { elems } => TypeArg::Sequence { - elems: elems.iter().map(|ta| ta.substitute(t)).collect(), - }, + TypeArg::Sequence { elems } => { + let mut are_types = elems.iter().map(|e| matches!(e, TypeArg::Type { .. })); + let elems = match are_types.next() { + Some(true) => { + assert!(are_types.all(|b| b)); // If one is a Type, so must the rest be + // So, anything that doesn't produce a Type, was a row variable => multiple Types + elems + .iter() + .flat_map(|ta| match ta.substitute(t) { + ty @ TypeArg::Type { .. } => vec![ty], + TypeArg::Sequence { elems } => elems, + _ => panic!("Expected Type or row of Types"), + }) + .collect() + } + _ => { + // not types, no need to flatten (and mustn't, in case of nested Sequences) + elems.iter().map(|ta| ta.substitute(t)).collect() + } + }; + TypeArg::Sequence { elems } + } TypeArg::Extensions { es } => TypeArg::Extensions { es: es.substitute(t), }, @@ -326,13 +383,29 @@ pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgErr _, ) if param.contains(cached_decl) => Ok(()), (TypeArg::Type { ty }, TypeParam::Type { b: bound }) - if bound.contains(ty.least_upper_bound()) => + if bound.contains(ty.least_upper_bound()) && !ty.is_row_var() => { Ok(()) } (TypeArg::Sequence { elems }, TypeParam::List { param }) => { - elems.iter().try_for_each(|arg| check_type_arg(arg, param)) + elems.iter().try_for_each(|arg| { + // Also allow elements that are RowVars if fitting into a List of Types + if let (TypeArg::Type { ty }, TypeParam::Type { b }) = (arg, &**param) { + if ty.is_row_var() && b.contains(ty.least_upper_bound()) { + return Ok(()); + } + } + check_type_arg(arg, param) + }) } + // Also allow a single "Type" to be used for a List *only* if the Type is a row variable + // (i.e., it's not really a Type, it's multiple Types) + (TypeArg::Type { ty }, TypeParam::List { param }) + if ty.is_row_var() && param.contains(&ty.least_upper_bound().into()) => + { + Ok(()) + } + (TypeArg::Sequence { elems: items }, TypeParam::Tuple { params: types }) => { if items.len() != types.len() { Err(TypeArgError::WrongNumberTuple(items.len(), types.len())) @@ -402,6 +475,155 @@ pub enum TypeArgError { #[cfg(test)] mod test { + use itertools::Itertools; + + use super::{check_type_arg, Substitution, TypeArg, TypeParam}; + use crate::extension::prelude::{BOOL_T, PRELUDE_REGISTRY, USIZE_T}; + use crate::types::{type_param::TypeArgError, Type, TypeBound}; + + #[test] + fn type_arg_fits_param() { + let rowvar = Type::new_row_var_use; + fn check(arg: impl Into, parm: &TypeParam) -> Result<(), TypeArgError> { + check_type_arg(&arg.into(), parm) + } + fn check_seq>( + args: &[T], + parm: &TypeParam, + ) -> Result<(), TypeArgError> { + let arg = args.iter().cloned().map_into().collect_vec().into(); + check_type_arg(&arg, parm) + } + // Simple cases: a TypeArg::Type is a TypeParam::Type but singleton sequences are lists + check(USIZE_T, &TypeBound::Eq.into()).unwrap(); + let seq_param = TypeParam::new_list(TypeBound::Eq); + check(USIZE_T, &seq_param).unwrap_err(); + check_seq(&[USIZE_T], &TypeBound::Any.into()).unwrap_err(); + + // Into a list of type, we can fit a single row var + check(rowvar(0, TypeBound::Eq), &seq_param).unwrap(); + // or a list of (types or row vars) + check(vec![], &seq_param).unwrap(); + check_seq(&[rowvar(0, TypeBound::Eq)], &seq_param).unwrap(); + check_seq( + &[rowvar(1, TypeBound::Any), USIZE_T, rowvar(0, TypeBound::Eq)], + &TypeParam::new_list(TypeBound::Any), + ) + .unwrap(); + // Next one fails because a list of Eq is required + check_seq( + &[rowvar(1, TypeBound::Any), USIZE_T, rowvar(0, TypeBound::Eq)], + &seq_param, + ) + .unwrap_err(); + // seq of seq of types is not allowed + check( + vec![USIZE_T.into(), vec![USIZE_T.into()].into()], + &seq_param, + ) + .unwrap_err(); + + // Similar for nats (but no equivalent of fancy row vars) + check(5, &TypeParam::max_nat()).unwrap(); + check_seq(&[5], &TypeParam::max_nat()).unwrap_err(); + let list_of_nat = TypeParam::new_list(TypeParam::max_nat()); + check(5, &list_of_nat).unwrap_err(); + check_seq(&[5], &list_of_nat).unwrap(); + check(TypeArg::new_var_use(0, list_of_nat.clone()), &list_of_nat).unwrap(); + // But no equivalent of row vars - can't append a nat onto a list-in-a-var: + check( + vec![5.into(), TypeArg::new_var_use(0, list_of_nat.clone())], + &list_of_nat, + ) + .unwrap_err(); + + // TypeParam::Tuples require a TypeArg::Seq of the same number of elems + let usize_and_ty = TypeParam::Tuple { + params: vec![TypeParam::max_nat(), TypeBound::Eq.into()], + }; + check(vec![5.into(), USIZE_T.into()], &usize_and_ty).unwrap(); + check(vec![USIZE_T.into(), 5.into()], &usize_and_ty).unwrap_err(); // Wrong way around + let two_types = TypeParam::Tuple { + params: vec![TypeBound::Any.into(), TypeBound::Any.into()], + }; + check(TypeArg::new_var_use(0, two_types.clone()), &two_types).unwrap(); + // not a Row Var which could have any number of elems + check(TypeArg::new_var_use(0, seq_param), &two_types).unwrap_err(); + } + + #[test] + fn type_arg_subst_row() { + let row_param = TypeParam::new_list(TypeBound::Copyable); + let row_arg: TypeArg = vec![BOOL_T.into(), Type::UNIT.into()].into(); + check_type_arg(&row_arg, &row_param).unwrap(); + + // Now say a row variable referring to *that* row was used + // to instantiate an outer "row parameter" (list of type). + let outer_param = TypeParam::new_list(TypeBound::Any); + let outer_arg = TypeArg::Sequence { + elems: vec![ + Type::new_row_var_use(0, TypeBound::Copyable).into(), + USIZE_T.into(), + ], + }; + check_type_arg(&outer_arg, &outer_param).unwrap(); + + let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg], &PRELUDE_REGISTRY)); + assert_eq!( + outer_arg2, + vec![BOOL_T.into(), Type::UNIT.into(), USIZE_T.into()].into() + ); + + // Of course this is still valid (as substitution is guaranteed to preserve validity) + check_type_arg(&outer_arg2, &outer_param).unwrap(); + } + + #[test] + fn subst_list_list() { + let outer_param = TypeParam::new_list(TypeParam::new_list(TypeBound::Any)); + let row_var_decl = TypeParam::new_list(TypeBound::Copyable); + let row_var_use = TypeArg::new_var_use(0, row_var_decl.clone()); + let good_arg = TypeArg::Sequence { + elems: vec![ + // The row variables here refer to `row_var_decl` above + vec![USIZE_T.into()].into(), + row_var_use.clone(), + vec![row_var_use, USIZE_T.into()].into(), + ], + }; + check_type_arg(&good_arg, &outer_param).unwrap(); + + // Outer list cannot include single types: + let TypeArg::Sequence { mut elems } = good_arg.clone() else { + panic!() + }; + elems.push(USIZE_T.into()); + assert_eq!( + check_type_arg(&TypeArg::Sequence { elems }, &outer_param), + Err(TypeArgError::TypeMismatch { + arg: USIZE_T.into(), + // The error reports the type expected for each element of the list: + param: TypeParam::new_list(TypeBound::Any) + }) + ); + + // Now substitute a list of two types for that row-variable + let row_var_arg = vec![USIZE_T.into(), BOOL_T.into()].into(); + check_type_arg(&row_var_arg, &row_var_decl).unwrap(); + let subst_arg = + good_arg.substitute(&Substitution(&[row_var_arg.clone()], &PRELUDE_REGISTRY)); + check_type_arg(&subst_arg, &outer_param).unwrap(); // invariance of substitution + assert_eq!( + subst_arg, + TypeArg::Sequence { + elems: vec![ + vec![USIZE_T.into()].into(), + row_var_arg, + vec![USIZE_T.into(), BOOL_T.into(), USIZE_T.into()].into() + ] + } + ); + } mod proptest { diff --git a/hugr/src/types/type_row.rs b/hugr/src/types/type_row.rs index 3d97a8068..5be3ddc2f 100644 --- a/hugr/src/types/type_row.rs +++ b/hugr/src/types/type_row.rs @@ -7,8 +7,11 @@ use std::{ ops::{Deref, DerefMut}, }; -use super::Type; -use crate::utils::display_list; +use super::{type_param::TypeParam, Substitution, Type}; +use crate::{ + extension::{ExtensionRegistry, SignatureError}, + utils::display_list, +}; use delegate::delegate; use itertools::Itertools; @@ -75,6 +78,28 @@ impl TypeRow { pub fn get_mut(&mut self, offset: usize) -> Option<&mut Type>; } } + + /// Applies a substitution to the row. Note this may change the length + /// if-and-only-if the row contains any [RowVariable]s. + /// + /// [RowVariable]: [crate::types::TypeEnum::RowVariable] + pub(super) fn substitute(&self, tr: &Substitution) -> TypeRow { + let res = self + .iter() + .flat_map(|ty| ty.substitute(tr)) + .collect::>() + .into(); + res + } + + pub(super) fn validate_var_len( + &self, + exts: &ExtensionRegistry, + var_decls: &[TypeParam], + ) -> Result<(), SignatureError> { + self.iter() + .try_for_each(|t| t.validate(true, exts, var_decls)) + } } impl Default for TypeRow { diff --git a/specification/schema/hugr_schema_strict_v1.json b/specification/schema/hugr_schema_strict_v1.json index df387a43f..d0c5aa92e 100644 --- a/specification/schema/hugr_schema_strict_v1.json +++ b/specification/schema/hugr_schema_strict_v1.json @@ -1652,6 +1652,34 @@ "title": "Qubit", "type": "object" }, + "RowVar": { + "additionalProperties": false, + "description": "A variable standing for a row of some (unknown) number of types.\nMay occur only within a row; not a node input/output.", + "properties": { + "t": { + "const": "R", + "default": "R", + "enum": [ + "R" + ], + "title": "T", + "type": "string" + }, + "i": { + "title": "I", + "type": "integer" + }, + "b": { + "$ref": "#/$defs/TypeBound" + } + }, + "required": [ + "i", + "b" + ], + "title": "RowVar", + "type": "object" + }, "SequenceArg": { "additionalProperties": false, "properties": { @@ -1913,6 +1941,7 @@ "I": "#/$defs/USize", "Opaque": "#/$defs/Opaque", "Q": "#/$defs/Qubit", + "R": "#/$defs/RowVar", "Sum": "#/$defs/SumType", "V": "#/$defs/Variable" }, @@ -1925,6 +1954,9 @@ { "$ref": "#/$defs/Variable" }, + { + "$ref": "#/$defs/RowVar" + }, { "$ref": "#/$defs/USize" }, diff --git a/specification/schema/hugr_schema_v1.json b/specification/schema/hugr_schema_v1.json index 6e8c177eb..d5dab6428 100644 --- a/specification/schema/hugr_schema_v1.json +++ b/specification/schema/hugr_schema_v1.json @@ -1652,6 +1652,34 @@ "title": "Qubit", "type": "object" }, + "RowVar": { + "additionalProperties": true, + "description": "A variable standing for a row of some (unknown) number of types.\nMay occur only within a row; not a node input/output.", + "properties": { + "t": { + "const": "R", + "default": "R", + "enum": [ + "R" + ], + "title": "T", + "type": "string" + }, + "i": { + "title": "I", + "type": "integer" + }, + "b": { + "$ref": "#/$defs/TypeBound" + } + }, + "required": [ + "i", + "b" + ], + "title": "RowVar", + "type": "object" + }, "SequenceArg": { "additionalProperties": true, "properties": { @@ -1913,6 +1941,7 @@ "I": "#/$defs/USize", "Opaque": "#/$defs/Opaque", "Q": "#/$defs/Qubit", + "R": "#/$defs/RowVar", "Sum": "#/$defs/SumType", "V": "#/$defs/Variable" }, @@ -1925,6 +1954,9 @@ { "$ref": "#/$defs/Variable" }, + { + "$ref": "#/$defs/RowVar" + }, { "$ref": "#/$defs/USize" }, diff --git a/specification/schema/testing_hugr_schema_strict_v1.json b/specification/schema/testing_hugr_schema_strict_v1.json index 39687363c..9c272a944 100644 --- a/specification/schema/testing_hugr_schema_strict_v1.json +++ b/specification/schema/testing_hugr_schema_strict_v1.json @@ -1729,6 +1729,34 @@ "title": "Qubit", "type": "object" }, + "RowVar": { + "additionalProperties": false, + "description": "A variable standing for a row of some (unknown) number of types.\nMay occur only within a row; not a node input/output.", + "properties": { + "t": { + "const": "R", + "default": "R", + "enum": [ + "R" + ], + "title": "T", + "type": "string" + }, + "i": { + "title": "I", + "type": "integer" + }, + "b": { + "$ref": "#/$defs/TypeBound" + } + }, + "required": [ + "i", + "b" + ], + "title": "RowVar", + "type": "object" + }, "SequenceArg": { "additionalProperties": false, "properties": { @@ -1990,6 +2018,7 @@ "I": "#/$defs/USize", "Opaque": "#/$defs/Opaque", "Q": "#/$defs/Qubit", + "R": "#/$defs/RowVar", "Sum": "#/$defs/SumType", "V": "#/$defs/Variable" }, @@ -2002,6 +2031,9 @@ { "$ref": "#/$defs/Variable" }, + { + "$ref": "#/$defs/RowVar" + }, { "$ref": "#/$defs/USize" }, diff --git a/specification/schema/testing_hugr_schema_v1.json b/specification/schema/testing_hugr_schema_v1.json index 7b87683af..01c3d6bb8 100644 --- a/specification/schema/testing_hugr_schema_v1.json +++ b/specification/schema/testing_hugr_schema_v1.json @@ -1729,6 +1729,34 @@ "title": "Qubit", "type": "object" }, + "RowVar": { + "additionalProperties": true, + "description": "A variable standing for a row of some (unknown) number of types.\nMay occur only within a row; not a node input/output.", + "properties": { + "t": { + "const": "R", + "default": "R", + "enum": [ + "R" + ], + "title": "T", + "type": "string" + }, + "i": { + "title": "I", + "type": "integer" + }, + "b": { + "$ref": "#/$defs/TypeBound" + } + }, + "required": [ + "i", + "b" + ], + "title": "RowVar", + "type": "object" + }, "SequenceArg": { "additionalProperties": true, "properties": { @@ -1990,6 +2018,7 @@ "I": "#/$defs/USize", "Opaque": "#/$defs/Opaque", "Q": "#/$defs/Qubit", + "R": "#/$defs/RowVar", "Sum": "#/$defs/SumType", "V": "#/$defs/Variable" }, @@ -2002,6 +2031,9 @@ { "$ref": "#/$defs/Variable" }, + { + "$ref": "#/$defs/RowVar" + }, { "$ref": "#/$defs/USize" },