From 05ab5e1cabbdde10b1e95d6461b1b0dbcfd4fa74 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 22 Nov 2023 16:07:11 +0800 Subject: [PATCH] refactor more Signed-off-by: Bugen Zhao --- src/batch/src/error.rs | 7 + src/batch/src/executor/sys_row_seq_scan.rs | 6 +- src/common/src/catalog/column.rs | 19 -- src/common/src/catalog/mod.rs | 4 +- src/common/src/types/interval.rs | 92 ++++----- src/common/src/types/mod.rs | 195 ++++++++---------- src/common/src/types/num256.rs | 5 +- src/common/src/types/postgres_type.rs | 9 +- src/common/src/util/match_util.rs | 40 ++-- src/frontend/src/binder/bind_param.rs | 9 +- src/frontend/src/binder/expr/value.rs | 3 +- .../src/catalog/system_catalog/mod.rs | 6 +- src/frontend/src/handler/util.rs | 5 +- src/utils/pgwire/src/pg_protocol.rs | 8 +- 14 files changed, 195 insertions(+), 213 deletions(-) diff --git a/src/batch/src/error.rs b/src/batch/src/error.rs index 4336b86055d8b..6894bea33b515 100644 --- a/src/batch/src/error.rs +++ b/src/batch/src/error.rs @@ -94,6 +94,13 @@ pub enum BatchError { BoxedError, ), + #[error("Failed to read from system table: {0}")] + SystemTable( + #[from] + #[backtrace] + BoxedError, + ), + // Make the ref-counted type to be a variant for easier code structuring. #[error(transparent)] Shared( diff --git a/src/batch/src/executor/sys_row_seq_scan.rs b/src/batch/src/executor/sys_row_seq_scan.rs index d0103d9883869..d28b0b95c5a38 100644 --- a/src/batch/src/executor/sys_row_seq_scan.rs +++ b/src/batch/src/executor/sys_row_seq_scan.rs @@ -107,7 +107,11 @@ impl Executor for SysRowSeqScanExecutor { impl SysRowSeqScanExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] async fn do_executor(self: Box) { - let rows = self.sys_catalog_reader.read_table(&self.table_id).await?; + let rows = self + .sys_catalog_reader + .read_table(&self.table_id) + .await + .map_err(BatchError::SystemTable)?; let filtered_rows = rows .iter() .map(|row| { diff --git a/src/common/src/catalog/column.rs b/src/common/src/catalog/column.rs index 7a984724bf116..68c1618073169 100644 --- a/src/common/src/catalog/column.rs +++ b/src/common/src/catalog/column.rs @@ -21,7 +21,6 @@ use risingwave_pb::plan_common::{PbColumnCatalog, PbColumnDesc}; use super::row_id_column_desc; use crate::catalog::{cdc_table_name_column_desc, offset_column_desc, Field, ROW_ID_COLUMN_ID}; -use crate::error::ErrorCode; use crate::types::DataType; /// Column ID is the unique identifier of a column in a table. Different from table ID, column ID is @@ -161,24 +160,6 @@ impl ColumnDesc { descs } - /// Find `column_desc` in `field_descs` by name. - pub fn field(&self, name: &String) -> crate::error::Result<(ColumnDesc, i32)> { - if let DataType::Struct { .. } = self.data_type { - for (index, col) in self.field_descs.iter().enumerate() { - if col.name == *name { - return Ok((col.clone(), index as i32)); - } - } - Err(ErrorCode::ItemNotFound(format!("Invalid field name: {}", name)).into()) - } else { - Err(ErrorCode::ItemNotFound(format!( - "Cannot get field from non nested column: {}", - self.name - )) - .into()) - } - } - pub fn new_atomic(data_type: DataType, name: &str, column_id: i32) -> Self { Self { data_type, diff --git a/src/common/src/catalog/mod.rs b/src/common/src/catalog/mod.rs index 204a5005cd2de..a8a698128d9b6 100644 --- a/src/common/src/catalog/mod.rs +++ b/src/common/src/catalog/mod.rs @@ -32,7 +32,7 @@ use risingwave_pb::catalog::HandleConflictBehavior as PbHandleConflictBehavior; pub use schema::{test_utils as schema_test_utils, Field, FieldDisplay, Schema}; pub use crate::constants::hummock; -use crate::error::Result; +use crate::error::BoxedError; use crate::row::OwnedRow; use crate::types::DataType; @@ -134,7 +134,7 @@ pub fn cdc_table_name_column_desc() -> ColumnDesc { /// The local system catalog reader in the frontend node. #[async_trait] pub trait SysCatalogReader: Sync + Send + 'static { - async fn read_table(&self, table_id: &TableId) -> Result>; + async fn read_table(&self, table_id: &TableId) -> Result, BoxedError>; } pub type SysCatalogReaderRef = Arc; diff --git a/src/common/src/types/interval.rs b/src/common/src/types/interval.rs index ca29b9a28abd3..a95bc412124b9 100644 --- a/src/common/src/types/interval.rs +++ b/src/common/src/types/interval.rs @@ -31,7 +31,6 @@ use rust_decimal::prelude::Decimal; use super::to_binary::ToBinary; use super::*; -use crate::error::{ErrorCode, Result, RwError}; use crate::estimate_size::EstimateSize; /// Every interval can be represented by a `Interval`. @@ -1001,6 +1000,23 @@ impl ToText for crate::types::Interval { } } +#[derive(thiserror::Error, Debug, thiserror_ext::Construct)] +pub enum IntervalParseError { + #[error("Invalid interval: {0}")] + Invalid(String), + + #[error("Invalid interval: {0}, expected format PYMDTHMS")] + InvalidIso8601(String), + + #[error("Invalid unit: {0}")] + InvalidUnit(String), + + #[error("{0}")] + Uncategorized(String), +} + +type ParseResult = std::result::Result; + impl Interval { pub fn as_iso_8601(&self) -> String { // ISO pattern - PnYnMnDTnHnMnS @@ -1029,7 +1045,7 @@ impl Interval { /// /// Example /// - P1Y2M3DT4H5M6.78S - pub fn from_iso_8601(s: &str) -> Result { + pub fn from_iso_8601(s: &str) -> ParseResult { // ISO pattern - PnYnMnDTnHnMnS static ISO_8601_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"P([0-9]+)Y([0-9]+)M([0-9]+)DT([0-9]+)H([0-9]+)M([0-9]+(?:\.[0-9]+)?)S") @@ -1061,7 +1077,7 @@ impl Interval { .checked_add(usecs)?, )) }; - f().ok_or_else(|| ErrorCode::InvalidInputSyntax(format!("Invalid interval: {}, expected format PYMDTHMS", s)).into()) + f().ok_or_else(|| IntervalParseError::invalid_iso8601(s)) } } @@ -1184,9 +1200,9 @@ pub enum DateTimeField { } impl FromStr for DateTimeField { - type Err = RwError; + type Err = IntervalParseError; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> ParseResult { match s.to_lowercase().as_str() { "years" | "year" | "yrs" | "yr" | "y" => Ok(Self::Year), "days" | "day" | "d" => Ok(Self::Day), @@ -1194,7 +1210,7 @@ impl FromStr for DateTimeField { "minutes" | "minute" | "mins" | "min" | "m" => Ok(Self::Minute), "months" | "month" | "mons" | "mon" => Ok(Self::Month), "seconds" | "second" | "secs" | "sec" | "s" => Ok(Self::Second), - _ => Err(ErrorCode::InvalidInputSyntax(format!("unknown unit {}", s)).into()), + _ => Err(IntervalParseError::invalid_unit(s)), } } } @@ -1206,7 +1222,7 @@ enum TimeStrToken { TimeUnit(DateTimeField), } -fn parse_interval(s: &str) -> Result> { +fn parse_interval(s: &str) -> ParseResult> { let s = s.trim(); let mut tokens = Vec::new(); let mut num_buf = "".to_string(); @@ -1235,21 +1251,16 @@ fn parse_interval(s: &str) -> Result> { ':' => { // there must be a digit before the ':' if num_buf.is_empty() { - return Err(ErrorCode::InvalidInputSyntax(format!( - "invalid interval format: {}", - s - )) - .into()); + return Err(IntervalParseError::invalid(s)); } hour_min_sec.push(num_buf.clone()); num_buf.clear(); } _ => { - return Err(ErrorCode::InvalidInputSyntax(format!( + return Err(IntervalParseError::uncategorized(format!( "Invalid character at offset {} in {}: {:?}. Only support digit or alphabetic now", i,s, c - )) - .into()); + ))); } }; } @@ -1262,23 +1273,20 @@ fn parse_interval(s: &str) -> Result> { convert_digit(&mut num_buf, &mut tokens)?; } convert_unit(&mut char_buf, &mut tokens)?; - convert_hms(&hour_min_sec, &mut tokens).ok_or_else(|| { - ErrorCode::InvalidInputSyntax(format!("Invalid interval: {:?}", hour_min_sec)) - })?; + convert_hms(&hour_min_sec, &mut tokens) + .ok_or_else(|| IntervalParseError::invalid(format!("{hour_min_sec:?}")))?; Ok(tokens) } -fn convert_digit(c: &mut String, t: &mut Vec) -> Result<()> { +fn convert_digit(c: &mut String, t: &mut Vec) -> ParseResult<()> { if !c.is_empty() { match c.parse::() { Ok(num) => { t.push(TimeStrToken::Num(num)); } Err(_) => { - return Err( - ErrorCode::InvalidInputSyntax(format!("Invalid interval: {}", c)).into(), - ); + return Err(IntervalParseError::invalid(c.clone())); } } c.clear(); @@ -1286,7 +1294,7 @@ fn convert_digit(c: &mut String, t: &mut Vec) -> Result<()> { Ok(()) } -fn convert_unit(c: &mut String, t: &mut Vec) -> Result<()> { +fn convert_unit(c: &mut String, t: &mut Vec) -> ParseResult<()> { if !c.is_empty() { t.push(TimeStrToken::TimeUnit(c.parse()?)); c.clear(); @@ -1338,25 +1346,17 @@ fn convert_hms(c: &Vec, t: &mut Vec) -> Option<()> { } impl Interval { - fn parse_sql_standard(s: &str, leading_field: DateTimeField) -> Result { + fn parse_sql_standard(s: &str, leading_field: DateTimeField) -> ParseResult { use DateTimeField::*; let tokens = parse_interval(s)?; // Todo: support more syntax if tokens.len() > 1 { - return Err(ErrorCode::InvalidInputSyntax(format!( - "(standard sql format) Can't support syntax of interval {}.", - &s - )) - .into()); + return Err(IntervalParseError::invalid(s)); } let num = match tokens.first() { Some(TimeStrToken::Num(num)) => *num, _ => { - return Err(ErrorCode::InvalidInputSyntax(format!( - "(standard sql format)Invalid interval {}.", - &s - )) - .into()); + return Err(IntervalParseError::invalid(s)); } }; @@ -1380,10 +1380,10 @@ impl Interval { Some(Interval::from_month_day_usec(0, 0, usecs)) } })() - .ok_or_else(|| ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", s)).into()) + .ok_or_else(|| IntervalParseError::invalid(s)) } - fn parse_postgres(s: &str) -> Result { + fn parse_postgres(s: &str) -> ParseResult { use DateTimeField::*; let mut tokens = parse_interval(s)?; if tokens.len() % 2 != 0 @@ -1392,7 +1392,7 @@ impl Interval { tokens.push(TimeStrToken::TimeUnit(DateTimeField::Second)); } if tokens.len() % 2 != 0 { - return Err(ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", &s)).into()); + return Err(IntervalParseError::invalid(s)); } let mut token_iter = tokens.into_iter(); let mut result = Interval::from_month_day_usec(0, 0, 0); @@ -1422,9 +1422,7 @@ impl Interval { } })() .and_then(|rhs| result.checked_add(&rhs)) - .ok_or_else(|| { - ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", s)) - })?; + .ok_or_else(|| IntervalParseError::invalid(s))?; } (TimeStrToken::Second(second), TimeStrToken::TimeUnit(interval_unit)) => { result = match interval_unit { @@ -1438,21 +1436,17 @@ impl Interval { _ => None, } .and_then(|rhs| result.checked_add(&rhs)) - .ok_or_else(|| { - ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", s)) - })?; + .ok_or_else(|| IntervalParseError::invalid(s))?; } _ => { - return Err( - ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", &s)).into(), - ); + return Err(IntervalParseError::invalid(s)); } } } Ok(result) } - pub fn parse_with_fields(s: &str, leading_field: Option) -> Result { + pub fn parse_with_fields(s: &str, leading_field: Option) -> ParseResult { if let Some(leading_field) = leading_field { Self::parse_sql_standard(s, leading_field) } else { @@ -1462,9 +1456,9 @@ impl Interval { } impl FromStr for Interval { - type Err = RwError; + type Err = IntervalParseError; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> ParseResult { Self::parse_with_fields(s, None) } } diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 70438b493d896..a700e5bcb0579 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -38,7 +38,7 @@ use crate::array::{ }; pub use crate::array::{ListRef, ListValue, StructRef, StructValue}; use crate::cast::{str_to_bool, str_to_bytea}; -use crate::error::{BoxedError, ErrorCode, Result as RwResult}; +use crate::error::BoxedError; use crate::estimate_size::EstimateSize; use crate::util::iter_util::ZipEqDebug; use crate::{ @@ -754,94 +754,94 @@ impl From> for ScalarImpl { } } +#[derive(Debug, thiserror::Error, thiserror_ext::Construct)] +pub enum FromSqlError { + #[error(transparent)] + FromBinary(BoxedError), + + #[error("Invalid param: {0}")] + FromText(String), + + #[error("Unsupported data type: {0}")] + Unsupported(DataType), +} + impl ScalarImpl { - pub fn from_binary(bytes: &Bytes, data_type: &DataType) -> RwResult { + pub fn from_binary(bytes: &Bytes, data_type: &DataType) -> Result { let res = match data_type { DataType::Varchar => Self::Utf8( String::from_sql(&Type::VARCHAR, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Bytea => Self::Bytea( Vec::::from_sql(&Type::BYTEA, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), - DataType::Boolean => Self::Bool( - bool::from_sql(&Type::BOOL, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, - ), - DataType::Int16 => Self::Int16( - i16::from_sql(&Type::INT2, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, - ), - DataType::Int32 => Self::Int32( - i32::from_sql(&Type::INT4, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, - ), - DataType::Int64 => Self::Int64( - i64::from_sql(&Type::INT8, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, - ), + DataType::Boolean => { + Self::Bool(bool::from_sql(&Type::BOOL, bytes).map_err(FromSqlError::from_binary)?) + } + DataType::Int16 => { + Self::Int16(i16::from_sql(&Type::INT2, bytes).map_err(FromSqlError::from_binary)?) + } + DataType::Int32 => { + Self::Int32(i32::from_sql(&Type::INT4, bytes).map_err(FromSqlError::from_binary)?) + } + DataType::Int64 => { + Self::Int64(i64::from_sql(&Type::INT8, bytes).map_err(FromSqlError::from_binary)?) + } DataType::Serial => Self::Serial(Serial::from( - i64::from_sql(&Type::INT8, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + i64::from_sql(&Type::INT8, bytes).map_err(FromSqlError::from_binary)?, )), DataType::Float32 => Self::Float32( f32::from_sql(&Type::FLOAT4, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Float64 => Self::Float64( f64::from_sql(&Type::FLOAT8, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Decimal => Self::Decimal( rust_decimal::Decimal::from_sql(&Type::NUMERIC, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Date => Self::Date( chrono::NaiveDate::from_sql(&Type::DATE, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Time => Self::Time( chrono::NaiveTime::from_sql(&Type::TIME, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Timestamp => Self::Timestamp( chrono::NaiveDateTime::from_sql(&Type::TIMESTAMP, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Timestamptz => Self::Timestamptz( chrono::DateTime::::from_sql(&Type::TIMESTAMPTZ, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Interval => Self::Interval( - Interval::from_sql(&Type::INTERVAL, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + Interval::from_sql(&Type::INTERVAL, bytes).map_err(FromSqlError::from_binary)?, ), - DataType::Jsonb => { - Self::Jsonb(JsonbVal::value_deserialize(bytes).ok_or_else(|| { - ErrorCode::InvalidInputSyntax("Invalid value of Jsonb".to_string()) - })?) - } - DataType::Int256 => Self::Int256( - Int256::from_binary(bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + DataType::Jsonb => Self::Jsonb( + JsonbVal::value_deserialize(bytes) + .ok_or_else(|| FromSqlError::from_binary("Invalid value of Jsonb"))?, ), + DataType::Int256 => { + Self::Int256(Int256::from_binary(bytes).map_err(FromSqlError::from_binary)?) + } DataType::Struct(_) | DataType::List { .. } => { - return Err(ErrorCode::NotSupported( - format!("param type: {}", data_type), - "".to_string(), - ) - .into()) + return Err(FromSqlError::Unsupported(data_type.clone())); } }; Ok(res) @@ -856,78 +856,66 @@ impl ScalarImpl { std::str::from_utf8(without_null) } - pub fn from_text(bytes: &[u8], data_type: &DataType) -> RwResult { - let str = Self::cstr_to_str(bytes).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {:?}", bytes)) - })?; + pub fn from_text(bytes: &[u8], data_type: &DataType) -> Result { + let str = + Self::cstr_to_str(bytes).map_err(|_| FromSqlError::from_text(format!("{bytes:?}")))?; let res = match data_type { DataType::Varchar => Self::Utf8(str.to_string().into()), - DataType::Boolean => Self::Bool(bool::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Int16 => Self::Int16(i16::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Int32 => Self::Int32(i32::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Int64 => Self::Int64(i64::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Int256 => Self::Int256(Int256::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Serial => Self::Serial(Serial::from(i64::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?)), + DataType::Boolean => { + Self::Bool(bool::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Int16 => { + Self::Int16(i16::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Int32 => { + Self::Int32(i32::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Int64 => { + Self::Int64(i64::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Int256 => { + Self::Int256(Int256::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Serial => Self::Serial(Serial::from( + i64::from_str(str).map_err(|_| FromSqlError::from_text(str))?, + )), DataType::Float32 => Self::Float32( f32::from_str(str) - .map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })? + .map_err(|_| FromSqlError::from_text(str))? .into(), ), DataType::Float64 => Self::Float64( f64::from_str(str) - .map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })? + .map_err(|_| FromSqlError::from_text(str))? .into(), ), DataType::Decimal => Self::Decimal( rust_decimal::Decimal::from_str(str) - .map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })? + .map_err(|_| FromSqlError::from_text(str))? .into(), ), - DataType::Date => Self::Date(Date::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Time => Self::Time(Time::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Timestamp => Self::Timestamp(Timestamp::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Timestamptz => { - Self::Timestamptz(Timestamptz::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?) + DataType::Date => { + Self::Date(Date::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Time => { + Self::Time(Time::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Timestamp => { + Self::Timestamp(Timestamp::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Timestamptz => Self::Timestamptz( + Timestamptz::from_str(str).map_err(|_| FromSqlError::from_text(str))?, + ), + DataType::Interval => { + Self::Interval(Interval::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Jsonb => { + Self::Jsonb(JsonbVal::from_str(str).map_err(|_| FromSqlError::from_text(str))?) } - DataType::Interval => Self::Interval(Interval::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Jsonb => Self::Jsonb(JsonbVal::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), DataType::List(datatype) => { // TODO: support nested list if !(str.starts_with('{') && str.ends_with('}')) { - return Err(ErrorCode::InvalidInputSyntax(format!( - "Invalid param string: {str}", - )) - .into()); + return Err(FromSqlError::from_text(str)); } let mut values = vec![]; for s in str[1..str.len() - 1].split(',') { @@ -937,10 +925,7 @@ impl ScalarImpl { } DataType::Struct(s) => { if !(str.starts_with('{') && str.ends_with('}')) { - return Err(ErrorCode::InvalidInputSyntax(format!( - "Invalid param string: {str}", - )) - .into()); + return Err(FromSqlError::from_text(str)); } let mut fields = Vec::with_capacity(s.len()); for (s, ty) in str[1..str.len() - 1].split(',').zip_eq_debug(s.types()) { @@ -949,11 +934,7 @@ impl ScalarImpl { ScalarImpl::Struct(StructValue::new(fields)) } DataType::Bytea => { - return Err(ErrorCode::NotSupported( - format!("param type: {}", data_type), - "".to_string(), - ) - .into()) + return Err(FromSqlError::unsupported(data_type.clone())); } }; Ok(res) diff --git a/src/common/src/types/num256.rs b/src/common/src/types/num256.rs index 26c0edaef59e4..864af97deb374 100644 --- a/src/common/src/types/num256.rs +++ b/src/common/src/types/num256.rs @@ -165,7 +165,10 @@ macro_rules! impl_common_for_num256 { } impl ToBinary for $scalar_ref<'_> { - fn to_binary_with_type(&self, _ty: &DataType) -> super::to_binary::Result> { + fn to_binary_with_type( + &self, + _ty: &DataType, + ) -> super::to_binary::Result> { let mut output = bytes::BytesMut::new(); let buffer = self.to_be_bytes(); output.put_slice(&buffer); diff --git a/src/common/src/types/postgres_type.rs b/src/common/src/types/postgres_type.rs index 5e470182bad63..b43b955f8ae23 100644 --- a/src/common/src/types/postgres_type.rs +++ b/src/common/src/types/postgres_type.rs @@ -13,7 +13,6 @@ // limitations under the License. use super::DataType; -use crate::error::ErrorCode; /// `DataType` information extracted from PostgreSQL `pg_type` /// @@ -49,6 +48,10 @@ macro_rules! for_all_base_types { }; } +#[derive(Debug, thiserror::Error)] +#[error("Unsupported oid {0}")] +pub struct UnsupportedOid(i32); + /// Get type information compatible with Postgres type, such as oid, type length. impl DataType { pub fn type_len(&self) -> i16 { @@ -73,7 +76,7 @@ impl DataType { // Such as: // https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat#L347 // For Numeric(aka Decimal): oid = 1700, array_type_oid = 1231 - pub fn from_oid(oid: i32) -> crate::error::Result { + pub fn from_oid(oid: i32) -> Result { macro_rules! impl_from_oid { ($( { $enum:ident | $oid:literal | $oid_array:literal | $name:ident | $input:ident | $len:literal } )*) => { match oid { @@ -86,7 +89,7 @@ impl DataType { // workaround to support text in extended mode. 25 => Ok(DataType::Varchar), 1009 => Ok(DataType::List(Box::new(DataType::Varchar))), - _ => Err(ErrorCode::InternalError(format!("Unsupported oid {}", oid)).into()), + _ => Err(UnsupportedOid(oid)), } } } diff --git a/src/common/src/util/match_util.rs b/src/common/src/util/match_util.rs index 26982812d6499..9591f05340761 100644 --- a/src/common/src/util/match_util.rs +++ b/src/common/src/util/match_util.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +/// Try to match an enum variant and return the internal value. +/// +/// Return an [`anyhow::Error`] if the enum variant does not match. #[macro_export] macro_rules! try_match_expand { ($e:expr, $variant:path) => { @@ -32,6 +35,9 @@ macro_rules! try_match_expand { }; } +/// Match an enum variant and return the internal value. +/// +/// Panic if the enum variant does not match. #[macro_export] macro_rules! must_match { ($expression:expr, $(|)? $( $pattern:pat_param )|+ $( if $guard: expr )? => $action:expr) => { @@ -43,41 +49,39 @@ macro_rules! must_match { } mod tests { + #[derive(thiserror::Error, Debug)] + #[error(transparent)] + struct ExpandError(#[from] anyhow::Error); + + #[allow(dead_code)] + enum MyEnum { + A(String), + B, + } + #[test] - fn test_try_match() -> crate::error::Result<()> { + fn test_try_match() -> Result<(), ExpandError> { assert_eq!( - try_match_expand!( - crate::error::ErrorCode::InternalError("failure".to_string()), - crate::error::ErrorCode::InternalError - )?, + try_match_expand!(MyEnum::A("failure".to_string()), MyEnum::A)?, "failure" ); assert_eq!( - try_match_expand!( - crate::error::ErrorCode::InternalError("failure".to_string()), - crate::error::ErrorCode::InternalError - )?, + try_match_expand!(MyEnum::A("failure".to_string()), MyEnum::A)?, "failure" ); assert_eq!( - try_match_expand!( - crate::error::ErrorCode::InternalError("failure".to_string()), - crate::error::ErrorCode::InternalError - )?, + try_match_expand!(MyEnum::A("failure".to_string()), MyEnum::A)?, "failure" ); // Test let statement is compilable. - let err_str = try_match_expand!( - crate::error::ErrorCode::InternalError("failure".to_string()), - crate::error::ErrorCode::InternalError - )?; + let err_str = try_match_expand!(MyEnum::A("failure".to_string()), MyEnum::A)?; assert_eq!(err_str, "failure"); Ok(()) } #[test] - fn test_must_match() -> crate::error::Result<()> { + fn test_must_match() -> Result<(), ExpandError> { #[allow(dead_code)] enum A { Foo, diff --git a/src/frontend/src/binder/bind_param.rs b/src/frontend/src/binder/bind_param.rs index 44c11f62393ae..7f35b107c9dca 100644 --- a/src/frontend/src/binder/bind_param.rs +++ b/src/frontend/src/binder/bind_param.rs @@ -14,8 +14,9 @@ use bytes::Bytes; use pgwire::types::{Format, FormatIterator}; -use risingwave_common::error::{ErrorCode, Result, RwError}; -use risingwave_common::types::{Datum, ScalarImpl}; +use risingwave_common::bail; +use risingwave_common::error::{ErrorCode, Result}; +use risingwave_common::types::{Datum, FromSqlError, ScalarImpl}; use super::statement::RewriteExprsRecursive; use super::BoundStatement; @@ -26,7 +27,7 @@ pub(crate) struct ParamRewriter { pub(crate) params: Vec>, pub(crate) parsed_params: Vec, pub(crate) param_formats: Vec, - pub(crate) error: Option, + pub(crate) error: Option, } impl ParamRewriter { @@ -107,7 +108,7 @@ impl BoundStatement { self.rewrite_exprs_recursive(&mut rewriter); if let Some(err) = rewriter.error { - return Err(err); + bail!(err); } Ok((self, rewriter.parsed_params)) diff --git a/src/frontend/src/binder/expr/value.rs b/src/frontend/src/binder/expr/value.rs index e5ae8bb4e9156..54559266a136f 100644 --- a/src/frontend/src/binder/expr/value.rs +++ b/src/frontend/src/binder/expr/value.rs @@ -72,7 +72,8 @@ impl Binder { leading_field: Option, ) -> Result { let interval = - Interval::parse_with_fields(&s, leading_field.map(Self::bind_date_time_field))?; + Interval::parse_with_fields(&s, leading_field.map(Self::bind_date_time_field)) + .map_err(|e| ErrorCode::BindError(e.to_string()))?; let datum = Some(ScalarImpl::Interval(interval)); let literal = Literal::new(datum, DataType::Interval); diff --git a/src/frontend/src/catalog/system_catalog/mod.rs b/src/frontend/src/catalog/system_catalog/mod.rs index 897171cfd38b4..c85cddd4deab9 100644 --- a/src/frontend/src/catalog/system_catalog/mod.rs +++ b/src/frontend/src/catalog/system_catalog/mod.rs @@ -26,7 +26,7 @@ use risingwave_common::catalog::{ ColumnCatalog, ColumnDesc, Field, SysCatalogReader, TableDesc, TableId, DEFAULT_SUPER_USER_ID, NON_RESERVED_SYS_CATALOG_ID, }; -use risingwave_common::error::Result; +use risingwave_common::error::BoxedError; use risingwave_common::row::OwnedRow; use risingwave_common::types::DataType; use risingwave_pb::user::grant_privilege::Object; @@ -314,14 +314,14 @@ macro_rules! prepare_sys_catalog { #[async_trait] impl SysCatalogReader for SysCatalogReaderImpl { - async fn read_table(&self, table_id: &TableId) -> Result> { + async fn read_table(&self, table_id: &TableId) -> Result, BoxedError> { let table_name = SYS_CATALOGS.table_name_by_id.get(table_id).unwrap(); $( if $builtin_catalog.name() == *table_name { $( let rows = self.$func(); $(let rows = rows.$await;)? - return rows; + return Ok(rows?); )? } )* diff --git a/src/frontend/src/handler/util.rs b/src/frontend/src/handler/util.rs index 1be30c2d470eb..1c8dbd9d0714e 100644 --- a/src/frontend/src/handler/util.rs +++ b/src/frontend/src/handler/util.rs @@ -17,6 +17,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use anyhow::Context as _; use bytes::Bytes; use futures::Stream; use itertools::Itertools; @@ -125,7 +126,9 @@ fn pg_value_format( Ok(d.text_format(data_type).into()) } } - Format::Binary => d.binary_format(data_type), + Format::Binary => Ok(d + .binary_format(data_type) + .context("failed to format binary value")?), } } diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 0fba456b39207..f912860794a3f 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -27,7 +27,6 @@ use futures::future::Either; use futures::stream::StreamExt; use itertools::Itertools; use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod}; -use risingwave_common::error::RwError; use risingwave_common::types::DataType; use risingwave_common::util::panic::FutureCatchUnwindExt; use risingwave_sqlparser::ast::Statement; @@ -648,11 +647,12 @@ where if id == 0 { Ok(None) } else { - Ok(Some(DataType::from_oid(id)?)) + DataType::from_oid(id) + .map(Some) + .map_err(|e| PsqlError::ParseError(e.into())) } }) - .try_collect() - .map_err(|err: RwError| PsqlError::ParseError(err.into()))?; + .try_collect()?; let prepare_statement = session .parse(stmt, param_types)