diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore index 2c574ff30d12..c35b1a7c1944 100644 --- a/benchmarks/.gitignore +++ b/benchmarks/.gitignore @@ -1,2 +1,3 @@ data -results \ No newline at end of file +results +venv diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 77779a12c450..efd56b17c7cb 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -37,6 +37,7 @@ DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} #CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --profile release-nonlto"} # for faster iterations PREFER_HASH_JOIN=${PREFER_HASH_JOIN:-true} +VIRTUAL_ENV=${VIRTUAL_ENV:-$SCRIPT_DIR/venv} usage() { echo " @@ -46,6 +47,7 @@ Usage: $0 data [benchmark] $0 run [benchmark] $0 compare +$0 venv ********** Examples: @@ -62,6 +64,7 @@ DATAFUSION_DIR=/source/datafusion ./bench.sh run tpch data: Generates or downloads data needed for benchmarking run: Runs the named benchmark compare: Compares results from benchmark runs +venv: Creates new venv (unless already exists) and installs compare's requirements into it ********** * Benchmarks @@ -84,7 +87,8 @@ DATA_DIR directory to store datasets CARGO_COMMAND command that runs the benchmark binary DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) RESULTS_NAME folder where the benchmark files are stored -PREFER_HASH_JOIN Prefer hash join algorithm(default true) +PREFER_HASH_JOIN Prefer hash join algorithm (default true) +VENV_PATH Python venv to use for compare and venv commands (default ./venv, override by /bin/activate) " exit 1 } @@ -243,6 +247,9 @@ main() { compare) compare_benchmarks "$ARG2" "$ARG3" ;; + venv) + setup_venv + ;; "") usage ;; @@ -302,7 +309,7 @@ data_tpch() { else echo " creating parquet files using benchmark binary ..." pushd "${SCRIPT_DIR}" > /dev/null - $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --prefer_hash_join ${PREFER_HASH_JOIN} --output "${TPCH_DIR}" --format parquet + $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet popd > /dev/null fi } @@ -405,7 +412,7 @@ run_clickbench_1() { RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --prefer_hash_join ${PREFER_HASH_JOIN} --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} } # Runs the clickbench benchmark with the partitioned parquet files @@ -413,7 +420,7 @@ run_clickbench_partitioned() { RESULTS_FILE="${RESULTS_DIR}/clickbench_partitioned.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (partitioned, 100 files) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --prefer_hash_join ${PREFER_HASH_JOIN} --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} } # Runs the clickbench "extended" benchmark with a single large parquet file @@ -421,7 +428,7 @@ run_clickbench_extended() { RESULTS_FILE="${RESULTS_DIR}/clickbench_extended.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) extended benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --prefer_hash_join ${PREFER_HASH_JOIN} --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o ${RESULTS_FILE} + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o ${RESULTS_FILE} } compare_benchmarks() { @@ -448,7 +455,7 @@ compare_benchmarks() { echo "--------------------" echo "Benchmark ${bench}" echo "--------------------" - python3 "${SCRIPT_DIR}"/compare.py "${RESULTS_FILE1}" "${RESULTS_FILE2}" + PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py "${RESULTS_FILE1}" "${RESULTS_FILE2}" else echo "Note: Skipping ${RESULTS_FILE1} as ${RESULTS_FILE2} does not exist" fi @@ -456,5 +463,10 @@ compare_benchmarks() { } +setup_venv() { + python3 -m venv $VIRTUAL_ENV + PATH=$VIRTUAL_ENV/bin:$PATH python3 -m pip install -r requirements.txt +} + # And start the process up main diff --git a/benchmarks/compare.py b/benchmarks/compare.py index ec2b28fa0556..2574c0735ca8 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -29,7 +29,7 @@ from rich.console import Console from rich.table import Table except ImportError: - print("Try `pip install rich` for using this script.") + print("Couldn't import modules -- run `./bench.sh venv` first") raise diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 000000000000..20a5a2bddbf2 --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +rich diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 187f856894b2..f2b29fe78690 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -133,7 +133,7 @@ struct Args { #[clap( long, - help = "The max number of rows to display for 'Table' format\n[default: 40] [possible values: numbers(0/10/...), inf(no limit)]", + help = "The max number of rows to display for 'Table' format\n[possible values: numbers(0/10/...), inf(no limit)]", default_value = "40" )] maxrows: MaxRows, diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 950cb7ddb2d3..b5c58eff577c 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -3100,10 +3100,7 @@ mod tests { let join_schema = physical_plan.schema(); match join_type { - JoinType::Inner - | JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti => { + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { let left_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c1", &join_schema)?), Arc::new(Column::new_with_schema("c2", &join_schema)?), @@ -3113,7 +3110,10 @@ mod tests { &Partitioning::Hash(left_exprs, default_partition_count) ); } - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinType::Inner + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti => { let right_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c2_c1", &join_schema)?), Arc::new(Column::new_with_schema("c2_c2", &join_schema)?), @@ -3133,6 +3133,7 @@ mod tests { Ok(()) } + #[tokio::test] async fn nested_explain_should_fail() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs index f51f2c49e896..e15e907cd9b8 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs @@ -384,7 +384,7 @@ mod test { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, RowGroupAccess::Selection( - // select / skip all 20 rows in row group 1 + // specifies all 20 rows in row group 1 vec![ RowSelector::select(5), RowSelector::skip(7), @@ -463,7 +463,7 @@ mod test { fn test_invalid_too_few() { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, - // select 12 rows, but row group 1 has 20 + // specify only 12 rows in selection, but row group 1 has 20 RowGroupAccess::Selection( vec![RowSelector::select(5), RowSelector::skip(7)].into(), ), @@ -484,7 +484,7 @@ mod test { fn test_invalid_too_many() { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, - // select 22 rows, but row group 1 has only 20 + // specify 22 rows in selection, but row group 1 has only 20 RowGroupAccess::Selection( vec![ RowSelector::select(10), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 5e5cc93bc54f..ec21c5504c69 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -156,9 +156,8 @@ pub use writer::plan_to_parquet; /// used to implement external indexes on top of parquet files and select only /// portions of the files. /// -/// The `ParquetExec` will try and further reduce any provided -/// `ParquetAccessPlan` further based on the contents of `ParquetMetadata` and -/// other settings. +/// The `ParquetExec` will try and reduce any provided `ParquetAccessPlan` +/// further based on the contents of `ParquetMetadata` and other settings. /// /// ## Example of providing a ParquetAccessPlan /// diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index 8557c6d5f950..36335863032c 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -238,6 +238,8 @@ fn create_initial_plan( // check row group count matches the plan return Ok(access_plan.clone()); + } else { + debug!("ParquetExec Ignoring unknown extension specified for {file_name}"); } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/reader.rs b/datafusion/core/src/datasource/physical_plan/parquet/reader.rs index 265fb9d570cc..8a4ba136fc96 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/reader.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/reader.rs @@ -16,7 +16,7 @@ // under the License. //! [`ParquetFileReaderFactory`] and [`DefaultParquetFileReaderFactory`] for -//! creating parquet file readers +//! low level control of parquet file readers use crate::datasource::physical_plan::{FileMeta, ParquetFileMetrics}; use bytes::Bytes; @@ -33,12 +33,19 @@ use std::sync::Arc; /// /// The combined implementations of [`ParquetFileReaderFactory`] and /// [`AsyncFileReader`] can be used to provide custom data access operations -/// such as pre-cached data, I/O coalescing, etc. +/// such as pre-cached metadata, I/O coalescing, etc. /// /// See [`DefaultParquetFileReaderFactory`] for a simple implementation. pub trait ParquetFileReaderFactory: Debug + Send + Sync + 'static { /// Provides an `AsyncFileReader` for reading data from a parquet file specified /// + /// # Notes + /// + /// If the resulting [`AsyncFileReader`] returns `ParquetMetaData` without + /// page index information, the reader will load it on demand. Thus it is important + /// to ensure that the returned `ParquetMetaData` has the necessary information + /// if you wish to avoid a subsequent I/O + /// /// # Arguments /// * partition_index - Index of the partition (for reporting metrics) /// * file_meta - The file to be read diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index c0d36f1fc4d7..a2e0d8fa66be 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -33,7 +33,8 @@ use arrow_array::{ use arrow_schema::{Field, FieldRef, Schema, TimeUnit}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use half::f16; -use parquet::file::metadata::RowGroupMetaData; +use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex, RowGroupMetaData}; +use parquet::file::page_index::index::Index; use parquet::file::statistics::Statistics as ParquetStatistics; use parquet::schema::types::SchemaDescriptor; use paste::paste; @@ -517,6 +518,74 @@ macro_rules! get_statistics { }}} } +macro_rules! make_data_page_stats_iterator { + ($iterator_type: ident, $func: ident, $index_type: path, $stat_value_type: ty) => { + struct $iterator_type<'a, I> + where + I: Iterator, + { + iter: I, + } + + impl<'a, I> $iterator_type<'a, I> + where + I: Iterator, + { + fn new(iter: I) -> Self { + Self { iter } + } + } + + impl<'a, I> Iterator for $iterator_type<'a, I> + where + I: Iterator, + { + type Item = Vec>; + + fn next(&mut self) -> Option { + let next = self.iter.next(); + match next { + Some((len, index)) => match index { + $index_type(native_index) => Some( + native_index + .indexes + .iter() + .map(|x| x.$func) + .collect::>(), + ), + // No matching `Index` found; + // thus no statistics that can be extracted. + // We return vec![None; len] to effectively + // create an arrow null-array with the length + // corresponding to the number of entries in + // `ParquetOffsetIndex` per row group per column. + _ => Some(vec![None; len]), + }, + _ => None, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } + } + }; +} + +make_data_page_stats_iterator!(MinInt64DataPageStatsIterator, min, Index::INT64, i64); +make_data_page_stats_iterator!(MaxInt64DataPageStatsIterator, max, Index::INT64, i64); + +macro_rules! get_data_page_statistics { + ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { + paste! { + match $data_type { + Some(DataType::Int64) => Ok(Arc::new(Int64Array::from_iter([<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten()))), + _ => unimplemented!() + } + } + } +} + /// Lookups up the parquet column by name /// /// Returns the parquet column index and the corresponding arrow field @@ -563,6 +632,51 @@ fn max_statistics<'a, I: Iterator>>( get_statistics!(Max, data_type, iterator) } +/// Extracts the min statistics from an iterator +/// of parquet page [`Index`]'es to an [`ArrayRef`] +pub(crate) fn min_page_statistics<'a, I>( + data_type: Option<&DataType>, + iterator: I, +) -> Result +where + I: Iterator, +{ + get_data_page_statistics!(Min, data_type, iterator) +} + +/// Extracts the max statistics from an iterator +/// of parquet page [`Index`]'es to an [`ArrayRef`] +pub(crate) fn max_page_statistics<'a, I>( + data_type: Option<&DataType>, + iterator: I, +) -> Result +where + I: Iterator, +{ + get_data_page_statistics!(Max, data_type, iterator) +} + +/// Extracts the null count statistics from an iterator +/// of parquet page [`Index`]'es to an [`ArrayRef`] +/// +/// The returned Array is an [`UInt64Array`] +pub(crate) fn null_counts_page_statistics<'a, I>(iterator: I) -> Result +where + I: Iterator, +{ + let iter = iterator.flat_map(|(len, index)| match index { + Index::NONE => vec![None; len], + Index::INT64(native_index) => native_index + .indexes + .iter() + .map(|x| x.null_count.map(|x| x as u64)) + .collect::>(), + _ => unimplemented!(), + }); + + Ok(Arc::new(UInt64Array::from_iter(iter))) +} + /// Extracts Parquet statistics as Arrow arrays /// /// This is used to convert Parquet statistics to Arrow arrays, with proper type @@ -771,10 +885,205 @@ impl<'a> StatisticsConverter<'a> { Ok(Arc::new(UInt64Array::from_iter(null_counts))) } + /// Extract the minimum values from Data Page statistics. + /// + /// In Parquet files, in addition to the Column Chunk level statistics + /// (stored for each column for each row group) there are also + /// optional statistics stored for each data page, as part of + /// the [`ParquetColumnIndex`]. + /// + /// Since a single Column Chunk is stored as one or more pages, + /// page level statistics can prune at a finer granularity. + /// + /// However since they are stored in a separate metadata + /// structure ([`Index`]) there is different code to extract them as + /// compared to arrow statistics. + /// + /// # Parameters: + /// + /// * `column_page_index`: The parquet column page indices, read from + /// `ParquetMetaData` column_index + /// + /// * `column_offset_index`: The parquet column offset indices, read from + /// `ParquetMetaData` offset_index + /// + /// * `row_group_indices`: The indices of the row groups, that are used to + /// extract the column page index and offset index on a per row group + /// per column basis. + /// + /// # Return Value + /// + /// The returned array contains 1 value for each `NativeIndex` + /// in the underlying `Index`es, in the same order as they appear + /// in `metadatas`. + /// + /// For example, if there are two `Index`es in `metadatas`: + /// 1. the first having `3` `PageIndex` entries + /// 2. the second having `2` `PageIndex` entries + /// + /// The returned array would have 5 rows. + /// + /// Each value is either: + /// * the minimum value for the page + /// * a null value, if the statistics can not be extracted + /// + /// Note that a null value does NOT mean the min value was actually + /// `null` it means it the requested statistic is unknown + /// + /// # Errors + /// + /// Reasons for not being able to extract the statistics include: + /// * the column is not present in the parquet file + /// * statistics for the pages are not present in the row group + /// * the stored statistic value can not be converted to the requested type + pub fn data_page_mins( + &self, + column_page_index: &ParquetColumnIndex, + column_offset_index: &ParquetOffsetIndex, + row_group_indices: I, + ) -> Result + where + I: IntoIterator, + { + let data_type = self.arrow_field.data_type(); + + let Some(parquet_index) = self.parquet_index else { + return Ok(self.make_null_array(data_type, row_group_indices)); + }; + + let iter = row_group_indices.into_iter().map(|rg_index| { + let column_page_index_per_row_group_per_column = + &column_page_index[*rg_index][parquet_index]; + let num_data_pages = &column_offset_index[*rg_index][parquet_index].len(); + + (*num_data_pages, column_page_index_per_row_group_per_column) + }); + + min_page_statistics(Some(data_type), iter) + } + + /// Extract the maximum values from Data Page statistics. + /// + /// See docs on [`Self::data_page_mins`] for details. + pub fn data_page_maxes( + &self, + column_page_index: &ParquetColumnIndex, + column_offset_index: &ParquetOffsetIndex, + row_group_indices: I, + ) -> Result + where + I: IntoIterator, + { + let data_type = self.arrow_field.data_type(); + + let Some(parquet_index) = self.parquet_index else { + return Ok(self.make_null_array(data_type, row_group_indices)); + }; + + let iter = row_group_indices.into_iter().map(|rg_index| { + let column_page_index_per_row_group_per_column = + &column_page_index[*rg_index][parquet_index]; + let num_data_pages = &column_offset_index[*rg_index][parquet_index].len(); + + (*num_data_pages, column_page_index_per_row_group_per_column) + }); + + max_page_statistics(Some(data_type), iter) + } + + /// Extract the null counts from Data Page statistics. + /// + /// The returned Array is an [`UInt64Array`] + /// + /// See docs on [`Self::data_page_mins`] for details. + pub fn data_page_null_counts( + &self, + column_page_index: &ParquetColumnIndex, + column_offset_index: &ParquetOffsetIndex, + row_group_indices: I, + ) -> Result + where + I: IntoIterator, + { + let data_type = self.arrow_field.data_type(); + + let Some(parquet_index) = self.parquet_index else { + return Ok(self.make_null_array(data_type, row_group_indices)); + }; + + let iter = row_group_indices.into_iter().map(|rg_index| { + let column_page_index_per_row_group_per_column = + &column_page_index[*rg_index][parquet_index]; + let num_data_pages = &column_offset_index[*rg_index][parquet_index].len(); + + (*num_data_pages, column_page_index_per_row_group_per_column) + }); + null_counts_page_statistics(iter) + } + + /// Returns an [`ArrayRef`] with row counts for each row group. + /// + /// This function iterates over the given row group indexes and computes + /// the row count for each page in the specified column. + /// + /// # Parameters: + /// + /// * `column_offset_index`: The parquet column offset indices, read from + /// `ParquetMetaData` offset_index + /// + /// * `row_group_metadatas`: The metadata slice of the row groups, read + /// from `ParquetMetaData` row_groups + /// + /// * `row_group_indices`: The indices of the row groups, that are used to + /// extract the column offset index on a per row group per column basis. + /// + /// See docs on [`Self::data_page_mins`] for details. + pub fn data_page_row_counts( + &self, + column_offset_index: &ParquetOffsetIndex, + row_group_metadatas: &[RowGroupMetaData], + row_group_indices: I, + ) -> Result + where + I: IntoIterator, + { + let data_type = self.arrow_field.data_type(); + + let Some(parquet_index) = self.parquet_index else { + return Ok(self.make_null_array(data_type, row_group_indices)); + }; + + // `offset_index[row_group_number][column_number][page_number]` holds + // the [`PageLocation`] corresponding to page `page_number` of column + // `column_number`of row group `row_group_number`. + let mut row_count_total = Vec::new(); + for rg_idx in row_group_indices { + let page_locations = &column_offset_index[*rg_idx][parquet_index]; + + let row_count_per_page = page_locations.windows(2).map(|loc| { + Some(loc[1].first_row_index as u64 - loc[0].first_row_index as u64) + }); + + let num_rows_in_row_group = &row_group_metadatas[*rg_idx].num_rows(); + + // append the last page row count + let row_count_per_page = row_count_per_page + .chain(std::iter::once(Some( + *num_rows_in_row_group as u64 + - page_locations.last().unwrap().first_row_index as u64, + ))) + .collect::>(); + + row_count_total.extend(row_count_per_page); + } + + Ok(Arc::new(UInt64Array::from_iter(row_count_total))) + } + /// Returns a null array of data_type with one element per row group - fn make_null_array(&self, data_type: &DataType, metadatas: I) -> ArrayRef + fn make_null_array(&self, data_type: &DataType, metadatas: I) -> ArrayRef where - I: IntoIterator, + I: IntoIterator, { // column was in the arrow schema but not in the parquet schema, so return a null array let num_row_groups = metadatas.into_iter().count(); diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index eeacc48b85db..ca1582bcb34f 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -390,6 +390,7 @@ pub(crate) mod tests { &[self.column()], &[], &[], + &[], schema, self.column_name(), false, diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 38b92959e841..b57f36f728d7 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -315,6 +315,7 @@ mod tests { &[expr], &[], &[], + &[], schema, name, false, @@ -404,6 +405,7 @@ mod tests { &[col("b", &schema)?], &[], &[], + &[], &schema, "Sum(b)", false, diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 154e77cd23ae..5320938d2eb8 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -245,6 +245,7 @@ pub fn bounded_window_exec( "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], + &[], &sort_exprs, Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4f9187595018..404bcbb2e7d4 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1766,7 +1766,8 @@ pub fn create_window_expr_with_name( window_frame, null_treatment, }) => { - let args = create_physical_exprs(args, logical_schema, execution_props)?; + let physical_args = + create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = create_physical_exprs(partition_by, logical_schema, execution_props)?; let order_by = @@ -1780,13 +1781,13 @@ pub fn create_window_expr_with_name( } let window_frame = Arc::new(window_frame.clone()); - let ignore_nulls = null_treatment - .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) + let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; windows::create_window_expr( fun, name, - &args, + &physical_args, + args, &partition_by, &order_by, window_frame, @@ -1837,7 +1838,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( order_by, null_treatment, }) => { - let args = + let physical_args = create_physical_exprs(args, logical_input_schema, execution_props)?; let filter = match filter { Some(e) => Some(create_physical_expr( @@ -1867,7 +1868,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let agg_expr = aggregates::create_aggregate_expr( fun, *distinct, - &args, + &physical_args, &ordering_reqs, physical_input_schema, name, @@ -1889,7 +1890,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( fun, - &args, + &physical_args, + args, &sort_exprs, &ordering_reqs, physical_input_schema, diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index b05769a6ce9d..1c55c48fea40 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -33,7 +33,7 @@ use datafusion::assert_batches_eq; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; -use datafusion_functions_aggregate::expr_fn::approx_median; +use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -363,7 +363,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { let expected = [ "+---------------------------------------------+", - "| APPROX_PERCENTILE_CONT(test.b,Float64(0.5)) |", + "| approx_percentile_cont(test.b,Float64(0.5)) |", "+---------------------------------------------+", "| 10 |", "+---------------------------------------------+", @@ -384,7 +384,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { let df = create_test_table().await?; let expected = [ "+--------------------------------------+", - "| APPROX_PERCENTILE_CONT(test.b,arg_2) |", + "| approx_percentile_cont(test.b,arg_2) |", "+--------------------------------------+", "| 10 |", "+--------------------------------------+", diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index c76c1fc2c736..a04f4f349122 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -108,6 +108,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str &[col("d", &schema).unwrap()], &[], &[], + &[], &schema, "sum1", false, diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index a893e780581f..516749e82a53 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -180,6 +180,9 @@ async fn test_semi_join_1k() { .await } +// The test is flaky +// https://github.com/apache/datafusion/issues/10886 +#[ignore] #[tokio::test] async fn test_semi_join_1k_filtered() { JoinFuzzTestCase::new( diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 4358691ee5a5..5bd19850cacc 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -252,6 +252,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { let partitionby_exprs = vec![]; let orderby_exprs = vec![]; + let logical_exprs = vec![]; // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -283,6 +284,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { &window_fn, fn_name.to_string(), &args, + &logical_exprs, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame), @@ -699,6 +701,7 @@ async fn run_window_test( &window_fn, fn_name.clone(), &args, + &[], &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), @@ -717,6 +720,7 @@ async fn run_window_test( &window_fn, fn_name, &args, + &[], &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 2ea18d7cf823..6b8705441d12 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -18,6 +18,7 @@ //! This file contains an end to end test of extracting statitics from parquet files. //! It writes data into a parquet file, reads statistics and verifies they are correct +use std::default::Default; use std::fs::File; use std::sync::Arc; @@ -39,102 +40,102 @@ use arrow_array::{ use arrow_schema::{DataType, Field, Schema}; use datafusion::datasource::physical_plan::parquet::StatisticsConverter; use half::f16; -use parquet::arrow::arrow_reader::{ArrowReaderBuilder, ParquetRecordBatchReaderBuilder}; +use parquet::arrow::arrow_reader::{ + ArrowReaderBuilder, ArrowReaderOptions, ParquetRecordBatchReaderBuilder, +}; use parquet::arrow::ArrowWriter; use parquet::file::properties::{EnabledStatistics, WriterProperties}; use super::make_test_file_rg; -// TEST HELPERS - -/// Return a record batch with i64 with Null values -fn make_int64_batches_with_null( +#[derive(Debug, Default, Clone)] +struct Int64Case { + /// Number of nulls in the column null_values: usize, + /// Non null values in the range `[no_null_values_start, + /// no_null_values_end]`, one value for each row no_null_values_start: i64, no_null_values_end: i64, -) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![Field::new("i64", DataType::Int64, true)])); - - let v64: Vec = (no_null_values_start as _..no_null_values_end as _).collect(); - - RecordBatch::try_new( - schema, - vec![make_array( - Int64Array::from_iter( - v64.into_iter() - .map(Some) - .chain(std::iter::repeat(None).take(null_values)), - ) - .to_data(), - )], - ) - .unwrap() -} - -// Create a parquet file with one column for data type i64 -// Data of the file include -// . Number of null rows is the given num_null -// . There are non-null values in the range [no_null_values_start, no_null_values_end], one value each row -// . The file is divided into row groups of size row_per_group -pub fn parquet_file_one_column( - num_null: usize, - no_null_values_start: i64, - no_null_values_end: i64, + /// Number of rows per row group row_per_group: usize, -) -> ParquetRecordBatchReaderBuilder { - parquet_file_one_column_stats( - num_null, - no_null_values_start, - no_null_values_end, - row_per_group, - EnabledStatistics::Chunk, - ) + /// if specified, overrides default statistics settings + enable_stats: Option, + /// If specified, the number of values in each page + data_page_row_count_limit: Option, } -// Create a parquet file with one column for data type i64 -// Data of the file include -// . Number of null rows is the given num_null -// . There are non-null values in the range [no_null_values_start, no_null_values_end], one value each row -// . The file is divided into row groups of size row_per_group -// . Statistics are enabled/disabled based on the given enable_stats -pub fn parquet_file_one_column_stats( - num_null: usize, - no_null_values_start: i64, - no_null_values_end: i64, - row_per_group: usize, - enable_stats: EnabledStatistics, -) -> ParquetRecordBatchReaderBuilder { - let mut output_file = tempfile::Builder::new() - .prefix("parquert_statistics_test") - .suffix(".parquet") - .tempfile() - .expect("tempfile creation"); - - let props = WriterProperties::builder() - .set_max_row_group_size(row_per_group) - .set_statistics_enabled(enable_stats) - .build(); - - let batches = vec![make_int64_batches_with_null( - num_null, - no_null_values_start, - no_null_values_end, - )]; - - let schema = batches[0].schema(); - - let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); +impl Int64Case { + /// Return a record batch with i64 with Null values + /// The first no_null_values_end - no_null_values_start values + /// are non-null with the specified range, the rest are null + fn make_int64_batches_with_null(&self) -> RecordBatch { + let schema = + Arc::new(Schema::new(vec![Field::new("i64", DataType::Int64, true)])); + + let v64: Vec = + (self.no_null_values_start as _..self.no_null_values_end as _).collect(); + + RecordBatch::try_new( + schema, + vec![make_array( + Int64Array::from_iter( + v64.into_iter() + .map(Some) + .chain(std::iter::repeat(None).take(self.null_values)), + ) + .to_data(), + )], + ) + .unwrap() + } + + // Create a parquet file with the specified settings + pub fn build(&self) -> ParquetRecordBatchReaderBuilder { + let mut output_file = tempfile::Builder::new() + .prefix("parquert_statistics_test") + .suffix(".parquet") + .tempfile() + .expect("tempfile creation"); + + let mut builder = + WriterProperties::builder().set_max_row_group_size(self.row_per_group); + if let Some(enable_stats) = self.enable_stats { + builder = builder.set_statistics_enabled(enable_stats); + } + if let Some(data_page_row_count_limit) = self.data_page_row_count_limit { + builder = builder.set_data_page_row_count_limit(data_page_row_count_limit); + } + let props = builder.build(); + + let batches = vec![self.make_int64_batches_with_null()]; + + let schema = batches[0].schema(); + + let mut writer = + ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); + + // if we have a datapage limit send the batches in one at a time to give + // the writer a chance to be split into multiple pages + if self.data_page_row_count_limit.is_some() { + for batch in batches { + for i in 0..batch.num_rows() { + writer.write(&batch.slice(i, 1)).expect("writing batch"); + } + } + } else { + for batch in batches { + writer.write(&batch).expect("writing batch"); + } + } + + // close file + let _file_meta = writer.close().unwrap(); - for batch in batches { - writer.write(&batch).expect("writing batch"); + // open the file & get the reader + let file = output_file.reopen().unwrap(); + let options = ArrowReaderOptions::new().with_page_index(true); + ArrowReaderBuilder::try_new_with_options(file, options).unwrap() } - - // close file - let _file_meta = writer.close().unwrap(); - - // open the file & get the reader - let file = output_file.reopen().unwrap(); - ArrowReaderBuilder::try_new(file).unwrap() } /// Defines what data to create in a parquet file @@ -158,7 +159,38 @@ impl TestReader { // open the file & get the reader let file = file.reopen().unwrap(); - ArrowReaderBuilder::try_new(file).unwrap() + let options = ArrowReaderOptions::new().with_page_index(true); + ArrowReaderBuilder::try_new_with_options(file, options).unwrap() + } +} + +/// Which statistics should we check? +#[derive(Clone, Debug, Copy)] +enum Check { + /// Extract and check row group statistics + RowGroup, + /// Extract and check data page statistics + DataPage, + /// Extract and check both row group and data page statistics. + /// + /// Note if a row group contains a single data page, + /// the statistics for row groups and data pages are the same. + Both, +} + +impl Check { + fn row_group(&self) -> bool { + match self { + Self::RowGroup | Self::Both => true, + Self::DataPage => false, + } + } + + fn data_page(&self) -> bool { + match self { + Self::DataPage | Self::Both => true, + Self::RowGroup => false, + } } } @@ -172,6 +204,8 @@ struct Test<'a> { expected_row_counts: UInt64Array, /// Which column to extract statistics from column_name: &'static str, + /// What statistics should be checked? + check: Check, } impl<'a> Test<'a> { @@ -183,6 +217,7 @@ impl<'a> Test<'a> { expected_null_counts, expected_row_counts, column_name, + check, } = self; let converter = StatisticsConverter::try_new( @@ -193,36 +228,105 @@ impl<'a> Test<'a> { .unwrap(); let row_groups = reader.metadata().row_groups(); - let min = converter.row_group_mins(row_groups).unwrap(); - - assert_eq!( - &min, &expected_min, - "{column_name}: Mismatch with expected minimums" - ); - - let max = converter.row_group_maxes(row_groups).unwrap(); - assert_eq!( - &max, &expected_max, - "{column_name}: Mismatch with expected maximum" - ); - - let null_counts = converter.row_group_null_counts(row_groups).unwrap(); let expected_null_counts = Arc::new(expected_null_counts) as ArrayRef; - assert_eq!( - &null_counts, &expected_null_counts, - "{column_name}: Mismatch with expected null counts. \ - Actual: {null_counts:?}. Expected: {expected_null_counts:?}" - ); - let row_counts = StatisticsConverter::row_group_row_counts( - reader.metadata().row_groups().iter(), - ) - .unwrap(); - assert_eq!( - row_counts, expected_row_counts, - "{column_name}: Mismatch with expected row counts. \ - Actual: {row_counts:?}. Expected: {expected_row_counts:?}" - ); + if check.data_page() { + let column_page_index = reader + .metadata() + .column_index() + .expect("File should have column page indices"); + + let column_offset_index = reader + .metadata() + .offset_index() + .expect("File should have column offset indices"); + + let row_group_indices = row_groups + .iter() + .enumerate() + .map(|(i, _)| i) + .collect::>(); + + let min = converter + .data_page_mins( + column_page_index, + column_offset_index, + &row_group_indices, + ) + .unwrap(); + assert_eq!( + &min, &expected_min, + "{column_name}: Mismatch with expected data page minimums" + ); + + let max = converter + .data_page_maxes( + column_page_index, + column_offset_index, + &row_group_indices, + ) + .unwrap(); + assert_eq!( + &max, &expected_max, + "{column_name}: Mismatch with expected data page maximum" + ); + + let null_counts = converter + .data_page_null_counts( + column_page_index, + column_offset_index, + &row_group_indices, + ) + .unwrap(); + + assert_eq!( + &null_counts, &expected_null_counts, + "{column_name}: Mismatch with expected data page null counts. \ + Actual: {null_counts:?}. Expected: {expected_null_counts:?}" + ); + + let row_counts = converter + .data_page_row_counts(column_offset_index, row_groups, &row_group_indices) + .unwrap(); + // https://github.com/apache/datafusion/issues/10926 + let expected_row_counts: ArrayRef = Arc::new(expected_row_counts.clone()); + assert_eq!( + &row_counts, &expected_row_counts, + "{column_name}: Mismatch with expected row counts. \ + Actual: {row_counts:?}. Expected: {expected_row_counts:?}" + ); + } + + if check.row_group() { + let min = converter.row_group_mins(row_groups).unwrap(); + assert_eq!( + &min, &expected_min, + "{column_name}: Mismatch with expected minimums" + ); + + let max = converter.row_group_maxes(row_groups).unwrap(); + assert_eq!( + &max, &expected_max, + "{column_name}: Mismatch with expected maximum" + ); + + let null_counts = converter.row_group_null_counts(row_groups).unwrap(); + assert_eq!( + &null_counts, &expected_null_counts, + "{column_name}: Mismatch with expected null counts. \ + Actual: {null_counts:?}. Expected: {expected_null_counts:?}" + ); + + let row_counts = StatisticsConverter::row_group_row_counts( + reader.metadata().row_groups().iter(), + ) + .unwrap(); + assert_eq!( + row_counts, expected_row_counts, + "{column_name}: Mismatch with expected row counts. \ + Actual: {row_counts:?}. Expected: {expected_row_counts:?}" + ); + } } /// Run the test and expect a column not found error @@ -234,6 +338,7 @@ impl<'a> Test<'a> { expected_null_counts: _, expected_row_counts: _, column_name, + .. } = self; let converter = StatisticsConverter::try_new( @@ -254,8 +359,15 @@ impl<'a> Test<'a> { #[tokio::test] async fn test_one_row_group_without_null() { - let row_per_group = 20; - let reader = parquet_file_one_column(0, 4, 7, row_per_group); + let reader = Int64Case { + null_values: 0, + no_null_values_start: 4, + no_null_values_end: 7, + row_per_group: 20, + ..Default::default() + } + .build(); + Test { reader: &reader, // min is 4 @@ -267,14 +379,21 @@ async fn test_one_row_group_without_null() { // 3 rows expected_row_counts: UInt64Array::from(vec![3]), column_name: "i64", + check: Check::RowGroup, } .run() } #[tokio::test] async fn test_one_row_group_with_null_and_negative() { - let row_per_group = 20; - let reader = parquet_file_one_column(2, -1, 5, row_per_group); + let reader = Int64Case { + null_values: 2, + no_null_values_start: -1, + no_null_values_end: 5, + row_per_group: 20, + ..Default::default() + } + .build(); Test { reader: &reader, @@ -287,14 +406,21 @@ async fn test_one_row_group_with_null_and_negative() { // 8 rows expected_row_counts: UInt64Array::from(vec![8]), column_name: "i64", + check: Check::RowGroup, } .run() } #[tokio::test] async fn test_two_row_group_with_null() { - let row_per_group = 10; - let reader = parquet_file_one_column(2, 4, 17, row_per_group); + let reader = Int64Case { + null_values: 2, + no_null_values_start: 4, + no_null_values_end: 17, + row_per_group: 10, + ..Default::default() + } + .build(); Test { reader: &reader, @@ -307,14 +433,21 @@ async fn test_two_row_group_with_null() { // row counts are [10, 5] expected_row_counts: UInt64Array::from(vec![10, 5]), column_name: "i64", + check: Check::RowGroup, } .run() } #[tokio::test] async fn test_two_row_groups_with_all_nulls_in_one() { - let row_per_group = 5; - let reader = parquet_file_one_column(4, -2, 2, row_per_group); + let reader = Int64Case { + null_values: 4, + no_null_values_start: -2, + no_null_values_end: 2, + row_per_group: 5, + ..Default::default() + } + .build(); Test { reader: &reader, @@ -327,6 +460,38 @@ async fn test_two_row_groups_with_all_nulls_in_one() { // row counts are [5, 3] expected_row_counts: UInt64Array::from(vec![5, 3]), column_name: "i64", + check: Check::RowGroup, + } + .run() +} + +#[tokio::test] +async fn test_multiple_data_pages_nulls_and_negatives() { + let reader = Int64Case { + null_values: 3, + no_null_values_start: -1, + no_null_values_end: 10, + row_per_group: 20, + // limit page row count to 4 + data_page_row_count_limit: Some(4), + enable_stats: Some(EnabledStatistics::Page), + } + .build(); + + // Data layout looks like this: + // + // page 0: [-1, 0, 1, 2] + // page 1: [3, 4, 5, 6] + // page 2: [7, 8, 9, null] + // page 3: [null, null] + Test { + reader: &reader, + expected_min: Arc::new(Int64Array::from(vec![Some(-1), Some(3), Some(7), None])), + expected_max: Arc::new(Int64Array::from(vec![Some(2), Some(6), Some(9), None])), + expected_null_counts: UInt64Array::from(vec![0, 0, 1, 2]), + expected_row_counts: UInt64Array::from(vec![4, 4, 4, 2]), + column_name: "i64", + check: Check::DataPage, } .run() } @@ -347,6 +512,7 @@ async fn test_int_64() { .build() .await; + // since each row has only one data page, the statistics are the same Test { reader: &reader, // mins are [-5, -4, 0, 5] @@ -358,6 +524,7 @@ async fn test_int_64() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i64", + check: Check::Both, } .run(); } @@ -383,6 +550,7 @@ async fn test_int_32() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i32", + check: Check::RowGroup, } .run(); } @@ -423,6 +591,7 @@ async fn test_int_16() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i16", + check: Check::RowGroup, } .run(); } @@ -451,6 +620,7 @@ async fn test_int_8() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i8", + check: Check::RowGroup, } .run(); } @@ -500,6 +670,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "nanos", + check: Check::RowGroup, } .run(); @@ -528,6 +699,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "nanos_timezoned", + check: Check::RowGroup, } .run(); @@ -549,6 +721,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "micros", + check: Check::RowGroup, } .run(); @@ -577,6 +750,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "micros_timezoned", + check: Check::RowGroup, } .run(); @@ -598,6 +772,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "millis", + check: Check::RowGroup, } .run(); @@ -626,6 +801,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "millis_timezoned", + check: Check::RowGroup, } .run(); @@ -647,6 +823,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "seconds", + check: Check::RowGroup, } .run(); @@ -675,6 +852,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "seconds_timezoned", + check: Check::RowGroup, } .run(); } @@ -720,6 +898,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "nanos", + check: Check::RowGroup, } .run(); @@ -746,6 +925,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "nanos_timezoned", + check: Check::RowGroup, } .run(); @@ -765,6 +945,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "micros", + check: Check::RowGroup, } .run(); @@ -791,6 +972,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "micros_timezoned", + check: Check::RowGroup, } .run(); @@ -810,6 +992,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "millis", + check: Check::RowGroup, } .run(); @@ -836,6 +1019,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "millis_timezoned", + check: Check::RowGroup, } .run(); @@ -855,6 +1039,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "seconds", + check: Check::RowGroup, } .run(); @@ -881,6 +1066,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "seconds_timezoned", + check: Check::RowGroup, } .run(); } @@ -918,6 +1104,7 @@ async fn test_dates_32_diff_rg_sizes() { // row counts are [13, 7] expected_row_counts: UInt64Array::from(vec![13, 7]), column_name: "date32", + check: Check::RowGroup, } .run(); } @@ -940,6 +1127,7 @@ async fn test_time32_second_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "second", + check: Check::RowGroup, } .run(); } @@ -966,6 +1154,7 @@ async fn test_time32_millisecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "millisecond", + check: Check::RowGroup, } .run(); } @@ -998,6 +1187,7 @@ async fn test_time64_microsecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "microsecond", + check: Check::RowGroup, } .run(); } @@ -1030,6 +1220,7 @@ async fn test_time64_nanosecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "nanosecond", + check: Check::RowGroup, } .run(); } @@ -1056,6 +1247,7 @@ async fn test_dates_64_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![2, 2]), expected_row_counts: UInt64Array::from(vec![13, 7]), column_name: "date64", + check: Check::RowGroup, } .run(); } @@ -1083,6 +1275,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u8", + check: Check::RowGroup, } .run(); @@ -1093,6 +1286,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u16", + check: Check::RowGroup, } .run(); @@ -1103,6 +1297,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u32", + check: Check::RowGroup, } .run(); @@ -1113,6 +1308,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u64", + check: Check::RowGroup, } .run(); } @@ -1135,6 +1331,7 @@ async fn test_int32_range() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: UInt64Array::from(vec![4]), column_name: "i", + check: Check::RowGroup, } .run(); } @@ -1157,6 +1354,7 @@ async fn test_uint32_range() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: UInt64Array::from(vec![4]), column_name: "u", + check: Check::RowGroup, } .run(); } @@ -1178,6 +1376,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u8", + check: Check::RowGroup, } .run(); @@ -1188,6 +1387,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u16", + check: Check::RowGroup, } .run(); @@ -1198,6 +1398,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u32", + check: Check::RowGroup, } .run(); @@ -1208,6 +1409,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u64", + check: Check::RowGroup, } .run(); } @@ -1229,6 +1431,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i8", + check: Check::RowGroup, } .run(); @@ -1239,6 +1442,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i16", + check: Check::RowGroup, } .run(); @@ -1249,6 +1453,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i32", + check: Check::RowGroup, } .run(); @@ -1259,6 +1464,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i64", + check: Check::RowGroup, } .run(); } @@ -1280,6 +1486,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f32", + check: Check::RowGroup, } .run(); @@ -1290,6 +1497,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f64", + check: Check::RowGroup, } .run(); @@ -1300,6 +1508,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f32_nan", + check: Check::RowGroup, } .run(); @@ -1310,6 +1519,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f64_nan", + check: Check::RowGroup, } .run(); } @@ -1332,6 +1542,7 @@ async fn test_float64() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "f", + check: Check::RowGroup, } .run(); } @@ -1364,6 +1575,7 @@ async fn test_float16() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "f", + check: Check::RowGroup, } .run(); } @@ -1394,6 +1606,7 @@ async fn test_decimal() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "decimal_col", + check: Check::RowGroup, } .run(); } @@ -1431,6 +1644,7 @@ async fn test_decimal_256() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "decimal256_col", + check: Check::RowGroup, } .run(); } @@ -1450,6 +1664,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "string_dict_i8", + check: Check::RowGroup, } .run(); @@ -1460,6 +1675,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "string_dict_i32", + check: Check::RowGroup, } .run(); @@ -1470,6 +1686,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "int_dict_i8", + check: Check::RowGroup, } .run(); } @@ -1507,6 +1724,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "name", + check: Check::RowGroup, } .run(); @@ -1526,6 +1744,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_string", + check: Check::RowGroup, } .run(); @@ -1544,6 +1763,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_binary", + check: Check::RowGroup, } .run(); @@ -1564,6 +1784,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_fixedsize", + check: Check::RowGroup, } .run(); @@ -1584,6 +1805,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_large_binary", + check: Check::RowGroup, } .run(); } @@ -1616,6 +1838,7 @@ async fn test_period_in_column_names() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "name", + check: Check::RowGroup, } .run(); @@ -1629,6 +1852,7 @@ async fn test_period_in_column_names() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service.name", + check: Check::RowGroup, } .run(); } @@ -1652,6 +1876,7 @@ async fn test_boolean() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 5]), column_name: "bool", + check: Check::RowGroup, } .run(); } @@ -1678,6 +1903,7 @@ async fn test_struct() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: UInt64Array::from(vec![3]), column_name: "struct", + check: Check::RowGroup, } .run(); } @@ -1700,6 +1926,7 @@ async fn test_utf8() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 5]), column_name: "utf8", + check: Check::RowGroup, } .run(); @@ -1711,6 +1938,7 @@ async fn test_utf8() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 5]), column_name: "large_utf8", + check: Check::RowGroup, } .run(); } @@ -1719,9 +1947,15 @@ async fn test_utf8() { #[tokio::test] async fn test_missing_statistics() { - let row_per_group = 5; - let reader = - parquet_file_one_column_stats(0, 4, 7, row_per_group, EnabledStatistics::None); + let reader = Int64Case { + null_values: 0, + no_null_values_start: 4, + no_null_values_end: 7, + row_per_group: 5, + enable_stats: Some(EnabledStatistics::None), + ..Default::default() + } + .build(); Test { reader: &reader, @@ -1730,6 +1964,7 @@ async fn test_missing_statistics() { expected_null_counts: UInt64Array::from(vec![None]), expected_row_counts: UInt64Array::from(vec![3]), // stil has row count statistics column_name: "i64", + check: Check::RowGroup, } .run(); } @@ -1751,6 +1986,7 @@ async fn test_column_not_found() { expected_null_counts: UInt64Array::from(vec![2, 2]), expected_row_counts: UInt64Array::from(vec![13, 7]), column_name: "not_a_column", + check: Check::RowGroup, } .run_col_not_found(); } diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 9546ab30c9e0..0434a271c32e 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -43,7 +43,7 @@ use datafusion::{ use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use half::f16; use parquet::arrow::ArrowWriter; -use parquet::file::properties::WriterProperties; +use parquet::file::properties::{EnabledStatistics, WriterProperties}; use std::sync::Arc; use tempfile::NamedTempFile; @@ -1349,6 +1349,7 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem let props = WriterProperties::builder() .set_max_row_group_size(row_per_group) .set_bloom_filter_enabled(true) + .set_statistics_enabled(EnabledStatistics::Page) .build(); let batches = create_data_batch(scenario); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index e3d2e6555d5c..441e8953dffc 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use std::{fmt, str::FromStr}; use crate::utils; -use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility}; +use crate::{type_coercion::aggregates::*, Signature, Volatility}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; @@ -33,8 +33,6 @@ use strum_macros::EnumIter; // https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// Count - Count, /// Minimum Min, /// Maximum @@ -47,28 +45,6 @@ pub enum AggregateFunction { NthValue, /// Correlation Correlation, - /// Slope from linear regression - RegrSlope, - /// Intercept from linear regression - RegrIntercept, - /// Number of input rows in which both expressions are not null - RegrCount, - /// R-squared value from linear regression - RegrR2, - /// Average of the independent variable - RegrAvgx, - /// Average of the dependent variable - RegrAvgy, - /// Sum of squares of the independent variable - RegrSXX, - /// Sum of squares of the dependent variable - RegrSYY, - /// Sum of products of pairs of numbers - RegrSXY, - /// Approximate continuous percentile function - ApproxPercentileCont, - /// Approximate continuous percentile function with weight - ApproxPercentileContWithWeight, /// Grouping Grouping, /// Bit And @@ -89,24 +65,12 @@ impl AggregateFunction { pub fn name(&self) -> &str { use AggregateFunction::*; match self { - Count => "COUNT", Min => "MIN", Max => "MAX", Avg => "AVG", ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", Correlation => "CORR", - RegrSlope => "REGR_SLOPE", - RegrIntercept => "REGR_INTERCEPT", - RegrCount => "REGR_COUNT", - RegrR2 => "REGR_R2", - RegrAvgx => "REGR_AVGX", - RegrAvgy => "REGR_AVGY", - RegrSXX => "REGR_SXX", - RegrSYY => "REGR_SYY", - RegrSXY => "REGR_SXY", - ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Grouping => "GROUPING", BitAnd => "BIT_AND", BitOr => "BIT_OR", @@ -135,7 +99,6 @@ impl FromStr for AggregateFunction { "bit_xor" => AggregateFunction::BitXor, "bool_and" => AggregateFunction::BoolAnd, "bool_or" => AggregateFunction::BoolOr, - "count" => AggregateFunction::Count, "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, "min" => AggregateFunction::Min, @@ -144,20 +107,6 @@ impl FromStr for AggregateFunction { "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, - "regr_slope" => AggregateFunction::RegrSlope, - "regr_intercept" => AggregateFunction::RegrIntercept, - "regr_count" => AggregateFunction::RegrCount, - "regr_r2" => AggregateFunction::RegrR2, - "regr_avgx" => AggregateFunction::RegrAvgx, - "regr_avgy" => AggregateFunction::RegrAvgy, - "regr_sxx" => AggregateFunction::RegrSXX, - "regr_syy" => AggregateFunction::RegrSYY, - "regr_sxy" => AggregateFunction::RegrSXY, - // approximate - "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, - "approx_percentile_cont_with_weight" => { - AggregateFunction::ApproxPercentileContWithWeight - } // other "grouping" => AggregateFunction::Grouping, _ => { @@ -190,7 +139,6 @@ impl AggregateFunction { })?; match self { - AggregateFunction::Count => Ok(DataType::Int64), AggregateFunction::Max | AggregateFunction::Min => { // For min and max agg function, the returned type is same as input type. // The coerced_data_types is same with input_types. @@ -205,25 +153,12 @@ impl AggregateFunction { AggregateFunction::Correlation => { correlation_return_type(&coerced_data_types[0]) } - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => Ok(DataType::Float64), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( "item", coerced_data_types[0].clone(), true, )))), - AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), - AggregateFunction::ApproxPercentileContWithWeight => { - Ok(coerced_data_types[0].clone()) - } AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), @@ -249,7 +184,6 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable), AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { Signature::any(1, Volatility::Immutable) } @@ -278,51 +212,9 @@ impl AggregateFunction { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), - AggregateFunction::Correlation - | AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { + AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } - AggregateFunction::ApproxPercentileCont => { - let mut variants = - Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); - // Accept any numeric value paired with a float64 percentile - for num in NUMERICS { - variants - .push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); - // Additionally accept an integer number of centroids for T-Digest - for int in INTEGERS { - variants.push(TypeSignature::Exact(vec![ - num.clone(), - DataType::Float64, - int.clone(), - ])) - } - } - - Signature::one_of(variants, Volatility::Immutable) - } - AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Volatility::Immutable, - ), AggregateFunction::StringAgg => { Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 57f5414c13bd..9ba866a4c919 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2135,18 +2135,6 @@ mod test { use super::*; - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); @@ -2250,7 +2238,6 @@ mod test { "nth_value", "min", "max", - "count", "avg", ]; for name in names { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1fafc63e9665..099851aece46 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -192,19 +192,6 @@ pub fn avg(expr: Expr) -> Expr { )) } -/// Create an expression to represent the count() aggregate function -// TODO: Remove this and use `expr_fn::count` instead -pub fn count(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( @@ -250,52 +237,11 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { )) } -/// Create an expression to represent the count(distinct) aggregate function -// TODO: Remove this and use `expr_fn::count_distinct` instead -pub fn count_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - true, - None, - None, - None, - )) -} - /// Create an in_list expression pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { Expr::InList(InList::new(Box::new(expr), list, negated)) } -/// Calculate an approximation of the specified `percentile` for `expr`. -pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxPercentileCont, - vec![expr, percentile], - false, - None, - None, - None, - )) -} - -/// Calculate an approximation of the specified `percentile` for `expr` and `weight_expr`. -pub fn approx_percentile_cont_with_weight( - expr: Expr, - weight_expr: Expr, - percentile: Expr, -) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxPercentileContWithWeight, - vec![expr, weight_expr, percentile], - false, - None, - None, - None, - )) -} - /// Create an EXISTS subquery expression pub fn exists(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 7ea0313bf776..986f85adebaa 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -141,7 +141,7 @@ impl ExprSchemable for Expr { // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { plan_datafusion_err!( - "{} and {}", + "{} {}", err, utils::generate_signature_error_msg( func.name(), @@ -164,7 +164,7 @@ impl ExprSchemable for Expr { WindowFunctionDefinition::AggregateUDF(udf) => { let new_types = data_types_with_aggregate_udf(&data_types, udf).map_err(|err| { plan_datafusion_err!( - "{} and {}", + "{} {}", err, utils::generate_signature_error_msg( fun.name(), @@ -192,7 +192,7 @@ impl ExprSchemable for Expr { AggregateFunctionDefinition::UDF(fun) => { let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| { plan_datafusion_err!( - "{} and {}", + "{} {}", err, utils::generate_signature_error_msg( fun.name(), diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index c06f177510e7..169436145aae 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -83,8 +83,8 @@ pub struct AccumulatorArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, - /// The number of arguments the aggregate function takes. - pub args_num: usize, + /// The logical expression of arguments the aggregate function takes. + pub input_exprs: &'a [Expr], } /// [`StateFieldsArgs`] contains information about the fields that an diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9ea2abe64ede..02378ab3fc1b 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2965,11 +2965,13 @@ mod tests { use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet}; use datafusion_common::tree_node::TreeNodeVisitor; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; + use crate::test::function_stub::count; + fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index b9aa1e636d94..ac98ee9747cc 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -31,7 +31,7 @@ use crate::{ use arrow::datatypes::{ DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { @@ -69,6 +69,19 @@ pub fn sum(expr: Expr) -> Expr { )) } +create_func!(Count, count_udaf); + +pub fn count(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + count_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + /// Stub `sum` used for optimizer testing #[derive(Debug)] pub struct Sum { @@ -189,3 +202,74 @@ impl AggregateUDFImpl for Sum { AggregateOrderSensitivity::Insensitive } } + +/// Testing stub implementation of COUNT aggregate +pub struct Count { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "COUNT" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index ab7deaff9885..98324ed6120b 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -17,7 +17,6 @@ use std::ops::Deref; -use super::functions::can_coerce_from; use crate::{AggregateFunction, Signature, TypeSignature}; use arrow::datatypes::{ @@ -96,7 +95,6 @@ pub fn coerce_types( check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; match agg_fun { - AggregateFunction::Count => Ok(input_types.to_vec()), AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), AggregateFunction::Min | AggregateFunction::Max => { // min and max support the dictionary data type @@ -159,76 +157,6 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { - let valid_types = [NUMERICS.to_vec(), vec![Null]].concat(); - let input_types_valid = // number of input already checked before - valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]); - if !input_types_valid { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } - AggregateFunction::ApproxPercentileCont => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - if input_types.len() == 3 && !input_types[2].is_integer() { - return plan_err!( - "The percentile sample points count for {:?} must be integer, not {:?}.", - agg_fun, input_types[2] - ); - } - let mut result = input_types.to_vec(); - if can_coerce_from(&Float64, &input_types[1]) { - result[1] = Float64; - } else { - return plan_err!( - "Could not coerce the percent argument for {:?} to Float64. Was {:?}.", - agg_fun, input_types[1] - ); - } - Ok(result) - } - AggregateFunction::ApproxPercentileContWithWeight => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) { - return plan_err!( - "The weight argument for {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[1] - ); - } - if !matches!(input_types[2], Float64) { - return plan_err!( - "The percentile argument for {:?} must be Float64, not {:?}.", - agg_fun, - input_types[2] - ); - } - Ok(input_types.to_vec()) - } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), AggregateFunction::StringAgg => { @@ -481,15 +409,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool { arg_type.is_integer() } -/// Return `true` if `arg_type` is of a [`DataType`] that the -/// [`AggregateFunction::ApproxPercentileCont`] aggregation can operate on. -pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - /// Return `true` if `arg_type` is of a [`DataType`] that the /// [`AggregateFunction::StringAgg`] aggregation can operate on. pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { @@ -525,7 +444,6 @@ mod tests { // test count, array_agg, approx_distinct, min, max. // the coerced types is same with input types let funs = vec![ - AggregateFunction::Count, AggregateFunction::ArrayAgg, AggregateFunction::Min, AggregateFunction::Max, @@ -555,29 +473,6 @@ mod tests { assert_eq!(r[0], DataType::Decimal128(20, 3)); let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap(); assert_eq!(r[0], DataType::Decimal256(20, 3)); - - // ApproxPercentileCont input types - let input_types = vec![ - vec![DataType::Int8, DataType::Float64], - vec![DataType::Int16, DataType::Float64], - vec![DataType::Int32, DataType::Float64], - vec![DataType::Int64, DataType::Float64], - vec![DataType::UInt8, DataType::Float64], - vec![DataType::UInt16, DataType::Float64], - vec![DataType::UInt32, DataType::Float64], - vec![DataType::UInt64, DataType::Float64], - vec![DataType::Float32, DataType::Float64], - vec![DataType::Float64, DataType::Float64], - ]; - for input_type in &input_types { - let signature = AggregateFunction::ApproxPercentileCont.signature(); - let result = coerce_types( - &AggregateFunction::ApproxPercentileCont, - input_type, - &signature, - ); - assert_eq!(*input_type, result.unwrap()); - } } #[test] diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 4dd8d6371934..5f060a4a4f16 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -49,10 +49,7 @@ pub fn data_types_with_scalar_udf( if signature.type_signature.supports_zero_argument() { return Ok(vec![]); } else { - return plan_err!( - "[data_types_with_scalar_udf] signature {:?} does not support zero arguments.", - &signature.type_signature - ); + return plan_err!("{} does not support zero arguments.", func.name()); } } @@ -79,11 +76,7 @@ pub fn data_types_with_aggregate_udf( if signature.type_signature.supports_zero_argument() { return Ok(vec![]); } else { - return plan_err!( - "[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.", - current_types, - &signature.type_signature - ); + return plan_err!("{} does not support zero arguments.", func.name()); } } @@ -118,8 +111,7 @@ pub fn data_types( return Ok(vec![]); } else { return plan_err!( - "[data_types] Coercion from {:?} to the signature {:?} failed.", - current_types, + "signature {:?} does not support zero arguments.", &signature.type_signature ); } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 71a3a5fe7309..3ab0c180dcba 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -818,7 +818,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( } } Expr::Literal(_) => { - indexes.push(std::usize::MAX); + indexes.push(usize::MAX); } _ => {} } diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index b8b86d30557a..bc723c862953 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -28,7 +28,6 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; -use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref; use crate::approx_percentile_cont::ApproxPercentileAccumulator; @@ -118,12 +117,3 @@ impl AggregateUDFImpl for ApproxMedian { ))) } } - -impl PartialEq for ApproxMedian { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.signature == x.signature) - .unwrap_or(false) - } -} diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index e75417efc684..5ae5684d9cab 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use arrow::array::RecordBatch; use arrow::{ array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, @@ -22,12 +27,238 @@ use arrow::{ }, datatypes::DataType, }; +use arrow_schema::{Field, Schema}; -use datafusion_common::{downcast_value, internal_err, DataFusionError, ScalarValue}; -use datafusion_expr::Accumulator; +use datafusion_common::{ + downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, ScalarValue, +}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, + Volatility, +}; use datafusion_physical_expr_common::aggregate::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; +use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr; + +make_udaf_expr_and_func!( + ApproxPercentileCont, + approx_percentile_cont, + expression percentile, + "Computes the approximate percentile continuous of a set of numbers", + approx_percentile_cont_udaf +); + +pub struct ApproxPercentileCont { + signature: Signature, +} + +impl Debug for ApproxPercentileCont { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("ApproxPercentileCont") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxPercentileCont { + fn default() -> Self { + Self::new() + } +} + +impl ApproxPercentileCont { + /// Create a new [`ApproxPercentileCont`] aggregate function. + pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); + // Accept any numeric value paired with a float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + DataType::Float64, + int.clone(), + ])) + } + } + Self { + signature: Signature::one_of(variants, Volatility::Immutable), + } + } + + pub(crate) fn create_accumulator( + &self, + args: AccumulatorArgs, + ) -> datafusion_common::Result { + let percentile = validate_input_percentile_expr(&args.input_exprs[1])?; + let tdigest_max_size = if args.input_exprs.len() == 3 { + Some(validate_input_max_size_expr(&args.input_exprs[2])?) + } else { + None + }; + + let accumulator: ApproxPercentileAccumulator = match args.input_type { + t @ (DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64) => { + if let Some(max_size) = tdigest_max_size { + ApproxPercentileAccumulator::new_with_max_size(percentile, t.clone(), max_size) + }else{ + ApproxPercentileAccumulator::new(percentile, t.clone()) + + } + } + other => { + return not_impl_err!( + "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented" + ) + } + }; + + Ok(accumulator) + } +} + +fn get_lit_value(expr: &Expr) -> datafusion_common::Result { + let empty_schema = Arc::new(Schema::empty()); + let empty_batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); + let expr = limited_convert_logical_expr_to_physical_expr(expr, &empty_schema)?; + let result = expr.evaluate(&empty_batch)?; + match result { + ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( + "The expr {:?} can't be evaluated to scalar value", + expr + ))), + ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + } +} + +fn validate_input_percentile_expr(expr: &Expr) -> datafusion_common::Result { + let lit = get_lit_value(expr)?; + let percentile = match &lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q, + got => return not_impl_err!( + "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + got.data_type() + ) + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return plan_err!( + "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" + ); + } + Ok(percentile) +} + +fn validate_input_max_size_expr(expr: &Expr) -> datafusion_common::Result { + let lit = get_lit_value(expr)?; + let max_size = match &lit { + ScalarValue::UInt8(Some(q)) => *q as usize, + ScalarValue::UInt16(Some(q)) => *q as usize, + ScalarValue::UInt32(Some(q)) => *q as usize, + ScalarValue::UInt64(Some(q)) => *q as usize, + ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize, + ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize, + ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize, + ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize, + got => return not_impl_err!( + "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", + got.data_type() + ) + }; + Ok(max_size) +} + +impl AggregateUDFImpl for ApproxPercentileCont { + fn as_any(&self) -> &dyn Any { + self + } + + #[allow(rustdoc::private_intra_doc_links)] + /// See [`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields( + &self, + args: StateFieldsArgs, + ) -> datafusion_common::Result> { + Ok(vec![ + Field::new( + format_state_name(args.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + format_state_name(args.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + format_state_name(args.name, "count"), + DataType::Float64, + false, + ), + Field::new( + format_state_name(args.name, "max"), + DataType::Float64, + false, + ), + Field::new( + format_state_name(args.name, "min"), + DataType::Float64, + false, + ), + Field::new_list( + format_state_name(args.name, "centroids"), + Field::new("item", DataType::Float64, true), + false, + ), + ]) + } + + fn name(&self) -> &str { + "approx_percentile_cont" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + #[inline] + fn accumulator( + &self, + acc_args: AccumulatorArgs, + ) -> datafusion_common::Result> { + Ok(Box::new(self.create_accumulator(acc_args)?)) + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + if !arg_types[0].is_numeric() { + return plan_err!("approx_percentile_cont requires numeric input types"); + } + if arg_types.len() == 3 && !arg_types[2].is_integer() { + return plan_err!( + "approx_percentile_cont requires integer max_size input types" + ); + } + Ok(arg_types[0].clone()) + } +} #[derive(Debug)] pub struct ApproxPercentileAccumulator { diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs similarity index 51% rename from datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs rename to datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 07c2aff3437f..a64218c606c4 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -15,105 +15,140 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::ApproxPercentileCont; -use crate::{AggregateExpr, PhysicalExpr}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; + use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; -use datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator; + +use datafusion_common::ScalarValue; +use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature}; use datafusion_physical_expr_common::aggregate::tdigest::{ Centroid, TDigest, DEFAULT_MAX_SIZE, }; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; -use crate::aggregate::utils::down_cast_any_ref; -use std::{any::Any, sync::Arc}; +make_udaf_expr_and_func!( + ApproxPercentileContWithWeight, + approx_percentile_cont_with_weight, + expression weight percentile, + "Computes the approximate percentile continuous with weight of a set of numbers", + approx_percentile_cont_with_weight_udaf +); /// APPROX_PERCENTILE_CONT_WITH_WEIGTH aggregate expression -#[derive(Debug)] pub struct ApproxPercentileContWithWeight { + signature: Signature, approx_percentile_cont: ApproxPercentileCont, - column_expr: Arc, - weight_expr: Arc, - percentile_expr: Arc, +} + +impl Debug for ApproxPercentileContWithWeight { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ApproxPercentileContWithWeight") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxPercentileContWithWeight { + fn default() -> Self { + Self::new() + } } impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. - pub fn new( - expr: Vec>, - name: impl Into, - return_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, WeightExpr, DesiredPercentileLiteral] - debug_assert_eq!(expr.len(), 3); - - let sub_expr = vec![expr[0].clone(), expr[2].clone()]; - let approx_percentile_cont = - ApproxPercentileCont::new(sub_expr, name, return_type)?; - - Ok(Self { - approx_percentile_cont, - column_expr: expr[0].clone(), - weight_expr: expr[1].clone(), - percentile_expr: expr[2].clone(), - }) + pub fn new() -> Self { + Self { + signature: Signature::one_of( + // Accept any numeric value paired with a float64 percentile + NUMERICS + .iter() + .map(|t| { + TypeSignature::Exact(vec![ + t.clone(), + t.clone(), + DataType::Float64, + ]) + }) + .collect(), + Immutable, + ), + approx_percentile_cont: ApproxPercentileCont::new(), + } } } -impl AggregateExpr for ApproxPercentileContWithWeight { +impl AggregateUDFImpl for ApproxPercentileContWithWeight { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - self.approx_percentile_cont.field() + fn name(&self) -> &str { + "approx_percentile_cont_with_weight" } - #[allow(rustdoc::private_intra_doc_links)] - /// See [`TDigest::to_scalar_state()`] for a description of the serialised - /// state. - fn state_fields(&self) -> Result> { - self.approx_percentile_cont.state_fields() + fn signature(&self) -> &Signature { + &self.signature } - fn expressions(&self) -> Vec> { - vec![ - self.column_expr.clone(), - self.weight_expr.clone(), - self.percentile_expr.clone(), - ] + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!( + "approx_percentile_cont_with_weight requires numeric input types" + ); + } + if !arg_types[1].is_numeric() { + return plan_err!( + "approx_percentile_cont_with_weight requires numeric weight input types" + ); + } + if arg_types[2] != DataType::Float64 { + return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); + } + Ok(arg_types[0].clone()) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!( + "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" + ); + } + + if acc_args.input_exprs.len() != 3 { + return plan_err!( + "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" + ); + } + + let sub_args = AccumulatorArgs { + input_exprs: &[ + acc_args.input_exprs[0].clone(), + acc_args.input_exprs[2].clone(), + ], + ..acc_args + }; let approx_percentile_cont_accumulator = - self.approx_percentile_cont.create_plain_accumulator()?; + self.approx_percentile_cont.create_accumulator(sub_args)?; let accumulator = ApproxPercentileWithWeightAccumulator::new( approx_percentile_cont_accumulator, ); Ok(Box::new(accumulator)) } - fn name(&self) -> &str { - self.approx_percentile_cont.name() - } -} - -impl PartialEq for ApproxPercentileContWithWeight { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.approx_percentile_cont == x.approx_percentile_cont - && self.column_expr.eq(&x.column_expr) - && self.weight_expr.eq(&x.weight_expr) - && self.percentile_expr.eq(&x.percentile_expr) - }) - .unwrap_or(false) + #[allow(rustdoc::private_intra_doc_links)] + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.approx_percentile_cont.state_fields(args) } } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index cfd56619537b..062e148975bf 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -258,7 +258,7 @@ impl AggregateUDFImpl for Count { if args.is_distinct { return false; } - args.args_num == 1 + args.input_exprs.len() == 1 } fn create_groups_accumulator( diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 56fc1305bb59..daddb9d93f78 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -61,13 +61,17 @@ pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod regr; pub mod stddev; pub mod sum; pub mod variance; pub mod approx_median; pub mod approx_percentile_cont; +pub mod approx_percentile_cont_with_weight; +use crate::approx_percentile_cont::approx_percentile_cont_udaf; +use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::AggregateUDF; @@ -78,6 +82,8 @@ use std::sync::Arc; pub mod expr_fn { pub use super::approx_distinct; pub use super::approx_median::approx_median; + pub use super::approx_percentile_cont::approx_percentile_cont; + pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; pub use super::count::count; pub use super::count::count_distinct; pub use super::covariance::covar_pop; @@ -85,6 +91,15 @@ pub mod expr_fn { pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::median::median; + pub use super::regr::regr_avgx; + pub use super::regr::regr_avgy; + pub use super::regr::regr_count; + pub use super::regr::regr_intercept; + pub use super::regr::regr_r2; + pub use super::regr::regr_slope; + pub use super::regr::regr_sxx; + pub use super::regr::regr_sxy; + pub use super::regr::regr_syy; pub use super::stddev::stddev; pub use super::stddev::stddev_pop; pub use super::sum::sum; @@ -102,12 +117,23 @@ pub fn all_default_aggregate_functions() -> Vec> { covariance::covar_pop_udaf(), median::median_udaf(), count::count_udaf(), + regr::regr_slope_udaf(), + regr::regr_intercept_udaf(), + regr::regr_count_udaf(), + regr::regr_r2_udaf(), + regr::regr_avgx_udaf(), + regr::regr_avgy_udaf(), + regr::regr_sxx_udaf(), + regr::regr_syy_udaf(), + regr::regr_sxy_udaf(), variance::var_samp_udaf(), variance::var_pop_udaf(), stddev::stddev_udaf(), stddev::stddev_pop_udaf(), approx_median::approx_median_udaf(), approx_distinct::approx_distinct_udaf(), + approx_percentile_cont_udaf(), + approx_percentile_cont_with_weight_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 75bb9dc54719..cae72cf35223 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -32,8 +32,8 @@ // specific language governing permissions and limitations // under the License. -macro_rules! make_udaf_expr_and_func { - ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { +macro_rules! make_udaf_expr { + ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN( @@ -48,7 +48,12 @@ macro_rules! make_udaf_expr_and_func { None, )) } + }; +} +macro_rules! make_udaf_expr_and_func { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN); create_func!($UDAF, $AGGREGATE_UDF_FN); }; ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { @@ -73,6 +78,9 @@ macro_rules! make_udaf_expr_and_func { macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { + create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); + }; + ($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => { paste::paste! { /// Singleton instance of [$UDAF], ensures the UDAF is only created once /// named STATIC_$(UDAF). For example `STATIC_FirstValue` @@ -86,7 +94,7 @@ macro_rules! create_func { pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { [< STATIC_ $UDAF >] .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default())) + std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE)) }) .clone() } diff --git a/datafusion/physical-expr/src/aggregate/regr.rs b/datafusion/functions-aggregate/src/regr.rs similarity index 84% rename from datafusion/physical-expr/src/aggregate/regr.rs rename to datafusion/functions-aggregate/src/regr.rs index 36e7b7c9b3e4..8d04ae87157d 100644 --- a/datafusion/physical-expr/src/aggregate/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -18,9 +18,8 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::sync::Arc; +use std::fmt::Debug; -use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::Float64Array; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -28,13 +27,56 @@ use arrow::{ datatypes::DataType, datatypes::Field, }; -use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; +use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; + +macro_rules! make_regr_udaf_expr_and_func { + ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { + make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN); + create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN))); + } +} + +make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope); +make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept); +make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count); +make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2); +make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX); +make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY); +make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); +make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); +make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); + +pub struct Regr { + signature: Signature, + regr_type: RegrType, + func_name: &'static str, +} -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; +impl Debug for Regr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("regr") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} +impl Regr { + pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { + Self { + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + regr_type, + func_name, + } + } +} + +/* #[derive(Debug)] pub struct Regr { name: String, @@ -48,6 +90,7 @@ impl Regr { self.regr_type.clone() } } +*/ #[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] @@ -92,86 +135,75 @@ pub enum RegrType { SXY, } -impl Regr { - pub fn new( - expr_y: Arc, - expr_x: Arc, - name: impl Into, - regr_type: RegrType, - return_type: DataType, - ) -> Self { - // the result of regr_slope only support FLOAT64 data type. - assert!(matches!(return_type, DataType::Float64)); - Self { - name: name.into(), - regr_type, - expr_y, - expr_x, - } - } -} - -impl AggregateExpr for Regr { +impl AggregateUDFImpl for Regr { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + self.func_name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Covariance requires numeric input types"); + } + + Ok(DataType::Float64) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn create_sliding_accumulator(&self) -> Result> { + fn create_sliding_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "count"), + format_state_name(args.name, "count"), DataType::UInt64, true, ), Field::new( - format_state_name(&self.name, "mean_x"), + format_state_name(args.name, "mean_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "mean_y"), + format_state_name(args.name, "mean_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_x"), + format_state_name(args.name, "m2_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_y"), + format_state_name(args.name, "m2_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "algo_const"), + format_state_name(args.name, "algo_const"), DataType::Float64, true, ), ]) } - - fn expressions(&self) -> Vec> { - vec![self.expr_y.clone(), self.expr_x.clone()] - } - - fn name(&self) -> &str { - &self.name - } } +/* impl PartialEq for Regr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) @@ -184,6 +216,7 @@ impl PartialEq for Regr { .unwrap_or(false) } } +*/ /// `RegrAccumulator` is used to compute linear regression aggregate functions /// by maintaining statistics needed to compute them in an online fashion. @@ -305,6 +338,10 @@ impl Accumulator for RegrAccumulator { Ok(()) } + fn supports_retract_batch(&self) -> bool { + true + } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values_y = &cast(&values[0], &DataType::Float64)?; let values_x = &cast(&values[1], &DataType::Float64)?; diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 4c3effe7650a..42cf44f65d8f 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -332,7 +332,7 @@ mod tests { name: "a", is_distinct: false, input_type: &DataType::Float64, - args_num: 1, + input_exprs: &[datafusion_expr::col("a")], }; let args2 = AccumulatorArgs { @@ -343,7 +343,7 @@ mod tests { name: "a", is_distinct: false, input_type: &DataType::Float64, - args_num: 1, + input_exprs: &[datafusion_expr::col("a")], }; let mut accum1 = agg1.accumulator(args1)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs new file mode 100644 index 000000000000..faf979f80614 --- /dev/null +++ b/datafusion/functions/src/string/contains.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Boolean; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_common::{arrow_datafusion_err, exec_err}; +use datafusion_expr::ScalarUDFImpl; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; +#[derive(Debug)] +pub struct ContainsFunc { + signature: Signature, +} + +impl Default for ContainsFunc { + fn default() -> Self { + ContainsFunc::new() + } +} + +impl ContainsFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ContainsFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(contains::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(contains::, vec![])(args), + other => { + exec_err!("unsupported data type {other:?} for function contains") + } + } + } +} + +/// use regexp_is_match_utf8_scalar to do the calculation for contains +pub fn contains( + args: &[ArrayRef], +) -> Result { + let mod_str = as_generic_string_array::(&args[0])?; + let match_str = as_generic_string_array::(&args[1])?; + let res = arrow::compute::kernels::comparison::regexp_is_match_utf8( + mod_str, match_str, None, + ) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(Arc::new(res) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::string::contains::ContainsFunc; + use crate::utils::test::test_function; + use arrow::array::Array; + use arrow::{array::BooleanArray, datatypes::DataType::Boolean}; + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use datafusion_expr::ColumnarValue; + use datafusion_expr::ScalarUDFImpl; + #[test] + fn test_functions() -> Result<()> { + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("alph")), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("dddddd")), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("pha")), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + Ok(()) + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 219ef8b5a50f..5bf372c29f2d 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -28,6 +28,7 @@ pub mod chr; pub mod common; pub mod concat; pub mod concat_ws; +pub mod contains; pub mod ends_with; pub mod initcap; pub mod levenshtein; @@ -43,7 +44,6 @@ pub mod starts_with; pub mod to_hex; pub mod upper; pub mod uuid; - // create UDFs make_udf_function!(ascii::AsciiFunc, ASCII, ascii); make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); @@ -66,7 +66,7 @@ make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); make_udf_function!(upper::UpperFunc, UPPER, upper); make_udf_function!(uuid::UuidFunc, UUID, uuid); - +make_udf_function!(contains::ContainsFunc, CONTAINS, contains); pub mod expr_fn { use datafusion_expr::Expr; @@ -149,6 +149,9 @@ pub mod expr_fn { ),( uuid, "returns uuid v4 as a string value", + ), ( + contains, + "Return true if search_string is found within string. treated it like a reglike", )); #[doc = "Removes all characters, spaces by default, from both sides of a string"] @@ -188,5 +191,6 @@ pub fn functions() -> Vec> { to_hex(), upper(), uuid(), + contains(), ] } diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index cb14f6bdd4a3..1a9e9630c076 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -56,5 +56,6 @@ regex-syntax = "0.8.0" [dev-dependencies] arrow-buffer = { workspace = true } ctor = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index af1c99c52390..de2af520053a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -25,9 +25,7 @@ use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionDefinition, WindowFunction, }; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{ - aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition, -}; +use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -56,37 +54,19 @@ fn is_wildcard(expr: &Expr) -> bool { } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { - match aggregate_function { + matches!(aggregate_function, AggregateFunction { func_def: AggregateFunctionDefinition::UDF(udf), args, .. - } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true, - AggregateFunction { - func_def: - AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ), - args, - .. - } if args.len() == 1 && is_wildcard(&args[0]) => true, - _ => false, - } + } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { let args = &window_function.args; - match window_function.fun { - WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ) if args.len() == 1 && is_wildcard(&args[0]) => true, + matches!(window_function.fun, WindowFunctionDefinition::AggregateUDF(ref udaf) - if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => - { - true - } - _ => false, - } + if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn analyze_internal(plan: LogicalPlan) -> Result> { @@ -121,14 +101,16 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; - use datafusion_expr::test::function_stub::sum; use datafusion_expr::{ - col, count, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, AggregateFunction, WindowFrame, - WindowFrameBound, WindowFrameUnits, + col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, + out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; + use datafusion_functions_aggregate::count::count_udaf; use std::sync::Arc; + use datafusion_functions_aggregate::expr_fn::{count, sum}; + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_analyzed_plan_eq_display_indent( Arc::new(CountWildcardRule::new()), @@ -239,7 +221,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 0c8e4ae34a90..acc21f14f44d 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1055,31 +1055,6 @@ mod test { Ok(()) } - #[test] - fn agg_function_invalid_input_percentile() { - let empty = empty(); - let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - vec![lit(0.95), lit(42.0), lit(100.0)], - false, - None, - None, - None, - )); - - let err = Projection::try_new(vec![agg_expr], empty) - .err() - .unwrap() - .strip_backtrace(); - - let prefix = "Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'. You might need to add explicit type casts.\n\tCandidate functions:"; - assert!(!err - .strip_prefix(prefix) - .unwrap() - .contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)")); - } - #[test] fn binary_op_date32_op_interval() -> Result<()> { // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...") diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index e14ee763a3c0..e949e1921b97 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -432,14 +432,8 @@ fn agg_exprs_evaluation_result_on_empty_batch( Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( - 0, - )))) - } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } + AggregateFunctionDefinition::BuiltIn(_fun) => { + Transformed::yes(Expr::Literal(ScalarValue::Null)) } AggregateFunctionDefinition::UDF(fun) => { if fun.name() == "COUNT" { diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index cef226d67b6c..7a8dd7aac249 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -129,10 +129,12 @@ mod tests { use datafusion_common::Result; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - col, count, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, - Signature, TypeSignature, + col, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, + TypeSignature, }; + use datafusion_functions_aggregate::expr_fn::count; + use std::sync::Arc; #[derive(Debug)] diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index af51814c9686..11540d3e162e 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -818,10 +818,11 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; + use datafusion_expr::AggregateExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, - col, count, + col, expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, @@ -830,6 +831,9 @@ mod tests { WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::count; + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -1886,16 +1890,10 @@ mod tests { #[test] fn aggregate_filter_pushdown() -> Result<()> { let table_scan = test_table_scan()?; - - let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("b")], - false, - Some(Box::new(col("c").gt(lit(42)))), - None, - None, - )); - + let aggr_with_filter = count_udaf() + .call(vec![col("b")]) + .filter(col("c").gt(lit(42))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index e738209eb4fd..d3d22eb53f39 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -362,11 +362,13 @@ mod tests { use super::*; use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; - use datafusion_expr::test::function_stub::{sum, sum_udaf}; + use datafusion_expr::AggregateExt; use datafusion_expr::{ - count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, min, - AggregateFunction, + lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum}; + use datafusion_functions_aggregate::sum::sum_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -679,14 +681,11 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - None, - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; @@ -725,19 +724,16 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - None, - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -748,19 +744,17 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index b3501cca9efa..f60bf6609005 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; @@ -323,7 +324,9 @@ fn test_sql(sql: &str) -> Result { let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); let statement = &ast[0]; - let context_provider = MyContextProvider::default().with_udaf(sum_udaf()); + let context_provider = MyContextProvider::default() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -345,7 +348,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 21884f840dbd..432267e045b2 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -46,6 +46,7 @@ use datafusion_expr::utils::AggregateOrderSensitivity; pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], + input_exprs: &[Expr], sort_exprs: &[Expr], ordering_req: &[PhysicalSortExpr], schema: &Schema, @@ -76,6 +77,7 @@ pub fn create_aggregate_expr( Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), + logical_args: input_exprs.to_vec(), data_type: fun.return_type(&input_exprs_types)?, name: name.into(), schema: schema.clone(), @@ -231,6 +233,7 @@ pub struct AggregatePhysicalExpressions { pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, + logical_args: Vec, /// Output / return type of this aggregate data_type: DataType, name: String, @@ -293,7 +296,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; @@ -308,7 +311,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; @@ -378,7 +381,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; self.fun.groups_accumulator_supported(args) @@ -392,7 +395,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; self.fun.create_groups_accumulator(args) @@ -434,6 +437,7 @@ impl AggregateExpr for AggregateFunctionExpr { create_aggregate_expr( &updated_fn, &self.args, + &self.logical_args, &self.sort_exprs, &self.ordering_req, &self.schema, @@ -468,6 +472,7 @@ impl AggregateExpr for AggregateFunctionExpr { let reverse_aggr = create_aggregate_expr( &reverse_udf, &self.args, + &self.logical_args, &reverse_sort_exprs, &reverse_ordering_req, &self.schema, diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr-common/src/expressions/literal.rs similarity index 98% rename from datafusion/physical-expr/src/expressions/literal.rs rename to datafusion/physical-expr-common/src/expressions/literal.rs index fcaf229af0a8..b3cff1ef69ba 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr-common/src/expressions/literal.rs @@ -21,8 +21,7 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; -use crate::PhysicalExpr; +use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; use arrow::{ datatypes::{DataType, Schema}, diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs b/datafusion/physical-expr-common/src/expressions/mod.rs index 4b5965e164b5..dd534cc07d20 100644 --- a/datafusion/physical-expr-common/src/expressions/mod.rs +++ b/datafusion/physical-expr-common/src/expressions/mod.rs @@ -17,5 +17,7 @@ mod cast; pub mod column; +pub mod literal; pub use cast::{cast, cast_with_options, CastExpr}; +pub use literal::{lit, Literal}; diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index f661400fcb10..d5cd3c6f4af0 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -17,18 +17,21 @@ use std::sync::Arc; -use crate::expressions::{self, CastExpr}; -use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::PhysicalSortExpr; -use crate::tree_node::ExprContext; - use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::Schema; + use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::Alias; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::Expr; +use crate::expressions::literal::Literal; +use crate::expressions::{self, CastExpr}; +use crate::physical_expr::PhysicalExpr; +use crate::sort_expr::PhysicalSortExpr; +use crate::tree_node::ExprContext; + /// Represents a [`PhysicalExpr`] node with associated properties (order and /// range) in a context where properties are tracked. pub type ExprPropertiesNode = ExprContext; @@ -115,6 +118,9 @@ pub fn limited_convert_logical_expr_to_physical_expr( schema: &Schema, ) -> Result> { match expr { + Expr::Alias(Alias { expr, .. }) => { + Ok(limited_convert_logical_expr_to_physical_expr(expr, schema)?) + } Expr::Column(col) => expressions::column::col(&col.name, schema), Expr::Cast(cast_expr) => Ok(Arc::new(CastExpr::new( limited_convert_logical_expr_to_physical_expr( @@ -124,10 +130,7 @@ pub fn limited_convert_logical_expr_to_physical_expr( cast_expr.data_type.clone(), None, ))), - Expr::Alias(alias_expr) => limited_convert_logical_expr_to_physical_expr( - alias_expr.expr.as_ref(), - schema, - ), + Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), _ => exec_err!( "Unsupported expression: {expr} for conversion to Arc" ), @@ -138,11 +141,12 @@ pub fn limited_convert_logical_expr_to_physical_expr( mod tests { use std::sync::Arc; - use super::*; - use arrow::array::Int32Array; + use datafusion_common::cast::{as_boolean_array, as_int32_array}; + use super::*; + #[test] fn scatter_int() -> Result<()> { let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs deleted file mode 100644 index f2068bbc92cc..000000000000 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ /dev/null @@ -1,249 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{any::Any, sync::Arc}; - -use arrow::datatypes::{DataType, Field}; -use arrow_array::RecordBatch; -use arrow_schema::Schema; - -use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, ColumnarValue}; -use datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -/// APPROX_PERCENTILE_CONT aggregate expression -#[derive(Debug)] -pub struct ApproxPercentileCont { - name: String, - input_data_type: DataType, - expr: Vec>, - percentile: f64, - tdigest_max_size: Option, -} - -impl ApproxPercentileCont { - /// Create a new [`ApproxPercentileCont`] aggregate function. - pub fn new( - expr: Vec>, - name: impl Into, - input_data_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, DesiredPercentileLiteral] - debug_assert_eq!(expr.len(), 2); - - let percentile = validate_input_percentile_expr(&expr[1])?; - - Ok(Self { - name: name.into(), - input_data_type, - // The physical expr to evaluate during accumulation - expr, - percentile, - tdigest_max_size: None, - }) - } - - /// Create a new [`ApproxPercentileCont`] aggregate function. - pub fn new_with_max_size( - expr: Vec>, - name: impl Into, - input_data_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, DesiredPercentileLiteral, TDigestMaxSize] - debug_assert_eq!(expr.len(), 3); - let percentile = validate_input_percentile_expr(&expr[1])?; - let max_size = validate_input_max_size_expr(&expr[2])?; - Ok(Self { - name: name.into(), - input_data_type, - // The physical expr to evaluate during accumulation - expr, - percentile, - tdigest_max_size: Some(max_size), - }) - } - - pub(crate) fn create_plain_accumulator(&self) -> Result { - let accumulator: ApproxPercentileAccumulator = match &self.input_data_type { - t @ (DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64) => { - if let Some(max_size) = self.tdigest_max_size { - ApproxPercentileAccumulator::new_with_max_size(self.percentile, t.clone(), max_size) - - }else{ - ApproxPercentileAccumulator::new(self.percentile, t.clone()) - - } - } - other => { - return not_impl_err!( - "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented" - ) - } - }; - Ok(accumulator) - } -} - -impl PartialEq for ApproxPercentileCont { - fn eq(&self, other: &ApproxPercentileCont) -> bool { - self.name == other.name - && self.input_data_type == other.input_data_type - && self.percentile == other.percentile - && self.tdigest_max_size == other.tdigest_max_size - && self.expr.len() == other.expr.len() - && self - .expr - .iter() - .zip(other.expr.iter()) - .all(|(this, other)| this.eq(other)) - } -} - -fn get_lit_value(expr: &Arc) -> Result { - let empty_schema = Schema::empty(); - let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema)); - let result = expr.evaluate(&empty_batch)?; - match result { - ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( - "The expr {:?} can't be evaluated to scalar value", - expr - ))), - ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), - } -} - -fn validate_input_percentile_expr(expr: &Arc) -> Result { - let lit = get_lit_value(expr)?; - let percentile = match &lit { - ScalarValue::Float32(Some(q)) => *q as f64, - ScalarValue::Float64(Some(q)) => *q, - got => return not_impl_err!( - "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", - got.data_type() - ) - }; - - // Ensure the percentile is between 0 and 1. - if !(0.0..=1.0).contains(&percentile) { - return plan_err!( - "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" - ); - } - Ok(percentile) -} - -fn validate_input_max_size_expr(expr: &Arc) -> Result { - let lit = get_lit_value(expr)?; - let max_size = match &lit { - ScalarValue::UInt8(Some(q)) => *q as usize, - ScalarValue::UInt16(Some(q)) => *q as usize, - ScalarValue::UInt32(Some(q)) => *q as usize, - ScalarValue::UInt64(Some(q)) => *q as usize, - ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize, - got => return not_impl_err!( - "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", - got.data_type() - ) - }; - Ok(max_size) -} - -impl AggregateExpr for ApproxPercentileCont { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), false)) - } - - #[allow(rustdoc::private_intra_doc_links)] - /// See [`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`] for a description of the serialised - /// state. - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "max_size"), - DataType::UInt64, - false, - ), - Field::new( - format_state_name(&self.name, "sum"), - DataType::Float64, - false, - ), - Field::new( - format_state_name(&self.name, "count"), - DataType::Float64, - false, - ), - Field::new( - format_state_name(&self.name, "max"), - DataType::Float64, - false, - ), - Field::new( - format_state_name(&self.name, "min"), - DataType::Float64, - false, - ), - Field::new_list( - format_state_name(&self.name, "centroids"), - Field::new("item", DataType::Float64, true), - false, - ), - ]) - } - - fn expressions(&self) -> Vec> { - self.expr.clone() - } - - fn create_accumulator(&self) -> Result> { - let accumulator = self.create_plain_accumulator()?; - Ok(Box::new(accumulator)) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ApproxPercentileCont { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.eq(x)) - .unwrap_or(false) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index aee7bca3b88f..a1f5f153a9ff 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,13 +30,13 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; -use crate::aggregate::regr::RegrType; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; + /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -61,9 +61,6 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Count, _) => { - return internal_err!("Builtin Count will be removed"); - } (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, @@ -158,118 +155,6 @@ pub fn create_aggregate_expr( (AggregateFunction::Correlation, true) => { return not_impl_err!("CORR(DISTINCT) aggregations are not available"); } - (AggregateFunction::RegrSlope, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Slope, - data_type, - )), - (AggregateFunction::RegrIntercept, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Intercept, - data_type, - )), - (AggregateFunction::RegrCount, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Count, - data_type, - )), - (AggregateFunction::RegrR2, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::R2, - data_type, - )), - (AggregateFunction::RegrAvgx, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgX, - data_type, - )), - (AggregateFunction::RegrAvgy, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgY, - data_type, - )), - (AggregateFunction::RegrSXX, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXX, - data_type, - )), - (AggregateFunction::RegrSYY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SYY, - data_type, - )), - (AggregateFunction::RegrSXY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXY, - data_type, - )), - ( - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY, - true, - ) => { - return not_impl_err!("{}(DISTINCT) aggregations are not available", fun); - } - (AggregateFunction::ApproxPercentileCont, false) => { - if input_phy_exprs.len() == 2 { - Arc::new(expressions::ApproxPercentileCont::new( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } else { - Arc::new(expressions::ApproxPercentileCont::new_with_max_size( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } - } - (AggregateFunction::ApproxPercentileCont, true) => { - return not_impl_err!( - "approx_percentile_cont(DISTINCT) aggregations are not available" - ); - } - (AggregateFunction::ApproxPercentileContWithWeight, false) => { - Arc::new(expressions::ApproxPercentileContWithWeight::new( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } - (AggregateFunction::ApproxPercentileContWithWeight, true) => { - return not_impl_err!( - "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" - ); - } (AggregateFunction::NthValue, _) => { let expr = &input_phy_exprs[0]; let Some(n) = input_phy_exprs[1] @@ -313,15 +198,15 @@ pub fn create_aggregate_expr( mod tests { use arrow::datatypes::{DataType, Field}; - use super::*; + use datafusion_common::plan_err; + use datafusion_expr::{type_coercion, Signature}; + use crate::expressions::{ - try_cast, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, - BoolOr, DistinctArrayAgg, Max, Min, + try_cast, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, + DistinctArrayAgg, Max, Min, }; - use datafusion_common::{plan_err, DataFusionError, ScalarValue}; - use datafusion_expr::type_coercion::aggregates::NUMERICS; - use datafusion_expr::{type_coercion, Signature}; + use super::*; #[test] fn test_approx_expr() -> Result<()> { @@ -385,59 +270,6 @@ mod tests { Ok(()) } - #[test] - fn test_agg_approx_percentile_phy_expr() { - for data_type in NUMERICS { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - ), - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &AggregateFunction::ApproxPercentileCont, - false, - &input_phy_exprs[..], - &input_schema, - "c1", - ) - .expect("failed to create aggregate expr"); - - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), false), - result_agg_phy_exprs.field().unwrap() - ); - } - } - - #[test] - fn test_agg_approx_percentile_invalid_phy_expr() { - for data_type in NUMERICS { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - ), - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), - ]; - let err = create_physical_agg_expr_for_test( - &AggregateFunction::ApproxPercentileCont, - false, - &input_phy_exprs[..], - &input_schema, - "c1", - ) - .expect_err("should fail due to invalid percentile"); - - assert!(matches!(err, DataFusionError::Plan(_))); - } - } - #[test] fn test_min_max_expr() -> Result<()> { let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; @@ -642,20 +474,6 @@ mod tests { Ok(()) } - #[test] - fn test_count_return_type() -> Result<()> { - let observed = AggregateFunction::Count.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = AggregateFunction::Count.return_type(&[DataType::Int8])?; - assert_eq!(DataType::Int64, observed); - - let observed = - AggregateFunction::Count.return_type(&[DataType::Decimal128(28, 13)])?; - assert_eq!(DataType::Int64, observed); - Ok(()) - } - #[test] fn test_avg_return_type() -> Result<()> { let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 01105c8559c9..c20902c11b86 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,8 +17,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub(crate) mod approx_percentile_cont; -pub(crate) mod approx_percentile_cont_with_weight; pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; @@ -33,7 +31,6 @@ pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; -pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod variance; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 123ada6d7c86..b9a159b21e3d 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -26,7 +26,6 @@ mod in_list; mod is_not_null; mod is_null; mod like; -mod literal; mod negative; mod no_op; mod not; @@ -36,8 +35,6 @@ mod try_cast; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont; -pub use crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight; pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; @@ -50,7 +47,6 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; -pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::string_agg::StringAgg; pub use crate::window::cume_dist::{cume_dist, CumeDist}; @@ -67,12 +63,12 @@ pub use column::UnKnownColumn; pub use datafusion_expr::utils::format_state_name; pub use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; pub use datafusion_physical_expr_common::expressions::column::{col, Column}; +pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal}; pub use datafusion_physical_expr_common::expressions::{cast, CastExpr}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; pub use like::{like, LikeExpr}; -pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 55d112e1f6e0..4bd40066ff34 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -125,7 +125,6 @@ impl BuiltInWindowFunctionExpr for NthValue { fn create_evaluator(&self) -> Result> { let state = NthValueState { - range: Default::default(), finalized_result: None, kind: self.kind, }; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 065371d9e43e..3cf68379d72b 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -559,7 +559,6 @@ pub enum NthValueKind { #[derive(Debug, Clone)] pub struct NthValueState { - pub range: Range, // In certain cases, we can finalize the result early. Consider this usage: // ``` // FIRST_VALUE(increasing_col) OVER window AS my_first_value diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b6fc70be7cbc..b7d8d60f4f35 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1339,6 +1339,7 @@ mod tests { let aggregates = vec![create_aggregate_expr( &count_udaf(), &[lit(1i8)], + &[datafusion_expr::lit(1i8)], &[], &[], &input_schema, @@ -1787,6 +1788,7 @@ mod tests { &args, &[], &[], + &[], schema, "MEDIAN(a)", false, @@ -1975,10 +1977,12 @@ mod tests { options: sort_options, }]; let args = vec![col("b", schema)?]; + let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); datafusion_physical_expr_common::aggregate::create_aggregate_expr( &func, &args, + &logical_args, &sort_exprs, &ordering_req, schema, @@ -2005,10 +2009,12 @@ mod tests { options: sort_options, }]; let args = vec![col("b", schema)?]; + let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); - datafusion_physical_expr_common::aggregate::create_aggregate_expr( + create_aggregate_expr( &func, &args, + &logical_args, &sort_exprs, &ordering_req, schema, diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index fa30141a1934..30c3353d4b71 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -175,11 +175,6 @@ impl DataSinkExec { &self.sort_order } - /// Returns the metrics of the underlying [DataSink] - pub fn metrics(&self) -> Option { - self.sink.metrics() - } - fn create_schema( input: &Arc, schema: SchemaRef, @@ -289,6 +284,11 @@ impl ExecutionPlan for DataSinkExec { stream, ))) } + + /// Returns the metrics of the underlying [DataSink] + fn metrics(&self) -> Option { + self.sink.metrics() + } } /// Create a output record batch with a count diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 784584f03f0f..5353092d5c45 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -21,8 +21,9 @@ use std::fmt; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use std::{any::Any, usize, vec}; +use std::{any::Any, vec}; +use super::utils::asymmetric_join_output_partitioning; use super::{ utils::{OnceAsync, OnceFut}, PartitionMode, @@ -34,10 +35,10 @@ use crate::{ execution_mode_from_children, handle_state, hash_utils::create_hashes, joins::utils::{ - adjust_indices_by_join_type, adjust_right_output_partitioning, - apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, estimate_join_statistics, get_final_indices_from_bit_map, - need_produce_result_in_final, partitioned_join_output_partitioning, + adjust_indices_by_join_type, apply_join_filter_to_indices, + build_batch_from_indices, build_join_schema, check_join_is_valid, + estimate_join_statistics, get_final_indices_from_bit_map, + need_produce_result_in_final, symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMap, JoinHashMapOffset, JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, }, @@ -490,33 +491,16 @@ impl HashJoinExec { on, ); - // Get output partitioning: - let left_columns_len = left.schema().fields.len(); let mut output_partitioning = match mode { - PartitionMode::CollectLeft => match join_type { - JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( - right.output_partitioning(), - left_columns_len, - ), - JoinType::RightSemi | JoinType::RightAnti => { - right.output_partitioning().clone() - } - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::Full => Partitioning::UnknownPartitioning( - right.output_partitioning().partition_count(), - ), - }, - PartitionMode::Partitioned => partitioned_join_output_partitioning( - join_type, - left.output_partitioning(), - right.output_partitioning(), - left_columns_len, - ), + PartitionMode::CollectLeft => { + asymmetric_join_output_partitioning(left, right, &join_type) + } PartitionMode::Auto => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), + PartitionMode::Partitioned => { + symmetric_join_output_partitioning(left, right, &join_type) + } }; // Determine execution mode by checking whether this join is pipeline diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 18518600ef2f..6be124cce06f 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -25,18 +25,19 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; +use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final}; use crate::coalesce_batches::concat_batches; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::utils::{ - adjust_indices_by_join_type, adjust_right_output_partitioning, - apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, estimate_join_statistics, get_final_indices_from_bit_map, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut, + adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, + build_join_schema, check_join_is_valid, estimate_join_statistics, + get_final_indices_from_bit_map, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + OnceAsync, OnceFut, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, - ExecutionMode, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, + ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; @@ -55,8 +56,6 @@ use datafusion_physical_expr::equivalence::join_equivalence_properties; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; -use super::utils::need_produce_result_in_final; - /// Shared bitmap for visited left-side indices type SharedBitmapBuilder = Mutex; /// Left (build-side) data @@ -228,21 +227,8 @@ impl NestedLoopJoinExec { &[], ); - // Get output partitioning, - let output_partitioning = match join_type { - JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( - right.output_partitioning(), - left.schema().fields().len(), - ), - JoinType::RightSemi | JoinType::RightAnti => { - right.output_partitioning().clone() - } - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full => { - Partitioning::UnknownPartitioning( - right.output_partitioning().partition_count(), - ) - } - }; + let output_partitioning = + asymmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: let mut mode = execution_mode_from_children([left, right]); @@ -673,7 +659,7 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; - use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr::{Partitioning, PhysicalExpr}; fn build_table( a: (&str, &Vec), diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 8da345cdfca6..01abb30181d0 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -30,12 +30,22 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::expressions::PhysicalSortExpr; +use crate::joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, +}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::{ + execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, Statistics, +}; + use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; use datafusion_common::{ internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, @@ -45,17 +55,8 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use crate::expressions::PhysicalSortExpr; -use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, -}; -use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{ - execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -220,14 +221,8 @@ impl SortMergeJoinExec { join_on, ); - // Get output partitioning: - let left_columns_len = left.schema().fields.len(); - let output_partitioning = partitioned_join_output_partitioning( - join_type, - left.output_partitioning(), - right.output_partitioning(), - left_columns_len, - ); + let output_partitioning = + symmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: let mode = execution_mode_from_children([left, right]); diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 0a01d84141e7..46d3ac5acf1e 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -20,7 +20,6 @@ use std::collections::{HashMap, VecDeque}; use std::sync::Arc; -use std::usize; use crate::joins::utils::{JoinFilter, JoinHashMapType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 7b4d790479b1..813f670147bc 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -29,7 +29,7 @@ use std::any::Any; use std::fmt::{self, Debug}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::{usize, vec}; +use std::vec; use crate::common::SharedMemoryReservation; use crate::handle_state; @@ -42,7 +42,7 @@ use crate::joins::stream_join_utils::{ }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, partitioned_join_output_partitioning, ColumnIndex, JoinFilter, + check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, }; use crate::{ @@ -271,14 +271,8 @@ impl SymmetricHashJoinExec { join_on, ); - // Get output partitioning: - let left_columns_len = left.schema().fields.len(); - let output_partitioning = partitioned_join_output_partitioning( - join_type, - left.output_partitioning(), - right.output_partitioning(), - left_columns_len, - ); + let output_partitioning = + symmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: let mode = execution_mode_from_children([left, right]); diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 9598ed83aa58..7e05ded6f69d 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -18,7 +18,6 @@ //! This file has test utils for hash joins use std::sync::Arc; -use std::usize; use crate::joins::utils::{JoinFilter, JoinOn}; use crate::joins::{ diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 0d99d7a16356..dfa1fd4763f4 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -23,10 +23,11 @@ use std::future::Future; use std::ops::{IndexMut, Range}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::usize; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; +use crate::{ + ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics, +}; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, @@ -429,27 +430,6 @@ fn check_join_set_is_valid( Ok(()) } -/// Calculate the OutputPartitioning for Partitioned Join -pub fn partitioned_join_output_partitioning( - join_type: JoinType, - left_partitioning: &Partitioning, - right_partitioning: &Partitioning, - left_columns_len: usize, -) -> Partitioning { - match join_type { - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - left_partitioning.clone() - } - JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), - JoinType::Right => { - adjust_right_output_partitioning(right_partitioning, left_columns_len) - } - JoinType::Full => { - Partitioning::UnknownPartitioning(right_partitioning.partition_count()) - } - } -} - /// Adjust the right out partitioning to new Column Index pub fn adjust_right_output_partitioning( right_partitioning: &Partitioning, @@ -1540,6 +1520,48 @@ pub enum StatefulStreamResult { Continue, } +pub(crate) fn symmetric_join_output_partitioning( + left: &Arc, + right: &Arc, + join_type: &JoinType, +) -> Partitioning { + let left_columns_len = left.schema().fields.len(); + let left_partitioning = left.output_partitioning(); + let right_partitioning = right.output_partitioning(); + match join_type { + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + left_partitioning.clone() + } + JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), + JoinType::Inner | JoinType::Right => { + adjust_right_output_partitioning(right_partitioning, left_columns_len) + } + JoinType::Full => { + // We could also use left partition count as they are necessarily equal. + Partitioning::UnknownPartitioning(right_partitioning.partition_count()) + } + } +} + +pub(crate) fn asymmetric_join_output_partitioning( + left: &Arc, + right: &Arc, + join_type: &JoinType, +) -> Partitioning { + match join_type { + JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( + right.output_partitioning(), + left.schema().fields().len(), + ), + JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full => { + Partitioning::UnknownPartitioning( + right.output_partitioning().partition_count(), + ) + } + } +} + #[cfg(test)] mod tests { use std::pin::Pin; diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 56d780e51394..fc60ab997375 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1194,7 +1194,7 @@ mod tests { RecordBatchStream, SendableRecordBatchStream, TaskContext, }; use datafusion_expr::{ - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + Expr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::{col, Column, NthValue}; @@ -1301,7 +1301,10 @@ mod tests { let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf()); let col_expr = Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; + let log_expr = + Expr::Column(datafusion_common::Column::from(schema.fields[0].name())); let args = vec![col_expr]; + let log_args = vec![log_expr]; let partitionby_exprs = vec![col(hash, &schema)?]; let orderby_exprs = vec![PhysicalSortExpr { expr: col(order_by, &schema)?, @@ -1322,6 +1325,7 @@ mod tests { &window_fn, fn_name, &args, + &log_args, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 63ce473fc57e..ecfe123a43af 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -90,6 +90,7 @@ pub fn create_window_expr( fun: &WindowFunctionDefinition, name: String, args: &[Arc], + logical_args: &[Expr], partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -144,6 +145,7 @@ pub fn create_window_expr( let aggregate = udaf::create_aggregate_expr( fun.as_ref(), args, + logical_args, &sort_exprs, order_by, input_schema, @@ -754,6 +756,7 @@ mod tests { &[col("a", &schema)?], &[], &[], + &[], Arc::new(WindowFrame::new(None)), schema.as_ref(), false, diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index b1897aa58e7d..aa8d0e55b68f 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -59,6 +59,7 @@ serde_json = { workspace = true, optional = true } [dev-dependencies] datafusion-functions = { workspace = true, default-features = true } +datafusion-functions-aggregate = { workspace = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2bb3ec793d7f..e5578ae62f3e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -476,7 +476,7 @@ enum AggregateFunction { MAX = 1; // SUM = 2; AVG = 3; - COUNT = 4; + // COUNT = 4; // APPROX_DISTINCT = 5; ARRAY_AGG = 6; // VARIANCE = 7; @@ -486,9 +486,9 @@ enum AggregateFunction { // STDDEV = 11; // STDDEV_POP = 12; CORRELATION = 13; - APPROX_PERCENTILE_CONT = 14; + // APPROX_PERCENTILE_CONT = 14; // APPROX_MEDIAN = 15; - APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; + // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; GROUPING = 17; // MEDIAN = 18; BIT_AND = 19; @@ -496,15 +496,15 @@ enum AggregateFunction { BIT_XOR = 21; BOOL_AND = 22; BOOL_OR = 23; - REGR_SLOPE = 26; - REGR_INTERCEPT = 27; - REGR_COUNT = 28; - REGR_R2 = 29; - REGR_AVGX = 30; - REGR_AVGY = 31; - REGR_SXX = 32; - REGR_SYY = 33; - REGR_SXY = 34; + // REGR_SLOPE = 26; + // REGR_INTERCEPT = 27; + // REGR_COUNT = 28; + // REGR_R2 = 29; + // REGR_AVGX = 30; + // REGR_AVGY = 31; + // REGR_SXX = 32; + // REGR_SYY = 33; + // REGR_SXY = 34; STRING_AGG = 35; NTH_VALUE_AGG = 36; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 59b7861a6ef1..4a7b9610e5bc 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -535,26 +535,14 @@ impl serde::Serialize for AggregateFunction { Self::Min => "MIN", Self::Max => "MAX", Self::Avg => "AVG", - Self::Count => "COUNT", Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", - Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - Self::ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Self::Grouping => "GROUPING", Self::BitAnd => "BIT_AND", Self::BitOr => "BIT_OR", Self::BitXor => "BIT_XOR", Self::BoolAnd => "BOOL_AND", Self::BoolOr => "BOOL_OR", - Self::RegrSlope => "REGR_SLOPE", - Self::RegrIntercept => "REGR_INTERCEPT", - Self::RegrCount => "REGR_COUNT", - Self::RegrR2 => "REGR_R2", - Self::RegrAvgx => "REGR_AVGX", - Self::RegrAvgy => "REGR_AVGY", - Self::RegrSxx => "REGR_SXX", - Self::RegrSyy => "REGR_SYY", - Self::RegrSxy => "REGR_SXY", Self::StringAgg => "STRING_AGG", Self::NthValueAgg => "NTH_VALUE_AGG", }; @@ -571,26 +559,14 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN", "MAX", "AVG", - "COUNT", "ARRAY_AGG", "CORRELATION", - "APPROX_PERCENTILE_CONT", - "APPROX_PERCENTILE_CONT_WITH_WEIGHT", "GROUPING", "BIT_AND", "BIT_OR", "BIT_XOR", "BOOL_AND", "BOOL_OR", - "REGR_SLOPE", - "REGR_INTERCEPT", - "REGR_COUNT", - "REGR_R2", - "REGR_AVGX", - "REGR_AVGY", - "REGR_SXX", - "REGR_SYY", - "REGR_SXY", "STRING_AGG", "NTH_VALUE_AGG", ]; @@ -636,26 +612,14 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), "AVG" => Ok(AggregateFunction::Avg), - "COUNT" => Ok(AggregateFunction::Count), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), - "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => Ok(AggregateFunction::ApproxPercentileContWithWeight), "GROUPING" => Ok(AggregateFunction::Grouping), "BIT_AND" => Ok(AggregateFunction::BitAnd), "BIT_OR" => Ok(AggregateFunction::BitOr), "BIT_XOR" => Ok(AggregateFunction::BitXor), "BOOL_AND" => Ok(AggregateFunction::BoolAnd), "BOOL_OR" => Ok(AggregateFunction::BoolOr), - "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope), - "REGR_INTERCEPT" => Ok(AggregateFunction::RegrIntercept), - "REGR_COUNT" => Ok(AggregateFunction::RegrCount), - "REGR_R2" => Ok(AggregateFunction::RegrR2), - "REGR_AVGX" => Ok(AggregateFunction::RegrAvgx), - "REGR_AVGY" => Ok(AggregateFunction::RegrAvgy), - "REGR_SXX" => Ok(AggregateFunction::RegrSxx), - "REGR_SYY" => Ok(AggregateFunction::RegrSyy), - "REGR_SXY" => Ok(AggregateFunction::RegrSxy), "STRING_AGG" => Ok(AggregateFunction::StringAgg), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0861c287fcfa..ffaef445d668 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1930,7 +1930,7 @@ pub enum AggregateFunction { Max = 1, /// SUM = 2; Avg = 3, - Count = 4, + /// COUNT = 4; /// APPROX_DISTINCT = 5; ArrayAgg = 6, /// VARIANCE = 7; @@ -1940,9 +1940,9 @@ pub enum AggregateFunction { /// STDDEV = 11; /// STDDEV_POP = 12; Correlation = 13, - ApproxPercentileCont = 14, + /// APPROX_PERCENTILE_CONT = 14; /// APPROX_MEDIAN = 15; - ApproxPercentileContWithWeight = 16, + /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; Grouping = 17, /// MEDIAN = 18; BitAnd = 19, @@ -1950,15 +1950,15 @@ pub enum AggregateFunction { BitXor = 21, BoolAnd = 22, BoolOr = 23, - RegrSlope = 26, - RegrIntercept = 27, - RegrCount = 28, - RegrR2 = 29, - RegrAvgx = 30, - RegrAvgy = 31, - RegrSxx = 32, - RegrSyy = 33, - RegrSxy = 34, + /// REGR_SLOPE = 26; + /// REGR_INTERCEPT = 27; + /// REGR_COUNT = 28; + /// REGR_R2 = 29; + /// REGR_AVGX = 30; + /// REGR_AVGY = 31; + /// REGR_SXX = 32; + /// REGR_SYY = 33; + /// REGR_SXY = 34; StringAgg = 35, NthValueAgg = 36, } @@ -1972,28 +1972,14 @@ impl AggregateFunction { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", AggregateFunction::Avg => "AVG", - AggregateFunction::Count => "COUNT", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", - AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - AggregateFunction::ApproxPercentileContWithWeight => { - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" - } AggregateFunction::Grouping => "GROUPING", AggregateFunction::BitAnd => "BIT_AND", AggregateFunction::BitOr => "BIT_OR", AggregateFunction::BitXor => "BIT_XOR", AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", - AggregateFunction::RegrSlope => "REGR_SLOPE", - AggregateFunction::RegrIntercept => "REGR_INTERCEPT", - AggregateFunction::RegrCount => "REGR_COUNT", - AggregateFunction::RegrR2 => "REGR_R2", - AggregateFunction::RegrAvgx => "REGR_AVGX", - AggregateFunction::RegrAvgy => "REGR_AVGY", - AggregateFunction::RegrSxx => "REGR_SXX", - AggregateFunction::RegrSyy => "REGR_SYY", - AggregateFunction::RegrSxy => "REGR_SXY", AggregateFunction::StringAgg => "STRING_AGG", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } @@ -2004,28 +1990,14 @@ impl AggregateFunction { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), "AVG" => Some(Self::Avg), - "COUNT" => Some(Self::Count), "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), - "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => { - Some(Self::ApproxPercentileContWithWeight) - } "GROUPING" => Some(Self::Grouping), "BIT_AND" => Some(Self::BitAnd), "BIT_OR" => Some(Self::BitOr), "BIT_XOR" => Some(Self::BitXor), "BOOL_AND" => Some(Self::BoolAnd), "BOOL_OR" => Some(Self::BoolOr), - "REGR_SLOPE" => Some(Self::RegrSlope), - "REGR_INTERCEPT" => Some(Self::RegrIntercept), - "REGR_COUNT" => Some(Self::RegrCount), - "REGR_R2" => Some(Self::RegrR2), - "REGR_AVGX" => Some(Self::RegrAvgx), - "REGR_AVGY" => Some(Self::RegrAvgy), - "REGR_SXX" => Some(Self::RegrSxx), - "REGR_SYY" => Some(Self::RegrSyy), - "REGR_SXY" => Some(Self::RegrSxy), "STRING_AGG" => Some(Self::StringAgg), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2ad40d883fe6..25b7413a984a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -145,24 +145,8 @@ impl From for AggregateFunction { protobuf::AggregateFunction::BitXor => Self::BitXor, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, - protobuf::AggregateFunction::Count => Self::Count, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, - protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, - protobuf::AggregateFunction::RegrIntercept => Self::RegrIntercept, - protobuf::AggregateFunction::RegrCount => Self::RegrCount, - protobuf::AggregateFunction::RegrR2 => Self::RegrR2, - protobuf::AggregateFunction::RegrAvgx => Self::RegrAvgx, - protobuf::AggregateFunction::RegrAvgy => Self::RegrAvgy, - protobuf::AggregateFunction::RegrSxx => Self::RegrSXX, - protobuf::AggregateFunction::RegrSyy => Self::RegrSYY, - protobuf::AggregateFunction::RegrSxy => Self::RegrSXY, - protobuf::AggregateFunction::ApproxPercentileCont => { - Self::ApproxPercentileCont - } - protobuf::AggregateFunction::ApproxPercentileContWithWeight => { - Self::ApproxPercentileContWithWeight - } protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, protobuf::AggregateFunction::StringAgg => Self::StringAgg, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6a275ed7a1b8..d9548325dac3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -116,22 +116,8 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::BitXor => Self::BitXor, AggregateFunction::BoolAnd => Self::BoolAnd, AggregateFunction::BoolOr => Self::BoolOr, - AggregateFunction::Count => Self::Count, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, - AggregateFunction::RegrSlope => Self::RegrSlope, - AggregateFunction::RegrIntercept => Self::RegrIntercept, - AggregateFunction::RegrCount => Self::RegrCount, - AggregateFunction::RegrR2 => Self::RegrR2, - AggregateFunction::RegrAvgx => Self::RegrAvgx, - AggregateFunction::RegrAvgy => Self::RegrAvgy, - AggregateFunction::RegrSXX => Self::RegrSxx, - AggregateFunction::RegrSYY => Self::RegrSyy, - AggregateFunction::RegrSXY => Self::RegrSxy, - AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, - AggregateFunction::ApproxPercentileContWithWeight => { - Self::ApproxPercentileContWithWeight - } AggregateFunction::Grouping => Self::Grouping, AggregateFunction::NthValue => Self::NthValueAgg, AggregateFunction::StringAgg => Self::StringAgg, @@ -391,12 +377,6 @@ pub fn serialize_expr( }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, @@ -406,25 +386,9 @@ pub fn serialize_expr( AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 0a91df568a1d..b636c77641c7 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -126,7 +126,6 @@ pub fn parse_physical_window_expr( ) -> Result> { let window_node_expr = parse_physical_exprs(&proto.args, registry, input_schema, codec)?; - let partition_by = parse_physical_exprs(&proto.partition_by, registry, input_schema, codec)?; @@ -178,10 +177,13 @@ pub fn parse_physical_window_expr( // TODO: Remove extended_schema if functions are all UDAF let extended_schema = schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; + // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. + let logical_exprs = &[]; create_window_expr( &fun, name, &window_node_expr, + logical_exprs, &partition_by, &order_by, Arc::new(window_frame), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index d0011e4917bf..8a488d30cf24 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -496,11 +496,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; + // TODO: 'logical_exprs' is not supported for UDAF yet. + // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. + let logical_exprs = &[]; // TODO: `order by` is not supported for UDAF yet let sort_exprs = &[]; let ordering_req = &[]; let ignore_nulls = false; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e25447b023d8..3a4c35a93e16 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,12 +23,11 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, - BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, - CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, StringAgg, - TryCastExpr, WindowShift, + ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, + CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, + InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, + NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, + StringAgg, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -270,25 +269,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation - } else if let Some(regr_expr) = aggr_expr.downcast_ref::() { - match regr_expr.get_regr_type() { - RegrType::Slope => protobuf::AggregateFunction::RegrSlope, - RegrType::Intercept => protobuf::AggregateFunction::RegrIntercept, - RegrType::Count => protobuf::AggregateFunction::RegrCount, - RegrType::R2 => protobuf::AggregateFunction::RegrR2, - RegrType::AvgX => protobuf::AggregateFunction::RegrAvgx, - RegrType::AvgY => protobuf::AggregateFunction::RegrAvgy, - RegrType::SXX => protobuf::AggregateFunction::RegrSxx, - RegrType::SYY => protobuf::AggregateFunction::RegrSyy, - RegrType::SXY => protobuf::AggregateFunction::RegrSxy, - } - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ApproxPercentileCont - } else if aggr_expr - .downcast_ref::() - .is_some() - { - protobuf::AggregateFunction::ApproxPercentileContWithWeight } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::StringAgg } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d9736da69d42..a496e226855a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -33,10 +33,11 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; -use datafusion::functions_aggregate::approx_median::approx_median; +use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ - covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, var_pop, - var_sample, + approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, + count_distinct, covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, + var_pop, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -53,10 +54,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateFunction, ColumnarValue, ExprSchemable, LogicalPlan, Operator, - PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, + Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable, + LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, + TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -662,6 +663,8 @@ async fn roundtrip_expr_api() -> Result<()> { stddev(lit(2.2)), stddev_pop(lit(2.2)), approx_median(lit(2)), + approx_percentile_cont(lit(2), lit(0.5)), + approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), ]; // ensure expressions created with the expr api can be round tripped @@ -1782,43 +1785,18 @@ fn roundtrip_similar_to() { #[test] fn roundtrip_count() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - false, - None, - None, - None, - )); + let test_expr = count(col("bananas")); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } #[test] fn roundtrip_count_distinct() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - true, - None, - None, - None, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); -} - -#[test] -fn roundtrip_approx_percentile_cont() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::ApproxPercentileCont, - vec![col("bananas"), lit(0.42_f32)], - false, - None, - None, - None, - )); - + let test_expr = count_udaf() + .call(vec![col("bananas")]) + .distinct() + .build() + .unwrap(); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index e517482f1db0..7f66cdbf7663 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -303,6 +303,7 @@ fn roundtrip_window() -> Result<()> { &args, &[], &[], + &[], &schema, "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", false, @@ -458,6 +459,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { &[col("b", &schema)?], &[], &[], + &[], &schema, "example_agg", false, diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 893db018c8af..aee4cf5a38ed 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -18,11 +18,12 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; -use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_sql::{ planner::{ContextProvider, SqlToRel}, sqlparser::{dialect::GenericDialect, parser::Parser}, @@ -50,7 +51,9 @@ fn main() { let statement = &ast[0]; // create a logical query plan - let context_provider = MyContextProvider::new().with_udaf(sum_udaf()); + let context_provider = MyContextProvider::new() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -66,7 +69,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 0f04281aa23b..a92e64597e82 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -439,6 +439,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLDataType::Bytea => Ok(DataType::Binary), SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + SQLDataType::Struct(fields) => { + let fields = fields + .iter() + .enumerate() + .map(|(idx, field)| { + let data_type = self.convert_data_type(&field.field_type)?; + let field_name = match &field.field_name{ + Some(ident) => ident.clone(), + None => Ident::new(format!("c{idx}")) + }; + Ok(Arc::new(Field::new( + self.normalizer.normalize(field_name), + data_type, + true, + ))) + }) + .collect::>>()?; + Ok(DataType::Struct(Fields::from(fields))) + } // Explicitly list all other types so that if sqlparser // adds/changes the `SQLDataType` the compiler will tell us on upgrade // and avoid bugs like https://github.com/apache/datafusion/issues/3059 @@ -472,7 +491,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Bytes(_) | SQLDataType::Int64 | SQLDataType::Float64 - | SQLDataType::Struct(_) | SQLDataType::JSONB | SQLDataType::Unspecified => not_impl_err!( diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index dc25a6c33ece..12c48054f1a7 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -960,13 +960,14 @@ mod tests { use arrow_schema::DataType::Int8; use datafusion_common::TableReference; + use datafusion_expr::AggregateExt; use datafusion_expr::{ - case, col, cube, exists, - expr::{AggregateFunction, AggregateFunctionDefinition}, - grouping_set, lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, - try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, WindowFrame, WindowFunctionDefinition, + case, col, cube, exists, grouping_set, lit, not, not_exists, out_ref_col, + placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, + WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; use crate::unparser::dialect::CustomDialect; @@ -1127,29 +1128,19 @@ mod tests { ), (sum(col("a")), r#"sum(a)"#), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: true, - filter: None, - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .distinct() + .build() + .unwrap(), "COUNT(DISTINCT *)", ), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: false, - filter: Some(Box::new(lit(true))), - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .filter(lit(true)) + .build() + .unwrap(), "COUNT(*) FILTER (WHERE true)", ), ( @@ -1167,9 +1158,7 @@ mod tests { ), ( Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::AggregateFunction( - datafusion_expr::AggregateFunction::Count, - ), + fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], order_by: vec![Expr::Sort(Sort::new( diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 51bacb5f702b..bc27d25cf216 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -350,7 +350,8 @@ mod tests { use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; use arrow_schema::Fields; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{col, count, lit, unnest, EmptyRelation, LogicalPlan}; + use datafusion_expr::{col, lit, unnest, EmptyRelation, LogicalPlan}; + use datafusion_functions_aggregate::expr_fn::count; use crate::utils::{recursive_transform_unnest, resolve_positions_to_exprs}; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 72018371a5f1..33e28e7056b9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -19,7 +19,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::{DFSchema, Result, TableReference}; -use datafusion_expr::test::function_stub::sum_udaf; +use datafusion_expr::test::function_stub::{count_udaf, sum_udaf}; use datafusion_expr::{col, table_scan}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -153,7 +153,9 @@ fn roundtrip_statement() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default().with_udaf(sum_udaf()); + let context = MockContextProvider::default() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index d91c09ae1287..893678d6b374 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -46,7 +46,8 @@ impl MockContextProvider { } pub(crate) fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 7b9d39a2b51e..8eb2a2b609e7 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -37,7 +37,9 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use datafusion_functions_aggregate::approx_median::approx_median_udaf; +use datafusion_functions_aggregate::{ + approx_median::approx_median_udaf, count::count_udaf, +}; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -2702,7 +2704,8 @@ fn logical_plan_with_dialect_and_options( )) .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64)) .with_udaf(sum_udaf()) - .with_udaf(approx_median_udaf()); + .with_udaf(approx_median_udaf()) + .with_udaf(count_udaf()); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 7ba1893bb11a..0a6def3d6f27 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -76,26 +76,26 @@ statement error DataFusion error: Schema error: Schema contains duplicate unqual SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_weight -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Utf8, Int8, Float64\)'. You might need to add explicit type casts. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Utf8, Int8, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Utf8, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Utf8, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Int8, Utf8\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Int8, Utf8\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_histogram_bins -statement error This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\). +statement error DataFusion error: External error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\)\. SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Utf8\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Utf8\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Float64, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 # array agg can use order by diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 1612adc643d9..77d1a9da1f55 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1137,7 +1137,7 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) # Testing with empty arguments should result in an error -query error DataFusion error: Error during planning: Error during planning: \[data_types_with_scalar_udf\] signature ArraySignature\(ArrayAndIndex\) does not support zero arguments. +query error DataFusion error: Error during planning: Error during planning: array_element does not support zero arguments. select array_element(); # array_element error @@ -1979,7 +1979,7 @@ select array_slice(a, -1, 2, 1), array_slice(a, -1, 2), [6.0] [6.0] [] [] # Testing with empty arguments should result in an error -query error DataFusion error: Error during planning: Error during planning: \[data_types_with_scalar_udf\] signature VariadicAny does not support zero arguments. +query error DataFusion error: Error during planning: Error during planning: array_slice does not support zero arguments. select array_slice(); diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index c7b9808c249d..d51c69496d46 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -112,11 +112,11 @@ statement error DataFusion error: Error during planning: No function matches the select avg(c1, c12) from aggregate_test_100; # AggregateFunction with wrong argument type -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Int64, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select regr_slope(1, '2'); # WindowFunction using AggregateFunction wrong signature -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Float32, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select c9, regr_slope(c11, '2') over () as min1 diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index f04d76822124..c3dd791f6ca8 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -487,7 +487,7 @@ statement error Did you mean 'to_timestamp_seconds'? SELECT to_TIMESTAMPS_second(v2) from test; # Aggregate function -statement error Did you mean 'COUNT'? +query error DataFusion error: Error during planning: Invalid function 'counter' SELECT counter(*) from test; # Aggregate function @@ -1158,3 +1158,21 @@ drop table uuid_table statement ok drop table t + + +# test for contains + +query B +select contains('alphabet', 'pha'); +---- +true + +query B +select contains('alphabet', 'dddd'); +---- +false + +query B +select contains('', ''); +---- +true diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 46a08709c3a3..749daa7e20e7 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -31,6 +31,33 @@ CREATE TABLE values( (3, 3.3, 'c', NULL) ; + +# named and named less struct fields +statement ok +CREATE TABLE struct_values ( + s1 struct, + s2 struct +) AS VALUES + (struct(1), struct(1, 'string1')), + (struct(2), struct(2, 'string2')), + (struct(3), struct(3, 'string3')) +; + +query ?? +select * from struct_values; +---- +{c0: 1} {a: 1, b: string1} +{c0: 2} {a: 2, b: string2} +{c0: 3} {a: 3, b: string3} + +query TT +select arrow_typeof(s1), arrow_typeof(s2) from struct_values; +---- +Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + + # struct[i] query IRT select struct(1, 3.14, 'h')['c0'], struct(3, 2.55, 'b')['c1'], struct(2, 6.43, 'a')['c2']; diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs new file mode 100644 index 000000000000..b4c5659a3a49 --- /dev/null +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for Function Compatibility + +#[cfg(test)] +mod tests { + use datafusion::common::Result; + use datafusion::prelude::{CsvReadOptions, SessionContext}; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use substrait::proto::Plan; + + #[tokio::test] + async fn contains_function_test() -> Result<()> { + let ctx = create_context().await?; + + let path = "tests/testdata/contains_plan.substrait.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + let plan_str = format!("{:?}", plan); + + assert_eq!( + plan_str, + "Projection: nation.b AS n_name\ + \n Filter: contains(nation.b, Utf8(\"IA\"))\ + \n TableScan: nation projection=[a, b, c, d, e, f]" + ); + Ok(()) + } + + async fn create_context() -> datafusion::common::Result { + let ctx = SessionContext::new(); + ctx.register_csv("nation", "tests/testdata/data.csv", CsvReadOptions::new()) + .await?; + Ok(ctx) + } +} diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index a31f93087d83..d3ea7695e4b9 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -16,6 +16,7 @@ // under the License. mod consumer_integration; +mod function_test; mod logical_plans; mod roundtrip_logical_plan; mod roundtrip_physical_plan; diff --git a/datafusion/substrait/tests/testdata/contains_plan.substrait.json b/datafusion/substrait/tests/testdata/contains_plan.substrait.json new file mode 100644 index 000000000000..76edde34e3b0 --- /dev/null +++ b/datafusion/substrait/tests/testdata/contains_plan.substrait.json @@ -0,0 +1,133 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_string.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "contains:str_str" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 4 + ] + } + }, + "input": { + "filter": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "n_nationkey", + "n_name", + "n_regionkey", + "n_comment" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "nation" + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "string": "IA" + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + }, + "names": [ + "n_name" + ] + } + } + ], + "version": { + "minorNumber": 38, + "producer": "ibis-substrait" + } +} \ No newline at end of file diff --git a/docs/source/user-guide/cli/usage.md b/docs/source/user-guide/cli/usage.md index 617b462875c7..6a620fc69252 100644 --- a/docs/source/user-guide/cli/usage.md +++ b/docs/source/user-guide/cli/usage.md @@ -52,7 +52,7 @@ OPTIONS: --maxrows The max number of rows to display for 'Table' format - [default: 40] [possible values: numbers(0/10/...), inf(no limit)] + [possible values: numbers(0/10/...), inf(no limit)] [default: 40] --mem-pool-type Specify the memory pool type 'greedy' or 'fair', default to 'greedy' diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 10c52bc5de9e..ec34dbf9ba6c 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -681,6 +681,7 @@ _Alias of [nvl](#nvl)._ - [substr_index](#substr_index) - [find_in_set](#find_in_set) - [position](#position) +- [contains](#contains) ### `ascii` @@ -1443,6 +1444,19 @@ position(substr in origstr) - **substr**: The pattern string. - **origstr**: The model string. +### `contains` + +Return true if search_string is found within string. + +``` +contains(string, search_string) +``` + +#### Arguments + +- **string**: The pattern string. +- **search_string**: The model string. + ## Time and Date Functions - [now](#now)