Skip to content

Commit

Permalink
feat: Add Package definition on hugr-core (#1587)
Browse files Browse the repository at this point in the history
Moves the definition of packages out of `hugr-cli`. (We leave a
re-export there to avoid breaking the API, it should be removed on the
next breaking release).

Adds some helper functions to load and validate the package.

Closes #1530 

drive-by: Replace `hugr-core` dependency on `hugr-cli` by `hugr`.
There's no need to access the internals there.
  • Loading branch information
aborgna-q authored Oct 17, 2024
1 parent a66c4b9 commit d899bd3
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 94 deletions.
3 changes: 2 additions & 1 deletion hugr-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ categories = ["compilers"]
[dependencies]
clap = { workspace = true, features = ["derive"] }
clap-verbosity-flag.workspace = true
hugr-core = { path = "../hugr-core", version = "0.13.1" }
derive_more = { workspace = true, features = ["display", "error", "from"] }
hugr = { path = "../hugr", version = "0.13.1" }
serde_json.workspace = true
serde.workspace = true
thiserror.workspace = true
Expand Down
2 changes: 1 addition & 1 deletion hugr-cli/src/extensions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Dump standard extensions in serialized form.
use clap::Parser;
use hugr_core::extension::ExtensionRegistry;
use hugr::extension::ExtensionRegistry;
use std::{io::Write, path::PathBuf};

/// Dump the standard extensions.
Expand Down
54 changes: 17 additions & 37 deletions hugr-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
use clap::Parser;
use clap_verbosity_flag::{InfoLevel, Verbosity};
use clio::Input;
use hugr_core::{Extension, Hugr};
use derive_more::{Display, Error, From};
use hugr::package::{PackageEncodingError, PackageValidationError};
use std::{ffi::OsString, path::PathBuf};
use thiserror::Error;

pub mod extensions;
pub mod mermaid;
pub mod validate;

// TODO: Deprecated re-export. Remove on a breaking release.
pub use hugr::package::Package;

/// CLI arguments.
#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
Expand All @@ -30,18 +33,21 @@ pub enum CliArgs {
}

/// Error type for the CLI.
#[derive(Debug, Error)]
#[error(transparent)]
#[derive(Debug, Display, Error, From)]
#[non_exhaustive]
pub enum CliError {
/// Error reading input.
#[error("Error reading from path: {0}")]
InputFile(#[from] std::io::Error),
#[display("Error reading from path: {_0}")]
InputFile(std::io::Error),
/// Error parsing input.
#[error("Error parsing input: {0}")]
Parse(#[from] serde_json::Error),
#[display("Error parsing input: {_0}")]
Parse(serde_json::Error),
/// Error loading a package.
#[display("Error parsing package: {_0}")]
Package(PackageEncodingError),
#[display("Error validating HUGR: {_0}")]
/// Errors produced by the `validate` subcommand.
Validate(#[from] validate::ValError),
Validate(PackageValidationError),
}

/// Validate and visualise a HUGR file.
Expand All @@ -68,36 +74,10 @@ pub struct HugrArgs {
pub extensions: Vec<PathBuf>,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
/// Package of module HUGRs and extensions.
/// The HUGRs are validated against the extensions.
pub struct Package {
/// Module HUGRs included in the package.
pub modules: Vec<Hugr>,
/// Extensions to validate against.
pub extensions: Vec<Extension>,
}

impl Package {
/// Create a new package.
pub fn new(modules: Vec<Hugr>, extensions: Vec<Extension>) -> Self {
Self {
modules,
extensions,
}
}
}

impl HugrArgs {
/// Read either a package or a single hugr from the input.
pub fn get_package(&mut self) -> Result<Package, CliError> {
let val: serde_json::Value = serde_json::from_reader(&mut self.input)?;
// read either a package or a single hugr
if let Ok(p) = serde_json::from_value::<Package>(val.clone()) {
Ok(p)
} else {
let hugr: Hugr = serde_json::from_value(val)?;
Ok(Package::new(vec![hugr], vec![]))
}
let pkg = Package::from_json_reader(&mut self.input)?;
Ok(pkg)
}
}
2 changes: 1 addition & 1 deletion hugr-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use clap_verbosity_flag::Level;
fn main() {
match CliArgs::parse() {
CliArgs::Validate(args) => run_validate(args),
CliArgs::GenExtensions(args) => args.run_dump(&hugr_core::std_extensions::STD_REG),
CliArgs::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG),
CliArgs::Mermaid(mut args) => args.run_print().unwrap(),
CliArgs::External(_) => {
// TODO: Implement support for external commands.
Expand Down
2 changes: 1 addition & 1 deletion hugr-cli/src/mermaid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::io::Write;

use clap::Parser;
use clio::Output;
use hugr_core::HugrView;
use hugr::HugrView;

/// Dump the standard extensions.
#[derive(Parser, Debug)]
Expand Down
51 changes: 10 additions & 41 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
use clap::Parser;
use clap_verbosity_flag::Level;
use hugr_core::{extension::ExtensionRegistry, Extension, Hugr};
use thiserror::Error;
use hugr::package::PackageValidationError;
use hugr::{extension::ExtensionRegistry, Extension, Hugr};

use crate::{CliError, HugrArgs, Package};
use crate::{CliError, HugrArgs};

/// Validate and visualise a HUGR file.
#[derive(Parser, Debug)]
Expand All @@ -19,18 +19,6 @@ pub struct ValArgs {
pub hugr_args: HugrArgs,
}

/// Error type for the CLI.
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum ValError {
/// Error validating HUGR.
#[error("Error validating HUGR: {0}")]
Validate(#[from] hugr_core::hugr::ValidationError),
/// Error registering extension.
#[error("Error registering extension: {0}")]
ExtReg(#[from] hugr_core::extension::ExtensionRegistryError),
}

/// String to print when validation is successful.
pub const VALID_PRINT: &str = "HUGR valid!";

Expand All @@ -50,49 +38,30 @@ impl ValArgs {
}
}

impl Package {
/// Validate the package against an extension registry.
///
/// `reg` is updated with any new extensions.
///
/// Returns the validated modules.
pub fn validate(mut self, reg: &mut ExtensionRegistry) -> Result<Vec<Hugr>, ValError> {
// register packed extensions
for ext in self.extensions {
reg.register_updated(ext)?;
}

for hugr in self.modules.iter_mut() {
hugr.update_validate(reg)?;
}

Ok(self.modules)
}
}

impl HugrArgs {
/// Load the package and validate against an extension registry.
///
/// Returns the validated modules and the extension registry the modules
/// were validated against.
pub fn validate(&mut self) -> Result<(Vec<Hugr>, ExtensionRegistry), CliError> {
let package = self.get_package()?;
let mut package = self.get_package()?;

let mut reg: ExtensionRegistry = if self.no_std {
hugr_core::extension::PRELUDE_REGISTRY.to_owned()
hugr::extension::PRELUDE_REGISTRY.to_owned()
} else {
hugr_core::std_extensions::STD_REG.to_owned()
hugr::std_extensions::STD_REG.to_owned()
};

// register external extensions
for ext in &self.extensions {
let f = std::fs::File::open(ext)?;
let ext: Extension = serde_json::from_reader(f)?;
reg.register_updated(ext).map_err(ValError::ExtReg)?;
reg.register_updated(ext)
.map_err(PackageValidationError::Extension)?;
}

let modules = package.validate(&mut reg)?;
Ok((modules, reg))
package.validate(&mut reg)?;
Ok((package.modules, reg))
}

/// Test whether a `level` message should be output.
Expand Down
14 changes: 7 additions & 7 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr_cli::{validate::VALID_PRINT, Package};
use hugr_core::builder::DFGBuilder;
use hugr_core::types::Type;
use hugr_core::{
use hugr::builder::DFGBuilder;
use hugr::types::Type;
use hugr::{
builder::{Container, Dataflow},
extension::prelude::{BOOL_T, QB_T},
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
type_row,
types::Signature,
Hugr,
};
use hugr_cli::{validate::VALID_PRINT, Package};
use predicates::{prelude::*, str::contains};
use rstest::{fixture, rstest};

Expand Down Expand Up @@ -128,7 +128,7 @@ fn test_bad_json(mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains("Error parsing input"));
.stderr(contains("Error parsing package"));
}

#[rstest]
Expand All @@ -139,7 +139,7 @@ fn test_bad_json_silent(mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains("Error parsing input").not());
.stderr(contains("Error parsing package").not());
}

#[rstest]
Expand Down Expand Up @@ -188,7 +188,7 @@ fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) {
#[fixture]
fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: hugr_core::Extension = serde_json::from_reader(rdr).unwrap();
let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap();
let package = Package::new(vec![test_hugr], vec![float_ext]);
serde_json::to_string(&package).unwrap()
}
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ serde = { workspace = true, features = ["derive", "rc"] }
serde_yaml = { workspace = true, optional = true }
typetag = { workspace = true }
smol_str = { workspace = true, features = ["serde"] }
derive_more = { workspace = true, features = ["display", "from"] }
derive_more = { workspace = true, features = ["display", "error", "from"] }
itertools = { workspace = true }
html-escape = { workspace = true }
bitvec = { workspace = true, features = ["serde"] }
Expand Down
42 changes: 39 additions & 3 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub use prelude::{PRELUDE, PRELUDE_REGISTRY};
pub mod declarative;

/// Extension Registries store extensions to be looked up e.g. during validation.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Extension>);

impl ExtensionRegistry {
Expand Down Expand Up @@ -92,6 +92,9 @@ impl ExtensionRegistry {
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept.
/// Returns a reference to the registered extension if successful.
///
/// Avoids cloning the extension unless required. For a reference version see
/// [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(
&mut self,
extension: Extension,
Expand All @@ -107,6 +110,30 @@ impl ExtensionRegistry {
}
}

/// Registers a new extension to the registry, keeping most up to date if
/// extension exists.
///
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept. Returns a reference
/// to the registered extension if successful.
///
/// Clones the extension if required. For no-cloning version see
/// [`ExtensionRegistry::register_updated`].
pub fn register_updated_ref(
&mut self,
extension: &Extension,
) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension.clone();
}
Ok(prev.into_mut())
}
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension.clone())),
}
}

/// Returns the number of extensions in the registry.
pub fn len(&self) -> usize {
self.0.len()
Expand Down Expand Up @@ -418,7 +445,7 @@ impl Extension {

impl PartialEq for Extension {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
self.name == other.name && self.version == other.version
}
}

Expand Down Expand Up @@ -612,7 +639,11 @@ pub mod test {

#[test]
fn test_register_update() {
// Two registers that should remain the same.
// We use them to test both `register_updated` and `register_updated_ref`.
let mut reg = ExtensionRegistry::try_new([]).unwrap();
let mut reg_ref = ExtensionRegistry::try_new([]).unwrap();

let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0));
Expand All @@ -621,7 +652,8 @@ pub mod test {
let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0));

reg.register(ext1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 0, 0));
reg_ref.register(ext1.clone()).unwrap();
assert_eq!(&reg, &reg_ref);

// normal registration fails
assert_eq!(
Expand All @@ -634,12 +666,16 @@ pub mod test {
);

// register with update works
reg_ref.register_updated_ref(&ext1_1).unwrap();
reg.register_updated(ext1_1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(&reg, &reg_ref);

// register with lower version does not change version
reg_ref.register_updated_ref(&ext1_2).unwrap();
reg.register_updated(ext1_2.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));
assert_eq!(&reg, &reg_ref);

reg.register(ext2.clone()).unwrap();
assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0));
Expand Down
1 change: 1 addition & 0 deletions hugr-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod hugr;
pub mod import;
pub mod macros;
pub mod ops;
pub mod package;
pub mod std_extensions;
pub mod types;
pub mod utils;
Expand Down
Loading

0 comments on commit d899bd3

Please sign in to comment.