Skip to content

Commit

Permalink
fix(udf): check the data type returned from UDF server (#12202)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Sep 12, 2023
1 parent 8b61c92 commit faa1bcc
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,14 @@ impl DataType {
}
d
}

/// Compares the datatype with another, ignoring nested field names and metadata.
pub fn equals_datatype(&self, other: &DataType) -> bool {
match (self, other) {
(Self::Struct(s1), Self::Struct(s2)) => s1.equals_datatype(s2),
_ => self == other,
}
}
}

impl From<DataType> for PbDataType {
Expand Down
10 changes: 10 additions & 0 deletions src/common/src/types/struct_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ impl StructType {
.chain(std::iter::repeat("").take(self.0.field_types.len() - self.0.field_names.len()))
.zip_eq_debug(self.0.field_types.iter())
}

/// Compares the datatype with another, ignoring nested field names and metadata.
pub fn equals_datatype(&self, other: &StructType) -> bool {
if self.0.field_types.len() != other.0.field_types.len() {
return false;
}
(self.0.field_types.iter())
.zip_eq_fast(other.0.field_types.iter())
.all(|(a, b)| a.equals_datatype(b))
}
}

impl Display for StructType {
Expand Down
7 changes: 7 additions & 0 deletions src/expr/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ impl UdfExpression {
};
let mut array = ArrayImpl::try_from(arrow_array)?;
array.set_bitmap(array.null_bitmap() & vis);
if !array.data_type().equals_datatype(&self.return_type) {
bail!(
"UDF returned {:?}, but expected {:?}",
array.data_type(),
self.return_type,
);
}
Ok(Arc::new(array))
}
}
Expand Down
30 changes: 30 additions & 0 deletions src/expr/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,39 @@ impl UserDefinedTableFunction {
.await?
{
let output = DataChunk::try_from(&res?)?;
self.check_output(&output)?;
yield output;
}
}

/// Check if the output chunk is valid.
fn check_output(&self, output: &DataChunk) -> Result<()> {
if output.columns().len() != 2 {
bail!(
"UDF returned {} columns, but expected 2",
output.columns().len()
);
}
if output.column_at(0).data_type() != DataType::Int32 {
bail!(
"UDF returned {:?} at column 0, but expected {:?}",
output.column_at(0).data_type(),
DataType::Int32,
);
}
if !output
.column_at(1)
.data_type()
.equals_datatype(&self.return_type)
{
bail!(
"UDF returned {:?} at column 1, but expected {:?}",
output.column_at(1).data_type(),
&self.return_type,
);
}
Ok(())
}
}

#[cfg(not(madsim))]
Expand Down

0 comments on commit faa1bcc

Please sign in to comment.