From c4a4cb9191a43a64161b2ba2af476fc6e61fb7b2 Mon Sep 17 00:00:00 2001 From: Rossil <40714231+Rossil2012@users.noreply.github.com> Date: Tue, 24 Oct 2023 03:24:06 +0800 Subject: [PATCH] fix(protobuf): recursive Any field (#13008) Co-authored-by: Michael Xu --- src/connector/src/parser/protobuf/parser.rs | 256 +++++++++++-------- src/connector/src/parser/unified/protobuf.rs | 4 +- src/connector/src/test_data/any-schema.pb | 7 +- src/connector/src/test_data/any-schema.proto | 5 + 4 files changed, 157 insertions(+), 115 deletions(-) diff --git a/src/connector/src/parser/protobuf/parser.rs b/src/connector/src/parser/protobuf/parser.rs index 5e5bb7ec15ae1..efa1ae9c96248 100644 --- a/src/connector/src/parser/protobuf/parser.rs +++ b/src/connector/src/parser/protobuf/parser.rs @@ -245,7 +245,12 @@ fn extract_any_info(dyn_msg: &DynamicMessage) -> (String, Value) { let type_url = dyn_msg .get_field_by_name("type_url") - .expect("Expect type_url in dyn_msg"); + .expect("Expect type_url in dyn_msg") + .to_string() + .split('/') + .nth(1) + .map(|part| part[..part.len() - 1].to_string()) + .unwrap_or_default(); let payload = dyn_msg .get_field_by_name("value") @@ -253,10 +258,6 @@ fn extract_any_info(dyn_msg: &DynamicMessage) -> (String, Value) { .as_ref() .clone(); - let type_url = type_url.to_string().split('/').collect::>()[1].to_string(); - - let type_url = type_url[..type_url.len() - 1].to_string(); - (type_url, payload) } @@ -269,14 +270,14 @@ fn extract_any_info(dyn_msg: &DynamicMessage) -> (String, Value) { fn recursive_parse_json(fields: &[Datum], full_name_vec: Option>) -> serde_json::Value { let mut ret: serde_json::Map = serde_json::Map::new(); - for i in 0..fields.len() { + for (idx, field) in fields.iter().enumerate() { let mut key = if full_name_vec.is_some() { - full_name_vec.as_ref().unwrap()[i].to_string() + full_name_vec.as_ref().unwrap()[idx].to_string() } else { "".to_string() }; - match fields[i].clone() { + match field.clone() { Some(ScalarImpl::Int16(v)) => { if key.is_empty() { key = "Int16".to_string(); @@ -342,7 +343,13 @@ fn recursive_parse_json(fields: &[Datum], full_name_vec: Option>) -> } ret.insert(key, recursive_parse_json(v.fields(), None)); } - _ => panic!("Not yet support ScalarImpl type"), + Some(ScalarImpl::Jsonb(v)) => { + if key.is_empty() { + key = "Jsonb".to_string(); + } + ret.insert(key, v.take()); + } + r#type => panic!("Not yet support ScalarImpl type: {:?}", r#type), } } @@ -353,7 +360,6 @@ pub fn from_protobuf_value( field_desc: &FieldDescriptor, value: &Value, descriptor_pool: &Arc, - type_expected: Option<&DataType>, ) -> Result { let v = match value { Value::Bool(v) => ScalarImpl::Bool(*v), @@ -380,30 +386,21 @@ pub fn from_protobuf_value( ScalarImpl::Utf8(enum_symbol.name().into()) } Value::Message(dyn_msg) => { - let any_flag = dyn_msg.descriptor().full_name() == "google.protobuf.Any"; - - if any_flag { + if dyn_msg.descriptor().full_name() == "google.protobuf.Any" { + // Sanity check debug_assert!( - type_expected == Some(&DataType::Jsonb), - "`type_expected` must be of `DataType::Jsonb` for any protobuf type" + dyn_msg.has_field_by_name("type_url") && + dyn_msg.has_field_by_name("value"), + "`type_url` & `value` must exist in fields of `dyn_msg`" ); - } - if dyn_msg.has_field_by_name("type_url") - && dyn_msg.has_field_by_name("value") - && any_flag - { // The message is of type `Any` let (type_url, payload) = extract_any_info(dyn_msg); let payload_field_desc = dyn_msg.descriptor().get_field_by_name("value").unwrap(); - let Some(ScalarImpl::Bytea(payload)) = from_protobuf_value( - &payload_field_desc, - &payload, - descriptor_pool, - type_expected, - )? + let Some(ScalarImpl::Bytea(payload)) = + from_protobuf_value(&payload_field_desc, &payload, descriptor_pool)? else { panic!("Expected ScalarImpl::Bytea for payload"); }; @@ -430,10 +427,6 @@ pub fn from_protobuf_value( field_desc, &Value::Message(decoded_value), descriptor_pool, - // FIXME: Here `type_expected` can not be parsed by context - // Thus this may be error-prone, need refactor / remove this - // when dealing with nested any type - type_expected, )? .unwrap(); @@ -442,9 +435,10 @@ pub fn from_protobuf_value( panic!("Expect ScalarImpl::Struct"); }; - ScalarImpl::Jsonb(JsonbVal::from( - serde_json::json!({"value": recursive_parse_json(v.fields(), Some(f))}), - )) + ScalarImpl::Jsonb(JsonbVal::from(serde_json::json!(recursive_parse_json( + v.fields(), + Some(f) + )))) } else { let mut rw_values = Vec::with_capacity(dyn_msg.descriptor().fields().len()); // fields is a btree map in descriptor @@ -462,12 +456,7 @@ pub fn from_protobuf_value( } // use default value if dyn_msg doesn't has this field let value = dyn_msg.get_field(&field_desc); - rw_values.push(from_protobuf_value( - &field_desc, - &value, - descriptor_pool, - type_expected, - )?); + rw_values.push(from_protobuf_value(&field_desc, &value, descriptor_pool)?); } ScalarImpl::Struct(StructValue::new(rw_values)) } @@ -475,7 +464,7 @@ pub fn from_protobuf_value( Value::List(values) => { let rw_values = values .iter() - .map(|value| from_protobuf_value(field_desc, value, descriptor_pool, type_expected)) + .map(|value| from_protobuf_value(field_desc, value, descriptor_pool)) .collect::>>()?; ScalarImpl::List(ListValue::new(rw_values)) } @@ -516,6 +505,9 @@ fn protobuf_type_mapping( .collect::>>()?; let field_names = m.fields().map(|f| f.name().to_string()).collect_vec(); + // Note that this part is useful for actual parsing + // Since RisingWave will parse message to `ScalarImpl::Jsonb` + // Please do NOT modify it if field_names.len() == 2 && field_names.contains(&"value".to_string()) && field_names.contains(&"type_url".to_string()) @@ -566,6 +558,7 @@ mod test { use risingwave_common::types::{DataType, StructType}; use risingwave_pb::catalog::StreamSourceInfo; use risingwave_pb::data::data_type::PbTypeName; + use serde_json::json; use super::*; use crate::parser::protobuf::recursive::all_types::{EnumType, ExampleOneof, NestedMessage}; @@ -718,9 +711,11 @@ mod test { assert!(columns.is_err()); } - async fn create_recursive_pb_parser_config() -> ProtobufParserConfig { - let location = schema_dir() + "/proto_recursive/recursive.pb"; - let message_name = "recursive.AllTypes"; + async fn create_recursive_pb_parser_config( + location: &str, + message_name: &str, + ) -> ProtobufParserConfig { + let location = schema_dir() + location; let info = StreamSourceInfo { proto_message_name: message_name.to_string(), @@ -742,7 +737,11 @@ mod test { #[tokio::test] async fn test_all_types_create_source() { - let conf = create_recursive_pb_parser_config().await; + let conf = create_recursive_pb_parser_config( + "/proto_recursive/recursive.pb", + "recursive.AllTypes", + ) + .await; // Ensure that the parser can recognize the schema. let columns = conf @@ -786,10 +785,7 @@ mod test { ("seconds", DataType::Int64), ("nanos", DataType::Int32) ])), // duration_field - DataType::Struct(StructType::new(vec![ - ("type_url", DataType::Varchar), - ("value", DataType::Bytea), - ])), // any_field + DataType::Jsonb, // any_field DataType::Struct(StructType::new(vec![("value", DataType::Int32)])), /* int32_value_field */ DataType::Struct(StructType::new(vec![("value", DataType::Varchar)])), /* string_value_field */ ] @@ -802,7 +798,11 @@ mod test { let mut payload = Vec::new(); m.encode(&mut payload).unwrap(); - let conf = create_recursive_pb_parser_config().await; + let conf = create_recursive_pb_parser_config( + "/proto_recursive/recursive.pb", + "recursive.AllTypes", + ) + .await; let mut access_builder = ProtobufAccessBuilder::new(conf).unwrap(); let access = access_builder.generate_accessor(payload).await.unwrap(); if let AccessImpl::Protobuf(a) = access { @@ -947,26 +947,9 @@ mod test { // } static ANY_GEN_PROTO_DATA: &[u8] = b"\x08\xb9\x60\x12\x32\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x56\x61\x6c\x75\x65\x12\x0a\x0a\x08\x4a\x6f\x68\x6e\x20\x44\x6f\x65"; - #[ignore] #[tokio::test] async fn test_any_schema() -> Result<()> { - let location = schema_dir() + "/any-schema.pb"; - println!("location: {}", location); - let message_name = "test.TestAny"; - let info = StreamSourceInfo { - proto_message_name: message_name.to_string(), - row_schema_location: location.to_string(), - use_schema_registry: false, - ..Default::default() - }; - - let parser_config = SpecificParserConfig::new( - SourceStruct::new(SourceFormat::Plain, SourceEncode::Protobuf), - &info, - &HashMap::new(), - )?; - - let conf = ProtobufParserConfig::new(parser_config.encoding_config).await?; + let conf = create_recursive_pb_parser_config("/any-schema.pb", "test.TestAny").await; println!("Current conf: {:#?}", conf); println!("---------------------------"); @@ -981,8 +964,7 @@ mod test { let field = value.fields().next().unwrap().0; if let Some(ret) = - from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool, None) - .unwrap() + from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap() { println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret); println!("---------------------------"); @@ -1002,18 +984,15 @@ mod test { } match fields[1].clone() { - Some(ScalarImpl::Struct(sv)) => { - let fields = sv.fields(); - debug_assert!(fields.len() == 1, "Expected only one field"); - match fields[0].clone() { - Some(ScalarImpl::Utf8(v)) => { - println!("Successfully decode field[0] for any type"); - assert_eq!(v.to_string(), "John Doe"); - } - _ => panic!("Expected ScalarImpl::Int32"), - } + Some(ScalarImpl::Jsonb(jv)) => { + assert_eq!( + jv, + JsonbVal::from(json!({ + "value": "John Doe" + })) + ); } - _ => panic!("Expected ScalarImpl::Struct"), + _ => panic!("Expected ScalarImpl::Jsonb"), } } @@ -1028,26 +1007,9 @@ mod test { // Unpacked Int32Value from Any: value: 114514 static ANY_GEN_PROTO_DATA_1: &[u8] = b"\x08\xb9\x60\x12\x2b\x0a\x23\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x49\x6e\x74\x33\x32\x56\x61\x6c\x75\x65\x12\x04\x08\xd2\xfe\x06"; - #[ignore] #[tokio::test] async fn test_any_schema_1() -> Result<()> { - let location = schema_dir() + "/any-schema.pb"; - println!("location: {}", location); - let message_name = "test.TestAny"; - let info = StreamSourceInfo { - proto_message_name: message_name.to_string(), - row_schema_location: location.to_string(), - use_schema_registry: false, - ..Default::default() - }; - - let parser_config = SpecificParserConfig::new( - SourceStruct::new(SourceFormat::Plain, SourceEncode::Protobuf), - &info, - &HashMap::new(), - )?; - - let conf = ProtobufParserConfig::new(parser_config.encoding_config).await?; + let conf = create_recursive_pb_parser_config("/any-schema.pb", "test.TestAny").await; println!("Current conf: {:#?}", conf); println!("---------------------------"); @@ -1062,8 +1024,7 @@ mod test { let field = value.fields().next().unwrap().0; if let Some(ret) = - from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool, None) - .unwrap() + from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap() { println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret); println!("---------------------------"); @@ -1083,18 +1044,91 @@ mod test { } match fields[1].clone() { - Some(ScalarImpl::Struct(sv)) => { - let fields = sv.fields(); - debug_assert!(fields.len() == 1, "Expected only one field"); - match fields[0].clone() { - Some(ScalarImpl::Int32(v)) => { - println!("Successfully decode field[0] for any type"); - assert_eq!(v, 114514); - } - _ => panic!("Expected ScalarImpl::Int32"), - } + Some(ScalarImpl::Jsonb(jv)) => { + assert_eq!( + jv, + JsonbVal::from(json!({ + "value": 114514 + })) + ); + } + _ => panic!("Expected ScalarImpl::Jsonb"), + } + } + + Ok(()) + } + + // "id": 12345, + // "any_value": { + // "type_url": "type.googleapis.com/test.AnyValue", + // "value": { + // "any_value_1": { + // "type_url": "type.googleapis.com/test.StringValue", + // "value": "114514" + // }, + // "any_value_2": { + // "type_url": "type.googleapis.com/test.Int32Value", + // "value": 114514 + // } + // } + // } + static ANY_RECURSIVE_GEN_PROTO_DATA: &[u8] = b"\x08\xb9\x60\x12\x84\x01\x0a\x21\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x41\x6e\x79\x56\x61\x6c\x75\x65\x12\x5f\x0a\x30\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x56\x61\x6c\x75\x65\x12\x08\x0a\x06\x31\x31\x34\x35\x31\x34\x12\x2b\x0a\x23\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x49\x6e\x74\x33\x32\x56\x61\x6c\x75\x65\x12\x04\x08\xd2\xfe\x06"; + + #[tokio::test] + async fn test_any_recursive() -> Result<()> { + let conf = create_recursive_pb_parser_config("/any-schema.pb", "test.TestAny").await; + + println!("Current conf: {:#?}", conf); + println!("---------------------------"); + + let value = DynamicMessage::decode( + conf.message_descriptor.clone(), + ANY_RECURSIVE_GEN_PROTO_DATA, + ) + .unwrap(); + + println!("Current Value: {:#?}", value); + println!("---------------------------"); + + // This is of no use + let field = value.fields().next().unwrap().0; + + if let Some(ret) = + from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap() + { + println!("Decoded Value for ANY_RECURSIVE_GEN_PROTO_DATA: {:#?}", ret); + println!("---------------------------"); + + let ScalarImpl::Struct(struct_value) = ret else { + panic!("Expected ScalarImpl::Struct"); + }; + + let fields = struct_value.fields(); + + match fields[0].clone() { + Some(ScalarImpl::Int32(v)) => { + println!("Successfully decode field[0]"); + assert_eq!(v, 12345); + } + _ => panic!("Expected ScalarImpl::Int32"), + } + + match fields[1].clone() { + Some(ScalarImpl::Jsonb(jv)) => { + assert_eq!( + jv, + JsonbVal::from(json!({ + "any_value_1": { + "value": "114514", + }, + "any_value_2": { + "value": 114514, + } + })) + ); } - _ => panic!("Expected ScalarImpl::Struct"), + _ => panic!("Expected ScalarImpl::Jsonb"), } } diff --git a/src/connector/src/parser/unified/protobuf.rs b/src/connector/src/parser/unified/protobuf.rs index 4505cbe45116f..bd447b84cfb62 100644 --- a/src/connector/src/parser/unified/protobuf.rs +++ b/src/connector/src/parser/unified/protobuf.rs @@ -39,7 +39,7 @@ impl ProtobufAccess { } impl Access for ProtobufAccess { - fn access(&self, path: &[&str], type_expected: Option<&DataType>) -> AccessResult { + fn access(&self, path: &[&str], _type_expected: Option<&DataType>) -> AccessResult { debug_assert_eq!(1, path.len()); let field_desc = self .message @@ -52,7 +52,7 @@ impl Access for ProtobufAccess { }) .map_err(|e| AccessError::Other(anyhow!(e)))?; let value = self.message.get_field(&field_desc); - from_protobuf_value(&field_desc, &value, &self.descriptor_pool, type_expected) + from_protobuf_value(&field_desc, &value, &self.descriptor_pool) .map_err(|e| AccessError::Other(anyhow!(e))) } } diff --git a/src/connector/src/test_data/any-schema.pb b/src/connector/src/test_data/any-schema.pb index ac7f71058ddd6..977f64cec3775 100644 --- a/src/connector/src/test_data/any-schema.pb +++ b/src/connector/src/test_data/any-schema.pb @@ -5,7 +5,7 @@ type_url ( RtypeUrl value ( RvalueBv com.google.protobufBAnyProtoPZ,google.golang.org/protobuf/types/known/anypb¢GPBªGoogle.Protobuf.WellKnownTypesbproto3 -é +á any-schema.prototestgoogle/protobuf/any.proto"L TestAny id (Rid1 @@ -14,7 +14,10 @@ value ( Rvalue"" Int32Value -value (Rvalue"@ +value (Rvalue"v +AnyValue4 + any_value_1 ( 2.google.protobuf.AnyR anyValue14 + any_value_2 ( 2.google.protobuf.AnyR anyValue2"@ StringInt32Value first ( Rfirst second (Rsecond"Ž diff --git a/src/connector/src/test_data/any-schema.proto b/src/connector/src/test_data/any-schema.proto index 30190f9b0b939..12a367100ce7d 100644 --- a/src/connector/src/test_data/any-schema.proto +++ b/src/connector/src/test_data/any-schema.proto @@ -16,6 +16,11 @@ message Int32Value { int32 value = 1; } +message AnyValue { + google.protobuf.Any any_value_1 = 1; + google.protobuf.Any any_value_2 = 2; +} + message StringInt32Value { string first = 1; int32 second = 2;