Skip to content

Commit

Permalink
refactor!: one way to add_op to extension (#704)
Browse files Browse the repository at this point in the history
Methods to allow appending extra data to opdef after adding to
extension.

BREAKING_CHANGES: `add_op` behaves like the now-removed `add_op_simple`
  • Loading branch information
ss2165 authored Nov 23, 2023
1 parent c2df319 commit 4ab6be4
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 107 deletions.
65 changes: 36 additions & 29 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ impl OpDef {
/// Fallibly returns a Hugr that may replace an instance of this OpDef
/// given a set of available extensions that may be used in the Hugr.
pub fn try_lower(&self, args: &[TypeArg], available_extensions: &ExtensionSet) -> Option<Hugr> {
// TODO test this
self.lower_funcs
.iter()
.flat_map(|f| match f {
Expand Down Expand Up @@ -384,6 +385,20 @@ impl OpDef {
}
Ok(())
}

/// Add a lowering function to the [OpDef]
pub fn add_lower_func(&mut self, lower: LowerFunc) {
self.lower_funcs.push(lower);
}

/// Insert miscellaneous data `v` to the [OpDef], keyed by `k`.
pub fn add_misc(
&mut self,
k: impl ToString,
v: serde_yaml::Value,
) -> Option<serde_yaml::Value> {
self.misc.insert(k.to_string(), v)
}
}

impl Extension {
Expand All @@ -395,41 +410,23 @@ impl Extension {
&mut self,
name: SmolStr,
description: String,
misc: HashMap<String, serde_yaml::Value>,
lower_funcs: Vec<LowerFunc>,
signature_func: impl Into<SignatureFunc>,
) -> Result<&OpDef, ExtensionBuildError> {
) -> Result<&mut OpDef, ExtensionBuildError> {
let op = OpDef {
extension: self.name.clone(),
name,
description,
misc,
signature_func: signature_func.into(),
lower_funcs,
misc: Default::default(),
lower_funcs: Default::default(),
};

match self.operations.entry(op.name.clone()) {
Entry::Occupied(_) => Err(ExtensionBuildError::OpDefExists(op.name)),
Entry::Vacant(ve) => Ok(ve.insert(Arc::new(op))),
// Just made the arc so should only be one reference to it, can get_mut,
Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()),
}
}

/// Create an OpDef with `PolyFuncType`, `impl CustomSignatureFunc` or `CustomValidator`
/// ; and no "misc" or "lowering functions" defined.
pub fn add_op_simple(
&mut self,
name: SmolStr,
description: String,
signature_func: impl Into<SignatureFunc>,
) -> Result<&OpDef, ExtensionBuildError> {
self.add_op(
name,
description,
HashMap::default(),
Vec::new(),
signature_func,
)
}
}

#[cfg(test)]
Expand All @@ -440,16 +437,18 @@ mod test {

use super::SignatureFromArgs;
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::op_def::LowerFunc;
use crate::extension::prelude::USIZE_T;
use crate::extension::{
ExtensionRegistry, SignatureError, EMPTY_REG, PRELUDE, PRELUDE_REGISTRY,
};
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY};
use crate::ops::custom::ExternalOp;
use crate::ops::LeafOp;
use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME};
use crate::types::Type;
use crate::types::{type_param::TypeParam, FunctionType, PolyFuncType, TypeArg, TypeBound};
use crate::Hugr;
use crate::{const_extension_ids, Extension};

const_extension_ids! {
const EXT_ID: ExtensionId = "MyExt";
}
Expand All @@ -463,7 +462,14 @@ mod test {
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
const OP_NAME: SmolStr = SmolStr::new_inline("Reverse");
let type_scheme = PolyFuncType::new(vec![TP], FunctionType::new_endo(vec![list_of_var]));
e.add_op(OP_NAME, "".into(), Default::default(), vec![], type_scheme)?;

let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?;
def.add_lower_func(LowerFunc::FixedHugr(ExtensionSet::new(), Hugr::default()));
def.add_misc("key", Default::default());
assert_eq!(def.description(), "desc");
assert_eq!(def.lower_funcs.len(), 1);
assert_eq!(def.misc.len(), 1);

let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned(), e]).unwrap();
let e = reg.get(&EXT_ID).unwrap();
Expand Down Expand Up @@ -515,7 +521,8 @@ mod test {
}
}
let mut e = Extension::new(EXT_ID);
let def = e.add_op_simple("MyOp".into(), "".to_string(), SigFun())?;
let def: &mut crate::extension::OpDef =
e.add_op("MyOp".into(), "".to_string(), SigFun())?;

// Base case, no type variables:
let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()];
Expand Down Expand Up @@ -576,7 +583,7 @@ mod test {
// Check that we can instantiate a PolyFuncType-scheme with an (external)
// type variable
let mut e = Extension::new(EXT_ID);
let def = e.add_op_simple(
let def = e.add_op(
"SimpleOp".into(),
"".into(),
PolyFuncType::new(
Expand Down
2 changes: 1 addition & 1 deletion src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ lazy_static! {
)
.unwrap();
prelude
.add_op_simple(
.add_op(
SmolStr::new_inline(NEW_ARRAY_OP_ID),
"Create a new array from elements".to_string(),
ArrayOpCustom,
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@
//! let mut extension = Extension::new(EXTENSION_ID);
//!
//! extension
//! .add_op_simple(SmolStr::new_inline("H"), "Hadamard".into(), one_qb_func())
//! .add_op(SmolStr::new_inline("H"), "Hadamard".into(), one_qb_func())
//! .unwrap();
//!
//! extension
//! .add_op_simple(SmolStr::new_inline("CX"), "CX".into(), two_qb_func())
//! .add_op(SmolStr::new_inline("CX"), "CX".into(), two_qb_func())
//! .unwrap();
//!
//! extension
//! .add_op_simple(
//! .add_op(
//! SmolStr::new_inline("Measure"),
//! "Measure a qubit, returning the qubit and the measurement result.".into(),
//! FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T]),
Expand Down
8 changes: 4 additions & 4 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,24 @@ pub fn extension() -> Extension {
]),
);
extension
.add_op_simple(
.add_op(
"trunc_u".into(),
"float to unsigned int".to_owned(),
ftoi_sig.clone(),
)
.unwrap();
extension
.add_op_simple("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig)
.add_op("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig)
.unwrap();
extension
.add_op_simple(
.add_op(
"convert_u".into(),
"unsigned int to float".to_owned(),
itof_sig.clone(),
)
.unwrap();
extension
.add_op_simple(
.add_op(
"convert_s".into(),
"signed int to float".to_owned(),
itof_sig,
Expand Down
32 changes: 16 additions & 16 deletions src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,72 +29,72 @@ pub fn extension() -> Extension {
let funop_sig: PolyFuncType =
FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into();
extension
.add_op_simple("feq".into(), "equality test".to_owned(), fcmp_sig.clone())
.add_op("feq".into(), "equality test".to_owned(), fcmp_sig.clone())
.unwrap();
extension
.add_op_simple("fne".into(), "inequality test".to_owned(), fcmp_sig.clone())
.add_op("fne".into(), "inequality test".to_owned(), fcmp_sig.clone())
.unwrap();
extension
.add_op_simple("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone())
.add_op("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone())
.unwrap();
extension
.add_op_simple(
.add_op(
"fgt".into(),
"\"greater than\"".to_owned(),
fcmp_sig.clone(),
)
.unwrap();
extension
.add_op_simple(
.add_op(
"fle".into(),
"\"less than or equal\"".to_owned(),
fcmp_sig.clone(),
)
.unwrap();
extension
.add_op_simple(
.add_op(
"fge".into(),
"\"greater than or equal\"".to_owned(),
fcmp_sig,
)
.unwrap();
extension
.add_op_simple("fmax".into(), "maximum".to_owned(), fbinop_sig.clone())
.add_op("fmax".into(), "maximum".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_simple("fmin".into(), "minimum".to_owned(), fbinop_sig.clone())
.add_op("fmin".into(), "minimum".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_simple("fadd".into(), "addition".to_owned(), fbinop_sig.clone())
.add_op("fadd".into(), "addition".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_simple("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone())
.add_op("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_simple("fneg".into(), "negation".to_owned(), funop_sig.clone())
.add_op("fneg".into(), "negation".to_owned(), funop_sig.clone())
.unwrap();
extension
.add_op_simple(
.add_op(
"fabs".into(),
"absolute value".to_owned(),
funop_sig.clone(),
)
.unwrap();
extension
.add_op_simple(
.add_op(
"fmul".into(),
"multiplication".to_owned(),
fbinop_sig.clone(),
)
.unwrap();
extension
.add_op_simple("fdiv".into(), "division".to_owned(), fbinop_sig)
.add_op("fdiv".into(), "division".to_owned(), fbinop_sig)
.unwrap();
extension
.add_op_simple("ffloor".into(), "floor".to_owned(), funop_sig.clone())
.add_op("ffloor".into(), "floor".to_owned(), funop_sig.clone())
.unwrap();
extension
.add_op_simple("fceil".into(), "ceiling".to_owned(), funop_sig)
.add_op("fceil".into(), "ceiling".to_owned(), funop_sig)
.unwrap();

extension
Expand Down
Loading

0 comments on commit 4ab6be4

Please sign in to comment.