diff --git a/src/connector/codec/src/decoder/protobuf/parser.rs b/src/connector/codec/src/decoder/protobuf/parser.rs index 15778727fc46..852fa9cca48d 100644 --- a/src/connector/codec/src/decoder/protobuf/parser.rs +++ b/src/connector/codec/src/decoder/protobuf/parser.rs @@ -17,7 +17,7 @@ use itertools::Itertools; use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value}; use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::types::{ - DataType, DatumCow, Decimal, JsonbVal, ScalarImpl, ToOwnedDatum, F32, F64, + DataType, DatumCow, Decimal, JsonbVal, MapType, MapValue, ScalarImpl, ToOwnedDatum, F32, F64, }; use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion}; use thiserror::Error; @@ -180,10 +180,47 @@ pub fn from_protobuf_value<'a>( ScalarImpl::List(ListValue::new(builder.finish())) } Value::Bytes(value) => borrowed!(&**value), - _ => { - return Err(AccessError::UnsupportedType { - ty: format!("{kind:?}"), - }); + Value::Map(map) => { + let err = || { + AccessError::TypeError { + expected: type_expected.to_string(), + got: format!("{:?}", kind), + value: value.to_string(), // Protobuf TEXT + } + }; + + let DataType::Map(map_type) = type_expected else { + return Err(err()); + }; + if !field_desc.is_map() { + return Err(err()); + } + let map_desc = kind.as_message().ok_or_else(err)?; + + let mut key_builder = map_type.key().create_array_builder(map.len()); + let mut value_builder = map_type.value().create_array_builder(map.len()); + // NOTE: HashMap's iter order is non-deterministic, but MapValue's + // order matters. We sort by key here to have deterministic order + // in tests. We might consider removing this, or make all MapValue sorted + // in the future. + for (key, value) in map.iter().sorted_by_key(|(k, _v)| *k) { + key_builder.append(from_protobuf_value( + &map_desc.map_entry_key_field(), + &key.clone().into(), + map_type.key(), + )?); + value_builder.append(from_protobuf_value( + &map_desc.map_entry_value_field(), + value, + map_type.value(), + )?); + } + let keys = key_builder.finish(); + let values = value_builder.finish(); + ScalarImpl::Map( + MapValue::try_from_kv(ListValue::new(keys), ListValue::new(values)) + .map_err(|e| uncategorized!("failed to convert protobuf map: {e}"))?, + ) } }; Ok(Some(v).into()) @@ -195,8 +232,7 @@ fn protobuf_type_mapping( parse_trace: &mut Vec, ) -> std::result::Result { detect_loop_and_push(parse_trace, field_descriptor)?; - let field_type = field_descriptor.kind(); - let mut t = match field_type { + let mut t = match field_descriptor.kind() { Kind::Bool => DataType::Boolean, Kind::Double => DataType::Float64, Kind::Float => DataType::Float32, @@ -207,10 +243,18 @@ fn protobuf_type_mapping( } Kind::Uint64 | Kind::Fixed64 => DataType::Decimal, Kind::String => DataType::Varchar, - Kind::Message(m) => match m.full_name() { - // Well-Known Types are identified by their full name - "google.protobuf.Any" => DataType::Jsonb, - _ => { + Kind::Message(m) => { + if m.full_name() == "google.protobuf.Any" { + // Well-Known Types are identified by their full name + DataType::Jsonb + } else if m.is_map_entry() { + // Map is equivalent to `repeated MapFieldEntry map_field = N;` + debug_assert!(field_descriptor.is_map()); + let key = protobuf_type_mapping(&m.map_entry_key_field(), parse_trace)?; + let value = protobuf_type_mapping(&m.map_entry_value_field(), parse_trace)?; + _ = parse_trace.pop(); + return Ok(DataType::Map(MapType::from_kv(key, value))); + } else { let fields = m .fields() .map(|f| protobuf_type_mapping(&f, parse_trace)) @@ -218,17 +262,12 @@ fn protobuf_type_mapping( let field_names = m.fields().map(|f| f.name().to_string()).collect_vec(); DataType::new_struct(fields, field_names) } - }, + } Kind::Enum(_) => DataType::Varchar, Kind::Bytes => DataType::Bytea, }; - if field_descriptor.is_map() { - bail_protobuf_type_error!( - "protobuf map type (on field `{}`) is not supported", - field_descriptor.full_name() - ); - } if field_descriptor.cardinality() == Cardinality::Repeated { + debug_assert!(!field_descriptor.is_map()); t = DataType::List(Box::new(t)) } _ = parse_trace.pop(); diff --git a/src/connector/codec/tests/integration_tests/protobuf.rs b/src/connector/codec/tests/integration_tests/protobuf.rs index b07d5f739b81..9a70ef5e5c7a 100644 --- a/src/connector/codec/tests/integration_tests/protobuf.rs +++ b/src/connector/codec/tests/integration_tests/protobuf.rs @@ -18,6 +18,7 @@ mod recursive; #[rustfmt::skip] #[allow(clippy::all)] mod all_types; +use std::collections::HashMap; use std::path::PathBuf; use anyhow::Context; @@ -516,6 +517,11 @@ fn test_all_types() -> anyhow::Result<()> { name: "Nested".to_string(), }), repeated_int_field: vec![1, 2, 3, 4, 5], + map_field: HashMap::from_iter([ + ("key1".to_string(), 1), + ("key2".to_string(), 2), + ("key3".to_string(), 3), + ]), timestamp_field: Some(::prost_types::Timestamp { seconds: 1630927032, nanos: 500000000, @@ -531,6 +537,26 @@ fn test_all_types() -> anyhow::Result<()> { int32_value_field: Some(42), string_value_field: Some("Hello, Wrapper!".to_string()), example_oneof: Some(ExampleOneof::OneofInt32(123)), + map_struct_field: HashMap::from_iter([ + ( + "key1".to_string(), + NestedMessage { + id: 1, + name: "A".to_string(), + }, + ), + ( + "key2".to_string(), + NestedMessage { + id: 2, + name: "B".to_string(), + }, + ), + ]), + map_enum_field: HashMap::from_iter([ + (1, EnumType::Option1 as i32), + (2, EnumType::Option2 as i32), + ]), } }; let mut data_bytes = Vec::new(); @@ -565,17 +591,23 @@ fn test_all_types() -> anyhow::Result<()> { oneof_string(#21): Varchar, oneof_int32(#22): Int32, oneof_enum(#23): Varchar, - timestamp_field(#26): Struct { + map_field(#26): Map(Varchar,Int32), type_name: all_types.AllTypes.MapFieldEntry, field_descs: [key(#24): Varchar, value(#25): Int32], + timestamp_field(#29): Struct { seconds: Int64, nanos: Int32, - }, type_name: google.protobuf.Timestamp, field_descs: [seconds(#24): Int64, nanos(#25): Int32], - duration_field(#29): Struct { + }, type_name: google.protobuf.Timestamp, field_descs: [seconds(#27): Int64, nanos(#28): Int32], + duration_field(#32): Struct { seconds: Int64, nanos: Int32, - }, type_name: google.protobuf.Duration, field_descs: [seconds(#27): Int64, nanos(#28): Int32], - any_field(#32): Jsonb, type_name: google.protobuf.Any, field_descs: [type_url(#30): Varchar, value(#31): Bytea], - int32_value_field(#34): Struct { value: Int32 }, type_name: google.protobuf.Int32Value, field_descs: [value(#33): Int32], - string_value_field(#36): Struct { value: Varchar }, type_name: google.protobuf.StringValue, field_descs: [value(#35): Varchar], + }, type_name: google.protobuf.Duration, field_descs: [seconds(#30): Int64, nanos(#31): Int32], + any_field(#35): Jsonb, type_name: google.protobuf.Any, field_descs: [type_url(#33): Varchar, value(#34): Bytea], + int32_value_field(#37): Struct { value: Int32 }, type_name: google.protobuf.Int32Value, field_descs: [value(#36): Int32], + string_value_field(#39): Struct { value: Varchar }, type_name: google.protobuf.StringValue, field_descs: [value(#38): Varchar], + map_struct_field(#44): Map(Varchar,Struct { id: Int32, name: Varchar }), type_name: all_types.AllTypes.MapStructFieldEntry, field_descs: [key(#40): Varchar, value(#43): Struct { + id: Int32, + name: Varchar, + }, type_name: all_types.AllTypes.NestedMessage, field_descs: [id(#41): Int32, name(#42): Varchar]], + map_enum_field(#47): Map(Int32,Varchar), type_name: all_types.AllTypes.MapEnumFieldEntry, field_descs: [key(#45): Int32, value(#46): Varchar], ]"#]], expect![[r#" Owned(Float64(OrderedFloat(1.2345))) @@ -608,6 +640,20 @@ fn test_all_types() -> anyhow::Result<()> { Owned(Utf8("")) Owned(Int32(123)) Owned(Utf8("DEFAULT")) + Owned([ + StructValue( + Utf8("key1"), + Int32(1), + ), + StructValue( + Utf8("key2"), + Int32(2), + ), + StructValue( + Utf8("key3"), + Int32(3), + ), + ]) Owned(StructValue( Int64(1630927032), Int32(500000000), @@ -620,7 +666,33 @@ fn test_all_types() -> anyhow::Result<()> { Error at column `any_field`: Fail to convert protobuf Any into jsonb: message 'my_custom_type' not found ~~~~ Owned(StructValue(Int32(42))) - Owned(StructValue(Utf8("Hello, Wrapper!")))"#]], + Owned(StructValue(Utf8("Hello, Wrapper!"))) + Owned([ + StructValue( + Utf8("key1"), + StructValue( + Int32(1), + Utf8("A"), + ), + ), + StructValue( + Utf8("key2"), + StructValue( + Int32(2), + Utf8("B"), + ), + ), + ]) + Owned([ + StructValue( + Int32(1), + Utf8("OPTION1"), + ), + StructValue( + Int32(2), + Utf8("OPTION2"), + ), + ])"#]], ); Ok(()) diff --git a/src/connector/codec/tests/test_data/all-types.proto b/src/connector/codec/tests/test_data/all-types.proto index 7dcad51a645d..5070328dbf5f 100644 --- a/src/connector/codec/tests/test_data/all-types.proto +++ b/src/connector/codec/tests/test_data/all-types.proto @@ -53,8 +53,8 @@ message AllTypes { EnumType oneof_enum = 21; } - // // map field - // map map_field = 22; + // map field + map map_field = 22; // timestamp google.protobuf.Timestamp timestamp_field = 23; @@ -73,4 +73,7 @@ message AllTypes { // wrapper types google.protobuf.Int32Value int32_value_field = 27; google.protobuf.StringValue string_value_field = 28; + + map map_struct_field = 29; + map map_enum_field = 30; }