Skip to content

Commit

Permalink
feat!: OpDefs and TypeDefs keep a reference to their extension
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Nov 26, 2024
1 parent c5c8a6f commit 1d8d896
Show file tree
Hide file tree
Showing 28 changed files with 829 additions and 550 deletions.
16 changes: 13 additions & 3 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ mod test {
use super::*;
use cool_asserts::assert_matches;

use crate::extension::{ExtensionId, ExtensionSet};
use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY};
use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
use crate::utils::test_quantum_extension::{
self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
Expand Down Expand Up @@ -298,8 +298,18 @@ mod test {
#[test]
fn with_nonlinear_and_outputs() {
let my_ext_name: ExtensionId = "MyExt".try_into().unwrap();
let mut my_ext = Extension::new_test(my_ext_name.clone());
let my_custom_op = my_ext.simple_ext_op("MyOp", Signature::new(vec![QB, NAT], vec![QB]));
let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| {
ext.add_op(
"MyOp".into(),
"".to_string(),
Signature::new(vec![QB, NAT], vec![QB]),
extension_ref,
)
.unwrap();
});
let my_custom_op = my_ext
.instantiate_extension_op("MyOp", [], &PRELUDE_REGISTRY)
.unwrap();

let build_res = build_main(
Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,10 @@ impl<'a> Context<'a> {

let poly_func_type = match opdef.signature_func() {
SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type,
_ => return self.make_named_global_ref(opdef.extension(), opdef.name()),
_ => return self.make_named_global_ref(opdef.extension_id(), opdef.name()),
};

let key = (opdef.extension().clone(), opdef.name().clone());
let key = (opdef.extension_id().clone(), opdef.name().clone());
let entry = self.decl_operations.entry(key);

let node = match entry {
Expand All @@ -467,7 +467,7 @@ impl<'a> Context<'a> {
};

let decl = self.with_local_scope(node, |this| {
let name = this.make_qualified_name(opdef.extension(), opdef.name());
let name = this.make_qualified_name(opdef.extension_id(), opdef.name());
let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type);
let decl = this.bump.alloc(model::OperationDecl {
name,
Expand Down
133 changes: 115 additions & 18 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ pub use semver::Version;
use std::collections::btree_map;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use std::mem;
use std::sync::{Arc, Weak};

use thiserror::Error;

Expand Down Expand Up @@ -335,6 +336,45 @@ impl ExtensionValue {
pub type ExtensionId = IdentList;

/// A extension is a set of capabilities required to execute a graph.
///
/// These are normally defined once and shared across multiple graphs and
/// operations wrapped in [`Arc`]s inside [`ExtensionRegistry`].
///
/// # Example
///
/// The following example demonstrates how to define a new extension with a
/// custom operation and a custom type.
///
/// When using `arc`s, the extension can only be modified at creation time. The
/// defined operations and types keep a [`Weak`] reference to their extension. We provide a
/// helper method [`Extension::new_arc`] to aid their definition.
///
/// ```
/// # use hugr_core::types::Signature;
/// # use hugr_core::extension::{Extension, ExtensionId, Version};
/// # use hugr_core::extension::{TypeDefBound};
/// Extension::new_arc(
/// ExtensionId::new_unchecked("my.extension"),
/// Version::new(0, 1, 0),
/// |ext, extension_ref| {
/// // Add a custom type definition
/// ext.add_type(
/// "MyType".into(),
/// vec![], // No type parameters
/// "Some type".into(),
/// TypeDefBound::any(),
/// extension_ref,
/// );
/// // Add a custom operation
/// ext.add_op(
/// "MyOp".into(),
/// "Some operation".into(),
/// Signature::new_endo(vec![]),
/// extension_ref,
/// );
/// },
/// );
/// ```
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Extension {
/// Extension version, follows semver.
Expand All @@ -361,6 +401,12 @@ pub struct Extension {

impl Extension {
/// Creates a new extension with the given name.
///
/// In most cases extensions are contained inside an [`Arc`] so that they
/// can be shared across hugr instances and operation definitions.
///
/// See [`Extension::new_arc`] for a more ergonomic way to create boxed
/// extensions.
pub fn new(name: ExtensionId, version: Version) -> Self {
Self {
name,
Expand All @@ -372,14 +418,63 @@ impl Extension {
}
}

/// Extend the requirements of this extension with another set of extensions.
pub fn with_reqs(self, extension_reqs: impl Into<ExtensionSet>) -> Self {
Self {
extension_reqs: self.extension_reqs.union(extension_reqs.into()),
..self
/// Creates a new extension wrapped in an [`Arc`].
///
/// The closure lets us use a weak reference to the arc while the extension
/// is being built. This is necessary for calling [`Extension::add_op`] and
/// [`Extension::add_type`].
pub fn new_arc(
name: ExtensionId,
version: Version,
init: impl FnOnce(&mut Extension, &Weak<Extension>),
) -> Arc<Self> {
Arc::new_cyclic(|extension_ref| {
let mut ext = Self::new(name, version);
init(&mut ext, extension_ref);
ext
})
}

/// Creates a new extension wrapped in an [`Arc`], using a fallible
/// initialization function.
///
/// The closure lets us use a weak reference to the arc while the extension
/// is being built. This is necessary for calling [`Extension::add_op`] and
/// [`Extension::add_type`].
pub fn try_new_arc<E>(
name: ExtensionId,
version: Version,
init: impl FnOnce(&mut Extension, &Weak<Extension>) -> Result<(), E>,
) -> Result<Arc<Self>, E> {
// Annoying hack around not having `Arc::try_new_cyclic` that can return
// a Result.
// https://github.com/rust-lang/rust/issues/75861#issuecomment-980455381
//
// When there is an error, we store it in `error` and return it at the
// end instead of the partially-initialized extension.
let mut error = None;
let ext = Arc::new_cyclic(|extension_ref| {
let mut ext = Self::new(name, version);
match init(&mut ext, extension_ref) {
Ok(_) => ext,
Err(e) => {
error = Some(e);
ext
}
}
});
match error {
Some(e) => Err(e),
None => Ok(ext),
}
}

/// Extend the requirements of this extension with another set of extensions.
pub fn set_reqs(&mut self, extension_reqs: impl Into<ExtensionSet>) {
let reqs = mem::take(&mut self.extension_reqs);
self.extension_reqs = reqs.union(extension_reqs.into());
}

/// Allows read-only access to the operations in this Extension
pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
self.operations.get(name)
Expand Down Expand Up @@ -634,20 +729,22 @@ pub mod test {

impl Extension {
/// Create a new extension for testing, with a 0 version.
pub(crate) fn new_test(name: ExtensionId) -> Self {
Self::new(name, Version::new(0, 0, 0))
pub(crate) fn new_test_arc(
name: ExtensionId,
init: impl FnOnce(&mut Extension, &Weak<Extension>),
) -> Arc<Self> {
Self::new_arc(name, Version::new(0, 0, 0), init)
}

/// Add a simple OpDef to the extension and return an extension op for it.
/// No description, no type parameters.
pub(crate) fn simple_ext_op(
&mut self,
name: &str,
signature: impl Into<SignatureFunc>,
) -> ExtensionOp {
self.add_op(name.into(), "".to_string(), signature).unwrap();
self.instantiate_extension_op(name, [], &PRELUDE_REGISTRY)
.unwrap()
/// Create a new extension for testing, with a 0 version.
pub(crate) fn try_new_test_arc(
name: ExtensionId,
init: impl FnOnce(
&mut Extension,
&Weak<Extension>,
) -> Result<(), Box<dyn std::error::Error>>,
) -> Result<Arc<Self>, Box<dyn std::error::Error>> {
Self::try_new_arc(name, Version::new(0, 0, 0), init)
}
}

Expand Down
32 changes: 19 additions & 13 deletions hugr-core/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod types;

use std::fs::File;
use std::path::Path;
use std::sync::Arc;

use crate::extension::prelude::PRELUDE_ID;
use crate::ops::OpName;
Expand Down Expand Up @@ -150,19 +151,24 @@ impl ExtensionDeclaration {
&self,
imports: &ExtensionSet,
ctx: DeclarationContext<'_>,
) -> Result<Extension, ExtensionDeclarationError> {
let mut ext = Extension::new(self.name.clone(), crate::extension::Version::new(0, 0, 0))
.with_reqs(imports.clone());

for t in &self.types {
t.register(&mut ext, ctx)?;
}

for o in &self.operations {
o.register(&mut ext, ctx)?;
}

Ok(ext)
) -> Result<Arc<Extension>, ExtensionDeclarationError> {
Extension::try_new_arc(
self.name.clone(),
// TODO: Get the version as a parameter.
crate::extension::Version::new(0, 0, 0),
|ext, extension_ref| {
for t in &self.types {
t.register(ext, ctx, extension_ref)?;
}

for o in &self.operations {
o.register(ext, ctx, extension_ref)?;
}
ext.set_reqs(imports.clone());

Ok(())
},
)
}
}

Expand Down
12 changes: 11 additions & 1 deletion hugr-core/src/extension/declarative/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration
use std::collections::HashMap;
use std::sync::Weak;

use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
Expand Down Expand Up @@ -55,10 +56,14 @@ pub(super) struct OperationDeclaration {

impl OperationDeclaration {
/// Register this operation in the given extension.
///
/// Requires a [`Weak`] reference to the extension defining the operation.
/// This method is intended to be used inside the closure passed to [`Extension::new_arc`].
pub fn register<'ext>(
&self,
ext: &'ext mut Extension,
ctx: DeclarationContext<'_>,
extension_ref: &Weak<Extension>,
) -> Result<&'ext mut OpDef, ExtensionDeclarationError> {
// We currently only support explicit signatures.
//
Expand Down Expand Up @@ -88,7 +93,12 @@ impl OperationDeclaration {

let signature_func: SignatureFunc = signature.make_signature(ext, ctx, &params)?;

let op_def = ext.add_op(self.name.clone(), self.description.clone(), signature_func)?;
let op_def = ext.add_op(
self.name.clone(),
self.description.clone(),
signature_func,
extension_ref,
)?;

for (k, v) in &self.misc {
op_def.add_misc(k, v.clone());
Expand Down
7 changes: 7 additions & 0 deletions hugr-core/src/extension/declarative/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//! [specification]: https://github.com/CQCL/hugr/blob/main/specification/hugr.md#declarative-format
//! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration
use std::sync::Weak;

use crate::extension::{TypeDef, TypeDefBound};
use crate::types::type_param::TypeParam;
use crate::types::{TypeBound, TypeName};
Expand Down Expand Up @@ -49,10 +51,14 @@ impl TypeDeclaration {
///
/// Types in the definition will be resolved using the extensions in `scope`
/// and the current extension.
///
/// Requires a [`Weak`] reference to the extension defining the operation.
/// This method is intended to be used inside the closure passed to [`Extension::new_arc`].
pub fn register<'ext>(
&self,
ext: &'ext mut Extension,
ctx: DeclarationContext<'_>,
extension_ref: &Weak<Extension>,
) -> Result<&'ext TypeDef, ExtensionDeclarationError> {
let params = self
.params
Expand All @@ -64,6 +70,7 @@ impl TypeDeclaration {
params,
self.description.clone(),
self.bound.into(),
extension_ref,
)?;
Ok(type_def)
}
Expand Down
Loading

0 comments on commit 1d8d896

Please sign in to comment.