diff --git a/src/common/src/array/map_array.rs b/src/common/src/array/map_array.rs index f519c25981a56..d3f852dda0560 100644 --- a/src/common/src/array/map_array.rs +++ b/src/common/src/array/map_array.rs @@ -18,13 +18,14 @@ use std::fmt::{self, Debug, Display}; use bytes::{Buf, BufMut}; use itertools::Itertools; use risingwave_common_estimate_size::EstimateSize; +use risingwave_error::BoxedError; use risingwave_pb::data::{PbArray, PbArrayType}; use serde::Serializer; use super::{ Array, ArrayBuilder, ArrayImpl, ArrayResult, DatumRef, DefaultOrdered, ListArray, - ListArrayBuilder, ListRef, ListValue, MapType, ScalarRef, ScalarRefImpl, StructArray, - StructRef, + ListArrayBuilder, ListRef, ListValue, MapType, ScalarImpl, ScalarRef, ScalarRefImpl, + StructArray, StructRef, }; use crate::bitmap::Bitmap; use crate::types::{DataType, Scalar, ToText}; @@ -525,3 +526,36 @@ impl ToText for MapRef<'_> { } } } + +impl MapValue { + pub fn from_str_for_test(s: &str, data_type: &MapType) -> Result { + // TODO: this is a quick trivial implementation. Implement the full version later. + + // example: {1:1,2:NULL,3:3} + + if !s.starts_with('{') { + return Err(format!("Missing left parenthesis: {}", s).into()); + } + if !s.ends_with('}') { + return Err(format!("Missing right parenthesis: {}", s).into()); + } + let mut key_builder = data_type.key().create_array_builder(100); + let mut value_builder = data_type.value().create_array_builder(100); + for kv in s[1..s.len() - 1].split(',') { + let (k, v) = kv.split_once(':').ok_or("Invalid map format")?; + key_builder.append(Some(ScalarImpl::from_text(k, data_type.key())?)); + if v == "NULL" { + value_builder.append_null(); + } else { + value_builder.append(Some(ScalarImpl::from_text(v, data_type.value())?)); + } + } + let key_array = key_builder.finish(); + let value_array = value_builder.finish(); + + Ok(MapValue::try_from_kv( + ListValue::new(key_array), + ListValue::new(value_array), + )?) + } +} diff --git a/src/common/src/array/struct_array.rs b/src/common/src/array/struct_array.rs index ebf224f581616..10ded3a64d66c 100644 --- a/src/common/src/array/struct_array.rs +++ b/src/common/src/array/struct_array.rs @@ -337,17 +337,14 @@ impl StructValue { .map(Self::new) } - /// Construct an array from literal string. + /// Construct a struct from literal string. /// /// # Example /// /// ``` /// # use risingwave_common::types::{StructValue, StructType, DataType, ScalarImpl}; /// - /// let ty = DataType::Struct(StructType::unnamed(vec![ - /// DataType::Int32, - /// DataType::Float64, - /// ])); + /// let ty = StructType::unnamed(vec![DataType::Int32, DataType::Float64]); /// let s = StructValue::from_str("(1, 2.0)", &ty).unwrap(); /// assert_eq!(s.fields()[0], Some(ScalarImpl::Int32(1))); /// assert_eq!(s.fields()[1], Some(ScalarImpl::Float64(2.0.into()))); @@ -356,11 +353,8 @@ impl StructValue { /// assert_eq!(s.fields()[0], None); /// assert_eq!(s.fields()[1], None); /// ``` - pub fn from_str(s: &str, data_type: &DataType) -> Result { + pub fn from_str(s: &str, ty: &StructType) -> Result { // FIXME(runji): this is a trivial implementation which does not support nested struct. - let DataType::Struct(ty) = data_type else { - return Err(format!("Expect struct type, got {:?}", data_type).into()); - }; if !s.starts_with('(') { return Err("Missing left parenthesis".into()); } diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index bc16398d1fbf2..2c62945b3df80 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -918,12 +918,17 @@ impl ScalarImpl { DataType::Time => Time::from_str(s)?.into(), DataType::Interval => Interval::from_str(s)?.into(), DataType::List(_) => ListValue::from_str(s, data_type)?.into(), - DataType::Struct(_) => StructValue::from_str(s, data_type)?.into(), + DataType::Struct(st) => StructValue::from_str(s, st)?.into(), DataType::Jsonb => JsonbVal::from_str(s)?.into(), DataType::Bytea => str_to_bytea(s)?.into(), - DataType::Map(_) => { - todo!() - } + DataType::Map(_m) => return Err("map from text is not supported".into()), + }) + } + + pub fn from_text_for_test(s: &str, data_type: &DataType) -> Result { + Ok(match data_type { + DataType::Map(map_type) => MapValue::from_str_for_test(s, map_type)?.into(), + _ => ScalarImpl::from_text(s, data_type)?, }) } } diff --git a/src/connector/src/sink/encoder/avro.rs b/src/connector/src/sink/encoder/avro.rs index 1a9218572814f..3c45c8f572cc5 100644 --- a/src/connector/src/sink/encoder/avro.rs +++ b/src/connector/src/sink/encoder/avro.rs @@ -146,9 +146,13 @@ impl SerTo> for AvroEncoded { } enum OptIdx { + /// `T` NotUnion, + /// `[T]` Single, + /// `[null, T]` NullLeft, + /// `[T, null]` NullRight, } @@ -167,7 +171,9 @@ trait MaybeData: std::fmt::Debug { fn on_list(self, elem: &DataType, avro: &AvroSchema) -> Result; - fn handle_union(out: Self::Out, opt_idx: OptIdx) -> Result; + fn on_map(self, value_type: &DataType, avro_value_schema: &AvroSchema) -> Result; + + fn handle_nullable_union(out: Self::Out, opt_idx: OptIdx) -> Result; } impl MaybeData for () { @@ -182,10 +188,14 @@ impl MaybeData for () { } fn on_list(self, elem: &DataType, avro: &AvroSchema) -> Result { - encode_field(elem, (), avro) + on_field(elem, (), avro) + } + + fn on_map(self, elem: &DataType, avro: &AvroSchema) -> Result { + on_field(elem, (), avro) } - fn handle_union(out: Self::Out, _: OptIdx) -> Result { + fn handle_nullable_union(out: Self::Out, _: OptIdx) -> Result { Ok(out) } } @@ -214,14 +224,27 @@ impl MaybeData for DatumRef<'_> { Some(s) => s.into_list(), None => return Ok(Value::Null), }; + let vs = d.iter().map(|d| on_field(elem, d, avro)).try_collect()?; + Ok(Value::Array(vs)) + } + + fn on_map(self, elem: &DataType, avro: &AvroSchema) -> Result { + let d = match self { + Some(s) => s.into_map(), + None => return Ok(Value::Null), + }; let vs = d .iter() - .map(|d| encode_field(elem, d, avro)) + .map(|(k, v)| { + let k = k.into_utf8().to_string(); + let v = on_field(elem, v, avro)?; + Ok((k, v)) + }) .try_collect()?; - Ok(Value::Array(vs)) + Ok(Value::Map(vs)) } - fn handle_union(out: Self::Out, opt_idx: OptIdx) -> Result { + fn handle_nullable_union(out: Self::Out, opt_idx: OptIdx) -> Result { use OptIdx::*; match out == Value::Null { @@ -264,7 +287,7 @@ fn validate_fields<'rw>( }; present[idx] = true; let avro_field = &fields[idx]; - encode_field(t, (), &avro_field.schema).map_err(|e| e.with_name(name))?; + on_field(t, (), &avro_field.schema).map_err(|e| e.with_name(name))?; } for (p, avro_field) in present.into_iter().zip_eq_fast(fields) { if p { @@ -292,7 +315,7 @@ fn encode_fields<'avro, 'rw>( let idx = lookup[name]; present[idx] = true; let avro_field = &fields[idx]; - let value = encode_field(t, d, &avro_field.schema).map_err(|e| e.with_name(name))?; + let value = on_field(t, d, &avro_field.schema).map_err(|e| e.with_name(name))?; record.put(name, value); } // Unfortunately, the upstream `apache_avro` does not handle missing fields as nullable correctly. @@ -323,11 +346,7 @@ fn encode_fields<'avro, 'rw>( /// Handles both `validate` (without actual data) and `encode`. /// See [`MaybeData`] for more info. -fn encode_field( - data_type: &DataType, - maybe: D, - expected: &AvroSchema, -) -> Result { +fn on_field(data_type: &DataType, maybe: D, expected: &AvroSchema) -> Result { use risingwave_common::types::Interval; let no_match_err = || { @@ -397,6 +416,16 @@ fn encode_field( AvroSchema::Array(avro_elem) => maybe.on_list(elem, avro_elem)?, _ => return no_match_err(), }, + DataType::Map(m) => { + if *m.key() != DataType::Varchar { + return no_match_err(); + } + match inner { + AvroSchema::Map(avro_value_type) => maybe.on_map(m.value(), avro_value_type)?, + _ => return no_match_err(), + } + } + // Group B: match between RisingWave types and Avro logical types DataType::Timestamptz => match inner { AvroSchema::TimestampMicros => maybe.on_base(|s| { @@ -454,40 +483,132 @@ fn encode_field( DataType::Int256 => { return no_match_err(); } - DataType::Map(_) => { - // TODO(map): support map - return no_match_err(); - } }; - D::handle_union(value, opt_idx) + D::handle_nullable_union(value, opt_idx) } #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::str::FromStr; + + use expect_test::expect; + use itertools::Itertools; + use risingwave_common::array::{ArrayBuilder, MapArrayBuilder}; use risingwave_common::catalog::Field; use risingwave_common::row::OwnedRow; use risingwave_common::types::{ - Date, Datum, Interval, ListValue, ScalarImpl, StructValue, Time, Timestamptz, ToDatumRef, + Date, Datum, Interval, ListValue, MapType, MapValue, Scalar, ScalarImpl, StructValue, Time, + Timestamptz, ToDatumRef, }; use super::*; - fn test_ok(t: &DataType, d: Datum, avro: &str, expected: Value) { - let avro_schema = AvroSchema::parse_str(avro).unwrap(); - let actual = encode_field(t, d.to_datum_ref(), &avro_schema).unwrap(); + #[track_caller] + fn test_ok(rw_type: &DataType, rw_datum: Datum, avro_type: &str, expected: Value) { + let avro_schema = AvroSchema::parse_str(avro_type).unwrap(); + let actual = on_field(rw_type, rw_datum.to_datum_ref(), &avro_schema).unwrap(); assert_eq!(actual, expected); } + #[track_caller] fn test_err(t: &DataType, d: D, avro: &str, expected: &str) where D::Out: std::fmt::Debug, { let avro_schema = AvroSchema::parse_str(avro).unwrap(); - let err = encode_field(t, d, &avro_schema).unwrap_err(); + let err = on_field(t, d, &avro_schema).unwrap_err(); assert_eq!(err.to_string(), expected); } + #[track_caller] + fn test_v2(rw_type: &str, rw_scalar: &str, avro_type: &str, expected: expect_test::Expect) { + let avro_schema = AvroSchema::parse_str(avro_type).unwrap(); + let rw_type = DataType::from_str(rw_type).unwrap(); + let rw_datum = ScalarImpl::from_text_for_test(rw_scalar, &rw_type).unwrap(); + + if let Err(validate_err) = on_field(&rw_type, (), &avro_schema) { + expected.assert_debug_eq(&validate_err); + return; + } + let actual = on_field(&rw_type, Some(rw_datum).to_datum_ref(), &avro_schema); + match actual { + Ok(v) => expected.assert_eq(&print_avro_value(&v)), + Err(e) => expected.assert_debug_eq(&e), + } + } + + fn print_avro_value(v: &Value) -> String { + match v { + Value::Map(m) => { + let mut res = "Map({".to_string(); + for (k, v) in m.iter().sorted_by_key(|x| x.0) { + res.push_str(&format!("{}: {}, ", k, print_avro_value(v))); + } + res.push_str("})"); + res + } + _ => format!("{v:?}"), + } + } + + #[test] + fn test_encode_v2() { + test_v2( + "boolean", + "false", + r#""int""#, + expect![[r#" + FieldEncodeError { + message: "cannot encode boolean column as \"int\" field", + rev_path: [], + } + "#]], + ); + test_v2("boolean", "true", r#""boolean""#, expect!["Boolean(true)"]); + + test_v2( + "map(varchar,varchar)", + "{1:1,2:2,3:3}", + r#"{"type": "map","values": "string"}"#, + expect![[r#"Map({1: String("1"), 2: String("2"), 3: String("3"), })"#]], + ); + + test_v2( + "map(varchar,varchar)", + "{1:1,2:NULL,3:3}", + r#"{"type": "map","values": "string"}"#, + expect![[r#" + FieldEncodeError { + message: "found null but required", + rev_path: [], + } + "#]], + ); + + test_v2( + "map(varchar,varchar)", + "{1:1,2:NULL,3:3}", + r#"{"type": "map","values": ["null", "string"]}"#, + expect![[ + r#"Map({1: Union(1, String("1")), 2: Union(0, Null), 3: Union(1, String("3")), })"# + ]], + ); + + test_v2( + "map(int,varchar)", + "{1:1,2:NULL,3:3}", + r#"{"type": "map","values": ["null", "string"]}"#, + expect![[r#" + FieldEncodeError { + message: "cannot encode map(integer,character varying) column as {\"type\":\"map\",\"values\":[\"null\",\"string\"]} field", + rev_path: [], + } + "#]], + ); + } + #[test] fn test_encode_avro_ok() { test_ok( @@ -592,7 +713,59 @@ mod tests { apache_avro::Days::new(2), apache_avro::Millis::new(1000), )), - ) + ); + + let mut inner_map_array_builder = MapArrayBuilder::with_type( + 2, + DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)), + ); + inner_map_array_builder.append(Some( + MapValue::try_from_kv( + ListValue::from_iter(["a", "b"]), + ListValue::from_iter([1, 2]), + ) + .unwrap() + .as_scalar_ref(), + )); + inner_map_array_builder.append(Some( + MapValue::try_from_kv( + ListValue::from_iter(["c", "d"]), + ListValue::from_iter([3, 4]), + ) + .unwrap() + .as_scalar_ref(), + )); + let inner_map_array = inner_map_array_builder.finish(); + test_ok( + &DataType::Map(MapType::from_kv( + DataType::Varchar, + DataType::Map(MapType::from_kv(DataType::Varchar, DataType::Int32)), + )), + Some(ScalarImpl::Map( + MapValue::try_from_kv( + ListValue::from_iter(["k1", "k2"]), + ListValue::new(inner_map_array.into()), + ) + .unwrap(), + )), + r#"{"type": "map","values": {"type": "map","values": "int"}}"#, + Value::Map(HashMap::from_iter([ + ( + "k1".into(), + Value::Map(HashMap::from_iter([ + ("a".into(), Value::Int(1)), + ("b".into(), Value::Int(2)), + ])), + ), + ( + "k2".into(), + Value::Map(HashMap::from_iter([ + ("c".into(), Value::Int(3)), + ("d".into(), Value::Int(4)), + ])), + ), + ])), + ); } #[test]