Skip to content

Commit

Permalink
feat(source): support ingesting protobuf map
Browse files Browse the repository at this point in the history
  • Loading branch information
xxchan committed Sep 13, 2024
1 parent 2fc2b5d commit 1833dff
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 27 deletions.
71 changes: 53 additions & 18 deletions src/connector/codec/src/decoder/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use itertools::Itertools;
use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value};
use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::types::{
DataType, DatumCow, Decimal, JsonbVal, ScalarImpl, ToOwnedDatum, F32, F64,
DataType, DatumCow, Decimal, JsonbVal, MapType, MapValue, ScalarImpl, ToOwnedDatum, F32, F64,
};
use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion};
use thiserror::Error;
Expand Down Expand Up @@ -180,10 +180,43 @@ pub fn from_protobuf_value<'a>(
ScalarImpl::List(ListValue::new(builder.finish()))
}
Value::Bytes(value) => borrowed!(&**value),
_ => {
return Err(AccessError::UnsupportedType {
ty: format!("{kind:?}"),
});
Value::Map(map) => {
let err = || {
AccessError::TypeError {
expected: type_expected.to_string(),
got: format!("{:?}", kind),
value: value.to_string(), // Protobuf TEXT
}
};

let DataType::Map(map_type) = type_expected else {
return Err(err());
};
let map_desc = kind.as_message().ok_or_else(err)?;
if !map_desc.is_map_entry() {
return Err(err());
}

let mut key_builder = map_type.key().create_array_builder(map.len());
let mut value_builder = map_type.value().create_array_builder(map.len());
// NOTE: HashMap's iter order is non-deterministic, but MapValue's
// order matters. We sort by key here to have deterministic order
// in tests. We might consider removing this, or make all MapValue sorted
// in the future.
for (key, value) in map.into_iter().sorted_by_key(|(k, _v)| *k) {
key_builder.append(from_protobuf_value(
field_desc,
&key.clone().into(),
map_type.key(),
)?);
value_builder.append(from_protobuf_value(field_desc, &value, map_type.value())?);
}
let keys = key_builder.finish();
let values = value_builder.finish();
ScalarImpl::Map(
MapValue::try_from_kv(ListValue::new(keys), ListValue::new(values))
.map_err(|e| uncategorized!("failed to convert protobuf map: {e}"))?,
)
}
};
Ok(Some(v).into())
Expand All @@ -195,8 +228,7 @@ fn protobuf_type_mapping(
parse_trace: &mut Vec<String>,
) -> std::result::Result<DataType, ProtobufTypeError> {
detect_loop_and_push(parse_trace, field_descriptor)?;
let field_type = field_descriptor.kind();
let mut t = match field_type {
let mut t = match field_descriptor.kind() {
Kind::Bool => DataType::Boolean,
Kind::Double => DataType::Float64,
Kind::Float => DataType::Float32,
Expand All @@ -207,28 +239,31 @@ fn protobuf_type_mapping(
}
Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
Kind::String => DataType::Varchar,
Kind::Message(m) => match m.full_name() {
// Well-Known Types are identified by their full name
"google.protobuf.Any" => DataType::Jsonb,
_ => {
Kind::Message(m) => {
if m.full_name() == "google.protobuf.Any" {
// Well-Known Types are identified by their full name
DataType::Jsonb
} else if m.is_map_entry() {
// Map is equivalent to `repeated MapFieldEntry map_field = N;`
debug_assert!(field_descriptor.is_map());
let key = protobuf_type_mapping(&m.map_entry_key_field(), parse_trace)?;
let value = protobuf_type_mapping(&m.map_entry_value_field(), parse_trace)?;
_ = parse_trace.pop();
return Ok(DataType::Map(MapType::from_kv(key, value)));
} else {
let fields = m
.fields()
.map(|f| protobuf_type_mapping(&f, parse_trace))
.try_collect()?;
let field_names = m.fields().map(|f| f.name().to_string()).collect_vec();
DataType::new_struct(fields, field_names)
}
},
}
Kind::Enum(_) => DataType::Varchar,
Kind::Bytes => DataType::Bytea,
};
if field_descriptor.is_map() {
bail_protobuf_type_error!(
"protobuf map type (on field `{}`) is not supported",
field_descriptor.full_name()
);
}
if field_descriptor.cardinality() == Cardinality::Repeated {
debug_assert!(!field_descriptor.is_map());
t = DataType::List(Box::new(t))
}
_ = parse_trace.pop();
Expand Down
35 changes: 28 additions & 7 deletions src/connector/codec/tests/integration_tests/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod recursive;
#[rustfmt::skip]
#[allow(clippy::all)]
mod all_types;
use std::collections::HashMap;
use std::path::PathBuf;

use anyhow::Context;
Expand Down Expand Up @@ -516,6 +517,11 @@ fn test_all_types() -> anyhow::Result<()> {
name: "Nested".to_string(),
}),
repeated_int_field: vec![1, 2, 3, 4, 5],
map_field: HashMap::from_iter([
("key1".to_string(), 1),
("key2".to_string(), 2),
("key3".to_string(), 3),
]),
timestamp_field: Some(::prost_types::Timestamp {
seconds: 1630927032,
nanos: 500000000,
Expand Down Expand Up @@ -565,17 +571,18 @@ fn test_all_types() -> anyhow::Result<()> {
oneof_string(#21): Varchar,
oneof_int32(#22): Int32,
oneof_enum(#23): Varchar,
timestamp_field(#26): Struct {
map_field(#26): Map(Varchar,Int32), type_name: all_types.AllTypes.MapFieldEntry, field_descs: [key(#24): Varchar, value(#25): Int32],
timestamp_field(#29): Struct {
seconds: Int64,
nanos: Int32,
}, type_name: google.protobuf.Timestamp, field_descs: [seconds(#24): Int64, nanos(#25): Int32],
duration_field(#29): Struct {
}, type_name: google.protobuf.Timestamp, field_descs: [seconds(#27): Int64, nanos(#28): Int32],
duration_field(#32): Struct {
seconds: Int64,
nanos: Int32,
}, type_name: google.protobuf.Duration, field_descs: [seconds(#27): Int64, nanos(#28): Int32],
any_field(#32): Jsonb, type_name: google.protobuf.Any, field_descs: [type_url(#30): Varchar, value(#31): Bytea],
int32_value_field(#34): Struct { value: Int32 }, type_name: google.protobuf.Int32Value, field_descs: [value(#33): Int32],
string_value_field(#36): Struct { value: Varchar }, type_name: google.protobuf.StringValue, field_descs: [value(#35): Varchar],
}, type_name: google.protobuf.Duration, field_descs: [seconds(#30): Int64, nanos(#31): Int32],
any_field(#35): Jsonb, type_name: google.protobuf.Any, field_descs: [type_url(#33): Varchar, value(#34): Bytea],
int32_value_field(#37): Struct { value: Int32 }, type_name: google.protobuf.Int32Value, field_descs: [value(#36): Int32],
string_value_field(#39): Struct { value: Varchar }, type_name: google.protobuf.StringValue, field_descs: [value(#38): Varchar],
]"#]],
expect![[r#"
Owned(Float64(OrderedFloat(1.2345)))
Expand Down Expand Up @@ -608,6 +615,20 @@ fn test_all_types() -> anyhow::Result<()> {
Owned(Utf8(""))
Owned(Int32(123))
Owned(Utf8("DEFAULT"))
Owned([
StructValue(
Utf8("key1"),
Int32(1),
),
StructValue(
Utf8("key2"),
Int32(2),
),
StructValue(
Utf8("key3"),
Int32(3),
),
])
Owned(StructValue(
Int64(1630927032),
Int32(500000000),
Expand Down
4 changes: 2 additions & 2 deletions src/connector/codec/tests/test_data/all-types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ message AllTypes {
EnumType oneof_enum = 21;
}

// // map field
// map<string, int32> map_field = 22;
// map field
map<string, int32> map_field = 22;

// timestamp
google.protobuf.Timestamp timestamp_field = 23;
Expand Down

0 comments on commit 1833dff

Please sign in to comment.