From d69e3a188bb8c2039b6b760a5fdc26328f31a340 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 12 Feb 2024 12:03:36 +0000 Subject: [PATCH] feat: Minimal implementation for YAML extensions (#833) This PR add a method `hugr::extensions::declarative::load_extensions` that dynamically loads a set of extensions defined as YAML onto a registry. The code is mostly comprised of struct definitions to match the human-readable serialisation format described in the spec, and some methods to translate them into the internal hugr definitions. There's a myriad of TODOs that should be addressed in future PRs, including: - Most parametric things (operations, type bounds, number of ports in a signature, ...). - Lowering functions, operations with non explicit signatures. - Resolving the signature types. - The syntax for describing these is not defined in the spec, so currently there's just a couple of basic hard-coded types used for testing: "Q" and "USize". Here's an example of a supported definition: ```yaml imports: [prelude] extensions: - name: SimpleExt types: - name: MyType description: A simple type with no parameters operations: - name: MyOperation description: A simple operation with no inputs nor outputs signature: inputs: [] outputs: [] - name: AnotherOperation description: An operation from 2 qubits to 2 qubits signature: inputs: [["Target", Q], ["Control", Q, 1]] outputs: [[null, Q, 2]] ``` --- examples/extension/declarative.yaml | 31 ++ specification/hugr.md | 19 +- src/extension.rs | 84 +++++- src/extension/declarative.rs | 400 +++++++++++++++++++++++++ src/extension/declarative/ops.rs | 115 +++++++ src/extension/declarative/signature.rs | 203 +++++++++++++ src/extension/declarative/types.rs | 170 +++++++++++ src/extension/type_def.rs | 12 +- src/types.rs | 4 + src/types/custom.rs | 10 +- src/utils.rs | 17 ++ 11 files changed, 1031 insertions(+), 34 deletions(-) create mode 100644 examples/extension/declarative.yaml create mode 100644 src/extension/declarative.rs create mode 100644 src/extension/declarative/ops.rs create mode 100644 src/extension/declarative/signature.rs create mode 100644 src/extension/declarative/types.rs diff --git a/examples/extension/declarative.yaml b/examples/extension/declarative.yaml new file mode 100644 index 000000000..ad65b0805 --- /dev/null +++ b/examples/extension/declarative.yaml @@ -0,0 +1,31 @@ +# Optionally import other extensions. The `prelude` is always imported. +imports: [logic] + +extensions: + - # Each extension must have a name + name: SimpleExt + types: + - # Types must have a name. + # Parameters are not currently supported. + name: Copyable type + description: A simple type with no parameters + # Types may have a "Eq", "Copyable", or "Any" bound. + # This field is optional and defaults to "Any". + bound: Copyable + operations: + - # Operations must have a name and a signature. + name: MyOperation + description: A simple operation with no inputs nor outputs + signature: + inputs: [] + outputs: [] + - name: AnotherOperation + description: An operation from 3 qubits to 3 qubits + signature: + # The input and outputs can be written directly as the types + inputs: [Q, Q, Q] + outputs: + - # Or as the type followed by a number of repetitions. + [Q, 1] + - # Or as a description, followed by the type and a number of repetitions. + [Control, Q, 2] diff --git a/specification/hugr.md b/specification/hugr.md index f278c425b..5099739c0 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -1072,6 +1072,7 @@ extensions: # Declare custom types types: - name: QubitVector + description: "A vector of qubits" # Opaque types can take type arguments, with specified names params: [["size", USize]] operations: @@ -1094,16 +1095,16 @@ extensions: - name: SU2 description: "One qubit unitary matrix" params: # per-node values passed to the type-scheme interpreter, but not used in signature - - matrix: Opaque(complex_matrix,2,2) + matrix: Opaque(complex_matrix,2,2) signature: inputs: [[null, Q]] outputs: [[null, Q]] - name: MatMul description: "Multiply matrices of statically-known size" params: # per-node values passed to type-scheme-interpreter and used in signature - - i: USize - - j: USize - - k: USize + i: USize + j: USize + k: USize signature: inputs: [["a", Array(Array(F64))], ["b", Array(Array(F64))]] outputs: [[null, Array(Array(F64))]] @@ -1112,7 +1113,7 @@ extensions: - name: max_float description: "Variable number of inputs" params: - - n: USize + n: USize signature: # Where an element of a signature has three subelements, the third is the number of repeats inputs: [[null, F64, n]] # (defaulting to 1 if omitted) @@ -1120,9 +1121,9 @@ extensions: - name: ArrayConcat description: "Concatenate two arrays. Extension provides a compute_signature implementation." params: - - t: Type # Classic or Quantum - - i: USize - - j: USize + t: Type # Classic or Quantum + i: USize + j: USize # inputs could be: Array(t), Array(t) # outputs would be, in principle: Array(t) # - but default type scheme interpreter does not support such addition @@ -1134,7 +1135,7 @@ extensions: signature: inputs: [[null, Function[r](USize -> USize)], ["arg", USize]] outputs: [[null, USize]] - extensions: r # Indicates that running this operation also invokes extensions r + extensions: [r] # Indicates that running this operation also invokes extensions r lowering: file: "graph_op_hugr.bin" extensions: ["arithmetic.int", r] # r is the ExtensionSet in "params" diff --git a/src/extension.rs b/src/extension.rs index f87925755..6f2859e79 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -3,7 +3,8 @@ //! TODO: YAML declaration and parsing. This should be similar to a plugin //! system (outside the `types` module), which also parses nested [`OpDef`]s. -use std::collections::hash_map::Entry; +use std::collections::btree_map; +use std::collections::hash_map; use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; @@ -16,7 +17,9 @@ use crate::ops; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::types::type_param::{check_type_args, TypeArgError}; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound}; +use crate::types::{ + check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound, TypeName, +}; #[allow(dead_code)] mod infer; @@ -38,6 +41,8 @@ pub mod validate; pub use const_fold::{ConstFold, ConstFoldResult}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; +pub mod declarative; + /// Extension Registries store extensions to be looked up e.g. during validation. #[derive(Clone, Debug)] pub struct ExtensionRegistry(BTreeMap); @@ -48,15 +53,22 @@ impl ExtensionRegistry { self.0.get(name) } + /// Returns `true` if the registry contains an extension with the given name. + pub fn contains(&self, name: &str) -> bool { + self.0.contains_key(name) + } + /// Makes a new ExtensionRegistry, validating all the extensions in it pub fn try_new( value: impl IntoIterator, - ) -> Result { + ) -> Result { let mut exts = BTreeMap::new(); for ext in value.into_iter() { let prev = exts.insert(ext.name.clone(), ext); if let Some(prev) = prev { - panic!("Multiple extensions with same name: {}", prev.name) + return Err(ExtensionRegistryError::AlreadyRegistered( + prev.name().clone(), + )); }; } // Note this potentially asks extensions to validate themselves against other extensions that @@ -66,10 +78,38 @@ impl ExtensionRegistry { // cyclically dependent, so there is no perfect solution, and this is at least simple. let res = ExtensionRegistry(exts); for ext in res.0.values() { - ext.validate(&res).map_err(|e| (ext.name().clone(), e))?; + ext.validate(&res) + .map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?; } Ok(res) } + + /// Registers a new extension to the registry. + /// + /// Returns a reference to the registered extension if successful. + pub fn register(&mut self, extension: Extension) -> Result<&Extension, ExtensionRegistryError> { + match self.0.entry(extension.name().clone()) { + btree_map::Entry::Occupied(_) => Err(ExtensionRegistryError::AlreadyRegistered( + extension.name().clone(), + )), + btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)), + } + } + + /// Returns the number of extensions in the registry. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns `true` if the registry contains no extensions. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Returns an iterator over the extensions in the registry. + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } } impl IntoIterator for ExtensionRegistry { @@ -92,7 +132,7 @@ pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new()); pub enum SignatureError { /// Name mismatch #[error("Definition name ({0}) and instantiation name ({1}) do not match.")] - NameMismatch(SmolStr, SmolStr), + NameMismatch(TypeName, TypeName), /// Extension mismatch #[error("Definition extension ({0:?}) and instantiation extension ({1:?}) do not match.")] ExtensionMismatch(ExtensionId, ExtensionId), @@ -107,7 +147,7 @@ pub enum SignatureError { ExtensionNotFound(ExtensionId), /// The Extension was found in the registry, but did not contain the Type(Def) referenced in the Signature #[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")] - ExtensionTypeNotFound { exn: ExtensionId, typ: SmolStr }, + ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName }, /// The bound recorded for a CustomType doesn't match what the TypeDef would compute #[error("Bound on CustomType ({actual}) did not match TypeDef ({expected})")] WrongBound { @@ -136,8 +176,13 @@ pub enum SignatureError { /// Concrete instantiations of types and operations defined in extensions. trait CustomConcrete { + /// A generic identifier to the element. + /// + /// This may either refer to a [`TypeName`] or an [`OpName`]. fn def_name(&self) -> &SmolStr; + /// The concrete type arguments for the instantiation. fn type_args(&self) -> &[TypeArg]; + /// Extension required by the instantiation. fn parent_extension(&self) -> &ExtensionId; } @@ -157,6 +202,7 @@ impl CustomConcrete for OpaqueOp { impl CustomConcrete for CustomType { fn def_name(&self) -> &SmolStr { + // Casts the `TypeName` to a generic string. self.name() } @@ -227,7 +273,7 @@ pub struct Extension { /// for any possible [TypeArg]. pub extension_reqs: ExtensionSet, /// Types defined by this extension. - types: HashMap, + types: HashMap, /// Static values defined by this extension. values: HashMap, /// Operation declarations with serializable definitions. @@ -282,7 +328,7 @@ impl Extension { } /// Iterator over the types of this [`Extension`]. - pub fn types(&self) -> impl Iterator { + pub fn types(&self) -> impl Iterator { self.types.iter() } @@ -298,8 +344,10 @@ impl Extension { typed_value, }; match self.values.entry(extension_value.name.clone()) { - Entry::Occupied(_) => Err(ExtensionBuildError::OpDefExists(extension_value.name)), - Entry::Vacant(ve) => Ok(ve.insert(extension_value)), + hash_map::Entry::Occupied(_) => { + Err(ExtensionBuildError::OpDefExists(extension_value.name)) + } + hash_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)), } } @@ -331,8 +379,18 @@ impl PartialEq for Extension { } } -/// An error that can occur in computing the signature of a node. -/// TODO: decide on failure modes +/// An error that can occur in defining an extension registry. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum ExtensionRegistryError { + /// Extension already defined. + #[error("The registry already contains an extension with id {0}.")] + AlreadyRegistered(ExtensionId), + /// A registered extension has invalid signatures. + #[error("The extension {0} contains an invalid signature, {1}.")] + InvalidSignature(ExtensionId, #[source] SignatureError), +} + +/// An error that can occur in building a new extension. #[derive(Debug, Clone, Error, PartialEq, Eq)] pub enum ExtensionBuildError { /// Existing [`OpDef`] diff --git a/src/extension/declarative.rs b/src/extension/declarative.rs new file mode 100644 index 000000000..21cc9e8d7 --- /dev/null +++ b/src/extension/declarative.rs @@ -0,0 +1,400 @@ +//! Declarative extension definitions. +//! +//! This module includes functions to dynamically load HUGR extensions defined in a YAML file. +//! +//! An extension file may define multiple extensions, each with a set of types and operations. +//! +//! See the [specification] for more details. +//! +//! ### Example +//! +//! ```yaml +#![doc = include_str!("../../examples/extension/declarative.yaml")] +//! ``` +//! +//! [specification]: https://github.com/CQCL/hugr/blob/main/specification/hugr.md#declarative-format + +mod ops; +mod signature; +mod types; + +use std::fs::File; +use std::path::Path; + +use crate::extension::prelude::PRELUDE_ID; +use crate::types::TypeName; +use crate::Extension; + +use super::{ + ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionRegistryError, ExtensionSet, + PRELUDE, +}; +use ops::OperationDeclaration; +use smol_str::SmolStr; +use types::TypeDeclaration; + +use serde::{Deserialize, Serialize}; + +/// Load a set of extensions from a YAML string into a registry. +/// +/// Any required extensions must already be present in the registry. +pub fn load_extensions( + yaml: &str, + registry: &mut ExtensionRegistry, +) -> Result<(), ExtensionDeclarationError> { + let ext: ExtensionSetDeclaration = serde_yaml::from_str(yaml)?; + ext.add_to_registry(registry) +} + +/// Load a set of extensions from a file into a registry. +/// +/// Any required extensions must already be present in the registry. +pub fn load_extensions_file( + path: &Path, + registry: &mut ExtensionRegistry, +) -> Result<(), ExtensionDeclarationError> { + let file = File::open(path)?; + let ext: ExtensionSetDeclaration = serde_yaml::from_reader(file)?; + ext.add_to_registry(registry) +} + +/// A set of declarative extension definitions with some metadata. +/// +/// These are normally contained in a single YAML file. +// +// TODO: More metadata, "namespace"? +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +struct ExtensionSetDeclaration { + /// A set of extension definitions. + // + // TODO: allow qualified, and maybe locally-scoped? + extensions: Vec, + /// A list of extension IDs that this extension depends on. + /// Optional. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + imports: ExtensionSet, +} + +/// A declarative extension definition. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +struct ExtensionDeclaration { + /// The name of the extension. + name: ExtensionId, + /// A list of types that this extension provides. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + types: Vec, + /// A list of operations that this extension provides. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + operations: Vec, + // TODO: Values? +} + +impl ExtensionSetDeclaration { + /// Register this set of extensions with the given registry. + pub fn add_to_registry( + &self, + registry: &mut ExtensionRegistry, + ) -> Result<(), ExtensionDeclarationError> { + // All dependencies must be present in the registry. + for imp in self.imports.iter() { + if !registry.contains(imp) { + return Err(ExtensionDeclarationError::MissingExtension { ext: imp.clone() }); + } + } + + // A set of extensions that are in scope for the definition. This is a + // subset of `registry` that includes `self.imports` and the previous + // extensions defined in the declaration. + let mut scope = self.imports.clone(); + + // The prelude is auto-imported. + if !registry.contains(&PRELUDE_ID) { + registry.register(PRELUDE.clone())?; + } + if !scope.contains(&PRELUDE_ID) { + scope.insert(&PRELUDE_ID); + } + + // Registers extensions sequentially, adding them to the current scope. + for decl in &self.extensions { + let ctx = DeclarationContext { + scope: &scope, + registry, + }; + let ext = decl.make_extension(&self.imports, ctx)?; + let ext = registry.register(ext)?; + scope.insert(ext.name()) + } + + Ok(()) + } +} + +impl ExtensionDeclaration { + /// Create an [`Extension`] from this declaration. + pub fn make_extension( + &self, + imports: &ExtensionSet, + ctx: DeclarationContext<'_>, + ) -> Result { + let mut ext = Extension::new_with_reqs(self.name.clone(), imports.clone()); + + for t in &self.types { + t.register(&mut ext, ctx)?; + } + + for o in &self.operations { + o.register(&mut ext, ctx)?; + } + + Ok(ext) + } +} + +/// Some context data used while translating a declarative extension definition. +#[derive(Debug, Copy, Clone)] +struct DeclarationContext<'a> { + /// The set of extensions that are in scope for this extension. + pub scope: &'a ExtensionSet, + /// The registry to use for resolving dependencies. + pub registry: &'a ExtensionRegistry, +} + +/// Errors that can occur while loading an extension set. +#[derive(Debug, thiserror::Error)] +pub enum ExtensionDeclarationError { + /// An error occurred while deserializing the extension set. + #[error("Error while parsing the extension set yaml: {0}")] + Deserialize(#[from] serde_yaml::Error), + /// An error in registering the loaded extensions. + #[error("Error registering the extensions.")] + ExtensionRegistryError(#[from] ExtensionRegistryError), + /// An error occurred while adding operations or types to the extension. + #[error("Error while adding operations or types to the extension: {0}")] + ExtensionBuildError(#[from] ExtensionBuildError), + /// Invalid yaml file. + #[error("Invalid yaml declaration file {0}")] + InvalidFile(#[from] std::io::Error), + /// A required extension is missing. + #[error("Missing required extension {ext}")] + MissingExtension { + /// The missing imported extension. + ext: ExtensionId, + }, + /// Referenced an unknown type. + #[error("Extension {ext} referenced an unknown type {ty}.")] + MissingType { + /// The extension that referenced the unknown type. + ext: ExtensionId, + /// The unknown type. + ty: TypeName, + }, + /// Parametric types are not currently supported as type parameters. + /// + /// TODO: Support this. + #[error("Found a currently unsupported higher-order type parameter {ty} in extension {ext}")] + ParametricTypeParameter { + /// The extension that referenced the unsupported type parameter. + ext: ExtensionId, + /// The unsupported type parameter. + ty: TypeName, + }, + /// Parametric types are not currently supported as type parameters. + /// + /// TODO: Support this. + #[error("Found a currently unsupported parametric operation {op} in extension {ext}")] + ParametricOperation { + /// The extension that referenced the unsupported op parameter. + ext: ExtensionId, + /// The operation. + op: SmolStr, + }, + /// Operation definitions with no signature are not currently supported. + /// + /// TODO: Support this. + #[error( + "Operation {op} in extension {ext} has no signature. This is not currently supported." + )] + MissingSignature { + /// The extension containing the operation. + ext: ExtensionId, + /// The operation with no signature. + op: SmolStr, + }, + /// An unknown type was specified in a signature. + #[error("Type {ty} is not in scope. In extension {ext}.")] + UnknownType { + /// The extension that referenced the type. + ext: ExtensionId, + /// The unsupported type. + ty: String, + }, + /// Parametric port repetitions are not currently supported. + /// + /// TODO: Support this. + #[error("Unsupported port repetition {parametric_repetition} in extension {ext}")] + UnsupportedPortRepetition { + /// The extension that referenced the type. + ext: crate::hugr::IdentList, + /// The repetition expression + parametric_repetition: SmolStr, + }, + /// Lowering definitions for an operation are not currently supported. + /// + /// TODO: Support this. + #[error("Unsupported lowering definition for op {op} in extension {ext}")] + LoweringNotSupported { + /// The extension. + ext: crate::hugr::IdentList, + /// The operation with the lowering definition. + op: SmolStr, + }, +} + +#[cfg(test)] +mod test { + use itertools::Itertools; + use rstest::rstest; + use std::path::PathBuf; + + use crate::extension::PRELUDE_REGISTRY; + use crate::std_extensions; + + use super::*; + + /// A yaml extension defining an empty extension. + const EMPTY_YAML: &str = r#" +extensions: +- name: EmptyExt +"#; + + /// A yaml extension defining an extension with one type and two operations. + const BASIC_YAML: &str = r#" +imports: [prelude] + +extensions: +- name: SimpleExt + types: + - name: MyType + description: A simple type with no parameters + bound: Any + operations: + - name: MyOperation + description: A simple operation with no inputs nor outputs + signature: + inputs: [] + outputs: [] + - name: AnotherOperation + description: An operation from 3 qubits to 3 qubits + signature: + inputs: [Q, Q, Q] + outputs: [[Q, 1], [Control, Q, 2]] +"#; + + /// A yaml extension with unsupported features. + const UNSUPPORTED_YAML: &str = r#" +extensions: +- name: UnsupportedExt + types: + - name: MyType + description: A simple type with no parameters + # Parametric types are not currently supported. + params: [Any, ["An unbounded natural number", USize]] + operations: + - name: UnsupportedOperation + description: An operation from 3 qubits to 3 qubits + params: + # Parametric operations are not currently supported. + param1: USize + signature: + # Type declarations will have their own syntax. + inputs: [] + outputs: ["Array[USize]"] +"#; + + /// The yaml used in the module documentation. + const EXAMPLE_YAML_FILE: &str = "examples/extension/declarative.yaml"; + + #[rstest] + #[case(EMPTY_YAML, 1, 0, 0, &PRELUDE_REGISTRY)] + #[case(BASIC_YAML, 1, 1, 2, &PRELUDE_REGISTRY)] + fn test_decode( + #[case] yaml: &str, + #[case] num_declarations: usize, + #[case] num_types: usize, + #[case] num_operations: usize, + #[case] dependencies: &ExtensionRegistry, + ) -> Result<(), Box> { + let mut reg = dependencies.clone(); + load_extensions(yaml, &mut reg)?; + + let new_exts = new_extensions(®, dependencies).collect_vec(); + + assert_eq!(new_exts.len(), num_declarations); + assert_eq!( + new_exts.iter().flat_map(|(_, e)| e.types()).count(), + num_types + ); + assert_eq!( + new_exts.iter().flat_map(|(_, e)| e.operations()).count(), + num_operations + ); + Ok(()) + } + + #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri + #[rstest] + #[case(EXAMPLE_YAML_FILE, 1, 1, 2, &std_extensions::logic::LOGIC_REG)] + fn test_decode_file( + #[case] yaml_file: &str, + #[case] num_declarations: usize, + #[case] num_types: usize, + #[case] num_operations: usize, + #[case] dependencies: &ExtensionRegistry, + ) -> Result<(), Box> { + let mut reg = dependencies.clone(); + load_extensions_file(&PathBuf::from(yaml_file), &mut reg)?; + + let new_exts = new_extensions(®, dependencies).collect_vec(); + + assert_eq!(new_exts.len(), num_declarations); + assert_eq!( + new_exts.iter().flat_map(|(_, e)| e.types()).count(), + num_types + ); + assert_eq!( + new_exts.iter().flat_map(|(_, e)| e.operations()).count(), + num_operations + ); + Ok(()) + } + + #[rstest] + #[case(UNSUPPORTED_YAML, &PRELUDE_REGISTRY)] + fn test_unsupported( + #[case] yaml: &str, + #[case] dependencies: &ExtensionRegistry, + ) -> Result<(), Box> { + let mut reg = dependencies.clone(); + + // The parsing should not fail. + let ext: ExtensionSetDeclaration = serde_yaml::from_str(yaml)?; + + assert!(ext.add_to_registry(&mut reg).is_err()); + + Ok(()) + } + + /// Returns a list of new extensions that have been defined in a register, + /// comparing against a set of pre-included dependencies. + fn new_extensions<'a>( + reg: &'a ExtensionRegistry, + dependencies: &'a ExtensionRegistry, + ) -> impl Iterator { + reg.iter() + .filter(move |(id, _)| !dependencies.contains(id) && *id != &PRELUDE_ID) + } +} diff --git a/src/extension/declarative/ops.rs b/src/extension/declarative/ops.rs new file mode 100644 index 000000000..f1716f533 --- /dev/null +++ b/src/extension/declarative/ops.rs @@ -0,0 +1,115 @@ +//! Declarative operation definitions. +//! +//! This module defines a YAML schema for defining operations in a declarative way. +//! +//! See the [specification] and [`ExtensionSetDeclaration`] for more details. +//! +//! [specification]: https://github.com/CQCL/hugr/blob/main/specification/hugr.md#declarative-format +//! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use smol_str::SmolStr; + +use crate::extension::{OpDef, SignatureFunc}; +use crate::types::type_param::TypeParam; +use crate::Extension; + +use super::signature::SignatureDeclaration; +use super::{DeclarationContext, ExtensionDeclarationError}; + +/// A declarative operation definition. +/// +/// TODO: The "Lowering" attribute is not yet supported. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(super) struct OperationDeclaration { + /// The identifier the operation. + name: SmolStr, + /// A description for the operation. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + description: String, + /// The signature of the operation. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + signature: Option, + /// A set of per-node parameters required to instantiate this operation. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + params: HashMap, + /// An extra set of data associated to the operation. + /// + /// This data is kept in the Hugr, and may be accessed by the relevant runtime. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + misc: HashMap, + /// A pre-compiled lowering routine. + /// + /// This is not yet supported, and will raise an error if present. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + lowering: Option, +} + +impl OperationDeclaration { + /// Register this operation in the given extension. + pub fn register<'ext>( + &self, + ext: &'ext mut Extension, + ctx: DeclarationContext<'_>, + ) -> Result<&'ext mut OpDef, ExtensionDeclarationError> { + // We currently only support explicit signatures. + // + // TODO: Support missing signatures? + let Some(signature) = &self.signature else { + return Err(ExtensionDeclarationError::MissingSignature { + ext: ext.name().clone(), + op: self.name.clone(), + }); + }; + + // We currently do not support parametric operations. + if !self.params.is_empty() { + return Err(ExtensionDeclarationError::ParametricOperation { + ext: ext.name().clone(), + op: self.name.clone(), + }); + } + let params: Vec = vec![]; + + if self.lowering.is_some() { + return Err(ExtensionDeclarationError::LoweringNotSupported { + ext: ext.name().clone(), + op: self.name.clone(), + }); + } + + let signature_func: SignatureFunc = signature.make_signature(ext, ctx, ¶ms)?; + + let op_def = ext.add_op(self.name.clone(), self.description.clone(), signature_func)?; + + for (k, v) in &self.misc { + op_def.add_misc(k, v.clone()); + } + + Ok(op_def) + } +} + +/// The type of a per-node operation parameter required to instantiate an operation. +/// +/// TODO: The value should be decoded as a [`TypeParam`]. +/// Valid options include: +/// +/// - `USize` +/// - `Type` +/// +/// [`TypeParam`]: crate::types::type_param::TypeParam +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub(super) struct ParamDeclaration( + /// TODO: Store a [`TypeParam`], and implement custom parsers. + /// + /// [`TypeParam`]: crate::types::type_param::TypeParam + String, +); diff --git a/src/extension/declarative/signature.rs b/src/extension/declarative/signature.rs new file mode 100644 index 000000000..de7163577 --- /dev/null +++ b/src/extension/declarative/signature.rs @@ -0,0 +1,203 @@ +//! Declarative signature definitions. +//! +//! This module defines a YAML schema for defining the signature of an operation in a declarative way. +//! +//! See the [specification] and [`ExtensionSetDeclaration`] for more details. +//! +//! [specification]: https://github.com/CQCL/hugr/blob/main/specification/hugr.md#declarative-format +//! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration + +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use smol_str::SmolStr; + +use crate::extension::prelude::PRELUDE_ID; +use crate::extension::{CustomValidator, ExtensionSet, SignatureFunc, TypeDef, TypeParametrised}; +use crate::types::type_param::TypeParam; +use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeRow}; +use crate::Extension; + +use super::{DeclarationContext, ExtensionDeclarationError}; + +/// A declarative operation signature definition. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub(super) struct SignatureDeclaration { + /// The inputs to the operation. + inputs: Vec, + /// The outputs of the operation. + outputs: Vec, + /// A set of extensions invoked while running this operation. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + extensions: ExtensionSet, +} + +impl SignatureDeclaration { + /// Register this signature in the given extension. + pub fn make_signature( + &self, + ext: &Extension, + ctx: DeclarationContext<'_>, + op_params: &[TypeParam], + ) -> Result { + let make_type_row = + |v: &[SignaturePortDeclaration]| -> Result { + let types = v + .iter() + .map(|port_decl| port_decl.make_types(ext, ctx, op_params)) + .flatten_ok() + .collect::, _>>()?; + Ok(types.into()) + }; + + let body = FunctionType { + input: make_type_row(&self.inputs)?, + output: make_type_row(&self.outputs)?, + extension_reqs: self.extensions.clone(), + }; + + let poly_func = PolyFuncType::new(op_params, body); + Ok(SignatureFunc::TypeScheme(CustomValidator::from_polyfunc( + poly_func, + ))) + } +} + +/// A declarative definition for a number of ports in a signature's input or output. +/// +/// Serialized as a single type, or as a 2 or 3-element lists. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +enum SignaturePortDeclaration { + /// A single port type. + Type(TypeDeclaration), + /// A 2-tuple with the type and a repetition declaration. + TypeRepeat(TypeDeclaration, PortRepetitionDeclaration), + /// A 3-tuple with a description, a type declaration, and a repetition declaration. + DescriptionTypeRepeat(String, TypeDeclaration, PortRepetitionDeclaration), +} + +impl SignaturePortDeclaration { + /// Return an iterator with the types for this port declaration. + fn make_types( + &self, + ext: &Extension, + ctx: DeclarationContext<'_>, + op_params: &[TypeParam], + ) -> Result, ExtensionDeclarationError> { + let n: usize = match self.repeat() { + PortRepetitionDeclaration::Count(n) => *n, + PortRepetitionDeclaration::Parameter(parametric_repetition) => { + return Err(ExtensionDeclarationError::UnsupportedPortRepetition { + ext: ext.name().clone(), + parametric_repetition: parametric_repetition.clone(), + }) + } + }; + + let ty = self.type_decl().make_type(ext, ctx, op_params)?; + let ty = Type::new_extension(ty); + + Ok(itertools::repeat_n(ty, n)) + } + + /// Get the type declaration for this port. + fn type_decl(&self) -> &TypeDeclaration { + match self { + SignaturePortDeclaration::Type(ty) => ty, + SignaturePortDeclaration::TypeRepeat(ty, _) => ty, + SignaturePortDeclaration::DescriptionTypeRepeat(_, ty, _) => ty, + } + } + + /// Get the repetition declaration for this port. + fn repeat(&self) -> &PortRepetitionDeclaration { + static DEFAULT_REPEAT: PortRepetitionDeclaration = PortRepetitionDeclaration::Count(1); + match self { + SignaturePortDeclaration::DescriptionTypeRepeat(_, _, repeat) => repeat, + SignaturePortDeclaration::TypeRepeat(_, repeat) => repeat, + _ => &DEFAULT_REPEAT, + } + } +} + +/// A number of repetitions for a signature's port definition. +/// +/// This value must be a number, indicating a repetition of the port that amount of times. +/// +/// Generic expressions are not yet supported. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +enum PortRepetitionDeclaration { + /// A constant number of repetitions for the port definition. + Count(usize), + /// An (integer) operation parameter identifier to use as the number of repetitions. + Parameter(SmolStr), +} + +impl Default for PortRepetitionDeclaration { + fn default() -> Self { + PortRepetitionDeclaration::Count(1) + } +} + +/// A type declaration used in signatures. +/// +/// TODO: The spec definition is more complex than just a type identifier, +/// we should be able to support expressions like: +/// +/// - `Q` +/// - `Array(Array(F64))` +/// - `Function[r](USize -> USize)` +/// - `Opaque(complex_matrix,i,j)` +/// +/// Note that `Q` is not the name used for a qubit in the prelude. +/// +/// For now, we just hard-code some basic types. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +struct TypeDeclaration( + /// The encoded type description. + String, +); + +impl TypeDeclaration { + /// Parse the type represented by this declaration. + /// + /// Currently hard-codes some basic types. + /// + /// TODO: Support arbitrary types. + /// TODO: Support parametric types. + pub fn make_type( + &self, + ext: &Extension, + ctx: DeclarationContext<'_>, + _op_params: &[TypeParam], + ) -> Result { + // The prelude is always in scope. + debug_assert!(ctx.scope.contains(&PRELUDE_ID)); + + // Only hard-coded prelude types are supported for now. + let prelude = ctx.registry.get(&PRELUDE_ID).unwrap(); + let op_def: &TypeDef = match self.0.as_str() { + "USize" => prelude.get_type("usize"), + "Q" => prelude.get_type("qubit"), + _ => { + return Err(ExtensionDeclarationError::UnknownType { + ext: ext.name().clone(), + ty: self.0.clone(), + }) + } + } + .ok_or(ExtensionDeclarationError::UnknownType { + ext: ext.name().clone(), + ty: self.0.clone(), + })?; + + // The hard-coded types are not parametric. + assert!(op_def.params().is_empty()); + let op = op_def.instantiate(&[]).unwrap(); + + Ok(op) + } +} diff --git a/src/extension/declarative/types.rs b/src/extension/declarative/types.rs new file mode 100644 index 000000000..c90eb08e7 --- /dev/null +++ b/src/extension/declarative/types.rs @@ -0,0 +1,170 @@ +//! Declarative type definitions. +//! +//! This module defines a YAML schema for defining types in a declarative way. +//! +//! See the [specification] and [`ExtensionSetDeclaration`] for more details. +//! +//! [specification]: https://github.com/CQCL/hugr/blob/main/specification/hugr.md#declarative-format +//! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration + +use crate::extension::{TypeDef, TypeDefBound, TypeParametrised}; +use crate::types::type_param::TypeParam; +use crate::types::{CustomType, TypeBound, TypeName}; +use crate::Extension; + +use serde::{Deserialize, Serialize}; + +use super::{DeclarationContext, ExtensionDeclarationError}; + +/// A declarative type definition. +#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)] +pub(super) struct TypeDeclaration { + /// The name of the type. + name: TypeName, + /// A description for the type. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + description: String, + /// The type bound describing what can be done to instances of this type. + /// Options are `Eq`, `Copyable`, or `Any`. + /// + /// See [`TypeBound`] and [`TypeDefBound`]. + /// + /// TODO: Derived bounds from the parameters (see [`TypeDefBound`]) are not yet supported. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + bound: TypeDefBoundDeclaration, + /// A list of type parameters for this type. + /// + /// Each element in the list is a 2-element list, where the first element is + /// the human-readable name of the type parameter, and the second element is + /// the type id. + #[serde(default)] + #[serde(skip_serializing_if = "crate::utils::is_default")] + params: Vec, +} + +impl TypeDeclaration { + /// Register this type in the given extension. + /// + /// Types in the definition will be resolved using the extensions in `scope` + /// and the current extension. + pub fn register<'ext>( + &self, + ext: &'ext mut Extension, + ctx: DeclarationContext<'_>, + ) -> Result<&'ext TypeDef, ExtensionDeclarationError> { + let params = self + .params + .iter() + .map(|param| param.make_type_param(ext, ctx)) + .collect::, _>>()?; + let type_def = ext.add_type( + self.name.clone(), + params, + self.description.clone(), + self.bound.into(), + )?; + Ok(type_def) + } +} + +/// A declarative TypeBound definition. +/// +/// Equivalent to a [`TypeDefBound`]. Provides human-friendly serialization, using +/// the full names. +/// +/// TODO: Support derived bounds +#[derive( + Debug, Copy, Clone, Serialize, Deserialize, Hash, PartialEq, Eq, Default, derive_more::Display, +)] +enum TypeDefBoundDeclaration { + /// The equality operation is valid on this type. + Eq, + /// The type can be copied in the program. + Copyable, + /// No bound on the type. + #[default] + Any, +} + +impl From for TypeDefBound { + fn from(bound: TypeDefBoundDeclaration) -> Self { + match bound { + TypeDefBoundDeclaration::Eq => Self::Explicit(TypeBound::Eq), + TypeDefBoundDeclaration::Copyable => Self::Explicit(TypeBound::Copyable), + TypeDefBoundDeclaration::Any => Self::Explicit(TypeBound::Any), + } + } +} + +/// A declarative type parameter definition. +/// +/// Either a type, or a 2-element list containing a human-readable name and a type id. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +enum TypeParamDeclaration { + /// Just the type id. + Type(TypeName), + /// A 2-tuple containing a human-readable name and a type id. + WithDescription(String, TypeName), +} + +impl TypeParamDeclaration { + /// Create a [`TypeParam`] from this declaration. + /// + /// Resolves any type ids using both the current extension and any other in `scope`. + /// + /// TODO: Only non-parametric opaque types are supported for now. + /// TODO: The parameter description is currently ignored. + pub fn make_type_param( + &self, + extension: &Extension, + ctx: DeclarationContext<'_>, + ) -> Result { + let instantiate_type = |ty: &TypeDef| -> Result { + match ty.params() { + [] => Ok(ty.instantiate([]).unwrap()), + _ => Err(ExtensionDeclarationError::ParametricTypeParameter { + ext: extension.name().clone(), + ty: self.type_name().clone(), + }), + } + }; + + // First try the previously defined types in the current extension. + if let Some(ty) = extension.get_type(self.type_name()) { + return Ok(TypeParam::Opaque { + ty: instantiate_type(ty)?, + }); + } + + // Try every extension in scope. + // + // TODO: Can we resolve the extension id from the type name instead? + for ext in ctx.scope.iter() { + if let Some(ty) = ctx + .registry + .get(ext) + .and_then(|ext| ext.get_type(self.type_name())) + { + return Ok(TypeParam::Opaque { + ty: instantiate_type(ty)?, + }); + } + } + + Err(ExtensionDeclarationError::MissingType { + ext: extension.name().clone(), + ty: self.type_name().clone(), + }) + } + + /// Returns the type name of this type parameter. + fn type_name(&self) -> &TypeName { + match self { + Self::Type(ty) => ty, + Self::WithDescription(_, ty) => ty, + } + } +} diff --git a/src/extension/type_def.rs b/src/extension/type_def.rs index a931c6383..dd13b453f 100644 --- a/src/extension/type_def.rs +++ b/src/extension/type_def.rs @@ -3,14 +3,12 @@ use std::collections::hash_map::Entry; use super::{CustomConcrete, ExtensionBuildError}; use super::{Extension, ExtensionId, SignatureError, TypeParametrised}; -use crate::types::{least_upper_bound, CustomType}; +use crate::types::{least_upper_bound, CustomType, TypeName}; use crate::types::type_param::TypeArg; use crate::types::type_param::TypeParam; -use smol_str::SmolStr; - use crate::types::TypeBound; /// The type bound of a [`TypeDef`] @@ -36,7 +34,7 @@ pub struct TypeDef { /// The unique Extension owning this TypeDef (of which this TypeDef is a member) extension: ExtensionId, /// The unique name of the type - name: SmolStr, + name: TypeName, /// Declaration of type parameters. The TypeDef must be instantiated /// with the same number of [`TypeArg`]'s to make an actual type. /// @@ -133,7 +131,7 @@ impl TypeParametrised for TypeDef { &self.params } - fn name(&self) -> &SmolStr { + fn name(&self) -> &TypeName { &self.name } @@ -146,7 +144,7 @@ impl Extension { /// Add an exported type to the extension. pub fn add_type( &mut self, - name: SmolStr, + name: TypeName, params: Vec, description: String, bound: TypeDefBound, @@ -183,7 +181,7 @@ mod test { b: TypeBound::Copyable, }], extension: "MyRsrc".try_into().unwrap(), - description: "Some parameterised type".into(), + description: "Some parametrised type".into(), bound: TypeDefBound::FromParams(vec![0]), }; let typ = Type::new_extension( diff --git a/src/types.rs b/src/types.rs index ff8164221..aaf616354 100644 --- a/src/types.rs +++ b/src/types.rs @@ -12,6 +12,7 @@ pub use check::{ConstTypeError, CustomCheckFailure}; pub use custom::CustomType; pub use poly_func::PolyFuncType; pub use signature::FunctionType; +use smol_str::SmolStr; pub use type_param::TypeArg; pub use type_row::TypeRow; @@ -26,6 +27,9 @@ use std::fmt::Debug; use self::type_param::TypeParam; +/// A unique identifier for a type. +pub type TypeName = SmolStr; + /// The kinds of edges in a HUGR, excluding Hierarchy. #[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] #[non_exhaustive] diff --git a/src/types/custom.rs b/src/types/custom.rs index 68e70b2c5..2224b7325 100644 --- a/src/types/custom.rs +++ b/src/types/custom.rs @@ -1,11 +1,11 @@ //! Opaque types, used to represent a user-defined [`Type`]. //! //! [`Type`]: super::Type -use smol_str::SmolStr; use std::fmt::{self, Display}; use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDef}; +use super::TypeName; use super::{ type_param::{TypeArg, TypeParam}, Substitution, TypeBound, @@ -19,7 +19,7 @@ pub struct CustomType { /// Same as the corresponding [`TypeDef`] /// /// [`TypeDef`]: crate::extension::TypeDef - id: SmolStr, + id: TypeName, /// Arguments that fit the [`TypeParam`]s declared by the typedef /// /// [`TypeParam`]: super::type_param::TypeParam @@ -31,7 +31,7 @@ pub struct CustomType { impl CustomType { /// Creates a new opaque type. pub fn new( - id: impl Into, + id: impl Into, args: impl Into>, extension: ExtensionId, bound: TypeBound, @@ -45,7 +45,7 @@ impl CustomType { } /// Creates a new opaque type (constant version, no type arguments) - pub const fn new_simple(id: SmolStr, extension: ExtensionId, bound: TypeBound) -> Self { + pub const fn new_simple(id: TypeName, extension: ExtensionId, bound: TypeBound) -> Self { Self { id, args: vec![], @@ -107,7 +107,7 @@ impl CustomType { } /// unique name of the type. - pub fn name(&self) -> &SmolStr { + pub fn name(&self) -> &TypeName { &self.id } diff --git a/src/utils.rs b/src/utils.rs index 36839ad7e..b610c6e9c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -27,6 +27,23 @@ pub fn collect_array(arr: &[T]) -> [&T; N] { arr.iter().collect_vec().try_into().unwrap() } +/// Helper method to skip serialization of default values in serde. +/// +/// ```ignore +/// use serde::Serialize; +/// +/// #[derive(Serialize)] +/// struct MyStruct { +/// #[serde(skip_serializing_if = "crate::utils::is_default")] +/// field: i32, +/// } +/// ``` +/// +/// From https://github.com/serde-rs/serde/issues/818. +pub(crate) fn is_default(t: &T) -> bool { + *t == Default::default() +} + #[cfg(test)] pub(crate) mod test_quantum_extension { use smol_str::SmolStr;