Skip to content

Commit

Permalink
fix: polymorphic calls (#901)
Browse files Browse the repository at this point in the history
* Store the polymorphic type of the called function in Call node, use
for the static input
* Also store the typeargs, and the cache of the instantiation

BREAKING CHANGE: `Call` nodes now constructed via `try_new` function
  • Loading branch information
acl-cqc authored Apr 2, 2024
1 parent 8111375 commit a2e4d05
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
3 changes: 1 addition & 2 deletions quantinuum-hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,7 @@ pub trait Dataflow: Container {
})
}
};
let signature = type_scheme.instantiate(type_args, exts)?;
let op: OpType = ops::Call { signature }.into();
let op: OpType = ops::Call::try_new(type_scheme, type_args, exts)?.into();
let const_in_port = op.static_input_port().unwrap();
let op_id = self.add_dataflow_op(op, input_wires)?;
let src_port = self.hugr_mut().num_outputs(function.node()) - 1;
Expand Down
18 changes: 18 additions & 0 deletions quantinuum-hugr/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder,
HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T};
use crate::extension::{Extension, ExtensionId, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY};
Expand Down Expand Up @@ -540,6 +541,23 @@ fn no_polymorphic_consts() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[test]
fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
let mut m = ModuleBuilder::new();
let id = m.declare(
"id",
PolyFuncType::new(
vec![TypeBound::Any.into()],
FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]),
),
)?;
let mut f = m.define_function("main", FunctionType::new_endo(vec![USIZE_T]).into())?;
let c = f.call(&id, &[USIZE_T.into()], f.input_wires(), &PRELUDE_REGISTRY)?;
f.finish_with_outputs(c.outputs())?;
let _ = m.finish_prelude_hugr()?;
Ok(())
}

#[cfg(feature = "extension_inference")]
mod extension_tests {
use super::*;
Expand Down
43 changes: 34 additions & 9 deletions quantinuum-hugr/src/ops/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
use super::{impl_op_name, OpTag, OpTrait};

use crate::extension::ExtensionSet;
use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError};
use crate::ops::StaticTag;
use crate::types::{EdgeKind, FunctionType, Type, TypeRow};
use crate::types::{EdgeKind, FunctionType, PolyFuncType, Type, TypeArg, TypeRow};
use crate::IncomingPort;

pub(crate) trait DataflowOpTrait {
Expand Down Expand Up @@ -153,7 +153,9 @@ impl<T: DataflowOpTrait> StaticTag for T {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct Call {
/// Signature of function being called
pub signature: FunctionType,
func_sig: PolyFuncType,
type_args: Vec<TypeArg>,
instantiation: FunctionType, // Cache, so we can fail in try_new() not in signature()
}
impl_op_name!(Call);

Expand All @@ -165,7 +167,7 @@ impl DataflowOpTrait for Call {
}

fn signature(&self) -> FunctionType {
self.signature.clone()
self.instantiation.clone()
}

fn static_input(&self) -> Option<EdgeKind> {
Expand All @@ -174,10 +176,28 @@ impl DataflowOpTrait for Call {
}
}
impl Call {
/// Try to make a new Call. Returns an error if the `type_args`` do not fit the [TypeParam]s
/// declared by the function.
///
/// [TypeParam]: crate::types::type_param::TypeParam
pub fn try_new(
func_sig: PolyFuncType,
type_args: impl Into<Vec<TypeArg>>,
exts: &ExtensionRegistry,
) -> Result<Self, SignatureError> {
let type_args = type_args.into();
let instantiation = func_sig.instantiate(&type_args, exts)?;
Ok(Self {
func_sig,
type_args,
instantiation,
})
}

#[inline]
/// Return the signature of the function called by this op.
pub fn called_function_type(&self) -> &FunctionType {
&self.signature
pub fn called_function_type(&self) -> &PolyFuncType {
&self.func_sig
}

/// The IncomingPort which links to the function being called.
Expand All @@ -189,20 +209,25 @@ impl Call {
/// # use hugr::ops::OpType;
/// # use hugr::types::FunctionType;
/// # use hugr::extension::prelude::QB_T;
/// # use hugr::extension::PRELUDE_REGISTRY;
/// let signature = FunctionType::new(vec![QB_T, QB_T], vec![QB_T, QB_T]);
/// let call = Call { signature };
/// let call = Call::try_new(signature.into(), &[], &PRELUDE_REGISTRY).unwrap();
/// let op = OpType::Call(call.clone());
/// assert_eq!(op.static_input_port(), Some(call.called_function_port()));
/// ```
///
/// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
#[inline]
pub fn called_function_port(&self) -> IncomingPort {
self.called_function_type().input_count().into()
self.instantiation.input_count().into()
}
}

/// Call a function indirectly. Like call, but the first input is a standard dataflow graph type.
/// Call a function indirectly. Like call, but the function input is a value
/// (runtime, not static) dataflow edge, and we assume all its binders have
/// already been given [TypeArg]s by [TypeApply] nodes.
///
/// [TypeApply]: crate::ops::LeafOp::TypeApply
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct CallIndirect {
/// Signature of function being called
Expand Down

0 comments on commit a2e4d05

Please sign in to comment.