diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 6590bfd5a..34a46409b 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -31,6 +31,9 @@ mod unwrap_builder; pub use unwrap_builder::UnwrapBuilder; +/// Operation to load generic bounded nat parameter. +pub mod generic; + /// Name of prelude extension. pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude"); /// Extension version. @@ -109,6 +112,7 @@ lazy_static! { TupleOpDef::load_all_ops(prelude, extension_ref).unwrap(); NoopDef.add_to_extension(prelude, extension_ref).unwrap(); LiftDef.add_to_extension(prelude, extension_ref).unwrap(); + generic::LoadNatDef.add_to_extension(prelude, extension_ref).unwrap(); }) }; diff --git a/hugr-core/src/extension/prelude/generic.rs b/hugr-core/src/extension/prelude/generic.rs new file mode 100644 index 000000000..b79bd40bf --- /dev/null +++ b/hugr-core/src/extension/prelude/generic.rs @@ -0,0 +1,214 @@ +use std::str::FromStr; +use std::sync::{Arc, Weak}; + +use crate::extension::prelude::usize_custom_t; +use crate::extension::simple_op::{ + HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, +}; +use crate::extension::OpDef; +use crate::extension::SignatureFunc; +use crate::extension::{ConstFold, ExtensionId}; +use crate::ops::ExtensionOp; +use crate::ops::NamedOp; +use crate::ops::OpName; +use crate::type_row; +use crate::types::FuncValueType; + +use crate::types::Type; + +use crate::extension::SignatureError; + +use crate::types::PolyFuncTypeRV; + +use crate::types::type_param::TypeArg; +use crate::Extension; + +use super::{ConstUsize, PRELUDE_ID}; +use super::{PRELUDE, PRELUDE_REGISTRY}; +use crate::types::type_param::TypeParam; + +/// Name of the operation for loading generic BoundedNat parameters. +pub const LOAD_NAT_OP_ID: OpName = OpName::new_inline("load_nat"); + +/// Definition of the load nat operation. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub struct LoadNatDef; + +impl NamedOp for LoadNatDef { + fn name(&self) -> OpName { + LOAD_NAT_OP_ID + } +} + +impl FromStr for LoadNatDef { + type Err = (); + + fn from_str(s: &str) -> Result { + if s == LoadNatDef.name() { + Ok(Self) + } else { + Err(()) + } + } +} + +impl ConstFold for LoadNatDef { + fn fold( + &self, + type_args: &[TypeArg], + _consts: &[(crate::IncomingPort, crate::ops::Value)], + ) -> crate::extension::ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let nat = arg.as_nat(); + if let Some(n) = nat { + let n_const = ConstUsize::new(n); + Some(vec![(0.into(), n_const.into())]) + } else { + None + } + } +} + +impl MakeOpDef for LoadNatDef { + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized, + { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { + let usize_t: Type = usize_custom_t(_extension_ref).into(); + let params = vec![TypeParam::max_nat()]; + PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![usize_t])).into() + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) + } + + fn extension(&self) -> ExtensionId { + PRELUDE_ID + } + + fn description(&self) -> String { + "Loads a generic bounded nat parameter into a usize runtime value.".into() + } + + fn post_opdef(&self, def: &mut OpDef) { + def.set_constant_folder(*self); + } +} + +/// Concrete load nat operation. +#[derive(Clone, Debug, PartialEq)] +pub struct LoadNat { + nat: TypeArg, +} + +impl LoadNat { + fn new(nat: TypeArg) -> Self { + LoadNat { nat } + } +} + +impl NamedOp for LoadNat { + fn name(&self) -> OpName { + LOAD_NAT_OP_ID + } +} + +impl MakeExtensionOp for LoadNat { + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + let def = LoadNatDef::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![self.nat.clone()] + } +} + +impl MakeRegisteredOp for LoadNat { + fn extension_id(&self) -> ExtensionId { + PRELUDE_ID + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { + &PRELUDE_REGISTRY + } +} + +impl HasDef for LoadNat { + type Def = LoadNatDef; +} + +impl HasConcrete for LoadNatDef { + type Concrete = LoadNat; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + match type_args { + [n] => Ok(LoadNat::new(n.clone())), + _ => Err(SignatureError::InvalidTypeArgs.into()), + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, + extension::prelude::{usize_t, ConstUsize}, + ops::{constant, OpType}, + type_row, + types::TypeArg, + HugrView, OutgoingPort, + }; + + use super::LoadNat; + + #[test] + fn test_load_nat() { + let mut b = DFGBuilder::new(inout_sig(type_row![], vec![usize_t()])).unwrap(); + + let arg = TypeArg::BoundedNat { n: 4 }; + let op = LoadNat::new(arg); + + let out = b.add_dataflow_op(op.clone(), []).unwrap(); + + let result = b.finish_hugr_with_outputs(out.outputs()).unwrap(); + + let exp_optype: OpType = op.into(); + + for child in result.children(result.root()) { + let node_optype = result.get_optype(child); + // The only node in the HUGR besides Input and Output should be LoadNat. + if !node_optype.is_input() && !node_optype.is_output() { + assert_eq!(node_optype, &exp_optype) + } + } + } + + #[test] + fn test_load_nat_fold() { + let arg = TypeArg::BoundedNat { n: 5 }; + let op = LoadNat::new(arg); + + let optype: OpType = op.into(); + + match optype { + OpType::ExtensionOp(ext_op) => { + let result = ext_op.constant_fold(&[]); + let exp_port: OutgoingPort = 0.into(); + let exp_val: constant::Value = ConstUsize::new(5).into(); + assert_eq!(result, Some(vec![(exp_port, exp_val)])) + } + _ => panic!(), + } + } +} diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index 3c40552aa..a079812aa 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -10,7 +10,7 @@ use std::num::NonZeroU64; use thiserror::Error; use super::row_var::MaybeRV; -use super::{check_typevar_decl, RowVariable, Substitution, Type, TypeBase, TypeBound}; +use super::{check_typevar_decl, NoRV, RowVariable, Substitution, Type, TypeBase, TypeBound}; use crate::extension::ExtensionRegistry; use crate::extension::ExtensionSet; use crate::extension::SignatureError; @@ -252,6 +252,30 @@ impl TypeArg { } } + /// Returns an integer if the TypeArg is an instance of BoundedNat. + pub fn as_nat(&self) -> Option { + match self { + TypeArg::BoundedNat { n } => Some(*n), + _ => None, + } + } + + /// Returns a type if the TypeArg is an instance of Type. + pub fn as_type(&self) -> Option> { + match self { + TypeArg::Type { ty } => Some(ty.clone()), + _ => None, + } + } + + /// Returns a string if the TypeArg is an instance of String. + pub fn as_string(&self) -> Option { + match self { + TypeArg::String { arg } => Some(arg.clone()), + _ => None, + } + } + /// Much as [Type::validate], also checks that the type of any [TypeArg::Opaque] /// is valid and closed. pub(crate) fn validate( diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index 1a12a6383..5aff50947 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -200,6 +200,29 @@ }, "binary": false }, + "load_nat": { + "extension": "prelude", + "name": "load_nat", + "description": "Loads a generic bounded nat parameter into a usize runtime value.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + } + ], + "body": { + "input": [], + "output": [ + { + "t": "I" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "panic": { "extension": "prelude", "name": "panic", diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index 1a12a6383..5aff50947 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -200,6 +200,29 @@ }, "binary": false }, + "load_nat": { + "extension": "prelude", + "name": "load_nat", + "description": "Loads a generic bounded nat parameter into a usize runtime value.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + } + ], + "body": { + "input": [], + "output": [ + { + "t": "I" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "panic": { "extension": "prelude", "name": "panic",