Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: protobuf fixed64 was incorrectly parsed as int64 #12126

Merged
merged 5 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions src/connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ num-bigint = "0.4"
opendal = "0.39"
parking_lot = "0.12"
prometheus = { version = "0.13", features = ["process"] }
prost = { version = "0.11.9", features = ["no-recursion-limit"] }
prost-reflect = "0.11.5"
prost = { version = "0.11", features = ["no-recursion-limit"] }
prost-reflect = "0.11"
protobuf-native = "0.2.1"
pulsar = { version = "6.0", default-features = false, features = [
"tokio-runtime",
Expand Down Expand Up @@ -113,9 +113,13 @@ workspace-hack = { path = "../workspace-hack" }

[dev-dependencies]
criterion = { workspace = true, features = ["async_tokio", "async"] }
prost-types = "0.11"
rand = "0.8"
tempfile = "3"

[build-dependencies]
prost-build = "0.11"

[[bench]]
name = "parser"
harness = false
29 changes: 29 additions & 0 deletions src/connector/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

fn main() {
let proto_dir = "./src/test_data/proto_recursive";

println!("cargo:rerun-if-changed={}", proto_dir);

let proto_files = vec!["recursive"];
let protos: Vec<String> = proto_files
.iter()
.map(|f| format!("{}/{}.proto", proto_dir, f))
.collect();
prost_build::Config::new()
.out_dir("./src/parser/protobuf")
.compile_protos(&protos, &Vec::<String>::new())
.unwrap();
}
1 change: 1 addition & 0 deletions src/connector/src/parser/protobuf/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive.rs
4 changes: 4 additions & 0 deletions src/connector/src/parser/protobuf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
mod parser;
pub use parser::*;
mod schema_resolver;

#[rustfmt::skip]
#[cfg(test)]
mod recursive;
235 changes: 217 additions & 18 deletions src/connector/src/parser/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,12 @@ fn protobuf_type_mapping(
Kind::Bool => DataType::Boolean,
Kind::Double => DataType::Float64,
Kind::Float => DataType::Float32,
Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 | Kind::Fixed32 => DataType::Int32,
Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Fixed64 | Kind::Uint32 => {
Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => DataType::Int32,
// Fixed32 represents [0, 2^32 - 1]. It's equal to u32.
Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 | Kind::Uint32 | Kind::Fixed32 => {
DataType::Int64
}
Kind::Uint64 => DataType::Decimal,
Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
Kind::String => DataType::Varchar,
Kind::Message(m) => {
let fields = m
Expand All @@ -317,6 +318,12 @@ fn protobuf_type_mapping(
Kind::Enum(_) => DataType::Varchar,
Kind::Bytes => DataType::Bytea,
};
if field_descriptor.is_map() {
return Err(RwError::from(ProtocolError(format!(
"map type is unsupported (field: '{}')",
field_descriptor.full_name()
neverchanje marked this conversation as resolved.
Show resolved Hide resolved
))));
}
if field_descriptor.cardinality() == Cardinality::Repeated {
t = DataType::List(Box::new(t))
}
Expand All @@ -342,15 +349,18 @@ pub(crate) fn resolve_pb_header(payload: &[u8]) -> Result<&[u8]> {

#[cfg(test)]
mod test {

use std::collections::HashMap;
use std::path::PathBuf;

use bytes::Bytes;
use prost::Message;
use risingwave_common::types::{DataType, StructType};
use risingwave_pb::catalog::StreamSourceInfo;
use risingwave_pb::data::data_type::PbTypeName;

use super::*;
use crate::parser::protobuf::recursive::all_types::{EnumType, ExampleOneof, NestedMessage};
use crate::parser::protobuf::recursive::AllTypes;
use crate::parser::unified::Access;
use crate::parser::SpecificParserConfig;
use crate::source::{SourceEncode, SourceFormat, SourceStruct};

Expand Down Expand Up @@ -498,8 +508,7 @@ mod test {
assert!(columns.is_err());
}

#[tokio::test]
async fn test_all_types() {
async fn create_recursive_pb_parser_config() -> ProtobufParserConfig {
let location = schema_dir() + "/proto_recursive/recursive.pb";
let message_name = "recursive.AllTypes";

Expand All @@ -515,19 +524,209 @@ mod test {
&HashMap::new(),
)
.unwrap();
let conf = ProtobufParserConfig::new(parser_config.encoding_config)

ProtobufParserConfig::new(parser_config.encoding_config)
.await
.unwrap();
// Ensure that the parser can recognize the schema.
conf.map_to_columns().unwrap();
.unwrap()
}

let field_desc = conf
.message_descriptor
.get_field_by_name("bytes_field")
.unwrap();
let d = from_protobuf_value(&field_desc, &Value::Bytes(Bytes::from(vec![1, 2, 3])))
#[tokio::test]
async fn test_all_types_create_source() {
let conf = create_recursive_pb_parser_config().await;

// Ensure that the parser can recognize the schema.
let columns = conf
.map_to_columns()
.unwrap()
.unwrap();
assert_eq!(d, ScalarImpl::Bytea(vec![1, 2, 3].into_boxed_slice()));
.into_iter()
.map(|c| DataType::from(&c.column_type.unwrap()))
.collect_vec();
assert_eq!(
columns,
vec![
DataType::Float64, // double_field
DataType::Float32, // float_field
DataType::Int32, // int32_field
DataType::Int64, // int64_field
DataType::Int64, // uint32_field
DataType::Decimal, // uint64_field
DataType::Int32, // sint32_field
DataType::Int64, // sint64_field
DataType::Int64, // fixed32_field
DataType::Decimal, // fixed64_field
DataType::Int32, // sfixed32_field
DataType::Int64, // sfixed64_field
DataType::Boolean, // bool_field
DataType::Varchar, // string_field
DataType::Bytea, // bytes_field
DataType::Varchar, // enum_field
DataType::Struct(StructType::new(vec![
("id", DataType::Int32),
("name", DataType::Varchar)
])), // nested_message_field
DataType::List(DataType::Int32.into()), // repeated_int_field
DataType::Varchar, // oneof_string
DataType::Int32, // oneof_int32
DataType::Varchar, // oneof_enum
DataType::Struct(StructType::new(vec![
("seconds", DataType::Int64),
("nanos", DataType::Int32)
])), // timestamp_field
DataType::Struct(StructType::new(vec![
("seconds", DataType::Int64),
("nanos", DataType::Int32)
])), // duration_field
DataType::Struct(StructType::new(vec![
("type_url", DataType::Varchar),
("value", DataType::Bytea),
])), // any_field
DataType::Struct(StructType::new(vec![("value", DataType::Int32)])), /* int32_value_field */
DataType::Struct(StructType::new(vec![("value", DataType::Varchar)])), /* string_value_field */
]
)
}

#[tokio::test]
async fn test_all_types_data_parsing() {
let m = create_all_types_message();
let mut payload = Vec::new();
m.encode(&mut payload).unwrap();

let conf = create_recursive_pb_parser_config().await;
let mut access_builder = ProtobufAccessBuilder::new(conf).unwrap();
let access = access_builder.generate_accessor(payload).await.unwrap();
if let AccessImpl::Protobuf(a) = access {
assert_all_types_eq(&a, &m);
} else {
panic!("unexpected")
}
}

fn assert_all_types_eq(a: &ProtobufAccess, m: &AllTypes) {
type S = ScalarImpl;

pb_eq(a, "double_field", S::Float64(m.double_field.into()));
pb_eq(a, "float_field", S::Float32(m.float_field.into()));
pb_eq(a, "int32_field", S::Int32(m.int32_field));
pb_eq(a, "int64_field", S::Int64(m.int64_field));
pb_eq(a, "uint32_field", S::Int64(m.uint32_field.into()));
pb_eq(a, "uint64_field", S::Decimal(m.uint64_field.into()));
pb_eq(a, "sint32_field", S::Int32(m.sint32_field));
pb_eq(a, "sint64_field", S::Int64(m.sint64_field));
pb_eq(a, "fixed32_field", S::Int64(m.fixed32_field.into()));
pb_eq(a, "fixed64_field", S::Decimal(m.fixed64_field.into()));
pb_eq(a, "sfixed32_field", S::Int32(m.sfixed32_field));
pb_eq(a, "sfixed64_field", S::Int64(m.sfixed64_field));
pb_eq(a, "bool_field", S::Bool(m.bool_field));
pb_eq(a, "string_field", S::Utf8(m.string_field.as_str().into()));
pb_eq(a, "bytes_field", S::Bytea(m.bytes_field.clone().into()));
pb_eq(a, "enum_field", S::Utf8("OPTION1".into()));
pb_eq(
a,
"nested_message_field",
S::Struct(StructValue::new(vec![
Some(ScalarImpl::Int32(100)),
Some(ScalarImpl::Utf8("Nested".into())),
])),
);
pb_eq(
a,
"repeated_int_field",
S::List(ListValue::new(
m.repeated_int_field
.iter()
.map(|&x| Some(x.into()))
.collect(),
)),
);
pb_eq(
a,
"timestamp_field",
S::Struct(StructValue::new(vec![
Some(ScalarImpl::Int64(1630927032)),
Some(ScalarImpl::Int32(500000000)),
])),
);
pb_eq(
a,
"duration_field",
S::Struct(StructValue::new(vec![
Some(ScalarImpl::Int64(60)),
Some(ScalarImpl::Int32(500000000)),
])),
);
pb_eq(
a,
"any_field",
S::Struct(StructValue::new(vec![
Some(ScalarImpl::Utf8(
m.any_field.as_ref().unwrap().type_url.as_str().into(),
)),
Some(ScalarImpl::Bytea(
m.any_field.as_ref().unwrap().value.clone().into(),
)),
])),
);
pb_eq(
a,
"int32_value_field",
S::Struct(StructValue::new(vec![Some(ScalarImpl::Int32(42))])),
);
pb_eq(
a,
"string_value_field",
S::Struct(StructValue::new(vec![Some(ScalarImpl::Utf8(
m.string_value_field.as_ref().unwrap().as_str().into(),
))])),
);
pb_eq(a, "oneof_string", S::Utf8("".into()));
pb_eq(a, "oneof_int32", S::Int32(123));
pb_eq(a, "oneof_enum", S::Utf8("DEFAULT".into()));
}

fn pb_eq(a: &ProtobufAccess, field_name: &str, value: ScalarImpl) {
let d = a.access(&[field_name], None).unwrap().unwrap();
assert_eq!(d, value, "field: {} value: {:?}", field_name, d);
}

fn create_all_types_message() -> AllTypes {
AllTypes {
double_field: 1.2345,
float_field: 1.2345,
int32_field: 42,
int64_field: 1234567890,
uint32_field: 98765,
uint64_field: 9876543210,
sint32_field: -12345,
sint64_field: -987654321,
fixed32_field: 1234,
fixed64_field: 5678,
sfixed32_field: -56789,
sfixed64_field: -123456,
bool_field: true,
string_field: "Hello, Prost!".to_string(),
bytes_field: b"byte data".to_vec(),
enum_field: EnumType::Option1 as i32,
nested_message_field: Some(NestedMessage {
id: 100,
name: "Nested".to_string(),
}),
repeated_int_field: vec![1, 2, 3, 4, 5],
timestamp_field: Some(::prost_types::Timestamp {
seconds: 1630927032,
nanos: 500000000,
}),
duration_field: Some(::prost_types::Duration {
seconds: 60,
nanos: 500000000,
}),
any_field: Some(::prost_types::Any {
type_url: "type.googleapis.com/my_custom_type".to_string(),
value: b"My custom data".to_vec(),
}),
int32_value_field: Some(42),
string_value_field: Some("Hello, Wrapper!".to_string()),
example_oneof: Some(ExampleOneof::OneofInt32(123)),
}
}
}
Binary file modified src/connector/src/test_data/proto_recursive/recursive.pb
Binary file not shown.
4 changes: 2 additions & 2 deletions src/connector/src/test_data/proto_recursive/recursive.proto
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ message AllTypes {
EnumType oneof_enum = 21;
}

// map field
map<string, int32> map_field = 22;
// // map field
// map<string, int32> map_field = 22;

// timestamp
google.protobuf.Timestamp timestamp_field = 23;
Expand Down