Skip to content

Commit

Permalink
fix: Call ops not tracking their parameter extensions (#1805)
Browse files Browse the repository at this point in the history
Fixes #1795 

I forgot to track the `type_args` of `Call` and `LoadFunction` when
doing extension resolution

drive-by: Add derives to `types::Substitution`
  • Loading branch information
aborgna-q authored Dec 17, 2024
1 parent d730c43 commit e40b6c7
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 3 deletions.
45 changes: 43 additions & 2 deletions hugr-core/src/extension/resolution/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ use crate::extension::{
use crate::ops::constant::test::CustomTestValue;
use crate::ops::constant::CustomConst;
use crate::ops::{CallIndirect, ExtensionOp, Input, OpTrait, OpType, Tag, Value};
use crate::std_extensions::arithmetic::float_types::{float64_type, ConstF64};
use crate::std_extensions::arithmetic::float_types::{self, float64_type, ConstF64};
use crate::std_extensions::arithmetic::int_ops;
use crate::std_extensions::arithmetic::int_types::{self, int_type};
use crate::std_extensions::collections::list::ListValue;
use crate::types::{Signature, Type};
use crate::types::type_param::TypeParam;
use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound};
use crate::{std_extensions, type_row, Extension, Hugr, HugrView};

#[rstest]
Expand Down Expand Up @@ -346,6 +347,46 @@ fn resolve_custom_const(#[case] custom_const: impl CustomConst) {
check_extension_resolution(hugr);
}

/// Test resolution of function call with type arguments.
#[rstest]
fn resolve_call() {
let dummy_fn_sig = PolyFuncType::new(
vec![TypeParam::Type { b: TypeBound::Any }],
Signature::new(vec![], vec![bool_t()]),
);

let generic_type_1 = TypeArg::Type { ty: float64_type() };
let generic_type_2 = TypeArg::Type { ty: int_type(6) };
let expected_exts = [
float_types::EXTENSION_ID.to_owned(),
int_types::EXTENSION_ID.to_owned(),
]
.into_iter()
.collect::<ExtensionSet>();

let mut module = ModuleBuilder::new();
let dummy_fn = module.declare("called_fn", dummy_fn_sig).unwrap();

let mut func = module
.define_function(
"caller_fn",
Signature::new(vec![], vec![bool_t()])
.with_extension_delta(ExtensionSet::from_iter(expected_exts.clone())),
)
.unwrap();
let _load_func = func.load_func(&dummy_fn, &[generic_type_1]).unwrap();
let call = func.call(&dummy_fn, &[generic_type_2], vec![]).unwrap();
func.finish_with_outputs(call.outputs()).unwrap();

let hugr = module.finish_hugr().unwrap();

for ext in expected_exts {
assert!(hugr.extensions().contains(&ext));
}

check_extension_resolution(hugr);
}

/// Fail when collecting extensions but the weak pointers are not resolved.
#[rstest]
fn dropped_weak_extensions() {
Expand Down
8 changes: 7 additions & 1 deletion hugr-core/src/extension/resolution/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,18 @@ pub(crate) fn collect_op_types_extensions(
OpType::Call(c) => {
collect_signature_exts(c.func_sig.body(), &mut used, &mut missing);
collect_signature_exts(&c.instantiation, &mut used, &mut missing);
for arg in &c.type_args {
collect_typearg_exts(arg, &mut used, &mut missing);
}
}
OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing),
OpType::LoadConstant(lc) => collect_type_exts(&lc.datatype, &mut used, &mut missing),
OpType::LoadFunction(lf) => {
collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing);
collect_signature_exts(&lf.instantiation, &mut used, &mut missing);
for arg in &lf.type_args {
collect_typearg_exts(arg, &mut used, &mut missing);
}
}
OpType::DFG(dfg) => collect_signature_exts(&dfg.signature, &mut used, &mut missing),
OpType::OpaqueOp(op) => {
Expand Down Expand Up @@ -203,7 +209,7 @@ pub(super) fn collect_type_exts<RV: MaybeRV>(
/// - `used_extensions`: A The registry where to store the used extensions.
/// - `missing_extensions`: A set of `ExtensionId`s of which the
/// `Weak<Extension>` pointer has been invalidated.
fn collect_typearg_exts(
pub(super) fn collect_typearg_exts(
arg: &TypeArg,
used_extensions: &mut WeakExtensionRegistry,
missing_extensions: &mut ExtensionSet,
Expand Down
6 changes: 6 additions & 0 deletions hugr-core/src/extension/resolution/types_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ pub fn resolve_op_types_extensions(
OpType::Call(c) => {
resolve_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?;
resolve_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?;
for arg in &mut c.type_args {
resolve_typearg_exts(node, arg, extensions, used_extensions)?;
}
}
OpType::CallIndirect(c) => {
resolve_signature_exts(node, &mut c.signature, extensions, used_extensions)?
Expand All @@ -60,6 +63,9 @@ pub fn resolve_op_types_extensions(
OpType::LoadFunction(lf) => {
resolve_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?;
resolve_signature_exts(node, &mut lf.instantiation, extensions, used_extensions)?;
for arg in &mut lf.type_args {
resolve_typearg_exts(node, arg, extensions, used_extensions)?;
}
}
OpType::DFG(dfg) => {
resolve_signature_exts(node, &mut dfg.signature, extensions, used_extensions)?
Expand Down
2 changes: 2 additions & 0 deletions hugr-core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,8 @@ impl From<Type> for TypeRV {

/// Details a replacement of type variables with a finite list of known values.
/// (Variables out of the range of the list will result in a panic)
#[derive(Clone, Debug, derive_more::Display)]
#[display("[{}]", _0.iter().map(|ta| ta.to_string()).join(", "))]
pub struct Substitution<'a>(&'a [TypeArg]);

impl<'a> Substitution<'a> {
Expand Down

0 comments on commit e40b6c7

Please sign in to comment.