Skip to content

Commit

Permalink
feat: Add LoadFunction node
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Apr 17, 2024
1 parent d016665 commit 7a15fda
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 5 deletions.
13 changes: 13 additions & 0 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,18 @@ class LoadConstant(DataflowOp):
datatype: Type


class LoadFunction(DataflowOp):
"""Load a static function in to the local dataflow graph."""

op: Literal["LoadFunction"] = "LoadFunction"
func_sig: PolyFuncType
type_args: list[tys.TypeArg]
signature: FunctionType = Field(default_factory=FunctionType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.signature = FunctionType(input=list(in_types), output=list(out_types))


class DFG(DataflowOp):
"""A simply nested dataflow graph."""

Expand Down Expand Up @@ -486,6 +498,7 @@ class OpType(RootModel):
| Call
| CallIndirect
| LoadConstant
| LoadFunction
| CustomOp
| Noop
| MakeTuple
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class FunctionType(BaseModel):
input: "TypeRow" # Value inputs of the function.
output: "TypeRow" # Value outputs of the function.
# The extension requirements which are added by the operation
extension_reqs: "ExtensionSet" = Field(default_factory=list)
extension_reqs: ExtensionSet = Field(default_factory=ExtensionSet)

@classmethod
def empty(cls) -> "FunctionType":
Expand Down
34 changes: 34 additions & 0 deletions hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,40 @@ pub trait Dataflow: Container {
self.load_const(&cid)
}

/// Load a static function and return the local dataflow wire for that function.
/// Adds a [`OpType::LoadFunction`] node.
///
/// The `DEF` const generic is used to indicate whether the function is defined
/// or just declared.
fn load_func<const DEFINED: bool>(
&mut self,
fid: &FuncID<DEFINED>,
type_args: &[TypeArg],
// Sadly required as we substituting in type_args may result in recomputing bounds of types:
exts: &ExtensionRegistry,
) -> Result<Wire, BuildError> {
let func_node = fid.node();
let func_op = self.hugr().get_nodetype(func_node).op();
let func_sig = match func_op {
OpType::FuncDefn(ops::FuncDefn { signature, .. })
| OpType::FuncDecl(ops::FuncDecl { signature, .. }) => signature.clone(),
_ => {
return Err(BuildError::UnexpectedType {
node: func_node,
op_desc: "FuncDecl/FuncDefn",
})
}
};

let load_n = self.add_dataflow_op(
ops::LoadFunction::try_new(func_sig, type_args, exts)?,
// Static wire from the function node
vec![Wire::new(func_node, OutgoingPort::from(0))],
)?;

Ok(load_n.out_wire(0))
}

/// Return a builder for a [`crate::ops::TailLoop`] node.
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
Expand Down
11 changes: 11 additions & 0 deletions hugr/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ pub enum SignatureError {
cached: FunctionType,
expected: FunctionType,
},
/// The result of the type application stored in a [LoadFunction]
/// is not what we get by applying the type-args to the polymorphic function
///
/// [LoadFunction]: crate::ops::dataflow::LoadFunction
#[error(
"Incorrect result of type application in LoadFunction - cached {cached} but expected {expected}"
)]
LoadFunctionIncorrectlyAppliesType {
cached: FunctionType,
expected: FunctionType,
},
}

/// Concrete instantiations of types and operations defined in extensions.
Expand Down
4 changes: 4 additions & 0 deletions hugr/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,10 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
c.validate(self.extension_registry)
.map_err(|cause| ValidationError::SignatureError { node, cause })?;
}
OpType::LoadFunction(c) => {
c.validate(self.extension_registry)
.map_err(|cause| ValidationError::SignatureError { node, cause })?;
}
_ => (),
}

Expand Down
24 changes: 24 additions & 0 deletions hugr/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,30 @@ fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[test]
fn test_polymorphic_load() -> Result<(), Box<dyn std::error::Error>> {
let mut m = ModuleBuilder::new();
let id = m.declare(
"id",
PolyFuncType::new(
vec![TypeBound::Any.into()],
FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]),
),
)?;
let sig = FunctionType::new(
vec![],
vec![Type::new_function(FunctionType::new_endo(vec![USIZE_T]))],
);
let mut f = m.define_function("main", sig.into())?;
let l = f.load_func(&id, &[USIZE_T.into()], &PRELUDE_REGISTRY)?;
f.finish_with_outputs([l])?;
let _ = m.finish_prelude_hugr()?;
Ok(())

// Function(PolyFuncType { params: [Type { b: Any }], body: FunctionType { input: TypeRow { types: [Type(Variable(0, Any), Any)] }, output: TypeRow { types: [Type(Variable(0, Any), Any)] }, extension_reqs: ExtensionSet({}) } })
// Value(Type(Extension(CustomType { extension: IdentList("prelude"), id: "usize", args: []
}

#[cfg(feature = "extension_inference")]
mod extension_tests {
use super::*;
Expand Down
10 changes: 8 additions & 2 deletions hugr/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ use enum_dispatch::enum_dispatch;
pub use constant::Const;
pub use controlflow::{BasicBlock, Case, Conditional, DataflowBlock, ExitBlock, TailLoop, CFG};
pub use custom::CustomOp;
pub use dataflow::{Call, CallIndirect, DataflowParent, Input, LoadConstant, Output, DFG};
pub use dataflow::{
Call, CallIndirect, DataflowParent, Input, LoadConstant, LoadFunction, Output, DFG,
};
pub use leaf::{Lift, MakeTuple, Noop, Tag, UnpackTuple};
pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module};
pub use tag::OpTag;
Expand All @@ -47,6 +49,7 @@ pub enum OpType {
Call,
CallIndirect,
LoadConstant,
LoadFunction,
DFG,
CustomOp,
Noop,
Expand Down Expand Up @@ -105,6 +108,7 @@ impl_op_ref_try_into!(Output);
impl_op_ref_try_into!(Call);
impl_op_ref_try_into!(CallIndirect);
impl_op_ref_try_into!(LoadConstant);
impl_op_ref_try_into!(LoadFunction);
impl_op_ref_try_into!(DFG, dfg);
impl_op_ref_try_into!(CustomOp);
impl_op_ref_try_into!(Noop);
Expand Down Expand Up @@ -226,7 +230,8 @@ impl OpType {
Some(Port::new(dir, self.value_port_count(dir)))
}

/// If the op has a static input ([`Call`] and [`LoadConstant`]), the port of that input.
/// If the op has a static input ([`Call`], [`LoadConstant`], and [`LoadFunction`]), the port of
/// that input.
#[inline]
pub fn static_input_port(&self) -> Option<IncomingPort> {
self.static_port(Direction::Incoming)
Expand Down Expand Up @@ -413,6 +418,7 @@ impl OpParent for Output {}
impl OpParent for Call {}
impl OpParent for CallIndirect {}
impl OpParent for LoadConstant {}
impl OpParent for LoadFunction {}
impl OpParent for CustomOp {}
impl OpParent for Noop {}
impl OpParent for MakeTuple {}
Expand Down
80 changes: 80 additions & 0 deletions hugr/src/ops/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,86 @@ impl LoadConstant {
}
}

/// Load a static function in to the local dataflow graph.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct LoadFunction {
/// Signature of the function
func_sig: PolyFuncType,
type_args: Vec<TypeArg>,
signature: FunctionType, // Cache, so we can fail in try_new() not in signature()
}
impl_op_name!(LoadFunction);
impl DataflowOpTrait for LoadFunction {
const TAG: OpTag = OpTag::LoadFunc;

fn description(&self) -> &str {
"Load a static function in to the local dataflow graph"
}

fn signature(&self) -> FunctionType {
self.signature.clone()
}

fn static_input(&self) -> Option<EdgeKind> {
Some(EdgeKind::Function(self.func_sig.clone()))
}
}
impl LoadFunction {
/// Try to make a new LoadFunction op. Returns an error if the `type_args`` do not fit
/// the [TypeParam]s declared by the function.
///
/// [TypeParam]: crate::types::type_param::TypeParam
pub fn try_new(
func_sig: PolyFuncType,
type_args: impl Into<Vec<TypeArg>>,
exts: &ExtensionRegistry,
) -> Result<Self, SignatureError> {
let type_args = type_args.into();
let instantiation = func_sig.instantiate(&type_args, exts)?;
let signature = FunctionType::new(TypeRow::new(), vec![Type::new_function(instantiation)]);
Ok(Self {
func_sig,
type_args,
signature,
})
}

#[inline]
/// Return the type of the function loaded by this op.
pub fn function_type(&self) -> &PolyFuncType {
&self.func_sig
}

/// The IncomingPort which links to the loaded function.
///
/// This matches [`OpType::static_input_port`].
///
/// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
#[inline]
pub fn function_port(&self) -> IncomingPort {
0.into()
}

pub(crate) fn validate(
&self,
extension_registry: &ExtensionRegistry,
) -> Result<(), SignatureError> {
let other = Self::try_new(
self.func_sig.clone(),
self.type_args.clone(),
extension_registry,
)?;
if other.signature == self.signature {
Ok(())
} else {
Err(SignatureError::LoadFunctionIncorrectlyAppliesType {
cached: self.signature.clone(),
expected: other.signature.clone(),
})
}
}
}

/// Operations that is the parent of a dataflow graph.
pub trait DataflowParent {
/// Signature of the inner dataflow graph.
Expand Down
6 changes: 5 additions & 1 deletion hugr/src/ops/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ pub enum OpTag {
FnCall,
/// A constant load operation.
LoadConst,
/// A function load operation.
LoadFunc,
/// A definition that could be at module level or inside a DSG.
ScopedDefn,
/// A tail-recursive loop.
Expand Down Expand Up @@ -129,6 +131,7 @@ impl OpTag {
OpTag::StaticOutput => &[OpTag::Any],
OpTag::FnCall => &[OpTag::StaticInput, OpTag::DataflowChild],
OpTag::LoadConst => &[OpTag::StaticInput, OpTag::DataflowChild],
OpTag::LoadFunc => &[OpTag::StaticInput, OpTag::DataflowChild],
OpTag::Leaf => &[OpTag::DataflowChild],
OpTag::DataflowParent => &[OpTag::Any],
}
Expand Down Expand Up @@ -156,10 +159,11 @@ impl OpTag {
OpTag::Cfg => "Nested control-flow operation",
OpTag::TailLoop => "Tail-recursive loop",
OpTag::Conditional => "Conditional operation",
OpTag::StaticInput => "Node with static input (LoadConst or FnCall)",
OpTag::StaticInput => "Node with static input (LoadConst, LoadFunc, or FnCall)",
OpTag::StaticOutput => "Node with static output (FuncDefn, FuncDecl, Const)",
OpTag::FnCall => "Function call",
OpTag::LoadConst => "Constant load operation",
OpTag::LoadFunc => "Function load operation",
OpTag::Leaf => "Leaf operation",
OpTag::ScopedDefn => "Definitions that can live at global or local scope",
OpTag::DataflowParent => "Operation whose children form a Dataflow Sibling Graph",
Expand Down
3 changes: 2 additions & 1 deletion hugr/src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ mod test {

use super::{
AliasDecl, AliasDefn, Call, CallIndirect, Const, CustomOp, FuncDecl, Input, Lift, LoadConstant,
MakeTuple, Noop, Output, Tag, UnpackTuple,
LoadFunction, MakeTuple, Noop, Output, Tag, UnpackTuple,
};
impl_validate_op!(FuncDecl);
impl_validate_op!(AliasDecl);
Expand All @@ -415,6 +415,7 @@ impl_validate_op!(Output);
impl_validate_op!(Const);
impl_validate_op!(Call);
impl_validate_op!(LoadConstant);
impl_validate_op!(LoadFunction);
impl_validate_op!(CallIndirect);
impl_validate_op!(CustomOp);
impl_validate_op!(Noop);
Expand Down

0 comments on commit 7a15fda

Please sign in to comment.