Skip to content

Commit

Permalink
feat: Export/import of JSON metadata (#1622)
Browse files Browse the repository at this point in the history
Metadata in `hugr-core` is attached to nodes and to `OpDef`s, consisting
of a map from string names to `serde_json::Value`s. On top of that,
`OpDef` also has a `description` field. This PR imports and exports node
metadata by serializing it to a JSON string and wrapping that string
with a `prelude.json` constructor. It also exports the metadata of
`OpDef`s, in which case the description field can be exported as a
string directly. This PR also introduces string escaping for the text
format (#1549).

By wrapping the metadata in a JSON type on the `hugr-model` side, we
leave open the option to have typed metadata via the usual term system
in the future.

Closes #1631.
  • Loading branch information
zrho authored Nov 13, 2024
1 parent 36bbbcf commit 935c61b
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 38 deletions.
63 changes: 58 additions & 5 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Exporting HUGR graphs to their `hugr-model` representation.
use crate::{
extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc},
hugr::IdentList,
hugr::{IdentList, NodeMetadataMap},
ops::{DataflowBlock, OpName, OpTrait, OpType},
types::{
type_param::{TypeArgVariable, TypeParam},
Expand All @@ -21,6 +21,8 @@ type FxIndexSet<T> = IndexSet<T, fxhash::FxBuildHasher>;

pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect";
const TERM_PARAM_TUPLE: &str = "param.tuple";
const TERM_JSON: &str = "prelude.json";
const META_DESCRIPTION: &str = "docs.description";

/// Export a [`Hugr`] graph to its representation in the model.
pub fn export_hugr<'a>(hugr: &'a Hugr, bump: &'a Bump) -> model::Module<'a> {
Expand Down Expand Up @@ -392,14 +394,19 @@ impl<'a> Context<'a> {
let inputs = self.make_ports(node, Direction::Incoming, num_inputs);
let outputs = self.make_ports(node, Direction::Outgoing, num_outputs);

let meta = match self.hugr.get_node_metadata(node) {
Some(metadata_map) => self.export_node_metadata(metadata_map),
None => &[],
};

// Replace the placeholder node with the actual node.
*self.module.get_node_mut(node_id).unwrap() = model::Node {
operation,
inputs,
outputs,
params,
regions,
meta: &[], // TODO: Export metadata
meta,
signature,
};

Expand Down Expand Up @@ -435,7 +442,7 @@ impl<'a> Context<'a> {
outputs: &[],
params: &[],
regions: &[],
meta: &[], // TODO: Metadata
meta: &[],
signature: None,
}))
}
Expand All @@ -452,8 +459,29 @@ impl<'a> Context<'a> {
decl
});

self.module.get_node_mut(node).unwrap().operation =
model::Operation::DeclareOperation { decl };
let meta = {
let description = Some(opdef.description()).filter(|d| !d.is_empty());
let meta_len = opdef.iter_misc().len() + description.is_some() as usize;
let mut meta = BumpVec::with_capacity_in(meta_len, self.bump);

if let Some(description) = description {
let name = META_DESCRIPTION;
let value = self.make_term(model::Term::Str(self.bump.alloc_str(description)));
meta.push(model::MetaItem { name, value })
}

for (name, value) in opdef.iter_misc() {
let name = self.bump.alloc_str(name);
let value = self.export_json(value);
meta.push(model::MetaItem { name, value });
}

self.bump.alloc_slice_copy(&meta)
};

let node_data = self.module.get_node_mut(node).unwrap();
node_data.operation = model::Operation::DeclareOperation { decl };
node_data.meta = meta;

model::GlobalRef::Direct(node)
}
Expand Down Expand Up @@ -843,6 +871,31 @@ impl<'a> Context<'a> {

self.make_term(model::Term::ExtSet { extensions, rest })
}

pub fn export_node_metadata(
&mut self,
metadata_map: &NodeMetadataMap,
) -> &'a [model::MetaItem<'a>] {
let mut meta = BumpVec::with_capacity_in(metadata_map.len(), self.bump);

for (name, value) in metadata_map {
let name = self.bump.alloc_str(name);
let value = self.export_json(value);
meta.push(model::MetaItem { name, value });
}

meta.into_bump_slice()
}

pub fn export_json(&mut self, value: &serde_json::Value) -> model::TermId {
let value = serde_json::to_string(value).expect("json values are always serializable");
let value = self.make_term(model::Term::Str(self.bump.alloc_str(&value)));
let value = self.bump.alloc_slice_copy(&[value]);
self.make_term(model::Term::ApplyFull {
global: model::GlobalRef::Named(TERM_JSON),
args: value,
})
}
}

#[cfg(test)]
Expand Down
5 changes: 5 additions & 0 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ impl OpDef {
self.misc.insert(k.to_string(), v)
}

/// Iterate over all miscellaneous data in the [OpDef].
pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator<Item = (&str, &serde_json::Value)> {
self.misc.iter().map(|(k, v)| (k.as_str(), v))
}

/// Set the constant folding function for this Op, which can evaluate it
/// given constant inputs.
pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) {
Expand Down
39 changes: 39 additions & 0 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use itertools::Either;
use smol_str::{SmolStr, ToSmolStr};
use thiserror::Error;

const TERM_JSON: &str = "prelude.json";

type FxIndexMap<K, V> = IndexMap<K, V, fxhash::FxBuildHasher>;

/// Error during import.
Expand Down Expand Up @@ -184,6 +186,14 @@ impl<'a> Context<'a> {
let node_data = self.get_node(node_id)?;
self.record_links(node, Direction::Incoming, node_data.inputs);
self.record_links(node, Direction::Outgoing, node_data.outputs);

for meta_item in node_data.meta {
// TODO: For now we expect all metadata to be JSON since this is how
// it is handled in `hugr-core`.
let value = self.import_json_value(meta_item.value)?;
self.hugr.set_metadata(node, meta_item.name, value);
}

Ok(node)
}

Expand Down Expand Up @@ -1200,6 +1210,35 @@ impl<'a> Context<'a> {
}
}
}

fn import_json_value(
&mut self,
term_id: model::TermId,
) -> Result<serde_json::Value, ImportError> {
let (global, args) = match self.get_term(term_id)? {
model::Term::Apply { global, args } | model::Term::ApplyFull { global, args } => {
(global, args)
}
_ => return Err(model::ModelError::TypeError(term_id).into()),
};

if global != &GlobalRef::Named(TERM_JSON) {
return Err(model::ModelError::TypeError(term_id).into());
}

let [json_arg] = args else {
return Err(model::ModelError::TypeError(term_id).into());
};

let model::Term::Str(json_str) = self.get_term(*json_arg)? else {
return Err(model::ModelError::TypeError(term_id).into());
};

let json_value =
serde_json::from_str(json_str).map_err(|_| model::ModelError::TypeError(term_id))?;

Ok(json_value)
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
Expand Down
11 changes: 9 additions & 2 deletions hugr-core/tests/snapshots/model__roundtrip_call.snap
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
---
source: hugr-core/tests/model.rs
expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))"
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call.edn\"))"
---
(hugr 0)

(declare-func example.callee
(forall ?0 ext-set)
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int . ?0))
(ext arithmetic.int . ?0)
(meta doc.description (@ prelude.json "\"This is a function declaration.\""))
(meta doc.title (@ prelude.json "\"Callee\"")))

(define-func example.caller
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int)
(meta doc.description
(@
prelude.json
"\"This defines a function that calls the function which we declared earlier.\""))
(meta doc.title (@ prelude.json "\"Caller\""))
(dfg
[%0] [%1]
(signature
Expand Down
8 changes: 6 additions & 2 deletions hugr-model/src/v0/text/hugr.pest
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ ext_name = @{ identifier ~ ("." ~ identifier)* }
symbol = @{ identifier ~ ("." ~ identifier)+ }
tag = @{ (ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) | "0" }

string = @{ "\"" ~ (!("\"") ~ ANY)* ~ "\"" }
list_tail = { "." }
string = { "\"" ~ (string_raw | string_escape | string_unicode)* ~ "\"" }
string_raw = @{ (!("\\" | "\"") ~ ANY)+ }
string_escape = @{ "\\" ~ ("\"" | "\\" | "n" | "r" | "t") }
string_unicode = @{ "\\u" ~ "{" ~ ASCII_HEX_DIGIT+ ~ "}" }

list_tail = { "." }

module = { "(" ~ "hugr" ~ "0" ~ ")" ~ meta* ~ node* ~ EOI }

Expand Down
75 changes: 62 additions & 13 deletions hugr-model/src/v0/text/parse.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use bumpalo::Bump;
use bumpalo::{collections::String as BumpString, Bump};
use pest::{
iterators::{Pair, Pairs},
Parser, RuleType,
Expand Down Expand Up @@ -60,7 +60,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_module(&mut self, pair: Pair<'a, Rule>) -> ParseResult<()> {
debug_assert!(matches!(pair.as_rule(), Rule::module));
debug_assert_eq!(pair.as_rule(), Rule::module);
let mut inner = pair.into_inner();
let meta = self.parse_meta(&mut inner)?;

Expand All @@ -81,7 +81,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_term(&mut self, pair: Pair<'a, Rule>) -> ParseResult<TermId> {
debug_assert!(matches!(pair.as_rule(), Rule::term));
debug_assert_eq!(pair.as_rule(), Rule::term);
let pair = pair.into_inner().next().unwrap();
let rule = pair.as_rule();
let mut inner = pair.into_inner();
Expand Down Expand Up @@ -160,9 +160,7 @@ impl<'a> ParseContext<'a> {
}

Rule::term_str => {
// TODO: Escaping?
let value = inner.next().unwrap().as_str();
let value = &value[1..value.len() - 1];
let value = self.parse_string(inner.next().unwrap())?;
Term::Str(value)
}

Expand Down Expand Up @@ -218,7 +216,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_node(&mut self, pair: Pair<'a, Rule>) -> ParseResult<NodeId> {
debug_assert!(matches!(pair.as_rule(), Rule::node));
debug_assert_eq!(pair.as_rule(), Rule::node);
let pair = pair.into_inner().next().unwrap();
let rule = pair.as_rule();

Expand Down Expand Up @@ -503,7 +501,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_region(&mut self, pair: Pair<'a, Rule>) -> ParseResult<RegionId> {
debug_assert!(matches!(pair.as_rule(), Rule::region));
debug_assert_eq!(pair.as_rule(), Rule::region);
let pair = pair.into_inner().next().unwrap();
let rule = pair.as_rule();
let mut inner = pair.into_inner();
Expand Down Expand Up @@ -541,7 +539,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_func_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a FuncDecl<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::func_header));
debug_assert_eq!(pair.as_rule(), Rule::func_header);

let mut inner = pair.into_inner();
let name = self.parse_symbol(&mut inner)?;
Expand All @@ -566,7 +564,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_alias_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a AliasDecl<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::alias_header));
debug_assert_eq!(pair.as_rule(), Rule::alias_header);

let mut inner = pair.into_inner();
let name = self.parse_symbol(&mut inner)?;
Expand All @@ -581,7 +579,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_ctr_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a ConstructorDecl<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::ctr_header));
debug_assert_eq!(pair.as_rule(), Rule::ctr_header);

let mut inner = pair.into_inner();
let name = self.parse_symbol(&mut inner)?;
Expand All @@ -596,7 +594,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_op_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a OperationDecl<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::operation_header));
debug_assert_eq!(pair.as_rule(), Rule::operation_header);

let mut inner = pair.into_inner();
let name = self.parse_symbol(&mut inner)?;
Expand Down Expand Up @@ -670,7 +668,7 @@ impl<'a> ParseContext<'a> {
}

fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult<LinkRef<'a>> {
debug_assert!(matches!(pair.as_rule(), Rule::port));
debug_assert_eq!(pair.as_rule(), Rule::port);
let mut inner = pair.into_inner();
let link = LinkRef::Named(&inner.next().unwrap().as_str()[1..]);
Ok(link)
Expand All @@ -697,6 +695,47 @@ impl<'a> ParseContext<'a> {
unreachable!("expected a symbol");
}
}

fn parse_string(&self, token: Pair<'a, Rule>) -> ParseResult<&'a str> {
debug_assert_eq!(token.as_rule(), Rule::string);

// Any escape sequence is longer than the character it represents.
// Therefore the length of this token (minus 2 for the quotes on either
// side) is an upper bound for the length of the string.
let capacity = token.as_str().len() - 2;
let mut string = BumpString::with_capacity_in(capacity, self.bump);
let tokens = token.into_inner();

for token in tokens {
match token.as_rule() {
Rule::string_raw => string.push_str(token.as_str()),
Rule::string_escape => match token.as_str().chars().nth(1).unwrap() {
'"' => string.push('"'),
'\\' => string.push('\\'),
'n' => string.push('\n'),
'r' => string.push('\r'),
't' => string.push('\t'),
_ => unreachable!(),
},
Rule::string_unicode => {
let token_str = token.as_str();
debug_assert_eq!(&token_str[0..3], r"\u{");
debug_assert_eq!(&token_str[token_str.len() - 1..], "}");
let code_str = &token_str[3..token_str.len() - 1];
let code = u32::from_str_radix(code_str, 16).map_err(|_| {
ParseError::custom("invalid unicode escape sequence", token.as_span())
})?;
let char = std::char::from_u32(code).ok_or_else(|| {
ParseError::custom("invalid unicode code point", token.as_span())
})?;
string.push(char);
}
_ => unreachable!(),
}
}

Ok(string.into_bump_str())
}
}

/// Draw from a pest pair iterator only the pairs that match a given rule.
Expand Down Expand Up @@ -750,6 +789,16 @@ impl ParseError {
InputLocation::Span((offset, _)) => offset,
}
}

fn custom(message: &str, span: pest::Span) -> Self {
let error = pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: message.to_string(),
},
span,
);
ParseError(Box::new(error))
}
}

// NOTE: `ParseError` does not implement `From<pest::error::Error<Rule>>` so that
Expand Down
Loading

0 comments on commit 935c61b

Please sign in to comment.