diff --git a/Cargo.lock b/Cargo.lock index a0b3e08b46441..09f9d7070a414 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8424,6 +8424,7 @@ dependencies = [ "madsim-tonic", "static_assertions", "thiserror", + "thiserror-ext", ] [[package]] @@ -9966,6 +9967,27 @@ dependencies = [ "thiserror-impl", ] +[[package]] +name = "thiserror-ext" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f28a4a7351f496662affc257826b85dd2a613406ce3cc2f07b849685e166d8c" +dependencies = [ + "thiserror", + "thiserror-ext-derive", +] + +[[package]] +name = "thiserror-ext-derive" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67621f6d39449754da63668ddd2423ad0c81c27434c16090f8805ad1db59b621" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + [[package]] name = "thiserror-impl" version = "1.0.48" diff --git a/Cargo.toml b/Cargo.toml index dc38b19e237f4..919c990deb7d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -120,6 +120,7 @@ arrow-buffer = "48" arrow-flight = "48" arrow-select = "48" arrow-ord = "48" +thiserror-ext = "0.0.1" tikv-jemalloc-ctl = { git = "https://github.com/risingwavelabs/jemallocator.git", rev = "64a2d9" } tikv-jemallocator = { git = "https://github.com/risingwavelabs/jemallocator.git", features = [ "profiling", diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index a11af2434b4f9..8aae5beae60a4 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -150,7 +150,7 @@ impl Build for UdfExpression { "", DataType::from(t) .try_into() - .map_err(risingwave_udf::Error::Unsupported)?, + .map_err(risingwave_udf::Error::unsupported)?, true, )) }) diff --git a/src/expr/core/src/table_function/user_defined.rs b/src/expr/core/src/table_function/user_defined.rs index 60fde34f9df1f..cbb8682ba48a7 100644 --- a/src/expr/core/src/table_function/user_defined.rs +++ b/src/expr/core/src/table_function/user_defined.rs @@ -140,7 +140,7 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result = std::result::Result; /// The error type for UDF operations. -#[derive(thiserror::Error, Debug)] -pub enum Error { +#[derive(Error, Debug, Box, Construct)] +#[thiserror_ext(type = Error)] +pub enum ErrorInner { #[error("failed to connect to UDF service: {0}")] Connect(#[from] tonic::transport::Error), #[error("failed to send requests to UDF service: {0}")] - Tonic(#[from] Box), + Tonic(#[from] tonic::Status), #[error("failed to call UDF: {0}")] - Flight(#[from] Box), + Flight(#[from] FlightError), #[error("type mismatch: {0}")] TypeMismatch(String), @@ -45,16 +48,4 @@ pub enum Error { ServiceError(String), } -static_assertions::const_assert_eq!(std::mem::size_of::(), 40); - -impl From for Error { - fn from(status: tonic::Status) -> Self { - Error::from(Box::new(status)) - } -} - -impl From for Error { - fn from(error: FlightError) -> Self { - Error::from(Box::new(error)) - } -} +static_assertions::const_assert_eq!(std::mem::size_of::(), 8); diff --git a/src/udf/src/external.rs b/src/udf/src/external.rs index e77b96f2bdab4..c5ed44850f1bf 100644 --- a/src/udf/src/external.rs +++ b/src/udf/src/external.rs @@ -59,7 +59,7 @@ impl ArrowFlightUdfClient { let full_schema = Schema::try_from(info) .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?; if input_num > full_schema.fields.len() { - return Err(Error::ServiceError(format!( + return Err(Error::service_error(format!( "function {:?} schema info not consistency: input_num: {}, total_fields: {}", id, input_num, @@ -73,13 +73,13 @@ impl ArrowFlightUdfClient { let expect_input_types: Vec<_> = args.fields.iter().map(|f| f.data_type()).collect(); let expect_result_types: Vec<_> = returns.fields.iter().map(|f| f.data_type()).collect(); if !data_types_match(&expect_input_types, &actual_input_types) { - return Err(Error::TypeMismatch(format!( + return Err(Error::type_mismatch(format!( "function: {:?}, expect arguments: {:?}, actual: {:?}", id, expect_input_types, actual_input_types ))); } if !data_types_match(&expect_result_types, &actual_result_types) { - return Err(Error::TypeMismatch(format!( + return Err(Error::type_mismatch(format!( "function: {:?}, expect return: {:?}, actual: {:?}", id, expect_result_types, actual_result_types ))); @@ -91,7 +91,10 @@ impl ArrowFlightUdfClient { pub async fn call(&self, id: &str, input: RecordBatch) -> Result { let mut output_stream = self.call_stream(id, stream::once(async { input })).await?; // TODO: support no output - let head = output_stream.next().await.ok_or(Error::NoReturned)??; + let head = output_stream + .next() + .await + .ok_or_else(Error::no_returned)??; let mut remaining = vec![]; while let Some(batch) = output_stream.next().await { remaining.push(batch?); diff --git a/src/udf/src/lib.rs b/src/udf/src/lib.rs index 513551a9108af..25207c2c19edf 100644 --- a/src/udf/src/lib.rs +++ b/src/udf/src/lib.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![feature(error_generic_member_access)] + mod error; mod external;