Skip to content

Commit

Permalink
fix(udf): fix "udf returned no data" error (#15076) (#15319)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangrunji0408 authored Feb 28, 2024
1 parent 336db6b commit 9e9d549
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 56 deletions.
63 changes: 24 additions & 39 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub struct UserDefinedFunction {
children: Vec<BoxedExpression>,
arg_types: Vec<DataType>,
return_type: DataType,
#[expect(dead_code)]
arg_schema: Arc<Schema>,
imp: UdfImpl,
identifier: String,
Expand Down Expand Up @@ -75,13 +76,19 @@ impl Expression for UserDefinedFunction {

#[cfg_or_panic(not(madsim))]
async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
let vis = input.visibility();
if input.cardinality() == 0 {
// early return for empty input
let mut builder = self.return_type.create_array_builder(input.capacity());
builder.append_n_null(input.capacity());
return Ok(builder.finish().into_ref());
}
let mut columns = Vec::with_capacity(self.children.len());
for child in &self.children {
let array = child.eval(input).await?;
columns.push(array);
}
self.eval_inner(columns, vis).await
let chunk = DataChunk::new(columns, input.visibility().clone());
self.eval_inner(&chunk).await
}

#[cfg_or_panic(not(madsim))]
Expand All @@ -93,57 +100,35 @@ impl Expression for UserDefinedFunction {
}
let arg_row = OwnedRow::new(columns);
let chunk = DataChunk::from_rows(std::slice::from_ref(&arg_row), &self.arg_types);
let arg_columns = chunk.columns().to_vec();
let output_array = self.eval_inner(arg_columns, chunk.visibility()).await?;
let output_array = self.eval_inner(&chunk).await?;
Ok(output_array.to_datum())
}
}

impl UserDefinedFunction {
async fn eval_inner(
&self,
columns: Vec<ArrayRef>,
vis: &risingwave_common::buffer::Bitmap,
) -> Result<ArrayRef> {
let chunk = DataChunk::new(columns, vis.clone());
let compacted_chunk = chunk.compact_cow();
let compacted_columns: Vec<arrow_array::ArrayRef> = compacted_chunk
.columns()
.iter()
.map(|c| {
c.as_ref()
.try_into()
.expect("failed covert ArrayRef to arrow_array::ArrayRef")
})
.collect();
let opts = arrow_array::RecordBatchOptions::default()
.with_row_count(Some(compacted_chunk.capacity()));
let input = arrow_array::RecordBatch::try_new_with_options(
self.arg_schema.clone(),
compacted_columns,
&opts,
)
.expect("failed to build record batch");
async fn eval_inner(&self, input: &DataChunk) -> Result<ArrayRef> {
// this will drop invisible rows
let arrow_input = arrow_array::RecordBatch::try_from(input)?;

let output: arrow_array::RecordBatch = match &self.imp {
UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &input)?,
let arrow_output: arrow_array::RecordBatch = match &self.imp {
UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &arrow_input)?,
UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &arrow_input)?,
UdfImpl::External(client) => {
let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed);
let result = if self.always_retry_on_network_error {
client
.call_with_always_retry_on_network_error(&self.identifier, input)
.call_with_always_retry_on_network_error(&self.identifier, arrow_input)
.instrument_await(self.span.clone())
.await
} else {
let result = if disable_retry_count != 0 {
client
.call(&self.identifier, input)
.call(&self.identifier, arrow_input)
.instrument_await(self.span.clone())
.await
} else {
client
.call_with_retry(&self.identifier, input)
.call_with_retry(&self.identifier, arrow_input)
.instrument_await(self.span.clone())
.await
};
Expand All @@ -167,16 +152,16 @@ impl UserDefinedFunction {
result?
}
};
if output.num_rows() != vis.count_ones() {
if arrow_output.num_rows() != input.cardinality() {
bail!(
"UDF returned {} rows, but expected {}",
output.num_rows(),
vis.len(),
arrow_output.num_rows(),
input.cardinality(),
);
}

let data_chunk = DataChunk::try_from(&output)?;
let output = data_chunk.uncompact(vis.clone());
let output = DataChunk::try_from(&arrow_output)?;
let output = output.uncompact(input.visibility().clone());

let Some(array) = output.columns().first() else {
bail!("UDF returned no columns");
Expand Down
40 changes: 23 additions & 17 deletions src/expr/udf/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,21 +175,17 @@ impl ArrowFlightUdfClient {
}

async fn call_internal(&self, id: &str, input: RecordBatch) -> Result<RecordBatch> {
let mut output_stream = self.call_stream(id, stream::once(async { input })).await?;
// TODO: support no output
let head = output_stream
.next()
.await
.ok_or_else(Error::no_returned)??;
let remaining = output_stream.try_collect::<Vec<_>>().await?;
if remaining.is_empty() {
Ok(head)
} else {
Ok(arrow_select::concat::concat_batches(
&head.schema(),
std::iter::once(&head).chain(remaining.iter()),
)?)
let mut output_stream = self
.call_stream_internal(id, stream::once(async { input }))
.await?;
let mut batches = vec![];
while let Some(batch) = output_stream.next().await {
batches.push(batch?);
}
Ok(arrow_select::concat::concat_batches(
output_stream.schema().ok_or_else(Error::no_returned)?,
batches.iter(),
)?)
}

/// Call a function, retry up to 5 times / 3s if connection is broken.
Expand Down Expand Up @@ -234,6 +230,17 @@ impl ArrowFlightUdfClient {
id: &str,
inputs: impl Stream<Item = RecordBatch> + Send + 'static,
) -> Result<impl Stream<Item = Result<RecordBatch>> + Send + 'static> {
Ok(self
.call_stream_internal(id, inputs)
.await?
.map_err(|e| e.into()))
}

async fn call_stream_internal(
&self,
id: &str,
inputs: impl Stream<Item = RecordBatch> + Send + 'static,
) -> Result<FlightRecordBatchStream> {
let descriptor = FlightDescriptor::new_path(vec![id.into()]);
let flight_data_stream =
FlightDataEncoderBuilder::new()
Expand All @@ -249,11 +256,10 @@ impl ArrowFlightUdfClient {

// decode response
let stream = response.into_inner();
let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
Ok(FlightRecordBatchStream::new_from_flight_data(
// convert tonic::Status to FlightError
stream.map_err(|e| e.into()),
);
Ok(record_batch_stream.map_err(|e| e.into()))
))
}
}

Expand Down

0 comments on commit 9e9d549

Please sign in to comment.