diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 7ca116684..ea4e216d0 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -300,6 +300,18 @@ class LoadConstant(DataflowOp): datatype: Type +class LoadFunction(DataflowOp): + """Load a static function in to the local dataflow graph.""" + + op: Literal["LoadFunction"] = "LoadFunction" + func_sig: PolyFuncType + type_args: list[tys.TypeArg] + signature: FunctionType = Field(default_factory=FunctionType.empty) + + def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: + self.signature = FunctionType(input=list(in_types), output=list(out_types)) + + class DFG(DataflowOp): """A simply nested dataflow graph.""" @@ -486,6 +498,7 @@ class OpType(RootModel): | Call | CallIndirect | LoadConstant + | LoadFunction | CustomOp | Noop | MakeTuple diff --git a/hugr-py/src/hugr/serialization/tys.py b/hugr-py/src/hugr/serialization/tys.py index d3c562608..f7f196fb7 100644 --- a/hugr-py/src/hugr/serialization/tys.py +++ b/hugr-py/src/hugr/serialization/tys.py @@ -194,7 +194,7 @@ class FunctionType(BaseModel): input: "TypeRow" # Value inputs of the function. output: "TypeRow" # Value outputs of the function. # The extension requirements which are added by the operation - extension_reqs: "ExtensionSet" = Field(default_factory=list) + extension_reqs: ExtensionSet = Field(default_factory=ExtensionSet) @classmethod def empty(cls) -> "FunctionType": diff --git a/hugr/src/builder/build_traits.rs b/hugr/src/builder/build_traits.rs index 62a5a3b33..5ea1964e0 100644 --- a/hugr/src/builder/build_traits.rs +++ b/hugr/src/builder/build_traits.rs @@ -375,6 +375,40 @@ pub trait Dataflow: Container { self.load_const(&cid) } + /// Load a static function and return the local dataflow wire for that function. + /// Adds a [`OpType::LoadFunction`] node. + /// + /// The `DEF` const generic is used to indicate whether the function is defined + /// or just declared. + fn load_func( + &mut self, + fid: &FuncID, + type_args: &[TypeArg], + // Sadly required as we substituting in type_args may result in recomputing bounds of types: + exts: &ExtensionRegistry, + ) -> Result { + let func_node = fid.node(); + let func_op = self.hugr().get_nodetype(func_node).op(); + let func_sig = match func_op { + OpType::FuncDefn(ops::FuncDefn { signature, .. }) + | OpType::FuncDecl(ops::FuncDecl { signature, .. }) => signature.clone(), + _ => { + return Err(BuildError::UnexpectedType { + node: func_node, + op_desc: "FuncDecl/FuncDefn", + }) + } + }; + + let load_n = self.add_dataflow_op( + ops::LoadFunction::try_new(func_sig, type_args, exts)?, + // Static wire from the function node + vec![Wire::new(func_node, OutgoingPort::from(0))], + )?; + + Ok(load_n.out_wire(0)) + } + /// Return a builder for a [`crate::ops::TailLoop`] node. /// The `inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. diff --git a/hugr/src/extension.rs b/hugr/src/extension.rs index c1eca4225..ef3564ee0 100644 --- a/hugr/src/extension.rs +++ b/hugr/src/extension.rs @@ -173,6 +173,17 @@ pub enum SignatureError { cached: FunctionType, expected: FunctionType, }, + /// The result of the type application stored in a [LoadFunction] + /// is not what we get by applying the type-args to the polymorphic function + /// + /// [LoadFunction]: crate::ops::dataflow::LoadFunction + #[error( + "Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}" + )] + LoadFunctionIncorrectlyAppliesType { + cached: FunctionType, + expected: FunctionType, + }, } /// Concrete instantiations of types and operations defined in extensions. diff --git a/hugr/src/hugr/validate.rs b/hugr/src/hugr/validate.rs index 30a7dce09..4d317f1c3 100644 --- a/hugr/src/hugr/validate.rs +++ b/hugr/src/hugr/validate.rs @@ -561,6 +561,10 @@ impl<'a, 'b> ValidationContext<'a, 'b> { c.validate(self.extension_registry) .map_err(|cause| ValidationError::SignatureError { node, cause })?; } + OpType::LoadFunction(c) => { + c.validate(self.extension_registry) + .map_err(|cause| ValidationError::SignatureError { node, cause })?; + } _ => (), } diff --git a/hugr/src/hugr/validate/test.rs b/hugr/src/hugr/validate/test.rs index 04e6958c1..a85debedf 100644 --- a/hugr/src/hugr/validate/test.rs +++ b/hugr/src/hugr/validate/test.rs @@ -558,6 +558,30 @@ fn test_polymorphic_call() -> Result<(), Box> { Ok(()) } +#[test] +fn test_polymorphic_load() -> Result<(), Box> { + let mut m = ModuleBuilder::new(); + let id = m.declare( + "id", + PolyFuncType::new( + vec![TypeBound::Any.into()], + FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + ), + )?; + let sig = FunctionType::new( + vec![], + vec![Type::new_function(FunctionType::new_endo(vec![USIZE_T]))], + ); + let mut f = m.define_function("main", sig.into())?; + let l = f.load_func(&id, &[USIZE_T.into()], &PRELUDE_REGISTRY)?; + f.finish_with_outputs([l])?; + let _ = m.finish_prelude_hugr()?; + Ok(()) + + // Function(PolyFuncType { params: [Type { b: Any }], body: FunctionType { input: TypeRow { types: [Type(Variable(0, Any), Any)] }, output: TypeRow { types: [Type(Variable(0, Any), Any)] }, extension_reqs: ExtensionSet({}) } }) + // Value(Type(Extension(CustomType { extension: IdentList("prelude"), id: "usize", args: [] +} + #[cfg(feature = "extension_inference")] mod extension_tests { use super::*; diff --git a/hugr/src/ops.rs b/hugr/src/ops.rs index 07a59d116..566388e85 100644 --- a/hugr/src/ops.rs +++ b/hugr/src/ops.rs @@ -23,7 +23,9 @@ use enum_dispatch::enum_dispatch; pub use constant::Const; pub use controlflow::{BasicBlock, Case, Conditional, DataflowBlock, ExitBlock, TailLoop, CFG}; pub use custom::CustomOp; -pub use dataflow::{Call, CallIndirect, DataflowParent, Input, LoadConstant, Output, DFG}; +pub use dataflow::{ + Call, CallIndirect, DataflowParent, Input, LoadConstant, LoadFunction, Output, DFG, +}; pub use leaf::{Lift, MakeTuple, Noop, Tag, UnpackTuple}; pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module}; pub use tag::OpTag; @@ -47,6 +49,7 @@ pub enum OpType { Call, CallIndirect, LoadConstant, + LoadFunction, DFG, CustomOp, Noop, @@ -105,6 +108,7 @@ impl_op_ref_try_into!(Output); impl_op_ref_try_into!(Call); impl_op_ref_try_into!(CallIndirect); impl_op_ref_try_into!(LoadConstant); +impl_op_ref_try_into!(LoadFunction); impl_op_ref_try_into!(DFG, dfg); impl_op_ref_try_into!(CustomOp); impl_op_ref_try_into!(Noop); @@ -226,7 +230,8 @@ impl OpType { Some(Port::new(dir, self.value_port_count(dir))) } - /// If the op has a static input ([`Call`] and [`LoadConstant`]), the port of that input. + /// If the op has a static input ([`Call`], [`LoadConstant`], and [`LoadFunction`]), the port of + /// that input. #[inline] pub fn static_input_port(&self) -> Option { self.static_port(Direction::Incoming) @@ -413,6 +418,7 @@ impl OpParent for Output {} impl OpParent for Call {} impl OpParent for CallIndirect {} impl OpParent for LoadConstant {} +impl OpParent for LoadFunction {} impl OpParent for CustomOp {} impl OpParent for Noop {} impl OpParent for MakeTuple {} diff --git a/hugr/src/ops/dataflow.rs b/hugr/src/ops/dataflow.rs index d8a6ed4fc..98255cec0 100644 --- a/hugr/src/ops/dataflow.rs +++ b/hugr/src/ops/dataflow.rs @@ -316,6 +316,86 @@ impl LoadConstant { } } +/// Load a static function in to the local dataflow graph. +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct LoadFunction { + /// Signature of the function + func_sig: PolyFuncType, + type_args: Vec, + signature: FunctionType, // Cache, so we can fail in try_new() not in signature() +} +impl_op_name!(LoadFunction); +impl DataflowOpTrait for LoadFunction { + const TAG: OpTag = OpTag::LoadFunc; + + fn description(&self) -> &str { + "Load a static function in to the local dataflow graph" + } + + fn signature(&self) -> FunctionType { + self.signature.clone() + } + + fn static_input(&self) -> Option { + Some(EdgeKind::Function(self.func_sig.clone())) + } +} +impl LoadFunction { + /// Try to make a new LoadFunction op. Returns an error if the `type_args`` do not fit + /// the [TypeParam]s declared by the function. + /// + /// [TypeParam]: crate::types::type_param::TypeParam + pub fn try_new( + func_sig: PolyFuncType, + type_args: impl Into>, + exts: &ExtensionRegistry, + ) -> Result { + let type_args = type_args.into(); + let instantiation = func_sig.instantiate(&type_args, exts)?; + let signature = FunctionType::new(TypeRow::new(), vec![Type::new_function(instantiation)]); + Ok(Self { + func_sig, + type_args, + signature, + }) + } + + #[inline] + /// Return the type of the function loaded by this op. + pub fn function_type(&self) -> &PolyFuncType { + &self.func_sig + } + + /// The IncomingPort which links to the loaded function. + /// + /// This matches [`OpType::static_input_port`]. + /// + /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port + #[inline] + pub fn function_port(&self) -> IncomingPort { + 0.into() + } + + pub(crate) fn validate( + &self, + extension_registry: &ExtensionRegistry, + ) -> Result<(), SignatureError> { + let other = Self::try_new( + self.func_sig.clone(), + self.type_args.clone(), + extension_registry, + )?; + if other.signature == self.signature { + Ok(()) + } else { + Err(SignatureError::LoadFunctionIncorrectlyAppliesType { + cached: self.signature.clone(), + expected: other.signature.clone(), + }) + } + } +} + /// Operations that is the parent of a dataflow graph. pub trait DataflowParent { /// Signature of the inner dataflow graph. diff --git a/hugr/src/ops/tag.rs b/hugr/src/ops/tag.rs index 435f279f1..9bc55843c 100644 --- a/hugr/src/ops/tag.rs +++ b/hugr/src/ops/tag.rs @@ -54,6 +54,8 @@ pub enum OpTag { FnCall, /// A constant load operation. LoadConst, + /// A function load operation. + LoadFunc, /// A definition that could be at module level or inside a DSG. ScopedDefn, /// A tail-recursive loop. @@ -129,6 +131,7 @@ impl OpTag { OpTag::StaticOutput => &[OpTag::Any], OpTag::FnCall => &[OpTag::StaticInput, OpTag::DataflowChild], OpTag::LoadConst => &[OpTag::StaticInput, OpTag::DataflowChild], + OpTag::LoadFunc => &[OpTag::StaticInput, OpTag::DataflowChild], OpTag::Leaf => &[OpTag::DataflowChild], OpTag::DataflowParent => &[OpTag::Any], } @@ -156,10 +159,11 @@ impl OpTag { OpTag::Cfg => "Nested control-flow operation", OpTag::TailLoop => "Tail-recursive loop", OpTag::Conditional => "Conditional operation", - OpTag::StaticInput => "Node with static input (LoadConst or FnCall)", + OpTag::StaticInput => "Node with static input (LoadConst, LoadFunc, or FnCall)", OpTag::StaticOutput => "Node with static output (FuncDefn, FuncDecl, Const)", OpTag::FnCall => "Function call", OpTag::LoadConst => "Constant load operation", + OpTag::LoadFunc => "Function load operation", OpTag::Leaf => "Leaf operation", OpTag::ScopedDefn => "Definitions that can live at global or local scope", OpTag::DataflowParent => "Operation whose children form a Dataflow Sibling Graph", diff --git a/hugr/src/ops/validate.rs b/hugr/src/ops/validate.rs index 3a828a8d0..e9edb2c86 100644 --- a/hugr/src/ops/validate.rs +++ b/hugr/src/ops/validate.rs @@ -405,7 +405,7 @@ mod test { use super::{ AliasDecl, AliasDefn, Call, CallIndirect, Const, CustomOp, FuncDecl, Input, Lift, LoadConstant, - MakeTuple, Noop, Output, Tag, UnpackTuple, + LoadFunction, MakeTuple, Noop, Output, Tag, UnpackTuple, }; impl_validate_op!(FuncDecl); impl_validate_op!(AliasDecl); @@ -415,6 +415,7 @@ impl_validate_op!(Output); impl_validate_op!(Const); impl_validate_op!(Call); impl_validate_op!(LoadConstant); +impl_validate_op!(LoadFunction); impl_validate_op!(CallIndirect); impl_validate_op!(CustomOp); impl_validate_op!(Noop);