Skip to content

Commit

Permalink
feat: Add LoadNat operation to enable loading generic BoundedNats…
Browse files Browse the repository at this point in the history
… into runtime values (#1763)

Closes #1629
  • Loading branch information
tatiana-s authored Dec 12, 2024
1 parent 187ea8f commit 6f035d6
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 1 deletion.
4 changes: 4 additions & 0 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ mod unwrap_builder;

pub use unwrap_builder::UnwrapBuilder;

/// Operation to load generic bounded nat parameter.
pub mod generic;

/// Name of prelude extension.
pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude");
/// Extension version.
Expand Down Expand Up @@ -109,6 +112,7 @@ lazy_static! {
TupleOpDef::load_all_ops(prelude, extension_ref).unwrap();
NoopDef.add_to_extension(prelude, extension_ref).unwrap();
LiftDef.add_to_extension(prelude, extension_ref).unwrap();
generic::LoadNatDef.add_to_extension(prelude, extension_ref).unwrap();
})
};

Expand Down
214 changes: 214 additions & 0 deletions hugr-core/src/extension/prelude/generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
use std::str::FromStr;
use std::sync::{Arc, Weak};

use crate::extension::prelude::usize_custom_t;
use crate::extension::simple_op::{
HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
};
use crate::extension::OpDef;
use crate::extension::SignatureFunc;
use crate::extension::{ConstFold, ExtensionId};
use crate::ops::ExtensionOp;
use crate::ops::NamedOp;
use crate::ops::OpName;
use crate::type_row;
use crate::types::FuncValueType;

use crate::types::Type;

use crate::extension::SignatureError;

use crate::types::PolyFuncTypeRV;

use crate::types::type_param::TypeArg;
use crate::Extension;

use super::{ConstUsize, PRELUDE_ID};
use super::{PRELUDE, PRELUDE_REGISTRY};
use crate::types::type_param::TypeParam;

/// Name of the operation for loading generic BoundedNat parameters.
pub const LOAD_NAT_OP_ID: OpName = OpName::new_inline("load_nat");

/// Definition of the load nat operation.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub struct LoadNatDef;

impl NamedOp for LoadNatDef {
fn name(&self) -> OpName {
LOAD_NAT_OP_ID
}
}

impl FromStr for LoadNatDef {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == LoadNatDef.name() {
Ok(Self)
} else {
Err(())
}
}
}

impl ConstFold for LoadNatDef {
fn fold(
&self,
type_args: &[TypeArg],
_consts: &[(crate::IncomingPort, crate::ops::Value)],
) -> crate::extension::ConstFoldResult {
let [arg] = type_args else {
return None;
};
let nat = arg.as_nat();
if let Some(n) = nat {
let n_const = ConstUsize::new(n);
Some(vec![(0.into(), n_const.into())])
} else {
None
}
}
}

impl MakeOpDef for LoadNatDef {
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
where
Self: Sized,
{
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
}

fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
let usize_t: Type = usize_custom_t(_extension_ref).into();
let params = vec![TypeParam::max_nat()];
PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![usize_t])).into()
}

fn extension_ref(&self) -> Weak<Extension> {
Arc::downgrade(&PRELUDE)
}

fn extension(&self) -> ExtensionId {
PRELUDE_ID
}

fn description(&self) -> String {
"Loads a generic bounded nat parameter into a usize runtime value.".into()
}

fn post_opdef(&self, def: &mut OpDef) {
def.set_constant_folder(*self);
}
}

/// Concrete load nat operation.
#[derive(Clone, Debug, PartialEq)]
pub struct LoadNat {
nat: TypeArg,
}

impl LoadNat {
fn new(nat: TypeArg) -> Self {
LoadNat { nat }
}
}

impl NamedOp for LoadNat {
fn name(&self) -> OpName {
LOAD_NAT_OP_ID
}
}

impl MakeExtensionOp for LoadNat {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
let def = LoadNatDef::from_def(ext_op.def())?;
def.instantiate(ext_op.args())
}

fn type_args(&self) -> Vec<TypeArg> {
vec![self.nat.clone()]
}
}

impl MakeRegisteredOp for LoadNat {
fn extension_id(&self) -> ExtensionId {
PRELUDE_ID
}

fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry {
&PRELUDE_REGISTRY
}
}

impl HasDef for LoadNat {
type Def = LoadNatDef;
}

impl HasConcrete for LoadNatDef {
type Concrete = LoadNat;

fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
match type_args {
[n] => Ok(LoadNat::new(n.clone())),
_ => Err(SignatureError::InvalidTypeArgs.into()),
}
}
}

#[cfg(test)]
mod tests {
use crate::{
builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::{usize_t, ConstUsize},
ops::{constant, OpType},
type_row,
types::TypeArg,
HugrView, OutgoingPort,
};

use super::LoadNat;

#[test]
fn test_load_nat() {
let mut b = DFGBuilder::new(inout_sig(type_row![], vec![usize_t()])).unwrap();

let arg = TypeArg::BoundedNat { n: 4 };
let op = LoadNat::new(arg);

let out = b.add_dataflow_op(op.clone(), []).unwrap();

let result = b.finish_hugr_with_outputs(out.outputs()).unwrap();

let exp_optype: OpType = op.into();

for child in result.children(result.root()) {
let node_optype = result.get_optype(child);
// The only node in the HUGR besides Input and Output should be LoadNat.
if !node_optype.is_input() && !node_optype.is_output() {
assert_eq!(node_optype, &exp_optype)
}
}
}

#[test]
fn test_load_nat_fold() {
let arg = TypeArg::BoundedNat { n: 5 };
let op = LoadNat::new(arg);

let optype: OpType = op.into();

match optype {
OpType::ExtensionOp(ext_op) => {
let result = ext_op.constant_fold(&[]);
let exp_port: OutgoingPort = 0.into();
let exp_val: constant::Value = ConstUsize::new(5).into();
assert_eq!(result, Some(vec![(exp_port, exp_val)]))
}
_ => panic!(),
}
}
}
26 changes: 25 additions & 1 deletion hugr-core/src/types/type_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::num::NonZeroU64;
use thiserror::Error;

use super::row_var::MaybeRV;
use super::{check_typevar_decl, RowVariable, Substitution, Type, TypeBase, TypeBound};
use super::{check_typevar_decl, NoRV, RowVariable, Substitution, Type, TypeBase, TypeBound};
use crate::extension::ExtensionRegistry;
use crate::extension::ExtensionSet;
use crate::extension::SignatureError;
Expand Down Expand Up @@ -252,6 +252,30 @@ impl TypeArg {
}
}

/// Returns an integer if the TypeArg is an instance of BoundedNat.
pub fn as_nat(&self) -> Option<u64> {
match self {
TypeArg::BoundedNat { n } => Some(*n),
_ => None,
}
}

/// Returns a type if the TypeArg is an instance of Type.
pub fn as_type(&self) -> Option<TypeBase<NoRV>> {
match self {
TypeArg::Type { ty } => Some(ty.clone()),
_ => None,
}
}

/// Returns a string if the TypeArg is an instance of String.
pub fn as_string(&self) -> Option<String> {
match self {
TypeArg::String { arg } => Some(arg.clone()),
_ => None,
}
}

/// Much as [Type::validate], also checks that the type of any [TypeArg::Opaque]
/// is valid and closed.
pub(crate) fn validate(
Expand Down
23 changes: 23 additions & 0 deletions hugr-py/src/hugr/std/_json_defs/prelude.json
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,29 @@
},
"binary": false
},
"load_nat": {
"extension": "prelude",
"name": "load_nat",
"description": "Loads a generic bounded nat parameter into a usize runtime value.",
"signature": {
"params": [
{
"tp": "BoundedNat",
"bound": null
}
],
"body": {
"input": [],
"output": [
{
"t": "I"
}
],
"extension_reqs": []
}
},
"binary": false
},
"panic": {
"extension": "prelude",
"name": "panic",
Expand Down
23 changes: 23 additions & 0 deletions specification/std_extensions/prelude.json
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,29 @@
},
"binary": false
},
"load_nat": {
"extension": "prelude",
"name": "load_nat",
"description": "Loads a generic bounded nat parameter into a usize runtime value.",
"signature": {
"params": [
{
"tp": "BoundedNat",
"bound": null
}
],
"body": {
"input": [],
"output": [
{
"t": "I"
}
],
"extension_reqs": []
}
},
"binary": false
},
"panic": {
"extension": "prelude",
"name": "panic",
Expand Down

0 comments on commit 6f035d6

Please sign in to comment.