diff --git a/src/common/src/array/data_chunk.rs b/src/common/src/array/data_chunk.rs index 58844ca2e437..98a237814176 100644 --- a/src/common/src/array/data_chunk.rs +++ b/src/common/src/array/data_chunk.rs @@ -245,6 +245,50 @@ impl DataChunk { Self::new(columns, Bitmap::ones(cardinality)) } + pub fn uncompact(self, vis: Bitmap) -> Self { + let mut uncompact_builders: Vec<_> = self + .columns + .iter() + .map(|c| c.create_builder(vis.len())) + .collect(); + let mut last_u = None; + + for (idx, u) in vis.iter_ones().enumerate() { + // pad invisible rows with NULL + let zeros = if let Some(last_u) = last_u { + u - last_u - 1 + } else { + u + }; + for _ in 0..zeros { + uncompact_builders + .iter_mut() + .for_each(|builder| builder.append_null()); + } + uncompact_builders + .iter_mut() + .zip_eq_fast(self.columns.iter()) + .for_each(|(builder, c)| builder.append(c.datum_at(idx))); + last_u = Some(u); + } + let zeros = if let Some(last_u) = last_u { + vis.len() - last_u - 1 + } else { + vis.len() + }; + for _ in 0..zeros { + uncompact_builders + .iter_mut() + .for_each(|builder| builder.append_null()); + } + let array: Vec<_> = uncompact_builders + .into_iter() + .map(|builder| Arc::new(builder.finish())) + .collect(); + + Self::new(array, vis) + } + /// Convert the chunk to compact format. /// /// If the chunk is not compacted, return a new compacted chunk, otherwise return a reference to self. diff --git a/src/expr/core/src/expr/expr_udf.rs b/src/expr/core/src/expr/expr_udf.rs index 54c49317b748..a11af2434b4f 100644 --- a/src/expr/core/src/expr/expr_udf.rs +++ b/src/expr/core/src/expr/expr_udf.rs @@ -16,10 +16,10 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::sync::{Arc, LazyLock, Mutex, Weak}; -use arrow_schema::{Field, Fields, Schema, SchemaRef}; +use arrow_schema::{Field, Fields, Schema}; use await_tree::InstrumentAwait; use cfg_or_panic::cfg_or_panic; -use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk}; +use risingwave_common::array::{ArrayRef, DataChunk}; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; use risingwave_pb::expr::ExprNode; @@ -34,7 +34,7 @@ pub struct UdfExpression { children: Vec, arg_types: Vec, return_type: DataType, - arg_schema: SchemaRef, + arg_schema: Arc, client: Arc, identifier: String, span: await_tree::Span, @@ -52,7 +52,7 @@ impl Expression for UdfExpression { let mut columns = Vec::with_capacity(self.children.len()); for child in &self.children { let array = child.eval(input).await?; - columns.push(array.as_ref().try_into()?); + columns.push(array); } self.eval_inner(columns, vis).await } @@ -66,11 +66,7 @@ impl Expression for UdfExpression { } let arg_row = OwnedRow::new(columns); let chunk = DataChunk::from_rows(std::slice::from_ref(&arg_row), &self.arg_types); - let arg_columns = chunk - .columns() - .iter() - .map::, _>(|c| Ok(c.as_ref().try_into()?)) - .try_collect()?; + let arg_columns = chunk.columns().to_vec(); let output_array = self.eval_inner(arg_columns, chunk.visibility()).await?; Ok(output_array.to_datum()) } @@ -79,30 +75,49 @@ impl Expression for UdfExpression { impl UdfExpression { async fn eval_inner( &self, - columns: Vec, + columns: Vec, vis: &risingwave_common::buffer::Bitmap, ) -> Result { - let opts = arrow_array::RecordBatchOptions::default().with_row_count(Some(vis.len())); - let input = - arrow_array::RecordBatch::try_new_with_options(self.arg_schema.clone(), columns, &opts) - .expect("failed to build record batch"); + let chunk = DataChunk::new(columns, vis.clone()); + let compacted_chunk = chunk.compact_cow(); + let compacted_columns: Vec = compacted_chunk + .columns() + .iter() + .map(|c| { + c.as_ref() + .try_into() + .expect("failed covert ArrayRef to arrow_array::ArrayRef") + }) + .collect(); + let opts = + arrow_array::RecordBatchOptions::default().with_row_count(Some(vis.count_ones())); + let input = arrow_array::RecordBatch::try_new_with_options( + self.arg_schema.clone(), + compacted_columns, + &opts, + ) + .expect("failed to build record batch"); + let output = self .client .call(&self.identifier, input) .instrument_await(self.span.clone()) .await?; - if output.num_rows() != vis.len() { + if output.num_rows() != vis.count_ones() { bail!( "UDF returned {} rows, but expected {}", output.num_rows(), vis.len(), ); } - let Some(arrow_array) = output.columns().get(0) else { + + let data_chunk = + DataChunk::try_from(&output).expect("failed to convert UDF output to DataChunk"); + let output = data_chunk.uncompact(vis.clone()); + + let Some(array) = output.columns().get(0) else { bail!("UDF returned no columns"); }; - let mut array = ArrayImpl::try_from(arrow_array)?; - array.set_bitmap(array.null_bitmap() & vis); if !array.data_type().equals_datatype(&self.return_type) { bail!( "UDF returned {:?}, but expected {:?}", @@ -110,7 +125,8 @@ impl UdfExpression { self.return_type, ); } - Ok(Arc::new(array)) + + Ok(array.clone()) } } @@ -123,6 +139,9 @@ impl Build for UdfExpression { let return_type = DataType::from(prost.get_return_type().unwrap()); let udf = prost.get_rex_node().unwrap().as_udf().unwrap(); + // connect to UDF service + let client = get_or_create_client(&udf.link)?; + let arg_schema = Arc::new(Schema::new( udf.arg_types .iter() @@ -137,8 +156,6 @@ impl Build for UdfExpression { }) .try_collect::()?, )); - // connect to UDF service - let client = get_or_create_client(&udf.link)?; Ok(Self { children: udf.children.iter().map(build_child).try_collect()?,