Skip to content

Commit

Permalink
refactor(source): refactor Access
Browse files Browse the repository at this point in the history
  • Loading branch information
xxchan committed May 30, 2024
1 parent 8d4f0a1 commit bf55e90
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 215 deletions.
3 changes: 1 addition & 2 deletions src/connector/src/parser/json_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ impl JsonParser {
let mut errors = Vec::new();
for value in values {
let accessor = JsonAccess::new(value);
match writer.insert(|column| accessor.access(&[&column.name], Some(&column.data_type)))
{
match writer.insert(|column| accessor.access(&[&column.name], &column.data_type)) {
Ok(_) => {}
Err(err) => errors.push(err),
}
Expand Down
13 changes: 5 additions & 8 deletions src/connector/src/parser/plain_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use risingwave_common::bail;

use super::unified::json::TimestamptzHandling;
use super::unified::ChangeEvent;
use super::unified::upsert::PlainEvent;
use super::{
AccessBuilderImpl, ByteStreamSourceParser, EncodingProperties, EncodingType,
SourceStreamChunkRowWriter, SpecificParserConfig,
Expand All @@ -24,7 +24,6 @@ use crate::error::ConnectorResult;
use crate::parser::bytes_parser::BytesAccessBuilder;
use crate::parser::simd_json_parser::DebeziumJsonAccessBuilder;
use crate::parser::unified::debezium::parse_transaction_meta;
use crate::parser::unified::upsert::UpsertChangeEvent;
use crate::parser::unified::AccessImpl;
use crate::parser::upsert_parser::get_key_column_name;
use crate::parser::{BytesProperties, ParseResult, ParserFormat};
Expand Down Expand Up @@ -103,22 +102,20 @@ impl PlainParser {
};
}

// reuse upsert component but always insert
let mut row_op: UpsertChangeEvent<AccessImpl<'_, '_>, AccessImpl<'_, '_>> =
UpsertChangeEvent::default();
let mut row_op: PlainEvent<AccessImpl<'_, '_>, AccessImpl<'_, '_>> = PlainEvent::default();

if let Some(data) = key
&& let Some(key_builder) = self.key_builder.as_mut()
{
// key is optional in format plain
row_op = row_op.with_key(key_builder.generate_accessor(data).await?);
row_op.with_key(key_builder.generate_accessor(data).await?);
}
if let Some(data) = payload {
// the data part also can be an empty vec
row_op = row_op.with_value(self.payload_builder.generate_accessor(data).await?);
row_op.with_value(self.payload_builder.generate_accessor(data).await?);
}

writer.insert(|column: &SourceColumnDesc| row_op.access_field(column))?;
writer.insert(|column: &SourceColumnDesc| row_op.access_field_impl(column))?;

Ok(ParseResult::Rows)
}
Expand Down
3 changes: 2 additions & 1 deletion src/connector/src/parser/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,8 @@ mod test {
}

fn pb_eq(a: &ProtobufAccess, field_name: &str, value: ScalarImpl) {
let d = a.access(&[field_name], None).unwrap().unwrap();
let dummy_type = DataType::Varchar;
let d = a.access(&[field_name], &dummy_type).unwrap().unwrap();
assert_eq!(d, value, "field: {} value: {:?}", field_name, d);
}

Expand Down
129 changes: 56 additions & 73 deletions src/connector/src/parser/unified/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::bail;
use risingwave_common::log::LogSuppresser;
use risingwave_common::types::{
DataType, Date, Datum, Interval, JsonbVal, ScalarImpl, Time, Timestamp, Timestamptz,
DataType, Date, Interval, JsonbVal, ScalarImpl, Time, Timestamp, Timestamptz,
};
use risingwave_common::util::iter_util::ZipEqFast;

use super::{bail_uncategorized, uncategorized, Access, AccessError, AccessResult};
use super::{bail_uncategorized, uncategorized, Access, AccessError, AccessResult, NullableAccess};
use crate::error::ConnectorResult;
use crate::parser::avro::util::avro_to_jsonb;
#[derive(Clone)]
Expand Down Expand Up @@ -82,7 +82,7 @@ impl<'a> AvroParseOptions<'a> {
pub fn convert_to_datum<'b>(
&self,
value: &'b Value,
type_expected: Option<&'b DataType>,
type_expected: &'b DataType,
) -> AccessResult
where
'b: 'a,
Expand All @@ -104,25 +104,25 @@ impl<'a> AvroParseOptions<'a> {
.convert_to_datum(v, type_expected);
}
// ---- Boolean -----
(Some(DataType::Boolean) | None, Value::Boolean(b)) => (*b).into(),
(DataType::Boolean, Value::Boolean(b)) => (*b).into(),
// ---- Int16 -----
(Some(DataType::Int16), Value::Int(i)) if self.relax_numeric => (*i as i16).into(),
(Some(DataType::Int16), Value::Long(i)) if self.relax_numeric => (*i as i16).into(),
(DataType::Int16, Value::Int(i)) if self.relax_numeric => (*i as i16).into(),
(DataType::Int16, Value::Long(i)) if self.relax_numeric => (*i as i16).into(),

// ---- Int32 -----
(Some(DataType::Int32) | None, Value::Int(i)) => (*i).into(),
(Some(DataType::Int32), Value::Long(i)) if self.relax_numeric => (*i as i32).into(),
(DataType::Int32, Value::Int(i)) => (*i).into(),
(DataType::Int32, Value::Long(i)) if self.relax_numeric => (*i as i32).into(),
// ---- Int64 -----
(Some(DataType::Int64) | None, Value::Long(i)) => (*i).into(),
(Some(DataType::Int64), Value::Int(i)) if self.relax_numeric => (*i as i64).into(),
(DataType::Int64, Value::Long(i)) => (*i).into(),
(DataType::Int64, Value::Int(i)) if self.relax_numeric => (*i as i64).into(),
// ---- Float32 -----
(Some(DataType::Float32) | None, Value::Float(i)) => (*i).into(),
(Some(DataType::Float32), Value::Double(i)) => (*i as f32).into(),
(DataType::Float32, Value::Float(i)) => (*i).into(),
(DataType::Float32, Value::Double(i)) => (*i as f32).into(),
// ---- Float64 -----
(Some(DataType::Float64) | None, Value::Double(i)) => (*i).into(),
(Some(DataType::Float64), Value::Float(i)) => (*i as f64).into(),
(DataType::Float64, Value::Double(i)) => (*i).into(),
(DataType::Float64, Value::Float(i)) => (*i as f64).into(),
// ---- Decimal -----
(Some(DataType::Decimal) | None, Value::Decimal(avro_decimal)) => {
(DataType::Decimal, Value::Decimal(avro_decimal)) => {
let (precision, scale) = match self.schema {
Some(Schema::Decimal(DecimalSchema {
precision, scale, ..
Expand All @@ -133,7 +133,7 @@ impl<'a> AvroParseOptions<'a> {
.map_err(|_| create_error())?;
ScalarImpl::Decimal(risingwave_common::types::Decimal::Normalized(decimal))
}
(Some(DataType::Decimal), Value::Record(fields)) => {
(DataType::Decimal, Value::Record(fields)) => {
// VariableScaleDecimal has fixed fields, scale(int) and value(bytes)
let find_in_records = |field_name: &str| {
fields
Expand Down Expand Up @@ -167,56 +167,46 @@ impl<'a> AvroParseOptions<'a> {
ScalarImpl::Decimal(risingwave_common::types::Decimal::Normalized(decimal))
}
// ---- Time -----
(Some(DataType::Time), Value::TimeMillis(ms)) => Time::with_milli(*ms as u32)
(DataType::Time, Value::TimeMillis(ms)) => Time::with_milli(*ms as u32)
.map_err(|_| create_error())?
.into(),
(Some(DataType::Time), Value::TimeMicros(us)) => Time::with_micro(*us as u64)
(DataType::Time, Value::TimeMicros(us)) => Time::with_micro(*us as u64)
.map_err(|_| create_error())?
.into(),
// ---- Date -----
(Some(DataType::Date) | None, Value::Date(days)) => {
Date::with_days(days + unix_epoch_days())
.map_err(|_| create_error())?
.into()
}
(DataType::Date, Value::Date(days)) => Date::with_days(days + unix_epoch_days())
.map_err(|_| create_error())?
.into(),
// ---- Varchar -----
(Some(DataType::Varchar) | None, Value::Enum(_, symbol)) => {
symbol.clone().into_boxed_str().into()
}
(Some(DataType::Varchar) | None, Value::String(s)) => s.clone().into_boxed_str().into(),
(DataType::Varchar, Value::Enum(_, symbol)) => symbol.clone().into_boxed_str().into(),
(DataType::Varchar, Value::String(s)) => s.clone().into_boxed_str().into(),
// ---- Timestamp -----
(Some(DataType::Timestamp) | None, Value::LocalTimestampMillis(ms)) => {
Timestamp::with_millis(*ms)
.map_err(|_| create_error())?
.into()
}
(Some(DataType::Timestamp) | None, Value::LocalTimestampMicros(us)) => {
Timestamp::with_micros(*us)
.map_err(|_| create_error())?
.into()
}
(DataType::Timestamp, Value::LocalTimestampMillis(ms)) => Timestamp::with_millis(*ms)
.map_err(|_| create_error())?
.into(),
(DataType::Timestamp, Value::LocalTimestampMicros(us)) => Timestamp::with_micros(*us)
.map_err(|_| create_error())?
.into(),

// ---- TimestampTz -----
(Some(DataType::Timestamptz) | None, Value::TimestampMillis(ms)) => {
Timestamptz::from_millis(*ms)
.ok_or_else(|| {
uncategorized!("timestamptz with milliseconds {ms} * 1000 is out of range")
})?
.into()
}
(Some(DataType::Timestamptz) | None, Value::TimestampMicros(us)) => {
(DataType::Timestamptz, Value::TimestampMillis(ms)) => Timestamptz::from_millis(*ms)
.ok_or_else(|| {
uncategorized!("timestamptz with milliseconds {ms} * 1000 is out of range")
})?
.into(),
(DataType::Timestamptz, Value::TimestampMicros(us)) => {
Timestamptz::from_micros(*us).into()
}

// ---- Interval -----
(Some(DataType::Interval) | None, Value::Duration(duration)) => {
(DataType::Interval, Value::Duration(duration)) => {
let months = u32::from(duration.months()) as i32;
let days = u32::from(duration.days()) as i32;
let usecs = (u32::from(duration.millis()) as i64) * 1000; // never overflows
ScalarImpl::Interval(Interval::from_month_day_usec(months, days, usecs))
}
// ---- Struct -----
(Some(DataType::Struct(struct_type_info)), Value::Record(descs)) => StructValue::new(
(DataType::Struct(struct_type_info), Value::Record(descs)) => StructValue::new(
struct_type_info
.names()
.zip_eq_fast(struct_type_info.types())
Expand All @@ -228,49 +218,33 @@ impl<'a> AvroParseOptions<'a> {
schema,
relax_numeric: self.relax_numeric,
}
.convert_to_datum(value, Some(field_type))?)
.convert_to_datum(value, field_type)?)
} else {
Ok(None)
}
})
.collect::<Result<_, AccessError>>()?,
)
.into(),
(None, Value::Record(descs)) => {
let rw_values = descs
.iter()
.map(|(field_name, field_value)| {
let schema = self.extract_inner_schema(Some(field_name));
Self {
schema,
relax_numeric: self.relax_numeric,
}
.convert_to_datum(field_value, None)
})
.collect::<Result<Vec<Datum>, AccessError>>()?;
ScalarImpl::Struct(StructValue::new(rw_values))
}
// ---- List -----
(Some(DataType::List(item_type)), Value::Array(array)) => ListValue::new({
(DataType::List(item_type), Value::Array(array)) => ListValue::new({
let schema = self.extract_inner_schema(None);
let mut builder = item_type.create_array_builder(array.len());
for v in array {
let value = Self {
schema,
relax_numeric: self.relax_numeric,
}
.convert_to_datum(v, Some(item_type))?;
.convert_to_datum(v, item_type)?;
builder.append(value);
}
builder.finish()
})
.into(),
// ---- Bytea -----
(Some(DataType::Bytea) | None, Value::Bytes(value)) => {
value.clone().into_boxed_slice().into()
}
(DataType::Bytea, Value::Bytes(value)) => value.clone().into_boxed_slice().into(),
// ---- Jsonb -----
(Some(DataType::Jsonb), v @ Value::Map(_)) => {
(DataType::Jsonb, v @ Value::Map(_)) => {
let mut builder = jsonbb::Builder::default();
avro_to_jsonb(v, &mut builder)?;
let jsonb = builder.finish();
Expand Down Expand Up @@ -299,7 +273,7 @@ impl<'a, 'b> Access for AvroAccess<'a, 'b>
where
'a: 'b,
{
fn access(&self, path: &[&str], type_expected: Option<&DataType>) -> AccessResult {
fn access(&self, path: &[&str], type_expected: &DataType) -> AccessResult {
let mut value = self.value;
let mut options: AvroParseOptions<'_> = self.options.clone();

Expand Down Expand Up @@ -333,6 +307,15 @@ where
}
}

impl<'a, 'b> NullableAccess for AvroAccess<'a, 'b>
where
'a: 'b,
{
fn is_null(&self) -> bool {
matches!(self.value, Value::Null)
}
}

pub(crate) fn avro_decimal_to_rust_decimal(
avro_decimal: AvroDecimal,
_precision: usize,
Expand Down Expand Up @@ -436,7 +419,7 @@ mod tests {
use std::str::FromStr;

use apache_avro::Decimal as AvroDecimal;
use risingwave_common::types::Decimal;
use risingwave_common::types::{Datum, Decimal};

use super::*;

Expand Down Expand Up @@ -489,7 +472,7 @@ mod tests {
shape: &DataType,
) -> crate::error::ConnectorResult<Datum> {
AvroParseOptions::create(value_schema)
.convert_to_datum(&value, Some(shape))
.convert_to_datum(&value, shape)
.map_err(Into::into)
}

Expand Down Expand Up @@ -532,7 +515,7 @@ mod tests {
let value = Value::Decimal(AvroDecimal::from(bytes));
let options = AvroParseOptions::create(&schema);
let resp = options
.convert_to_datum(&value, Some(&DataType::Decimal))
.convert_to_datum(&value, &DataType::Decimal)
.unwrap();
assert_eq!(
resp,
Expand Down Expand Up @@ -571,7 +554,7 @@ mod tests {

let options = AvroParseOptions::create(&schema);
let resp = options
.convert_to_datum(&value, Some(&DataType::Decimal))
.convert_to_datum(&value, &DataType::Decimal)
.unwrap();
assert_eq!(resp, Some(ScalarImpl::Decimal(Decimal::from(66051))));
}
Expand Down
4 changes: 2 additions & 2 deletions src/connector/src/parser/unified/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ impl<'a> BytesAccess<'a> {

impl<'a> Access for BytesAccess<'a> {
/// path is empty currently, `type_expected` should be `Bytea`
fn access(&self, path: &[&str], type_expected: Option<&DataType>) -> AccessResult {
if let DataType::Bytea = type_expected.unwrap() {
fn access(&self, path: &[&str], type_expected: &DataType) -> AccessResult {
if let DataType::Bytea = type_expected {
if self.column_name.is_none()
|| (path.len() == 1 && self.column_name.as_ref().unwrap() == path[0])
{
Expand Down
Loading

0 comments on commit bf55e90

Please sign in to comment.