Skip to content

Commit

Permalink
feat(udf): don't send invisible rows to UDF server (#12486)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Axel <[email protected]>
  • Loading branch information
KveinAxel authored Oct 17, 2023
1 parent c256098 commit 203cac8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 21 deletions.
44 changes: 44 additions & 0 deletions src/common/src/array/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,50 @@ impl DataChunk {
Self::new(columns, Bitmap::ones(cardinality))
}

pub fn uncompact(self, vis: Bitmap) -> Self {
let mut uncompact_builders: Vec<_> = self
.columns
.iter()
.map(|c| c.create_builder(vis.len()))
.collect();
let mut last_u = None;

for (idx, u) in vis.iter_ones().enumerate() {
// pad invisible rows with NULL
let zeros = if let Some(last_u) = last_u {
u - last_u - 1
} else {
u
};
for _ in 0..zeros {
uncompact_builders
.iter_mut()
.for_each(|builder| builder.append_null());
}
uncompact_builders
.iter_mut()
.zip_eq_fast(self.columns.iter())
.for_each(|(builder, c)| builder.append(c.datum_at(idx)));
last_u = Some(u);
}
let zeros = if let Some(last_u) = last_u {
vis.len() - last_u - 1
} else {
vis.len()
};
for _ in 0..zeros {
uncompact_builders
.iter_mut()
.for_each(|builder| builder.append_null());
}
let array: Vec<_> = uncompact_builders
.into_iter()
.map(|builder| Arc::new(builder.finish()))
.collect();

Self::new(array, vis)
}

/// Convert the chunk to compact format.
///
/// If the chunk is not compacted, return a new compacted chunk, otherwise return a reference to self.
Expand Down
59 changes: 38 additions & 21 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ use std::collections::HashMap;
use std::convert::TryFrom;
use std::sync::{Arc, LazyLock, Mutex, Weak};

use arrow_schema::{Field, Fields, Schema, SchemaRef};
use arrow_schema::{Field, Fields, Schema};
use await_tree::InstrumentAwait;
use cfg_or_panic::cfg_or_panic;
use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk};
use risingwave_common::array::{ArrayRef, DataChunk};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use risingwave_pb::expr::ExprNode;
Expand All @@ -34,7 +34,7 @@ pub struct UdfExpression {
children: Vec<BoxedExpression>,
arg_types: Vec<DataType>,
return_type: DataType,
arg_schema: SchemaRef,
arg_schema: Arc<Schema>,
client: Arc<ArrowFlightUdfClient>,
identifier: String,
span: await_tree::Span,
Expand All @@ -52,7 +52,7 @@ impl Expression for UdfExpression {
let mut columns = Vec::with_capacity(self.children.len());
for child in &self.children {
let array = child.eval(input).await?;
columns.push(array.as_ref().try_into()?);
columns.push(array);
}
self.eval_inner(columns, vis).await
}
Expand All @@ -66,11 +66,7 @@ impl Expression for UdfExpression {
}
let arg_row = OwnedRow::new(columns);
let chunk = DataChunk::from_rows(std::slice::from_ref(&arg_row), &self.arg_types);
let arg_columns = chunk
.columns()
.iter()
.map::<Result<_>, _>(|c| Ok(c.as_ref().try_into()?))
.try_collect()?;
let arg_columns = chunk.columns().to_vec();
let output_array = self.eval_inner(arg_columns, chunk.visibility()).await?;
Ok(output_array.to_datum())
}
Expand All @@ -79,38 +75,58 @@ impl Expression for UdfExpression {
impl UdfExpression {
async fn eval_inner(
&self,
columns: Vec<arrow_array::ArrayRef>,
columns: Vec<ArrayRef>,
vis: &risingwave_common::buffer::Bitmap,
) -> Result<ArrayRef> {
let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(vis.len()));
let input =
arrow_array::RecordBatch::try_new_with_options(self.arg_schema.clone(), columns, &opts)
.expect("failed to build record batch");
let chunk = DataChunk::new(columns, vis.clone());
let compacted_chunk = chunk.compact_cow();
let compacted_columns: Vec<arrow_array::ArrayRef> = compacted_chunk
.columns()
.iter()
.map(|c| {
c.as_ref()
.try_into()
.expect("failed covert ArrayRef to arrow_array::ArrayRef")
})
.collect();
let opts =
arrow_array::RecordBatchOptions::default().with_row_count(Some(vis.count_ones()));
let input = arrow_array::RecordBatch::try_new_with_options(
self.arg_schema.clone(),
compacted_columns,
&opts,
)
.expect("failed to build record batch");

let output = self
.client
.call(&self.identifier, input)
.instrument_await(self.span.clone())
.await?;
if output.num_rows() != vis.len() {
if output.num_rows() != vis.count_ones() {
bail!(
"UDF returned {} rows, but expected {}",
output.num_rows(),
vis.len(),
);
}
let Some(arrow_array) = output.columns().get(0) else {

let data_chunk =
DataChunk::try_from(&output).expect("failed to convert UDF output to DataChunk");
let output = data_chunk.uncompact(vis.clone());

let Some(array) = output.columns().get(0) else {
bail!("UDF returned no columns");
};
let mut array = ArrayImpl::try_from(arrow_array)?;
array.set_bitmap(array.null_bitmap() & vis);
if !array.data_type().equals_datatype(&self.return_type) {
bail!(
"UDF returned {:?}, but expected {:?}",
array.data_type(),
self.return_type,
);
}
Ok(Arc::new(array))

Ok(array.clone())
}
}

Expand All @@ -123,6 +139,9 @@ impl Build for UdfExpression {
let return_type = DataType::from(prost.get_return_type().unwrap());
let udf = prost.get_rex_node().unwrap().as_udf().unwrap();

// connect to UDF service
let client = get_or_create_client(&udf.link)?;

let arg_schema = Arc::new(Schema::new(
udf.arg_types
.iter()
Expand All @@ -137,8 +156,6 @@ impl Build for UdfExpression {
})
.try_collect::<Fields>()?,
));
// connect to UDF service
let client = get_or_create_client(&udf.link)?;

Ok(Self {
children: udf.children.iter().map(build_child).try_collect()?,
Expand Down

0 comments on commit 203cac8

Please sign in to comment.