Skip to content

Commit

Permalink
fix: protobuf fixed64 was incorrectly parsed as int64 (#12126) (#12182)
Browse files Browse the repository at this point in the history
  • Loading branch information
neverchanje authored Sep 8, 2023
1 parent 83ba86a commit 03281e5
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 22 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions src/connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ num-bigint = "0.4"
opendal = "0.39"
parking_lot = "0.12"
prometheus = { version = "0.13", features = ["process"] }
prost = { version = "0.11.9", features = ["no-recursion-limit"] }
prost-reflect = "0.11.4"
prost = { version = "0.11", features = ["no-recursion-limit"] }
prost-reflect = "0.11"
protobuf-native = "0.2.1"
pulsar = { version = "6.0", default-features = false, features = [
"tokio-runtime",
Expand Down Expand Up @@ -113,9 +113,13 @@ workspace-hack = { path = "../workspace-hack" }

[dev-dependencies]
criterion = { workspace = true, features = ["async_tokio", "async"] }
prost-types = "0.11"
rand = "0.8"
tempfile = "3"

[build-dependencies]
prost-build = "0.11"

[[bench]]
name = "parser"
harness = false
29 changes: 29 additions & 0 deletions src/connector/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

fn main() {
let proto_dir = "./src/test_data/proto_recursive";

println!("cargo:rerun-if-changed={}", proto_dir);

let proto_files = vec!["recursive"];
let protos: Vec<String> = proto_files
.iter()
.map(|f| format!("{}/{}.proto", proto_dir, f))
.collect();
prost_build::Config::new()
.out_dir("./src/parser/protobuf")
.compile_protos(&protos, &Vec::<String>::new())
.unwrap();
}
1 change: 1 addition & 0 deletions src/connector/src/parser/protobuf/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive.rs
4 changes: 4 additions & 0 deletions src/connector/src/parser/protobuf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
mod parser;
pub use parser::*;
mod schema_resolver;

#[rustfmt::skip]
#[cfg(test)]
mod recursive;
235 changes: 217 additions & 18 deletions src/connector/src/parser/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,12 @@ fn protobuf_type_mapping(
Kind::Bool => DataType::Boolean,
Kind::Double => DataType::Float64,
Kind::Float => DataType::Float32,
Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 | Kind::Fixed32 => DataType::Int32,
Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Fixed64 | Kind::Uint32 => {
Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32,
// Fixed32 represents [0, 2^32 - 1]. It's equal to u32.
Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => {
DataType::Int64
}
Kind::Uint64 => DataType::Decimal,
Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
Kind::String => DataType::Varchar,
Kind::Message(m) => {
let fields = m
Expand All @@ -317,6 +318,12 @@ fn protobuf_type_mapping(
Kind::Enum(_) => DataType::Varchar,
Kind::Bytes => DataType::Bytea,
};
if field_descriptor.is_map() {
return Err(RwError::from(ProtocolError(format!(
"map type is unsupported (field: '{}')",
field_descriptor.full_name()
))));
}
if field_descriptor.cardinality() == Cardinality::Repeated {
t = DataType::List(Box::new(t))
}
Expand All @@ -342,15 +349,18 @@ pub(crate) fn resolve_pb_header(payload: &[u8]) -> Result<&[u8]> {

#[cfg(test)]
mod test {

use std::collections::HashMap;
use std::path::PathBuf;

use bytes::Bytes;
use prost::Message;
use risingwave_common::types::{DataType, StructType};
use risingwave_pb::catalog::StreamSourceInfo;
use risingwave_pb::data::data_type::PbTypeName;

use super::*;
use crate::parser::protobuf::recursive::all_types::{EnumType, ExampleOneof, NestedMessage};
use crate::parser::protobuf::recursive::AllTypes;
use crate::parser::unified::Access;
use crate::parser::SpecificParserConfig;
use crate::source::{SourceEncode, SourceFormat, SourceStruct};

Expand Down Expand Up @@ -498,8 +508,7 @@ mod test {
assert!(columns.is_err());
}

#[tokio::test]
async fn test_all_types() {
async fn create_recursive_pb_parser_config() -> ProtobufParserConfig {
let location = schema_dir() + "/proto_recursive/recursive.pb";
let message_name = "recursive.AllTypes";

Expand All @@ -515,19 +524,209 @@ mod test {
&HashMap::new(),
)
.unwrap();
let conf = ProtobufParserConfig::new(parser_config.encoding_config)

ProtobufParserConfig::new(parser_config.encoding_config)
.await
.unwrap();
// Ensure that the parser can recognize the schema.
conf.map_to_columns().unwrap();
.unwrap()
}

let field_desc = conf
.message_descriptor
.get_field_by_name("bytes_field")
.unwrap();
let d = from_protobuf_value(&field_desc, &Value::Bytes(Bytes::from(vec![1, 2, 3])))
#[tokio::test]
async fn test_all_types_create_source() {
let conf = create_recursive_pb_parser_config().await;

// Ensure that the parser can recognize the schema.
let columns = conf
.map_to_columns()
.unwrap()
.unwrap();
assert_eq!(d, ScalarImpl::Bytea(vec![1, 2, 3].into_boxed_slice()));
.into_iter()
.map(|c| DataType::from(&c.column_type.unwrap()))
.collect_vec();
assert_eq!(
columns,
vec![
DataType::Float64, // double_field
DataType::Float32, // float_field
DataType::Int32, // int32_field
DataType::Int64, // int64_field
DataType::Int64, // uint32_field
DataType::Decimal, // uint64_field
DataType::Int32, // sint32_field
DataType::Int64, // sint64_field
DataType::Int64, // fixed32_field
DataType::Decimal, // fixed64_field
DataType::Int32, // sfixed32_field
DataType::Int64, // sfixed64_field
DataType::Boolean, // bool_field
DataType::Varchar, // string_field
DataType::Bytea, // bytes_field
DataType::Varchar, // enum_field
DataType::Struct(StructType::new(vec![
("id", DataType::Int32),
("name", DataType::Varchar)
])), // nested_message_field
DataType::List(DataType::Int32.into()), // repeated_int_field
DataType::Varchar, // oneof_string
DataType::Int32, // oneof_int32
DataType::Varchar, // oneof_enum
DataType::Struct(StructType::new(vec![
("seconds", DataType::Int64),
("nanos", DataType::Int32)
])), // timestamp_field
DataType::Struct(StructType::new(vec![
("seconds", DataType::Int64),
("nanos", DataType::Int32)
])), // duration_field
DataType::Struct(StructType::new(vec![
("type_url", DataType::Varchar),
("value", DataType::Bytea),
])), // any_field
DataType::Struct(StructType::new(vec![("value", DataType::Int32)])), /* int32_value_field */
DataType::Struct(StructType::new(vec![("value", DataType::Varchar)])), /* string_value_field */
]
)
}

#[tokio::test]
async fn test_all_types_data_parsing() {
let m = create_all_types_message();
let mut payload = Vec::new();
m.encode(&mut payload).unwrap();

let conf = create_recursive_pb_parser_config().await;
let mut access_builder = ProtobufAccessBuilder::new(conf).unwrap();
let access = access_builder.generate_accessor(payload).await.unwrap();
if let AccessImpl::Protobuf(a) = access {
assert_all_types_eq(&a, &m);
} else {
panic!("unexpected")
}
}

fn assert_all_types_eq(a: &ProtobufAccess, m: &AllTypes) {
type S = ScalarImpl;

pb_eq(a, "double_field", S::Float64(m.double_field.into()));
pb_eq(a, "float_field", S::Float32(m.float_field.into()));
pb_eq(a, "int32_field", S::Int32(m.int32_field));
pb_eq(a, "int64_field", S::Int64(m.int64_field));
pb_eq(a, "uint32_field", S::Int64(m.uint32_field.into()));
pb_eq(a, "uint64_field", S::Decimal(m.uint64_field.into()));
pb_eq(a, "sint32_field", S::Int32(m.sint32_field));
pb_eq(a, "sint64_field", S::Int64(m.sint64_field));
pb_eq(a, "fixed32_field", S::Int64(m.fixed32_field.into()));
pb_eq(a, "fixed64_field", S::Decimal(m.fixed64_field.into()));
pb_eq(a, "sfixed32_field", S::Int32(m.sfixed32_field));
pb_eq(a, "sfixed64_field", S::Int64(m.sfixed64_field));
pb_eq(a, "bool_field", S::Bool(m.bool_field));
pb_eq(a, "string_field", S::Utf8(m.string_field.as_str().into()));
pb_eq(a, "bytes_field", S::Bytea(m.bytes_field.clone().into()));
pb_eq(a, "enum_field", S::Utf8("OPTION1".into()));
pb_eq(
a,
"nested_message_field",
S::Struct(StructValue::new(vec![
Some(ScalarImpl::Int32(100)),
Some(ScalarImpl::Utf8("Nested".into())),
])),
);
pb_eq(
a,
"repeated_int_field",
S::List(ListValue::new(
m.repeated_int_field
.iter()
.map(|&x| Some(x.into()))
.collect(),
)),
);
pb_eq(
a,
"timestamp_field",
S::Struct(StructValue::new(vec![
Some(ScalarImpl::Int64(1630927032)),
Some(ScalarImpl::Int32(500000000)),
])),
);
pb_eq(
a,
"duration_field",
S::Struct(StructValue::new(vec![
Some(ScalarImpl::Int64(60)),
Some(ScalarImpl::Int32(500000000)),
])),
);
pb_eq(
a,
"any_field",
S::Struct(StructValue::new(vec![
Some(ScalarImpl::Utf8(
m.any_field.as_ref().unwrap().type_url.as_str().into(),
)),
Some(ScalarImpl::Bytea(
m.any_field.as_ref().unwrap().value.clone().into(),
)),
])),
);
pb_eq(
a,
"int32_value_field",
S::Struct(StructValue::new(vec![Some(ScalarImpl::Int32(42))])),
);
pb_eq(
a,
"string_value_field",
S::Struct(StructValue::new(vec![Some(ScalarImpl::Utf8(
m.string_value_field.as_ref().unwrap().as_str().into(),
))])),
);
pb_eq(a, "oneof_string", S::Utf8("".into()));
pb_eq(a, "oneof_int32", S::Int32(123));
pb_eq(a, "oneof_enum", S::Utf8("DEFAULT".into()));
}

fn pb_eq(a: &ProtobufAccess, field_name: &str, value: ScalarImpl) {
let d = a.access(&[field_name], None).unwrap().unwrap();
assert_eq!(d, value, "field: {} value: {:?}", field_name, d);
}

fn create_all_types_message() -> AllTypes {
AllTypes {
double_field: 1.2345,
float_field: 1.2345,
int32_field: 42,
int64_field: 1234567890,
uint32_field: 98765,
uint64_field: 9876543210,
sint32_field: -12345,
sint64_field: -987654321,
fixed32_field: 1234,
fixed64_field: 5678,
sfixed32_field: -56789,
sfixed64_field: -123456,
bool_field: true,
string_field: "Hello, Prost!".to_string(),
bytes_field: b"byte data".to_vec(),
enum_field: EnumType::Option1 as i32,
nested_message_field: Some(NestedMessage {
id: 100,
name: "Nested".to_string(),
}),
repeated_int_field: vec![1, 2, 3, 4, 5],
timestamp_field: Some(::prost_types::Timestamp {
seconds: 1630927032,
nanos: 500000000,
}),
duration_field: Some(::prost_types::Duration {
seconds: 60,
nanos: 500000000,
}),
any_field: Some(::prost_types::Any {
type_url: "type.googleapis.com/my_custom_type".to_string(),
value: b"My custom data".to_vec(),
}),
int32_value_field: Some(42),
string_value_field: Some("Hello, Wrapper!".to_string()),
example_oneof: Some(ExampleOneof::OneofInt32(123)),
}
}
}
Binary file modified src/connector/src/test_data/proto_recursive/recursive.pb
Binary file not shown.
4 changes: 2 additions & 2 deletions src/connector/src/test_data/proto_recursive/recursive.proto
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ message AllTypes {
EnumType oneof_enum = 21;
}

// map field
map<string, int32> map_field = 22;
// // map field
// map<string, int32> map_field = 22;

// timestamp
google.protobuf.Timestamp timestamp_field = 23;
Expand Down

0 comments on commit 03281e5

Please sign in to comment.