Skip to content

Commit

Permalink
implement union
Browse files Browse the repository at this point in the history
  • Loading branch information
xxchan committed Jul 2, 2024
1 parent 3c72107 commit 011b18c
Show file tree
Hide file tree
Showing 7 changed files with 507 additions and 55 deletions.
2 changes: 2 additions & 0 deletions src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,8 @@ pub trait ScalarRef<'a>: ScalarBounds<ScalarRefImpl<'a>> + 'a + Copy {
macro_rules! scalar_impl_enum {
($( { $variant_name:ident, $suffix_name:ident, $scalar:ty, $scalar_ref:ty } ),*) => {
/// `ScalarImpl` embeds all possible scalars in the evaluation framework.
///
/// See `for_all_variants` for the definition.
#[derive(Debug, Clone, PartialEq, Eq, EstimateSize)]
pub enum ScalarImpl {
$( $variant_name($scalar) ),*
Expand Down
150 changes: 142 additions & 8 deletions src/connector/codec/src/decoder/avro/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use risingwave_common::util::iter_util::ZipEqFast;
pub use self::schema::{avro_schema_to_column_descs, MapHandling, ResolvedAvroSchema};
use super::utils::extract_decimal;
use super::{bail_uncategorized, uncategorized, Access, AccessError, AccessResult};
use crate::decoder::avro::schema::avro_schema_to_struct_field_name;

#[derive(Clone)]
/// Options for parsing an `AvroValue` into Datum, with an optional avro schema.
Expand Down Expand Up @@ -107,6 +108,43 @@ impl<'a> AvroParseOptions<'a> {

let v: ScalarImpl = match (type_expected, value) {
(_, Value::Null) => return Ok(DatumCow::NULL),
// ---- Union -----
(DataType::Struct(struct_type_info), Value::Union(variant, v)) => match self.schema {
Some(Schema::Union(u)) => {
let variant_schema = &u.variants()[*variant as usize];

if matches!(variant_schema, &Schema::Null) {
return Ok(DatumCow::NULL);
}

// XXX: can we use the variant idx to find the field idx?
// We will need to get the index of the "null" variant, and then re-map the variant index to the field index.
// Which way is better?
let expected_field_name = avro_schema_to_struct_field_name(variant_schema);

let mut fields = Vec::with_capacity(struct_type_info.len());
for (field_name, field_type) in struct_type_info
.names()
.zip_eq_fast(struct_type_info.types())
{
if field_name == expected_field_name {
let datum = Self {
schema: Some(variant_schema),
relax_numeric: self.relax_numeric,
}
.convert_to_datum(v, field_type)?
.to_owned_datum();

fields.push(datum)
} else {
fields.push(None)
}
}
StructValue::new(fields).into()
}
_ => Err(create_error())?,
},
// nullable Union
(_, Value::Union(_, v)) => {
let schema = self.extract_inner_schema(None);
return Self {
Expand Down Expand Up @@ -290,6 +328,11 @@ impl Access for AvroAccess<'_> {
let mut value = self.value;
let mut options: AvroParseOptions<'_> = self.options.clone();

debug_assert!(
path.len() == 1 || (path.len() == 2 && path[0] == "before"),
"unexpected path access: {:?}",
path
);
let mut i = 0;
while i < path.len() {
let key = path[i];
Expand All @@ -299,6 +342,29 @@ impl Access for AvroAccess<'_> {
};
match value {
Value::Union(_, v) => {
// The debezium "before" field is a nullable union.
// "fields": [
// {
// "name": "before",
// "type": [
// "null",
// {
// "type": "record",
// "name": "Value",
// "fields": [...],
// }
// ],
// "default": null
// },
// {
// "name": "after",
// "type": [
// "null",
// "Value"
// ],
// "default": null
// },
// ...]
value = v;
options.schema = options.extract_inner_schema(None);
continue;
Expand Down Expand Up @@ -341,13 +407,8 @@ pub(crate) fn avro_decimal_to_rust_decimal(
/// If the union schema is `[null, T]` or `[T, null]`, returns `Some(T)`; otherwise returns `None`.
fn get_nullable_union_inner(union_schema: &UnionSchema) -> Option<&'_ Schema> {
let variants = union_schema.variants();
if variants.len() == 2
|| variants
.iter()
.filter(|s| matches!(s, &&Schema::Null))
.count()
== 1
{
// Note: `[null, null] is invalid`, we don't need to worry about that.
if variants.len() == 2 && variants.contains(&Schema::Null) {
let inner_schema = variants
.iter()
.find(|s| !matches!(s, &&Schema::Null))
Expand Down Expand Up @@ -389,6 +450,8 @@ pub fn avro_extract_field_schema<'a>(
Ok(&field.schema)
}
Schema::Array(schema) => Ok(schema),
// Only nullable union should be handled here.
// We will not extract inner schema for real union (and it's not extractable).
Schema::Union(_) => avro_schema_skip_nullable_union(schema),
Schema::Map(schema) => Ok(schema),
_ => bail!("avro schema does not have inner item, schema: {:?}", schema),
Expand Down Expand Up @@ -501,7 +564,78 @@ mod tests {

/// Test the behavior of the Rust Avro lib for handling union with logical type.
#[test]
fn test_union_logical_type() {
fn test_avro_lib_union() {
// duplicate types
let s = Schema::parse_str(r#"["null", "null"]"#);
expect![[r#"
Err(
Unions cannot contain duplicate types,
)
"#]]
.assert_debug_eq(&s);
let s = Schema::parse_str(r#"["int", "int"]"#);
expect![[r#"
Err(
Unions cannot contain duplicate types,
)
"#]]
.assert_debug_eq(&s);
// multiple map/array are considered as the same type, regardless of the element type!
let s = Schema::parse_str(
r#"[
"null",
{
"type": "map",
"values" : "long",
"default": {}
},
{
"type": "map",
"values" : "int",
"default": {}
}
]
"#,
);
expect![[r#"
Err(
Unions cannot contain duplicate types,
)
"#]]
.assert_debug_eq(&s);
let s = Schema::parse_str(
r#"[
"null",
{
"type": "array",
"items" : "long",
"default": {}
},
{
"type": "array",
"items" : "int",
"default": {}
}
]
"#,
);
expect![[r#"
Err(
Unions cannot contain duplicate types,
)
"#]]
.assert_debug_eq(&s);

// union in union
let s = Schema::parse_str(r#"["int", ["null", "int"]]"#);
expect![[r#"
Err(
Unions may not directly contain a union,
)
"#]]
.assert_debug_eq(&s);

// logical type
let s = Schema::parse_str(r#"["null", {"type":"string","logicalType":"uuid"}]"#).unwrap();
expect![[r#"
Union(
Expand Down
102 changes: 58 additions & 44 deletions src/connector/codec/src/decoder/avro/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ fn avro_type_mapping(
DataType::List(Box::new(item_type))
}
Schema::Union(union_schema) => {
// Unions may not contain more than one schema with the same type, except for the named types record, fixed and enum.
// Note: Unions may not immediately contain other unions. So a `null` must represent a top-level null.
// e.g., ["null", ["null", "string"]] is not allowed

// Note: Unions may not contain more than one schema with the same type, except for the named types record, fixed and enum.
// https://avro.apache.org/docs/1.11.1/specification/_print/#unions
debug_assert!(
union_schema
Expand All @@ -222,61 +225,21 @@ fn avro_type_mapping(
// Note: Avro union's variant tag is type name, not field name (unlike Rust enum, or Protobuf oneof).

// XXX: do we need to introduce union.handling.mode?

let (fields, field_names) = union_schema
.variants()
.iter()
// null will mean the whole struct is null
.filter(|variant| !matches!(variant, &&Schema::Null))
.map(|variant| {
avro_type_mapping(variant, map_handling).map(|t| {
let name = match variant {
Schema::Null => unreachable!(),
Schema::Boolean => "boolean".to_string(),
Schema::Int => "integer".to_string(),
Schema::Long => "bigint".to_string(),
Schema::Float => "real".to_string(),
Schema::Double => "double precision".to_string(),
Schema::Bytes => "bytea".to_string(),
Schema::String => "text".to_string(),
Schema::Array(_) => "array".to_string(),
Schema::Map(_) =>"map".to_string(),
Schema::Union(_) => "union".to_string(),
// For logical types, should we use the real type or the logical type as the field name?
//
// Example about the representation:
// schema: ["null", {"type":"string","logicalType":"uuid"}]
// data: {"string": "67e55044-10b1-426f-9247-bb680e5fe0c8"}
//
// Note: for union with logical type AND the real type, e.g., ["string", {"type":"string","logicalType":"uuid"}]
// In this case, the uuid cannot be constructed. Some library
// https://issues.apache.org/jira/browse/AVRO-2380
Schema::Uuid => "uuid".to_string(),
Schema::Decimal(_) => todo!(),
Schema::Date => "date".to_string(),
Schema::TimeMillis => "time without time zone".to_string(),
Schema::TimeMicros => "time without time zone".to_string(),
Schema::TimestampMillis => "timestamp without time zone".to_string(),
Schema::TimestampMicros => "timestamp without time zone".to_string(),
Schema::LocalTimestampMillis => "timestamp without time zone".to_string(),
Schema::LocalTimestampMicros => "timestamp without time zone".to_string(),
Schema::Duration => "interval".to_string(),
Schema::Enum(_)
| Schema::Ref { name: _ }
| Schema::Fixed(_) => todo!(),
| Schema::Record(_) => variant.name().unwrap().fullname(None), // XXX: Is the namespace correct here?
};
let name = avro_schema_to_struct_field_name(variant);
(t, name)
})
})
.process_results(|it| it.unzip::<_, _, Vec<_>, Vec<_>>())
.context("failed to convert Avro union to struct")?;

DataType::new_struct(fields, field_names);

bail!(
"unsupported Avro type, only unions like [null, T] is supported: {:?}",
schema
);
DataType::new_struct(fields, field_names)
}
}
}
Expand Down Expand Up @@ -351,3 +314,54 @@ fn supported_avro_to_json_type(schema: &Schema) -> bool {
| Schema::Union(_) => false,
}
}

/// The field name when converting Avro union type to RisingWave struct type.
pub(super) fn avro_schema_to_struct_field_name(schema: &Schema) -> String {
match schema {
Schema::Null => unreachable!(),
Schema::Union(_) => unreachable!(),
// Primitive types
Schema::Boolean => "boolean".to_string(),
Schema::Int => "int".to_string(),
Schema::Long => "long".to_string(),
Schema::Float => "float".to_string(),
Schema::Double => "double".to_string(),
Schema::Bytes => "bytes".to_string(),
Schema::String => "string".to_string(),
// Unnamed Complex types
Schema::Array(_) => "array".to_string(),
Schema::Map(_) => "map".to_string(),
// Named Complex types
// TODO: Verify is the namespace correct here
Schema::Enum(_) | Schema::Ref { name: _ } | Schema::Fixed(_) => todo!(),
Schema::Record(_) => schema.name().unwrap().fullname(None),
// Logical types
// XXX: should we use the real type or the logical type as the field name?
// It seems not to matter much, as we always have the index of the field when we get a Union Value.
//
// Currently choose the logical type because it might be more user-friendly.
//
// Example about the representation:
// schema: ["null", {"type":"string","logicalType":"uuid"}]
// data: {"string": "67e55044-10b1-426f-9247-bb680e5fe0c8"}
//
// Note: for union with logical type AND the real type, e.g., ["string", {"type":"string","logicalType":"uuid"}]
// In this case, the uuid cannot be constructed.
// Actually this should be an invalid schema according to the spec. https://issues.apache.org/jira/browse/AVRO-2380
// But some library like Python and Rust both allow it. See `risingwave_connector_codec::decoder::avro::tests::test_avro_lib_union`
Schema::Uuid => "uuid".to_string(),
Schema::Decimal(_) => "decimal".to_string(),
Schema::Date => "date".to_string(),
// Note: in Avro, the name style is "time-millis", etc.
// But in RisingWave (Postgres), it will require users to use quotes, i.e.,
// select (struct)."time-millis", (struct).time_millies from t;
// The latter might be more user-friendly.
Schema::TimeMillis => "time_millis".to_string(),
Schema::TimeMicros => "time_micros".to_string(),
Schema::TimestampMillis => "timestamp_millis".to_string(),
Schema::TimestampMicros => "timestamp_micros".to_string(),
Schema::LocalTimestampMillis => "local_timestamp_millis".to_string(),
Schema::LocalTimestampMicros => "local_timestamp_micros".to_string(),
Schema::Duration => "duration".to_string(),
}
}
17 changes: 15 additions & 2 deletions src/connector/codec/src/decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,31 @@ pub enum AccessError {
pub type AccessResult<T = Datum> = std::result::Result<T, AccessError>;

/// Access to a field in the data structure. Created by `AccessBuilder`.
///
/// It's the `ENCODE ...` part in `FORMAT ... ENCODE ...`
pub trait Access {
/// Accesses `path` in the data structure (*parsed* Avro/JSON/Protobuf data),
/// and then converts it to RisingWave `Datum`.
///
/// `type_expected` might or might not be used during the conversion depending on the implementation.
///
/// # Path
///
/// We usually expect the data is a record (struct), and `path` represents field path.
/// We usually expect the data (`Access` instance) is a record (struct), and `path` represents field path.
/// The data (or part of the data) represents the whole row (`Vec<Datum>`),
/// and we use different `path` to access one column at a time.
///
/// e.g., for Avro, we access `["col_name"]`; for Debezium Avro, we access `["before", "col_name"]`.
/// TODO: the meaning of `path` is a little confusing and maybe over-abstracted.
/// `access` does not need to serve arbitrarily deep `path` access, but just "top-level" access.
/// The API creates an illusion that arbitrary access is supported, but it's not.
/// Perhapts we should separate out another trait like `ToDatum`,
/// which only does type mapping, without caring about the path. And `path` itself is only an `enum` instead of `&[&str]`.
///
/// What `path` to access is decided by the CDC layer, i.e., the `FORMAT ...` part (`ChangeEvent`).
/// e.g.,
/// - `DebeziumChangeEvent` accesses `["before", "col_name"]` for value, `["op"]` for op type.
/// - `MaxwellChangeEvent` accesses `["data", "col_name"]` for value, `["type"]` for op type.
/// - In the simplest case, for `FORMAT PLAIN/UPSERT` (`KvEvent`), they just access `["col_name"]` for value, and op type is derived.
///
/// # Returns
///
Expand Down
Loading

0 comments on commit 011b18c

Please sign in to comment.