diff --git a/hugr-cli/Cargo.toml b/hugr-cli/Cargo.toml index f2b09cc2c..a517d8f48 100644 --- a/hugr-cli/Cargo.toml +++ b/hugr-cli/Cargo.toml @@ -15,7 +15,8 @@ categories = ["compilers"] [dependencies] clap = { workspace = true, features = ["derive"] } clap-verbosity-flag.workspace = true -hugr-core = { path = "../hugr-core", version = "0.13.1" } +derive_more = { workspace = true, features = ["display", "error", "from"] } +hugr = { path = "../hugr", version = "0.13.1" } serde_json.workspace = true serde.workspace = true thiserror.workspace = true diff --git a/hugr-cli/src/extensions.rs b/hugr-cli/src/extensions.rs index a6c5ff33c..792d997e7 100644 --- a/hugr-cli/src/extensions.rs +++ b/hugr-cli/src/extensions.rs @@ -1,6 +1,6 @@ //! Dump standard extensions in serialized form. use clap::Parser; -use hugr_core::extension::ExtensionRegistry; +use hugr::extension::ExtensionRegistry; use std::{io::Write, path::PathBuf}; /// Dump the standard extensions. diff --git a/hugr-cli/src/lib.rs b/hugr-cli/src/lib.rs index 8a421aa70..a75752af9 100644 --- a/hugr-cli/src/lib.rs +++ b/hugr-cli/src/lib.rs @@ -3,14 +3,17 @@ use clap::Parser; use clap_verbosity_flag::{InfoLevel, Verbosity}; use clio::Input; -use hugr_core::{Extension, Hugr}; +use derive_more::{Display, Error, From}; +use hugr::package::{PackageEncodingError, PackageValidationError}; use std::{ffi::OsString, path::PathBuf}; -use thiserror::Error; pub mod extensions; pub mod mermaid; pub mod validate; +// TODO: Deprecated re-export. Remove on a breaking release. +pub use hugr::package::Package; + /// CLI arguments. #[derive(Parser, Debug)] #[clap(version = "1.0", long_about = None)] @@ -30,18 +33,21 @@ pub enum CliArgs { } /// Error type for the CLI. -#[derive(Debug, Error)] -#[error(transparent)] +#[derive(Debug, Display, Error, From)] #[non_exhaustive] pub enum CliError { /// Error reading input. - #[error("Error reading from path: {0}")] - InputFile(#[from] std::io::Error), + #[display("Error reading from path: {_0}")] + InputFile(std::io::Error), /// Error parsing input. - #[error("Error parsing input: {0}")] - Parse(#[from] serde_json::Error), + #[display("Error parsing input: {_0}")] + Parse(serde_json::Error), + /// Error loading a package. + #[display("Error parsing package: {_0}")] + Package(PackageEncodingError), + #[display("Error validating HUGR: {_0}")] /// Errors produced by the `validate` subcommand. - Validate(#[from] validate::ValError), + Validate(PackageValidationError), } /// Validate and visualise a HUGR file. @@ -68,36 +74,10 @@ pub struct HugrArgs { pub extensions: Vec, } -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -/// Package of module HUGRs and extensions. -/// The HUGRs are validated against the extensions. -pub struct Package { - /// Module HUGRs included in the package. - pub modules: Vec, - /// Extensions to validate against. - pub extensions: Vec, -} - -impl Package { - /// Create a new package. - pub fn new(modules: Vec, extensions: Vec) -> Self { - Self { - modules, - extensions, - } - } -} - impl HugrArgs { /// Read either a package or a single hugr from the input. pub fn get_package(&mut self) -> Result { - let val: serde_json::Value = serde_json::from_reader(&mut self.input)?; - // read either a package or a single hugr - if let Ok(p) = serde_json::from_value::(val.clone()) { - Ok(p) - } else { - let hugr: Hugr = serde_json::from_value(val)?; - Ok(Package::new(vec![hugr], vec![])) - } + let pkg = Package::from_json_reader(&mut self.input)?; + Ok(pkg) } } diff --git a/hugr-cli/src/main.rs b/hugr-cli/src/main.rs index 87fbab41a..20f602fed 100644 --- a/hugr-cli/src/main.rs +++ b/hugr-cli/src/main.rs @@ -9,7 +9,7 @@ use clap_verbosity_flag::Level; fn main() { match CliArgs::parse() { CliArgs::Validate(args) => run_validate(args), - CliArgs::GenExtensions(args) => args.run_dump(&hugr_core::std_extensions::STD_REG), + CliArgs::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG), CliArgs::Mermaid(mut args) => args.run_print().unwrap(), CliArgs::External(_) => { // TODO: Implement support for external commands. diff --git a/hugr-cli/src/mermaid.rs b/hugr-cli/src/mermaid.rs index 273d40337..0e67a1598 100644 --- a/hugr-cli/src/mermaid.rs +++ b/hugr-cli/src/mermaid.rs @@ -3,7 +3,7 @@ use std::io::Write; use clap::Parser; use clio::Output; -use hugr_core::HugrView; +use hugr::HugrView; /// Dump the standard extensions. #[derive(Parser, Debug)] diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index badebb8d8..b7fac3b5c 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -2,10 +2,10 @@ use clap::Parser; use clap_verbosity_flag::Level; -use hugr_core::{extension::ExtensionRegistry, Extension, Hugr}; -use thiserror::Error; +use hugr::package::PackageValidationError; +use hugr::{extension::ExtensionRegistry, Extension, Hugr}; -use crate::{CliError, HugrArgs, Package}; +use crate::{CliError, HugrArgs}; /// Validate and visualise a HUGR file. #[derive(Parser, Debug)] @@ -19,18 +19,6 @@ pub struct ValArgs { pub hugr_args: HugrArgs, } -/// Error type for the CLI. -#[derive(Error, Debug)] -#[non_exhaustive] -pub enum ValError { - /// Error validating HUGR. - #[error("Error validating HUGR: {0}")] - Validate(#[from] hugr_core::hugr::ValidationError), - /// Error registering extension. - #[error("Error registering extension: {0}")] - ExtReg(#[from] hugr_core::extension::ExtensionRegistryError), -} - /// String to print when validation is successful. pub const VALID_PRINT: &str = "HUGR valid!"; @@ -50,49 +38,30 @@ impl ValArgs { } } -impl Package { - /// Validate the package against an extension registry. - /// - /// `reg` is updated with any new extensions. - /// - /// Returns the validated modules. - pub fn validate(mut self, reg: &mut ExtensionRegistry) -> Result, ValError> { - // register packed extensions - for ext in self.extensions { - reg.register_updated(ext)?; - } - - for hugr in self.modules.iter_mut() { - hugr.update_validate(reg)?; - } - - Ok(self.modules) - } -} - impl HugrArgs { /// Load the package and validate against an extension registry. /// /// Returns the validated modules and the extension registry the modules /// were validated against. pub fn validate(&mut self) -> Result<(Vec, ExtensionRegistry), CliError> { - let package = self.get_package()?; + let mut package = self.get_package()?; let mut reg: ExtensionRegistry = if self.no_std { - hugr_core::extension::PRELUDE_REGISTRY.to_owned() + hugr::extension::PRELUDE_REGISTRY.to_owned() } else { - hugr_core::std_extensions::STD_REG.to_owned() + hugr::std_extensions::STD_REG.to_owned() }; // register external extensions for ext in &self.extensions { let f = std::fs::File::open(ext)?; let ext: Extension = serde_json::from_reader(f)?; - reg.register_updated(ext).map_err(ValError::ExtReg)?; + reg.register_updated(ext) + .map_err(PackageValidationError::Extension)?; } - let modules = package.validate(&mut reg)?; - Ok((modules, reg)) + package.validate(&mut reg)?; + Ok((package.modules, reg)) } /// Test whether a `level` message should be output. diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index d885ab6cd..fb80e8809 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -6,10 +6,9 @@ use assert_cmd::Command; use assert_fs::{fixture::FileWriteStr, NamedTempFile}; -use hugr_cli::{validate::VALID_PRINT, Package}; -use hugr_core::builder::DFGBuilder; -use hugr_core::types::Type; -use hugr_core::{ +use hugr::builder::DFGBuilder; +use hugr::types::Type; +use hugr::{ builder::{Container, Dataflow}, extension::prelude::{BOOL_T, QB_T}, std_extensions::arithmetic::float_types::FLOAT64_TYPE, @@ -17,6 +16,7 @@ use hugr_core::{ types::Signature, Hugr, }; +use hugr_cli::{validate::VALID_PRINT, Package}; use predicates::{prelude::*, str::contains}; use rstest::{fixture, rstest}; @@ -128,7 +128,7 @@ fn test_bad_json(mut val_cmd: Command) { val_cmd .assert() .failure() - .stderr(contains("Error parsing input")); + .stderr(contains("Error parsing package")); } #[rstest] @@ -139,7 +139,7 @@ fn test_bad_json_silent(mut val_cmd: Command) { val_cmd .assert() .failure() - .stderr(contains("Error parsing input").not()); + .stderr(contains("Error parsing package").not()); } #[rstest] @@ -188,7 +188,7 @@ fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) { #[fixture] fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String { let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap(); - let float_ext: hugr_core::Extension = serde_json::from_reader(rdr).unwrap(); + let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap(); let package = Package::new(vec![test_hugr], vec![float_ext]); serde_json::to_string(&package).unwrap() } diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index f48df9974..0ff9cf535 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -37,7 +37,7 @@ serde = { workspace = true, features = ["derive", "rc"] } serde_yaml = { workspace = true, optional = true } typetag = { workspace = true } smol_str = { workspace = true, features = ["serde"] } -derive_more = { workspace = true, features = ["display", "from"] } +derive_more = { workspace = true, features = ["display", "error", "from"] } itertools = { workspace = true } html-escape = { workspace = true } bitvec = { workspace = true, features = ["serde"] } diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 060ea97e9..9a1b30b99 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -37,7 +37,7 @@ pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; pub mod declarative; /// Extension Registries store extensions to be looked up e.g. during validation. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct ExtensionRegistry(BTreeMap); impl ExtensionRegistry { @@ -92,6 +92,9 @@ impl ExtensionRegistry { /// If extension IDs match, the extension with the higher version is kept. /// If versions match, the original extension is kept. /// Returns a reference to the registered extension if successful. + /// + /// Avoids cloning the extension unless required. For a reference version see + /// [`ExtensionRegistry::register_updated_ref`]. pub fn register_updated( &mut self, extension: Extension, @@ -107,6 +110,30 @@ impl ExtensionRegistry { } } + /// Registers a new extension to the registry, keeping most up to date if + /// extension exists. + /// + /// If extension IDs match, the extension with the higher version is kept. + /// If versions match, the original extension is kept. Returns a reference + /// to the registered extension if successful. + /// + /// Clones the extension if required. For no-cloning version see + /// [`ExtensionRegistry::register_updated`]. + pub fn register_updated_ref( + &mut self, + extension: &Extension, + ) -> Result<&Extension, ExtensionRegistryError> { + match self.0.entry(extension.name().clone()) { + btree_map::Entry::Occupied(mut prev) => { + if prev.get().version() < extension.version() { + *prev.get_mut() = extension.clone(); + } + Ok(prev.into_mut()) + } + btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension.clone())), + } + } + /// Returns the number of extensions in the registry. pub fn len(&self) -> usize { self.0.len() @@ -418,7 +445,7 @@ impl Extension { impl PartialEq for Extension { fn eq(&self, other: &Self) -> bool { - self.name == other.name + self.name == other.name && self.version == other.version } } @@ -612,7 +639,11 @@ pub mod test { #[test] fn test_register_update() { + // Two registers that should remain the same. + // We use them to test both `register_updated` and `register_updated_ref`. let mut reg = ExtensionRegistry::try_new([]).unwrap(); + let mut reg_ref = ExtensionRegistry::try_new([]).unwrap(); + let ext_1_id = ExtensionId::new("ext1").unwrap(); let ext_2_id = ExtensionId::new("ext2").unwrap(); let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0)); @@ -621,7 +652,8 @@ pub mod test { let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0)); reg.register(ext1.clone()).unwrap(); - assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 0, 0)); + reg_ref.register(ext1.clone()).unwrap(); + assert_eq!(®, ®_ref); // normal registration fails assert_eq!( @@ -634,12 +666,16 @@ pub mod test { ); // register with update works + reg_ref.register_updated_ref(&ext1_1).unwrap(); reg.register_updated(ext1_1.clone()).unwrap(); assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0)); + assert_eq!(®, ®_ref); // register with lower version does not change version + reg_ref.register_updated_ref(&ext1_2).unwrap(); reg.register_updated(ext1_2.clone()).unwrap(); assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0)); + assert_eq!(®, ®_ref); reg.register(ext2.clone()).unwrap(); assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0)); diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index f58f8ef77..879ce5744 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -17,6 +17,7 @@ pub mod hugr; pub mod import; pub mod macros; pub mod ops; +pub mod package; pub mod std_extensions; pub mod types; pub mod utils; diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs new file mode 100644 index 000000000..bcdf4282f --- /dev/null +++ b/hugr-core/src/package.rs @@ -0,0 +1,120 @@ +//! Bundles of hugr modules along with the extension required to load them. + +use derive_more::{Display, Error, From}; +use std::path::Path; +use std::{fs, io}; + +use crate::extension::{ExtensionRegistry, ExtensionRegistryError}; +use crate::hugr::ValidationError; +use crate::{Extension, Hugr}; + +#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)] +/// Package of module HUGRs and extensions. +/// The HUGRs are validated against the extensions. +pub struct Package { + /// Module HUGRs included in the package. + pub modules: Vec, + /// Extensions to validate against. + pub extensions: Vec, +} + +impl Package { + /// Create a new package from a list of hugrs and extensions. + pub fn new( + modules: impl IntoIterator, + extensions: impl IntoIterator, + ) -> Self { + Self { + modules: modules.into_iter().collect(), + extensions: extensions.into_iter().collect(), + } + } + + /// Validate the package against an extension registry. + /// + /// `reg` is updated with any new extensions. + pub fn validate(&mut self, reg: &mut ExtensionRegistry) -> Result<(), PackageValidationError> { + for ext in &self.extensions { + reg.register_updated_ref(ext)?; + } + for hugr in self.modules.iter_mut() { + hugr.update_validate(reg)?; + } + Ok(()) + } + + /// Read a Package in json format from an io reader. + /// + /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package]. + pub fn from_json_reader(reader: impl io::Read) -> Result { + let val: serde_json::Value = serde_json::from_reader(reader)?; + let pkg_load_err = match serde_json::from_value::(val.clone()) { + Ok(p) => return Ok(p), + Err(e) => e, + }; + + if let Ok(hugr) = serde_json::from_value::(val) { + return Ok(Package::new([hugr], [])); + } + + // Return the original error from parsing the package. + Err(PackageEncodingError::JsonEncoding(pkg_load_err)) + } + + /// Read a Package from a json string. + /// + /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package]. + pub fn from_json(json: impl AsRef) -> Result { + Self::from_json_reader(json.as_ref().as_bytes()) + } + + /// Read a Package from a json file. + /// + /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package]. + pub fn from_json_file(path: impl AsRef) -> Result { + let file = fs::File::open(path)?; + let reader = io::BufReader::new(file); + Self::from_json_reader(reader) + } + + /// Write the Package in json format into an io writer. + pub fn to_json_writer(&self, writer: impl io::Write) -> Result<(), PackageEncodingError> { + serde_json::to_writer(writer, self)?; + Ok(()) + } + + /// Write the Package into a json string. + /// + /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package]. + pub fn to_json(&self) -> Result { + let json = serde_json::to_string(self)?; + Ok(json) + } + + /// Write the Package into a json file. + pub fn to_json_file(&self, path: impl AsRef) -> Result<(), PackageEncodingError> { + let file = fs::File::open(path)?; + let writer = io::BufWriter::new(file); + self.to_json_writer(writer) + } +} + +/// Error raised while loading a package. +#[derive(Debug, Display, Error, From)] +#[non_exhaustive] +pub enum PackageEncodingError { + /// Error raised while parsing the package json. + JsonEncoding(serde_json::Error), + /// Error raised while reading from a file. + IOError(io::Error), +} + +/// Error raised while validating a package. +#[derive(Debug, Display, Error, From)] +#[non_exhaustive] +pub enum PackageValidationError { + /// Error raised while processing the package extensions. + Extension(ExtensionRegistryError), + /// Error raised while validating the package hugrs. + Validation(ValidationError), +} diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index 1b81b98da..e4fa3ee99 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -135,7 +135,7 @@ // These modules are re-exported as-is. If more control is needed, define a new module in this crate with the desired exports. // The doc inline directive is necessary for renamed modules to appear as if they were defined in this crate. -pub use hugr_core::{builder, core, extension, ops, std_extensions, types, utils}; +pub use hugr_core::{builder, core, extension, ops, package, std_extensions, types, utils}; #[doc(inline)] pub use hugr_passes as algorithms;