Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(connector): avoid anyhow in AccessError and avoid using RwError if possible #14874

Merged
merged 3 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions src/common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub enum ErrorCode {
InternalError(String),
// TODO: unify with the above
#[error(transparent)]
InternalErrorAnyhow(
Uncategorized(
#[from]
#[backtrace]
anyhow::Error,
Expand Down Expand Up @@ -236,6 +236,12 @@ pub enum ErrorCode {
),
}

impl RwError {
pub fn uncategorized(err: impl Into<anyhow::Error>) -> Self {
Self::from(ErrorCode::Uncategorized(err.into()))
}
}

impl From<RwError> for tonic::Status {
fn from(err: RwError) -> Self {
use tonic::Code;
Expand Down Expand Up @@ -278,13 +284,13 @@ impl From<tonic::Status> for RwError {

impl From<JoinError> for RwError {
fn from(join_error: JoinError) -> Self {
anyhow::anyhow!(join_error).into()
Self::uncategorized(join_error)
}
}

impl From<std::net::AddrParseError> for RwError {
fn from(addr_parse_error: std::net::AddrParseError) -> Self {
anyhow::anyhow!(addr_parse_error).into()
Self::uncategorized(addr_parse_error)
}
}

Expand Down Expand Up @@ -456,7 +462,7 @@ mod tests {
use anyhow::anyhow;

use super::*;
use crate::error::ErrorCode::InternalErrorAnyhow;
use crate::error::ErrorCode::Uncategorized;

#[test]
fn test_display_internal_error() {
Expand All @@ -477,7 +483,7 @@ mod tests {
.unwrap_err();

assert_eq!(
RwError::from(InternalErrorAnyhow(anyhow!(err_msg))).to_string(),
RwError::from(Uncategorized(anyhow!(err_msg))).to_string(),
error.to_string(),
);
}
Expand All @@ -490,7 +496,7 @@ mod tests {
})()
.unwrap_err();
assert_eq!(
RwError::from(InternalErrorAnyhow(anyhow!(err_msg))).to_string(),
RwError::from(Uncategorized(anyhow!(err_msg))).to_string(),
error.to_string()
);
}
Expand All @@ -502,11 +508,7 @@ mod tests {
})()
.unwrap_err();
assert_eq!(
RwError::from(InternalErrorAnyhow(anyhow!(
"error msg with args: {}",
"xx"
)))
.to_string(),
RwError::from(Uncategorized(anyhow!("error msg with args: {}", "xx"))).to_string(),
error.to_string()
);
}
Expand Down
2 changes: 1 addition & 1 deletion src/connector/src/parser/debezium/debezium_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl DebeziumParser {
Err(err) => {
// Only try to access transaction control message if the row operation access failed
// to make it a fast path.
if let Ok(transaction_control) =
if let Some(transaction_control) =
row_op.transaction_control(&self.source_ctx.connector_props)
{
Ok(ParseResult::TransactionControl(transaction_control))
Expand Down
85 changes: 46 additions & 39 deletions src/connector/src/parser/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@ use prost_reflect::{
ReflectMessage, Value,
};
use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::error::ErrorCode::{InternalError, ProtocolError};
use risingwave_common::error::ErrorCode::ProtocolError;
use risingwave_common::error::{Result, RwError};
use risingwave_common::try_match_expand;
use risingwave_common::types::{DataType, Datum, Decimal, JsonbVal, ScalarImpl, F32, F64};
use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion};
use thiserror::Error;
use thiserror_ext::{AsReport, Macro};

use super::schema_resolver::*;
use crate::parser::unified::protobuf::ProtobufAccess;
use crate::parser::unified::AccessImpl;
use crate::parser::unified::{
bail_uncategorized, uncategorized, AccessError, AccessImpl, AccessResult,
};
use crate::parser::util::bytes_from_url;
use crate::parser::{AccessBuilder, EncodingProperties};
use crate::schema::schema_registry::{
Expand Down Expand Up @@ -157,7 +161,8 @@ impl ProtobufParserConfig {
index: &mut i32,
parse_trace: &mut Vec<String>,
) -> Result<ColumnDesc> {
let field_type = protobuf_type_mapping(field_descriptor, parse_trace)?;
let field_type =
protobuf_type_mapping(field_descriptor, parse_trace).map_err(RwError::uncategorized)?;
if let Kind::Message(m) = field_descriptor.kind() {
let field_descs = if let DataType::List { .. } = field_type {
vec![]
Expand Down Expand Up @@ -192,15 +197,22 @@ impl ProtobufParserConfig {
}
}

fn detect_loop_and_push(trace: &mut Vec<String>, fd: &FieldDescriptor) -> Result<()> {
#[derive(Error, Debug, Macro)]
#[error("{0}")]
struct ProtobufTypeError(#[message] String);

fn detect_loop_and_push(
trace: &mut Vec<String>,
fd: &FieldDescriptor,
) -> std::result::Result<(), ProtobufTypeError> {
let identifier = format!("{}({})", fd.name(), fd.full_name());
if trace.iter().any(|s| s == identifier.as_str()) {
return Err(RwError::from(ProtocolError(format!(
bail_protobuf_type_error!(
"circular reference detected: {}, conflict with {}, kind {:?}",
trace.iter().join("->"),
trace.iter().format("->"),
identifier,
fd.kind(),
))));
);
}
trace.push(identifier);
Ok(())
Expand Down Expand Up @@ -341,7 +353,9 @@ pub fn from_protobuf_value(
field_desc: &FieldDescriptor,
value: &Value,
descriptor_pool: &Arc<DescriptorPool>,
) -> Result<Datum> {
) -> AccessResult {
let kind = field_desc.kind();

let v = match value {
Value::Bool(v) => ScalarImpl::Bool(*v),
Value::I32(i) => ScalarImpl::Int32(*i),
Expand All @@ -352,17 +366,13 @@ pub fn from_protobuf_value(
Value::F64(f) => ScalarImpl::Float64(F64::from(*f)),
Value::String(s) => ScalarImpl::Utf8(s.as_str().into()),
Value::EnumNumber(idx) => {
let kind = field_desc.kind();
let enum_desc = kind.as_enum().ok_or_else(|| {
let err_msg = format!("protobuf parse error.not a enum desc {:?}", field_desc);
RwError::from(ProtocolError(err_msg))
let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError {
expected: "enum".to_owned(),
got: format!("{kind:?}"),
value: value.to_string(),
})?;
let enum_symbol = enum_desc.get_value(*idx).ok_or_else(|| {
let err_msg = format!(
"protobuf parse error.unknown enum index {} of enum {:?}",
idx, enum_desc
);
RwError::from(ProtocolError(err_msg))
uncategorized!("unknown enum index {} of enum {:?}", idx, enum_desc)
})?;
ScalarImpl::Utf8(enum_symbol.name().into())
}
Expand All @@ -389,18 +399,14 @@ pub fn from_protobuf_value(
let Some(ScalarImpl::Bytea(payload)) =
from_protobuf_value(&payload_field_desc, &payload, descriptor_pool)?
else {
let err_msg = "Expected ScalarImpl::Bytea for payload".to_string();
return Err(RwError::from(ProtocolError(err_msg)));
bail_uncategorized!("expected bytes for dynamic message payload");
};

// Get the corresponding schema from the descriptor pool
let msg_desc = descriptor_pool
.get_message_by_name(&type_url)
.ok_or_else(|| {
ProtocolError(format!(
"Cannot find message {} in from_protobuf_value.\nDescriptor pool is {:#?}",
type_url, descriptor_pool
))
uncategorized!("message `{type_url}` not found in descriptor pool")
})?;

let f = msg_desc
Expand Down Expand Up @@ -439,11 +445,10 @@ pub fn from_protobuf_value(
if !dyn_msg.has_field(&field_desc)
&& field_desc.cardinality() == Cardinality::Required
{
let err_msg = format!(
"protobuf parse error.missing required field {:?}",
field_desc
);
return Err(RwError::from(ProtocolError(err_msg)));
return Err(AccessError::Undefined {
name: field_desc.name().to_owned(),
path: dyn_msg.descriptor().full_name().to_owned(),
});
}
// use default value if dyn_msg doesn't has this field
let value = dyn_msg.get_field(&field_desc);
Expand All @@ -453,7 +458,11 @@ pub fn from_protobuf_value(
}
}
Value::List(values) => {
let data_type = protobuf_type_mapping(field_desc, &mut vec![])?;
let data_type = protobuf_type_mapping(field_desc, &mut vec![]).map_err(|e| {
AccessError::Uncategorized {
BugenZhao marked this conversation as resolved.
Show resolved Hide resolved
message: e.to_report_string(),
}
})?;
let mut builder = data_type.as_list().create_array_builder(values.len());
for value in values {
builder.append(from_protobuf_value(field_desc, value, descriptor_pool)?);
Expand All @@ -462,11 +471,9 @@ pub fn from_protobuf_value(
}
Value::Bytes(value) => ScalarImpl::Bytea(value.to_vec().into_boxed_slice()),
_ => {
let err_msg = format!(
"protobuf parse error.unsupported type {:?}, value {:?}",
field_desc, value
);
return Err(RwError::from(InternalError(err_msg)));
return Err(AccessError::UnsupportedType {
ty: format!("{kind:?}"),
});
}
};
Ok(Some(v))
Expand All @@ -476,7 +483,7 @@ pub fn from_protobuf_value(
fn protobuf_type_mapping(
field_descriptor: &FieldDescriptor,
parse_trace: &mut Vec<String>,
) -> Result<DataType> {
) -> std::result::Result<DataType, ProtobufTypeError> {
detect_loop_and_push(parse_trace, field_descriptor)?;
let field_type = field_descriptor.kind();
let mut t = match field_type {
Expand All @@ -494,7 +501,7 @@ fn protobuf_type_mapping(
let fields = m
.fields()
.map(|f| protobuf_type_mapping(&f, parse_trace))
.collect::<Result<Vec<_>>>()?;
.try_collect()?;
let field_names = m.fields().map(|f| f.name().to_string()).collect_vec();

// Note that this part is useful for actual parsing
Expand All @@ -513,10 +520,10 @@ fn protobuf_type_mapping(
Kind::Bytes => DataType::Bytea,
};
if field_descriptor.is_map() {
return Err(RwError::from(ProtocolError(format!(
"map type is unsupported (field: '{}')",
bail_protobuf_type_error!(
"protobuf map type (on field `{}`) is not supported",
field_descriptor.full_name()
))));
);
}
if field_descriptor.cardinality() == Cardinality::Repeated {
t = DataType::List(Box::new(t))
Expand Down
42 changes: 20 additions & 22 deletions src/connector/src/parser/unified/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use std::str::FromStr;
use std::sync::LazyLock;

use anyhow::anyhow;
use apache_avro::schema::{DecimalSchema, RecordSchema};
use apache_avro::types::Value;
use apache_avro::{Decimal as AvroDecimal, Schema};
Expand All @@ -30,7 +29,7 @@ use risingwave_common::types::{
};
use risingwave_common::util::iter_util::ZipEqFast;

use super::{Access, AccessError, AccessResult};
use super::{bail_uncategorized, uncategorized, Access, AccessError, AccessResult};
#[derive(Clone)]
/// Options for parsing an `AvroValue` into Datum, with an optional avro schema.
pub struct AvroParseOptions<'a> {
Expand Down Expand Up @@ -136,26 +135,25 @@ impl<'a> AvroParseOptions<'a> {
.iter()
.find(|field| field.0 == field_name)
.map(|field| &field.1)
.ok_or_else(|| {
uncategorized!("`{field_name}` field not found in VariableScaleDecimal")
})
};
let scale = match find_in_records("scale").ok_or_else(|| {
AccessError::Other(anyhow!("scale field not found in VariableScaleDecimal"))
})? {
Value::Int(scale) => Ok(*scale),
avro_value => Err(AccessError::Other(anyhow!(
let scale = match find_in_records("scale")? {
Value::Int(scale) => *scale,
avro_value => bail_uncategorized!(
"scale field in VariableScaleDecimal is not int, got {:?}",
avro_value
))),
}?;

let value: BigInt = match find_in_records("value").ok_or_else(|| {
AccessError::Other(anyhow!("value field not found in VariableScaleDecimal"))
})? {
Value::Bytes(bytes) => Ok(BigInt::from_signed_bytes_be(bytes)),
avro_value => Err(AccessError::Other(anyhow!(
),
};

let value: BigInt = match find_in_records("value")? {
Value::Bytes(bytes) => BigInt::from_signed_bytes_be(bytes),
avro_value => bail_uncategorized!(
"value field in VariableScaleDecimal is not bytes, got {:?}",
avro_value
))),
}?;
),
};

let negative = value.sign() == Sign::Minus;
let (lo, mid, hi) = extract_decimal(value.to_bytes_be().1)?;
Expand Down Expand Up @@ -196,9 +194,9 @@ impl<'a> AvroParseOptions<'a> {
// ---- TimestampTz -----
(Some(DataType::Timestamptz) | None, Value::TimestampMillis(ms)) => {
Timestamptz::from_millis(*ms)
.ok_or(AccessError::Other(anyhow!(
"timestamptz with milliseconds {ms} * 1000 is out of range",
)))?
.ok_or_else(|| {
uncategorized!("timestamptz with milliseconds {ms} * 1000 is out of range")
})?
.into()
}
(Some(DataType::Timestamptz) | None, Value::TimestampMicros(us)) => {
Expand Down Expand Up @@ -350,7 +348,7 @@ pub(crate) fn avro_decimal_to_rust_decimal(
))
}

pub(crate) fn extract_decimal(bytes: Vec<u8>) -> anyhow::Result<(u32, u32, u32)> {
pub(crate) fn extract_decimal(bytes: Vec<u8>) -> AccessResult<(u32, u32, u32)> {
match bytes.len() {
len @ 0..=4 => {
let mut pad = vec![0; 4 - len];
Expand Down Expand Up @@ -383,7 +381,7 @@ pub(crate) fn extract_decimal(bytes: Vec<u8>) -> anyhow::Result<(u32, u32, u32)>
let lo = u32::from_be_bytes(bytes[mid_end..].to_owned().try_into().unwrap());
Ok((lo, mid, hi))
}
_ => Err(anyhow!("decimal bytes len: {:?} > 12", bytes.len())),
_ => bail_uncategorized!("invalid decimal bytes length {}", bytes.len()),
}
}

Expand Down
Loading
Loading