Skip to content

Commit

Permalink
fix(udf): handle visibility of input chunks in UDTF (#12357)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 committed Sep 18, 2023
1 parent 2ea80bd commit 1616b82
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 12 deletions.
29 changes: 29 additions & 0 deletions e2e_test/udf/udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,35 @@ select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d9
----
192.168.0.14 192.168.0.1 861 8374

# steaming
# to ensure UDF & UDTF respect visibility

statement ok
create table t (x int);

statement ok
create materialized view mv as select gcd(x, x), series(x) from t where x <> 2;

statement ok
insert into t values (1), (2), (3);

statement ok
flush;

query II
select * from mv;
----
1 0
3 0
3 1
3 2

statement ok
drop materialized view mv;

statement ok
drop table t;

# error handling

statement error
Expand Down
6 changes: 4 additions & 2 deletions src/common/src/array/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::util::iter_util::ZipEqDebug;

// Implement bi-directional `From` between `DataChunk` and `arrow_array::RecordBatch`.

// note: DataChunk -> arrow RecordBatch will IGNORE the visibilities.
impl TryFrom<&DataChunk> for arrow_array::RecordBatch {
type Error = ArrayError;

Expand All @@ -47,8 +48,9 @@ impl TryFrom<&DataChunk> for arrow_array::RecordBatch {
.collect();

let schema = Arc::new(Schema::new(fields));

arrow_array::RecordBatch::try_new(schema, columns)
let opts =
arrow_array::RecordBatchOptions::default().with_row_count(Some(chunk.capacity()));
arrow_array::RecordBatch::try_new_with_options(schema, columns, &opts)
.map_err(|err| ArrayError::ToArrow(err.to_string()))
}
}
Expand Down
23 changes: 23 additions & 0 deletions src/common/src/array/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::borrow::Cow;
use std::fmt::Display;
use std::hash::BuildHasher;
use std::sync::Arc;
use std::{fmt, usize};
Expand Down Expand Up @@ -261,6 +263,27 @@ impl DataChunk {
}
}

/// Convert the chunk to compact format.
///
/// If the chunk is not compacted, return a new compacted chunk, otherwise return a reference to self.
pub fn compact_cow(&self) -> Cow<'_, Self> {
match &self.vis2 {
Vis::Compact(_) => Cow::Borrowed(self),
Vis::Bitmap(visibility) => {
let cardinality = visibility.count_ones();
let columns = self
.columns
.iter()
.map(|col| {
let array = col;
array.compact(visibility, cardinality).into()
})
.collect::<Vec<_>>();
Cow::Owned(Self::new(columns, cardinality))
}
}
}

pub fn from_protobuf(proto: &PbDataChunk) -> ArrayResult<Self> {
let mut columns = vec![];
for any_col in proto.get_columns() {
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/table_function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub trait TableFunction: std::fmt::Debug + Sync + Send {
/// # Contract of the output
///
/// The returned `DataChunk` contains exact two columns:
/// - The first column is an I32Array containing row indexes of input chunk. It should be
/// - The first column is an I32Array containing row indices of input chunk. It should be
/// monotonically increasing.
/// - The second column is the output values. The data type of the column is `return_type`.
///
Expand Down Expand Up @@ -80,7 +80,7 @@ pub trait TableFunction: std::fmt::Debug + Sync + Send {
/// (You don't need to understand this section to implement a `TableFunction`)
///
/// The output of the `TableFunction` is different from the output of the `ProjectSet` executor.
/// `ProjectSet` executor uses the row indexes to stitch multiple table functions and produces
/// `ProjectSet` executor uses the row indices to stitch multiple table functions and produces
/// `projected_row_id`.
///
/// ## Example
Expand Down
38 changes: 30 additions & 8 deletions src/expr/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

use std::sync::Arc;

use arrow_array::RecordBatch;
use arrow_schema::{Field, Fields, Schema, SchemaRef};
use futures_util::stream;
use risingwave_common::array::DataChunk;
use risingwave_common::array::{DataChunk, I32Array};
use risingwave_common::bail;
use risingwave_udf::ArrowFlightUdfClient;

Expand All @@ -25,6 +26,7 @@ use super::*;
#[derive(Debug)]
pub struct UserDefinedTableFunction {
children: Vec<BoxedExpression>,
#[allow(dead_code)]
arg_schema: SchemaRef,
return_type: DataType,
client: Arc<ArrowFlightUdfClient>,
Expand All @@ -49,25 +51,42 @@ impl TableFunction for UserDefinedTableFunction {
impl UserDefinedTableFunction {
#[try_stream(boxed, ok = DataChunk, error = ExprError)]
async fn eval_inner<'a>(&'a self, input: &'a DataChunk) {
// evaluate children expressions
let mut columns = Vec::with_capacity(self.children.len());
for c in &self.children {
let val = c.eval_checked(input).await?.as_ref().try_into()?;
let val = c.eval_checked(input).await?;
columns.push(val);
}
let direct_input = DataChunk::new(columns, input.vis().clone());

// compact the input chunk and record the row mapping
let visible_rows = direct_input.vis().iter_ones().collect_vec();
let compacted_input = direct_input.compact_cow();
let arrow_input = RecordBatch::try_from(compacted_input.as_ref())?;

let opts =
arrow_array::RecordBatchOptions::default().with_row_count(Some(input.cardinality()));
let input =
arrow_array::RecordBatch::try_new_with_options(self.arg_schema.clone(), columns, &opts)
.expect("failed to build record batch");
// call UDTF
#[for_await]
for res in self
.client
.call_stream(&self.identifier, stream::once(async { input }))
.call_stream(&self.identifier, stream::once(async { arrow_input }))
.await?
{
let output = DataChunk::try_from(&res?)?;
self.check_output(&output)?;

// we send the compacted input to UDF, so we need to map the row indices back to the original input
let origin_indices = output
.column_at(0)
.as_int32()
.raw_iter()
// we have checked all indices are non-negative
.map(|idx| visible_rows[idx as usize] as i32)
.collect::<I32Array>();

let output = DataChunk::new(
vec![origin_indices.into_ref(), output.column_at(1).clone()],
output.vis().clone(),
);
yield output;
}
}
Expand All @@ -87,6 +106,9 @@ impl UserDefinedTableFunction {
DataType::Int32,
);
}
if output.column_at(0).as_int32().raw_iter().any(|i| i < 0) {
bail!("UDF returned negative row index");
}
if !output
.column_at(1)
.data_type()
Expand Down

0 comments on commit 1616b82

Please sign in to comment.