diff --git a/pbjson-types/Cargo.toml b/pbjson-types/Cargo.toml index f9069d8..485af0f 100644 --- a/pbjson-types/Cargo.toml +++ b/pbjson-types/Cargo.toml @@ -10,6 +10,7 @@ categories = ["encoding"] repository = "https://github.com/influxdata/pbjson" [dependencies] # In alphabetical order +base64 = "0.13.0" bytes = "1.0" chrono = "0.4" pbjson = { path = "../pbjson", version = "0.2" } @@ -17,6 +18,7 @@ prost = "0.9" serde = { version = "1.0", features = ["derive"] } [dev-dependencies] +pretty_assertions = "1.0.0" serde_json = "1.0" [build-dependencies] # In alphabetical order diff --git a/pbjson-types/build.rs b/pbjson-types/build.rs index 748f403..2aa0063 100644 --- a/pbjson-types/build.rs +++ b/pbjson-types/build.rs @@ -29,6 +29,7 @@ fn main() -> Result<()> { pbjson_build::Builder::new() .register_descriptors(&descriptor_set)? .exclude([ + ".google.protobuf.Any", ".google.protobuf.Duration", ".google.protobuf.Timestamp", ".google.protobuf.Value", diff --git a/pbjson-types/src/any.rs b/pbjson-types/src/any.rs new file mode 100644 index 0000000..0dbb345 --- /dev/null +++ b/pbjson-types/src/any.rs @@ -0,0 +1,247 @@ +const TYPE_FIELD: &str = "@type"; +const VALUE_FIELD: &str = "value"; + +impl serde::Serialize for crate::Any { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::{SerializeMap, SerializeStruct}; + + const NAME: &str = "google.protobuf.Any"; + + let value = ::decode(self.value.clone()).map_err(|_| { + serde::ser::Error::custom( + "Couldn't transcode google.protobuf.Any value into google.protobuf.Value", + ) + })?; + + let mut field_length = 1; + match value.kind { + Some(crate::value::Kind::StructValue(map)) => { + field_length += map.len(); + let mut map_ser = serializer.serialize_map(Some(field_length))?; + map_ser.serialize_entry(TYPE_FIELD, &self.type_url)?; + + for (k, v) in &map { + map_ser.serialize_entry(k, v)?; + } + + map_ser.end() + } + _ => { + let mut struct_ser = serializer.serialize_struct(NAME, 2)?; + struct_ser.serialize_field(TYPE_FIELD, &self.type_url)?; + struct_ser.serialize_field(VALUE_FIELD, &base64::encode(&self.value))?; + struct_ser.end() + } + } + } +} + +impl<'de> serde::Deserialize<'de> for crate::Any { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(AnyVisitor) + } +} + +struct AnyVisitor; + +impl<'de> serde::de::Visitor<'de> for AnyVisitor { + type Value = crate::Any; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct google.protobuf.Any") + } + + fn visit_map(self, mut map_access: V) -> Result + where + V: serde::de::MapAccess<'de>, + { + enum BytesOrValue { + Bytes(bytes::Bytes), + Value(crate::Value), + } + + let mut type_url = None; + let mut value: Option = None; + let mut map = std::collections::HashMap::new(); + while let Some(k) = map_access.next_key()? { + match k { + AnyField::TypeUrl => { + if type_url.is_some() { + return Err(serde::de::Error::duplicate_field(TYPE_FIELD)); + } + type_url = Some(map_access.next_value()?); + } + AnyField::Value => { + if value.is_some() { + return Err(serde::de::Error::duplicate_field(VALUE_FIELD)); + } + + value = if let Ok(bytes) = + map_access.next_value::<::pbjson::private::BytesDeserialize<_>>() + { + Some(BytesOrValue::Bytes(bytes.0)) + } else { + Some(BytesOrValue::Value(map_access.next_value()?)) + }; + } + AnyField::Unknown(key) => { + if map.contains_key(&key) { + return Err(serde::de::Error::custom(format!( + "Duplicate field: {}", + &key + ))); + } + + map.insert(key, map_access.next_value::()?); + } + } + } + + macro_rules! encode_map { + () => {{ + use prost::Message; + + let mut buffer = Vec::new(); + crate::Value::from(map) + .encode(&mut buffer) + .map_err(serde::de::Error::custom)?; + buffer.into() + }}; + } + + let value = match value { + Some(BytesOrValue::Bytes(bytes)) => bytes, + Some(BytesOrValue::Value(value)) => { + map.insert("value".into(), value); + encode_map!() + } + None if map.is_empty() => return Err(serde::de::Error::missing_field(VALUE_FIELD)), + None => encode_map!(), + }; + + Ok(crate::Any { + type_url: type_url.ok_or_else(|| serde::de::Error::missing_field(TYPE_FIELD))?, + value, + }) + } +} + +#[derive(Debug)] +enum AnyField { + TypeUrl, + Value, + Unknown(String), +} + +impl<'de> serde::Deserialize<'de> for AnyField { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct AnyFieldVisitor; + + impl<'de> serde::de::Visitor<'de> for AnyFieldVisitor { + type Value = AnyField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected a string key") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + TYPE_FIELD => Ok(AnyField::TypeUrl), + VALUE_FIELD => Ok(AnyField::Value), + value => Ok(AnyField::Unknown(value.to_owned())), + } + } + } + + deserializer.deserialize_identifier(AnyFieldVisitor) + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use prost::Message; + + #[test] + fn object() { + let map = crate::Value::from(std::collections::HashMap::from([ + (String::from("bool"), true.into()), + (String::from("unit"), crate::Value::from(None)), + (String::from("number"), 5.0.into()), + (String::from("string"), "string".into()), + (String::from("list"), vec![1.0.into(), 2.0.into()].into()), + ( + String::from("map"), + std::collections::HashMap::from([(String::from("key"), "value".into())]).into(), + ), + ])); + + let any = crate::Any { + type_url: "google.protobuf.Value".into(), + value: map.encode_to_vec().into(), + }; + + let json = serde_json::to_value(&any).unwrap(); + + assert_eq!( + json, + serde_json::json!({ + "@type": "google.protobuf.Value", + "bool": true, + "unit": null, + "number": 5.0, + "string": "string", + "list": [1.0, 2.0], + "map": { + "key": "value", + } + }) + ); + + let decoded = serde_json::from_value::(json).unwrap(); + assert_eq!(decoded.type_url, any.type_url); + assert_eq!( + crate::Value::decode(decoded.value).unwrap(), + crate::Value::decode(any.value).unwrap() + ); + } + + #[test] + fn primitive_value() { + let boolean = crate::Value::from(true); + let protobuf_encoding = boolean.encode_to_vec(); + + let any = crate::Any { + type_url: "google.protobuf.Value".into(), + value: protobuf_encoding.clone().into(), + }; + + let json = serde_json::to_value(&any).unwrap(); + let expected = serde_json::json!({ + "@type": "google.protobuf.Value", + "value": base64::encode(&protobuf_encoding), + }); + + assert_eq!(json, expected); + + let decoded = serde_json::from_value::(expected).unwrap(); + assert_eq!(decoded.type_url, any.type_url); + assert_eq!( + crate::Value::decode(decoded.value).unwrap(), + crate::Value::decode(any.value).unwrap() + ); + } +} diff --git a/pbjson-types/src/lib.rs b/pbjson-types/src/lib.rs index 48b808b..9c1a269 100644 --- a/pbjson-types/src/lib.rs +++ b/pbjson-types/src/lib.rs @@ -36,6 +36,7 @@ mod pb { } } +mod any; mod duration; mod list_value; mod null_value; diff --git a/pbjson-types/src/struct.rs b/pbjson-types/src/struct.rs index 428dcaa..1ea9d79 100644 --- a/pbjson-types/src/struct.rs +++ b/pbjson-types/src/struct.rs @@ -1,15 +1,31 @@ use crate::Struct; -impl From> for Struct { +type ValueMap = std::collections::HashMap; +type ValueEntry = (String, crate::Value); + +impl std::ops::Deref for crate::Struct { + type Target = ValueMap; + fn deref(&self) -> &Self::Target { + &self.fields + } +} + +impl std::ops::DerefMut for crate::Struct { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.fields + } +} + +impl From for crate::Struct { fn from(fields: std::collections::HashMap) -> Self { Self { fields } } } -impl FromIterator<(String, crate::Value)> for Struct { +impl FromIterator for crate::Struct { fn from_iter(iter: T) -> Self where - T: IntoIterator, + T: IntoIterator, { Self { fields: iter.into_iter().collect(), @@ -58,6 +74,35 @@ impl<'de> serde::de::Visitor<'de> for StructVisitor { } } +impl IntoIterator for crate::Struct { + type Item = ::Item; + type IntoIter = ::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.fields.into_iter() + } +} + +impl<'r> IntoIterator for &'r crate::Struct { + type Item = <&'r ValueMap as IntoIterator>::Item; + type IntoIter = <&'r ValueMap as IntoIterator>::IntoIter; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.fields.iter() + } +} + +impl<'r> IntoIterator for &'r mut crate::Struct { + type Item = <&'r mut ValueMap as IntoIterator>::Item; + type IntoIter = <&'r mut ValueMap as IntoIterator>::IntoIter; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.fields.iter_mut() + } +} + #[cfg(test)] mod tests { #[test] diff --git a/pbjson/src/lib.rs b/pbjson/src/lib.rs index 17d9a49..876ab49 100644 --- a/pbjson/src/lib.rs +++ b/pbjson/src/lib.rs @@ -64,7 +64,7 @@ pub mod private { where D: serde::Deserializer<'de>, { - let s: &str = Deserialize::deserialize(deserializer)?; + let s: String = Deserialize::deserialize(deserializer)?; let decoded = base64::decode(s).map_err(serde::de::Error::custom)?; Ok(Self(decoded.into())) }