Skip to content

Commit

Permalink
fix(udf): check udf schema fields num and total records (#12206)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Axel <[email protected]>
  • Loading branch information
KveinAxel authored and wangrunji0408 committed Sep 18, 2023
1 parent e8ad2d6 commit 82af04e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/udf/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ pub enum Error {

#[error("UDF service returned no data")]
NoReturned,

#[error("Flight service error: {0}")]
ServiceError(String),
}

static_assertions::const_assert_eq!(std::mem::size_of::<Error>(), 32);
Expand Down
9 changes: 9 additions & 0 deletions src/udf/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ impl ArrowFlightUdfClient {
let input_num = info.total_records as usize;
let full_schema = Schema::try_from(info)
.map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
if input_num > full_schema.fields.len() {
return Err(Error::ServiceError(format!(
"function {:?} schema info not consistency: input_num: {}, total_fields: {}",
id,
input_num,
full_schema.fields.len()
)));
}

let (input_fields, return_fields) = full_schema.fields.split_at(input_num);
let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect();
let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect();
Expand Down

0 comments on commit 82af04e

Please sign in to comment.