From 4bd7c137e0e205140e273a7c25824c94b457c660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 4 Apr 2024 12:30:16 +0300 Subject: [PATCH 1/4] CrossJoin Refactor (#9830) * First iteration * Wrap the logic inside function * Send batches in the size of left batches * Update cross_join.rs * fuzz tests * Update cross_join_fuzz.rs * Update cross_join_fuzz.rs * Test version 2 * Minor changes * Minor changes * Stateful implementation of CJ * Adding comments * Update cross_join_fuzz.rs * Update cross_join.rs * collect until batch size * tmp * revert changes * Preserve the join strategy, clean the algorithm and states * Update cross_join.rs * Review * Update cross_join.rs --------- Co-authored-by: Mustafa Akur Co-authored-by: Mehmet Ozan Kabak --- .../physical-plan/src/joins/cross_join.rs | 142 ++++++++++++------ 1 file changed, 95 insertions(+), 47 deletions(-) diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 19d34f8048e3..9d1de3715f54 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -22,14 +22,15 @@ use std::{any::Any, sync::Arc, task::Poll}; use super::utils::{ adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, + StatefulStreamResult, }; use crate::coalesce_batches::concat_batches; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::ExecutionPlanProperties; use crate::{ - execution_mode_from_children, ColumnStatistics, DisplayAs, DisplayFormatType, - Distribution, ExecutionMode, ExecutionPlan, PlanProperties, RecordBatchStream, + execution_mode_from_children, handle_state, ColumnStatistics, DisplayAs, + DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan, + ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; @@ -37,7 +38,7 @@ use arrow::datatypes::{Fields, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_array::RecordBatchOptions; use datafusion_common::stats::Precision; -use datafusion_common::{JoinType, Result, ScalarValue}; +use datafusion_common::{internal_err, JoinType, Result, ScalarValue}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; @@ -257,9 +258,10 @@ impl ExecutionPlan for CrossJoinExec { schema: self.schema.clone(), left_fut, right: stream, - right_batch: Arc::new(parking_lot::Mutex::new(None)), left_index: 0, join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), })) } @@ -319,16 +321,18 @@ fn stats_cartesian_product( struct CrossJoinStream { /// Input schema schema: Arc, - /// future for data from left side + /// Future for data from left side left_fut: OnceFut, - /// right + /// Right side stream right: SendableRecordBatchStream, /// Current value on the left left_index: usize, - /// Current batch being processed from the right side - right_batch: Arc>>, - /// join execution metrics + /// Join execution metrics join_metrics: BuildProbeJoinMetrics, + /// State of the stream + state: CrossJoinStreamState, + /// Left data + left_data: RecordBatch, } impl RecordBatchStream for CrossJoinStream { @@ -337,6 +341,25 @@ impl RecordBatchStream for CrossJoinStream { } } +/// Represents states of CrossJoinStream +enum CrossJoinStreamState { + WaitBuildSide, + FetchProbeBatch, + /// Holds the currently processed right side batch + BuildBatches(RecordBatch), +} + +impl CrossJoinStreamState { + /// Tries to extract RecordBatch from CrossJoinStreamState enum. + /// Returns an error if state is not BuildBatches state. + fn try_as_record_batch(&mut self) -> Result<&RecordBatch> { + match self { + CrossJoinStreamState::BuildBatches(rb) => Ok(rb), + _ => internal_err!("Expected RecordBatch in BuildBatches state"), + } + } +} + fn build_batch( left_index: usize, batch: &RecordBatch, @@ -384,58 +407,83 @@ impl CrossJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll>> { + loop { + return match self.state { + CrossJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + CrossJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + CrossJoinStreamState::BuildBatches(_) => { + handle_state!(self.build_batches()) + } + }; + } + } + + /// Collects build (left) side of the join into the state. In case of an empty build batch, + /// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch. + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); let (left_data, _) = match ready!(self.left_fut.get(cx)) { Ok(left_data) => left_data, - Err(e) => return Poll::Ready(Some(Err(e))), + Err(e) => return Poll::Ready(Err(e)), }; build_timer.done(); - if left_data.num_rows() == 0 { - return Poll::Ready(None); - } + let result = if left_data.num_rows() == 0 { + StatefulStreamResult::Ready(None) + } else { + self.left_data = left_data.clone(); + self.state = CrossJoinStreamState::FetchProbeBatch; + StatefulStreamResult::Continue + }; + Poll::Ready(Ok(result)) + } + + /// Fetches the probe (right) batch, updates the metrics, and save the batch in the state. + /// Then, the state is updated to build result batches. + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + self.left_index = 0; + let right_data = match ready!(self.right.poll_next_unpin(cx)) { + Some(Ok(right_data)) => right_data, + Some(Err(e)) => return Poll::Ready(Err(e)), + None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))), + }; + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(right_data.num_rows()); + + self.state = CrossJoinStreamState::BuildBatches(right_data); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } - if self.left_index > 0 && self.left_index < left_data.num_rows() { + /// Joins the the indexed row of left data with the current probe batch. + /// If all the results are produced, the state is set to fetch new probe batch. + fn build_batches(&mut self) -> Result>> { + let right_batch = self.state.try_as_record_batch()?; + if self.left_index < self.left_data.num_rows() { let join_timer = self.join_metrics.join_time.timer(); - let right_batch = { - let right_batch = self.right_batch.lock(); - right_batch.clone().unwrap() - }; let result = - build_batch(self.left_index, &right_batch, left_data, &self.schema); - self.join_metrics.input_rows.add(right_batch.num_rows()); + build_batch(self.left_index, right_batch, &self.left_data, &self.schema); + join_timer.done(); + if let Ok(ref batch) = result { - join_timer.done(); self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(batch.num_rows()); } self.left_index += 1; - return Poll::Ready(Some(result)); + result.map(|r| StatefulStreamResult::Ready(Some(r))) + } else { + self.state = CrossJoinStreamState::FetchProbeBatch; + Ok(StatefulStreamResult::Continue) } - self.left_index = 0; - self.right - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(batch)) => { - let join_timer = self.join_metrics.join_time.timer(); - let result = - build_batch(self.left_index, &batch, left_data, &self.schema); - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - if let Ok(ref batch) = result { - join_timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - self.left_index = 1; - - let mut right_batch = self.right_batch.lock(); - *right_batch = Some(batch); - - Some(result) - } - other => other, - }) } } From 24fc99c821dbeafb3c586d7402a521a4a80d70f2 Mon Sep 17 00:00:00 2001 From: JasonLi Date: Thu, 4 Apr 2024 22:25:28 +0800 Subject: [PATCH 2/4] Optimization: concat function (#9732) * optimization: concat function fix: concat_ws chore: add license header add arrow feature update concat * change Cargo.toml * pass cargo clippy * chore: add annotation --- datafusion/physical-expr/Cargo.toml | 5 + datafusion/physical-expr/benches/concat.rs | 47 ++ datafusion/physical-expr/src/functions.rs | 6 +- .../physical-expr/src/string_expressions.rs | 411 ++++++++++++++---- 4 files changed, 390 insertions(+), 79 deletions(-) create mode 100644 datafusion/physical-expr/benches/concat.rs diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index baca00bea724..56b3f3c91eee 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -73,6 +73,7 @@ regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } [dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } criterion = "0.5" rand = { workspace = true } rstest = { workspace = true } @@ -81,3 +82,7 @@ tokio = { workspace = true, features = ["rt-multi-thread"] } [[bench]] harness = false name = "in_list" + +[[bench]] +harness = false +name = "concat" diff --git a/datafusion/physical-expr/benches/concat.rs b/datafusion/physical-expr/benches/concat.rs new file mode 100644 index 000000000000..cdd54d767f1f --- /dev/null +++ b/datafusion/physical-expr/benches/concat.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::string_expressions::concat; +use std::sync::Arc; + +fn create_args(size: usize, str_len: usize) -> Vec { + let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); + let scalar = ScalarValue::Utf8(Some(", ".to_string())); + vec![ + ColumnarValue::Array(array.clone()), + ColumnarValue::Scalar(scalar), + ColumnarValue::Array(array), + ] +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 4096, 8192] { + let args = create_args(size, 32); + let mut group = c.benchmark_group("concat function"); + group.bench_function(BenchmarkId::new("concat", size), |b| { + b.iter(|| criterion::black_box(concat(&args).unwrap())) + }); + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index f7be2704ab79..f201deb50f41 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -221,9 +221,9 @@ pub fn create_physical_fun( // string functions BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), - BuiltinScalarFunction::ConcatWithSeparator => Arc::new(|args| { - make_scalar_function_inner(string_expressions::concat_ws)(args) - }), + BuiltinScalarFunction::ConcatWithSeparator => { + Arc::new(string_expressions::concat_ws) + } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::initcap::)(args) diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 2185b7c5b4a1..fd6c8eb6b1d9 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -23,6 +23,7 @@ use std::sync::Arc; +use arrow::array::ArrayDataBuilder; use arrow::{ array::{ Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, OffsetSizeTrait, @@ -30,6 +31,7 @@ use arrow::{ }, datatypes::DataType, }; +use arrow_buffer::{MutableBuffer, NullBuffer}; use datafusion_common::Result; use datafusion_common::{ @@ -38,75 +40,153 @@ use datafusion_common::{ }; use datafusion_expr::ColumnarValue; +enum ColumnarValueRef<'a> { + Scalar(&'a [u8]), + NullableArray(&'a StringArray), + NonNullableArray(&'a StringArray), +} + +impl<'a> ColumnarValueRef<'a> { + #[inline] + fn is_valid(&self, i: usize) -> bool { + match &self { + Self::Scalar(_) | Self::NonNullableArray(_) => true, + Self::NullableArray(array) => array.is_valid(i), + } + } + + #[inline] + fn nulls(&self) -> Option { + match &self { + Self::Scalar(_) | Self::NonNullableArray(_) => None, + Self::NullableArray(array) => array.nulls().cloned(), + } + } +} + +/// Optimized version of the StringBuilder in Arrow that: +/// 1. Precalculating the expected length of the result, avoiding reallocations. +/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` +struct StringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl StringArrayBuilder { + fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_buffer = MutableBuffer::with_capacity( + (item_capacity + 1) * std::mem::size_of::(), + ); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i32) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + fn write(&mut self, column: &ColumnarValueRef, i: usize) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + fn append_offset(&mut self) { + let next_offset: i32 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + unsafe { self.offsets_buffer.push_unchecked(next_offset) }; + } + + fn finish(self, null_buffer: Option) -> StringArray { + let array_builder = ArrayDataBuilder::new(DataType::Utf8) + .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + StringArray::from(array_data) + } +} + /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' pub fn concat(args: &[ColumnarValue]) -> Result { - // do not accept 0 arguments. - if args.is_empty() { - return exec_err!( - "concat was called with {} arguments. It requires at least 1.", - args.len() - ); + let array_len = args + .iter() + .filter_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .next(); + + // Scalar + if array_len.is_none() { + let mut result = String::new(); + for arg in args { + if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { + result.push_str(v); + } + } + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); } - // first, decide whether to return a scalar or a vector. - let mut return_array = args.iter().filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }); - if let Some(size) = return_array.next() { - let result = (0..size) - .map(|index| { - let mut owned_string: String = "".to_owned(); - for arg in args { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(value) = maybe_value { - owned_string.push_str(value); - } - } - ColumnarValue::Array(v) => { - if v.is_valid(index) { - let v = as_string_array(v).unwrap(); - owned_string.push_str(v.value(index)); - } - } - _ => unreachable!(), - } + // Array + let len = array_len.unwrap(); + let mut data_size = 0; + let mut columns = Vec::with_capacity(args.len()); + + for arg in args { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); } - Some(owned_string) - }) - .collect::(); - - Ok(ColumnarValue::Array(Arc::new(result))) - } else { - // short avenue with only scalars - let initial = Some("".to_string()); - let result = args.iter().fold(initial, |mut acc, rhs| { - if let Some(ref mut inner) = acc { - match rhs { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) => { - inner.push_str(v); - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} - _ => unreachable!(""), + } + ColumnarValue::Array(array) => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) }; - }; - acc - }); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + columns.push(column); + } + _ => unreachable!(), + } + } + + let mut builder = StringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + columns + .iter() + .for_each(|column| builder.write::(column, i)); + builder.append_offset(); } + Ok(ColumnarValue::Array(Arc::new(builder.finish(None)))) } /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' -pub fn concat_ws(args: &[ArrayRef]) -> Result { - // downcast all arguments to strings - let args = args - .iter() - .map(|e| as_string_array(e)) - .collect::>>()?; - +pub fn concat_ws(args: &[ColumnarValue]) -> Result { // do not accept 0 or 1 arguments. if args.len() < 2 { return exec_err!( @@ -115,28 +195,126 @@ pub fn concat_ws(args: &[ArrayRef]) -> Result { ); } - // first map is the iterator, second is for the `Option<_>` - let result = args[0] + let array_len = args .iter() - .enumerate() - .map(|(index, x)| { - x.map(|sep: &str| { - let string_vec = args[1..] - .iter() - .flat_map(|arg| { - if !arg.is_null(index) { - Some(arg.value(index)) - } else { - None - } - }) - .collect::>(); - string_vec.join(sep) - }) + .filter_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, }) - .collect::(); + .next(); - Ok(Arc::new(result) as ArrayRef) + // Scalar + if array_len.is_none() { + let sep = match &args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => unreachable!(), + }; + + let mut result = String::new(); + let iter = &mut args[1..].iter(); + + for arg in iter.by_ref() { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + result.push_str(s); + break; + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + _ => unreachable!(), + } + } + + for arg in iter.by_ref() { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + result.push_str(sep); + result.push_str(s); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + _ => unreachable!(), + } + } + + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + } + + // Array + let len = array_len.unwrap(); + let mut data_size = 0; + + // parse sep + let sep = match &args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + data_size += s.len() * len * (args.len() - 2); // estimate + ColumnarValueRef::Scalar(s.as_bytes()) + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); + } + ColumnarValue::Array(array) => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); // estimate + if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + } + } + _ => unreachable!(), + }; + + let mut columns = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + } + } + ColumnarValue::Array(array) => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + }; + columns.push(column); + } + _ => unreachable!(), + } + } + + let mut builder = StringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset(); + continue; + } + + let mut iter = columns.iter(); + for column in iter.by_ref() { + if column.is_valid(i) { + builder.write::(column, i); + break; + } + } + + for column in iter { + if column.is_valid(i) { + builder.write::(&sep, i); + builder.write::(column, i); + } + } + + builder.append_offset(); + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())))) } /// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. @@ -234,3 +412,84 @@ pub fn ends_with(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn concat() -> Result<()> { + let c0 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + let args = &[c0, c1, c2]; + + let result = super::concat(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + Ok(()) + } + + #[test] + fn concat_ws() -> Result<()> { + // sep is scalar + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + let args = &[c0, c1, c2]; + + let result = super::concat_ws(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + + // sep is nullable array + let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some(","), + None, + Some("+"), + ]))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + Some("y"), + Some("z"), + ]))); + let args = &[c0, c1, c2]; + + let result = super::concat_ws(args)?; + let expected = + Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")])) + as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + + Ok(()) + } +} From 63888e853b7b094f2f47f53192a94f38327f5f5a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 4 Apr 2024 10:43:15 -0400 Subject: [PATCH 3/4] Improve AggregateUDFImpl::state_fields documentation (#9919) --- datafusion/expr/src/udaf.rs | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index ba80f39dde43..14e5195116b1 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -171,9 +171,11 @@ impl AggregateUDF { self.inner.accumulator(acc_args) } - /// Return the fields of the intermediate state used by this aggregator, given - /// its state name, value type and ordering fields. See [`AggregateUDFImpl::state_fields`] - /// for more details. Supports multi-phase aggregations + /// Return the fields used to store the intermediate state for this aggregator, given + /// the name of the aggregate, value type and ordering fields. See [`AggregateUDFImpl::state_fields`] + /// for more details. + /// + /// This is used to support multi-phase aggregations pub fn state_fields( &self, name: &str, @@ -283,13 +285,28 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// `acc_args`: the arguments to the accumulator. See [`AccumulatorArgs`] for more details. fn accumulator(&self, acc_args: AccumulatorArgs) -> Result>; - /// Return the fields of the intermediate state. + /// Return the fields used to store the intermediate state of this accumulator. + /// + /// # Arguments: + /// 1. `name`: the name of the expression (e.g. AVG, SUM, etc) + /// 2. `value_type`: Aggregate's aggregate's output (returned by [`Self::return_type`]) + /// 3. `ordering_fields`: the fields used to order the input arguments, if any. + /// Empty if no ordering expression is provided. + /// + /// # Notes: /// - /// name: the name of the state + /// The default implementation returns a single state field named `name` + /// with the same type as `value_type`. This is suitable for aggregates such + /// as `SUM` or `MIN` where partial state can be combined by applying the + /// same aggregate. /// - /// value_type: the type of the value, it should be the result of the `return_type` + /// For aggregates such as `AVG` where the partial state is more complex + /// (e.g. a COUNT and a SUM), this method is used to define the additional + /// fields. /// - /// ordering_fields: the fields used for ordering, empty if no ordering expression is provided + /// The name of the fields must be unique within the query and thus should + /// be derived from `name`. See [`format_state_name`] for a utility function + /// to generate a unique name. fn state_fields( &self, name: &str, From 202f415811c0a559dde63108be967855844c14cb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 4 Apr 2024 12:16:59 -0400 Subject: [PATCH 4/4] chore(deps): update substrait requirement from 0.28.0 to 0.29.0 (#9942) Updates the requirements on [substrait](https://github.com/substrait-io/substrait-rs) to permit the latest version. - [Release notes](https://github.com/substrait-io/substrait-rs/releases) - [Changelog](https://github.com/substrait-io/substrait-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/substrait-io/substrait-rs/compare/v0.28.0...v0.29.0) --- updated-dependencies: - dependency-name: substrait dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/substrait/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index cc79685c9429..f9523446980e 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -36,7 +36,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.28.0" +substrait = "0.29.0" [dev-dependencies] tokio = { workspace = true }