diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index c4b4bad79..19373b04c 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -22,11 +22,12 @@ use crate::extension::{ use crate::ops::constant::test::CustomTestValue; use crate::ops::constant::CustomConst; use crate::ops::{CallIndirect, ExtensionOp, Input, OpTrait, OpType, Tag, Value}; -use crate::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; +use crate::std_extensions::arithmetic::float_types::{self, float64_type, ConstF64}; use crate::std_extensions::arithmetic::int_ops; use crate::std_extensions::arithmetic::int_types::{self, int_type}; use crate::std_extensions::collections::list::ListValue; -use crate::types::{Signature, Type}; +use crate::types::type_param::TypeParam; +use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; use crate::{std_extensions, type_row, Extension, Hugr, HugrView}; #[rstest] @@ -346,6 +347,46 @@ fn resolve_custom_const(#[case] custom_const: impl CustomConst) { check_extension_resolution(hugr); } +/// Test resolution of function call with type arguments. +#[rstest] +fn resolve_call() { + let dummy_fn_sig = PolyFuncType::new( + vec![TypeParam::Type { b: TypeBound::Any }], + Signature::new(vec![], vec![bool_t()]), + ); + + let generic_type_1 = TypeArg::Type { ty: float64_type() }; + let generic_type_2 = TypeArg::Type { ty: int_type(6) }; + let expected_exts = [ + float_types::EXTENSION_ID.to_owned(), + int_types::EXTENSION_ID.to_owned(), + ] + .into_iter() + .collect::(); + + let mut module = ModuleBuilder::new(); + let dummy_fn = module.declare("called_fn", dummy_fn_sig).unwrap(); + + let mut func = module + .define_function( + "caller_fn", + Signature::new(vec![], vec![bool_t()]) + .with_extension_delta(ExtensionSet::from_iter(expected_exts.clone())), + ) + .unwrap(); + let _load_func = func.load_func(&dummy_fn, &[generic_type_1]).unwrap(); + let call = func.call(&dummy_fn, &[generic_type_2], vec![]).unwrap(); + func.finish_with_outputs(call.outputs()).unwrap(); + + let hugr = module.finish_hugr().unwrap(); + + for ext in expected_exts { + assert!(hugr.extensions().contains(&ext)); + } + + check_extension_resolution(hugr); +} + /// Fail when collecting extensions but the weak pointers are not resolved. #[rstest] fn dropped_weak_extensions() { diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 7559ea3fe..6094f0aee 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -50,12 +50,18 @@ pub(crate) fn collect_op_types_extensions( OpType::Call(c) => { collect_signature_exts(c.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&c.instantiation, &mut used, &mut missing); + for arg in &c.type_args { + collect_typearg_exts(arg, &mut used, &mut missing); + } } OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing), OpType::LoadConstant(lc) => collect_type_exts(&lc.datatype, &mut used, &mut missing), OpType::LoadFunction(lf) => { collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing); collect_signature_exts(&lf.instantiation, &mut used, &mut missing); + for arg in &lf.type_args { + collect_typearg_exts(arg, &mut used, &mut missing); + } } OpType::DFG(dfg) => collect_signature_exts(&dfg.signature, &mut used, &mut missing), OpType::OpaqueOp(op) => { @@ -203,7 +209,7 @@ pub(super) fn collect_type_exts( /// - `used_extensions`: A The registry where to store the used extensions. /// - `missing_extensions`: A set of `ExtensionId`s of which the /// `Weak` pointer has been invalidated. -fn collect_typearg_exts( +pub(super) fn collect_typearg_exts( arg: &TypeArg, used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index 7d30ac4fa..d70d6b861 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -50,6 +50,9 @@ pub fn resolve_op_types_extensions( OpType::Call(c) => { resolve_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?; resolve_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?; + for arg in &mut c.type_args { + resolve_typearg_exts(node, arg, extensions, used_extensions)?; + } } OpType::CallIndirect(c) => { resolve_signature_exts(node, &mut c.signature, extensions, used_extensions)? @@ -60,6 +63,9 @@ pub fn resolve_op_types_extensions( OpType::LoadFunction(lf) => { resolve_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?; resolve_signature_exts(node, &mut lf.instantiation, extensions, used_extensions)?; + for arg in &mut lf.type_args { + resolve_typearg_exts(node, arg, extensions, used_extensions)?; + } } OpType::DFG(dfg) => { resolve_signature_exts(node, &mut dfg.signature, extensions, used_extensions)? diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 1e18bf94d..1c5e8f8e6 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -587,6 +587,8 @@ impl From for TypeRV { /// Details a replacement of type variables with a finite list of known values. /// (Variables out of the range of the list will result in a panic) +#[derive(Clone, Debug, derive_more::Display)] +#[display("[{}]", _0.iter().map(|ta| ta.to_string()).join(", "))] pub struct Substitution<'a>(&'a [TypeArg]); impl<'a> Substitution<'a> {