Skip to content

Commit

Permalink
cherrypick fix(udf): handle visibility of input chunks in UDTF (#12357)…
Browse files Browse the repository at this point in the history
… (#12368)

Signed-off-by: Kevin Axel <[email protected]>
Signed-off-by: Runji Wang <[email protected]>
Co-authored-by: Kevin Axel <[email protected]>
Co-authored-by: stonepage <[email protected]>
  • Loading branch information
3 people authored Sep 19, 2023
1 parent 2673a39 commit 9f0c199
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 12 deletions.
29 changes: 29 additions & 0 deletions e2e_test/udf/udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,35 @@ select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d9
----
192.168.0.14 192.168.0.1 861 8374

# steaming
# to ensure UDF & UDTF respect visibility

statement ok
create table t (x int);

statement ok
create materialized view mv as select gcd(x, x), series(x) from t where x <> 2;

statement ok
insert into t values (1), (2), (3);

statement ok
flush;

query II
select * from mv;
----
1 0
3 0
3 1
3 2

statement ok
drop materialized view mv;

statement ok
drop table t;

# error handling

statement error
Expand Down
6 changes: 4 additions & 2 deletions src/common/src/array/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::util::iter_util::ZipEqDebug;

// Implement bi-directional `From` between `DataChunk` and `arrow_array::RecordBatch`.

// note: DataChunk -> arrow RecordBatch will IGNORE the visibilities.
impl TryFrom<&DataChunk> for arrow_array::RecordBatch {
type Error = ArrayError;

Expand All @@ -47,8 +48,9 @@ impl TryFrom<&DataChunk> for arrow_array::RecordBatch {
.collect();

let schema = Arc::new(Schema::new(fields));

arrow_array::RecordBatch::try_new(schema, columns)
let opts =
arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity()));
arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts)
.map_err(|err| ArrayError::ToArrow(err.to_string()))
}
}
Expand Down
23 changes: 23 additions & 0 deletions src/common/src/array/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::borrow::Cow;
use std::hash::BuildHasher;
use std::sync::Arc;
use std::{fmt, usize};
Expand Down Expand Up @@ -261,6 +262,28 @@ impl DataChunk {
}
}

/// Convert the chunk to compact format.
///
/// If the chunk is not compacted, return a new compacted chunk, otherwise return a reference to
/// self.
pub fn compact_cow(&self) -> Cow<'_, Self> {
match &self.vis2 {
Vis::Compact(_) => Cow::Borrowed(self),
Vis::Bitmap(visibility) => {
let cardinality = visibility.count_ones();
let columns = self
.columns
.iter()
.map(|col| {
let array = col;
array.compact(visibility, cardinality).into()
})
.collect::<Vec<_>>();
Cow::Owned(Self::new(columns, cardinality))
}
}
}

pub fn from_protobuf(proto: &PbDataChunk) -> ArrayResult<Self> {
let mut columns = vec![];
for any_col in proto.get_columns() {
Expand Down
8 changes: 8 additions & 0 deletions src/common/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,14 @@ impl DataType {
}
d
}

/// Compares the datatype with another, ignoring nested field names and metadata.
pub fn equals_datatype(&self, other: &DataType) -> bool {
match (self, other) {
(Self::Struct(s1), Self::Struct(s2)) => s1.equals_datatype(s2),
_ => self == other,
}
}
}

impl From<DataType> for PbDataType {
Expand Down
10 changes: 10 additions & 0 deletions src/common/src/types/struct_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ impl StructType {
.chain(std::iter::repeat("").take(self.0.field_types.len() - self.0.field_names.len()))
.zip_eq_debug(self.0.field_types.iter())
}

/// Compares the datatype with another, ignoring nested field names and metadata.
pub fn equals_datatype(&self, other: &StructType) -> bool {
if self.0.field_types.len() != other.0.field_types.len() {
return false;
}
(self.0.field_types.iter())
.zip_eq_fast(other.0.field_types.iter())
.all(|(a, b)| a.equals_datatype(b))
}
}

impl Display for StructType {
Expand Down
7 changes: 7 additions & 0 deletions src/expr/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ impl UdfExpression {
};
let mut array = ArrayImpl::try_from(arrow_array)?;
array.set_bitmap(array.null_bitmap() & vis);
if !array.data_type().equals_datatype(&self.return_type) {
bail!(
"UDF returned {:?}, but expected {:?}",
array.data_type(),
self.return_type,
);
}
Ok(Arc::new(array))
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/table_function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub trait TableFunction: std::fmt::Debug + Sync + Send {
/// # Contract of the output
///
/// The returned `DataChunk` contains exact two columns:
/// - The first column is an I32Array containing row indexes of input chunk. It should be
/// - The first column is an I32Array containing row indices of input chunk. It should be
/// monotonically increasing.
/// - The second column is the output values. The data type of the column is `return_type`.
///
Expand Down Expand Up @@ -80,7 +80,7 @@ pub trait TableFunction: std::fmt::Debug + Sync + Send {
/// (You don't need to understand this section to implement a `TableFunction`)
///
/// The output of the `TableFunction` is different from the output of the `ProjectSet` executor.
/// `ProjectSet` executor uses the row indexes to stitch multiple table functions and produces
/// `ProjectSet` executor uses the row indices to stitch multiple table functions and produces
/// `projected_row_id`.
///
/// ## Example
Expand Down
69 changes: 61 additions & 8 deletions src/expr/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

use std::sync::Arc;

use arrow_array::RecordBatch;
use arrow_schema::{Field, Fields, Schema, SchemaRef};
use futures_util::stream;
use risingwave_common::array::DataChunk;
use risingwave_common::array::{DataChunk, I32Array};
use risingwave_common::bail;
use risingwave_udf::ArrowFlightUdfClient;

Expand All @@ -25,6 +26,7 @@ use super::*;
#[derive(Debug)]
pub struct UserDefinedTableFunction {
children: Vec<BoxedExpression>,
#[allow(dead_code)]
arg_schema: SchemaRef,
return_type: DataType,
client: Arc<ArrowFlightUdfClient>,
Expand All @@ -49,27 +51,78 @@ impl TableFunction for UserDefinedTableFunction {
impl UserDefinedTableFunction {
#[try_stream(boxed, ok = DataChunk, error = ExprError)]
async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
// evaluate children expressions
let mut columns = Vec::with_capacity(self.children.len());
for c in &self.children {
let val = c.eval_checked(input).await?.as_ref().try_into()?;
let val = c.eval_checked(input).await?;
columns.push(val);
}
let direct_input = DataChunk::new(columns, input.vis().clone());

// compact the input chunk and record the row mapping
let visible_rows = direct_input.vis().iter_ones().collect_vec();
let compacted_input = direct_input.compact_cow();
let arrow_input = RecordBatch::try_from(compacted_input.as_ref())?;

let opts =
arrow_array::RecordBatchOptions::default().with_row_count(Some(input.cardinality()));
let input =
arrow_array::RecordBatch::try_new_with_options(self.arg_schema.clone(), columns, &opts)
.expect("failed to build record batch");
// call UDTF
#[for_await]
for res in self
.client
.call_stream(&self.identifier, stream::once(async { input }))
.call_stream(&self.identifier, stream::once(async { arrow_input }))
.await?
{
let output = DataChunk::try_from(&res?)?;
self.check_output(&output)?;

// we send the compacted input to UDF, so we need to map the row indices back to the
// original input
let origin_indices = output
.column_at(0)
.as_int32()
.raw_iter()
// we have checked all indices are non-negative
.map(|idx| visible_rows[idx as usize] as i32)
.collect::<I32Array>();

let output = DataChunk::new(
vec![origin_indices.into_ref(), output.column_at(1).clone()],
output.vis().clone(),
);
yield output;
}
}

/// Check if the output chunk is valid.
fn check_output(&self, output: &DataChunk) -> Result<()> {
if output.columns().len() != 2 {
bail!(
"UDF returned {} columns, but expected 2",
output.columns().len()
);
}
if output.column_at(0).data_type() != DataType::Int32 {
bail!(
"UDF returned {:?} at column 0, but expected {:?}",
output.column_at(0).data_type(),
DataType::Int32,
);
}
if output.column_at(0).as_int32().raw_iter().any(|i| i < 0) {
bail!("UDF returned negative row index");
}
if !output
.column_at(1)
.data_type()
.equals_datatype(&self.return_type)
{
bail!(
"UDF returned {:?} at column 1, but expected {:?}",
output.column_at(1).data_type(),
&self.return_type,
);
}
Ok(())
}
}

#[cfg(not(madsim))]
Expand Down
3 changes: 3 additions & 0 deletions src/udf/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ pub enum Error {

#[error("UDF service returned no data")]
NoReturned,

#[error("Flight service error: {0}")]
ServiceError(String),
}

static_assertions::const_assert_eq!(std::mem::size_of::<Error>(), 32);
Expand Down
9 changes: 9 additions & 0 deletions src/udf/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ impl ArrowFlightUdfClient {
let input_num = info.total_records as usize;
let full_schema = Schema::try_from(info)
.map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?;
if input_num > full_schema.fields.len() {
return Err(Error::ServiceError(format!(
"function {:?} schema info not consistency: input_num: {}, total_fields: {}",
id,
input_num,
full_schema.fields.len()
)));
}

let (input_fields, return_fields) = full_schema.fields.split_at(input_num);
let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect();
let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect();
Expand Down

0 comments on commit 9f0c199

Please sign in to comment.