From a2e4d050f7f270e992f9d498cc1fabb0ec2c49ab Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 2 Apr 2024 12:19:04 +0100 Subject: [PATCH] fix: polymorphic calls (#901) * Store the polymorphic type of the called function in Call node, use for the static input * Also store the typeargs, and the cache of the instantiation BREAKING CHANGE: `Call` nodes now constructed via `try_new` function --- quantinuum-hugr/src/builder/build_traits.rs | 3 +- quantinuum-hugr/src/hugr/validate/test.rs | 18 +++++++++ quantinuum-hugr/src/ops/dataflow.rs | 43 ++++++++++++++++----- 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/quantinuum-hugr/src/builder/build_traits.rs b/quantinuum-hugr/src/builder/build_traits.rs index 277bcd0a6..7b510df68 100644 --- a/quantinuum-hugr/src/builder/build_traits.rs +++ b/quantinuum-hugr/src/builder/build_traits.rs @@ -576,8 +576,7 @@ pub trait Dataflow: Container { }) } }; - let signature = type_scheme.instantiate(type_args, exts)?; - let op: OpType = ops::Call { signature }.into(); + let op: OpType = ops::Call::try_new(type_scheme, type_args, exts)?.into(); let const_in_port = op.static_input_port().unwrap(); let op_id = self.add_dataflow_op(op, input_wires)?; let src_port = self.hugr_mut().num_outputs(function.node()) - 1; diff --git a/quantinuum-hugr/src/hugr/validate/test.rs b/quantinuum-hugr/src/hugr/validate/test.rs index 3bf93758c..6ecf00699 100644 --- a/quantinuum-hugr/src/hugr/validate/test.rs +++ b/quantinuum-hugr/src/hugr/validate/test.rs @@ -4,6 +4,7 @@ use super::*; use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{ BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, + HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; use crate::extension::{Extension, ExtensionId, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY}; @@ -540,6 +541,23 @@ fn no_polymorphic_consts() -> Result<(), Box> { Ok(()) } +#[test] +fn test_polymorphic_call() -> Result<(), Box> { + let mut m = ModuleBuilder::new(); + let id = m.declare( + "id", + PolyFuncType::new( + vec![TypeBound::Any.into()], + FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + ), + )?; + let mut f = m.define_function("main", FunctionType::new_endo(vec![USIZE_T]).into())?; + let c = f.call(&id, &[USIZE_T.into()], f.input_wires(), &PRELUDE_REGISTRY)?; + f.finish_with_outputs(c.outputs())?; + let _ = m.finish_prelude_hugr()?; + Ok(()) +} + #[cfg(feature = "extension_inference")] mod extension_tests { use super::*; diff --git a/quantinuum-hugr/src/ops/dataflow.rs b/quantinuum-hugr/src/ops/dataflow.rs index 74ef64dde..9929ac5ac 100644 --- a/quantinuum-hugr/src/ops/dataflow.rs +++ b/quantinuum-hugr/src/ops/dataflow.rs @@ -2,9 +2,9 @@ use super::{impl_op_name, OpTag, OpTrait}; -use crate::extension::ExtensionSet; +use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; use crate::ops::StaticTag; -use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; +use crate::types::{EdgeKind, FunctionType, PolyFuncType, Type, TypeArg, TypeRow}; use crate::IncomingPort; pub(crate) trait DataflowOpTrait { @@ -153,7 +153,9 @@ impl StaticTag for T { #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct Call { /// Signature of function being called - pub signature: FunctionType, + func_sig: PolyFuncType, + type_args: Vec, + instantiation: FunctionType, // Cache, so we can fail in try_new() not in signature() } impl_op_name!(Call); @@ -165,7 +167,7 @@ impl DataflowOpTrait for Call { } fn signature(&self) -> FunctionType { - self.signature.clone() + self.instantiation.clone() } fn static_input(&self) -> Option { @@ -174,10 +176,28 @@ impl DataflowOpTrait for Call { } } impl Call { + /// Try to make a new Call. Returns an error if the `type_args`` do not fit the [TypeParam]s + /// declared by the function. + /// + /// [TypeParam]: crate::types::type_param::TypeParam + pub fn try_new( + func_sig: PolyFuncType, + type_args: impl Into>, + exts: &ExtensionRegistry, + ) -> Result { + let type_args = type_args.into(); + let instantiation = func_sig.instantiate(&type_args, exts)?; + Ok(Self { + func_sig, + type_args, + instantiation, + }) + } + #[inline] /// Return the signature of the function called by this op. - pub fn called_function_type(&self) -> &FunctionType { - &self.signature + pub fn called_function_type(&self) -> &PolyFuncType { + &self.func_sig } /// The IncomingPort which links to the function being called. @@ -189,8 +209,9 @@ impl Call { /// # use hugr::ops::OpType; /// # use hugr::types::FunctionType; /// # use hugr::extension::prelude::QB_T; + /// # use hugr::extension::PRELUDE_REGISTRY; /// let signature = FunctionType::new(vec![QB_T, QB_T], vec![QB_T, QB_T]); - /// let call = Call { signature }; + /// let call = Call::try_new(signature.into(), &[], &PRELUDE_REGISTRY).unwrap(); /// let op = OpType::Call(call.clone()); /// assert_eq!(op.static_input_port(), Some(call.called_function_port())); /// ``` @@ -198,11 +219,15 @@ impl Call { /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port #[inline] pub fn called_function_port(&self) -> IncomingPort { - self.called_function_type().input_count().into() + self.instantiation.input_count().into() } } -/// Call a function indirectly. Like call, but the first input is a standard dataflow graph type. +/// Call a function indirectly. Like call, but the function input is a value +/// (runtime, not static) dataflow edge, and we assume all its binders have +/// already been given [TypeArg]s by [TypeApply] nodes. +/// +/// [TypeApply]: crate::ops::LeafOp::TypeApply #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct CallIndirect { /// Signature of function being called