diff --git a/Cargo.lock b/Cargo.lock index c666315abe849..860d96c1a33cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6909,7 +6909,9 @@ dependencies = [ "parking_lot 0.12.1", "prometheus", "prost", + "prost-build", "prost-reflect", + "prost-types", "protobuf-native", "pulsar", "rand", diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 4188291614311..6b6aa6b7b8daf 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -63,8 +63,8 @@ num-bigint = "0.4" opendal = "0.39" parking_lot = "0.12" prometheus = { version = "0.13", features = ["process"] } -prost = { version = "0.11.9", features = ["no-recursion-limit"] } -prost-reflect = "0.11.5" +prost = { version = "0.11", features = ["no-recursion-limit"] } +prost-reflect = "0.11" protobuf-native = "0.2.1" pulsar = { version = "6.0", default-features = false, features = [ "tokio-runtime", @@ -113,9 +113,13 @@ workspace-hack = { path = "../workspace-hack" } [dev-dependencies] criterion = { workspace = true, features = ["async_tokio", "async"] } +prost-types = "0.11" rand = "0.8" tempfile = "3" +[build-dependencies] +prost-build = "0.11" + [[bench]] name = "parser" harness = false diff --git a/src/connector/build.rs b/src/connector/build.rs new file mode 100644 index 0000000000000..6439fcf00f932 --- /dev/null +++ b/src/connector/build.rs @@ -0,0 +1,29 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +fn main() { + let proto_dir = "./src/test_data/proto_recursive"; + + println!("cargo:rerun-if-changed={}", proto_dir); + + let proto_files = vec!["recursive"]; + let protos: Vec = proto_files + .iter() + .map(|f| format!("{}/{}.proto", proto_dir, f)) + .collect(); + prost_build::Config::new() + .out_dir("./src/parser/protobuf") + .compile_protos(&protos, &Vec::::new()) + .unwrap(); +} diff --git a/src/connector/src/parser/protobuf/.gitignore b/src/connector/src/parser/protobuf/.gitignore new file mode 100644 index 0000000000000..4109deeeb3337 --- /dev/null +++ b/src/connector/src/parser/protobuf/.gitignore @@ -0,0 +1 @@ +recursive.rs diff --git a/src/connector/src/parser/protobuf/mod.rs b/src/connector/src/parser/protobuf/mod.rs index 8870ee8f67b48..a2874e325e1fd 100644 --- a/src/connector/src/parser/protobuf/mod.rs +++ b/src/connector/src/parser/protobuf/mod.rs @@ -15,3 +15,7 @@ mod parser; pub use parser::*; mod schema_resolver; + +#[rustfmt::skip] +#[cfg(test)] +mod recursive; diff --git a/src/connector/src/parser/protobuf/parser.rs b/src/connector/src/parser/protobuf/parser.rs index 38a718d8cb814..09763bcd6a49d 100644 --- a/src/connector/src/parser/protobuf/parser.rs +++ b/src/connector/src/parser/protobuf/parser.rs @@ -300,11 +300,12 @@ fn protobuf_type_mapping( Kind::Bool => DataType::Boolean, Kind::Double => DataType::Float64, Kind::Float => DataType::Float32, - Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 | Kind::Fixed32 => DataType::Int32, - Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Fixed64 | Kind::Uint32 => { + Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32, + // Fixed32 represents [0, 2^32 - 1]. It's equal to u32. + Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => { DataType::Int64 } - Kind::Uint64 => DataType::Decimal, + Kind::Uint64 | Kind::Fixed64 => DataType::Decimal, Kind::String => DataType::Varchar, Kind::Message(m) => { let fields = m @@ -317,6 +318,12 @@ fn protobuf_type_mapping( Kind::Enum(_) => DataType::Varchar, Kind::Bytes => DataType::Bytea, }; + if field_descriptor.is_map() { + return Err(RwError::from(ProtocolError(format!( + "map type is unsupported (field: '{}')", + field_descriptor.full_name() + )))); + } if field_descriptor.cardinality() == Cardinality::Repeated { t = DataType::List(Box::new(t)) } @@ -342,15 +349,18 @@ pub(crate) fn resolve_pb_header(payload: &[u8]) -> Result<&[u8]> { #[cfg(test)] mod test { - use std::collections::HashMap; use std::path::PathBuf; - use bytes::Bytes; + use prost::Message; + use risingwave_common::types::{DataType, StructType}; use risingwave_pb::catalog::StreamSourceInfo; use risingwave_pb::data::data_type::PbTypeName; use super::*; + use crate::parser::protobuf::recursive::all_types::{EnumType, ExampleOneof, NestedMessage}; + use crate::parser::protobuf::recursive::AllTypes; + use crate::parser::unified::Access; use crate::parser::SpecificParserConfig; use crate::source::{SourceEncode, SourceFormat, SourceStruct}; @@ -498,8 +508,7 @@ mod test { assert!(columns.is_err()); } - #[tokio::test] - async fn test_all_types() { + async fn create_recursive_pb_parser_config() -> ProtobufParserConfig { let location = schema_dir() + "/proto_recursive/recursive.pb"; let message_name = "recursive.AllTypes"; @@ -515,19 +524,209 @@ mod test { &HashMap::new(), ) .unwrap(); - let conf = ProtobufParserConfig::new(parser_config.encoding_config) + + ProtobufParserConfig::new(parser_config.encoding_config) .await - .unwrap(); - // Ensure that the parser can recognize the schema. - conf.map_to_columns().unwrap(); + .unwrap() + } - let field_desc = conf - .message_descriptor - .get_field_by_name("bytes_field") - .unwrap(); - let d = from_protobuf_value(&field_desc, &Value::Bytes(Bytes::from(vec![1, 2, 3]))) + #[tokio::test] + async fn test_all_types_create_source() { + let conf = create_recursive_pb_parser_config().await; + + // Ensure that the parser can recognize the schema. + let columns = conf + .map_to_columns() .unwrap() - .unwrap(); - assert_eq!(d, ScalarImpl::Bytea(vec![1, 2, 3].into_boxed_slice())); + .into_iter() + .map(|c| DataType::from(&c.column_type.unwrap())) + .collect_vec(); + assert_eq!( + columns, + vec![ + DataType::Float64, // double_field + DataType::Float32, // float_field + DataType::Int32, // int32_field + DataType::Int64, // int64_field + DataType::Int64, // uint32_field + DataType::Decimal, // uint64_field + DataType::Int32, // sint32_field + DataType::Int64, // sint64_field + DataType::Int64, // fixed32_field + DataType::Decimal, // fixed64_field + DataType::Int32, // sfixed32_field + DataType::Int64, // sfixed64_field + DataType::Boolean, // bool_field + DataType::Varchar, // string_field + DataType::Bytea, // bytes_field + DataType::Varchar, // enum_field + DataType::Struct(StructType::new(vec![ + ("id", DataType::Int32), + ("name", DataType::Varchar) + ])), // nested_message_field + DataType::List(DataType::Int32.into()), // repeated_int_field + DataType::Varchar, // oneof_string + DataType::Int32, // oneof_int32 + DataType::Varchar, // oneof_enum + DataType::Struct(StructType::new(vec![ + ("seconds", DataType::Int64), + ("nanos", DataType::Int32) + ])), // timestamp_field + DataType::Struct(StructType::new(vec![ + ("seconds", DataType::Int64), + ("nanos", DataType::Int32) + ])), // duration_field + DataType::Struct(StructType::new(vec![ + ("type_url", DataType::Varchar), + ("value", DataType::Bytea), + ])), // any_field + DataType::Struct(StructType::new(vec![("value", DataType::Int32)])), /* int32_value_field */ + DataType::Struct(StructType::new(vec![("value", DataType::Varchar)])), /* string_value_field */ + ] + ) + } + + #[tokio::test] + async fn test_all_types_data_parsing() { + let m = create_all_types_message(); + let mut payload = Vec::new(); + m.encode(&mut payload).unwrap(); + + let conf = create_recursive_pb_parser_config().await; + let mut access_builder = ProtobufAccessBuilder::new(conf).unwrap(); + let access = access_builder.generate_accessor(payload).await.unwrap(); + if let AccessImpl::Protobuf(a) = access { + assert_all_types_eq(&a, &m); + } else { + panic!("unexpected") + } + } + + fn assert_all_types_eq(a: &ProtobufAccess, m: &AllTypes) { + type S = ScalarImpl; + + pb_eq(a, "double_field", S::Float64(m.double_field.into())); + pb_eq(a, "float_field", S::Float32(m.float_field.into())); + pb_eq(a, "int32_field", S::Int32(m.int32_field)); + pb_eq(a, "int64_field", S::Int64(m.int64_field)); + pb_eq(a, "uint32_field", S::Int64(m.uint32_field.into())); + pb_eq(a, "uint64_field", S::Decimal(m.uint64_field.into())); + pb_eq(a, "sint32_field", S::Int32(m.sint32_field)); + pb_eq(a, "sint64_field", S::Int64(m.sint64_field)); + pb_eq(a, "fixed32_field", S::Int64(m.fixed32_field.into())); + pb_eq(a, "fixed64_field", S::Decimal(m.fixed64_field.into())); + pb_eq(a, "sfixed32_field", S::Int32(m.sfixed32_field)); + pb_eq(a, "sfixed64_field", S::Int64(m.sfixed64_field)); + pb_eq(a, "bool_field", S::Bool(m.bool_field)); + pb_eq(a, "string_field", S::Utf8(m.string_field.as_str().into())); + pb_eq(a, "bytes_field", S::Bytea(m.bytes_field.clone().into())); + pb_eq(a, "enum_field", S::Utf8("OPTION1".into())); + pb_eq( + a, + "nested_message_field", + S::Struct(StructValue::new(vec![ + Some(ScalarImpl::Int32(100)), + Some(ScalarImpl::Utf8("Nested".into())), + ])), + ); + pb_eq( + a, + "repeated_int_field", + S::List(ListValue::new( + m.repeated_int_field + .iter() + .map(|&x| Some(x.into())) + .collect(), + )), + ); + pb_eq( + a, + "timestamp_field", + S::Struct(StructValue::new(vec![ + Some(ScalarImpl::Int64(1630927032)), + Some(ScalarImpl::Int32(500000000)), + ])), + ); + pb_eq( + a, + "duration_field", + S::Struct(StructValue::new(vec![ + Some(ScalarImpl::Int64(60)), + Some(ScalarImpl::Int32(500000000)), + ])), + ); + pb_eq( + a, + "any_field", + S::Struct(StructValue::new(vec![ + Some(ScalarImpl::Utf8( + m.any_field.as_ref().unwrap().type_url.as_str().into(), + )), + Some(ScalarImpl::Bytea( + m.any_field.as_ref().unwrap().value.clone().into(), + )), + ])), + ); + pb_eq( + a, + "int32_value_field", + S::Struct(StructValue::new(vec![Some(ScalarImpl::Int32(42))])), + ); + pb_eq( + a, + "string_value_field", + S::Struct(StructValue::new(vec![Some(ScalarImpl::Utf8( + m.string_value_field.as_ref().unwrap().as_str().into(), + ))])), + ); + pb_eq(a, "oneof_string", S::Utf8("".into())); + pb_eq(a, "oneof_int32", S::Int32(123)); + pb_eq(a, "oneof_enum", S::Utf8("DEFAULT".into())); + } + + fn pb_eq(a: &ProtobufAccess, field_name: &str, value: ScalarImpl) { + let d = a.access(&[field_name], None).unwrap().unwrap(); + assert_eq!(d, value, "field: {} value: {:?}", field_name, d); + } + + fn create_all_types_message() -> AllTypes { + AllTypes { + double_field: 1.2345, + float_field: 1.2345, + int32_field: 42, + int64_field: 1234567890, + uint32_field: 98765, + uint64_field: 9876543210, + sint32_field: -12345, + sint64_field: -987654321, + fixed32_field: 1234, + fixed64_field: 5678, + sfixed32_field: -56789, + sfixed64_field: -123456, + bool_field: true, + string_field: "Hello, Prost!".to_string(), + bytes_field: b"byte data".to_vec(), + enum_field: EnumType::Option1 as i32, + nested_message_field: Some(NestedMessage { + id: 100, + name: "Nested".to_string(), + }), + repeated_int_field: vec![1, 2, 3, 4, 5], + timestamp_field: Some(::prost_types::Timestamp { + seconds: 1630927032, + nanos: 500000000, + }), + duration_field: Some(::prost_types::Duration { + seconds: 60, + nanos: 500000000, + }), + any_field: Some(::prost_types::Any { + type_url: "type.googleapis.com/my_custom_type".to_string(), + value: b"My custom data".to_vec(), + }), + int32_value_field: Some(42), + string_value_field: Some("Hello, Wrapper!".to_string()), + example_oneof: Some(ExampleOneof::OneofInt32(123)), + } } } diff --git a/src/connector/src/test_data/proto_recursive/recursive.pb b/src/connector/src/test_data/proto_recursive/recursive.pb index eb5e822e12956..5c611c18d0d30 100644 Binary files a/src/connector/src/test_data/proto_recursive/recursive.pb and b/src/connector/src/test_data/proto_recursive/recursive.pb differ diff --git a/src/connector/src/test_data/proto_recursive/recursive.proto b/src/connector/src/test_data/proto_recursive/recursive.proto index fe664c8d59cf6..93f177055788c 100644 --- a/src/connector/src/test_data/proto_recursive/recursive.proto +++ b/src/connector/src/test_data/proto_recursive/recursive.proto @@ -72,8 +72,8 @@ message AllTypes { EnumType oneof_enum = 21; } - // map field - map map_field = 22; + // // map field + // map map_field = 22; // timestamp google.protobuf.Timestamp timestamp_field = 23;