From 5f3947b431f3bc4faafe775bfbedcf1a7eee77b1 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 26 Jun 2024 17:01:24 +0800 Subject: [PATCH] implement union --- src/common/src/types/mod.rs | 2 + src/connector/codec/src/decoder/avro/mod.rs | 150 ++++++++- .../codec/src/decoder/avro/schema.rs | 102 +++--- src/connector/codec/src/decoder/mod.rs | 17 +- .../codec/tests/integration_tests/avro.rs | 293 +++++++++++++++++- .../src/parser/debezium/avro_parser.rs | 24 ++ src/connector/src/parser/unified/json.rs | 2 +- 7 files changed, 521 insertions(+), 69 deletions(-) diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 69e727cbde655..3b02b8c38d020 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -573,6 +573,8 @@ pub trait ScalarRef<'a>: ScalarBounds> + 'a + Copy { macro_rules! scalar_impl_enum { ($( { $variant_name:ident, $suffix_name:ident, $scalar:ty, $scalar_ref:ty } ),*) => { /// `ScalarImpl` embeds all possible scalars in the evaluation framework. + /// + /// See `for_all_variants` for the definition. #[derive(Debug, Clone, PartialEq, Eq, EstimateSize)] pub enum ScalarImpl { $( $variant_name($scalar) ),* diff --git a/src/connector/codec/src/decoder/avro/mod.rs b/src/connector/codec/src/decoder/avro/mod.rs index d67d1b5c8c8d0..71bb430059d87 100644 --- a/src/connector/codec/src/decoder/avro/mod.rs +++ b/src/connector/codec/src/decoder/avro/mod.rs @@ -33,6 +33,7 @@ use risingwave_common::util::iter_util::ZipEqFast; pub use self::schema::{avro_schema_to_column_descs, MapHandling, ResolvedAvroSchema}; use super::utils::extract_decimal; use super::{bail_uncategorized, uncategorized, Access, AccessError, AccessResult}; +use crate::decoder::avro::schema::avro_schema_to_struct_field_name; #[derive(Clone)] /// Options for parsing an `AvroValue` into Datum, with an optional avro schema. @@ -107,6 +108,43 @@ impl<'a> AvroParseOptions<'a> { let v: ScalarImpl = match (type_expected, value) { (_, Value::Null) => return Ok(DatumCow::NULL), + // ---- Union ----- + (DataType::Struct(struct_type_info), Value::Union(variant, v)) => match self.schema { + Some(Schema::Union(u)) => { + let variant_schema = &u.variants()[*variant as usize]; + + if matches!(variant_schema, &Schema::Null) { + return Ok(DatumCow::NULL); + } + + // XXX: can we use the variant idx to find the field idx? + // We will need to get the index of the "null" variant, and then re-map the variant index to the field index. + // Which way is better? + let expected_field_name = avro_schema_to_struct_field_name(variant_schema); + + let mut fields = Vec::with_capacity(struct_type_info.len()); + for (field_name, field_type) in struct_type_info + .names() + .zip_eq_fast(struct_type_info.types()) + { + if field_name == expected_field_name { + let datum = Self { + schema: Some(variant_schema), + relax_numeric: self.relax_numeric, + } + .convert_to_datum(v, field_type)? + .to_owned_datum(); + + fields.push(datum) + } else { + fields.push(None) + } + } + StructValue::new(fields).into() + } + _ => Err(create_error())?, + }, + // nullable Union (_, Value::Union(_, v)) => { let schema = self.extract_inner_schema(None); return Self { @@ -290,6 +328,11 @@ impl Access for AvroAccess<'_> { let mut value = self.value; let mut options: AvroParseOptions<'_> = self.options.clone(); + debug_assert!( + path.len() == 1 || (path.len() == 2 && path[0] == "before"), + "unexpected path access: {:?}", + path + ); let mut i = 0; while i < path.len() { let key = path[i]; @@ -299,6 +342,29 @@ impl Access for AvroAccess<'_> { }; match value { Value::Union(_, v) => { + // The debezium "before" field is a nullable union. + // "fields": [ + // { + // "name": "before", + // "type": [ + // "null", + // { + // "type": "record", + // "name": "Value", + // "fields": [...], + // } + // ], + // "default": null + // }, + // { + // "name": "after", + // "type": [ + // "null", + // "Value" + // ], + // "default": null + // }, + // ...] value = v; options.schema = options.extract_inner_schema(None); continue; @@ -341,13 +407,8 @@ pub(crate) fn avro_decimal_to_rust_decimal( /// If the union schema is `[null, T]` or `[T, null]`, returns `Some(T)`; otherwise returns `None`. fn get_nullable_union_inner(union_schema: &UnionSchema) -> Option<&'_ Schema> { let variants = union_schema.variants(); - if variants.len() == 2 - || variants - .iter() - .filter(|s| matches!(s, &&Schema::Null)) - .count() - == 1 - { + // Note: `[null, null] is invalid`, we don't need to worry about that. + if variants.len() == 2 && variants.contains(&Schema::Null) { let inner_schema = variants .iter() .find(|s| !matches!(s, &&Schema::Null)) @@ -389,6 +450,8 @@ pub fn avro_extract_field_schema<'a>( Ok(&field.schema) } Schema::Array(schema) => Ok(schema), + // Only nullable union should be handled here. + // We will not extract inner schema for real union (and it's not extractable). Schema::Union(_) => avro_schema_skip_nullable_union(schema), Schema::Map(schema) => Ok(schema), _ => bail!("avro schema does not have inner item, schema: {:?}", schema), @@ -501,7 +564,78 @@ mod tests { /// Test the behavior of the Rust Avro lib for handling union with logical type. #[test] - fn test_union_logical_type() { + fn test_avro_lib_union() { + // duplicate types + let s = Schema::parse_str(r#"["null", "null"]"#); + expect![[r#" + Err( + Unions cannot contain duplicate types, + ) + "#]] + .assert_debug_eq(&s); + let s = Schema::parse_str(r#"["int", "int"]"#); + expect![[r#" + Err( + Unions cannot contain duplicate types, + ) + "#]] + .assert_debug_eq(&s); + // multiple map/array are considered as the same type, regardless of the element type! + let s = Schema::parse_str( + r#"[ +"null", +{ + "type": "map", + "values" : "long", + "default": {} +}, +{ + "type": "map", + "values" : "int", + "default": {} +} +] +"#, + ); + expect![[r#" + Err( + Unions cannot contain duplicate types, + ) + "#]] + .assert_debug_eq(&s); + let s = Schema::parse_str( + r#"[ +"null", +{ + "type": "array", + "items" : "long", + "default": {} +}, +{ + "type": "array", + "items" : "int", + "default": {} +} +] +"#, + ); + expect![[r#" + Err( + Unions cannot contain duplicate types, + ) + "#]] + .assert_debug_eq(&s); + + // union in union + let s = Schema::parse_str(r#"["int", ["null", "int"]]"#); + expect![[r#" + Err( + Unions may not directly contain a union, + ) + "#]] + .assert_debug_eq(&s); + + // logical type let s = Schema::parse_str(r#"["null", {"type":"string","logicalType":"uuid"}]"#).unwrap(); expect![[r#" Union( diff --git a/src/connector/codec/src/decoder/avro/schema.rs b/src/connector/codec/src/decoder/avro/schema.rs index 066fc45273353..f5a9d779f6500 100644 --- a/src/connector/codec/src/decoder/avro/schema.rs +++ b/src/connector/codec/src/decoder/avro/schema.rs @@ -201,7 +201,10 @@ fn avro_type_mapping( DataType::List(Box::new(item_type)) } Schema::Union(union_schema) => { - // Unions may not contain more than one schema with the same type, except for the named types record, fixed and enum. + // Note: Unions may not immediately contain other unions. So a `null` must represent a top-level null. + // e.g., ["null", ["null", "string"]] is not allowed + + // Note: Unions may not contain more than one schema with the same type, except for the named types record, fixed and enum. // https://avro.apache.org/docs/1.11.1/specification/_print/#unions debug_assert!( union_schema @@ -222,61 +225,21 @@ fn avro_type_mapping( // Note: Avro union's variant tag is type name, not field name (unlike Rust enum, or Protobuf oneof). // XXX: do we need to introduce union.handling.mode? - let (fields, field_names) = union_schema .variants() .iter() + // null will mean the whole struct is null .filter(|variant| !matches!(variant, &&Schema::Null)) .map(|variant| { avro_type_mapping(variant, map_handling).map(|t| { - let name = match variant { - Schema::Null => unreachable!(), - Schema::Boolean => "boolean".to_string(), - Schema::Int => "integer".to_string(), - Schema::Long => "bigint".to_string(), - Schema::Float => "real".to_string(), - Schema::Double => "double precision".to_string(), - Schema::Bytes => "bytea".to_string(), - Schema::String => "text".to_string(), - Schema::Array(_) => "array".to_string(), - Schema::Map(_) =>"map".to_string(), - Schema::Union(_) => "union".to_string(), - // For logical types, should we use the real type or the logical type as the field name? - // - // Example about the representation: - // schema: ["null", {"type":"string","logicalType":"uuid"}] - // data: {"string": "67e55044-10b1-426f-9247-bb680e5fe0c8"} - // - // Note: for union with logical type AND the real type, e.g., ["string", {"type":"string","logicalType":"uuid"}] - // In this case, the uuid cannot be constructed. Some library - // https://issues.apache.org/jira/browse/AVRO-2380 - Schema::Uuid => "uuid".to_string(), - Schema::Decimal(_) => todo!(), - Schema::Date => "date".to_string(), - Schema::TimeMillis => "time without time zone".to_string(), - Schema::TimeMicros => "time without time zone".to_string(), - Schema::TimestampMillis => "timestamp without time zone".to_string(), - Schema::TimestampMicros => "timestamp without time zone".to_string(), - Schema::LocalTimestampMillis => "timestamp without time zone".to_string(), - Schema::LocalTimestampMicros => "timestamp without time zone".to_string(), - Schema::Duration => "interval".to_string(), - Schema::Enum(_) - | Schema::Ref { name: _ } - | Schema::Fixed(_) => todo!(), - | Schema::Record(_) => variant.name().unwrap().fullname(None), // XXX: Is the namespace correct here? - }; + let name = avro_schema_to_struct_field_name(variant); (t, name) }) }) .process_results(|it| it.unzip::<_, _, Vec<_>, Vec<_>>()) .context("failed to convert Avro union to struct")?; - DataType::new_struct(fields, field_names); - - bail!( - "unsupported Avro type, only unions like [null, T] is supported: {:?}", - schema - ); + DataType::new_struct(fields, field_names) } } } @@ -351,3 +314,54 @@ fn supported_avro_to_json_type(schema: &Schema) -> bool { | Schema::Union(_) => false, } } + +/// The field name when converting Avro union type to RisingWave struct type. +pub(super) fn avro_schema_to_struct_field_name(schema: &Schema) -> String { + match schema { + Schema::Null => unreachable!(), + Schema::Union(_) => unreachable!(), + // Primitive types + Schema::Boolean => "boolean".to_string(), + Schema::Int => "int".to_string(), + Schema::Long => "long".to_string(), + Schema::Float => "float".to_string(), + Schema::Double => "double".to_string(), + Schema::Bytes => "bytes".to_string(), + Schema::String => "string".to_string(), + // Unnamed Complex types + Schema::Array(_) => "array".to_string(), + Schema::Map(_) => "map".to_string(), + // Named Complex types + // TODO: Verify is the namespace correct here + Schema::Enum(_) | Schema::Ref { name: _ } | Schema::Fixed(_) => todo!(), + Schema::Record(_) => schema.name().unwrap().fullname(None), + // Logical types + // XXX: should we use the real type or the logical type as the field name? + // It seems not to matter much, as we always have the index of the field when we get a Union Value. + // + // Currently choose the logical type because it might be more user-friendly. + // + // Example about the representation: + // schema: ["null", {"type":"string","logicalType":"uuid"}] + // data: {"string": "67e55044-10b1-426f-9247-bb680e5fe0c8"} + // + // Note: for union with logical type AND the real type, e.g., ["string", {"type":"string","logicalType":"uuid"}] + // In this case, the uuid cannot be constructed. + // Actually this should be an invalid schema according to the spec. https://issues.apache.org/jira/browse/AVRO-2380 + // But some library like Python and Rust both allow it. See `risingwave_connector_codec::decoder::avro::tests::test_avro_lib_union` + Schema::Uuid => "uuid".to_string(), + Schema::Decimal(_) => "decimal".to_string(), + Schema::Date => "date".to_string(), + // Note: in Avro, the name style is "time-millis", etc. + // But in RisingWave (Postgres), it will require users to use quotes, i.e., + // select (struct)."time-millis", (struct).time_millies from t; + // The latter might be more user-friendly. + Schema::TimeMillis => "time_millis".to_string(), + Schema::TimeMicros => "time_micros".to_string(), + Schema::TimestampMillis => "timestamp_millis".to_string(), + Schema::TimestampMicros => "timestamp_micros".to_string(), + Schema::LocalTimestampMillis => "local_timestamp_millis".to_string(), + Schema::LocalTimestampMicros => "local_timestamp_micros".to_string(), + Schema::Duration => "duration".to_string(), + } +} diff --git a/src/connector/codec/src/decoder/mod.rs b/src/connector/codec/src/decoder/mod.rs index c7e04ab210a6e..a2823cceca2ff 100644 --- a/src/connector/codec/src/decoder/mod.rs +++ b/src/connector/codec/src/decoder/mod.rs @@ -45,18 +45,31 @@ pub enum AccessError { pub type AccessResult = std::result::Result; /// Access to a field in the data structure. Created by `AccessBuilder`. +/// +/// It's the `ENCODE ...` part in `FORMAT ... ENCODE ...` pub trait Access { /// Accesses `path` in the data structure (*parsed* Avro/JSON/Protobuf data), /// and then converts it to RisingWave `Datum`. + /// /// `type_expected` might or might not be used during the conversion depending on the implementation. /// /// # Path /// - /// We usually expect the data is a record (struct), and `path` represents field path. + /// We usually expect the data (`Access` instance) is a record (struct), and `path` represents field path. /// The data (or part of the data) represents the whole row (`Vec`), /// and we use different `path` to access one column at a time. /// - /// e.g., for Avro, we access `["col_name"]`; for Debezium Avro, we access `["before", "col_name"]`. + /// TODO: the meaning of `path` is a little confusing and maybe over-abstracted. + /// `access` does not need to serve arbitrarily deep `path` access, but just "top-level" access. + /// The API creates an illusion that arbitrary access is supported, but it's not. + /// Perhapts we should separate out another trait like `ToDatum`, + /// which only does type mapping, without caring about the path. And `path` itself is only an `enum` instead of `&[&str]`. + /// + /// What `path` to access is decided by the CDC layer, i.e., the `FORMAT ...` part (`ChangeEvent`). + /// e.g., + /// - `DebeziumChangeEvent` accesses `["before", "col_name"]` for value, `["op"]` for op type. + /// - `MaxwellChangeEvent` accesses `["data", "col_name"]` for value, `["type"]` for op type. + /// - In the simplest case, for `FORMAT PLAIN/UPSERT` (`KvEvent`), they just access `["col_name"]` for value, and op type is derived. /// /// # Returns /// diff --git a/src/connector/codec/tests/integration_tests/avro.rs b/src/connector/codec/tests/integration_tests/avro.rs index 21b167a768531..74a8db5e1f5a7 100644 --- a/src/connector/codec/tests/integration_tests/avro.rs +++ b/src/connector/codec/tests/integration_tests/avro.rs @@ -398,19 +398,284 @@ fn test_1() { rate(#14): Float64, ]"#]], expect![[r#" - Borrowed(Some(Utf8("update"))) - Borrowed(Some(Utf8("id1"))) - Borrowed(Some(Utf8("1"))) - Borrowed(Some(Utf8("6768"))) - Borrowed(Some(Utf8("6970"))) - Borrowed(Some(Utf8("value9"))) - Borrowed(Some(Utf8("7172"))) - Borrowed(Some(Utf8("info9"))) - Borrowed(Some(Utf8("2021-05-18T07:59:58.714Z"))) - Owned(Some(Decimal(Normalized(99999999.99)))) - Owned(None) - Owned(None) - Owned(None) - Owned(Some(Float64(OrderedFloat(NaN))))"#]], + Borrowed(Utf8("update")) + Borrowed(Utf8("id1")) + Borrowed(Utf8("1")) + Borrowed(Utf8("6768")) + Borrowed(Utf8("6970")) + Borrowed(Utf8("value9")) + Borrowed(Utf8("7172")) + Borrowed(Utf8("info9")) + Borrowed(Utf8("2021-05-18T07:59:58.714Z")) + Owned(Decimal(Normalized(99999999.99))) + Owned(null) + Owned(null) + Owned(null) + Owned(Float64(OrderedFloat(NaN)))"#]], + ); +} + +#[test] +fn test_union() { + check( + r#" +{ + "type": "record", + "name": "Root", + "fields": [ + { + "name": "unionType", + "type": ["int", "string"] + }, + { + "name": "unionTypeComplex", + "type": [ + "null", + {"type": "record", "name": "Email","fields": [{"name":"inner","type":"string"}]}, + {"type": "record", "name": "Fax","fields": [{"name":"inner","type":"int"}]}, + {"type": "record", "name": "Sms","fields": [{"name":"inner","type":"int"}]} + ] + }, + { + "name": "nullableString", + "type": ["null", "string"] + } + ] +} + "#, + &[ + // { + // "unionType": {"int": 114514}, + // "unionTypeComplex": {"Sms": {"inner":6}}, + // "nullableString": null + // } + "00a4fd0d060c00", + // { + // "unionType": {"int": 114514}, + // "unionTypeComplex": {"Fax": {"inner":6}}, + // "nullableString": null + // } + "00a4fd0d040c00", + // { + // "unionType": {"string": "oops"}, + // "unionTypeComplex": null, + // "nullableString": {"string": "hello"} + // } + "02086f6f707300020a68656c6c6f", + // { + // "unionType": {"string": "oops"}, + // "unionTypeComplex": {"Email": {"inner":"a@b.c"}}, + // "nullableString": null + // } + "02086f6f7073020a6140622e6300", + ], + Config { + map_handling: None, + data_encoding: TestDataEncoding::HexBinary, + }, + // FIXME: why the struct type doesn't have field_descs? + expect![[r#" + [ + unionType(#1): Struct { + int: Int32, + string: Varchar, + }, + unionTypeComplex(#2): Struct { + Email: Struct { inner: Varchar }, + Fax: Struct { inner: Int32 }, + Sms: Struct { inner: Int32 }, + }, + nullableString(#3): Varchar, + ]"#]], + expect![[r#" + Owned(StructValue( + Int32(114514), + null, + )) + Owned(StructValue( + null, + null, + StructValue(Int32(6)), + )) + Owned(null) + ---- + Owned(StructValue( + Int32(114514), + null, + )) + Owned(StructValue( + null, + StructValue(Int32(6)), + null, + )) + Owned(null) + ---- + Owned(StructValue( + null, + Utf8("oops"), + )) + Owned(null) + Borrowed(Utf8("hello")) + ---- + Owned(StructValue( + null, + Utf8("oops"), + )) + Owned(StructValue( + StructValue(Utf8("a@b.c")), + null, + null, + )) + Owned(null)"#]], + ); + + check( + r#" +{ + "namespace": "com.abc.efg.mqtt", + "name": "also.DataMessage", + "type": "record", + "fields": [ + { + "name": "metrics", + "type": { + "type": "array", + "items": { + "name": "also_data_metric", + "type": "record", + "fields": [ + { + "name": "id", + "type": "string" + }, + { + "name": "name", + "type": "string" + }, + { + "name": "norm_name", + "type": [ + "null", + "string" + ], + "default": null + }, + { + "name": "uom", + "type": [ + "null", + "string" + ], + "default": null + }, + { + "name": "data", + "type": { + "type": "array", + "items": { + "name": "dataItem", + "type": "record", + "fields": [ + { + "name": "ts", + "type": "string", + "doc": "Timestamp of the metric." + }, + { + "name": "value", + "type": [ + "null", + "boolean", + "double", + "string" + ], + "doc": "Value of the metric." + } + ] + } + }, + "doc": "The data message" + } + ], + "doc": "A metric object" + } + }, + "doc": "A list of metrics." + } + ] +} + "#, + &[ + // { + // "metrics": [ + // {"id":"foo", "name":"a", "data": [] } + // ] + // } + "0206666f6f026100000000", + // { + // "metrics": [ + // {"id":"foo", "name":"a", "norm_name": null, "uom": {"string":"c"}, "data": [{"ts":"1", "value":null}, {"ts":"2", "value": {"boolean": true }}] } + // ] + // } + "0206666f6f02610002026304023100023202010000", + ], + Config { + map_handling: None, + data_encoding: TestDataEncoding::HexBinary, + }, + expect![[r#" + [ + metrics(#1): List( + Struct { + id: Varchar, + name: Varchar, + norm_name: Varchar, + uom: Varchar, + data: List( + Struct { + ts: Varchar, + value: Struct { + boolean: Boolean, + double: Float64, + string: Varchar, + }, + }, + ), + }, + ), + ]"#]], + expect![[r#" + Owned([ + StructValue( + Utf8("foo"), + Utf8("a"), + null, + null, + [], + ), + ]) + ---- + Owned([ + StructValue( + Utf8("foo"), + Utf8("a"), + null, + Utf8("c"), + [ + StructValue( + Utf8("1"), + null, + ), + StructValue( + Utf8("2"), + StructValue( + Bool(true), + null, + null, + ), + ), + ], + ), + ])"#]], ); } diff --git a/src/connector/src/parser/debezium/avro_parser.rs b/src/connector/src/parser/debezium/avro_parser.rs index 467fb4c7379da..70e7304c0eea8 100644 --- a/src/connector/src/parser/debezium/avro_parser.rs +++ b/src/connector/src/parser/debezium/avro_parser.rs @@ -125,6 +125,30 @@ impl DebeziumAvroParserConfig { } pub fn map_to_columns(&self) -> ConnectorResult> { + // Refer to debezium_avro_msg_schema.avsc for how the schema looks like: + + // "fields": [ + // { + // "name": "before", + // "type": [ + // "null", + // { + // "type": "record", + // "name": "Value", + // "fields": [...], + // } + // ], + // "default": null + // }, + // { + // "name": "after", + // "type": [ + // "null", + // "Value" + // ], + // "default": null + // }, + // ...] avro_schema_to_column_descs( avro_schema_skip_nullable_union(avro_extract_field_schema( // FIXME: use resolved schema here. diff --git a/src/connector/src/parser/unified/json.rs b/src/connector/src/parser/unified/json.rs index e4a229bb61b98..ca709e2eebc73 100644 --- a/src/connector/src/parser/unified/json.rs +++ b/src/connector/src/parser/unified/json.rs @@ -534,7 +534,7 @@ impl JsonParseOptions { (DataType::Struct(struct_type_info), ValueType::Object) => { // Collecting into a Result> doesn't reserve the capacity in advance, so we `Vec::with_capacity` instead. // https://github.com/rust-lang/rust/issues/48994 - let mut fields = Vec::with_capacity(struct_type_info.types().len()); + let mut fields = Vec::with_capacity(struct_type_info.len()); for (field_name, field_type) in struct_type_info .names() .zip_eq_fast(struct_type_info.types())