diff --git a/src/udf/src/error.rs b/src/udf/src/error.rs index d787808a32810..f20816ee5b2c0 100644 --- a/src/udf/src/error.rs +++ b/src/udf/src/error.rs @@ -40,6 +40,9 @@ pub enum Error { #[error("UDF service returned no data")] NoReturned, + + #[error("Flight service error: {0}")] + ServiceError(String), } static_assertions::const_assert_eq!(std::mem::size_of::(), 32); diff --git a/src/udf/src/external.rs b/src/udf/src/external.rs index 1a666f4d4e378..585adc7ebec5b 100644 --- a/src/udf/src/external.rs +++ b/src/udf/src/external.rs @@ -49,6 +49,15 @@ impl ArrowFlightUdfClient { let input_num = info.total_records as usize; let full_schema = Schema::try_from(info) .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?; + if input_num > full_schema.fields.len() { + return Err(Error::ServiceError(format!( + "function {:?} schema info not consistency: input_num: {}, total_fields: {}", + id, + input_num, + full_schema.fields.len() + ))); + } + let (input_fields, return_fields) = full_schema.fields.split_at(input_num); let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect(); let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect();