diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 707fb24651995..df51666e30790 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -83,6 +83,7 @@ paste = "1" prometheus = { version = "0.13", features = ["process"] } prost = { version = "0.12", features = ["no-recursion-limit"] } prost-reflect = "0.12" +prost-types = "0.12" protobuf-native = "0.2.1" pulsar = { version = "6.0", default-features = false, features = [ "tokio-runtime", @@ -138,7 +139,6 @@ workspace-hack = { path = "../workspace-hack" } [dev-dependencies] criterion = { workspace = true, features = ["async_tokio", "async"] } -prost-types = "0.12" rand = "0.8" tempfile = "3" tracing-test = "0.2" diff --git a/src/connector/src/lib.rs b/src/connector/src/lib.rs index 1fba061555f44..4dd1691b00f89 100644 --- a/src/connector/src/lib.rs +++ b/src/connector/src/lib.rs @@ -32,6 +32,7 @@ #![feature(impl_trait_in_assoc_type)] #![feature(iter_from_generator)] #![feature(if_let_guard)] +#![feature(iterator_try_collect)] use std::time::Duration; diff --git a/src/connector/src/sink/encoder/avro.rs b/src/connector/src/sink/encoder/avro.rs new file mode 100644 index 0000000000000..0b26d202409b6 --- /dev/null +++ b/src/connector/src/sink/encoder/avro.rs @@ -0,0 +1,896 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use apache_avro::schema::Schema as AvroSchema; +use apache_avro::types::{Record, Value}; +use apache_avro::Writer; +use risingwave_common::catalog::Schema; +use risingwave_common::row::Row; +use risingwave_common::types::{DataType, DatumRef, ScalarRefImpl, StructType}; +use risingwave_common::util::iter_util::{ZipEqDebug, ZipEqFast}; + +use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo}; + +type Result = std::result::Result; + +pub struct AvroEncoder<'a> { + schema: &'a Schema, + col_indices: Option<&'a [usize]>, + avro_schema: &'a AvroSchema, +} + +impl<'a> AvroEncoder<'a> { + pub fn new( + schema: &'a Schema, + col_indices: Option<&'a [usize]>, + avro_schema: &'a AvroSchema, + ) -> SinkResult { + match col_indices { + Some(col_indices) => validate_fields( + col_indices.iter().map(|idx| { + let f = &schema[*idx]; + (f.name.as_str(), &f.data_type) + }), + avro_schema, + )?, + None => validate_fields( + schema + .fields + .iter() + .map(|f| (f.name.as_str(), &f.data_type)), + avro_schema, + )?, + }; + + Ok(Self { + schema, + col_indices, + avro_schema, + }) + } +} + +impl<'a> RowEncoder for AvroEncoder<'a> { + type Output = (Record<'a>, &'a AvroSchema); + + fn schema(&self) -> &Schema { + self.schema + } + + fn col_indices(&self) -> Option<&[usize]> { + self.col_indices + } + + fn encode_cols( + &self, + row: impl Row, + col_indices: impl Iterator, + ) -> SinkResult { + let record = encode_fields( + col_indices.map(|idx| { + let f = &self.schema[idx]; + ((f.name.as_str(), &f.data_type), row.datum_at(idx)) + }), + self.avro_schema, + )?; + Ok((record, self.avro_schema)) + } +} + +impl<'a> SerTo> for (Record<'a>, &'a AvroSchema) { + fn ser_to(self) -> SinkResult> { + let mut w = Writer::new(self.1, Vec::new()); + w.append(self.0) + .and_then(|_| w.into_inner()) + .map_err(|e| crate::sink::SinkError::Encode(e.to_string())) + } +} + +enum OptIdx { + NotUnion, + Single, + NullLeft, + NullRight, +} + +/// A trait that assists code reuse between `validate` and `encode`. +/// * For `validate`, the inputs are (RisingWave type, ProtoBuf type). +/// * For `encode`, the inputs are (RisingWave type, RisingWave data, ProtoBuf type). +/// +/// Thus we impl [`MaybeData`] for both [`()`] and [`DatumRef`]. +trait MaybeData: std::fmt::Debug { + type Out; + + fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result) -> Result; + + /// Switch to `RecordSchema` after #12562 + fn on_struct(self, st: &StructType, avro: &AvroSchema) -> Result; + + fn on_list(self, elem: &DataType, avro: &AvroSchema) -> Result; + + fn handle_union(out: Self::Out, opt_idx: OptIdx) -> Result; +} + +impl MaybeData for () { + type Out = (); + + fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result) -> Result { + Ok(self) + } + + fn on_struct(self, st: &StructType, avro: &AvroSchema) -> Result { + validate_fields(st.iter(), avro) + } + + fn on_list(self, elem: &DataType, avro: &AvroSchema) -> Result { + encode_field(elem, (), avro) + } + + fn handle_union(out: Self::Out, _: OptIdx) -> Result { + Ok(out) + } +} + +impl MaybeData for DatumRef<'_> { + type Out = Value; + + fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result) -> Result { + match self { + Some(s) => f(s), + None => Ok(Value::Null), + } + } + + fn on_struct(self, st: &StructType, avro: &AvroSchema) -> Result { + let d = match self { + Some(s) => s.into_struct(), + None => return Ok(Value::Null), + }; + let record = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), avro)?; + Ok(record.into()) + } + + fn on_list(self, elem: &DataType, avro: &AvroSchema) -> Result { + let d = match self { + Some(s) => s.into_list(), + None => return Ok(Value::Null), + }; + let vs = d + .iter() + .map(|d| encode_field(elem, d, avro)) + .try_collect()?; + Ok(Value::Array(vs)) + } + + fn handle_union(out: Self::Out, opt_idx: OptIdx) -> Result { + use OptIdx::*; + + match out == Value::Null { + true => { + let ni = match opt_idx { + NotUnion | Single => { + return Err(FieldEncodeError::new("found null but required")) + } + NullLeft => 0, + NullRight => 1, + }; + Ok(Value::Union(ni, out.into())) + } + false => { + let vi = match opt_idx { + NotUnion => return Ok(out), + NullLeft => 1, + Single | NullRight => 0, + }; + Ok(Value::Union(vi, out.into())) + } + } + } +} + +fn validate_fields<'rw>( + rw_fields: impl Iterator, + avro: &AvroSchema, +) -> Result<()> { + let AvroSchema::Record { fields, lookup, .. } = avro else { + return Err(FieldEncodeError::new(format!( + "expect avro record but got {}", + avro.canonical_form(), + ))); + }; + let mut present = vec![false; fields.len()]; + for (name, t) in rw_fields { + let Some(&idx) = lookup.get(name) else { + return Err(FieldEncodeError::new("field not in avro").with_name(name)); + }; + present[idx] = true; + let avro_field = &fields[idx]; + encode_field(t, (), &avro_field.schema).map_err(|e| e.with_name(name))?; + } + for (p, avro_field) in present.into_iter().zip_eq_fast(fields) { + if p { + continue; + } + if !avro_field.is_nullable() { + return Err( + FieldEncodeError::new("field not present but required").with_name(&avro_field.name) + ); + } + } + Ok(()) +} + +fn encode_fields<'avro, 'rw>( + fields_with_datums: impl Iterator)>, + schema: &'avro AvroSchema, +) -> Result> { + let mut record = Record::new(schema).unwrap(); + let AvroSchema::Record { fields, lookup, .. } = schema else { + unreachable!() + }; + let mut present = vec![false; fields.len()]; + for ((name, t), d) in fields_with_datums { + let idx = lookup[name]; + present[idx] = true; + let avro_field = &fields[idx]; + let value = encode_field(t, d, &avro_field.schema).map_err(|e| e.with_name(name))?; + record.put(name, value); + } + // Unfortunately, the upstream `apache_avro` does not handle missing fields as nullable correctly. + // The correct encoding is `Value::Union(null_index, Value::Null)` but it simply writes `Value::Null`. + for (p, avro_field) in present.into_iter().zip_eq_fast(fields) { + if p { + continue; + } + let AvroSchema::Union(u) = &avro_field.schema else { + unreachable!() + }; + // We could have saved null index of each field during [`validate_fields`] to avoid repeated lookup. + // But in most cases it is the 0th. + // Alternatively, we can simplify by enforcing the best practice of `null at 0th`. + let ni = u + .variants() + .iter() + .position(|a| a == &AvroSchema::Null) + .unwrap(); + record.put( + &avro_field.name, + Value::Union(ni.try_into().unwrap(), Value::Null.into()), + ); + } + Ok(record) +} + +/// Handles both `validate` (without actual data) and `encode`. +/// See [`MaybeData`] for more info. +fn encode_field( + data_type: &DataType, + maybe: D, + expected: &AvroSchema, +) -> Result { + use risingwave_common::types::Interval; + + let no_match_err = || { + Err(FieldEncodeError::new(format!( + "cannot encode {} column as {} field", + data_type, + expected.canonical_form() + ))) + }; + + if let AvroSchema::Ref { .. } = expected { + return Err(FieldEncodeError::new("avro name ref unsupported yet")); + } + + // For now, we only support optional single type, rather than general union. + // For example, how do we encode int16 into avro `["int", "long"]`? + let (inner, opt_idx) = match expected { + AvroSchema::Union(union) => match union.variants() { + [] => return no_match_err(), + [one] => (one, OptIdx::Single), + [AvroSchema::Null, r] => (r, OptIdx::NullLeft), + [l, AvroSchema::Null] => (l, OptIdx::NullRight), + _ => return no_match_err(), + }, + _ => (expected, OptIdx::NotUnion), + }; + + let value = match &data_type { + // Group A: perfect match between RisingWave types and Avro types + DataType::Boolean => match inner { + AvroSchema::Boolean => maybe.on_base(|s| Ok(Value::Boolean(s.into_bool())))?, + _ => return no_match_err(), + }, + DataType::Varchar => match inner { + AvroSchema::String => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?, + _ => return no_match_err(), + }, + DataType::Bytea => match inner { + AvroSchema::Bytes => maybe.on_base(|s| Ok(Value::Bytes(s.into_bytea().into())))?, + _ => return no_match_err(), + }, + DataType::Float32 => match inner { + AvroSchema::Float => maybe.on_base(|s| Ok(Value::Float(s.into_float32().into())))?, + _ => return no_match_err(), + }, + DataType::Float64 => match inner { + AvroSchema::Double => maybe.on_base(|s| Ok(Value::Double(s.into_float64().into())))?, + _ => return no_match_err(), + }, + DataType::Int32 => match inner { + AvroSchema::Int => maybe.on_base(|s| Ok(Value::Int(s.into_int32())))?, + _ => return no_match_err(), + }, + DataType::Int64 => match inner { + AvroSchema::Long => maybe.on_base(|s| Ok(Value::Long(s.into_int64())))?, + _ => return no_match_err(), + }, + DataType::Struct(st) => match inner { + AvroSchema::Record { .. } => maybe.on_struct(st, inner)?, + _ => return no_match_err(), + }, + DataType::List(elem) => match inner { + AvroSchema::Array(avro_elem) => maybe.on_list(elem, avro_elem)?, + _ => return no_match_err(), + }, + // Group B: match between RisingWave types and Avro logical types + DataType::Timestamptz => match inner { + AvroSchema::TimestampMicros => maybe.on_base(|s| { + Ok(Value::TimestampMicros( + s.into_timestamptz().timestamp_micros(), + )) + })?, + AvroSchema::TimestampMillis => maybe.on_base(|s| { + Ok(Value::TimestampMillis( + s.into_timestamptz().timestamp_millis(), + )) + })?, + _ => return no_match_err(), + }, + DataType::Timestamp => return no_match_err(), + DataType::Date => match inner { + AvroSchema::Date => { + maybe.on_base(|s| Ok(Value::Date(s.into_date().get_nums_days_unix_epoch())))? + } + _ => return no_match_err(), + }, + DataType::Time => match inner { + AvroSchema::TimeMicros => { + maybe.on_base(|s| Ok(Value::TimeMicros(Interval::from(s.into_time()).usecs())))? + } + AvroSchema::TimeMillis => maybe.on_base(|s| { + Ok(Value::TimeMillis( + (Interval::from(s.into_time()).usecs() / 1000) + .try_into() + .unwrap(), + )) + })?, + _ => return no_match_err(), + }, + DataType::Interval => match inner { + AvroSchema::Duration => maybe.on_base(|s| { + use apache_avro::{Days, Duration, Millis, Months}; + let iv = s.into_interval(); + + let overflow = |_| FieldEncodeError::new(format!("{iv} overflows avro duration")); + + Ok(Value::Duration(Duration::new( + Months::new(iv.months().try_into().map_err(overflow)?), + Days::new(iv.days().try_into().map_err(overflow)?), + Millis::new((iv.usecs() / 1000).try_into().map_err(overflow)?), + ))) + })?, + _ => return no_match_err(), + }, + // Group C: experimental + DataType::Int16 => return no_match_err(), + DataType::Decimal => return no_match_err(), + DataType::Jsonb => return no_match_err(), + // Group D: unsupported + DataType::Serial | DataType::Int256 => { + return no_match_err(); + } + }; + + D::handle_union(value, opt_idx) +} + +#[cfg(test)] +mod tests { + use risingwave_common::catalog::Field; + use risingwave_common::row::OwnedRow; + use risingwave_common::types::{ + Date, Datum, Interval, ListValue, ScalarImpl, StructValue, Time, Timestamptz, ToDatumRef, + }; + + use super::*; + + fn test_ok(t: &DataType, d: Datum, avro: &str, expected: Value) { + let avro_schema = AvroSchema::parse_str(avro).unwrap(); + let actual = encode_field(t, d.to_datum_ref(), &avro_schema).unwrap(); + assert_eq!(actual, expected); + } + + fn test_err(t: &DataType, d: D, avro: &str, expected: &str) + where + D::Out: std::fmt::Debug, + { + let avro_schema = AvroSchema::parse_str(avro).unwrap(); + let err = encode_field(t, d, &avro_schema).unwrap_err(); + assert_eq!(err.to_string(), expected); + } + + #[test] + fn test_encode_avro_ok() { + test_ok( + &DataType::Boolean, + Some(ScalarImpl::Bool(false)), + r#""boolean""#, + Value::Boolean(false), + ); + + test_ok( + &DataType::Varchar, + Some(ScalarImpl::Utf8("RisingWave".into())), + r#""string""#, + Value::String("RisingWave".into()), + ); + + test_ok( + &DataType::Bytea, + Some(ScalarImpl::Bytea([0xbe, 0xef].into())), + r#""bytes""#, + Value::Bytes([0xbe, 0xef].into()), + ); + + test_ok( + &DataType::Float32, + Some(ScalarImpl::Float32(3.5f32.into())), + r#""float""#, + Value::Float(3.5f32), + ); + + test_ok( + &DataType::Float64, + Some(ScalarImpl::Float64(4.25f64.into())), + r#""double""#, + Value::Double(4.25f64), + ); + + test_ok( + &DataType::Int32, + Some(ScalarImpl::Int32(16)), + r#""int""#, + Value::Int(16), + ); + + test_ok( + &DataType::Int64, + Some(ScalarImpl::Int64(std::i64::MAX)), + r#""long""#, + Value::Long(i64::MAX), + ); + + let tstz = "2018-01-26T18:30:09.453Z".parse().unwrap(); + test_ok( + &DataType::Timestamptz, + Some(ScalarImpl::Timestamptz(tstz)), + r#"{"type": "long", "logicalType": "timestamp-micros"}"#, + Value::TimestampMicros(tstz.timestamp_micros()), + ); + test_ok( + &DataType::Timestamptz, + Some(ScalarImpl::Timestamptz(tstz)), + r#"{"type": "long", "logicalType": "timestamp-millis"}"#, + Value::TimestampMillis(tstz.timestamp_millis()), + ); + + test_ok( + &DataType::Date, + Some(ScalarImpl::Date(Date::from_ymd_uncheck(1970, 1, 2))), + r#"{"type": "int", "logicalType": "date"}"#, + Value::Date(1), + ); + + let tm = Time::from_num_seconds_from_midnight_uncheck(1000, 0); + test_ok( + &DataType::Time, + Some(ScalarImpl::Time(tm)), + r#"{"type": "long", "logicalType": "time-micros"}"#, + Value::TimeMicros(1000 * 1_000_000), + ); + test_ok( + &DataType::Time, + Some(ScalarImpl::Time(tm)), + r#"{"type": "int", "logicalType": "time-millis"}"#, + Value::TimeMillis(1000 * 1000), + ); + + test_ok( + &DataType::Interval, + Some(ScalarImpl::Interval(Interval::from_month_day_usec( + 13, 2, 1000000, + ))), + // https://github.com/apache/avro/pull/2283 + // r#"{"type": "fixed", "name": "Duration", "size": 12, "logicalType": "duration"}"#, + r#"{"type": {"type": "fixed", "name": "Duration", "size": 12}, "logicalType": "duration"}"#, + Value::Duration(apache_avro::Duration::new( + apache_avro::Months::new(13), + apache_avro::Days::new(2), + apache_avro::Millis::new(1000), + )), + ) + } + + #[test] + fn test_encode_avro_err() { + test_err( + &DataType::Struct(StructType::new(vec![ + ( + "p", + DataType::Struct(StructType::new(vec![ + ("x", DataType::Int32), + ("y", DataType::Int32), + ])), + ), + ( + "q", + DataType::Struct(StructType::new(vec![ + ("x", DataType::Int32), + ("y", DataType::Int32), + ])), + ), + ])), + Some(ScalarImpl::Struct(StructValue::new(vec![ + Some(ScalarImpl::Struct(StructValue::new(vec![ + Some(ScalarImpl::Int32(-2)), + Some(ScalarImpl::Int32(-1)), + ]))), + Some(ScalarImpl::Struct(StructValue::new(vec![ + Some(ScalarImpl::Int32(2)), + Some(ScalarImpl::Int32(1)), + ]))), + ]))) + .to_datum_ref(), + r#"{ + "type": "record", + "name": "Segment", + "fields": [ + { + "name": "p", + "type": { + "type": "record", + "name": "Point", + "fields": [ + { + "name": "x", + "type": "int" + }, + { + "name": "y", + "type": "int" + } + ] + } + }, + { + "name": "q", + "type": "Point" + } + ] + }"#, + "encode q error: avro name ref unsupported yet", + ); + + test_err( + &DataType::Interval, + Some(ScalarRefImpl::Interval(Interval::from_month_day_usec( + -1, + -1, + i64::MAX, + ))), + // https://github.com/apache/avro/pull/2283 + r#"{"type": {"type": "fixed", "name": "Duration", "size": 12}, "logicalType": "duration"}"#, + "encode error: -1 mons -1 days +2562047788:00:54.775807 overflows avro duration", + ); + + let avro_schema = AvroSchema::parse_str( + r#"{"type": "record", "name": "Root", "fields": [ + {"name": "f0", "type": "int"} + ]}"#, + ) + .unwrap(); + let mut record = Record::new(&avro_schema).unwrap(); + record.put("f0", Value::String("2".into())); + let res: SinkResult> = (record, &avro_schema).ser_to(); + assert_eq!(res.unwrap_err().to_string(), "Encode error: Value does not match schema: Reason: Unsupported value-schema combination"); + } + + #[test] + fn test_encode_avro_record() { + let avro_schema = AvroSchema::parse_str( + r#"{ + "type": "record", + "name": "Root", + "fields": [ + {"name": "req", "type": "int"}, + {"name": "opt", "type": ["null", "long"]} + ] + }"#, + ) + .unwrap(); + + let schema = Schema::new(vec![ + Field::with_name(DataType::Int64, "opt"), + Field::with_name(DataType::Int32, "req"), + ]); + let row = OwnedRow::new(vec![ + Some(ScalarImpl::Int64(31)), + Some(ScalarImpl::Int32(15)), + ]); + let encoder = AvroEncoder::new(&schema, None, &avro_schema).unwrap(); + let actual = encoder.encode(row).unwrap(); + assert_eq!( + Value::from(actual.0), + Value::Record(vec![ + ("req".into(), Value::Int(15)), + ("opt".into(), Value::Union(1, Value::Long(31).into())), + ]) + ); + + let schema = Schema::new(vec![Field::with_name(DataType::Int32, "req")]); + let row = OwnedRow::new(vec![Some(ScalarImpl::Int32(15))]); + let encoder = AvroEncoder::new(&schema, None, &avro_schema).unwrap(); + let actual = encoder.encode(row).unwrap(); + assert_eq!( + Value::from(actual.0), + Value::Record(vec![ + ("req".into(), Value::Int(15)), + ("opt".into(), Value::Union(0, Value::Null.into())), + ]) + ); + + let schema = Schema::new(vec![Field::with_name(DataType::Int64, "opt")]); + let Err(err) = AvroEncoder::new(&schema, None, &avro_schema) else { + panic!() + }; + assert_eq!( + err.to_string(), + "Encode error: encode req error: field not present but required" + ); + + let schema = Schema::new(vec![ + Field::with_name(DataType::Int64, "opt"), + Field::with_name(DataType::Int32, "req"), + Field::with_name(DataType::Varchar, "extra"), + ]); + let Err(err) = AvroEncoder::new(&schema, None, &avro_schema) else { + panic!() + }; + assert_eq!( + err.to_string(), + "Encode error: encode extra error: field not in avro" + ); + + let avro_schema = AvroSchema::parse_str(r#"["null", "long"]"#).unwrap(); + let schema = Schema::new(vec![Field::with_name(DataType::Int64, "opt")]); + let Err(err) = AvroEncoder::new(&schema, None, &avro_schema) else { + panic!() + }; + assert_eq!( + err.to_string(), + r#"Encode error: encode error: expect avro record but got ["null","long"]"# + ); + + test_err( + &DataType::Struct(StructType::new(vec![("f0", DataType::Boolean)])), + (), + r#"{"type": "record", "name": "T", "fields": [{"name": "f0", "type": "int"}]}"#, + "encode f0 error: cannot encode boolean column as \"int\" field", + ); + } + + #[test] + fn test_encode_avro_array() { + let avro_schema = r#"{ + "type": "array", + "items": "int" + }"#; + + test_ok( + &DataType::List(DataType::Int32.into()), + Some(ScalarImpl::List(ListValue::new(vec![ + Some(ScalarImpl::Int32(4)), + Some(ScalarImpl::Int32(5)), + ]))), + avro_schema, + Value::Array(vec![Value::Int(4), Value::Int(5)]), + ); + + test_err( + &DataType::List(DataType::Int32.into()), + Some(ScalarImpl::List(ListValue::new(vec![ + Some(ScalarImpl::Int32(4)), + None, + ]))) + .to_datum_ref(), + avro_schema, + "encode error: found null but required", + ); + + test_ok( + &DataType::List(DataType::Int32.into()), + Some(ScalarImpl::List(ListValue::new(vec![ + Some(ScalarImpl::Int32(4)), + None, + ]))), + r#"{ + "type": "array", + "items": ["null", "int"] + }"#, + Value::Array(vec![ + Value::Union(1, Value::Int(4).into()), + Value::Union(0, Value::Null.into()), + ]), + ); + + test_ok( + &DataType::List(DataType::List(DataType::Int32.into()).into()), + Some(ScalarImpl::List(ListValue::new(vec![ + Some(ScalarImpl::List(ListValue::new(vec![ + Some(ScalarImpl::Int32(26)), + Some(ScalarImpl::Int32(29)), + ]))), + Some(ScalarImpl::List(ListValue::new(vec![ + Some(ScalarImpl::Int32(46)), + Some(ScalarImpl::Int32(49)), + ]))), + ]))), + r#"{ + "type": "array", + "items": { + "type": "array", + "items": "int" + } + }"#, + Value::Array(vec![ + Value::Array(vec![Value::Int(26), Value::Int(29)]), + Value::Array(vec![Value::Int(46), Value::Int(49)]), + ]), + ); + + test_err( + &DataType::List(DataType::Boolean.into()), + (), + r#"{"type": "array", "items": "int"}"#, + "encode error: cannot encode boolean column as \"int\" field", + ); + } + + #[test] + fn test_encode_avro_union() { + let t = &DataType::Timestamptz; + let datum = Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(1500))); + let opt_micros = r#"["null", {"type": "long", "logicalType": "timestamp-micros"}]"#; + let opt_millis = r#"["null", {"type": "long", "logicalType": "timestamp-millis"}]"#; + let both = r#"[{"type": "long", "logicalType": "timestamp-millis"}, {"type": "long", "logicalType": "timestamp-micros"}]"#; + let empty = "[]"; + let one = r#"[{"type": "long", "logicalType": "timestamp-millis"}]"#; + let right = r#"[{"type": "long", "logicalType": "timestamp-millis"}, "null"]"#; + + test_ok( + t, + datum.clone(), + opt_micros, + Value::Union(1, Value::TimestampMicros(1500).into()), + ); + test_ok(t, None, opt_micros, Value::Union(0, Value::Null.into())); + test_ok( + t, + datum.clone(), + opt_millis, + Value::Union(1, Value::TimestampMillis(1).into()), + ); + test_ok(t, None, opt_millis, Value::Union(0, Value::Null.into())); + + test_err( + t, + datum.to_datum_ref(), + both, + r#"encode error: cannot encode timestamp with time zone column as [{"type":"long","logicalType":"timestamp-millis"},{"type":"long","logicalType":"timestamp-micros"}] field"#, + ); + + test_err( + t, + datum.to_datum_ref(), + empty, + "encode error: cannot encode timestamp with time zone column as [] field", + ); + + test_ok( + t, + datum.clone(), + one, + Value::Union(0, Value::TimestampMillis(1).into()), + ); + test_err(t, None, one, "encode error: found null but required"); + + test_ok( + t, + datum.clone(), + right, + Value::Union(0, Value::TimestampMillis(1).into()), + ); + test_ok(t, None, right, Value::Union(1, Value::Null.into())); + } + + /// This just demonstrates bugs of the upstream [`apache_avro`], rather than our encoder. + /// The encoder is not using these buggy calls and is already tested above. + #[test] + fn test_encode_avro_lib_bug() { + use apache_avro::Reader; + + // a record with 2 optional int fields + let avro_schema = AvroSchema::parse_str( + r#"{ + "type": "record", + "name": "Root", + "fields": [ + { + "name": "f0", + "type": ["null", "int"] + }, + { + "name": "f1", + "type": ["null", "int"] + } + ] + }"#, + ) + .unwrap(); + + let mut writer = Writer::new(&avro_schema, Vec::new()); + let mut record = Record::new(writer.schema()).unwrap(); + // f0 omitted, f1 = Int(3) + record.put("f1", Value::Int(3)); + writer.append(record).unwrap(); + let encoded = writer.into_inner().unwrap(); + // writing produced no error, but read fails + let reader = Reader::new(encoded.as_slice()).unwrap(); + for value in reader { + assert_eq!( + value.unwrap_err().to_string(), + "Union index 3 out of bounds: 2" + ); + } + + let mut writer = Writer::new(&avro_schema, Vec::new()); + let mut record = Record::new(writer.schema()).unwrap(); + // f0 omitted, f1 = Union(1, Int(3)) + record.put("f1", Value::Union(1, Value::Int(3).into())); + writer.append(record).unwrap(); + let encoded = writer.into_inner().unwrap(); + // writing produced no error, but read returns wrong value + let reader = Reader::new(encoded.as_slice()).unwrap(); + for value in reader { + assert_eq!( + value.unwrap(), + Value::Record(vec![ + ("f0".into(), Value::Union(1, Value::Int(3).into())), + ("f1".into(), Value::Union(0, Value::Null.into())), + ]) + ); + } + } +} diff --git a/src/connector/src/sink/encoder/json.rs b/src/connector/src/sink/encoder/json.rs index 6add09b2cb86e..25d142ad8357b 100644 --- a/src/connector/src/sink/encoder/json.rs +++ b/src/connector/src/sink/encoder/json.rs @@ -90,7 +90,7 @@ impl RowEncoder for JsonEncoder { self.timestamp_handling_mode, &self.custom_json_type, ) - .map_err(|e| SinkError::JsonParse(e.to_string()))?; + .map_err(|e| SinkError::Encode(e.to_string()))?; mappings.insert(key, value); } Ok(mappings) diff --git a/src/connector/src/sink/encoder/mod.rs b/src/connector/src/sink/encoder/mod.rs index 1807fd1d421e8..b0a13606948dd 100644 --- a/src/connector/src/sink/encoder/mod.rs +++ b/src/connector/src/sink/encoder/mod.rs @@ -19,9 +19,13 @@ use risingwave_common::row::Row; use crate::sink::Result; +mod avro; mod json; +mod proto; +pub use avro::AvroEncoder; pub use json::JsonEncoder; +pub use proto::ProtoEncoder; /// Encode a row of a relation into /// * an object in json @@ -57,7 +61,9 @@ pub trait RowEncoder { /// /// This is like `TryInto` but allows us to `impl> SerTo> for T`. /// -/// Shall we consider `impl serde::Serialize` in the future? +/// Note that `serde` does not fit here because its data model does not contain logical types. +/// For example, although `chrono::DateTime` implements `Serialize`, +/// it produces avro String rather than avro `TimestampMicros`. pub trait SerTo { fn ser_to(self) -> Result; } @@ -86,3 +92,42 @@ pub enum CustomJsonType { Doris(HashMap), None, } + +#[derive(Debug)] +struct FieldEncodeError { + message: String, + rev_path: Vec, +} + +impl FieldEncodeError { + fn new(message: impl std::fmt::Display) -> Self { + Self { + message: message.to_string(), + rev_path: vec![], + } + } + + fn with_name(mut self, name: &str) -> Self { + self.rev_path.push(name.into()); + self + } +} + +impl std::fmt::Display for FieldEncodeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use itertools::Itertools; + + write!( + f, + "encode {} error: {}", + self.rev_path.iter().rev().join("."), + self.message + ) + } +} + +impl From for super::SinkError { + fn from(value: FieldEncodeError) -> Self { + Self::Encode(value.to_string()) + } +} diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs new file mode 100644 index 0000000000000..106621fbd240b --- /dev/null +++ b/src/connector/src/sink/encoder/proto.rs @@ -0,0 +1,473 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use prost::Message; +use prost_reflect::{ + DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value, +}; +use risingwave_common::catalog::Schema; +use risingwave_common::row::Row; +use risingwave_common::types::{DataType, DatumRef, ScalarRefImpl, StructType}; +use risingwave_common::util::iter_util::ZipEqDebug; + +use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo}; + +type Result = std::result::Result; + +pub struct ProtoEncoder<'a> { + schema: &'a Schema, + col_indices: Option<&'a [usize]>, + descriptor: MessageDescriptor, +} + +impl<'a> ProtoEncoder<'a> { + pub fn new( + schema: &'a Schema, + col_indices: Option<&'a [usize]>, + descriptor: MessageDescriptor, + ) -> SinkResult { + match col_indices { + Some(col_indices) => validate_fields( + col_indices.iter().map(|idx| { + let f = &schema[*idx]; + (f.name.as_str(), &f.data_type) + }), + &descriptor, + )?, + None => validate_fields( + schema + .fields + .iter() + .map(|f| (f.name.as_str(), &f.data_type)), + &descriptor, + )?, + }; + + Ok(Self { + schema, + col_indices, + descriptor, + }) + } +} + +impl<'a> RowEncoder for ProtoEncoder<'a> { + type Output = DynamicMessage; + + fn schema(&self) -> &Schema { + self.schema + } + + fn col_indices(&self) -> Option<&[usize]> { + self.col_indices + } + + fn encode_cols( + &self, + row: impl Row, + col_indices: impl Iterator, + ) -> SinkResult { + encode_fields( + col_indices.map(|idx| { + let f = &self.schema[idx]; + ((f.name.as_str(), &f.data_type), row.datum_at(idx)) + }), + &self.descriptor, + ) + .map_err(Into::into) + } +} + +impl SerTo> for DynamicMessage { + fn ser_to(self) -> SinkResult> { + Ok(self.encode_to_vec()) + } +} + +/// A trait that assists code reuse between `validate` and `encode`. +/// * For `validate`, the inputs are (RisingWave type, ProtoBuf type). +/// * For `encode`, the inputs are (RisingWave type, RisingWave data, ProtoBuf type). +/// +/// Thus we impl [`MaybeData`] for both [`()`] and [`ScalarRefImpl`]. +trait MaybeData: std::fmt::Debug { + type Out; + + fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result) -> Result; + + fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result; + + fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result; +} + +impl MaybeData for () { + type Out = (); + + fn on_base(self, _: impl FnOnce(ScalarRefImpl<'_>) -> Result) -> Result { + Ok(self) + } + + fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result { + validate_fields(st.iter(), pb) + } + + fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result { + encode_field(elem, (), pb, true) + } +} + +/// Nullability is not part of type system in proto. +/// * Top level is always a message. +/// * All message fields can be omitted in proto3. +/// * All repeated elements must have a value. +/// So we handle [`ScalarRefImpl`] rather than [`DatumRef`] here. +impl MaybeData for ScalarRefImpl<'_> { + type Out = Value; + + fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result) -> Result { + f(self) + } + + fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result { + let d = self.into_struct(); + let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?; + Ok(Value::Message(message)) + } + + fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result { + let d = self.into_list(); + let vs = d + .iter() + .map(|d| { + encode_field( + elem, + d.ok_or_else(|| { + FieldEncodeError::new("array containing null not allowed as repeated field") + })?, + pb, + true, + ) + }) + .try_collect()?; + Ok(Value::List(vs)) + } +} + +fn validate_fields<'a>( + fields: impl Iterator, + descriptor: &MessageDescriptor, +) -> Result<()> { + for (name, t) in fields { + let Some(proto_field) = descriptor.get_field_by_name(name) else { + return Err(FieldEncodeError::new("field not in proto").with_name(name)); + }; + if proto_field.cardinality() == prost_reflect::Cardinality::Required { + return Err(FieldEncodeError::new("`required` not supported").with_name(name)); + } + encode_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?; + } + Ok(()) +} + +fn encode_fields<'a>( + fields_with_datums: impl Iterator)>, + descriptor: &MessageDescriptor, +) -> Result { + let mut message = DynamicMessage::new(descriptor.clone()); + for ((name, t), d) in fields_with_datums { + let proto_field = descriptor.get_field_by_name(name).unwrap(); + // On `null`, simply skip setting the field. + if let Some(scalar) = d { + let value = + encode_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?; + message + .try_set_field(&proto_field, value) + .map_err(|e| FieldEncodeError::new(e).with_name(name))?; + } + } + Ok(message) +} + +// Full name of Well-Known Types +const WKT_TIMESTAMP: &str = "google.protobuf.Timestamp"; +const WKT_BOOL_VALUE: &str = "google.protobuf.BoolValue"; + +/// Handles both `validate` (without actual data) and `encode`. +/// See [`MaybeData`] for more info. +fn encode_field( + data_type: &DataType, + maybe: D, + proto_field: &FieldDescriptor, + in_repeated: bool, +) -> Result { + // Regarding (proto_field.is_list, in_repeated): + // (F, T) => impossible + // (F, F) => encoding to a non-repeated field + // (T, F) => encoding to a repeated field + // (T, T) => encoding to an element of a repeated field + // In the bottom 2 cases, we need to distinguish the same `proto_field` with the help of `in_repeated`. + assert!(proto_field.is_list() || !in_repeated); + let expect_list = proto_field.is_list() && !in_repeated; + if proto_field.is_map() || proto_field.is_group() { + return Err(FieldEncodeError::new( + "proto map or group not supported yet", + )); + } + + let no_match_err = || { + Err(FieldEncodeError::new(format!( + "cannot encode {} column as {}{:?} field", + data_type, + if expect_list { "repeated " } else { "" }, + proto_field.kind() + ))) + }; + + let value = match &data_type { + // Group A: perfect match between RisingWave types and ProtoBuf types + DataType::Boolean => match (expect_list, proto_field.kind()) { + (false, Kind::Bool) => maybe.on_base(|s| Ok(Value::Bool(s.into_bool())))?, + _ => return no_match_err(), + }, + DataType::Varchar => match (expect_list, proto_field.kind()) { + (false, Kind::String) => maybe.on_base(|s| Ok(Value::String(s.into_utf8().into())))?, + _ => return no_match_err(), + }, + DataType::Bytea => match (expect_list, proto_field.kind()) { + (false, Kind::Bytes) => { + maybe.on_base(|s| Ok(Value::Bytes(Bytes::copy_from_slice(s.into_bytea()))))? + } + _ => return no_match_err(), + }, + DataType::Float32 => match (expect_list, proto_field.kind()) { + (false, Kind::Float) => maybe.on_base(|s| Ok(Value::F32(s.into_float32().into())))?, + _ => return no_match_err(), + }, + DataType::Float64 => match (expect_list, proto_field.kind()) { + (false, Kind::Double) => maybe.on_base(|s| Ok(Value::F64(s.into_float64().into())))?, + _ => return no_match_err(), + }, + DataType::Int32 => match (expect_list, proto_field.kind()) { + (false, Kind::Int32 | Kind::Sint32 | Kind::Sfixed32) => { + maybe.on_base(|s| Ok(Value::I32(s.into_int32())))? + } + _ => return no_match_err(), + }, + DataType::Int64 => match (expect_list, proto_field.kind()) { + (false, Kind::Int64 | Kind::Sint64 | Kind::Sfixed64) => { + maybe.on_base(|s| Ok(Value::I64(s.into_int64())))? + } + _ => return no_match_err(), + }, + DataType::Struct(st) => match (expect_list, proto_field.kind()) { + (false, Kind::Message(pb)) => maybe.on_struct(st, &pb)?, + _ => return no_match_err(), + }, + DataType::List(elem) => match expect_list { + true => maybe.on_list(elem, proto_field)?, + false => return no_match_err(), + }, + // Group B: match between RisingWave types and ProtoBuf Well-Known types + DataType::Timestamptz => match (expect_list, proto_field.kind()) { + (false, Kind::Message(pb)) if pb.full_name() == WKT_TIMESTAMP => { + maybe.on_base(|s| { + let d = s.into_timestamptz(); + let message = prost_types::Timestamp { + seconds: d.timestamp(), + nanos: d.timestamp_subsec_nanos().try_into().unwrap(), + }; + Ok(Value::Message(message.transcode_to_dynamic())) + })? + } + _ => return no_match_err(), + }, + DataType::Jsonb => return no_match_err(), // Value, NullValue, Struct (map), ListValue + // Group C: experimental + DataType::Int16 => return no_match_err(), + DataType::Date => return no_match_err(), // google.type.Date + DataType::Time => return no_match_err(), // google.type.TimeOfDay + DataType::Timestamp => return no_match_err(), // google.type.DateTime + DataType::Decimal => return no_match_err(), // google.type.Decimal + DataType::Interval => return no_match_err(), + // Group D: unsupported + DataType::Serial | DataType::Int256 => { + return no_match_err(); + } + }; + + Ok(value) +} + +#[cfg(test)] +mod tests { + use risingwave_common::catalog::Field; + use risingwave_common::row::OwnedRow; + use risingwave_common::types::{ListValue, ScalarImpl, StructValue, Timestamptz}; + + use super::*; + + #[test] + fn test_encode_proto_ok() { + let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("src/test_data/proto_recursive/recursive.pb"); + let pool_bytes = std::fs::read(pool_path).unwrap(); + let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); + let descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); + + let schema = Schema::new(vec![ + Field::with_name(DataType::Boolean, "bool_field"), + Field::with_name(DataType::Varchar, "string_field"), + Field::with_name(DataType::Bytea, "bytes_field"), + Field::with_name(DataType::Float32, "float_field"), + Field::with_name(DataType::Float64, "double_field"), + Field::with_name(DataType::Int32, "int32_field"), + Field::with_name(DataType::Int64, "int64_field"), + Field::with_name(DataType::Int32, "sint32_field"), + Field::with_name(DataType::Int64, "sint64_field"), + Field::with_name(DataType::Int32, "sfixed32_field"), + Field::with_name(DataType::Int64, "sfixed64_field"), + Field::with_name( + DataType::Struct(StructType::new(vec![ + ("id", DataType::Int32), + ("name", DataType::Varchar), + ])), + "nested_message_field", + ), + Field::with_name(DataType::List(DataType::Int32.into()), "repeated_int_field"), + Field::with_name(DataType::Timestamptz, "timestamp_field"), + ]); + let row = OwnedRow::new(vec![ + Some(ScalarImpl::Bool(true)), + Some(ScalarImpl::Utf8("RisingWave".into())), + Some(ScalarImpl::Bytea([0xbe, 0xef].into())), + Some(ScalarImpl::Float32(3.5f32.into())), + Some(ScalarImpl::Float64(4.25f64.into())), + Some(ScalarImpl::Int32(22)), + Some(ScalarImpl::Int64(23)), + Some(ScalarImpl::Int32(24)), + None, + Some(ScalarImpl::Int32(26)), + Some(ScalarImpl::Int64(27)), + Some(ScalarImpl::Struct(StructValue::new(vec![ + Some(ScalarImpl::Int32(1)), + Some(ScalarImpl::Utf8("".into())), + ]))), + Some(ScalarImpl::List(ListValue::new(vec![ + Some(ScalarImpl::Int32(4)), + Some(ScalarImpl::Int32(0)), + Some(ScalarImpl::Int32(4)), + ]))), + Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))), + ]); + + let encoder = ProtoEncoder::new(&schema, None, descriptor.clone()).unwrap(); + let m = encoder.encode(row).unwrap(); + let encoded: Vec = m.ser_to().unwrap(); + assert_eq!( + encoded, + // Hint: write the binary output to a file `test.binpb`, and view it with `protoc`: + // ``` + // protoc --decode_raw < test.binpb + // protoc --decode=recursive.AllTypes recursive.proto < test.binpb + // ``` + [ + 9, 0, 0, 0, 0, 0, 0, 17, 64, 21, 0, 0, 96, 64, 24, 22, 32, 23, 56, 48, 93, 26, 0, + 0, 0, 97, 27, 0, 0, 0, 0, 0, 0, 0, 104, 1, 114, 10, 82, 105, 115, 105, 110, 103, + 87, 97, 118, 101, 122, 2, 190, 239, 138, 1, 2, 8, 1, 146, 1, 3, 4, 0, 4, 186, 1, 3, + 16, 184, 23 + ] + ); + } + + #[test] + fn test_encode_proto_repeated() { + let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("src/test_data/proto_recursive/recursive.pb"); + let pool_bytes = std::fs::read(pool_path).unwrap(); + let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); + let message_descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); + + let schema = Schema::new(vec![Field::with_name( + DataType::List(DataType::List(DataType::Int32.into()).into()), + "repeated_int_field", + )]); + + let err = validate_fields( + schema + .fields + .iter() + .map(|f| (f.name.as_str(), &f.data_type)), + &message_descriptor, + ) + .unwrap_err(); + assert_eq!( + err.to_string(), + "encode repeated_int_field error: cannot encode integer[] column as int32 field" + ); + + let schema = Schema::new(vec![Field::with_name( + DataType::List(DataType::Int32.into()), + "repeated_int_field", + )]); + let row = OwnedRow::new(vec![Some(ScalarImpl::List(ListValue::new(vec![ + Some(ScalarImpl::Int32(0)), + None, + Some(ScalarImpl::Int32(2)), + Some(ScalarImpl::Int32(3)), + ])))]); + + let err = encode_fields( + schema + .fields + .iter() + .map(|f| (f.name.as_str(), &f.data_type)) + .zip_eq_debug(row.iter()), + &message_descriptor, + ) + .unwrap_err(); + assert_eq!( + err.to_string(), + "encode repeated_int_field error: array containing null not allowed as repeated field" + ); + } + + #[test] + fn test_encode_proto_err() { + let pool_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("src/test_data/proto_recursive/recursive.pb"); + let pool_bytes = std::fs::read(pool_path).unwrap(); + let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); + let message_descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); + + let err = validate_fields( + std::iter::once(("not_exists", &DataType::Int16)), + &message_descriptor, + ) + .unwrap_err(); + assert_eq!( + err.to_string(), + "encode not_exists error: field not in proto" + ); + + let err = validate_fields( + std::iter::once(("map_field", &DataType::Jsonb)), + &message_descriptor, + ) + .unwrap_err(); + assert_eq!( + err.to_string(), + "encode map_field error: field not in proto" + ); + } +} diff --git a/src/connector/src/sink/mod.rs b/src/connector/src/sink/mod.rs index 31b276994da4a..8c1c84b78bdd9 100644 --- a/src/connector/src/sink/mod.rs +++ b/src/connector/src/sink/mod.rs @@ -363,8 +363,8 @@ pub enum SinkError { Kinesis(anyhow::Error), #[error("Remote sink error: {0}")] Remote(anyhow::Error), - #[error("Json parse error: {0}")] - JsonParse(String), + #[error("Encode error: {0}")] + Encode(String), #[error("Iceberg error: {0}")] Iceberg(anyhow::Error), #[error("config error: {0}")]