diff --git a/src/connector/codec/src/decoder/protobuf/parser.rs b/src/connector/codec/src/decoder/protobuf/parser.rs index 15778727fc46..edae3433a84c 100644 --- a/src/connector/codec/src/decoder/protobuf/parser.rs +++ b/src/connector/codec/src/decoder/protobuf/parser.rs @@ -17,7 +17,7 @@ use itertools::Itertools; use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value}; use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::types::{ - DataType, DatumCow, Decimal, JsonbVal, ScalarImpl, ToOwnedDatum, F32, F64, + DataType, DatumCow, Decimal, JsonbVal, MapType, MapValue, ScalarImpl, ToOwnedDatum, F32, F64, }; use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion}; use thiserror::Error; @@ -180,10 +180,43 @@ pub fn from_protobuf_value<'a>( ScalarImpl::List(ListValue::new(builder.finish())) } Value::Bytes(value) => borrowed!(&**value), - _ => { - return Err(AccessError::UnsupportedType { - ty: format!("{kind:?}"), - }); + Value::Map(map) => { + let err = || { + AccessError::TypeError { + expected: type_expected.to_string(), + got: format!("{:?}", kind), + value: value.to_string(), // Protobuf TEXT + } + }; + + let DataType::Map(map_type) = type_expected else { + return Err(err()); + }; + let map_desc = kind.as_message().ok_or_else(err)?; + if !map_desc.is_map_entry() { + return Err(err()); + } + + let mut key_builder = map_type.key().create_array_builder(map.len()); + let mut value_builder = map_type.value().create_array_builder(map.len()); + // NOTE: HashMap's iter order is non-deterministic, but MapValue's + // order matters. We sort by key here to have deterministic order + // in tests. We might consider removing this, or make all MapValue sorted + // in the future. + for (key, value) in map.into_iter().sorted_by_key(|(k, _v)| *k) { + key_builder.append(from_protobuf_value( + field_desc, + &key.clone().into(), + map_type.key(), + )?); + value_builder.append(from_protobuf_value(field_desc, &value, map_type.value())?); + } + let keys = key_builder.finish(); + let values = value_builder.finish(); + ScalarImpl::Map( + MapValue::try_from_kv(ListValue::new(keys), ListValue::new(values)) + .map_err(|e| uncategorized!("failed to convert protobuf map: {e}"))?, + ) } }; Ok(Some(v).into()) @@ -195,8 +228,7 @@ fn protobuf_type_mapping( parse_trace: &mut Vec, ) -> std::result::Result { detect_loop_and_push(parse_trace, field_descriptor)?; - let field_type = field_descriptor.kind(); - let mut t = match field_type { + let mut t = match field_descriptor.kind() { Kind::Bool => DataType::Boolean, Kind::Double => DataType::Float64, Kind::Float => DataType::Float32, @@ -207,10 +239,18 @@ fn protobuf_type_mapping( } Kind::Uint64 | Kind::Fixed64 => DataType::Decimal, Kind::String => DataType::Varchar, - Kind::Message(m) => match m.full_name() { - // Well-Known Types are identified by their full name - "google.protobuf.Any" => DataType::Jsonb, - _ => { + Kind::Message(m) => { + if m.full_name() == "google.protobuf.Any" { + // Well-Known Types are identified by their full name + DataType::Jsonb + } else if m.is_map_entry() { + // Map is equivalent to `repeated MapFieldEntry map_field = N;` + debug_assert!(field_descriptor.is_map()); + let key = protobuf_type_mapping(&m.map_entry_key_field(), parse_trace)?; + let value = protobuf_type_mapping(&m.map_entry_value_field(), parse_trace)?; + _ = parse_trace.pop(); + return Ok(DataType::Map(MapType::from_kv(key, value))); + } else { let fields = m .fields() .map(|f| protobuf_type_mapping(&f, parse_trace)) @@ -218,17 +258,12 @@ fn protobuf_type_mapping( let field_names = m.fields().map(|f| f.name().to_string()).collect_vec(); DataType::new_struct(fields, field_names) } - }, + } Kind::Enum(_) => DataType::Varchar, Kind::Bytes => DataType::Bytea, }; - if field_descriptor.is_map() { - bail_protobuf_type_error!( - "protobuf map type (on field `{}`) is not supported", - field_descriptor.full_name() - ); - } if field_descriptor.cardinality() == Cardinality::Repeated { + debug_assert!(!field_descriptor.is_map()); t = DataType::List(Box::new(t)) } _ = parse_trace.pop(); diff --git a/src/connector/codec/tests/integration_tests/protobuf.rs b/src/connector/codec/tests/integration_tests/protobuf.rs index b07d5f739b81..be71ece4e9a6 100644 --- a/src/connector/codec/tests/integration_tests/protobuf.rs +++ b/src/connector/codec/tests/integration_tests/protobuf.rs @@ -18,6 +18,7 @@ mod recursive; #[rustfmt::skip] #[allow(clippy::all)] mod all_types; +use std::collections::HashMap; use std::path::PathBuf; use anyhow::Context; @@ -516,6 +517,11 @@ fn test_all_types() -> anyhow::Result<()> { name: "Nested".to_string(), }), repeated_int_field: vec![1, 2, 3, 4, 5], + map_field: HashMap::from_iter([ + ("key1".to_string(), 1), + ("key2".to_string(), 2), + ("key3".to_string(), 3), + ]), timestamp_field: Some(::prost_types::Timestamp { seconds: 1630927032, nanos: 500000000, @@ -565,17 +571,18 @@ fn test_all_types() -> anyhow::Result<()> { oneof_string(#21): Varchar, oneof_int32(#22): Int32, oneof_enum(#23): Varchar, - timestamp_field(#26): Struct { + map_field(#26): Map(Varchar,Int32), type_name: all_types.AllTypes.MapFieldEntry, field_descs: [key(#24): Varchar, value(#25): Int32], + timestamp_field(#29): Struct { seconds: Int64, nanos: Int32, - }, type_name: google.protobuf.Timestamp, field_descs: [seconds(#24): Int64, nanos(#25): Int32], - duration_field(#29): Struct { + }, type_name: google.protobuf.Timestamp, field_descs: [seconds(#27): Int64, nanos(#28): Int32], + duration_field(#32): Struct { seconds: Int64, nanos: Int32, - }, type_name: google.protobuf.Duration, field_descs: [seconds(#27): Int64, nanos(#28): Int32], - any_field(#32): Jsonb, type_name: google.protobuf.Any, field_descs: [type_url(#30): Varchar, value(#31): Bytea], - int32_value_field(#34): Struct { value: Int32 }, type_name: google.protobuf.Int32Value, field_descs: [value(#33): Int32], - string_value_field(#36): Struct { value: Varchar }, type_name: google.protobuf.StringValue, field_descs: [value(#35): Varchar], + }, type_name: google.protobuf.Duration, field_descs: [seconds(#30): Int64, nanos(#31): Int32], + any_field(#35): Jsonb, type_name: google.protobuf.Any, field_descs: [type_url(#33): Varchar, value(#34): Bytea], + int32_value_field(#37): Struct { value: Int32 }, type_name: google.protobuf.Int32Value, field_descs: [value(#36): Int32], + string_value_field(#39): Struct { value: Varchar }, type_name: google.protobuf.StringValue, field_descs: [value(#38): Varchar], ]"#]], expect![[r#" Owned(Float64(OrderedFloat(1.2345))) @@ -608,6 +615,20 @@ fn test_all_types() -> anyhow::Result<()> { Owned(Utf8("")) Owned(Int32(123)) Owned(Utf8("DEFAULT")) + Owned([ + StructValue( + Utf8("key1"), + Int32(1), + ), + StructValue( + Utf8("key2"), + Int32(2), + ), + StructValue( + Utf8("key3"), + Int32(3), + ), + ]) Owned(StructValue( Int64(1630927032), Int32(500000000), diff --git a/src/connector/codec/tests/test_data/all-types.proto b/src/connector/codec/tests/test_data/all-types.proto index 7dcad51a645d..3d019a70167d 100644 --- a/src/connector/codec/tests/test_data/all-types.proto +++ b/src/connector/codec/tests/test_data/all-types.proto @@ -53,8 +53,8 @@ message AllTypes { EnumType oneof_enum = 21; } - // // map field - // map map_field = 22; + // map field + map map_field = 22; // timestamp google.protobuf.Timestamp timestamp_field = 23;