From b1c0b3e81643b1d007dfd54177d46457fd742d94 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 22 Nov 2023 15:09:57 +0800 Subject: [PATCH 1/4] refactor some Signed-off-by: Bugen Zhao --- Cargo.lock | 5 ++ src/common/Cargo.toml | 1 + src/common/common_service/Cargo.toml | 2 + src/common/common_service/src/lib.rs | 1 + .../common_service/src/observer_manager.rs | 46 +++++++++++++------ src/common/heap_profiling/Cargo.toml | 2 + src/common/heap_profiling/src/jeprof.rs | 45 ++++++++++-------- src/common/heap_profiling/src/lib.rs | 2 + src/common/src/array/proto_reader.rs | 22 +++------ src/common/src/types/datetime.rs | 6 +-- src/common/src/types/decimal.rs | 3 +- src/common/src/types/interval.rs | 2 +- src/common/src/types/jsonb.rs | 2 +- src/common/src/types/mod.rs | 2 +- src/common/src/types/num256.rs | 2 +- src/common/src/types/serial.rs | 2 +- src/common/src/types/timestamptz.rs | 3 +- src/common/src/types/to_binary.rs | 33 +++++++++---- src/common/src/util/mod.rs | 43 ----------------- src/common/src/util/sort_util.rs | 27 +++++------ .../src/rpc/service/monitor_service.rs | 6 ++- .../src/bin/replay/replay_impl.rs | 8 ++-- .../src/mock_notification_client.rs | 12 +++-- 23 files changed, 139 insertions(+), 138 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4065f1e83b0cf..c1299d6e14157 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7485,6 +7485,7 @@ dependencies = [ "sysinfo", "tempfile", "thiserror", + "thiserror-ext", "tinyvec", "toml 0.8.2", "tower-layer", @@ -7507,6 +7508,8 @@ dependencies = [ "madsim-tokio", "parking_lot 0.12.1", "risingwave_common", + "thiserror", + "thiserror-ext", "tikv-jemalloc-ctl", "tracing", ] @@ -7535,6 +7538,8 @@ dependencies = [ "risingwave_common", "risingwave_pb", "risingwave_rpc_client", + "thiserror", + "thiserror-ext", "tower", "tower-http", "tracing", diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index 921c02ee6ae4e..cff2dd4123d94 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -89,6 +89,7 @@ strum = "0.25" strum_macros = "0.25" sysinfo = { version = "0.29", default-features = false } thiserror = "1" +thiserror-ext = { workspace = true } tinyvec = { version = "1", features = ["rustc_1_55", "grab_spare_slice"] } tokio = { version = "0.2", package = "madsim-tokio", features = [ "rt", diff --git a/src/common/common_service/Cargo.toml b/src/common/common_service/Cargo.toml index 1eaa14c46b8e9..ea228f9dcba88 100644 --- a/src/common/common_service/Cargo.toml +++ b/src/common/common_service/Cargo.toml @@ -22,6 +22,8 @@ prometheus = { version = "0.13" } risingwave_common = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } +thiserror = "1" +thiserror-ext = { workspace = true } tokio = { version = "0.2", package = "madsim-tokio", features = ["rt", "rt-multi-thread", "sync", "macros", "time", "signal"] } tonic = { workspace = true } tower = { version = "0.4", features = ["util", "load-shed"] } diff --git a/src/common/common_service/src/lib.rs b/src/common/common_service/src/lib.rs index 6ee220588d2fa..5456061035fc1 100644 --- a/src/common/common_service/src/lib.rs +++ b/src/common/common_service/src/lib.rs @@ -16,6 +16,7 @@ #![feature(lint_reasons)] #![feature(impl_trait_in_assoc_type)] +#![feature(error_generic_member_access)] pub mod metrics_manager; pub mod observer_manager; diff --git a/src/common/common_service/src/observer_manager.rs b/src/common/common_service/src/observer_manager.rs index d97618b3605c6..c6db757853941 100644 --- a/src/common/common_service/src/observer_manager.rs +++ b/src/common/common_service/src/observer_manager.rs @@ -14,8 +14,6 @@ use std::time::Duration; -use risingwave_common::bail; -use risingwave_common::error::Result; use risingwave_pb::meta::subscribe_response::Info; use risingwave_pb::meta::{SubscribeResponse, SubscribeType}; use risingwave_rpc_client::error::RpcError; @@ -80,6 +78,25 @@ impl ObserverManager { } } +#[derive(thiserror::Error, Debug)] +pub enum ObserverError { + #[error("notification channel closed")] + ChannelClosed, + + #[error(transparent)] + Rpc( + #[from] + #[backtrace] + RpcError, + ), +} + +impl From for ObserverError { + fn from(value: tonic::Status) -> Self { + Self::Rpc(value.into()) + } +} + impl ObserverManager where T: NotificationClient, @@ -97,24 +114,19 @@ where } } - async fn wait_init_notification(&mut self) -> Result<()> { + async fn wait_init_notification(&mut self) -> Result<(), ObserverError> { let mut notification_vec = Vec::new(); let init_notification = loop { // notification before init notification must be received successfully. - match self.rx.message().await { - Ok(Some(notification)) => { + match self.rx.message().await? { + Some(notification) => { if !matches!(notification.info.as_ref().unwrap(), &Info::Snapshot(_)) { notification_vec.push(notification); } else { break notification; } } - Ok(None) => { - bail!("notification channel from meta is closed"); - } - Err(err) => { - bail!("receives meta's notification err: {:?}", err); - } + None => return Err(ObserverError::ChannelClosed), } }; @@ -231,7 +243,10 @@ impl Channel for Streaming { #[async_trait::async_trait] pub trait NotificationClient: Send + Sync + 'static { type Channel: Channel; - async fn subscribe(&self, subscribe_type: SubscribeType) -> Result; + async fn subscribe( + &self, + subscribe_type: SubscribeType, + ) -> Result; } pub struct RpcNotificationClient { @@ -248,10 +263,13 @@ impl RpcNotificationClient { impl NotificationClient for RpcNotificationClient { type Channel = Streaming; - async fn subscribe(&self, subscribe_type: SubscribeType) -> Result { + async fn subscribe( + &self, + subscribe_type: SubscribeType, + ) -> Result { self.meta_client .subscribe(subscribe_type) .await - .map_err(RpcError::into) + .map_err(Into::into) } } diff --git a/src/common/heap_profiling/Cargo.toml b/src/common/heap_profiling/Cargo.toml index c7123eaac5817..0c9b6a1794695 100644 --- a/src/common/heap_profiling/Cargo.toml +++ b/src/common/heap_profiling/Cargo.toml @@ -22,6 +22,8 @@ chrono = { version = "0.4", default-features = false, features = [ ] } parking_lot = "0.12" risingwave_common = { workspace = true } +thiserror = "1" +thiserror-ext = { workspace = true } tikv-jemalloc-ctl = { workspace = true } tokio = { version = "0.2", package = "madsim-tokio" } tracing = "0.1" diff --git a/src/common/heap_profiling/src/jeprof.rs b/src/common/heap_profiling/src/jeprof.rs index 013632f32838e..0a137148c2a16 100644 --- a/src/common/heap_profiling/src/jeprof.rs +++ b/src/common/heap_profiling/src/jeprof.rs @@ -14,12 +14,24 @@ use std::path::Path; use std::process::Command; +use std::result::Result; use std::{env, fs}; -use anyhow::anyhow; -use risingwave_common::error::Result; +#[derive(thiserror::Error, Debug, thiserror_ext::ContextInto)] +pub enum JeprofError { + #[error(transparent)] + IoError(#[from] std::io::Error), -pub async fn run(profile_path: String, collapsed_path: String) -> Result<()> { + #[error("jeprof exit with an error (stdout: {stdout}, stderr: {stderr}): {inner}")] + ExitError { + #[source] + inner: std::process::ExitStatusError, + stdout: String, + stderr: String, + }, +} + +pub async fn run(profile_path: String, collapsed_path: String) -> Result<(), JeprofError> { let executable_path = env::current_exe()?; let prof_cmd = move || { @@ -29,20 +41,15 @@ pub async fn run(profile_path: String, collapsed_path: String) -> Result<()> { .arg(Path::new(&profile_path)) .output() }; - match tokio::task::spawn_blocking(prof_cmd).await.unwrap() { - Ok(output) => { - if output.status.success() { - fs::write(Path::new(&collapsed_path), &output.stdout)?; - Ok(()) - } else { - Err(anyhow!( - "jeprof exit with an error. stdout: {}, stderr: {}", - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ) - .into()) - } - } - Err(e) => Err(e.into()), - } + + let output = tokio::task::spawn_blocking(prof_cmd).await.unwrap()?; + + output.status.exit_ok().into_exit_error( + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr), + )?; + + fs::write(Path::new(&collapsed_path), &output.stdout)?; + + Ok(()) } diff --git a/src/common/heap_profiling/src/lib.rs b/src/common/heap_profiling/src/lib.rs index f6ffb66d836d7..0a5dbabc49fe7 100644 --- a/src/common/heap_profiling/src/lib.rs +++ b/src/common/heap_profiling/src/lib.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![feature(exit_status_error)] + pub const MANUALLY_DUMP_SUFFIX: &str = "manual.heap"; pub const AUTO_DUMP_SUFFIX: &str = "auto.heap"; pub const COLLAPSED_SUFFIX: &str = "collapsed"; diff --git a/src/common/src/array/proto_reader.rs b/src/common/src/array/proto_reader.rs index 4ca6bf7b70d05..ae14300955b97 100644 --- a/src/common/src/array/proto_reader.rs +++ b/src/common/src/array/proto_reader.rs @@ -256,13 +256,12 @@ mod tests { DecimalArray, DecimalArrayBuilder, I32Array, I32ArrayBuilder, TimeArray, TimeArrayBuilder, TimestampArray, TimestampArrayBuilder, Utf8Array, Utf8ArrayBuilder, }; - use crate::error::Result; use crate::types::{Date, Decimal, Time, Timestamp}; // Convert a column to protobuf, then convert it back to column, and ensures the two are // identical. #[test] - fn test_column_protobuf_conversion() -> Result<()> { + fn test_column_protobuf_conversion() { let cardinality = 2048; let mut builder = I32ArrayBuilder::new(cardinality); for i in 0..cardinality { @@ -283,11 +282,10 @@ mod tests { assert!(x.is_none()); } }); - Ok(()) } #[test] - fn test_bool_column_protobuf_conversion() -> Result<()> { + fn test_bool_column_protobuf_conversion() { let cardinality = 2048; let mut builder = BoolArrayBuilder::new(cardinality); for i in 0..cardinality { @@ -306,11 +304,10 @@ mod tests { 1 => assert_eq!(Some(true), x), _ => assert_eq!(None, x), }); - Ok(()) } #[test] - fn test_utf8_column_conversion() -> Result<()> { + fn test_utf8_column_conversion() { let cardinality = 2048; let mut builder = Utf8ArrayBuilder::new(cardinality); for i in 0..cardinality { @@ -330,11 +327,10 @@ mod tests { assert!(x.is_none()); } }); - Ok(()) } #[test] - fn test_decimal_protobuf_conversion() -> Result<()> { + fn test_decimal_protobuf_conversion() { let cardinality = 2048; let mut builder = DecimalArrayBuilder::new(cardinality); for i in 0..cardinality { @@ -355,11 +351,10 @@ mod tests { assert!(x.is_none()); } }); - Ok(()) } #[test] - fn test_date_protobuf_conversion() -> Result<()> { + fn test_date_protobuf_conversion() { let cardinality = 2048; let mut builder = DateArrayBuilder::new(cardinality); for i in 0..cardinality { @@ -380,11 +375,10 @@ mod tests { assert!(x.is_none()); } }); - Ok(()) } #[test] - fn test_time_protobuf_conversion() -> Result<()> { + fn test_time_protobuf_conversion() { let cardinality = 2048; let mut builder = TimeArrayBuilder::new(cardinality); for i in 0..cardinality { @@ -410,11 +404,10 @@ mod tests { assert!(x.is_none()); } }); - Ok(()) } #[test] - fn test_timestamp_protobuf_conversion() -> Result<()> { + fn test_timestamp_protobuf_conversion() { let cardinality = 2048; let mut builder = TimestampArrayBuilder::new(cardinality); for i in 0..cardinality { @@ -440,6 +433,5 @@ mod tests { assert!(x.is_none()); } }); - Ok(()) } } diff --git a/src/common/src/types/datetime.rs b/src/common/src/types/datetime.rs index d1e77190a376f..fb2b6a14f63de 100644 --- a/src/common/src/types/datetime.rs +++ b/src/common/src/types/datetime.rs @@ -318,7 +318,7 @@ impl ToText for Timestamp { } impl ToBinary for Date { - fn to_binary_with_type(&self, ty: &DataType) -> crate::error::Result> { + fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { match ty { super::DataType::Date => { let mut output = BytesMut::new(); @@ -331,7 +331,7 @@ impl ToBinary for Date { } impl ToBinary for Time { - fn to_binary_with_type(&self, ty: &DataType) -> crate::error::Result> { + fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { match ty { super::DataType::Time => { let mut output = BytesMut::new(); @@ -344,7 +344,7 @@ impl ToBinary for Time { } impl ToBinary for Timestamp { - fn to_binary_with_type(&self, ty: &DataType) -> crate::error::Result> { + fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { match ty { super::DataType::Timestamp => { let mut output = BytesMut::new(); diff --git a/src/common/src/types/decimal.rs b/src/common/src/types/decimal.rs index 3b63da076748d..861ccce33a575 100644 --- a/src/common/src/types/decimal.rs +++ b/src/common/src/types/decimal.rs @@ -28,7 +28,6 @@ use super::to_binary::ToBinary; use super::to_text::ToText; use super::DataType; use crate::array::ArrayResult; -use crate::error::Result as RwResult; use crate::estimate_size::EstimateSize; use crate::types::ordered_float::OrderedFloat; use crate::types::Decimal::Normalized; @@ -82,7 +81,7 @@ impl Decimal { } impl ToBinary for Decimal { - fn to_binary_with_type(&self, ty: &DataType) -> RwResult> { + fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { match ty { DataType::Decimal => { let mut output = BytesMut::new(); diff --git a/src/common/src/types/interval.rs b/src/common/src/types/interval.rs index b7347ed6c8f99..ca29b9a28abd3 100644 --- a/src/common/src/types/interval.rs +++ b/src/common/src/types/interval.rs @@ -1161,7 +1161,7 @@ impl<'a> FromSql<'a> for Interval { } impl ToBinary for Interval { - fn to_binary_with_type(&self, ty: &DataType) -> Result> { + fn to_binary_with_type(&self, ty: &DataType) -> super::to_binary::Result> { match ty { DataType::Interval => { let mut output = BytesMut::new(); diff --git a/src/common/src/types/jsonb.rs b/src/common/src/types/jsonb.rs index 4a625faec1cb8..6e27ff5344198 100644 --- a/src/common/src/types/jsonb.rs +++ b/src/common/src/types/jsonb.rs @@ -128,7 +128,7 @@ impl crate::types::to_binary::ToBinary for JsonbRef<'_> { fn to_binary_with_type( &self, _ty: &crate::types::DataType, - ) -> crate::error::Result> { + ) -> super::to_binary::Result> { Ok(Some(self.value_serialize().into())) } } diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 98ea7ee061bd6..70438b493d896 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -1047,7 +1047,7 @@ pub fn hash_datum(datum: impl ToDatumRef, state: &mut impl std::hash::Hasher) { impl ScalarRefImpl<'_> { /// Encode the scalar to postgresql binary format. /// The encoder implements encoding using - pub fn binary_format(&self, data_type: &DataType) -> RwResult { + pub fn binary_format(&self, data_type: &DataType) -> to_binary::Result { self.to_binary_with_type(data_type).transpose().unwrap() } diff --git a/src/common/src/types/num256.rs b/src/common/src/types/num256.rs index 42df630509d08..26c0edaef59e4 100644 --- a/src/common/src/types/num256.rs +++ b/src/common/src/types/num256.rs @@ -165,7 +165,7 @@ macro_rules! impl_common_for_num256 { } impl ToBinary for $scalar_ref<'_> { - fn to_binary_with_type(&self, _ty: &DataType) -> crate::error::Result> { + fn to_binary_with_type(&self, _ty: &DataType) -> super::to_binary::Result> { let mut output = bytes::BytesMut::new(); let buffer = self.to_be_bytes(); output.put_slice(&buffer); diff --git a/src/common/src/types/serial.rs b/src/common/src/types/serial.rs index 65cb7ea743bc6..6601676c93029 100644 --- a/src/common/src/types/serial.rs +++ b/src/common/src/types/serial.rs @@ -75,7 +75,7 @@ impl crate::types::to_binary::ToBinary for Serial { fn to_binary_with_type( &self, _ty: &crate::types::DataType, - ) -> crate::error::Result> { + ) -> super::to_binary::Result> { let mut output = bytes::BytesMut::new(); self.0.to_sql(&Type::ANY, &mut output).unwrap(); Ok(Some(output.freeze())) diff --git a/src/common/src/types/timestamptz.rs b/src/common/src/types/timestamptz.rs index 1f9b962c9d376..6be713b446c2a 100644 --- a/src/common/src/types/timestamptz.rs +++ b/src/common/src/types/timestamptz.rs @@ -25,7 +25,6 @@ use super::to_binary::ToBinary; use super::to_text::ToText; use super::DataType; use crate::array::ArrayResult; -use crate::error::Result; use crate::estimate_size::ZeroHeapSize; /// Timestamp with timezone. @@ -38,7 +37,7 @@ pub struct Timestamptz(i64); impl ZeroHeapSize for Timestamptz {} impl ToBinary for Timestamptz { - fn to_binary_with_type(&self, _ty: &DataType) -> Result> { + fn to_binary_with_type(&self, _ty: &DataType) -> super::to_binary::Result> { let instant = self.to_datetime_utc(); let mut out = BytesMut::new(); // postgres_types::Type::ANY is only used as a placeholder. diff --git a/src/common/src/types/to_binary.rs b/src/common/src/types/to_binary.rs index 540167cc4f02c..981678b854c9d 100644 --- a/src/common/src/types/to_binary.rs +++ b/src/common/src/types/to_binary.rs @@ -16,7 +16,19 @@ use bytes::{Bytes, BytesMut}; use postgres_types::{ToSql, Type}; use super::{DataType, DatumRef, ScalarRefImpl, F32, F64}; -use crate::error::{ErrorCode, Result}; +use crate::error::TrackingIssue; + +#[derive(thiserror::Error, Debug)] +pub enum ToBinaryError { + #[error(transparent)] + ToSql(Box), + + #[error("Feature is not yet implemented: {0}\n{1}")] + NotImplemented(String, TrackingIssue), +} + +pub type Result = std::result::Result; + // Used to convert ScalarRef to text format pub trait ToBinary { fn to_binary_with_type(&self, ty: &DataType) -> Result>; @@ -32,7 +44,7 @@ macro_rules! implement_using_to_sql { DataType::$data_type => { let mut output = BytesMut::new(); #[allow(clippy::redundant_closure_call)] - $accessor(self).to_sql(&Type::ANY, &mut output).unwrap(); + $accessor(self).to_sql(&Type::ANY, &mut output).map_err(ToBinaryError::ToSql)?; Ok(Some(output.freeze())) }, _ => unreachable!(), @@ -74,14 +86,15 @@ impl ToBinary for ScalarRefImpl<'_> { ScalarRefImpl::Time(v) => v.to_binary_with_type(ty), ScalarRefImpl::Bytea(v) => v.to_binary_with_type(ty), ScalarRefImpl::Jsonb(v) => v.to_binary_with_type(ty), - ScalarRefImpl::Struct(_) | ScalarRefImpl::List(_) => Err(ErrorCode::NotImplemented( - format!( - "the pgwire extended-mode encoding for {} is unsupported", - ty - ), - Some(7949).into(), - ) - .into()), + ScalarRefImpl::Struct(_) | ScalarRefImpl::List(_) => { + Err(ToBinaryError::NotImplemented( + format!( + "the pgwire extended-mode encoding for {} is unsupported", + ty + ), + Some(7949).into(), + )) + } } } } diff --git a/src/common/src/util/mod.rs b/src/common/src/util/mod.rs index e1f85263e1415..2f2b5cb1f0cd2 100644 --- a/src/common/src/util/mod.rs +++ b/src/common/src/util/mod.rs @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::any::{type_name, Any}; -use std::sync::Arc; - pub use self::prost::*; -use crate::error::ErrorCode::InternalError; -use crate::error::{Result, RwError}; pub mod addr; pub mod chunk_coalesce; @@ -50,41 +45,3 @@ pub use future_utils::{ }; #[macro_use] pub mod match_util; - -pub fn downcast_ref(source: &S) -> Result<&T> -where - S: AsRef + ?Sized, - T: 'static, -{ - source.as_ref().downcast_ref::().ok_or_else(|| { - RwError::from(InternalError(format!( - "Failed to cast to {}", - type_name::() - ))) - }) -} - -pub fn downcast_arc(source: Arc) -> Result> -where - T: 'static + Send + Sync, -{ - source.downcast::().map_err(|_| { - RwError::from(InternalError(format!( - "Failed to cast to {}", - type_name::() - ))) - }) -} - -pub fn downcast_mut(source: &mut S) -> Result<&mut T> -where - S: AsMut + ?Sized, - T: 'static, -{ - source.as_mut().downcast_mut::().ok_or_else(|| { - RwError::from(InternalError(format!( - "Failed to cast to {}", - type_name::() - ))) - }) -} diff --git a/src/common/src/util/sort_util.rs b/src/common/src/util/sort_util.rs index 3129cbcac5172..091dbe3cfaa14 100644 --- a/src/common/src/util/sort_util.rs +++ b/src/common/src/util/sort_util.rs @@ -23,8 +23,6 @@ use super::iter_util::ZipEqDebug; use crate::array::{Array, DataChunk}; use crate::catalog::{FieldDisplay, Schema}; use crate::dispatch_array_variants; -use crate::error::ErrorCode::InternalError; -use crate::error::Result; use crate::estimate_size::EstimateSize; use crate::row::Row; use crate::types::{DefaultOrdered, ToDatumRef}; @@ -372,7 +370,6 @@ impl Ord for HeapElem { other.elem_idx, self.column_orders.as_ref(), ) - .unwrap() }; ord.reverse() } @@ -439,25 +436,25 @@ where .expect("items in the same `Array` type should be able to compare") } -pub fn compare_rows_in_chunk( +fn compare_rows_in_chunk( lhs_data_chunk: &DataChunk, lhs_idx: usize, rhs_data_chunk: &DataChunk, rhs_idx: usize, column_orders: &[ColumnOrder], -) -> Result { +) -> Ordering { for column_order in column_orders { let lhs_array = lhs_data_chunk.column_at(column_order.column_index); let rhs_array = rhs_data_chunk.column_at(column_order.column_index); let res = dispatch_array_variants!(&**lhs_array, lhs_inner, { - let rhs_inner = (&**rhs_array).try_into().map_err(|_| { - InternalError(format!( + let rhs_inner = (&**rhs_array).try_into().unwrap_or_else(|_| { + panic!( "Unmatched array types, lhs array is: {}, rhs array is: {}", lhs_array.get_ident(), rhs_array.get_ident(), - )) - })?; + ) + }); compare_values_in_array( lhs_inner, lhs_idx, @@ -468,10 +465,10 @@ pub fn compare_rows_in_chunk( }); if res != Ordering::Equal { - return Ok(res); + return res; } } - Ok(Ordering::Equal) + Ordering::Equal } /// Partial compare two `Datum`s with specified order type. @@ -651,11 +648,11 @@ mod tests { assert_eq!( Ordering::Equal, - compare_rows_in_chunk(&chunk, 0, &chunk, 0, &column_orders).unwrap() + compare_rows_in_chunk(&chunk, 0, &chunk, 0, &column_orders) ); assert_eq!( Ordering::Less, - compare_rows_in_chunk(&chunk, 0, &chunk, 1, &column_orders).unwrap() + compare_rows_in_chunk(&chunk, 0, &chunk, 1, &column_orders) ); } @@ -731,11 +728,11 @@ mod tests { ); assert_eq!( Ordering::Equal, - compare_rows_in_chunk(&chunk, 0, &chunk, 0, &column_orders).unwrap() + compare_rows_in_chunk(&chunk, 0, &chunk, 0, &column_orders) ); assert_eq!( Ordering::Less, - compare_rows_in_chunk(&chunk, 0, &chunk, 1, &column_orders).unwrap() + compare_rows_in_chunk(&chunk, 0, &chunk, 1, &column_orders) ); } diff --git a/src/compute/src/rpc/service/monitor_service.rs b/src/compute/src/rpc/service/monitor_service.rs index 8fc24664ec016..fa2d9ed064c40 100644 --- a/src/compute/src/rpc/service/monitor_service.rs +++ b/src/compute/src/rpc/service/monitor_service.rs @@ -27,8 +27,9 @@ use risingwave_pb::monitor_service::{ ListHeapProfilingRequest, ListHeapProfilingResponse, ProfilingRequest, ProfilingResponse, StackTraceRequest, StackTraceResponse, }; +use risingwave_rpc_client::error::ToTonicStatus; use risingwave_stream::task::LocalStreamManager; -use tonic::{Request, Response, Status}; +use tonic::{Code, Request, Response, Status}; #[derive(Clone)] pub struct MonitorServiceImpl { @@ -221,7 +222,8 @@ impl MonitorService for MonitorServiceImpl { dumped_path_str, collapsed_path_str.clone(), ) - .await?; + .await + .map_err(|e| e.to_status(Code::Internal, "monitor"))?; } let file = fs::read(Path::new(&collapsed_path_str))?; diff --git a/src/storage/hummock_test/src/bin/replay/replay_impl.rs b/src/storage/hummock_test/src/bin/replay/replay_impl.rs index 67d299ec34cc2..75e48cd5becdd 100644 --- a/src/storage/hummock_test/src/bin/replay/replay_impl.rs +++ b/src/storage/hummock_test/src/bin/replay/replay_impl.rs @@ -17,9 +17,8 @@ use std::ops::Bound; use futures::stream::BoxStream; use futures::{Stream, StreamExt}; use futures_async_stream::{for_await, try_stream}; -use risingwave_common::error::Result as RwResult; use risingwave_common::util::addr::HostAddr; -use risingwave_common_service::observer_manager::{Channel, NotificationClient}; +use risingwave_common_service::observer_manager::{Channel, NotificationClient, ObserverError}; use risingwave_hummock_sdk::key::TableKey; use risingwave_hummock_sdk::HummockReadEpoch; use risingwave_hummock_trace::{ @@ -320,7 +319,10 @@ impl ReplayNotificationClient { impl NotificationClient for ReplayNotificationClient { type Channel = ReplayChannel; - async fn subscribe(&self, subscribe_type: SubscribeType) -> RwResult { + async fn subscribe( + &self, + subscribe_type: SubscribeType, + ) -> std::result::Result { let (tx, rx) = unbounded_channel(); self.notification_manager diff --git a/src/storage/hummock_test/src/mock_notification_client.rs b/src/storage/hummock_test/src/mock_notification_client.rs index b88f0e467c9b1..991a5a9d5bf84 100644 --- a/src/storage/hummock_test/src/mock_notification_client.rs +++ b/src/storage/hummock_test/src/mock_notification_client.rs @@ -15,9 +15,8 @@ use std::collections::HashMap; use std::sync::Arc; -use risingwave_common::error::Result; use risingwave_common::util::addr::HostAddr; -use risingwave_common_service::observer_manager::{Channel, NotificationClient}; +use risingwave_common_service::observer_manager::{Channel, NotificationClient, ObserverError}; use risingwave_meta::hummock::{HummockManager, HummockManagerRef}; use risingwave_meta::manager::{MessageStatus, MetaSrvEnv, NotificationManagerRef, WorkerKey}; use risingwave_pb::backup_service::MetaBackupManifestId; @@ -50,7 +49,10 @@ impl MockNotificationClient { impl NotificationClient for MockNotificationClient { type Channel = TestChannel; - async fn subscribe(&self, subscribe_type: SubscribeType) -> Result { + async fn subscribe( + &self, + subscribe_type: SubscribeType, + ) -> Result { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let worker_key = WorkerKey(self.addr.to_protobuf()); @@ -88,13 +90,13 @@ pub fn get_notification_client_for_test( ) } -pub struct TestChannel(UnboundedReceiver>); +pub struct TestChannel(UnboundedReceiver>); #[async_trait::async_trait] impl Channel for TestChannel { type Item = T; - async fn message(&mut self) -> std::result::Result, MessageStatus> { + async fn message(&mut self) -> Result, MessageStatus> { match self.0.recv().await { None => Ok(None), Some(result) => result.map(|r| Some(r)), From 05ab5e1cabbdde10b1e95d6461b1b0dbcfd4fa74 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 22 Nov 2023 16:07:11 +0800 Subject: [PATCH 2/4] refactor more Signed-off-by: Bugen Zhao --- src/batch/src/error.rs | 7 + src/batch/src/executor/sys_row_seq_scan.rs | 6 +- src/common/src/catalog/column.rs | 19 -- src/common/src/catalog/mod.rs | 4 +- src/common/src/types/interval.rs | 92 ++++----- src/common/src/types/mod.rs | 195 ++++++++---------- src/common/src/types/num256.rs | 5 +- src/common/src/types/postgres_type.rs | 9 +- src/common/src/util/match_util.rs | 40 ++-- src/frontend/src/binder/bind_param.rs | 9 +- src/frontend/src/binder/expr/value.rs | 3 +- .../src/catalog/system_catalog/mod.rs | 6 +- src/frontend/src/handler/util.rs | 5 +- src/utils/pgwire/src/pg_protocol.rs | 8 +- 14 files changed, 195 insertions(+), 213 deletions(-) diff --git a/src/batch/src/error.rs b/src/batch/src/error.rs index 4336b86055d8b..6894bea33b515 100644 --- a/src/batch/src/error.rs +++ b/src/batch/src/error.rs @@ -94,6 +94,13 @@ pub enum BatchError { BoxedError, ), + #[error("Failed to read from system table: {0}")] + SystemTable( + #[from] + #[backtrace] + BoxedError, + ), + // Make the ref-counted type to be a variant for easier code structuring. #[error(transparent)] Shared( diff --git a/src/batch/src/executor/sys_row_seq_scan.rs b/src/batch/src/executor/sys_row_seq_scan.rs index d0103d9883869..d28b0b95c5a38 100644 --- a/src/batch/src/executor/sys_row_seq_scan.rs +++ b/src/batch/src/executor/sys_row_seq_scan.rs @@ -107,7 +107,11 @@ impl Executor for SysRowSeqScanExecutor { impl SysRowSeqScanExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] async fn do_executor(self: Box) { - let rows = self.sys_catalog_reader.read_table(&self.table_id).await?; + let rows = self + .sys_catalog_reader + .read_table(&self.table_id) + .await + .map_err(BatchError::SystemTable)?; let filtered_rows = rows .iter() .map(|row| { diff --git a/src/common/src/catalog/column.rs b/src/common/src/catalog/column.rs index 7a984724bf116..68c1618073169 100644 --- a/src/common/src/catalog/column.rs +++ b/src/common/src/catalog/column.rs @@ -21,7 +21,6 @@ use risingwave_pb::plan_common::{PbColumnCatalog, PbColumnDesc}; use super::row_id_column_desc; use crate::catalog::{cdc_table_name_column_desc, offset_column_desc, Field, ROW_ID_COLUMN_ID}; -use crate::error::ErrorCode; use crate::types::DataType; /// Column ID is the unique identifier of a column in a table. Different from table ID, column ID is @@ -161,24 +160,6 @@ impl ColumnDesc { descs } - /// Find `column_desc` in `field_descs` by name. - pub fn field(&self, name: &String) -> crate::error::Result<(ColumnDesc, i32)> { - if let DataType::Struct { .. } = self.data_type { - for (index, col) in self.field_descs.iter().enumerate() { - if col.name == *name { - return Ok((col.clone(), index as i32)); - } - } - Err(ErrorCode::ItemNotFound(format!("Invalid field name: {}", name)).into()) - } else { - Err(ErrorCode::ItemNotFound(format!( - "Cannot get field from non nested column: {}", - self.name - )) - .into()) - } - } - pub fn new_atomic(data_type: DataType, name: &str, column_id: i32) -> Self { Self { data_type, diff --git a/src/common/src/catalog/mod.rs b/src/common/src/catalog/mod.rs index 204a5005cd2de..a8a698128d9b6 100644 --- a/src/common/src/catalog/mod.rs +++ b/src/common/src/catalog/mod.rs @@ -32,7 +32,7 @@ use risingwave_pb::catalog::HandleConflictBehavior as PbHandleConflictBehavior; pub use schema::{test_utils as schema_test_utils, Field, FieldDisplay, Schema}; pub use crate::constants::hummock; -use crate::error::Result; +use crate::error::BoxedError; use crate::row::OwnedRow; use crate::types::DataType; @@ -134,7 +134,7 @@ pub fn cdc_table_name_column_desc() -> ColumnDesc { /// The local system catalog reader in the frontend node. #[async_trait] pub trait SysCatalogReader: Sync + Send + 'static { - async fn read_table(&self, table_id: &TableId) -> Result>; + async fn read_table(&self, table_id: &TableId) -> Result, BoxedError>; } pub type SysCatalogReaderRef = Arc; diff --git a/src/common/src/types/interval.rs b/src/common/src/types/interval.rs index ca29b9a28abd3..a95bc412124b9 100644 --- a/src/common/src/types/interval.rs +++ b/src/common/src/types/interval.rs @@ -31,7 +31,6 @@ use rust_decimal::prelude::Decimal; use super::to_binary::ToBinary; use super::*; -use crate::error::{ErrorCode, Result, RwError}; use crate::estimate_size::EstimateSize; /// Every interval can be represented by a `Interval`. @@ -1001,6 +1000,23 @@ impl ToText for crate::types::Interval { } } +#[derive(thiserror::Error, Debug, thiserror_ext::Construct)] +pub enum IntervalParseError { + #[error("Invalid interval: {0}")] + Invalid(String), + + #[error("Invalid interval: {0}, expected format PYMDTHMS")] + InvalidIso8601(String), + + #[error("Invalid unit: {0}")] + InvalidUnit(String), + + #[error("{0}")] + Uncategorized(String), +} + +type ParseResult = std::result::Result; + impl Interval { pub fn as_iso_8601(&self) -> String { // ISO pattern - PnYnMnDTnHnMnS @@ -1029,7 +1045,7 @@ impl Interval { /// /// Example /// - P1Y2M3DT4H5M6.78S - pub fn from_iso_8601(s: &str) -> Result { + pub fn from_iso_8601(s: &str) -> ParseResult { // ISO pattern - PnYnMnDTnHnMnS static ISO_8601_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"P([0-9]+)Y([0-9]+)M([0-9]+)DT([0-9]+)H([0-9]+)M([0-9]+(?:\.[0-9]+)?)S") @@ -1061,7 +1077,7 @@ impl Interval { .checked_add(usecs)?, )) }; - f().ok_or_else(|| ErrorCode::InvalidInputSyntax(format!("Invalid interval: {}, expected format PYMDTHMS", s)).into()) + f().ok_or_else(|| IntervalParseError::invalid_iso8601(s)) } } @@ -1184,9 +1200,9 @@ pub enum DateTimeField { } impl FromStr for DateTimeField { - type Err = RwError; + type Err = IntervalParseError; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> ParseResult { match s.to_lowercase().as_str() { "years" | "year" | "yrs" | "yr" | "y" => Ok(Self::Year), "days" | "day" | "d" => Ok(Self::Day), @@ -1194,7 +1210,7 @@ impl FromStr for DateTimeField { "minutes" | "minute" | "mins" | "min" | "m" => Ok(Self::Minute), "months" | "month" | "mons" | "mon" => Ok(Self::Month), "seconds" | "second" | "secs" | "sec" | "s" => Ok(Self::Second), - _ => Err(ErrorCode::InvalidInputSyntax(format!("unknown unit {}", s)).into()), + _ => Err(IntervalParseError::invalid_unit(s)), } } } @@ -1206,7 +1222,7 @@ enum TimeStrToken { TimeUnit(DateTimeField), } -fn parse_interval(s: &str) -> Result> { +fn parse_interval(s: &str) -> ParseResult> { let s = s.trim(); let mut tokens = Vec::new(); let mut num_buf = "".to_string(); @@ -1235,21 +1251,16 @@ fn parse_interval(s: &str) -> Result> { ':' => { // there must be a digit before the ':' if num_buf.is_empty() { - return Err(ErrorCode::InvalidInputSyntax(format!( - "invalid interval format: {}", - s - )) - .into()); + return Err(IntervalParseError::invalid(s)); } hour_min_sec.push(num_buf.clone()); num_buf.clear(); } _ => { - return Err(ErrorCode::InvalidInputSyntax(format!( + return Err(IntervalParseError::uncategorized(format!( "Invalid character at offset {} in {}: {:?}. Only support digit or alphabetic now", i,s, c - )) - .into()); + ))); } }; } @@ -1262,23 +1273,20 @@ fn parse_interval(s: &str) -> Result> { convert_digit(&mut num_buf, &mut tokens)?; } convert_unit(&mut char_buf, &mut tokens)?; - convert_hms(&hour_min_sec, &mut tokens).ok_or_else(|| { - ErrorCode::InvalidInputSyntax(format!("Invalid interval: {:?}", hour_min_sec)) - })?; + convert_hms(&hour_min_sec, &mut tokens) + .ok_or_else(|| IntervalParseError::invalid(format!("{hour_min_sec:?}")))?; Ok(tokens) } -fn convert_digit(c: &mut String, t: &mut Vec) -> Result<()> { +fn convert_digit(c: &mut String, t: &mut Vec) -> ParseResult<()> { if !c.is_empty() { match c.parse::() { Ok(num) => { t.push(TimeStrToken::Num(num)); } Err(_) => { - return Err( - ErrorCode::InvalidInputSyntax(format!("Invalid interval: {}", c)).into(), - ); + return Err(IntervalParseError::invalid(c.clone())); } } c.clear(); @@ -1286,7 +1294,7 @@ fn convert_digit(c: &mut String, t: &mut Vec) -> Result<()> { Ok(()) } -fn convert_unit(c: &mut String, t: &mut Vec) -> Result<()> { +fn convert_unit(c: &mut String, t: &mut Vec) -> ParseResult<()> { if !c.is_empty() { t.push(TimeStrToken::TimeUnit(c.parse()?)); c.clear(); @@ -1338,25 +1346,17 @@ fn convert_hms(c: &Vec, t: &mut Vec) -> Option<()> { } impl Interval { - fn parse_sql_standard(s: &str, leading_field: DateTimeField) -> Result { + fn parse_sql_standard(s: &str, leading_field: DateTimeField) -> ParseResult { use DateTimeField::*; let tokens = parse_interval(s)?; // Todo: support more syntax if tokens.len() > 1 { - return Err(ErrorCode::InvalidInputSyntax(format!( - "(standard sql format) Can't support syntax of interval {}.", - &s - )) - .into()); + return Err(IntervalParseError::invalid(s)); } let num = match tokens.first() { Some(TimeStrToken::Num(num)) => *num, _ => { - return Err(ErrorCode::InvalidInputSyntax(format!( - "(standard sql format)Invalid interval {}.", - &s - )) - .into()); + return Err(IntervalParseError::invalid(s)); } }; @@ -1380,10 +1380,10 @@ impl Interval { Some(Interval::from_month_day_usec(0, 0, usecs)) } })() - .ok_or_else(|| ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", s)).into()) + .ok_or_else(|| IntervalParseError::invalid(s)) } - fn parse_postgres(s: &str) -> Result { + fn parse_postgres(s: &str) -> ParseResult { use DateTimeField::*; let mut tokens = parse_interval(s)?; if tokens.len() % 2 != 0 @@ -1392,7 +1392,7 @@ impl Interval { tokens.push(TimeStrToken::TimeUnit(DateTimeField::Second)); } if tokens.len() % 2 != 0 { - return Err(ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", &s)).into()); + return Err(IntervalParseError::invalid(s)); } let mut token_iter = tokens.into_iter(); let mut result = Interval::from_month_day_usec(0, 0, 0); @@ -1422,9 +1422,7 @@ impl Interval { } })() .and_then(|rhs| result.checked_add(&rhs)) - .ok_or_else(|| { - ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", s)) - })?; + .ok_or_else(|| IntervalParseError::invalid(s))?; } (TimeStrToken::Second(second), TimeStrToken::TimeUnit(interval_unit)) => { result = match interval_unit { @@ -1438,21 +1436,17 @@ impl Interval { _ => None, } .and_then(|rhs| result.checked_add(&rhs)) - .ok_or_else(|| { - ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", s)) - })?; + .ok_or_else(|| IntervalParseError::invalid(s))?; } _ => { - return Err( - ErrorCode::InvalidInputSyntax(format!("Invalid interval {}.", &s)).into(), - ); + return Err(IntervalParseError::invalid(s)); } } } Ok(result) } - pub fn parse_with_fields(s: &str, leading_field: Option) -> Result { + pub fn parse_with_fields(s: &str, leading_field: Option) -> ParseResult { if let Some(leading_field) = leading_field { Self::parse_sql_standard(s, leading_field) } else { @@ -1462,9 +1456,9 @@ impl Interval { } impl FromStr for Interval { - type Err = RwError; + type Err = IntervalParseError; - fn from_str(s: &str) -> Result { + fn from_str(s: &str) -> ParseResult { Self::parse_with_fields(s, None) } } diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 70438b493d896..a700e5bcb0579 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -38,7 +38,7 @@ use crate::array::{ }; pub use crate::array::{ListRef, ListValue, StructRef, StructValue}; use crate::cast::{str_to_bool, str_to_bytea}; -use crate::error::{BoxedError, ErrorCode, Result as RwResult}; +use crate::error::BoxedError; use crate::estimate_size::EstimateSize; use crate::util::iter_util::ZipEqDebug; use crate::{ @@ -754,94 +754,94 @@ impl From> for ScalarImpl { } } +#[derive(Debug, thiserror::Error, thiserror_ext::Construct)] +pub enum FromSqlError { + #[error(transparent)] + FromBinary(BoxedError), + + #[error("Invalid param: {0}")] + FromText(String), + + #[error("Unsupported data type: {0}")] + Unsupported(DataType), +} + impl ScalarImpl { - pub fn from_binary(bytes: &Bytes, data_type: &DataType) -> RwResult { + pub fn from_binary(bytes: &Bytes, data_type: &DataType) -> Result { let res = match data_type { DataType::Varchar => Self::Utf8( String::from_sql(&Type::VARCHAR, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Bytea => Self::Bytea( Vec::::from_sql(&Type::BYTEA, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), - DataType::Boolean => Self::Bool( - bool::from_sql(&Type::BOOL, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, - ), - DataType::Int16 => Self::Int16( - i16::from_sql(&Type::INT2, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, - ), - DataType::Int32 => Self::Int32( - i32::from_sql(&Type::INT4, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, - ), - DataType::Int64 => Self::Int64( - i64::from_sql(&Type::INT8, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, - ), + DataType::Boolean => { + Self::Bool(bool::from_sql(&Type::BOOL, bytes).map_err(FromSqlError::from_binary)?) + } + DataType::Int16 => { + Self::Int16(i16::from_sql(&Type::INT2, bytes).map_err(FromSqlError::from_binary)?) + } + DataType::Int32 => { + Self::Int32(i32::from_sql(&Type::INT4, bytes).map_err(FromSqlError::from_binary)?) + } + DataType::Int64 => { + Self::Int64(i64::from_sql(&Type::INT8, bytes).map_err(FromSqlError::from_binary)?) + } DataType::Serial => Self::Serial(Serial::from( - i64::from_sql(&Type::INT8, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + i64::from_sql(&Type::INT8, bytes).map_err(FromSqlError::from_binary)?, )), DataType::Float32 => Self::Float32( f32::from_sql(&Type::FLOAT4, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Float64 => Self::Float64( f64::from_sql(&Type::FLOAT8, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Decimal => Self::Decimal( rust_decimal::Decimal::from_sql(&Type::NUMERIC, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Date => Self::Date( chrono::NaiveDate::from_sql(&Type::DATE, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Time => Self::Time( chrono::NaiveTime::from_sql(&Type::TIME, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Timestamp => Self::Timestamp( chrono::NaiveDateTime::from_sql(&Type::TIMESTAMP, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Timestamptz => Self::Timestamptz( chrono::DateTime::::from_sql(&Type::TIMESTAMPTZ, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))? + .map_err(FromSqlError::from_binary)? .into(), ), DataType::Interval => Self::Interval( - Interval::from_sql(&Type::INTERVAL, bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + Interval::from_sql(&Type::INTERVAL, bytes).map_err(FromSqlError::from_binary)?, ), - DataType::Jsonb => { - Self::Jsonb(JsonbVal::value_deserialize(bytes).ok_or_else(|| { - ErrorCode::InvalidInputSyntax("Invalid value of Jsonb".to_string()) - })?) - } - DataType::Int256 => Self::Int256( - Int256::from_binary(bytes) - .map_err(|err| ErrorCode::InvalidInputSyntax(err.to_string()))?, + DataType::Jsonb => Self::Jsonb( + JsonbVal::value_deserialize(bytes) + .ok_or_else(|| FromSqlError::from_binary("Invalid value of Jsonb"))?, ), + DataType::Int256 => { + Self::Int256(Int256::from_binary(bytes).map_err(FromSqlError::from_binary)?) + } DataType::Struct(_) | DataType::List { .. } => { - return Err(ErrorCode::NotSupported( - format!("param type: {}", data_type), - "".to_string(), - ) - .into()) + return Err(FromSqlError::Unsupported(data_type.clone())); } }; Ok(res) @@ -856,78 +856,66 @@ impl ScalarImpl { std::str::from_utf8(without_null) } - pub fn from_text(bytes: &[u8], data_type: &DataType) -> RwResult { - let str = Self::cstr_to_str(bytes).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {:?}", bytes)) - })?; + pub fn from_text(bytes: &[u8], data_type: &DataType) -> Result { + let str = + Self::cstr_to_str(bytes).map_err(|_| FromSqlError::from_text(format!("{bytes:?}")))?; let res = match data_type { DataType::Varchar => Self::Utf8(str.to_string().into()), - DataType::Boolean => Self::Bool(bool::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Int16 => Self::Int16(i16::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Int32 => Self::Int32(i32::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Int64 => Self::Int64(i64::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Int256 => Self::Int256(Int256::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Serial => Self::Serial(Serial::from(i64::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?)), + DataType::Boolean => { + Self::Bool(bool::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Int16 => { + Self::Int16(i16::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Int32 => { + Self::Int32(i32::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Int64 => { + Self::Int64(i64::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Int256 => { + Self::Int256(Int256::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Serial => Self::Serial(Serial::from( + i64::from_str(str).map_err(|_| FromSqlError::from_text(str))?, + )), DataType::Float32 => Self::Float32( f32::from_str(str) - .map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })? + .map_err(|_| FromSqlError::from_text(str))? .into(), ), DataType::Float64 => Self::Float64( f64::from_str(str) - .map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })? + .map_err(|_| FromSqlError::from_text(str))? .into(), ), DataType::Decimal => Self::Decimal( rust_decimal::Decimal::from_str(str) - .map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })? + .map_err(|_| FromSqlError::from_text(str))? .into(), ), - DataType::Date => Self::Date(Date::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Time => Self::Time(Time::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Timestamp => Self::Timestamp(Timestamp::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Timestamptz => { - Self::Timestamptz(Timestamptz::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?) + DataType::Date => { + Self::Date(Date::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Time => { + Self::Time(Time::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Timestamp => { + Self::Timestamp(Timestamp::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Timestamptz => Self::Timestamptz( + Timestamptz::from_str(str).map_err(|_| FromSqlError::from_text(str))?, + ), + DataType::Interval => { + Self::Interval(Interval::from_str(str).map_err(|_| FromSqlError::from_text(str))?) + } + DataType::Jsonb => { + Self::Jsonb(JsonbVal::from_str(str).map_err(|_| FromSqlError::from_text(str))?) } - DataType::Interval => Self::Interval(Interval::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), - DataType::Jsonb => Self::Jsonb(JsonbVal::from_str(str).map_err(|_| { - ErrorCode::InvalidInputSyntax(format!("Invalid param string: {}", str)) - })?), DataType::List(datatype) => { // TODO: support nested list if !(str.starts_with('{') && str.ends_with('}')) { - return Err(ErrorCode::InvalidInputSyntax(format!( - "Invalid param string: {str}", - )) - .into()); + return Err(FromSqlError::from_text(str)); } let mut values = vec![]; for s in str[1..str.len() - 1].split(',') { @@ -937,10 +925,7 @@ impl ScalarImpl { } DataType::Struct(s) => { if !(str.starts_with('{') && str.ends_with('}')) { - return Err(ErrorCode::InvalidInputSyntax(format!( - "Invalid param string: {str}", - )) - .into()); + return Err(FromSqlError::from_text(str)); } let mut fields = Vec::with_capacity(s.len()); for (s, ty) in str[1..str.len() - 1].split(',').zip_eq_debug(s.types()) { @@ -949,11 +934,7 @@ impl ScalarImpl { ScalarImpl::Struct(StructValue::new(fields)) } DataType::Bytea => { - return Err(ErrorCode::NotSupported( - format!("param type: {}", data_type), - "".to_string(), - ) - .into()) + return Err(FromSqlError::unsupported(data_type.clone())); } }; Ok(res) diff --git a/src/common/src/types/num256.rs b/src/common/src/types/num256.rs index 26c0edaef59e4..864af97deb374 100644 --- a/src/common/src/types/num256.rs +++ b/src/common/src/types/num256.rs @@ -165,7 +165,10 @@ macro_rules! impl_common_for_num256 { } impl ToBinary for $scalar_ref<'_> { - fn to_binary_with_type(&self, _ty: &DataType) -> super::to_binary::Result> { + fn to_binary_with_type( + &self, + _ty: &DataType, + ) -> super::to_binary::Result> { let mut output = bytes::BytesMut::new(); let buffer = self.to_be_bytes(); output.put_slice(&buffer); diff --git a/src/common/src/types/postgres_type.rs b/src/common/src/types/postgres_type.rs index 5e470182bad63..b43b955f8ae23 100644 --- a/src/common/src/types/postgres_type.rs +++ b/src/common/src/types/postgres_type.rs @@ -13,7 +13,6 @@ // limitations under the License. use super::DataType; -use crate::error::ErrorCode; /// `DataType` information extracted from PostgreSQL `pg_type` /// @@ -49,6 +48,10 @@ macro_rules! for_all_base_types { }; } +#[derive(Debug, thiserror::Error)] +#[error("Unsupported oid {0}")] +pub struct UnsupportedOid(i32); + /// Get type information compatible with Postgres type, such as oid, type length. impl DataType { pub fn type_len(&self) -> i16 { @@ -73,7 +76,7 @@ impl DataType { // Such as: // https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat#L347 // For Numeric(aka Decimal): oid = 1700, array_type_oid = 1231 - pub fn from_oid(oid: i32) -> crate::error::Result { + pub fn from_oid(oid: i32) -> Result { macro_rules! impl_from_oid { ($( { $enum:ident | $oid:literal | $oid_array:literal | $name:ident | $input:ident | $len:literal } )*) => { match oid { @@ -86,7 +89,7 @@ impl DataType { // workaround to support text in extended mode. 25 => Ok(DataType::Varchar), 1009 => Ok(DataType::List(Box::new(DataType::Varchar))), - _ => Err(ErrorCode::InternalError(format!("Unsupported oid {}", oid)).into()), + _ => Err(UnsupportedOid(oid)), } } } diff --git a/src/common/src/util/match_util.rs b/src/common/src/util/match_util.rs index 26982812d6499..9591f05340761 100644 --- a/src/common/src/util/match_util.rs +++ b/src/common/src/util/match_util.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +/// Try to match an enum variant and return the internal value. +/// +/// Return an [`anyhow::Error`] if the enum variant does not match. #[macro_export] macro_rules! try_match_expand { ($e:expr, $variant:path) => { @@ -32,6 +35,9 @@ macro_rules! try_match_expand { }; } +/// Match an enum variant and return the internal value. +/// +/// Panic if the enum variant does not match. #[macro_export] macro_rules! must_match { ($expression:expr, $(|)? $( $pattern:pat_param )|+ $( if $guard: expr )? => $action:expr) => { @@ -43,41 +49,39 @@ macro_rules! must_match { } mod tests { + #[derive(thiserror::Error, Debug)] + #[error(transparent)] + struct ExpandError(#[from] anyhow::Error); + + #[allow(dead_code)] + enum MyEnum { + A(String), + B, + } + #[test] - fn test_try_match() -> crate::error::Result<()> { + fn test_try_match() -> Result<(), ExpandError> { assert_eq!( - try_match_expand!( - crate::error::ErrorCode::InternalError("failure".to_string()), - crate::error::ErrorCode::InternalError - )?, + try_match_expand!(MyEnum::A("failure".to_string()), MyEnum::A)?, "failure" ); assert_eq!( - try_match_expand!( - crate::error::ErrorCode::InternalError("failure".to_string()), - crate::error::ErrorCode::InternalError - )?, + try_match_expand!(MyEnum::A("failure".to_string()), MyEnum::A)?, "failure" ); assert_eq!( - try_match_expand!( - crate::error::ErrorCode::InternalError("failure".to_string()), - crate::error::ErrorCode::InternalError - )?, + try_match_expand!(MyEnum::A("failure".to_string()), MyEnum::A)?, "failure" ); // Test let statement is compilable. - let err_str = try_match_expand!( - crate::error::ErrorCode::InternalError("failure".to_string()), - crate::error::ErrorCode::InternalError - )?; + let err_str = try_match_expand!(MyEnum::A("failure".to_string()), MyEnum::A)?; assert_eq!(err_str, "failure"); Ok(()) } #[test] - fn test_must_match() -> crate::error::Result<()> { + fn test_must_match() -> Result<(), ExpandError> { #[allow(dead_code)] enum A { Foo, diff --git a/src/frontend/src/binder/bind_param.rs b/src/frontend/src/binder/bind_param.rs index 44c11f62393ae..7f35b107c9dca 100644 --- a/src/frontend/src/binder/bind_param.rs +++ b/src/frontend/src/binder/bind_param.rs @@ -14,8 +14,9 @@ use bytes::Bytes; use pgwire::types::{Format, FormatIterator}; -use risingwave_common::error::{ErrorCode, Result, RwError}; -use risingwave_common::types::{Datum, ScalarImpl}; +use risingwave_common::bail; +use risingwave_common::error::{ErrorCode, Result}; +use risingwave_common::types::{Datum, FromSqlError, ScalarImpl}; use super::statement::RewriteExprsRecursive; use super::BoundStatement; @@ -26,7 +27,7 @@ pub(crate) struct ParamRewriter { pub(crate) params: Vec>, pub(crate) parsed_params: Vec, pub(crate) param_formats: Vec, - pub(crate) error: Option, + pub(crate) error: Option, } impl ParamRewriter { @@ -107,7 +108,7 @@ impl BoundStatement { self.rewrite_exprs_recursive(&mut rewriter); if let Some(err) = rewriter.error { - return Err(err); + bail!(err); } Ok((self, rewriter.parsed_params)) diff --git a/src/frontend/src/binder/expr/value.rs b/src/frontend/src/binder/expr/value.rs index e5ae8bb4e9156..54559266a136f 100644 --- a/src/frontend/src/binder/expr/value.rs +++ b/src/frontend/src/binder/expr/value.rs @@ -72,7 +72,8 @@ impl Binder { leading_field: Option, ) -> Result { let interval = - Interval::parse_with_fields(&s, leading_field.map(Self::bind_date_time_field))?; + Interval::parse_with_fields(&s, leading_field.map(Self::bind_date_time_field)) + .map_err(|e| ErrorCode::BindError(e.to_string()))?; let datum = Some(ScalarImpl::Interval(interval)); let literal = Literal::new(datum, DataType::Interval); diff --git a/src/frontend/src/catalog/system_catalog/mod.rs b/src/frontend/src/catalog/system_catalog/mod.rs index 897171cfd38b4..c85cddd4deab9 100644 --- a/src/frontend/src/catalog/system_catalog/mod.rs +++ b/src/frontend/src/catalog/system_catalog/mod.rs @@ -26,7 +26,7 @@ use risingwave_common::catalog::{ ColumnCatalog, ColumnDesc, Field, SysCatalogReader, TableDesc, TableId, DEFAULT_SUPER_USER_ID, NON_RESERVED_SYS_CATALOG_ID, }; -use risingwave_common::error::Result; +use risingwave_common::error::BoxedError; use risingwave_common::row::OwnedRow; use risingwave_common::types::DataType; use risingwave_pb::user::grant_privilege::Object; @@ -314,14 +314,14 @@ macro_rules! prepare_sys_catalog { #[async_trait] impl SysCatalogReader for SysCatalogReaderImpl { - async fn read_table(&self, table_id: &TableId) -> Result> { + async fn read_table(&self, table_id: &TableId) -> Result, BoxedError> { let table_name = SYS_CATALOGS.table_name_by_id.get(table_id).unwrap(); $( if $builtin_catalog.name() == *table_name { $( let rows = self.$func(); $(let rows = rows.$await;)? - return rows; + return Ok(rows?); )? } )* diff --git a/src/frontend/src/handler/util.rs b/src/frontend/src/handler/util.rs index 1be30c2d470eb..1c8dbd9d0714e 100644 --- a/src/frontend/src/handler/util.rs +++ b/src/frontend/src/handler/util.rs @@ -17,6 +17,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use anyhow::Context as _; use bytes::Bytes; use futures::Stream; use itertools::Itertools; @@ -125,7 +126,9 @@ fn pg_value_format( Ok(d.text_format(data_type).into()) } } - Format::Binary => d.binary_format(data_type), + Format::Binary => Ok(d + .binary_format(data_type) + .context("failed to format binary value")?), } } diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 0fba456b39207..f912860794a3f 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -27,7 +27,6 @@ use futures::future::Either; use futures::stream::StreamExt; use itertools::Itertools; use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod}; -use risingwave_common::error::RwError; use risingwave_common::types::DataType; use risingwave_common::util::panic::FutureCatchUnwindExt; use risingwave_sqlparser::ast::Statement; @@ -648,11 +647,12 @@ where if id == 0 { Ok(None) } else { - Ok(Some(DataType::from_oid(id)?)) + DataType::from_oid(id) + .map(Some) + .map_err(|e| PsqlError::ParseError(e.into())) } }) - .try_collect() - .map_err(|err: RwError| PsqlError::ParseError(err.into()))?; + .try_collect()?; let prepare_statement = session .parse(stmt, param_types) From 971f77c6db1aa2032bd82d757f9e6356234c137f Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 22 Nov 2023 16:25:38 +0800 Subject: [PATCH 3/4] add documentation Signed-off-by: Bugen Zhao --- src/common/common_service/src/observer_manager.rs | 1 + src/common/heap_profiling/src/jeprof.rs | 1 + src/common/src/types/interval.rs | 1 + src/common/src/types/mod.rs | 1 + src/common/src/types/to_binary.rs | 1 + 5 files changed, 5 insertions(+) diff --git a/src/common/common_service/src/observer_manager.rs b/src/common/common_service/src/observer_manager.rs index c6db757853941..c799995e41adb 100644 --- a/src/common/common_service/src/observer_manager.rs +++ b/src/common/common_service/src/observer_manager.rs @@ -78,6 +78,7 @@ impl ObserverManager { } } +/// Error type for [`ObserverManager`]. #[derive(thiserror::Error, Debug)] pub enum ObserverError { #[error("notification channel closed")] diff --git a/src/common/heap_profiling/src/jeprof.rs b/src/common/heap_profiling/src/jeprof.rs index 0a137148c2a16..0dea5f05f8d16 100644 --- a/src/common/heap_profiling/src/jeprof.rs +++ b/src/common/heap_profiling/src/jeprof.rs @@ -17,6 +17,7 @@ use std::process::Command; use std::result::Result; use std::{env, fs}; +/// Error type for running `jeprof`. #[derive(thiserror::Error, Debug, thiserror_ext::ContextInto)] pub enum JeprofError { #[error(transparent)] diff --git a/src/common/src/types/interval.rs b/src/common/src/types/interval.rs index a95bc412124b9..0990a5c8e61f4 100644 --- a/src/common/src/types/interval.rs +++ b/src/common/src/types/interval.rs @@ -1000,6 +1000,7 @@ impl ToText for crate::types::Interval { } } +/// Error type for parsing an [`Interval`]. #[derive(thiserror::Error, Debug, thiserror_ext::Construct)] pub enum IntervalParseError { #[error("Invalid interval: {0}")] diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index a700e5bcb0579..2450c2664e273 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -754,6 +754,7 @@ impl From> for ScalarImpl { } } +/// Error type for [`ScalarImpl::from_binary`] and [`ScalarImpl::from_text`]. #[derive(Debug, thiserror::Error, thiserror_ext::Construct)] pub enum FromSqlError { #[error(transparent)] diff --git a/src/common/src/types/to_binary.rs b/src/common/src/types/to_binary.rs index 981678b854c9d..46b5a61589493 100644 --- a/src/common/src/types/to_binary.rs +++ b/src/common/src/types/to_binary.rs @@ -18,6 +18,7 @@ use postgres_types::{ToSql, Type}; use super::{DataType, DatumRef, ScalarRefImpl, F32, F64}; use crate::error::TrackingIssue; +/// Error type for [`ToBinary`] trait. #[derive(thiserror::Error, Debug)] pub enum ToBinaryError { #[error(transparent)] From b87318c36c821ed5210e09168821bf1961ebe29f Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 24 Nov 2023 14:12:47 +0800 Subject: [PATCH 4/4] refine session config error Signed-off-by: Bugen Zhao --- src/common/proc_macro/src/session_config.rs | 30 ++++++++++--------- src/common/src/error.rs | 15 ++++++++-- src/common/src/session_config/mod.rs | 21 +++++++++++-- src/common/src/session_config/non_zero64.rs | 6 ++-- src/common/src/session_config/over_window.rs | 4 +-- src/common/src/session_config/query_mode.rs | 4 +-- src/common/src/session_config/search_path.rs | 4 +-- .../src/session_config/sink_decouple.rs | 4 +-- .../transaction_isolation_level.rs | 4 +-- .../src/session_config/visibility_mode.rs | 4 +-- src/frontend/src/session.rs | 10 +++++-- 11 files changed, 70 insertions(+), 36 deletions(-) diff --git a/src/common/proc_macro/src/session_config.rs b/src/common/proc_macro/src/session_config.rs index ea29e0a20fec8..6b622241b1296 100644 --- a/src/common/proc_macro/src/session_config.rs +++ b/src/common/proc_macro/src/session_config.rs @@ -104,10 +104,11 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { let check_hook = if let Some(check_hook_name) = check_hook_name { quote! { - #check_hook_name(&val).map_err(|_e| { - ErrorCode::InvalidConfigValue { - config_entry: #entry_name.to_string(), - config_value: val.to_string(), + #check_hook_name(&val).map_err(|e| { + SessionConfigError::InvalidValue { + entry: #entry_name, + value: val.to_string(), + source: anyhow::anyhow!(e), } })?; } @@ -131,11 +132,12 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { &mut self, val: &str, reporter: &mut impl ConfigReporter - ) -> RwResult<()> { - let val_t: #ty = val.parse().map_err(|_e| { - ErrorCode::InvalidConfigValue { - config_entry: #entry_name.to_string(), - config_value: val.to_string(), + ) -> SessionConfigResult<()> { + let val_t = <#ty as ::std::str::FromStr>::from_str(val).map_err(|e| { + SessionConfigError::InvalidValue { + entry: #entry_name, + value: val.to_string(), + source: anyhow::anyhow!(e), } })?; @@ -148,7 +150,7 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { &mut self, val: #ty, reporter: &mut impl ConfigReporter - ) -> RwResult<()> { + ) -> SessionConfigResult<()> { #check_hook #report_hook @@ -236,18 +238,18 @@ pub(crate) fn derive_config(input: DeriveInput) -> TokenStream { #(#struct_impl_reset)* /// Set a parameter given it's name and value string. - pub fn set(&mut self, key_name: &str, value: String, reporter: &mut impl ConfigReporter) -> RwResult<()> { + pub fn set(&mut self, key_name: &str, value: String, reporter: &mut impl ConfigReporter) -> SessionConfigResult<()> { match key_name.to_ascii_lowercase().as_ref() { #(#set_match_branches)* - _ => Err(ErrorCode::UnrecognizedConfigurationParameter(key_name.to_string()).into()), + _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())), } } /// Get a parameter by it's name. - pub fn get(&self, key_name: &str) -> RwResult { + pub fn get(&self, key_name: &str) -> SessionConfigResult { match key_name.to_ascii_lowercase().as_ref() { #(#get_match_branches)* - _ => Err(ErrorCode::UnrecognizedConfigurationParameter(key_name.to_string()).into()), + _ => Err(SessionConfigError::UnrecognizedEntry(key_name.to_string())), } } diff --git a/src/common/src/error.rs b/src/common/src/error.rs index 3e6cef078ad71..2e1e0be5b343b 100644 --- a/src/common/src/error.rs +++ b/src/common/src/error.rs @@ -26,6 +26,7 @@ use thiserror::Error; use tokio::task::JoinError; use crate::array::ArrayError; +use crate::session_config::SessionConfigError; use crate::util::value_encoding::error::ValueEncodingError; const ERROR_SUPPRESSOR_RESET_DURATION: Duration = Duration::from_millis(60 * 60 * 1000); // 1h @@ -191,8 +192,12 @@ pub enum ErrorCode { ), #[error("Permission denied: {0}")] PermissionDenied(String), - #[error("unrecognized configuration parameter \"{0}\"")] - UnrecognizedConfigurationParameter(String), + #[error("Failed to get/set session config: {0}")] + SessionConfig( + #[from] + #[backtrace] + SessionConfigError, + ), } pub fn internal_error(msg: impl Into) -> RwError { @@ -335,6 +340,12 @@ impl From for RwError { } } +impl From for RwError { + fn from(value: SessionConfigError) -> Self { + ErrorCode::SessionConfig(value).into() + } +} + impl From for RwError { fn from(err: tonic::transport::Error) -> Self { ErrorCode::RpcError(err.into()).into() diff --git a/src/common/src/session_config/mod.rs b/src/common/src/session_config/mod.rs index 0600684b9ff90..6e45acd5c378c 100644 --- a/src/common/src/session_config/mod.rs +++ b/src/common/src/session_config/mod.rs @@ -25,15 +25,30 @@ pub use over_window::OverWindowCachePolicy; pub use query_mode::QueryMode; use risingwave_common_proc_macro::SessionConfig; pub use search_path::{SearchPath, USER_NAME_WILD_CARD}; +use thiserror::Error; use self::non_zero64::ConfigNonZeroU64; -use crate::error::{ErrorCode, Result as RwResult}; use crate::session_config::sink_decouple::SinkDecouple; use crate::session_config::transaction_isolation_level::IsolationLevel; pub use crate::session_config::visibility_mode::VisibilityMode; pub const SESSION_CONFIG_LIST_SEP: &str = ", "; +#[derive(Error, Debug)] +pub enum SessionConfigError { + #[error("Invalid value `{value}` for `{entry}`")] + InvalidValue { + entry: &'static str, + value: String, + source: anyhow::Error, + }, + + #[error("Unrecognized config entry `{0}`")] + UnrecognizedEntry(String), +} + +type SessionConfigResult = std::result::Result; + /// This is the Session Config of RisingWave. #[derive(SessionConfig)] pub struct ConfigMap { @@ -255,7 +270,7 @@ impl ConfigMap { &mut self, val: bool, reporter: &mut impl ConfigReporter, - ) -> RwResult<()> { + ) -> SessionConfigResult<()> { self.set_force_two_phase_agg_inner(val, reporter)?; if self.force_two_phase_agg { self.set_enable_two_phase_agg(true, reporter) @@ -268,7 +283,7 @@ impl ConfigMap { &mut self, val: bool, reporter: &mut impl ConfigReporter, - ) -> RwResult<()> { + ) -> SessionConfigResult<()> { self.set_enable_two_phase_agg_inner(val, reporter)?; if !self.force_two_phase_agg { self.set_force_two_phase_agg(false, reporter) diff --git a/src/common/src/session_config/non_zero64.rs b/src/common/src/session_config/non_zero64.rs index 2bec6eb5f462f..66e678272c41d 100644 --- a/src/common/src/session_config/non_zero64.rs +++ b/src/common/src/session_config/non_zero64.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::num::NonZeroU64; +use std::num::{NonZeroU64, ParseIntError}; use std::str::FromStr; /// When set this config as `0`, the value is `None`, otherwise the value is @@ -21,10 +21,10 @@ use std::str::FromStr; pub struct ConfigNonZeroU64(pub Option); impl FromStr for ConfigNonZeroU64 { - type Err = (); + type Err = ParseIntError; fn from_str(s: &str) -> Result { - let parsed = s.parse::().map_err(|_| ())?; + let parsed = s.parse::()?; if parsed == 0 { Ok(Self(None)) } else { diff --git a/src/common/src/session_config/over_window.rs b/src/common/src/session_config/over_window.rs index d3c4e23433f3a..d3dbff7d61833 100644 --- a/src/common/src/session_config/over_window.rs +++ b/src/common/src/session_config/over_window.rs @@ -33,7 +33,7 @@ pub enum OverWindowCachePolicy { } impl FromStr for OverWindowCachePolicy { - type Err = (); + type Err = &'static str; fn from_str(s: &str) -> Result { let s = s.to_ascii_lowercase().replace('-', "_"); @@ -42,7 +42,7 @@ impl FromStr for OverWindowCachePolicy { "recent" => Ok(Self::Recent), "recent_first_n" => Ok(Self::RecentFirstN), "recent_last_n" => Ok(Self::RecentLastN), - _ => Err(()), + _ => Err("expect one of [full, recent, recent_first_n, recent_last_n]"), } } } diff --git a/src/common/src/session_config/query_mode.rs b/src/common/src/session_config/query_mode.rs index 6879a73bfed9b..2384520a088ea 100644 --- a/src/common/src/session_config/query_mode.rs +++ b/src/common/src/session_config/query_mode.rs @@ -28,7 +28,7 @@ pub enum QueryMode { } impl FromStr for QueryMode { - type Err = (); + type Err = &'static str; fn from_str(s: &str) -> Result { if s.eq_ignore_ascii_case("local") { @@ -38,7 +38,7 @@ impl FromStr for QueryMode { } else if s.eq_ignore_ascii_case("auto") { Ok(Self::Auto) } else { - Err(()) + Err("expect one of [local, distributed, auto]") } } } diff --git a/src/common/src/session_config/search_path.rs b/src/common/src/session_config/search_path.rs index c176503824d43..2802222cda899 100644 --- a/src/common/src/session_config/search_path.rs +++ b/src/common/src/session_config/search_path.rs @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::convert::Infallible; use std::str::FromStr; use super::SESSION_CONFIG_LIST_SEP; use crate::catalog::{DEFAULT_SCHEMA_NAME, PG_CATALOG_SCHEMA_NAME, RW_CATALOG_SCHEMA_NAME}; -use crate::error::RwError; pub const USER_NAME_WILD_CARD: &str = "\"$user\""; @@ -59,7 +59,7 @@ impl Default for SearchPath { } impl FromStr for SearchPath { - type Err = RwError; + type Err = Infallible; fn from_str(s: &str) -> Result { let paths = s.split(SESSION_CONFIG_LIST_SEP).map(|path| path.trim()); diff --git a/src/common/src/session_config/sink_decouple.rs b/src/common/src/session_config/sink_decouple.rs index 0129654b138da..6b6e3049fd4e1 100644 --- a/src/common/src/session_config/sink_decouple.rs +++ b/src/common/src/session_config/sink_decouple.rs @@ -26,14 +26,14 @@ pub enum SinkDecouple { } impl FromStr for SinkDecouple { - type Err = (); + type Err = &'static str; fn from_str(s: &str) -> Result { match s.to_ascii_lowercase().as_str() { "true" | "enable" => Ok(Self::Enable), "false" | "disable" => Ok(Self::Disable), "default" => Ok(Self::Default), - _ => Err(()), + _ => Err("expect one of [true, enable, false, disable, default]"), } } } diff --git a/src/common/src/session_config/transaction_isolation_level.rs b/src/common/src/session_config/transaction_isolation_level.rs index af558e8525dd9..5f122075f41f8 100644 --- a/src/common/src/session_config/transaction_isolation_level.rs +++ b/src/common/src/session_config/transaction_isolation_level.rs @@ -27,10 +27,10 @@ pub enum IsolationLevel { } impl FromStr for IsolationLevel { - type Err = (); + type Err = &'static str; fn from_str(_s: &str) -> Result { - Err(()) + Err("isolation level is not yet supported") } } diff --git a/src/common/src/session_config/visibility_mode.rs b/src/common/src/session_config/visibility_mode.rs index b8aa4f6faef3a..ece1bdc99e9d2 100644 --- a/src/common/src/session_config/visibility_mode.rs +++ b/src/common/src/session_config/visibility_mode.rs @@ -29,7 +29,7 @@ pub enum VisibilityMode { } impl FromStr for VisibilityMode { - type Err = (); + type Err = &'static str; fn from_str(s: &str) -> Result { if s.eq_ignore_ascii_case("all") { @@ -39,7 +39,7 @@ impl FromStr for VisibilityMode { } else if s.eq_ignore_ascii_case("default") { Ok(Self::Default) } else { - Err(()) + Err("expect one of [all, checkpoint, default]") } } } diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index dc22d5635e58e..55d9756187bbd 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -608,7 +608,10 @@ impl SessionImpl { } pub fn set_config(&self, key: &str, value: String) -> Result<()> { - self.config_map.write().set(key, value, &mut ()) + self.config_map + .write() + .set(key, value, &mut ()) + .map_err(Into::into) } pub fn set_config_report( @@ -617,7 +620,10 @@ impl SessionImpl { value: String, mut reporter: impl ConfigReporter, ) -> Result<()> { - self.config_map.write().set(key, value, &mut reporter) + self.config_map + .write() + .set(key, value, &mut reporter) + .map_err(Into::into) } pub fn session_id(&self) -> SessionId {