From 935c61bbe4f66466fde4423f58b2a06bbe2aa536 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 13 Nov 2024 10:13:40 +0000 Subject: [PATCH] feat: Export/import of JSON metadata (#1622) Metadata in `hugr-core` is attached to nodes and to `OpDef`s, consisting of a map from string names to `serde_json::Value`s. On top of that, `OpDef` also has a `description` field. This PR imports and exports node metadata by serializing it to a JSON string and wrapping that string with a `prelude.json` constructor. It also exports the metadata of `OpDef`s, in which case the description field can be exported as a string directly. This PR also introduces string escaping for the text format (#1549). By wrapping the metadata in a JSON type on the `hugr-model` side, we leave open the option to have typed metadata via the usual term system in the future. Closes #1631. --- hugr-core/src/export.rs | 63 ++++++++++++++-- hugr-core/src/extension/op_def.rs | 5 ++ hugr-core/src/import.rs | 39 ++++++++++ .../snapshots/model__roundtrip_call.snap | 11 ++- hugr-model/src/v0/text/hugr.pest | 8 +- hugr-model/src/v0/text/parse.rs | 75 +++++++++++++++---- hugr-model/src/v0/text/print.rs | 30 +++++--- hugr-model/tests/fixtures/model-call.edn | 8 +- hugr-model/tests/fixtures/model-literals.edn | 3 + .../tests/snapshots/text__literals.snap | 7 ++ hugr-model/tests/text.rs | 5 ++ 11 files changed, 216 insertions(+), 38 deletions(-) create mode 100644 hugr-model/tests/fixtures/model-literals.edn create mode 100644 hugr-model/tests/snapshots/text__literals.snap diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index e7a85c98f..68f3a15c0 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1,7 +1,7 @@ //! Exporting HUGR graphs to their `hugr-model` representation. use crate::{ extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc}, - hugr::IdentList, + hugr::{IdentList, NodeMetadataMap}, ops::{DataflowBlock, OpName, OpTrait, OpType}, types::{ type_param::{TypeArgVariable, TypeParam}, @@ -21,6 +21,8 @@ type FxIndexSet = IndexSet; pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect"; const TERM_PARAM_TUPLE: &str = "param.tuple"; +const TERM_JSON: &str = "prelude.json"; +const META_DESCRIPTION: &str = "docs.description"; /// Export a [`Hugr`] graph to its representation in the model. pub fn export_hugr<'a>(hugr: &'a Hugr, bump: &'a Bump) -> model::Module<'a> { @@ -392,6 +394,11 @@ impl<'a> Context<'a> { let inputs = self.make_ports(node, Direction::Incoming, num_inputs); let outputs = self.make_ports(node, Direction::Outgoing, num_outputs); + let meta = match self.hugr.get_node_metadata(node) { + Some(metadata_map) => self.export_node_metadata(metadata_map), + None => &[], + }; + // Replace the placeholder node with the actual node. *self.module.get_node_mut(node_id).unwrap() = model::Node { operation, @@ -399,7 +406,7 @@ impl<'a> Context<'a> { outputs, params, regions, - meta: &[], // TODO: Export metadata + meta, signature, }; @@ -435,7 +442,7 @@ impl<'a> Context<'a> { outputs: &[], params: &[], regions: &[], - meta: &[], // TODO: Metadata + meta: &[], signature: None, })) } @@ -452,8 +459,29 @@ impl<'a> Context<'a> { decl }); - self.module.get_node_mut(node).unwrap().operation = - model::Operation::DeclareOperation { decl }; + let meta = { + let description = Some(opdef.description()).filter(|d| !d.is_empty()); + let meta_len = opdef.iter_misc().len() + description.is_some() as usize; + let mut meta = BumpVec::with_capacity_in(meta_len, self.bump); + + if let Some(description) = description { + let name = META_DESCRIPTION; + let value = self.make_term(model::Term::Str(self.bump.alloc_str(description))); + meta.push(model::MetaItem { name, value }) + } + + for (name, value) in opdef.iter_misc() { + let name = self.bump.alloc_str(name); + let value = self.export_json(value); + meta.push(model::MetaItem { name, value }); + } + + self.bump.alloc_slice_copy(&meta) + }; + + let node_data = self.module.get_node_mut(node).unwrap(); + node_data.operation = model::Operation::DeclareOperation { decl }; + node_data.meta = meta; model::GlobalRef::Direct(node) } @@ -843,6 +871,31 @@ impl<'a> Context<'a> { self.make_term(model::Term::ExtSet { extensions, rest }) } + + pub fn export_node_metadata( + &mut self, + metadata_map: &NodeMetadataMap, + ) -> &'a [model::MetaItem<'a>] { + let mut meta = BumpVec::with_capacity_in(metadata_map.len(), self.bump); + + for (name, value) in metadata_map { + let name = self.bump.alloc_str(name); + let value = self.export_json(value); + meta.push(model::MetaItem { name, value }); + } + + meta.into_bump_slice() + } + + pub fn export_json(&mut self, value: &serde_json::Value) -> model::TermId { + let value = serde_json::to_string(value).expect("json values are always serializable"); + let value = self.make_term(model::Term::Str(self.bump.alloc_str(&value))); + let value = self.bump.alloc_slice_copy(&[value]); + self.make_term(model::Term::ApplyFull { + global: model::GlobalRef::Named(TERM_JSON), + args: value, + }) + } } #[cfg(test)] diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 6c1a49d9e..110da4124 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -435,6 +435,11 @@ impl OpDef { self.misc.insert(k.to_string(), v) } + /// Iterate over all miscellaneous data in the [OpDef]. + pub(crate) fn iter_misc(&self) -> impl ExactSizeIterator { + self.misc.iter().map(|(k, v)| (k.as_str(), v)) + } + /// Set the constant folding function for this Op, which can evaluate it /// given constant inputs. pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index d981049fb..feabe815f 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -26,6 +26,8 @@ use itertools::Either; use smol_str::{SmolStr, ToSmolStr}; use thiserror::Error; +const TERM_JSON: &str = "prelude.json"; + type FxIndexMap = IndexMap; /// Error during import. @@ -184,6 +186,14 @@ impl<'a> Context<'a> { let node_data = self.get_node(node_id)?; self.record_links(node, Direction::Incoming, node_data.inputs); self.record_links(node, Direction::Outgoing, node_data.outputs); + + for meta_item in node_data.meta { + // TODO: For now we expect all metadata to be JSON since this is how + // it is handled in `hugr-core`. + let value = self.import_json_value(meta_item.value)?; + self.hugr.set_metadata(node, meta_item.name, value); + } + Ok(node) } @@ -1200,6 +1210,35 @@ impl<'a> Context<'a> { } } } + + fn import_json_value( + &mut self, + term_id: model::TermId, + ) -> Result { + let (global, args) = match self.get_term(term_id)? { + model::Term::Apply { global, args } | model::Term::ApplyFull { global, args } => { + (global, args) + } + _ => return Err(model::ModelError::TypeError(term_id).into()), + }; + + if global != &GlobalRef::Named(TERM_JSON) { + return Err(model::ModelError::TypeError(term_id).into()); + } + + let [json_arg] = args else { + return Err(model::ModelError::TypeError(term_id).into()); + }; + + let model::Term::Str(json_str) = self.get_term(*json_arg)? else { + return Err(model::ModelError::TypeError(term_id).into()); + }; + + let json_value = + serde_json::from_str(json_str).map_err(|_| model::ModelError::TypeError(term_id))?; + + Ok(json_value) + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index a799a0944..460d8f4c0 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call.edn\"))" --- (hugr 0) @@ -8,12 +8,19 @@ expression: "roundtrip(include_str!(\"fixtures/model-call.edn\"))" (forall ?0 ext-set) [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] - (ext arithmetic.int . ?0)) + (ext arithmetic.int . ?0) + (meta doc.description (@ prelude.json "\"This is a function declaration.\"")) + (meta doc.title (@ prelude.json "\"Callee\""))) (define-func example.caller [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int) + (meta doc.description + (@ + prelude.json + "\"This defines a function that calls the function which we declared earlier.\"")) + (meta doc.title (@ prelude.json "\"Caller\"")) (dfg [%0] [%1] (signature diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 33974a76a..132d78567 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -5,8 +5,12 @@ ext_name = @{ identifier ~ ("." ~ identifier)* } symbol = @{ identifier ~ ("." ~ identifier)+ } tag = @{ (ASCII_NONZERO_DIGIT ~ ASCII_DIGIT*) | "0" } -string = @{ "\"" ~ (!("\"") ~ ANY)* ~ "\"" } -list_tail = { "." } +string = { "\"" ~ (string_raw | string_escape | string_unicode)* ~ "\"" } +string_raw = @{ (!("\\" | "\"") ~ ANY)+ } +string_escape = @{ "\\" ~ ("\"" | "\\" | "n" | "r" | "t") } +string_unicode = @{ "\\u" ~ "{" ~ ASCII_HEX_DIGIT+ ~ "}" } + +list_tail = { "." } module = { "(" ~ "hugr" ~ "0" ~ ")" ~ meta* ~ node* ~ EOI } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index b669ce38c..fa486454b 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -1,4 +1,4 @@ -use bumpalo::Bump; +use bumpalo::{collections::String as BumpString, Bump}; use pest::{ iterators::{Pair, Pairs}, Parser, RuleType, @@ -60,7 +60,7 @@ impl<'a> ParseContext<'a> { } fn parse_module(&mut self, pair: Pair<'a, Rule>) -> ParseResult<()> { - debug_assert!(matches!(pair.as_rule(), Rule::module)); + debug_assert_eq!(pair.as_rule(), Rule::module); let mut inner = pair.into_inner(); let meta = self.parse_meta(&mut inner)?; @@ -81,7 +81,7 @@ impl<'a> ParseContext<'a> { } fn parse_term(&mut self, pair: Pair<'a, Rule>) -> ParseResult { - debug_assert!(matches!(pair.as_rule(), Rule::term)); + debug_assert_eq!(pair.as_rule(), Rule::term); let pair = pair.into_inner().next().unwrap(); let rule = pair.as_rule(); let mut inner = pair.into_inner(); @@ -160,9 +160,7 @@ impl<'a> ParseContext<'a> { } Rule::term_str => { - // TODO: Escaping? - let value = inner.next().unwrap().as_str(); - let value = &value[1..value.len() - 1]; + let value = self.parse_string(inner.next().unwrap())?; Term::Str(value) } @@ -218,7 +216,7 @@ impl<'a> ParseContext<'a> { } fn parse_node(&mut self, pair: Pair<'a, Rule>) -> ParseResult { - debug_assert!(matches!(pair.as_rule(), Rule::node)); + debug_assert_eq!(pair.as_rule(), Rule::node); let pair = pair.into_inner().next().unwrap(); let rule = pair.as_rule(); @@ -503,7 +501,7 @@ impl<'a> ParseContext<'a> { } fn parse_region(&mut self, pair: Pair<'a, Rule>) -> ParseResult { - debug_assert!(matches!(pair.as_rule(), Rule::region)); + debug_assert_eq!(pair.as_rule(), Rule::region); let pair = pair.into_inner().next().unwrap(); let rule = pair.as_rule(); let mut inner = pair.into_inner(); @@ -541,7 +539,7 @@ impl<'a> ParseContext<'a> { } fn parse_func_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a FuncDecl<'a>> { - debug_assert!(matches!(pair.as_rule(), Rule::func_header)); + debug_assert_eq!(pair.as_rule(), Rule::func_header); let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; @@ -566,7 +564,7 @@ impl<'a> ParseContext<'a> { } fn parse_alias_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a AliasDecl<'a>> { - debug_assert!(matches!(pair.as_rule(), Rule::alias_header)); + debug_assert_eq!(pair.as_rule(), Rule::alias_header); let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; @@ -581,7 +579,7 @@ impl<'a> ParseContext<'a> { } fn parse_ctr_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a ConstructorDecl<'a>> { - debug_assert!(matches!(pair.as_rule(), Rule::ctr_header)); + debug_assert_eq!(pair.as_rule(), Rule::ctr_header); let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; @@ -596,7 +594,7 @@ impl<'a> ParseContext<'a> { } fn parse_op_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a OperationDecl<'a>> { - debug_assert!(matches!(pair.as_rule(), Rule::operation_header)); + debug_assert_eq!(pair.as_rule(), Rule::operation_header); let mut inner = pair.into_inner(); let name = self.parse_symbol(&mut inner)?; @@ -670,7 +668,7 @@ impl<'a> ParseContext<'a> { } fn parse_port(&mut self, pair: Pair<'a, Rule>) -> ParseResult> { - debug_assert!(matches!(pair.as_rule(), Rule::port)); + debug_assert_eq!(pair.as_rule(), Rule::port); let mut inner = pair.into_inner(); let link = LinkRef::Named(&inner.next().unwrap().as_str()[1..]); Ok(link) @@ -697,6 +695,47 @@ impl<'a> ParseContext<'a> { unreachable!("expected a symbol"); } } + + fn parse_string(&self, token: Pair<'a, Rule>) -> ParseResult<&'a str> { + debug_assert_eq!(token.as_rule(), Rule::string); + + // Any escape sequence is longer than the character it represents. + // Therefore the length of this token (minus 2 for the quotes on either + // side) is an upper bound for the length of the string. + let capacity = token.as_str().len() - 2; + let mut string = BumpString::with_capacity_in(capacity, self.bump); + let tokens = token.into_inner(); + + for token in tokens { + match token.as_rule() { + Rule::string_raw => string.push_str(token.as_str()), + Rule::string_escape => match token.as_str().chars().nth(1).unwrap() { + '"' => string.push('"'), + '\\' => string.push('\\'), + 'n' => string.push('\n'), + 'r' => string.push('\r'), + 't' => string.push('\t'), + _ => unreachable!(), + }, + Rule::string_unicode => { + let token_str = token.as_str(); + debug_assert_eq!(&token_str[0..3], r"\u{"); + debug_assert_eq!(&token_str[token_str.len() - 1..], "}"); + let code_str = &token_str[3..token_str.len() - 1]; + let code = u32::from_str_radix(code_str, 16).map_err(|_| { + ParseError::custom("invalid unicode escape sequence", token.as_span()) + })?; + let char = std::char::from_u32(code).ok_or_else(|| { + ParseError::custom("invalid unicode code point", token.as_span()) + })?; + string.push(char); + } + _ => unreachable!(), + } + } + + Ok(string.into_bump_str()) + } } /// Draw from a pest pair iterator only the pairs that match a given rule. @@ -750,6 +789,16 @@ impl ParseError { InputLocation::Span((offset, _)) => offset, } } + + fn custom(message: &str, span: pest::Span) -> Self { + let error = pest::error::Error::new_from_span( + pest::error::ErrorVariant::CustomError { + message: message.to_string(), + }, + span, + ); + ParseError(Box::new(error)) + } } // NOTE: `ParseError` does not implement `From>` so that diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 494c10df2..01b9d7195 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -1,4 +1,4 @@ -use pretty::{docs, Arena, DocAllocator, RefDoc}; +use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ @@ -668,16 +668,22 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { } /// Print a string literal. - fn print_string(&mut self, string: &'p str) { - // TODO: escape - self.docs.push( - docs![ - self.arena, - self.arena.text("\""), - self.arena.text(string), - self.arena.text("\"") - ] - .into_doc(), - ); + fn print_string(&mut self, string: &str) { + let mut output = String::with_capacity(string.len() + 2); + output.push('"'); + + for c in string.chars() { + match c { + '\\' => output.push_str("\\\\"), + '"' => output.push_str("\\\""), + '\n' => output.push_str("\\n"), + '\r' => output.push_str("\\r"), + '\t' => output.push_str("\\t"), + _ => output.push(c), + } + } + + output.push('"'); + self.print_text(output); } } diff --git a/hugr-model/tests/fixtures/model-call.edn b/hugr-model/tests/fixtures/model-call.edn index d463e391a..ce849a772 100644 --- a/hugr-model/tests/fixtures/model-call.edn +++ b/hugr-model/tests/fixtures/model-call.edn @@ -3,13 +3,13 @@ (declare-func example.callee (forall ?ext ext-set) [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int . ?ext) - (meta doc.title "Callee") - (meta doc.description "This is a function declaration.")) + (meta doc.title (prelude.json "\"Callee\"")) + (meta doc.description (prelude.json "\"This is a function declaration.\""))) (define-func example.caller [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int) - (meta doc.title "Caller") - (meta doc.description "This defines a function that calls the function which we declared earlier.") + (meta doc.title (prelude.json "\"Caller\"")) + (meta doc.description (prelude.json "\"This defines a function that calls the function which we declared earlier.\"")) (dfg [%3] [%4] (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) (call (@ example.callee (ext)) [%3] [%4] diff --git a/hugr-model/tests/fixtures/model-literals.edn b/hugr-model/tests/fixtures/model-literals.edn new file mode 100644 index 000000000..552155dda --- /dev/null +++ b/hugr-model/tests/fixtures/model-literals.edn @@ -0,0 +1,3 @@ +(hugr 0) + +(define-alias mod.string str "\"\n\r\t\\\u{1F44D}") diff --git a/hugr-model/tests/snapshots/text__literals.snap b/hugr-model/tests/snapshots/text__literals.snap new file mode 100644 index 000000000..4a639d8e7 --- /dev/null +++ b/hugr-model/tests/snapshots/text__literals.snap @@ -0,0 +1,7 @@ +--- +source: hugr-model/tests/text.rs +expression: "roundtrip(include_str!(\"fixtures/model-literals.edn\"))" +--- +(hugr 0) + +(define-alias mod.string str "\"\n\r\t\\👍") diff --git a/hugr-model/tests/text.rs b/hugr-model/tests/text.rs index ba0a7c68e..d6e07b4ff 100644 --- a/hugr-model/tests/text.rs +++ b/hugr-model/tests/text.rs @@ -11,3 +11,8 @@ fn roundtrip(source: &str) -> String { pub fn test_declarative_extensions() { insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-decl-exts.edn"))) } + +#[test] +pub fn test_literals() { + insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-literals.edn"))) +}