diff --git a/Cargo.lock b/Cargo.lock index c1b1ec57fdec..4e648e08a3fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9216,9 +9216,12 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55a6a9143ae25c25fa7b6a48d6cc08b10785372060009c25140a4e7c340e95af" dependencies = [ + "base64 0.22.0", "once_cell", "prost 0.13.1", "prost-types 0.13.1", + "serde", + "serde-value", ] [[package]] diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index d87e89c1cf65..a77e9cb929d1 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -103,7 +103,7 @@ pg_bigdecimal = { git = "https://github.com/risingwavelabs/rust-pg_bigdecimal", postgres-openssl = "0.5.0" prometheus = { version = "0.13", features = ["process"] } prost = { workspace = true, features = ["no-recursion-limit"] } -prost-reflect = "0.14" +prost-reflect = { version = "0.14", features = ["serde"] } prost-types = "0.13" protobuf-native = "0.2.2" pulsar = { version = "6.3", default-features = false, features = [ diff --git a/src/connector/codec/src/decoder/mod.rs b/src/connector/codec/src/decoder/mod.rs index 814e06a166c6..bbfdbf0a90d7 100644 --- a/src/connector/codec/src/decoder/mod.rs +++ b/src/connector/codec/src/decoder/mod.rs @@ -38,6 +38,9 @@ pub enum AccessError { #[error("Unsupported additional column `{name}`")] UnsupportedAdditionalColumn { name: String }, + #[error("Fail to convert protobuf Any into jsonb: {0}")] + ProtobufAnyToJson(#[source] serde_json::Error), + /// Errors that are not categorized into variants above. #[error("{message}")] Uncategorized { message: String }, diff --git a/src/connector/src/parser/protobuf/parser.rs b/src/connector/src/parser/protobuf/parser.rs index 8be25074f629..ec8c747cafd5 100644 --- a/src/connector/src/parser/protobuf/parser.rs +++ b/src/connector/src/parser/protobuf/parser.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use anyhow::Context; use itertools::Itertools; use prost_reflect::{ @@ -22,8 +20,7 @@ use prost_reflect::{ }; use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::types::{ - DataType, Datum, DatumCow, Decimal, JsonbRef, JsonbVal, ScalarImpl, ScalarRefImpl, ToDatumRef, - ToOwnedDatum, F32, F64, + DataType, DatumCow, Decimal, JsonbVal, ScalarImpl, ToOwnedDatum, F32, F64, }; use risingwave_common::{bail, try_match_expand}; use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion}; @@ -32,9 +29,7 @@ use thiserror_ext::{AsReport, Macro}; use crate::error::ConnectorResult; use crate::parser::unified::protobuf::ProtobufAccess; -use crate::parser::unified::{ - bail_uncategorized, uncategorized, AccessError, AccessImpl, AccessResult, -}; +use crate::parser::unified::{uncategorized, AccessError, AccessImpl, AccessResult}; use crate::parser::util::bytes_from_url; use crate::parser::{AccessBuilder, EncodingProperties}; use crate::schema::schema_registry::{extract_schema_id, handle_sr_list, Client, WireFormatError}; @@ -44,7 +39,6 @@ use crate::schema::SchemaLoader; pub struct ProtobufAccessBuilder { confluent_wire_type: bool, message_descriptor: MessageDescriptor, - descriptor_pool: Arc, } impl AccessBuilder for ProtobufAccessBuilder { @@ -59,10 +53,7 @@ impl AccessBuilder for ProtobufAccessBuilder { let message = DynamicMessage::decode(self.message_descriptor.clone(), payload) .context("failed to parse message")?; - Ok(AccessImpl::Protobuf(ProtobufAccess::new( - message, - Arc::clone(&self.descriptor_pool), - ))) + Ok(AccessImpl::Protobuf(ProtobufAccess::new(message))) } } @@ -71,13 +62,11 @@ impl ProtobufAccessBuilder { let ProtobufParserConfig { confluent_wire_type, message_descriptor, - descriptor_pool, } = config; Ok(Self { confluent_wire_type, message_descriptor, - descriptor_pool, }) } } @@ -86,8 +75,6 @@ impl ProtobufAccessBuilder { pub struct ProtobufParserConfig { confluent_wire_type: bool, pub(crate) message_descriptor: MessageDescriptor, - /// Note that the pub(crate) here is merely for testing - pub(crate) descriptor_pool: Arc, } impl ProtobufParserConfig { @@ -132,7 +119,6 @@ impl ProtobufParserConfig { Ok(Self { message_descriptor, confluent_wire_type: protobuf_config.use_schema_registry, - descriptor_pool: Arc::new(pool), }) } @@ -216,141 +202,9 @@ fn detect_loop_and_push( Ok(()) } -fn extract_any_info(dyn_msg: &DynamicMessage) -> (String, Value) { - debug_assert!( - dyn_msg.fields().count() == 2, - "Expected only two fields for Any Type MessageDescriptor" - ); - - let type_url = dyn_msg - .get_field_by_name("type_url") - .expect("Expect type_url in dyn_msg") - .to_string() - .split('/') - .nth(1) - .map(|part| part[..part.len() - 1].to_string()) - .unwrap_or_default(); - - let payload = dyn_msg - .get_field_by_name("value") - .expect("Expect value (payload) in dyn_msg") - .as_ref() - .clone(); - - (type_url, payload) -} - -/// TODO: Resolve the potential naming conflict in the map -/// i.e., If the two anonymous type shares the same key (e.g., "Int32"), -/// the latter will overwrite the former one in `serde_json::Map`. -/// Possible solution, maintaining a global id map, for the same types -/// In the same level of fields, add the unique id at the tail of the name. -/// e.g., "Int32.1" & "Int32.2" in the above example -fn recursive_parse_json( - fields: &[Datum], - full_name_vec: Option>, - full_name: Option, -) -> serde_json::Value { - // Note that the key is of no order - let mut ret: serde_json::Map = serde_json::Map::new(); - - // The hidden type hint for user's convenience - // i.e., `"_type": message.full_name()` - if let Some(full_name) = full_name { - ret.insert("_type".to_string(), serde_json::Value::String(full_name)); - } - - for (idx, field) in fields.iter().enumerate() { - let mut key; - if let Some(k) = full_name_vec.as_ref() { - key = k[idx].to_string(); - } else { - key = "".to_string(); - } - - match field.clone() { - Some(ScalarImpl::Int16(v)) => { - if key.is_empty() { - key = "Int16".to_string(); - } - ret.insert(key, serde_json::Value::Number(serde_json::Number::from(v))); - } - Some(ScalarImpl::Int32(v)) => { - if key.is_empty() { - key = "Int32".to_string(); - } - ret.insert(key, serde_json::Value::Number(serde_json::Number::from(v))); - } - Some(ScalarImpl::Int64(v)) => { - if key.is_empty() { - key = "Int64".to_string(); - } - ret.insert(key, serde_json::Value::Number(serde_json::Number::from(v))); - } - Some(ScalarImpl::Bool(v)) => { - if key.is_empty() { - key = "Bool".to_string(); - } - ret.insert(key, serde_json::Value::Bool(v)); - } - Some(ScalarImpl::Bytea(v)) => { - if key.is_empty() { - key = "Bytea".to_string(); - } - let s = String::from_utf8(v.to_vec()).unwrap(); - ret.insert(key, serde_json::Value::String(s)); - } - Some(ScalarImpl::Float32(v)) => { - if key.is_empty() { - key = "Int16".to_string(); - } - ret.insert( - key, - serde_json::Value::Number( - serde_json::Number::from_f64(v.into_inner() as f64).unwrap(), - ), - ); - } - Some(ScalarImpl::Float64(v)) => { - if key.is_empty() { - key = "Float64".to_string(); - } - ret.insert( - key, - serde_json::Value::Number( - serde_json::Number::from_f64(v.into_inner()).unwrap(), - ), - ); - } - Some(ScalarImpl::Utf8(v)) => { - if key.is_empty() { - key = "Utf8".to_string(); - } - ret.insert(key, serde_json::Value::String(v.to_string())); - } - Some(ScalarImpl::Struct(v)) => { - if key.is_empty() { - key = "Struct".to_string(); - } - ret.insert(key, recursive_parse_json(v.fields(), None, None)); - } - Some(ScalarImpl::Jsonb(v)) => { - if key.is_empty() { - key = "Jsonb".to_string(); - } - ret.insert(key, v.take()); - } - r#type => panic!("Not yet support ScalarImpl type: {:?}", r#type), - } - } - - serde_json::Value::Object(ret) -} - pub fn from_protobuf_value<'a>( field_desc: &FieldDescriptor, value: &'a Value, - descriptor_pool: &Arc, ) -> AccessResult> { let kind = field_desc.kind(); @@ -382,62 +236,9 @@ pub fn from_protobuf_value<'a>( } Value::Message(dyn_msg) => { if dyn_msg.descriptor().full_name() == "google.protobuf.Any" { - // If the fields are not presented, default value is an empty string - if !dyn_msg.has_field_by_name("type_url") || !dyn_msg.has_field_by_name("value") { - borrowed!(JsonbRef::empty_string()); - } - - // Sanity check - debug_assert!( - dyn_msg.has_field_by_name("type_url") && dyn_msg.has_field_by_name("value"), - "`type_url` & `value` must exist in fields of `dyn_msg`" - ); - - // The message is of type `Any` - let (type_url, payload) = extract_any_info(dyn_msg); - - let payload_field_desc = dyn_msg.descriptor().get_field_by_name("value").unwrap(); - - let payload = from_protobuf_value(&payload_field_desc, &payload, descriptor_pool)?; - let Some(ScalarRefImpl::Bytea(payload)) = payload.to_datum_ref() else { - bail_uncategorized!("expected bytes for dynamic message payload"); - }; - - // Get the corresponding schema from the descriptor pool - let msg_desc = descriptor_pool - .get_message_by_name(&type_url) - .ok_or_else(|| { - uncategorized!("message `{type_url}` not found in descriptor pool") - })?; - - let f = msg_desc - .clone() - .fields() - .map(|f| f.name().to_string()) - .collect::>(); - - let full_name = msg_desc.clone().full_name().to_string(); - - // Decode the payload based on the `msg_desc` - let decoded_value = DynamicMessage::decode(msg_desc, payload).unwrap(); - let decoded_value = from_protobuf_value( - field_desc, - &Value::Message(decoded_value), - descriptor_pool, - )? - .to_owned_datum() - .unwrap(); - - // Extract the struct value - let ScalarImpl::Struct(v) = decoded_value else { - panic!("Expect ScalarImpl::Struct"); - }; - - ScalarImpl::Jsonb(JsonbVal::from(serde_json::json!(recursive_parse_json( - v.fields(), - Some(f), - Some(full_name), - )))) + ScalarImpl::Jsonb(JsonbVal::from( + serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?, + )) } else { let mut rw_values = Vec::with_capacity(dyn_msg.descriptor().fields().len()); // fields is a btree map in descriptor @@ -454,9 +255,7 @@ pub fn from_protobuf_value<'a>( } // use default value if dyn_msg doesn't has this field let value = dyn_msg.get_field(&field_desc); - rw_values.push( - from_protobuf_value(&field_desc, &value, descriptor_pool)?.to_owned_datum(), - ); + rw_values.push(from_protobuf_value(&field_desc, &value)?.to_owned_datum()); } ScalarImpl::Struct(StructValue::new(rw_values)) } @@ -466,7 +265,7 @@ pub fn from_protobuf_value<'a>( .map_err(|e| uncategorized!("{}", e.to_report_string()))?; let mut builder = data_type.as_list().create_array_builder(values.len()); for value in values { - builder.append(from_protobuf_value(field_desc, value, descriptor_pool)?); + builder.append(from_protobuf_value(field_desc, value)?); } ScalarImpl::List(ListValue::new(builder.finish())) } @@ -498,25 +297,18 @@ fn protobuf_type_mapping( } Kind::Uint64 | Kind::Fixed64 => DataType::Decimal, Kind::String => DataType::Varchar, - Kind::Message(m) => { - let fields = m - .fields() - .map(|f| protobuf_type_mapping(&f, parse_trace)) - .try_collect()?; - let field_names = m.fields().map(|f| f.name().to_string()).collect_vec(); - - // Note that this part is useful for actual parsing - // Since RisingWave will parse message to `ScalarImpl::Jsonb` - // Please do NOT modify it - if field_names.len() == 2 - && field_names.contains(&"value".to_string()) - && field_names.contains(&"type_url".to_string()) - { - DataType::Jsonb - } else { + Kind::Message(m) => match m.full_name() { + // Well-Known Types are identified by their full name + "google.protobuf.Any" => DataType::Jsonb, + _ => { + let fields = m + .fields() + .map(|f| protobuf_type_mapping(&f, parse_trace)) + .try_collect()?; + let field_names = m.fields().map(|f| f.name().to_string()).collect_vec(); DataType::new_struct(fields, field_names) } - } + }, Kind::Enum(_) => DataType::Varchar, Kind::Bytes => DataType::Bytea, }; @@ -973,10 +765,9 @@ mod test { // This is of no use let field = value.fields().next().unwrap().0; - if let Some(ret) = - from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool) - .unwrap() - .to_owned_datum() + if let Some(ret) = from_protobuf_value(&field, &Value::Message(value)) + .unwrap() + .to_owned_datum() { println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret); println!("---------------------------"); @@ -1000,7 +791,7 @@ mod test { assert_eq!( jv, JsonbVal::from(json!({ - "_type": "test.StringValue", + "@type": "type.googleapis.com/test.StringValue", "value": "John Doe" })) ); @@ -1036,10 +827,9 @@ mod test { // This is of no use let field = value.fields().next().unwrap().0; - if let Some(ret) = - from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool) - .unwrap() - .to_owned_datum() + if let Some(ret) = from_protobuf_value(&field, &Value::Message(value)) + .unwrap() + .to_owned_datum() { println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret); println!("---------------------------"); @@ -1063,7 +853,7 @@ mod test { assert_eq!( jv, JsonbVal::from(json!({ - "_type": "test.Int32Value", + "@type": "type.googleapis.com/test.Int32Value", "value": 114514 })) ); @@ -1110,10 +900,9 @@ mod test { // This is of no use let field = value.fields().next().unwrap().0; - if let Some(ret) = - from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool) - .unwrap() - .to_owned_datum() + if let Some(ret) = from_protobuf_value(&field, &Value::Message(value)) + .unwrap() + .to_owned_datum() { println!("Decoded Value for ANY_RECURSIVE_GEN_PROTO_DATA: {:#?}", ret); println!("---------------------------"); @@ -1137,13 +926,13 @@ mod test { assert_eq!( jv, JsonbVal::from(json!({ - "_type": "test.AnyValue", - "any_value_1": { - "_type": "test.StringValue", + "@type": "type.googleapis.com/test.AnyValue", + "anyValue1": { + "@type": "type.googleapis.com/test.StringValue", "value": "114514", }, - "any_value_2": { - "_type": "test.Int32Value", + "anyValue2": { + "@type": "type.googleapis.com/test.Int32Value", "value": 114514, } })) @@ -1156,6 +945,37 @@ mod test { Ok(()) } + // id: 12345 + // any_value: { + // type_url: "type.googleapis.com/test.StringXalue" + // value: "\n\010John Doe" + // } + static ANY_GEN_PROTO_DATA_INVALID: &[u8] = b"\x08\xb9\x60\x12\x32\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x58\x61\x6c\x75\x65\x12\x0a\x0a\x08\x4a\x6f\x68\x6e\x20\x44\x6f\x65"; + + #[tokio::test] + async fn test_any_invalid() -> crate::error::ConnectorResult<()> { + let conf = create_recursive_pb_parser_config("/any-schema.pb", "test.TestAny").await; + + let value = + DynamicMessage::decode(conf.message_descriptor.clone(), ANY_GEN_PROTO_DATA_INVALID) + .unwrap(); + + // The top-level `Value` is not a proto field, but we need a dummy one. + let field = value.fields().next().unwrap().0; + + let err = from_protobuf_value(&field, &Value::Message(value)).unwrap_err(); + + let expected = expect_test::expect![[r#" + Fail to convert protobuf Any into jsonb + + Caused by: + message 'test.StringXalue' not found + "#]]; + expected.assert_eq(err.to_report_string_pretty().as_str()); + + Ok(()) + } + #[test] fn test_decode_varint_zigzag() { // 1. Positive number diff --git a/src/connector/src/parser/unified/mod.rs b/src/connector/src/parser/unified/mod.rs index 8045ce013240..fdfe3aae6aae 100644 --- a/src/connector/src/parser/unified/mod.rs +++ b/src/connector/src/parser/unified/mod.rs @@ -17,9 +17,7 @@ use auto_impl::auto_impl; use risingwave_common::types::{DataType, DatumCow}; use risingwave_connector_codec::decoder::avro::AvroAccess; -pub use risingwave_connector_codec::decoder::{ - bail_uncategorized, uncategorized, Access, AccessError, AccessResult, -}; +pub use risingwave_connector_codec::decoder::{uncategorized, Access, AccessError, AccessResult}; use self::bytes::BytesAccess; use self::json::JsonAccess; diff --git a/src/connector/src/parser/unified/protobuf.rs b/src/connector/src/parser/unified/protobuf.rs index 02febc22db24..b1d34746b502 100644 --- a/src/connector/src/parser/unified/protobuf.rs +++ b/src/connector/src/parser/unified/protobuf.rs @@ -13,9 +13,9 @@ // limitations under the License. use std::borrow::Cow; -use std::sync::{Arc, LazyLock}; +use std::sync::LazyLock; -use prost_reflect::{DescriptorPool, DynamicMessage, ReflectMessage}; +use prost_reflect::{DynamicMessage, ReflectMessage}; use risingwave_common::log::LogSuppresser; use risingwave_common::types::{DataType, DatumCow, ToOwnedDatum}; use thiserror_ext::AsReport; @@ -26,15 +26,11 @@ use crate::parser::unified::uncategorized; pub struct ProtobufAccess { message: DynamicMessage, - descriptor_pool: Arc, } impl ProtobufAccess { - pub fn new(message: DynamicMessage, descriptor_pool: Arc) -> Self { - Self { - message, - descriptor_pool, - } + pub fn new(message: DynamicMessage) -> Self { + Self { message } } } @@ -59,10 +55,10 @@ impl Access for ProtobufAccess { })?; match self.message.get_field(&field_desc) { - Cow::Borrowed(value) => from_protobuf_value(&field_desc, value, &self.descriptor_pool), + Cow::Borrowed(value) => from_protobuf_value(&field_desc, value), // `Owned` variant occurs only if there's no such field and the default value is returned. - Cow::Owned(value) => from_protobuf_value(&field_desc, &value, &self.descriptor_pool) + Cow::Owned(value) => from_protobuf_value(&field_desc, &value) // enforce `Owned` variant to avoid returning a reference to a temporary value .map(|d| d.to_owned_datum().into()), }