From bf55e90a2aaba84697dbacf65f5e4b7c2883e77c Mon Sep 17 00:00:00 2001 From: xxchan Date: Thu, 30 May 2024 20:40:08 +0800 Subject: [PATCH] refactor(source): refactor Access --- src/connector/src/parser/json_parser.rs | 3 +- src/connector/src/parser/plain_parser.rs | 13 +- src/connector/src/parser/protobuf/parser.rs | 3 +- src/connector/src/parser/unified/avro.rs | 129 ++++++++----------- src/connector/src/parser/unified/bytes.rs | 4 +- src/connector/src/parser/unified/debezium.rs | 34 ++--- src/connector/src/parser/unified/json.rs | 115 ++++++++--------- src/connector/src/parser/unified/maxwell.rs | 4 +- src/connector/src/parser/unified/mod.rs | 30 ++++- src/connector/src/parser/unified/protobuf.rs | 2 +- src/connector/src/parser/unified/upsert.rs | 93 +++++++------ src/connector/src/parser/upsert_parser.rs | 4 +- 12 files changed, 219 insertions(+), 215 deletions(-) diff --git a/src/connector/src/parser/json_parser.rs b/src/connector/src/parser/json_parser.rs index ca8a3d1e7b44f..3ce74f00e2993 100644 --- a/src/connector/src/parser/json_parser.rs +++ b/src/connector/src/parser/json_parser.rs @@ -132,8 +132,7 @@ impl JsonParser { let mut errors = Vec::new(); for value in values { let accessor = JsonAccess::new(value); - match writer.insert(|column| accessor.access(&[&column.name], Some(&column.data_type))) - { + match writer.insert(|column| accessor.access(&[&column.name], &column.data_type)) { Ok(_) => {} Err(err) => errors.push(err), } diff --git a/src/connector/src/parser/plain_parser.rs b/src/connector/src/parser/plain_parser.rs index 2241f786cfdd5..fb4c66819ba9d 100644 --- a/src/connector/src/parser/plain_parser.rs +++ b/src/connector/src/parser/plain_parser.rs @@ -15,7 +15,7 @@ use risingwave_common::bail; use super::unified::json::TimestamptzHandling; -use super::unified::ChangeEvent; +use super::unified::upsert::PlainEvent; use super::{ AccessBuilderImpl, ByteStreamSourceParser, EncodingProperties, EncodingType, SourceStreamChunkRowWriter, SpecificParserConfig, @@ -24,7 +24,6 @@ use crate::error::ConnectorResult; use crate::parser::bytes_parser::BytesAccessBuilder; use crate::parser::simd_json_parser::DebeziumJsonAccessBuilder; use crate::parser::unified::debezium::parse_transaction_meta; -use crate::parser::unified::upsert::UpsertChangeEvent; use crate::parser::unified::AccessImpl; use crate::parser::upsert_parser::get_key_column_name; use crate::parser::{BytesProperties, ParseResult, ParserFormat}; @@ -103,22 +102,20 @@ impl PlainParser { }; } - // reuse upsert component but always insert - let mut row_op: UpsertChangeEvent, AccessImpl<'_, '_>> = - UpsertChangeEvent::default(); + let mut row_op: PlainEvent, AccessImpl<'_, '_>> = PlainEvent::default(); if let Some(data) = key && let Some(key_builder) = self.key_builder.as_mut() { // key is optional in format plain - row_op = row_op.with_key(key_builder.generate_accessor(data).await?); + row_op.with_key(key_builder.generate_accessor(data).await?); } if let Some(data) = payload { // the data part also can be an empty vec - row_op = row_op.with_value(self.payload_builder.generate_accessor(data).await?); + row_op.with_value(self.payload_builder.generate_accessor(data).await?); } - writer.insert(|column: &SourceColumnDesc| row_op.access_field(column))?; + writer.insert(|column: &SourceColumnDesc| row_op.access_field_impl(column))?; Ok(ParseResult::Rows) } diff --git a/src/connector/src/parser/protobuf/parser.rs b/src/connector/src/parser/protobuf/parser.rs index 7dcb502b1f674..7638215c0b27a 100644 --- a/src/connector/src/parser/protobuf/parser.rs +++ b/src/connector/src/parser/protobuf/parser.rs @@ -896,7 +896,8 @@ mod test { } fn pb_eq(a: &ProtobufAccess, field_name: &str, value: ScalarImpl) { - let d = a.access(&[field_name], None).unwrap().unwrap(); + let dummy_type = DataType::Varchar; + let d = a.access(&[field_name], &dummy_type).unwrap().unwrap(); assert_eq!(d, value, "field: {} value: {:?}", field_name, d); } diff --git a/src/connector/src/parser/unified/avro.rs b/src/connector/src/parser/unified/avro.rs index 2c94eb47ccfd1..7ed9cad3dfca1 100644 --- a/src/connector/src/parser/unified/avro.rs +++ b/src/connector/src/parser/unified/avro.rs @@ -24,11 +24,11 @@ use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::bail; use risingwave_common::log::LogSuppresser; use risingwave_common::types::{ - DataType, Date, Datum, Interval, JsonbVal, ScalarImpl, Time, Timestamp, Timestamptz, + DataType, Date, Interval, JsonbVal, ScalarImpl, Time, Timestamp, Timestamptz, }; use risingwave_common::util::iter_util::ZipEqFast; -use super::{bail_uncategorized, uncategorized, Access, AccessError, AccessResult}; +use super::{bail_uncategorized, uncategorized, Access, AccessError, AccessResult, NullableAccess}; use crate::error::ConnectorResult; use crate::parser::avro::util::avro_to_jsonb; #[derive(Clone)] @@ -82,7 +82,7 @@ impl<'a> AvroParseOptions<'a> { pub fn convert_to_datum<'b>( &self, value: &'b Value, - type_expected: Option<&'b DataType>, + type_expected: &'b DataType, ) -> AccessResult where 'b: 'a, @@ -104,25 +104,25 @@ impl<'a> AvroParseOptions<'a> { .convert_to_datum(v, type_expected); } // ---- Boolean ----- - (Some(DataType::Boolean) | None, Value::Boolean(b)) => (*b).into(), + (DataType::Boolean, Value::Boolean(b)) => (*b).into(), // ---- Int16 ----- - (Some(DataType::Int16), Value::Int(i)) if self.relax_numeric => (*i as i16).into(), - (Some(DataType::Int16), Value::Long(i)) if self.relax_numeric => (*i as i16).into(), + (DataType::Int16, Value::Int(i)) if self.relax_numeric => (*i as i16).into(), + (DataType::Int16, Value::Long(i)) if self.relax_numeric => (*i as i16).into(), // ---- Int32 ----- - (Some(DataType::Int32) | None, Value::Int(i)) => (*i).into(), - (Some(DataType::Int32), Value::Long(i)) if self.relax_numeric => (*i as i32).into(), + (DataType::Int32, Value::Int(i)) => (*i).into(), + (DataType::Int32, Value::Long(i)) if self.relax_numeric => (*i as i32).into(), // ---- Int64 ----- - (Some(DataType::Int64) | None, Value::Long(i)) => (*i).into(), - (Some(DataType::Int64), Value::Int(i)) if self.relax_numeric => (*i as i64).into(), + (DataType::Int64, Value::Long(i)) => (*i).into(), + (DataType::Int64, Value::Int(i)) if self.relax_numeric => (*i as i64).into(), // ---- Float32 ----- - (Some(DataType::Float32) | None, Value::Float(i)) => (*i).into(), - (Some(DataType::Float32), Value::Double(i)) => (*i as f32).into(), + (DataType::Float32, Value::Float(i)) => (*i).into(), + (DataType::Float32, Value::Double(i)) => (*i as f32).into(), // ---- Float64 ----- - (Some(DataType::Float64) | None, Value::Double(i)) => (*i).into(), - (Some(DataType::Float64), Value::Float(i)) => (*i as f64).into(), + (DataType::Float64, Value::Double(i)) => (*i).into(), + (DataType::Float64, Value::Float(i)) => (*i as f64).into(), // ---- Decimal ----- - (Some(DataType::Decimal) | None, Value::Decimal(avro_decimal)) => { + (DataType::Decimal, Value::Decimal(avro_decimal)) => { let (precision, scale) = match self.schema { Some(Schema::Decimal(DecimalSchema { precision, scale, .. @@ -133,7 +133,7 @@ impl<'a> AvroParseOptions<'a> { .map_err(|_| create_error())?; ScalarImpl::Decimal(risingwave_common::types::Decimal::Normalized(decimal)) } - (Some(DataType::Decimal), Value::Record(fields)) => { + (DataType::Decimal, Value::Record(fields)) => { // VariableScaleDecimal has fixed fields, scale(int) and value(bytes) let find_in_records = |field_name: &str| { fields @@ -167,56 +167,46 @@ impl<'a> AvroParseOptions<'a> { ScalarImpl::Decimal(risingwave_common::types::Decimal::Normalized(decimal)) } // ---- Time ----- - (Some(DataType::Time), Value::TimeMillis(ms)) => Time::with_milli(*ms as u32) + (DataType::Time, Value::TimeMillis(ms)) => Time::with_milli(*ms as u32) .map_err(|_| create_error())? .into(), - (Some(DataType::Time), Value::TimeMicros(us)) => Time::with_micro(*us as u64) + (DataType::Time, Value::TimeMicros(us)) => Time::with_micro(*us as u64) .map_err(|_| create_error())? .into(), // ---- Date ----- - (Some(DataType::Date) | None, Value::Date(days)) => { - Date::with_days(days + unix_epoch_days()) - .map_err(|_| create_error())? - .into() - } + (DataType::Date, Value::Date(days)) => Date::with_days(days + unix_epoch_days()) + .map_err(|_| create_error())? + .into(), // ---- Varchar ----- - (Some(DataType::Varchar) | None, Value::Enum(_, symbol)) => { - symbol.clone().into_boxed_str().into() - } - (Some(DataType::Varchar) | None, Value::String(s)) => s.clone().into_boxed_str().into(), + (DataType::Varchar, Value::Enum(_, symbol)) => symbol.clone().into_boxed_str().into(), + (DataType::Varchar, Value::String(s)) => s.clone().into_boxed_str().into(), // ---- Timestamp ----- - (Some(DataType::Timestamp) | None, Value::LocalTimestampMillis(ms)) => { - Timestamp::with_millis(*ms) - .map_err(|_| create_error())? - .into() - } - (Some(DataType::Timestamp) | None, Value::LocalTimestampMicros(us)) => { - Timestamp::with_micros(*us) - .map_err(|_| create_error())? - .into() - } + (DataType::Timestamp, Value::LocalTimestampMillis(ms)) => Timestamp::with_millis(*ms) + .map_err(|_| create_error())? + .into(), + (DataType::Timestamp, Value::LocalTimestampMicros(us)) => Timestamp::with_micros(*us) + .map_err(|_| create_error())? + .into(), // ---- TimestampTz ----- - (Some(DataType::Timestamptz) | None, Value::TimestampMillis(ms)) => { - Timestamptz::from_millis(*ms) - .ok_or_else(|| { - uncategorized!("timestamptz with milliseconds {ms} * 1000 is out of range") - })? - .into() - } - (Some(DataType::Timestamptz) | None, Value::TimestampMicros(us)) => { + (DataType::Timestamptz, Value::TimestampMillis(ms)) => Timestamptz::from_millis(*ms) + .ok_or_else(|| { + uncategorized!("timestamptz with milliseconds {ms} * 1000 is out of range") + })? + .into(), + (DataType::Timestamptz, Value::TimestampMicros(us)) => { Timestamptz::from_micros(*us).into() } // ---- Interval ----- - (Some(DataType::Interval) | None, Value::Duration(duration)) => { + (DataType::Interval, Value::Duration(duration)) => { let months = u32::from(duration.months()) as i32; let days = u32::from(duration.days()) as i32; let usecs = (u32::from(duration.millis()) as i64) * 1000; // never overflows ScalarImpl::Interval(Interval::from_month_day_usec(months, days, usecs)) } // ---- Struct ----- - (Some(DataType::Struct(struct_type_info)), Value::Record(descs)) => StructValue::new( + (DataType::Struct(struct_type_info), Value::Record(descs)) => StructValue::new( struct_type_info .names() .zip_eq_fast(struct_type_info.types()) @@ -228,7 +218,7 @@ impl<'a> AvroParseOptions<'a> { schema, relax_numeric: self.relax_numeric, } - .convert_to_datum(value, Some(field_type))?) + .convert_to_datum(value, field_type)?) } else { Ok(None) } @@ -236,22 +226,8 @@ impl<'a> AvroParseOptions<'a> { .collect::>()?, ) .into(), - (None, Value::Record(descs)) => { - let rw_values = descs - .iter() - .map(|(field_name, field_value)| { - let schema = self.extract_inner_schema(Some(field_name)); - Self { - schema, - relax_numeric: self.relax_numeric, - } - .convert_to_datum(field_value, None) - }) - .collect::, AccessError>>()?; - ScalarImpl::Struct(StructValue::new(rw_values)) - } // ---- List ----- - (Some(DataType::List(item_type)), Value::Array(array)) => ListValue::new({ + (DataType::List(item_type), Value::Array(array)) => ListValue::new({ let schema = self.extract_inner_schema(None); let mut builder = item_type.create_array_builder(array.len()); for v in array { @@ -259,18 +235,16 @@ impl<'a> AvroParseOptions<'a> { schema, relax_numeric: self.relax_numeric, } - .convert_to_datum(v, Some(item_type))?; + .convert_to_datum(v, item_type)?; builder.append(value); } builder.finish() }) .into(), // ---- Bytea ----- - (Some(DataType::Bytea) | None, Value::Bytes(value)) => { - value.clone().into_boxed_slice().into() - } + (DataType::Bytea, Value::Bytes(value)) => value.clone().into_boxed_slice().into(), // ---- Jsonb ----- - (Some(DataType::Jsonb), v @ Value::Map(_)) => { + (DataType::Jsonb, v @ Value::Map(_)) => { let mut builder = jsonbb::Builder::default(); avro_to_jsonb(v, &mut builder)?; let jsonb = builder.finish(); @@ -299,7 +273,7 @@ impl<'a, 'b> Access for AvroAccess<'a, 'b> where 'a: 'b, { - fn access(&self, path: &[&str], type_expected: Option<&DataType>) -> AccessResult { + fn access(&self, path: &[&str], type_expected: &DataType) -> AccessResult { let mut value = self.value; let mut options: AvroParseOptions<'_> = self.options.clone(); @@ -333,6 +307,15 @@ where } } +impl<'a, 'b> NullableAccess for AvroAccess<'a, 'b> +where + 'a: 'b, +{ + fn is_null(&self) -> bool { + matches!(self.value, Value::Null) + } +} + pub(crate) fn avro_decimal_to_rust_decimal( avro_decimal: AvroDecimal, _precision: usize, @@ -436,7 +419,7 @@ mod tests { use std::str::FromStr; use apache_avro::Decimal as AvroDecimal; - use risingwave_common::types::Decimal; + use risingwave_common::types::{Datum, Decimal}; use super::*; @@ -489,7 +472,7 @@ mod tests { shape: &DataType, ) -> crate::error::ConnectorResult { AvroParseOptions::create(value_schema) - .convert_to_datum(&value, Some(shape)) + .convert_to_datum(&value, shape) .map_err(Into::into) } @@ -532,7 +515,7 @@ mod tests { let value = Value::Decimal(AvroDecimal::from(bytes)); let options = AvroParseOptions::create(&schema); let resp = options - .convert_to_datum(&value, Some(&DataType::Decimal)) + .convert_to_datum(&value, &DataType::Decimal) .unwrap(); assert_eq!( resp, @@ -571,7 +554,7 @@ mod tests { let options = AvroParseOptions::create(&schema); let resp = options - .convert_to_datum(&value, Some(&DataType::Decimal)) + .convert_to_datum(&value, &DataType::Decimal) .unwrap(); assert_eq!(resp, Some(ScalarImpl::Decimal(Decimal::from(66051)))); } diff --git a/src/connector/src/parser/unified/bytes.rs b/src/connector/src/parser/unified/bytes.rs index ff47424d60acf..f9064c3ec3079 100644 --- a/src/connector/src/parser/unified/bytes.rs +++ b/src/connector/src/parser/unified/bytes.rs @@ -31,8 +31,8 @@ impl<'a> BytesAccess<'a> { impl<'a> Access for BytesAccess<'a> { /// path is empty currently, `type_expected` should be `Bytea` - fn access(&self, path: &[&str], type_expected: Option<&DataType>) -> AccessResult { - if let DataType::Bytea = type_expected.unwrap() { + fn access(&self, path: &[&str], type_expected: &DataType) -> AccessResult { + if let DataType::Bytea = type_expected { if self.column_name.is_none() || (path.len() == 1 && self.column_name.as_ref().unwrap() == path[0]) { diff --git a/src/connector/src/parser/unified/debezium.rs b/src/connector/src/parser/unified/debezium.rs index 966c5f167474c..3c415ad96678b 100644 --- a/src/connector/src/parser/unified/debezium.rs +++ b/src/connector/src/parser/unified/debezium.rs @@ -79,8 +79,8 @@ pub fn parse_transaction_meta( connector_props: &ConnectorProperties, ) -> AccessResult { if let (Some(ScalarImpl::Utf8(status)), Some(ScalarImpl::Utf8(id))) = ( - accessor.access(&[TRANSACTION_STATUS], Some(&DataType::Varchar))?, - accessor.access(&[TRANSACTION_ID], Some(&DataType::Varchar))?, + accessor.access(&[TRANSACTION_STATUS], &DataType::Varchar)?, + accessor.access(&[TRANSACTION_ID], &DataType::Varchar)?, ) { // The id field has different meanings for different databases: // PG: txID:LSN @@ -172,16 +172,16 @@ where .key_accessor .as_ref() .expect("key_accessor must be provided for delete operation") - .access(&[&desc.name], Some(&desc.data_type)); + .access(&[&desc.name], &desc.data_type); } if let Some(va) = self.value_accessor.as_ref() { - va.access(&[BEFORE, &desc.name], Some(&desc.data_type)) + va.access(&[BEFORE, &desc.name], &desc.data_type) } else { self.key_accessor .as_ref() .unwrap() - .access(&[&desc.name], Some(&desc.data_type)) + .access(&[&desc.name], &desc.data_type) } } @@ -193,7 +193,7 @@ where self.value_accessor .as_ref() .expect("value_accessor must be provided for upsert operation") - .access(&[AFTER, &desc.name], Some(&desc.data_type)) + .access(&[AFTER, &desc.name], &desc.data_type) }, |additional_column_type| { match additional_column_type { @@ -203,7 +203,7 @@ where .value_accessor .as_ref() .expect("value_accessor must be provided for upsert operation") - .access(&[SOURCE, SOURCE_TS_MS], Some(&DataType::Int64))?; + .access(&[SOURCE, SOURCE_TS_MS], &DataType::Int64)?; Ok(ts_ms.map(|scalar| { Timestamptz::from_millis(scalar.into_int64()) .expect("source.ts_ms must in millisecond") @@ -222,7 +222,7 @@ where fn op(&self) -> Result { if let Some(accessor) = &self.value_accessor { - if let Some(ScalarImpl::Utf8(op)) = accessor.access(&[OP], Some(&DataType::Varchar))? { + if let Some(ScalarImpl::Utf8(op)) = accessor.access(&[OP], &DataType::Varchar)? { match op.as_ref() { DEBEZIUM_READ_OP | DEBEZIUM_CREATE_OP | DEBEZIUM_UPDATE_OP => { return Ok(ChangeEventOperation::Upsert) @@ -309,15 +309,12 @@ impl Access for MongoJsonAccess where A: Access, { - fn access(&self, path: &[&str], type_expected: Option<&DataType>) -> super::AccessResult { + fn access(&self, path: &[&str], type_expected: &DataType) -> super::AccessResult { match path { ["after" | "before", "_id"] => { - let payload = self.access(&[path[0]], Some(&DataType::Jsonb))?; + let payload = self.access(&[path[0]], &DataType::Jsonb)?; if let Some(ScalarImpl::Jsonb(bson_doc)) = payload { - Ok(extract_bson_id( - type_expected.unwrap_or(&DataType::Jsonb), - &bson_doc.take(), - )?) + Ok(extract_bson_id(type_expected, &bson_doc.take())?) } else { // fail to extract the "_id" field from the message payload Err(AccessError::Undefined { @@ -326,19 +323,16 @@ where })? } } - ["after" | "before", "payload"] => self.access(&[path[0]], Some(&DataType::Jsonb)), + ["after" | "before", "payload"] => self.access(&[path[0]], &DataType::Jsonb), // To handle a DELETE message, we need to extract the "_id" field from the message key, because it is not in the payload. // In addition, the "_id" field is named as "id" in the key. An example of message key: // {"schema":null,"payload":{"id":"{\"$oid\": \"65bc9fb6c485f419a7a877fe\"}"}} ["_id"] => { let ret = self.accessor.access(path, type_expected); if matches!(ret, Err(AccessError::Undefined { .. })) { - let id_bson = self.accessor.access(&["id"], Some(&DataType::Jsonb))?; + let id_bson = self.accessor.access(&["id"], &DataType::Jsonb)?; if let Some(ScalarImpl::Jsonb(bson_doc)) = id_bson { - Ok(extract_bson_id( - type_expected.unwrap_or(&DataType::Jsonb), - &bson_doc.take(), - )?) + Ok(extract_bson_id(type_expected, &bson_doc.take())?) } else { // fail to extract the "_id" field from the message key Err(AccessError::Undefined { diff --git a/src/connector/src/parser/unified/json.rs b/src/connector/src/parser/unified/json.rs index 11c569832268e..efec451d62e78 100644 --- a/src/connector/src/parser/unified/json.rs +++ b/src/connector/src/parser/unified/json.rs @@ -31,7 +31,7 @@ use simd_json::prelude::{ use simd_json::{BorrowedValue, ValueType}; use thiserror_ext::AsReport; -use super::{Access, AccessError, AccessResult}; +use super::{Access, AccessError, AccessResult, NullableAccess}; use crate::parser::common::json_object_get_case_insensitive; use crate::parser::unified::avro::extract_decimal; use crate::schema::{bail_invalid_option_error, InvalidOptionError}; @@ -199,11 +199,7 @@ impl JsonParseOptions { } } - pub fn parse( - &self, - value: &BorrowedValue<'_>, - type_expected: Option<&DataType>, - ) -> AccessResult { + pub fn parse(&self, value: &BorrowedValue<'_>, type_expected: &DataType) -> AccessResult { let create_error = || AccessError::TypeError { expected: format!("{:?}", type_expected), got: value.value_type().to_string(), @@ -213,10 +209,10 @@ impl JsonParseOptions { let v: ScalarImpl = match (type_expected, value.value_type()) { (_, ValueType::Null) => return Ok(None), // ---- Boolean ----- - (Some(DataType::Boolean) | None, ValueType::Bool) => value.as_bool().unwrap().into(), + (DataType::Boolean , ValueType::Bool) => value.as_bool().unwrap().into(), ( - Some(DataType::Boolean), + DataType::Boolean, ValueType::I64 | ValueType::I128 | ValueType::U64 | ValueType::U128, ) if matches!(self.boolean_handling, BooleanHandling::Relax { .. }) && matches!(value.as_i64(), Some(0i64) | Some(1i64)) => @@ -224,7 +220,7 @@ impl JsonParseOptions { (value.as_i64() == Some(1i64)).into() } - (Some(DataType::Boolean), ValueType::String) + (DataType::Boolean, ValueType::String) if matches!( self.boolean_handling, BooleanHandling::Relax { @@ -256,11 +252,11 @@ impl JsonParseOptions { } // ---- Int16 ----- ( - Some(DataType::Int16), + DataType::Int16, ValueType::I64 | ValueType::I128 | ValueType::U64 | ValueType::U128, ) => value.try_as_i16().map_err(|_| create_error())?.into(), - (Some(DataType::Int16), ValueType::String) + (DataType::Int16, ValueType::String) if matches!( self.numeric_handling, NumericHandling::Relax { @@ -277,11 +273,11 @@ impl JsonParseOptions { } // ---- Int32 ----- ( - Some(DataType::Int32), + DataType::Int32, ValueType::I64 | ValueType::I128 | ValueType::U64 | ValueType::U128, ) => value.try_as_i32().map_err(|_| create_error())?.into(), - (Some(DataType::Int32), ValueType::String) + (DataType::Int32, ValueType::String) if matches!( self.numeric_handling, NumericHandling::Relax { @@ -297,15 +293,12 @@ impl JsonParseOptions { .into() } // ---- Int64 ----- - (None, ValueType::I64 | ValueType::U64) => { - value.try_as_i64().map_err(|_| create_error())?.into() - } ( - Some(DataType::Int64), + DataType::Int64, ValueType::I64 | ValueType::I128 | ValueType::U64 | ValueType::U128, ) => value.try_as_i64().map_err(|_| create_error())?.into(), - (Some(DataType::Int64), ValueType::String) + (DataType::Int64, ValueType::String) if matches!( self.numeric_handling, NumericHandling::Relax { @@ -322,12 +315,12 @@ impl JsonParseOptions { } // ---- Float32 ----- ( - Some(DataType::Float32), + DataType::Float32, ValueType::I64 | ValueType::I128 | ValueType::U64 | ValueType::U128, ) if matches!(self.numeric_handling, NumericHandling::Relax { .. }) => { (value.try_as_i64().map_err(|_| create_error())? as f32).into() } - (Some(DataType::Float32), ValueType::String) + (DataType::Float32, ValueType::String) if matches!( self.numeric_handling, NumericHandling::Relax { @@ -342,17 +335,17 @@ impl JsonParseOptions { .map_err(|_| create_error())? .into() } - (Some(DataType::Float32), ValueType::F64) => { + (DataType::Float32, ValueType::F64) => { value.try_as_f32().map_err(|_| create_error())?.into() } // ---- Float64 ----- ( - Some(DataType::Float64), + DataType::Float64, ValueType::I64 | ValueType::I128 | ValueType::U64 | ValueType::U128, ) if matches!(self.numeric_handling, NumericHandling::Relax { .. }) => { (value.try_as_i64().map_err(|_| create_error())? as f64).into() } - (Some(DataType::Float64), ValueType::String) + (DataType::Float64, ValueType::String) if matches!( self.numeric_handling, NumericHandling::Relax { @@ -367,25 +360,25 @@ impl JsonParseOptions { .map_err(|_| create_error())? .into() } - (Some(DataType::Float64) | None, ValueType::F64) => { + (DataType::Float64 , ValueType::F64) => { value.try_as_f64().map_err(|_| create_error())?.into() } // ---- Decimal ----- - (Some(DataType::Decimal) | None, ValueType::I128 | ValueType::U128) => { + (DataType::Decimal , ValueType::I128 | ValueType::U128) => { Decimal::from_str(&value.try_as_i128().map_err(|_| create_error())?.to_string()) .map_err(|_| create_error())? .into() } - (Some(DataType::Decimal), ValueType::I64 | ValueType::U64) => { + (DataType::Decimal, ValueType::I64 | ValueType::U64) => { Decimal::from(value.try_as_i64().map_err(|_| create_error())?).into() } - (Some(DataType::Decimal), ValueType::F64) => { + (DataType::Decimal, ValueType::F64) => { Decimal::try_from(value.try_as_f64().map_err(|_| create_error())?) .map_err(|_| create_error())? .into() } - (Some(DataType::Decimal), ValueType::String) => { + (DataType::Decimal, ValueType::String) => { let str = value.as_str().unwrap(); // the following values are special string generated by Debezium and should be handled separately match str { @@ -395,7 +388,7 @@ impl JsonParseOptions { _ => ScalarImpl::Decimal(Decimal::from_str(str).map_err(|_err| create_error())?), } } - (Some(DataType::Decimal), ValueType::Object) => { + (DataType::Decimal, ValueType::Object) => { // ref https://github.com/risingwavelabs/risingwave/issues/10628 // handle debezium json (variable scale): {"scale": int, "value": bytes} let scale = value @@ -418,21 +411,21 @@ impl JsonParseOptions { } // ---- Date ----- ( - Some(DataType::Date), + DataType::Date, ValueType::I64 | ValueType::I128 | ValueType::U64 | ValueType::U128, ) => Date::with_days_since_unix_epoch(value.try_as_i32().map_err(|_| create_error())?) .map_err(|_| create_error())? .into(), - (Some(DataType::Date), ValueType::String) => value + (DataType::Date, ValueType::String) => value .as_str() .unwrap() .parse::() .map_err(|_| create_error())? .into(), // ---- Varchar ----- - (Some(DataType::Varchar) | None, ValueType::String) => value.as_str().unwrap().into(), + (DataType::Varchar , ValueType::String) => value.as_str().unwrap().into(), ( - Some(DataType::Varchar), + DataType::Varchar, ValueType::Bool | ValueType::I64 | ValueType::I128 @@ -443,7 +436,7 @@ impl JsonParseOptions { value.to_string().into() } ( - Some(DataType::Varchar), + DataType::Varchar, ValueType::Bool | ValueType::I64 | ValueType::I128 @@ -456,14 +449,14 @@ impl JsonParseOptions { value.to_string().into() } // ---- Time ----- - (Some(DataType::Time), ValueType::String) => value + (DataType::Time, ValueType::String) => value .as_str() .unwrap() .parse:: ChangeEvent for (ChangeEventOperation, A) where A: Access, { - fn op(&self) -> std::result::Result { + fn op(&self) -> AccessResult { Ok(self.0) } fn access_field(&self, desc: &SourceColumnDesc) -> AccessResult { - self.1.access(&[desc.name.as_str()], Some(&desc.data_type)) + self.1.access(&[desc.name.as_str()], &desc.data_type) } } diff --git a/src/connector/src/parser/unified/protobuf.rs b/src/connector/src/parser/unified/protobuf.rs index cd9178c7dd08d..4455dea6a905d 100644 --- a/src/connector/src/parser/unified/protobuf.rs +++ b/src/connector/src/parser/unified/protobuf.rs @@ -38,7 +38,7 @@ impl ProtobufAccess { } impl Access for ProtobufAccess { - fn access(&self, path: &[&str], _type_expected: Option<&DataType>) -> AccessResult { + fn access(&self, path: &[&str], _type_expected: &DataType) -> AccessResult { debug_assert_eq!(1, path.len()); let field_desc = self .message diff --git a/src/connector/src/parser/unified/upsert.rs b/src/connector/src/parser/unified/upsert.rs index dbef878fc2685..50080d67020a0 100644 --- a/src/connector/src/parser/unified/upsert.rs +++ b/src/connector/src/parser/unified/upsert.rs @@ -12,60 +12,51 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::{Deref, DerefMut}; + use risingwave_common::types::DataType; use risingwave_pb::plan_common::additional_column::ColumnType as AdditionalColumnType; -use super::{Access, ChangeEvent, ChangeEventOperation}; +use super::{Access, ChangeEvent, ChangeEventOperation, NullableAccess}; use crate::parser::unified::AccessError; use crate::source::SourceColumnDesc; -/// `UpsertAccess` wraps a key-value message format into an upsert source. -/// A key accessor and a value accessor are required. -pub struct UpsertChangeEvent { +pub struct KvEvent { key_accessor: Option, value_accessor: Option, - key_column_name: Option, } -impl Default for UpsertChangeEvent { +impl Default for KvEvent { fn default() -> Self { Self { key_accessor: None, value_accessor: None, - key_column_name: None, } } } -impl UpsertChangeEvent { - pub fn with_key(mut self, key: K) -> Self +impl KvEvent { + pub fn with_key(&mut self, key: K) where K: Access, { self.key_accessor = Some(key); - self } - pub fn with_value(mut self, value: V) -> Self + pub fn with_value(&mut self, value: V) where V: Access, { self.value_accessor = Some(value); - self - } - - pub fn with_key_column_name(mut self, name: impl ToString) -> Self { - self.key_column_name = Some(name.to_string()); - self } } -impl UpsertChangeEvent +impl KvEvent where K: Access, V: Access, { - fn access_key(&self, path: &[&str], type_expected: Option<&DataType>) -> super::AccessResult { + fn access_key(&self, path: &[&str], type_expected: &DataType) -> super::AccessResult { if let Some(ka) = &self.key_accessor { ka.access(path, type_expected) } else { @@ -76,7 +67,7 @@ where } } - fn access_value(&self, path: &[&str], type_expected: Option<&DataType>) -> super::AccessResult { + fn access_value(&self, path: &[&str], type_expected: &DataType) -> super::AccessResult { if let Some(va) = &self.value_accessor { va.access(path, type_expected) } else { @@ -86,34 +77,62 @@ where }) } } + + pub fn access_field_impl(&self, desc: &SourceColumnDesc) -> super::AccessResult { + match desc.additional_column.column_type { + Some(AdditionalColumnType::Key(_)) => self.access_key(&[&desc.name], &desc.data_type), + None => self.access_value(&[&desc.name], &desc.data_type), + _ => unreachable!(), + } + } +} + +/// Wraps a key-value message into an upsert event, which uses `null` value to represent `DELETE`s. +pub struct UpsertChangeEvent(KvEvent); + +impl Default for UpsertChangeEvent { + fn default() -> Self { + Self(KvEvent::default()) + } +} + +impl Deref for UpsertChangeEvent { + type Target = KvEvent; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for UpsertChangeEvent { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } impl ChangeEvent for UpsertChangeEvent where K: Access, - V: Access, + V: NullableAccess, { fn op(&self) -> std::result::Result { - if let Ok(Some(_)) = self.access_value(&[], None) { - Ok(ChangeEventOperation::Upsert) + if let Some(va) = &self.0.value_accessor { + if va.is_null() { + Ok(ChangeEventOperation::Delete) + } else { + Ok(ChangeEventOperation::Upsert) + } } else { - Ok(ChangeEventOperation::Delete) + Err(AccessError::Undefined { + name: "value".to_string(), + path: String::new(), + }) } } fn access_field(&self, desc: &SourceColumnDesc) -> super::AccessResult { - match desc.additional_column.column_type { - Some(AdditionalColumnType::Key(_)) => { - if let Some(key_as_column_name) = &self.key_column_name - && &desc.name == key_as_column_name - { - self.access_key(&[], Some(&desc.data_type)) - } else { - self.access_key(&[&desc.name], Some(&desc.data_type)) - } - } - None => self.access_value(&[&desc.name], Some(&desc.data_type)), - _ => unreachable!(), - } + self.0.access_field_impl(desc) } } + +pub type PlainEvent = KvEvent; diff --git a/src/connector/src/parser/upsert_parser.rs b/src/connector/src/parser/upsert_parser.rs index 048fd0beca3ff..611a881babe4a 100644 --- a/src/connector/src/parser/upsert_parser.rs +++ b/src/connector/src/parser/upsert_parser.rs @@ -101,13 +101,13 @@ impl UpsertParser { UpsertChangeEvent::default(); let mut change_event_op = ChangeEventOperation::Delete; if let Some(data) = key { - row_op = row_op.with_key(self.key_builder.generate_accessor(data).await?); + row_op.with_key(self.key_builder.generate_accessor(data).await?); } // Empty payload of kafka is Some(vec![]) if let Some(data) = payload && !data.is_empty() { - row_op = row_op.with_value(self.payload_builder.generate_accessor(data).await?); + row_op.with_value(self.payload_builder.generate_accessor(data).await?); change_event_op = ChangeEventOperation::Upsert; }