diff --git a/src/connector/src/sink/encoder/avro.rs b/src/connector/src/sink/encoder/avro.rs new file mode 100644 index 0000000000000..923fdde2f5368 --- /dev/null +++ b/src/connector/src/sink/encoder/avro.rs @@ -0,0 +1,327 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use apache_avro::schema::Schema as AvroSchema; +use apache_avro::types::{Record, Value}; +use apache_avro::Writer; +use risingwave_common::array::{ArrayError, ArrayResult}; +use risingwave_common::catalog::Schema; +use risingwave_common::row::Row; +use risingwave_common::types::{DataType, DatumRef, ScalarRefImpl}; +use risingwave_common::util::iter_util::ZipEqDebug; + +use super::{Result, RowEncoder, SerTo}; +use crate::sink::SinkError; + +pub struct AvroEncoder<'a> { + schema: &'a Schema, + col_indices: Option<&'a [usize]>, + avro_schema: &'a AvroSchema, +} + +impl<'a> AvroEncoder<'a> { + pub fn new( + schema: &'a Schema, + col_indices: Option<&'a [usize]>, + avro_schema: &'a AvroSchema, + ) -> Result { + let AvroSchema::Record { fields, lookup, .. } = avro_schema else { + return Err(SinkError::JsonParse(format!( + "not an avro record: {:?}", + avro_schema + ))); + }; + for idx in col_indices.unwrap() { + let f = &schema[*idx]; + let Some(expected) = lookup.get(&f.name).map(|i| &fields[*i]) else { + return Err(SinkError::JsonParse(format!( + "field {} not in avro", + f.name, + ))); + }; + if !is_valid(&f.data_type, &expected.schema) { + return Err(SinkError::JsonParse(format!( + "field {}:{} cannot output as avro {:?}", + f.name, f.data_type, expected + ))); + } + } + + Ok(Self { + schema, + col_indices, + avro_schema, + }) + } +} + +impl<'a> RowEncoder for AvroEncoder<'a> { + type Output = (Record<'a>, &'a AvroSchema); + + fn schema(&self) -> &Schema { + self.schema + } + + fn col_indices(&self) -> Option<&[usize]> { + self.col_indices + } + + fn encode_cols( + &self, + row: impl Row, + col_indices: impl Iterator, + ) -> Result { + let record = row_to_record( + col_indices.map(|idx| { + let f = &self.schema[idx]; + ((f.name.as_str(), &f.data_type), row.datum_at(idx)) + }), + &self.avro_schema, + ) + .map_err(|e| SinkError::JsonParse(e.to_string()))?; + Ok((record, self.avro_schema)) + } +} + +impl<'a> SerTo> for (Record<'a>, &'a AvroSchema) { + fn ser_to(self) -> Result> { + let mut w = Writer::new(self.1, Vec::new()); + w.append(self.0).unwrap(); + Ok(w.into_inner().unwrap()) + } +} + +fn is_valid(_data_type: &DataType, _expected: &AvroSchema) -> bool { + false +} + +fn row_to_record<'avro, 'rw>( + fields_with_datums: impl Iterator)>, + schema: &'avro AvroSchema, +) -> ArrayResult> { + let mut record = Record::new(schema).unwrap(); + let AvroSchema::Record { fields, lookup, .. } = schema else { + unreachable!() + }; + for ((name, t), d) in fields_with_datums { + let expected = &fields[lookup[name]]; + if let Some(scalar) = d { + let value = scalar_to_avro(name, t, scalar, &expected.schema)?; + record.put(name, value); + } + } + Ok(record) +} + +fn scalar_to_avro( + name: &str, + data_type: &DataType, + scalar_ref: ScalarRefImpl<'_>, + expected: &AvroSchema, +) -> ArrayResult { + tracing::debug!("scalar_to_avro: {:?}, {:?}", data_type, scalar_ref); + + let err = || { + Err(ArrayError::internal( + format!("scalar_to_avro: unsupported data type: field name: {:?}, logical type: {:?}, physical type: {:?}", name, data_type, scalar_ref), + )) + }; + + let value = match &data_type { + // Group A: perfect match between RisingWave types and Avro types + DataType::Boolean => match expected { + AvroSchema::Boolean => Value::Boolean(scalar_ref.into_bool()), + _ => return err(), + }, + DataType::Varchar => match expected { + AvroSchema::String => Value::String(scalar_ref.into_utf8().into()), + _ => return err(), + }, + DataType::Bytea => match expected { + AvroSchema::Bytes => Value::Bytes(scalar_ref.into_bytea().into()), + _ => return err(), + }, + DataType::Float32 => match expected { + AvroSchema::Float => Value::Float(scalar_ref.into_float32().into()), + _ => return err(), + }, + DataType::Float64 => match expected { + AvroSchema::Double => Value::Double(scalar_ref.into_float64().into()), + _ => return err(), + }, + DataType::Int32 => match expected { + AvroSchema::Int => Value::Int(scalar_ref.into_int32()), + _ => return err(), + }, + DataType::Int64 => match expected { + AvroSchema::Long => Value::Long(scalar_ref.into_int64()), + _ => return err(), + }, + DataType::Struct(t_rw) => match expected { + AvroSchema::Record { .. } => { + let d = scalar_ref.into_struct(); + let record = + row_to_record(t_rw.iter().zip_eq_debug(d.iter_fields_ref()), expected)?; + record.into() + } + _ => return err(), + }, + DataType::List(t_rw) => match expected { + AvroSchema::Array(elem) => { + let d = scalar_ref.into_list(); + let vs = d + .iter() + .map(|d| scalar_to_avro(name, t_rw, d.unwrap(), elem).unwrap()) + .collect(); + Value::Array(vs) + } + _ => return err(), + }, + // Group B: match between RisingWave types and Avro logical types + DataType::Timestamptz => match expected { + AvroSchema::TimestampMicros => { + Value::TimestampMicros(scalar_ref.into_timestamptz().timestamp_micros()) + } + AvroSchema::TimestampMillis => { + Value::TimestampMillis(scalar_ref.into_timestamptz().timestamp_millis()) + } + _ => return err(), + }, + DataType::Timestamp => todo!(), + DataType::Date => todo!(), + DataType::Time => todo!(), + DataType::Interval => match expected { + AvroSchema::Duration => { + use apache_avro::{Days, Duration, Millis, Months}; + let iv = scalar_ref.into_interval(); + Value::Duration(Duration::new( + Months::new(iv.months().try_into().unwrap()), + Days::new(iv.days().try_into().unwrap()), + Millis::new((iv.usecs() / 1000).try_into().unwrap()), + )) + } + _ => return err(), + }, + // Group C: experimental + DataType::Int16 => todo!(), + DataType::Decimal => todo!(), + DataType::Jsonb => todo!(), + // Group D: unsupported + DataType::Serial | DataType::Int256 => { + return err(); + } + }; + + Ok(value) +} + +#[cfg(test)] +mod tests { + + use risingwave_common::types::{DataType, Interval, ScalarImpl, Time, Timestamp}; + + use super::*; + + fn mock_avro() -> AvroSchema { + todo!() + } + + fn any_value(_: impl std::fmt::Display) -> Value { + todo!() + } + + #[test] + fn test_to_avro_basic_type() { + let expected = mock_avro(); + let boolean_value = scalar_to_avro( + "", + &DataType::Boolean, + ScalarImpl::Bool(false).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(boolean_value, Value::Boolean(false)); + + let int32_value = scalar_to_avro( + "", + &DataType::Int32, + ScalarImpl::Int32(16).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(int32_value, Value::Int(16)); + + let int64_value = scalar_to_avro( + "", + &DataType::Int64, + ScalarImpl::Int64(std::i64::MAX).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(int64_value, Value::Long(i64::MAX)); + + // https://github.com/debezium/debezium/blob/main/debezium-core/src/main/java/io/debezium/time/ZonedTimestamp.java + let tstz_inner = "2018-01-26T18:30:09.453Z".parse().unwrap(); + let tstz_value = scalar_to_avro( + "", + &DataType::Timestamptz, + ScalarImpl::Timestamptz(tstz_inner).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(tstz_value, any_value("2018-01-26 18:30:09.453000")); + + let ts_value = scalar_to_avro( + "", + &DataType::Timestamp, + ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(1000, 0)).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(ts_value, any_value(1000 * 1000)); + + let ts_value = scalar_to_avro( + "", + &DataType::Timestamp, + ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(1000, 0)).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!( + ts_value, + any_value("1970-01-01 00:16:40.000000".to_string()) + ); + + // Represents the number of microseconds past midnigh, io.debezium.time.Time + let time_value = scalar_to_avro( + "", + &DataType::Time, + ScalarImpl::Time(Time::from_num_seconds_from_midnight_uncheck(1000, 0)) + .as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(time_value, any_value(1000 * 1000)); + + let interval_value = scalar_to_avro( + "", + &DataType::Interval, + ScalarImpl::Interval(Interval::from_month_day_usec(13, 2, 1000000)) + .as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(interval_value, any_value("P1Y1M2DT0H0M1S")); + } +} diff --git a/src/connector/src/sink/encoder/mod.rs b/src/connector/src/sink/encoder/mod.rs index 1807fd1d421e8..d1beb5fa402b7 100644 --- a/src/connector/src/sink/encoder/mod.rs +++ b/src/connector/src/sink/encoder/mod.rs @@ -19,9 +19,13 @@ use risingwave_common::row::Row; use crate::sink::Result; +mod avro; mod json; +mod proto; +pub use avro::AvroEncoder; pub use json::JsonEncoder; +pub use proto::ProtoEncoder; /// Encode a row of a relation into /// * an object in json diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs new file mode 100644 index 0000000000000..41df88251cbe9 --- /dev/null +++ b/src/connector/src/sink/encoder/proto.rs @@ -0,0 +1,298 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::Bytes; +use prost::Message; +use prost_reflect::{DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, Value}; +use risingwave_common::array::{ArrayError, ArrayResult}; +use risingwave_common::catalog::Schema; +use risingwave_common::row::Row; +use risingwave_common::types::{DataType, DatumRef, ScalarRefImpl}; +use risingwave_common::util::iter_util::ZipEqDebug; + +use super::{Result, RowEncoder, SerTo}; +use crate::sink::SinkError; + +pub struct ProtoEncoder<'a> { + schema: &'a Schema, + col_indices: Option<&'a [usize]>, + descriptor: MessageDescriptor, +} + +impl<'a> ProtoEncoder<'a> { + pub fn new( + schema: &'a Schema, + col_indices: Option<&'a [usize]>, + descriptor: MessageDescriptor, + ) -> Result { + for idx in col_indices.unwrap() { + let f = &schema[*idx]; + let Some(expected) = descriptor.get_field_by_name(&f.name) else { + return Err(SinkError::JsonParse(format!( + "field {} not in proto", + f.name, + ))); + }; + if !is_valid(&f.data_type, &expected) { + return Err(SinkError::JsonParse(format!( + "field {}:{} cannot output as proto {:?}", + f.name, f.data_type, expected + ))); + } + } + + Ok(Self { + schema, + col_indices, + descriptor, + }) + } +} + +impl<'a> RowEncoder for ProtoEncoder<'a> { + type Output = DynamicMessage; + + fn schema(&self) -> &Schema { + self.schema + } + + fn col_indices(&self) -> Option<&[usize]> { + self.col_indices + } + + fn encode_cols( + &self, + row: impl Row, + col_indices: impl Iterator, + ) -> Result { + row_to_message( + col_indices.map(|idx| { + let f = &self.schema[idx]; + ((f.name.as_str(), &f.data_type), row.datum_at(idx)) + }), + &self.descriptor, + ) + .map_err(|e| SinkError::JsonParse(e.to_string())) + } +} + +impl SerTo> for DynamicMessage { + fn ser_to(self) -> Result> { + Ok(self.encode_to_vec()) + } +} + +fn is_valid(_data_type: &DataType, _expected: &FieldDescriptor) -> bool { + false +} + +fn row_to_message<'a>( + fields_with_datums: impl Iterator)>, + descriptor: &MessageDescriptor, +) -> ArrayResult { + let mut message = DynamicMessage::new(descriptor.clone()); + for ((name, t), d) in fields_with_datums { + let expected = descriptor.get_field_by_name(name).unwrap(); + if let Some(scalar) = d { + let value = scalar_to_proto(name, t, scalar, &expected)?; + message.set_field(&expected, value); + } + } + Ok(message) +} + +fn scalar_to_proto( + name: &str, + data_type: &DataType, + scalar_ref: ScalarRefImpl<'_>, + expected: &FieldDescriptor, +) -> ArrayResult { + tracing::debug!("scalar_to_proto: {:?}, {:?}", data_type, scalar_ref); + + let err = || { + Err(ArrayError::internal( + format!("scalar_to_proto: unsupported data type: field name: {:?}, logical type: {:?}, physical type: {:?}", name, data_type, scalar_ref), + )) + }; + + let value = match &data_type { + // Group A: perfect match between RisingWave types and ProtoBuf types + DataType::Boolean => match expected.kind() { + Kind::Bool if !expected.is_list() => Value::Bool(scalar_ref.into_bool()), + _ => return err(), + }, + DataType::Varchar => match expected.kind() { + Kind::String if !expected.is_list() => Value::String(scalar_ref.into_utf8().into()), + _ => return err(), + }, + DataType::Bytea => match expected.kind() { + Kind::Bytes if !expected.is_list() => { + Value::Bytes(Bytes::copy_from_slice(scalar_ref.into_bytea())) + } + _ => return err(), + }, + DataType::Float32 => match expected.kind() { + Kind::Float if !expected.is_list() => Value::F32(scalar_ref.into_float32().into()), + _ => return err(), + }, + DataType::Float64 => match expected.kind() { + Kind::Double if !expected.is_list() => Value::F64(scalar_ref.into_float64().into()), + _ => return err(), + }, + DataType::Int32 => match expected.kind() { + Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 if !expected.is_list() => { + Value::I32(scalar_ref.into_int32()) + } + _ => return err(), + }, + DataType::Int64 => match expected.kind() { + Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 if !expected.is_list() => { + Value::I64(scalar_ref.into_int64()) + } + _ => return err(), + }, + DataType::Struct(t_rw) => match expected.kind() { + Kind::Message(t_pb) if !expected.is_list() => { + let d = scalar_ref.into_struct(); + let message = row_to_message(t_rw.iter().zip_eq_debug(d.iter_fields_ref()), &t_pb)?; + Value::Message(message) + } + _ => return err(), + }, + DataType::List(t_rw) => match expected.is_list() { + true => { + let d = scalar_ref.into_list(); + let vs = d + .iter() + .map(|d| scalar_to_proto(name, t_rw, d.unwrap(), expected).unwrap()) + .collect(); + Value::List(vs) + } + false => return err(), + }, + // Group B: match between RisingWave types and ProtoBuf Well-Known types + DataType::Timestamptz => todo!(), + DataType::Jsonb => todo!(), + // Group C: experimental + DataType::Int16 => todo!(), + DataType::Date | DataType::Timestamp | DataType::Time | DataType::Decimal => todo!(), + DataType::Interval => todo!(), + // Group D: unsupported + DataType::Serial | DataType::Int256 => { + return err(); + } + }; + + Ok(value) +} + +#[cfg(test)] +mod tests { + + use risingwave_common::types::{DataType, Interval, ScalarImpl, Time, Timestamp}; + + use super::*; + + fn mock_pb() -> FieldDescriptor { + todo!() + } + + fn any_value(_: impl std::fmt::Display) -> Value { + todo!() + } + + #[test] + fn test_to_proto_basic_type() { + let expected = mock_pb(); + let boolean_value = scalar_to_proto( + "", + &DataType::Boolean, + ScalarImpl::Bool(false).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(boolean_value, Value::Bool(false)); + + let int32_value = scalar_to_proto( + "", + &DataType::Int32, + ScalarImpl::Int32(16).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(int32_value, Value::I32(16)); + + let int64_value = scalar_to_proto( + "", + &DataType::Int64, + ScalarImpl::Int64(std::i64::MAX).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(int64_value, Value::I64(i64::MAX)); + + // https://github.com/debezium/debezium/blob/main/debezium-core/src/main/java/io/debezium/time/ZonedTimestamp.java + let tstz_inner = "2018-01-26T18:30:09.453Z".parse().unwrap(); + let tstz_value = scalar_to_proto( + "", + &DataType::Timestamptz, + ScalarImpl::Timestamptz(tstz_inner).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(tstz_value, any_value("2018-01-26 18:30:09.453000")); + + let ts_value = scalar_to_proto( + "", + &DataType::Timestamp, + ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(1000, 0)).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(ts_value, any_value(1000 * 1000)); + + let ts_value = scalar_to_proto( + "", + &DataType::Timestamp, + ScalarImpl::Timestamp(Timestamp::from_timestamp_uncheck(1000, 0)).as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!( + ts_value, + any_value("1970-01-01 00:16:40.000000".to_string()) + ); + + // Represents the number of microseconds past midnigh, io.debezium.time.Time + let time_value = scalar_to_proto( + "", + &DataType::Time, + ScalarImpl::Time(Time::from_num_seconds_from_midnight_uncheck(1000, 0)) + .as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(time_value, any_value(1000 * 1000)); + + let interval_value = scalar_to_proto( + "", + &DataType::Interval, + ScalarImpl::Interval(Interval::from_month_day_usec(13, 2, 1000000)) + .as_scalar_ref_impl(), + &expected, + ) + .unwrap(); + assert_eq!(interval_value, any_value("P1Y1M2DT0H0M1S")); + } +}