From c8ef54537c9e180eac10e13c1327f81e5e005394 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Wed, 24 Jul 2024 09:47:25 +0800 Subject: [PATCH 01/17] fix: panic and incorrect results in `LogFunc::output_ordering()` (#11571) * fix: panic and incorrect results in `LogFunc::output_ordering()` * fix for nulls_first --- datafusion/functions/src/math/log.rs | 123 ++++++++++++++++++- datafusion/sqllogictest/test_files/order.slt | 31 +++-- 2 files changed, 138 insertions(+), 16 deletions(-) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index ea424c14749e..0e181aa61250 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -82,10 +82,16 @@ impl ScalarUDFImpl for LogFunc { } fn output_ordering(&self, input: &[ExprProperties]) -> Result { - match (input[0].sort_properties, input[1].sort_properties) { - (first @ SortProperties::Ordered(value), SortProperties::Ordered(base)) - if !value.descending && base.descending - || value.descending && !base.descending => + let (base_sort_properties, num_sort_properties) = if input.len() == 1 { + // log(x) defaults to log(10, x) + (SortProperties::Singleton, input[0].sort_properties) + } else { + (input[0].sort_properties, input[1].sort_properties) + }; + match (num_sort_properties, base_sort_properties) { + (first @ SortProperties::Ordered(num), SortProperties::Ordered(base)) + if num.descending != base.descending + && num.nulls_first == base.nulls_first => { Ok(first) } @@ -230,6 +236,7 @@ mod tests { use super::*; + use arrow::compute::SortOptions; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; @@ -334,4 +341,112 @@ mod tests { assert_eq!(args[0], lit(2)); assert_eq!(args[1], lit(3)); } + + #[test] + fn test_log_output_ordering() { + // [Unordered, Ascending, Descending, Literal] + let orders = vec![ + ExprProperties::new_unknown(), + ExprProperties::new_unknown().with_order(SortProperties::Ordered( + SortOptions { + descending: false, + nulls_first: true, + }, + )), + ExprProperties::new_unknown().with_order(SortProperties::Ordered( + SortOptions { + descending: true, + nulls_first: true, + }, + )), + ExprProperties::new_unknown().with_order(SortProperties::Singleton), + ]; + + let log = LogFunc::new(); + + // Test log(num) + for order in orders.iter().cloned() { + let result = log.output_ordering(&[order.clone()]).unwrap(); + assert_eq!(result, order.sort_properties); + } + + // Test log(base, num), where `nulls_first` is the same + let mut results = Vec::with_capacity(orders.len() * orders.len()); + for base_order in orders.iter() { + for num_order in orders.iter().cloned() { + let result = log + .output_ordering(&[base_order.clone(), num_order]) + .unwrap(); + results.push(result); + } + } + let expected = vec![ + // base: Unordered + SortProperties::Unordered, + SortProperties::Unordered, + SortProperties::Unordered, + SortProperties::Unordered, + // base: Ascending, num: Unordered + SortProperties::Unordered, + // base: Ascending, num: Ascending + SortProperties::Unordered, + // base: Ascending, num: Descending + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + // base: Ascending, num: Literal + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + // base: Descending, num: Unordered + SortProperties::Unordered, + // base: Descending, num: Ascending + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: true, + }), + // base: Descending, num: Descending + SortProperties::Unordered, + // base: Descending, num: Literal + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: true, + }), + // base: Literal, num: Unordered + SortProperties::Unordered, + // base: Literal, num: Ascending + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: true, + }), + // base: Literal, num: Descending + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + // base: Literal, num: Literal + SortProperties::Singleton, + ]; + assert_eq!(results, expected); + + // Test with different `nulls_first` + let base_order = ExprProperties::new_unknown().with_order( + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + ); + let num_order = ExprProperties::new_unknown().with_order( + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: false, + }), + ); + assert_eq!( + log.output_ordering(&[base_order, num_order]).unwrap(), + SortProperties::Unordered + ); + } } diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 1aeaf9b76d48..d0a6d6adc107 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -326,6 +326,13 @@ select column1 + column2 from foo group by column1, column2 ORDER BY column2 des 7 3 +# Test issue: https://github.com/apache/datafusion/issues/11549 +query I +select column1 from foo order by log(column2); +---- +1 +3 +5 # Cleanup statement ok @@ -512,7 +519,7 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( ) STORED AS CSV WITH ORDER(c11) -WITH ORDER(c12 DESC) +WITH ORDER(c12 DESC NULLS LAST) LOCATION '../../testing/data/csv/aggregate_test_100.csv' OPTIONS ('format.has_header' 'true'); @@ -547,34 +554,34 @@ physical_plan 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11], output_ordering=[c11@0 ASC NULLS LAST], has_header=true query TT - EXPLAIN SELECT LOG(c11, c12) as log_c11_base_c12 + EXPLAIN SELECT LOG(c12, c11) as log_c11_base_c12 FROM aggregate_test_100 ORDER BY log_c11_base_c12; ---- logical_plan 01)Sort: log_c11_base_c12 ASC NULLS LAST -02)--Projection: log(CAST(aggregate_test_100.c11 AS Float64), aggregate_test_100.c12) AS log_c11_base_c12 +02)--Projection: log(aggregate_test_100.c12, CAST(aggregate_test_100.c11 AS Float64)) AS log_c11_base_c12 03)----TableScan: aggregate_test_100 projection=[c11, c12] physical_plan 01)SortPreservingMergeExec: [log_c11_base_c12@0 ASC NULLS LAST] -02)--ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c11_base_c12] +02)--ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c11_base_c12] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC NULLS LAST]], has_header=true query TT -EXPLAIN SELECT LOG(c12, c11) as log_c12_base_c11 +EXPLAIN SELECT LOG(c11, c12) as log_c12_base_c11 FROM aggregate_test_100 -ORDER BY log_c12_base_c11 DESC; +ORDER BY log_c12_base_c11 DESC NULLS LAST; ---- logical_plan -01)Sort: log_c12_base_c11 DESC NULLS FIRST -02)--Projection: log(aggregate_test_100.c12, CAST(aggregate_test_100.c11 AS Float64)) AS log_c12_base_c11 +01)Sort: log_c12_base_c11 DESC NULLS LAST +02)--Projection: log(CAST(aggregate_test_100.c11 AS Float64), aggregate_test_100.c12) AS log_c12_base_c11 03)----TableScan: aggregate_test_100 projection=[c11, c12] physical_plan -01)SortPreservingMergeExec: [log_c12_base_c11@0 DESC] -02)--ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c12_base_c11] +01)SortPreservingMergeExec: [log_c12_base_c11@0 DESC NULLS LAST] +02)--ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c12_base_c11] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC NULLS LAST]], has_header=true statement ok drop table aggregate_test_100; From 72c6491d25fe253b8757028be77e1e6f5cd74c71 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 08:36:49 -0400 Subject: [PATCH 02/17] Minor: Use upstream concat_batches (#11615) --- .../physical-plan/src/coalesce_batches.rs | 38 ++++--------------- .../physical-plan/src/joins/cross_join.rs | 9 ++--- .../src/joins/nested_loop_join.rs | 6 +-- 3 files changed, 14 insertions(+), 39 deletions(-) diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index b9bdfcdee712..8cb25827ff8f 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -18,25 +18,23 @@ //! CoalesceBatchesExec combines small batches into larger batches for more efficient use of //! vectorized processing by upstream operators. -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics}; use crate::{ DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, }; +use arrow::compute::concat_batches; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::TaskContext; use futures::stream::{Stream, StreamExt}; -use log::trace; /// CoalesceBatchesExec combines small batches into larger batches for more efficient use of /// vectorized processing by upstream operators. @@ -229,11 +227,7 @@ impl CoalesceBatchesStream { // check to see if we have enough batches yet if self.buffered_rows >= self.target_batch_size { // combine the batches and return - let batch = concat_batches( - &self.schema, - &self.buffer, - self.buffered_rows, - )?; + let batch = concat_batches(&self.schema, &self.buffer)?; // reset buffer state self.buffer.clear(); self.buffered_rows = 0; @@ -250,11 +244,7 @@ impl CoalesceBatchesStream { return Poll::Ready(None); } else { // combine the batches and return - let batch = concat_batches( - &self.schema, - &self.buffer, - self.buffered_rows, - )?; + let batch = concat_batches(&self.schema, &self.buffer)?; // reset buffer state self.buffer.clear(); self.buffered_rows = 0; @@ -276,20 +266,6 @@ impl RecordBatchStream for CoalesceBatchesStream { } } -/// Concatenates an array of `RecordBatch` into one batch -pub fn concat_batches( - schema: &SchemaRef, - batches: &[RecordBatch], - row_count: usize, -) -> ArrowResult { - trace!( - "Combined {} batches containing {} rows", - batches.len(), - row_count - ); - arrow::compute::concat_batches(schema, batches) -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 8304ddc7331a..b1482a9699d5 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -18,13 +18,10 @@ //! Defines the cross join plan for loading the left side of the cross join //! and producing batches in parallel for the right partitions -use std::{any::Any, sync::Arc, task::Poll}; - use super::utils::{ adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, StatefulStreamResult, }; -use crate::coalesce_batches::concat_batches; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ @@ -33,6 +30,8 @@ use crate::{ ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use arrow::compute::concat_batches; +use std::{any::Any, sync::Arc, task::Poll}; use arrow::datatypes::{Fields, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; @@ -155,7 +154,7 @@ async fn load_left_input( let stream = merge.execute(0, context)?; // Load all batches and count the rows - let (batches, num_rows, _, reservation) = stream + let (batches, _num_rows, _, reservation) = stream .try_fold( (Vec::new(), 0usize, metrics, reservation), |mut acc, batch| async { @@ -175,7 +174,7 @@ async fn load_left_input( ) .await?; - let merged_batch = concat_batches(&left_schema, &batches, num_rows)?; + let merged_batch = concat_batches(&left_schema, &batches)?; Ok((merged_batch, reservation)) } diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index f8ca38980850..eac135bfd0fe 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -26,7 +26,6 @@ use std::sync::Arc; use std::task::Poll; use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final}; -use crate::coalesce_batches::concat_batches; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, @@ -44,6 +43,7 @@ use crate::{ use arrow::array::{ BooleanBufferBuilder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, }; +use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; @@ -364,7 +364,7 @@ async fn collect_left_input( let stream = merge.execute(0, context)?; // Load all batches and count the rows - let (batches, num_rows, metrics, mut reservation) = stream + let (batches, _num_rows, metrics, mut reservation) = stream .try_fold( (Vec::new(), 0usize, join_metrics, reservation), |mut acc, batch| async { @@ -384,7 +384,7 @@ async fn collect_left_input( ) .await?; - let merged_batch = concat_batches(&schema, &batches, num_rows)?; + let merged_batch = concat_batches(&schema, &batches)?; // Reserve memory for visited_left_side bitmap if required by join type let visited_left_side = if with_visited_left_side { From 1e06b91d598782f8f732b104fe4c46468c4e3136 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 24 Jul 2024 21:09:33 +0800 Subject: [PATCH 03/17] Rename `functions-array` to `functions-nested` (#11602) * rename create to function-nested * rename array_expressions to nested_expression * rename doc and workflow * cargo fmt * update lock * Update readme * rename the missing parts * rename the planner * add backward compatibility --- .github/workflows/rust.yml | 4 ++-- Cargo.toml | 4 ++-- README.md | 2 +- datafusion-cli/Cargo.lock | 4 ++-- datafusion/core/Cargo.toml | 10 ++++---- datafusion/core/benches/map_query_sql.rs | 2 +- .../src/execution/session_state_defaults.rs | 23 ++++++++++--------- datafusion/core/src/lib.rs | 15 ++++++++---- datafusion/core/src/prelude.rs | 4 ++-- .../tests/dataframe/dataframe_functions.rs | 2 +- datafusion/core/tests/expr_api/mod.rs | 2 +- .../user_defined_scalar_functions.rs | 2 +- datafusion/expr/src/expr_rewriter/mod.rs | 2 +- .../Cargo.toml | 6 ++--- .../README.md | 4 ++-- .../benches/array_expression.rs | 2 +- .../benches/map.rs | 6 ++--- .../src/array_has.rs | 0 .../src/cardinality.rs | 0 .../src/concat.rs | 0 .../src/dimension.rs | 0 .../src/empty.rs | 0 .../src/except.rs | 0 .../src/expr_ext.rs | 4 ++-- .../src/extract.rs | 0 .../src/flatten.rs | 0 .../src/length.rs | 0 .../src/lib.rs | 14 +++++------ .../src/macros.rs | 0 .../src/make_array.rs | 0 .../src/map.rs | 0 .../src/planner.rs | 6 ++--- .../src/position.rs | 0 .../src/range.rs | 0 .../src/remove.rs | 0 .../src/repeat.rs | 0 .../src/replace.rs | 0 .../src/resize.rs | 0 .../src/reverse.rs | 0 .../src/set_ops.rs | 0 .../src/sort.rs | 0 .../src/string.rs | 0 .../src/utils.rs | 0 .../tests/cases/roundtrip_logical_plan.rs | 2 +- dev/release/README.md | 2 +- dev/release/crate-deps.dot | 10 ++++---- dev/release/crate-deps.svg | 22 +++++++++--------- dev/update_datafusion_versions.py | 2 +- 48 files changed, 83 insertions(+), 73 deletions(-) rename datafusion/{functions-array => functions-nested}/Cargo.toml (92%) rename datafusion/{functions-array => functions-nested}/README.md (87%) rename datafusion/{functions-array => functions-nested}/benches/array_expression.rs (95%) rename datafusion/{functions-array => functions-nested}/benches/map.rs (95%) rename datafusion/{functions-array => functions-nested}/src/array_has.rs (100%) rename datafusion/{functions-array => functions-nested}/src/cardinality.rs (100%) rename datafusion/{functions-array => functions-nested}/src/concat.rs (100%) rename datafusion/{functions-array => functions-nested}/src/dimension.rs (100%) rename datafusion/{functions-array => functions-nested}/src/empty.rs (100%) rename datafusion/{functions-array => functions-nested}/src/except.rs (100%) rename datafusion/{functions-array => functions-nested}/src/expr_ext.rs (95%) rename datafusion/{functions-array => functions-nested}/src/extract.rs (100%) rename datafusion/{functions-array => functions-nested}/src/flatten.rs (100%) rename datafusion/{functions-array => functions-nested}/src/length.rs (100%) rename datafusion/{functions-array => functions-nested}/src/lib.rs (93%) rename datafusion/{functions-array => functions-nested}/src/macros.rs (100%) rename datafusion/{functions-array => functions-nested}/src/make_array.rs (100%) rename datafusion/{functions-array => functions-nested}/src/map.rs (100%) rename datafusion/{functions-array => functions-nested}/src/planner.rs (97%) rename datafusion/{functions-array => functions-nested}/src/position.rs (100%) rename datafusion/{functions-array => functions-nested}/src/range.rs (100%) rename datafusion/{functions-array => functions-nested}/src/remove.rs (100%) rename datafusion/{functions-array => functions-nested}/src/repeat.rs (100%) rename datafusion/{functions-array => functions-nested}/src/replace.rs (100%) rename datafusion/{functions-array => functions-nested}/src/resize.rs (100%) rename datafusion/{functions-array => functions-nested}/src/reverse.rs (100%) rename datafusion/{functions-array => functions-nested}/src/set_ops.rs (100%) rename datafusion/{functions-array => functions-nested}/src/sort.rs (100%) rename datafusion/{functions-array => functions-nested}/src/string.rs (100%) rename datafusion/{functions-array => functions-nested}/src/utils.rs (100%) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2ddeebbc558e..4a41fd542e5d 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -90,8 +90,8 @@ jobs: # Ensure that the datafusion crate can be built with only a subset of the function # packages enabled. - - name: Check datafusion (array_expressions) - run: cargo check --no-default-features --features=array_expressions -p datafusion + - name: Check datafusion (nested_expressions) + run: cargo check --no-default-features --features=nested_expressions -p datafusion - name: Check datafusion (crypto) run: cargo check --no-default-features --features=crypto_expressions -p datafusion diff --git a/Cargo.toml b/Cargo.toml index 24bde78b3001..cb27a8761a8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ members = [ "datafusion/execution", "datafusion/functions-aggregate", "datafusion/functions", - "datafusion/functions-array", + "datafusion/functions-nested", "datafusion/optimizer", "datafusion/physical-expr-common", "datafusion/physical-expr", @@ -94,7 +94,7 @@ datafusion-execution = { path = "datafusion/execution", version = "40.0.0" } datafusion-expr = { path = "datafusion/expr", version = "40.0.0" } datafusion-functions = { path = "datafusion/functions", version = "40.0.0" } datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "40.0.0" } -datafusion-functions-array = { path = "datafusion/functions-array", version = "40.0.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "40.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "40.0.0", default-features = false } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "40.0.0", default-features = false } datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "40.0.0", default-features = false } diff --git a/README.md b/README.md index 197e5d2b3fe1..b1d38b61109f 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ This crate has several [features] which can be specified in your `Cargo.toml`. Default features: -- `array_expressions`: functions for working with arrays such as `array_to_string` +- `nested_expressions`: functions for working with nested type function such as `array_to_string` - `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` - `crypto_expressions`: cryptographic functions such as `md5` and `sha256` - `datetime_expressions`: date and time functions such as `to_timestamp` diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 84bff8c87190..a4e87f99b5c3 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1149,7 +1149,7 @@ dependencies = [ "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "datafusion-functions-array", + "datafusion-functions-nested", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -1315,7 +1315,7 @@ dependencies = [ ] [[package]] -name = "datafusion-functions-array" +name = "datafusion-functions-nested" version = "40.0.0" dependencies = [ "arrow", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 4301396b231f..bed9265ff016 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -40,15 +40,17 @@ name = "datafusion" path = "src/lib.rs" [features] +nested_expressions = ["datafusion-functions-nested"] +# This feature is deprecated. Use the `nested_expressions` feature instead. +array_expressions = ["nested_expressions"] # Used to enable the avro format -array_expressions = ["datafusion-functions-array"] avro = ["apache-avro", "num-traits", "datafusion-common/avro"] backtrace = ["datafusion-common/backtrace"] compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression", "tokio-util"] crypto_expressions = ["datafusion-functions/crypto_expressions"] datetime_expressions = ["datafusion-functions/datetime_expressions"] default = [ - "array_expressions", + "nested_expressions", "crypto_expressions", "datetime_expressions", "encoding_expressions", @@ -102,7 +104,7 @@ datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } -datafusion-functions-array = { workspace = true, optional = true } +datafusion-functions-nested = { workspace = true, optional = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } @@ -221,4 +223,4 @@ name = "parquet_statistic" [[bench]] harness = false name = "map_query_sql" -required-features = ["array_expressions"] +required-features = ["nested_expressions"] diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index b6ac8b6b647a..e4c5f7c5deb3 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -27,7 +27,7 @@ use tokio::runtime::Runtime; use datafusion::prelude::SessionContext; use datafusion_common::ScalarValue; use datafusion_expr::Expr; -use datafusion_functions_array::map::map; +use datafusion_functions_nested::map::map; mod data_utils; diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 0b0465e44605..b7e7b5f0955f 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -26,8 +26,8 @@ use crate::datasource::file_format::parquet::ParquetFormatFactory; use crate::datasource::file_format::FileFormatFactory; use crate::datasource::provider::{DefaultTableFactory, TableProviderFactory}; use crate::execution::context::SessionState; -#[cfg(feature = "array_expressions")] -use crate::functions_array; +#[cfg(feature = "nested_expressions")] +use crate::functions_nested; use crate::{functions, functions_aggregate}; use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; @@ -82,11 +82,11 @@ impl SessionStateDefaults { pub fn default_expr_planners() -> Vec> { let expr_planners: Vec> = vec![ Arc::new(functions::core::planner::CoreFunctionPlanner::default()), - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::ArrayFunctionPlanner), - #[cfg(feature = "array_expressions")] - Arc::new(functions_array::planner::FieldAccessPlanner), + // register crate of nested expressions (if enabled) + #[cfg(feature = "nested_expressions")] + Arc::new(functions_nested::planner::NestedFunctionPlanner), + #[cfg(feature = "nested_expressions")] + Arc::new(functions_nested::planner::FieldAccessPlanner), #[cfg(any( feature = "datetime_expressions", feature = "unicode_expressions" @@ -100,8 +100,8 @@ impl SessionStateDefaults { /// returns the list of default [`ScalarUDF']'s pub fn default_scalar_functions() -> Vec> { let mut functions: Vec> = functions::all_default_functions(); - #[cfg(feature = "array_expressions")] - functions.append(&mut functions_array::all_default_array_functions()); + #[cfg(feature = "nested_expressions")] + functions.append(&mut functions_nested::all_default_nested_functions()); functions } @@ -140,8 +140,9 @@ impl SessionStateDefaults { /// registers all the builtin array functions pub fn register_array_functions(state: &mut SessionState) { // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - functions_array::register_all(state).expect("can not register array expressions"); + #[cfg(feature = "nested_expressions")] + functions_nested::register_all(state) + .expect("can not register nested expressions"); } /// registers all the builtin aggregate functions diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 9b9b1db8ff81..9ab6ed527d82 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -458,7 +458,7 @@ //! * [datafusion_execution]: State and structures needed for execution //! * [datafusion_expr]: [`LogicalPlan`], [`Expr`] and related logical planning structure //! * [datafusion_functions]: Scalar function packages -//! * [datafusion_functions_array]: Scalar function packages for `ARRAY`s +//! * [datafusion_functions_nested]: Scalar function packages for `ARRAY`s, `MAP`s and `STRUCT`s //! * [datafusion_optimizer]: [`OptimizerRule`]s and [`AnalyzerRule`]s //! * [datafusion_physical_expr]: [`PhysicalExpr`] and related expressions //! * [datafusion_physical_plan]: [`ExecutionPlan`] and related expressions @@ -569,10 +569,17 @@ pub mod functions { pub use datafusion_functions::*; } -/// re-export of [`datafusion_functions_array`] crate, if "array_expressions" feature is enabled +/// re-export of [`datafusion_functions_nested`] crate, if "nested_expressions" feature is enabled +pub mod functions_nested { + #[cfg(feature = "nested_expressions")] + pub use datafusion_functions_nested::*; +} + +/// re-export of [`datafusion_functions_nested`] crate as [`functions_array`] for backward compatibility, if "nested_expressions" feature is enabled +#[deprecated(since = "41.0.0", note = "use datafusion-functions-nested instead")] pub mod functions_array { - #[cfg(feature = "array_expressions")] - pub use datafusion_functions_array::*; + #[cfg(feature = "nested_expressions")] + pub use datafusion_functions_nested::*; } /// re-export of [`datafusion_functions_aggregate`] crate diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index d82a5a2cc1a1..9c9fcd04bf09 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -39,8 +39,8 @@ pub use datafusion_expr::{ Expr, }; pub use datafusion_functions::expr_fn::*; -#[cfg(feature = "array_expressions")] -pub use datafusion_functions_array::expr_fn::*; +#[cfg(feature = "nested_expressions")] +pub use datafusion_functions_nested::expr_fn::*; pub use std::ops::Not; pub use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index f7b02196d8ed..7a0e9888a61c 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -34,7 +34,7 @@ use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; -use datafusion_functions_array::map::map; +use datafusion_functions_nested::map::map; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index f36f2d539845..37d06355d2d3 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -25,7 +25,7 @@ use datafusion_expr::AggregateExt; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; -use datafusion_functions_array::expr_ext::{IndexAccessor, SliceAccessor}; +use datafusion_functions_nested::expr_ext::{IndexAccessor, SliceAccessor}; use sqlparser::ast::NullTreatment; /// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan use std::sync::{Arc, OnceLock}; diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 219f6c26cf8f..9164e89de8f9 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -45,7 +45,7 @@ use datafusion_expr::{ LogicalPlanBuilder, OperateFunctionArg, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; -use datafusion_functions_array::range::range_udf; +use datafusion_functions_nested::range::range_udf; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 8d460bdc8e7d..bf2bfe2c3932 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -42,7 +42,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// /// For example, concatenating arrays `a || b` is represented as /// `Operator::ArrowAt`, but can be implemented by calling a function -/// `array_concat` from the `functions-array` crate. +/// `array_concat` from the `functions-nested` crate. // This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it. pub trait FunctionRewrite { /// Return a human readable name for this rewrite diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-nested/Cargo.toml similarity index 92% rename from datafusion/functions-array/Cargo.toml rename to datafusion/functions-nested/Cargo.toml index de424b259694..6a1973ecfed1 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -16,8 +16,8 @@ # under the License. [package] -name = "datafusion-functions-array" -description = "Array Function packages for the DataFusion query engine" +name = "datafusion-functions-nested" +description = "Nested Type Function packages for the DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] readme = "README.md" version = { workspace = true } @@ -34,7 +34,7 @@ workspace = true [features] [lib] -name = "datafusion_functions_array" +name = "datafusion_functions_nested" path = "src/lib.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/datafusion/functions-array/README.md b/datafusion/functions-nested/README.md similarity index 87% rename from datafusion/functions-array/README.md rename to datafusion/functions-nested/README.md index 25deca8e1c77..8a5047c838ab 100644 --- a/datafusion/functions-array/README.md +++ b/datafusion/functions-nested/README.md @@ -17,11 +17,11 @@ under the License. --> -# DataFusion Array Function Library +# DataFusion Nested Type Function Library [DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. -This crate contains functions for working with arrays, such as `array_append` that work with +This crate contains functions for working with arrays, maps and structs, such as `array_append` that work with `ListArray`, `LargeListArray` and `FixedListArray` types from the `arrow` crate. [df]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-array/benches/array_expression.rs b/datafusion/functions-nested/benches/array_expression.rs similarity index 95% rename from datafusion/functions-array/benches/array_expression.rs rename to datafusion/functions-nested/benches/array_expression.rs index 48b829793cef..0e3ecbc72641 100644 --- a/datafusion/functions-array/benches/array_expression.rs +++ b/datafusion/functions-nested/benches/array_expression.rs @@ -21,7 +21,7 @@ extern crate arrow; use crate::criterion::Criterion; use datafusion_expr::lit; -use datafusion_functions_array::expr_fn::{array_replace_all, make_array}; +use datafusion_functions_nested::expr_fn::{array_replace_all, make_array}; fn criterion_benchmark(c: &mut Criterion) { // Construct large arrays for benchmarking diff --git a/datafusion/functions-array/benches/map.rs b/datafusion/functions-nested/benches/map.rs similarity index 95% rename from datafusion/functions-array/benches/map.rs rename to datafusion/functions-nested/benches/map.rs index c2e0e641e80d..c9a12eefa4fa 100644 --- a/datafusion/functions-array/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -28,8 +28,8 @@ use std::sync::Arc; use datafusion_common::ScalarValue; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{ColumnarValue, Expr}; -use datafusion_functions_array::map::map_udf; -use datafusion_functions_array::planner::ArrayFunctionPlanner; +use datafusion_functions_nested::map::map_udf; +use datafusion_functions_nested::planner::NestedFunctionPlanner; fn keys(rng: &mut ThreadRng) -> Vec { let mut keys = vec![]; @@ -58,7 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); } - let planner = ArrayFunctionPlanner {}; + let planner = NestedFunctionPlanner {}; b.iter(|| { black_box( diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs similarity index 100% rename from datafusion/functions-array/src/array_has.rs rename to datafusion/functions-nested/src/array_has.rs diff --git a/datafusion/functions-array/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs similarity index 100% rename from datafusion/functions-array/src/cardinality.rs rename to datafusion/functions-nested/src/cardinality.rs diff --git a/datafusion/functions-array/src/concat.rs b/datafusion/functions-nested/src/concat.rs similarity index 100% rename from datafusion/functions-array/src/concat.rs rename to datafusion/functions-nested/src/concat.rs diff --git a/datafusion/functions-array/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs similarity index 100% rename from datafusion/functions-array/src/dimension.rs rename to datafusion/functions-nested/src/dimension.rs diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-nested/src/empty.rs similarity index 100% rename from datafusion/functions-array/src/empty.rs rename to datafusion/functions-nested/src/empty.rs diff --git a/datafusion/functions-array/src/except.rs b/datafusion/functions-nested/src/except.rs similarity index 100% rename from datafusion/functions-array/src/except.rs rename to datafusion/functions-nested/src/except.rs diff --git a/datafusion/functions-array/src/expr_ext.rs b/datafusion/functions-nested/src/expr_ext.rs similarity index 95% rename from datafusion/functions-array/src/expr_ext.rs rename to datafusion/functions-nested/src/expr_ext.rs index 5505ef746881..3524d62d0bc4 100644 --- a/datafusion/functions-array/src/expr_ext.rs +++ b/datafusion/functions-nested/src/expr_ext.rs @@ -35,7 +35,7 @@ use crate::extract::{array_element, array_slice}; /// /// ``` /// # use datafusion_expr::{lit, col, Expr}; -/// # use datafusion_functions_array::expr_ext::IndexAccessor; +/// # use datafusion_functions_nested::expr_ext::IndexAccessor; /// let expr = col("c1") /// .index(lit(3)); /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(3)]"); @@ -65,7 +65,7 @@ impl IndexAccessor for Expr { /// /// ``` /// # use datafusion_expr::{lit, col}; -/// # use datafusion_functions_array::expr_ext::SliceAccessor; +/// # use datafusion_functions_nested::expr_ext::SliceAccessor; /// let expr = col("c1") /// .range(lit(2), lit(4)); /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4)]"); diff --git a/datafusion/functions-array/src/extract.rs b/datafusion/functions-nested/src/extract.rs similarity index 100% rename from datafusion/functions-array/src/extract.rs rename to datafusion/functions-nested/src/extract.rs diff --git a/datafusion/functions-array/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs similarity index 100% rename from datafusion/functions-array/src/flatten.rs rename to datafusion/functions-nested/src/flatten.rs diff --git a/datafusion/functions-array/src/length.rs b/datafusion/functions-nested/src/length.rs similarity index 100% rename from datafusion/functions-array/src/length.rs rename to datafusion/functions-nested/src/length.rs diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-nested/src/lib.rs similarity index 93% rename from datafusion/functions-array/src/lib.rs rename to datafusion/functions-nested/src/lib.rs index f68f59dcd6a1..ef2c5e709bc1 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -17,9 +17,9 @@ // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -//! Array Functions for [DataFusion]. +//! Nested type Functions for [DataFusion]. //! -//! This crate contains a collection of array functions implemented using the +//! This crate contains a collection of nested type functions implemented using the //! extension API. //! //! [DataFusion]: https://crates.io/crates/datafusion @@ -102,8 +102,8 @@ pub mod expr_fn { pub use super::string::string_to_array; } -/// Return all default array functions -pub fn all_default_array_functions() -> Vec> { +/// Return all default nested type functions +pub fn all_default_nested_functions() -> Vec> { vec![ string::array_to_string_udf(), string::string_to_array_udf(), @@ -148,7 +148,7 @@ pub fn all_default_array_functions() -> Vec> { /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { - let functions: Vec> = all_default_array_functions(); + let functions: Vec> = all_default_nested_functions(); functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; if let Some(existing_udf) = existing_udf { @@ -162,14 +162,14 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { #[cfg(test)] mod tests { - use crate::all_default_array_functions; + use crate::all_default_nested_functions; use datafusion_common::Result; use std::collections::HashSet; #[test] fn test_no_duplicate_name() -> Result<()> { let mut names = HashSet::new(); - for func in all_default_array_functions() { + for func in all_default_nested_functions() { assert!( names.insert(func.name().to_string().to_lowercase()), "duplicate function name: {}", diff --git a/datafusion/functions-array/src/macros.rs b/datafusion/functions-nested/src/macros.rs similarity index 100% rename from datafusion/functions-array/src/macros.rs rename to datafusion/functions-nested/src/macros.rs diff --git a/datafusion/functions-array/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs similarity index 100% rename from datafusion/functions-array/src/make_array.rs rename to datafusion/functions-nested/src/make_array.rs diff --git a/datafusion/functions-array/src/map.rs b/datafusion/functions-nested/src/map.rs similarity index 100% rename from datafusion/functions-array/src/map.rs rename to datafusion/functions-nested/src/map.rs diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-nested/src/planner.rs similarity index 97% rename from datafusion/functions-array/src/planner.rs rename to datafusion/functions-nested/src/planner.rs index 3f779c9f111e..97c54cc77beb 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! SQL planning extensions like [`ArrayFunctionPlanner`] and [`FieldAccessPlanner`] +//! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`] use datafusion_common::{exec_err, utils::list_ndims, DFSchema, Result}; use datafusion_expr::expr::ScalarFunction; @@ -35,9 +35,9 @@ use crate::{ make_array::make_array, }; -pub struct ArrayFunctionPlanner; +pub struct NestedFunctionPlanner; -impl ExprPlanner for ArrayFunctionPlanner { +impl ExprPlanner for NestedFunctionPlanner { fn plan_binary_op( &self, expr: RawBinaryExpr, diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-nested/src/position.rs similarity index 100% rename from datafusion/functions-array/src/position.rs rename to datafusion/functions-nested/src/position.rs diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-nested/src/range.rs similarity index 100% rename from datafusion/functions-array/src/range.rs rename to datafusion/functions-nested/src/range.rs diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-nested/src/remove.rs similarity index 100% rename from datafusion/functions-array/src/remove.rs rename to datafusion/functions-nested/src/remove.rs diff --git a/datafusion/functions-array/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs similarity index 100% rename from datafusion/functions-array/src/repeat.rs rename to datafusion/functions-nested/src/repeat.rs diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-nested/src/replace.rs similarity index 100% rename from datafusion/functions-array/src/replace.rs rename to datafusion/functions-nested/src/replace.rs diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-nested/src/resize.rs similarity index 100% rename from datafusion/functions-array/src/resize.rs rename to datafusion/functions-nested/src/resize.rs diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-nested/src/reverse.rs similarity index 100% rename from datafusion/functions-array/src/reverse.rs rename to datafusion/functions-nested/src/reverse.rs diff --git a/datafusion/functions-array/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs similarity index 100% rename from datafusion/functions-array/src/set_ops.rs rename to datafusion/functions-nested/src/set_ops.rs diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-nested/src/sort.rs similarity index 100% rename from datafusion/functions-array/src/sort.rs rename to datafusion/functions-nested/src/sort.rs diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-nested/src/string.rs similarity index 100% rename from datafusion/functions-array/src/string.rs rename to datafusion/functions-nested/src/string.rs diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-nested/src/utils.rs similarity index 100% rename from datafusion/functions-array/src/utils.rs rename to datafusion/functions-nested/src/utils.rs diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e17515086ecd..25223c3731be 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -44,7 +44,7 @@ use datafusion::functions_aggregate::expr_fn::{ count_distinct, covar_pop, covar_samp, first_value, grouping, median, stddev, stddev_pop, sum, var_pop, var_sample, }; -use datafusion::functions_array::map::map; +use datafusion::functions_nested::map::map; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::TableOptions; diff --git a/dev/release/README.md b/dev/release/README.md index 9486222c5105..c6bc9be2b0db 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -268,7 +268,7 @@ dot -Tsvg dev/release/crate-deps.dot > dev/release/crate-deps.svg (cd datafusion/functions-aggregate && cargo publish) (cd datafusion/physical-expr && cargo publish) (cd datafusion/functions && cargo publish) -(cd datafusion/functions-array && cargo publish) +(cd datafusion/functions-nested && cargo publish) (cd datafusion/sql && cargo publish) (cd datafusion/optimizer && cargo publish) (cd datafusion/common-runtime && cargo publish) diff --git a/dev/release/crate-deps.dot b/dev/release/crate-deps.dot index 69811c7d6109..1d903a56021d 100644 --- a/dev/release/crate-deps.dot +++ b/dev/release/crate-deps.dot @@ -74,15 +74,15 @@ digraph G { datafusion -> datafusion_execution datafusion -> datafusion_expr datafusion -> datafusion_functions - datafusion -> datafusion_functions_array + datafusion -> datafusion_functions_nested datafusion -> datafusion_optimizer datafusion -> datafusion_physical_expr datafusion -> datafusion_physical_plan datafusion -> datafusion_sql - datafusion_functions_array - datafusion_functions_array -> datafusion_common - datafusion_functions_array -> datafusion_execution - datafusion_functions_array -> datafusion_expr + datafusion_functions_nested + datafusion_functions_nested -> datafusion_common + datafusion_functions_nested -> datafusion_execution + datafusion_functions_nested -> datafusion_expr datafusion_execution datafusion_execution -> datafusion_common datafusion_execution -> datafusion_expr diff --git a/dev/release/crate-deps.svg b/dev/release/crate-deps.svg index cf60bf752642..c76fe3abb4ac 100644 --- a/dev/release/crate-deps.svg +++ b/dev/release/crate-deps.svg @@ -153,15 +153,15 @@ - + -datafusion_functions_array +datafusion_functions_nested -datafusion_functions_array +datafusion_functions_nested - + -datafusion->datafusion_functions_array +datafusion->datafusion_functions_nested @@ -411,21 +411,21 @@ - + -datafusion_functions_array->datafusion_common +datafusion_functions_nested->datafusion_common - + -datafusion_functions_array->datafusion_expr +datafusion_functions_nested->datafusion_expr - + -datafusion_functions_array->datafusion_execution +datafusion_functions_nested->datafusion_execution diff --git a/dev/update_datafusion_versions.py b/dev/update_datafusion_versions.py index 74a8a2ebd5b6..2e3374cd920b 100755 --- a/dev/update_datafusion_versions.py +++ b/dev/update_datafusion_versions.py @@ -35,7 +35,7 @@ 'datafusion-expr': 'datafusion/expr/Cargo.toml', 'datafusion-functions': 'datafusion/functions/Cargo.toml', 'datafusion-functions-aggregate': 'datafusion/functions-aggregate/Cargo.toml', - 'datafusion-functions-array': 'datafusion/functions-array/Cargo.toml', + 'datafusion-functions-nested': 'datafusion/functions-nested/Cargo.toml', 'datafusion-optimizer': 'datafusion/optimizer/Cargo.toml', 'datafusion-physical-expr': 'datafusion/physical-expr/Cargo.toml', 'datafusion-physical-expr-common': 'datafusion/physical-expr-common/Cargo.toml', From 8945462ed0baf20eb4fb8e298407d08072030e33 Mon Sep 17 00:00:00 2001 From: Namgung Chan <33323415+getChan@users.noreply.github.com> Date: Wed, 24 Jul 2024 22:11:00 +0900 Subject: [PATCH 04/17] Fix : `signum` function bug when `0.0` input (#11580) * add signum unit test * fix: signum function implementation - input zero output zero * fix: run cargo fmt * fix: not specified return type is float64 * fix: sqllogictest --- datafusion/functions/src/math/mod.rs | 3 +- datafusion/functions/src/math/monotonicity.rs | 5 - datafusion/functions/src/math/signum.rs | 215 ++++++++++++++++++ datafusion/sqllogictest/test_files/scalar.slt | 2 +- 4 files changed, 218 insertions(+), 7 deletions(-) create mode 100644 datafusion/functions/src/math/signum.rs diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 9ee173bb6176..3b32a158b884 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -35,6 +35,7 @@ pub mod pi; pub mod power; pub mod random; pub mod round; +pub mod signum; pub mod trunc; // Create UDFs @@ -81,7 +82,7 @@ make_math_unary_udf!( ); make_udf_function!(random::RandomFunc, RANDOM, random); make_udf_function!(round::RoundFunc, ROUND, round); -make_math_unary_udf!(SignumFunc, SIGNUM, signum, signum, super::signum_order); +make_udf_function!(signum::SignumFunc, SIGNUM, signum); make_math_unary_udf!(SinFunc, SIN, sin, sin, super::sin_order); make_math_unary_udf!(SinhFunc, SINH, sinh, sinh, super::sinh_order); make_math_unary_udf!(SqrtFunc, SQRT, sqrt, sqrt, super::sqrt_order); diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index 56c5a45788bc..33c061ee11d0 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -197,11 +197,6 @@ pub fn radians_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } -/// Non-decreasing for all real numbers x. -pub fn signum_order(input: &[ExprProperties]) -> Result { - Ok(input[0].sort_properties) -} - /// Non-decreasing on \[0, π\] and then non-increasing on \[π, 2π\]. /// This pattern repeats periodically with a period of 2π. // TODO: Implement ordering rule of the SIN function. diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs new file mode 100644 index 000000000000..d2a806a46e13 --- /dev/null +++ b/datafusion/functions/src/math/signum.rs @@ -0,0 +1,215 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Float32, Float64}; + +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct SignumFunc { + signature: Signature, +} + +impl Default for SignumFunc { + fn default() -> Self { + SignumFunc::new() + } +} + +impl SignumFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SignumFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "signum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // Non-decreasing for all real numbers x. + Ok(input[0].sort_properties) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(signum, vec![])(args) + } +} + +/// signum SQL function +pub fn signum(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "signum", + Float64Array, + Float64Array, + { + |x: f64| { + if x == 0_f64 { + 0_f64 + } else { + x.signum() + } + } + } + )) as ArrayRef), + + Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "signum", + Float32Array, + Float32Array, + { + |x: f32| { + if x == 0_f32 { + 0_f32 + } else { + x.signum() + } + } + } + )) as ArrayRef), + + other => exec_err!("Unsupported data type {other:?} for function signum"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{Float32Array, Float64Array}; + + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::math::signum::SignumFunc; + + #[test] + fn test_signum_f32() { + let args = [ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + -1.0, + -0.0, + 0.0, + 1.0, + -0.01, + 0.01, + f32::NAN, + f32::INFINITY, + f32::NEG_INFINITY, + ])))]; + + let result = SignumFunc::new() + .invoke(&args) + .expect("failed to initialize function signum"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 9); + assert_eq!(floats.value(0), -1.0); + assert_eq!(floats.value(1), 0.0); + assert_eq!(floats.value(2), 0.0); + assert_eq!(floats.value(3), 1.0); + assert_eq!(floats.value(4), -1.0); + assert_eq!(floats.value(5), 1.0); + assert!(floats.value(6).is_nan()); + assert_eq!(floats.value(7), 1.0); + assert_eq!(floats.value(8), -1.0); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_signum_f64() { + let args = [ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + -1.0, + -0.0, + 0.0, + 1.0, + -0.01, + 0.01, + f64::NAN, + f64::INFINITY, + f64::NEG_INFINITY, + ])))]; + + let result = SignumFunc::new() + .invoke(&args) + .expect("failed to initialize function signum"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 9); + assert_eq!(floats.value(0), -1.0); + assert_eq!(floats.value(1), 0.0); + assert_eq!(floats.value(2), 0.0); + assert_eq!(floats.value(3), 1.0); + assert_eq!(floats.value(4), -1.0); + assert_eq!(floats.value(5), 1.0); + assert!(floats.value(6).is_nan()); + assert_eq!(floats.value(7), 1.0); + assert_eq!(floats.value(8), -1.0); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } +} diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index ff9afa94f40a..188a2c5863e6 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -794,7 +794,7 @@ select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 21474836 query RRR rowsort select signum(-2), signum(0), signum(2); ---- --1 1 1 +-1 0 1 # signum scalar nulls query R rowsort From 6efdbe6d4b8df4ef8c149f42e57d9c3aed7f3266 Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Wed, 24 Jul 2024 18:42:17 +0530 Subject: [PATCH 05/17] Enforce uniqueness of `named_struct` field names (#11614) * check struct field names for uniqueness * add logic test * improve error log --- datafusion/functions/src/core/named_struct.rs | 15 ++++++++++++++- datafusion/sqllogictest/test_files/struct.slt | 4 ++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 8ccda977f3a4..f71b1b00f0fe 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -20,6 +20,7 @@ use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use hashbrown::HashSet; use std::any::Any; use std::sync::Arc; @@ -45,7 +46,6 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { .map(|(i, chunk)| { let name_column = &chunk[0]; - let name = match name_column { ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar, _ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2) @@ -57,6 +57,19 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { .into_iter() .unzip(); + { + // Check to enforce the uniqueness of struct field name + let mut unique_field_names = HashSet::new(); + for name in names.iter() { + if unique_field_names.contains(name) { + return exec_err!( + "named_struct requires unique field names. Field {name} is used more than once." + ); + } + unique_field_names.insert(name); + } + } + let arrays = ColumnarValue::values_to_arrays(&values)?; let fields = names diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index a7384fd4d8ad..caa612f556fe 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -122,6 +122,10 @@ physical_plan query error select named_struct(); +# error on duplicate field names +query error +select named_struct('c0': 1, 'c1': 2, 'c1': 3); + # error on odd number of arguments #1 query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead select named_struct('a'); From e90b3ac5cf89ec5b1a94506ac69e85bd9b7d319e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 10:20:18 -0400 Subject: [PATCH 06/17] Minor: unecessary row_count calculation in `CrossJoinExec` and `NestedLoopsJoinExec` (#11632) * Minor: remove row_count calculation * Minor: remove row_count calculation --- .../physical-plan/src/joins/cross_join.rs | 31 ++++++++----------- .../src/joins/nested_loop_join.rs | 14 ++++----- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index b1482a9699d5..2840d3f62bf9 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -154,24 +154,19 @@ async fn load_left_input( let stream = merge.execute(0, context)?; // Load all batches and count the rows - let (batches, _num_rows, _, reservation) = stream - .try_fold( - (Vec::new(), 0usize, metrics, reservation), - |mut acc, batch| async { - let batch_size = batch.get_array_memory_size(); - // Reserve memory for incoming batch - acc.3.try_grow(batch_size)?; - // Update metrics - acc.2.build_mem_used.add(batch_size); - acc.2.build_input_batches.add(1); - acc.2.build_input_rows.add(batch.num_rows()); - // Update rowcount - acc.1 += batch.num_rows(); - // Push batch to output - acc.0.push(batch); - Ok(acc) - }, - ) + let (batches, _metrics, reservation) = stream + .try_fold((Vec::new(), metrics, reservation), |mut acc, batch| async { + let batch_size = batch.get_array_memory_size(); + // Reserve memory for incoming batch + acc.2.try_grow(batch_size)?; + // Update metrics + acc.1.build_mem_used.add(batch_size); + acc.1.build_input_batches.add(1); + acc.1.build_input_rows.add(batch.num_rows()); + // Push batch to output + acc.0.push(batch); + Ok(acc) + }) .await?; let merged_batch = concat_batches(&left_schema, &batches)?; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index eac135bfd0fe..9f1465c2d7c1 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -364,19 +364,17 @@ async fn collect_left_input( let stream = merge.execute(0, context)?; // Load all batches and count the rows - let (batches, _num_rows, metrics, mut reservation) = stream + let (batches, metrics, mut reservation) = stream .try_fold( - (Vec::new(), 0usize, join_metrics, reservation), + (Vec::new(), join_metrics, reservation), |mut acc, batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch - acc.3.try_grow(batch_size)?; + acc.2.try_grow(batch_size)?; // Update metrics - acc.2.build_mem_used.add(batch_size); - acc.2.build_input_batches.add(1); - acc.2.build_input_rows.add(batch.num_rows()); - // Update rowcount - acc.1 += batch.num_rows(); + acc.1.build_mem_used.add(batch_size); + acc.1.build_input_batches.add(1); + acc.1.build_input_rows.add(batch.num_rows()); // Push batch to output acc.0.push(batch); Ok(acc) From 13569340bce99e4a317ec4d71e5c46d69dfa733d Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 24 Jul 2024 22:30:05 +0800 Subject: [PATCH 07/17] ExprBuilder for Physical Aggregate Expr (#11617) * aggregate expr builder Signed-off-by: jayzhan211 * replace parts of test Signed-off-by: jayzhan211 * continue Signed-off-by: jayzhan211 * cleanup all Signed-off-by: jayzhan211 * clipp Signed-off-by: jayzhan211 * add sort Signed-off-by: jayzhan211 * rm field Signed-off-by: jayzhan211 * address comment Signed-off-by: jayzhan211 * fix import path Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/src/lib.rs | 5 + .../aggregate_statistics.rs | 20 +- .../combine_partial_final_agg.rs | 41 +-- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 23 +- .../physical-expr-common/src/aggregate/mod.rs | 286 +++++++++++++----- .../physical-plan/src/aggregates/mod.rs | 134 +++----- datafusion/physical-plan/src/windows/mod.rs | 39 +-- datafusion/proto/src/physical_plan/mod.rs | 11 +- .../tests/cases/roundtrip_physical_plan.rs | 177 ++++------- 9 files changed, 369 insertions(+), 367 deletions(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 9ab6ed527d82..d9ab9e1c07dd 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -545,6 +545,11 @@ pub mod optimizer { pub use datafusion_optimizer::*; } +/// re-export of [`datafusion_physical_expr`] crate +pub mod physical_expr_common { + pub use datafusion_physical_expr_common::*; +} + /// re-export of [`datafusion_physical_expr`] crate pub mod physical_expr { pub use datafusion_physical_expr::*; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index e7580d3e33ef..5f08e4512b3a 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -326,7 +326,7 @@ pub(crate) mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::cast; use datafusion_physical_expr::PhysicalExpr; - use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; use datafusion_physical_plan::aggregates::AggregateMode; /// Mock data using a MemoryExec which has an exact count statistic @@ -419,19 +419,11 @@ pub(crate) mod tests { // Return appropriate expr depending if COUNT is for col or table (*) pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { - create_aggregate_expr( - &count_udaf(), - &[self.column()], - &[], - &[], - &[], - schema, - self.column_name(), - false, - false, - false, - ) - .unwrap() + AggregateExprBuilder::new(count_udaf(), vec![self.column()]) + .schema(Arc::new(schema.clone())) + .name(self.column_name()) + .build() + .unwrap() } /// what argument would this aggregate need in the plan? diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index ddb7d36fb595..6f3274820c8c 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -177,7 +177,7 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; - use datafusion_physical_plan::udaf::create_aggregate_expr; + use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected macro_rules! assert_optimized { @@ -278,19 +278,11 @@ mod tests { name: &str, schema: &Schema, ) -> Arc { - create_aggregate_expr( - &count_udaf(), - &[expr], - &[], - &[], - &[], - schema, - name, - false, - false, - false, - ) - .unwrap() + AggregateExprBuilder::new(count_udaf(), vec![expr]) + .schema(Arc::new(schema.clone())) + .name(name) + .build() + .unwrap() } #[test] @@ -368,19 +360,14 @@ mod tests { #[test] fn aggregations_with_group_combined() -> Result<()> { let schema = schema(); - - let aggr_expr = vec![create_aggregate_expr( - &sum_udaf(), - &[col("b", &schema)?], - &[], - &[], - &[], - &schema, - "Sum(b)", - false, - false, - false, - )?]; + let aggr_expr = + vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("Sum(b)") + .build() + .unwrap(), + ]; let groups: Vec<(Arc, String)> = vec![(col("c", &schema)?, "c".to_string())]; diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 736560da97db..6f286c9aeba1 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -35,7 +35,7 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor} use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_plan::udaf::create_aggregate_expr; +use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; @@ -103,19 +103,14 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .with_sort_information(vec![sort_keys]), ); - let aggregate_expr = vec![create_aggregate_expr( - &sum_udaf(), - &[col("d", &schema).unwrap()], - &[], - &[], - &[], - &schema, - "sum1", - false, - false, - false, - ) - .unwrap()]; + let aggregate_expr = + vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .name("sum1") + .build() + .unwrap(), + ]; let expr = group_by_columns .iter() .map(|elem| (col(elem, &schema).unwrap(), elem.to_string())) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 8c5f9f9e5a7e..b58a5a6faf24 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -22,8 +22,8 @@ pub mod stats; pub mod tdigest; pub mod utils; -use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::{not_impl_err, DFSchema, Result}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; use datafusion_expr::ReversedUDAF; @@ -33,7 +33,7 @@ use datafusion_expr::{ use std::fmt::Debug; use std::{any::Any, sync::Arc}; -use self::utils::{down_cast_any_ref, ordering_fields}; +use self::utils::down_cast_any_ref; use crate::physical_expr::PhysicalExpr; use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; use crate::utils::reverse_order_bys; @@ -55,6 +55,8 @@ use datafusion_expr::utils::AggregateOrderSensitivity; /// `is_reversed` is used to indicate whether the aggregation is running in reverse order, /// it could be used to hint Accumulator to accumulate in the reversed order, /// you can just set to false if you are not reversing expression +/// +/// You can also create expression by [`AggregateExprBuilder`] #[allow(clippy::too_many_arguments)] pub fn create_aggregate_expr( fun: &AggregateUDF, @@ -66,45 +68,23 @@ pub fn create_aggregate_expr( name: impl Into, ignore_nulls: bool, is_distinct: bool, - is_reversed: bool, ) -> Result> { - debug_assert_eq!(sort_exprs.len(), ordering_req.len()); - - let input_exprs_types = input_phy_exprs - .iter() - .map(|arg| arg.data_type(schema)) - .collect::>>()?; - - check_arg_count( - fun.name(), - &input_exprs_types, - &fun.signature().type_signature, - )?; - - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(schema)) - .collect::>>()?; - - let ordering_fields = ordering_fields(ordering_req, &ordering_types); - let name = name.into(); - - Ok(Arc::new(AggregateFunctionExpr { - fun: fun.clone(), - args: input_phy_exprs.to_vec(), - logical_args: input_exprs.to_vec(), - data_type: fun.return_type(&input_exprs_types)?, - name, - schema: schema.clone(), - dfschema: DFSchema::empty(), - sort_exprs: sort_exprs.to_vec(), - ordering_req: ordering_req.to_vec(), - ignore_nulls, - ordering_fields, - is_distinct, - input_type: input_exprs_types[0].clone(), - is_reversed, - })) + let mut builder = + AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec()); + builder = builder.sort_exprs(sort_exprs.to_vec()); + builder = builder.order_by(ordering_req.to_vec()); + builder = builder.logical_exprs(input_exprs.to_vec()); + builder = builder.schema(Arc::new(schema.clone())); + builder = builder.name(name); + + if ignore_nulls { + builder = builder.ignore_nulls(); + } + if is_distinct { + builder = builder.distinct(); + } + + builder.build() } #[allow(clippy::too_many_arguments)] @@ -121,44 +101,196 @@ pub fn create_aggregate_expr_with_dfschema( is_distinct: bool, is_reversed: bool, ) -> Result> { - debug_assert_eq!(sort_exprs.len(), ordering_req.len()); - + let mut builder = + AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec()); + builder = builder.sort_exprs(sort_exprs.to_vec()); + builder = builder.order_by(ordering_req.to_vec()); + builder = builder.logical_exprs(input_exprs.to_vec()); + builder = builder.dfschema(dfschema.clone()); let schema: Schema = dfschema.into(); + builder = builder.schema(Arc::new(schema)); + builder = builder.name(name); + + if ignore_nulls { + builder = builder.ignore_nulls(); + } + if is_distinct { + builder = builder.distinct(); + } + if is_reversed { + builder = builder.reversed(); + } + + builder.build() +} + +/// Builder for physical [`AggregateExpr`] +/// +/// `AggregateExpr` contains the information necessary to call +/// an aggregate expression. +#[derive(Debug, Clone)] +pub struct AggregateExprBuilder { + fun: Arc, + /// Physical expressions of the aggregate function + args: Vec>, + /// Logical expressions of the aggregate function, it will be deprecated in + logical_args: Vec, + name: String, + /// Arrow Schema for the aggregate function + schema: SchemaRef, + /// Datafusion Schema for the aggregate function + dfschema: DFSchema, + /// The logical order by expressions, it will be deprecated in + sort_exprs: Vec, + /// The physical order by expressions + ordering_req: LexOrdering, + /// Whether to ignore null values + ignore_nulls: bool, + /// Whether is distinct aggregate function + is_distinct: bool, + /// Whether the expression is reversed + is_reversed: bool, +} + +impl AggregateExprBuilder { + pub fn new(fun: Arc, args: Vec>) -> Self { + Self { + fun, + args, + logical_args: vec![], + name: String::new(), + schema: Arc::new(Schema::empty()), + dfschema: DFSchema::empty(), + sort_exprs: vec![], + ordering_req: vec![], + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + } + } + + pub fn build(self) -> Result> { + let Self { + fun, + args, + logical_args, + name, + schema, + dfschema, + sort_exprs, + ordering_req, + ignore_nulls, + is_distinct, + is_reversed, + } = self; + if args.is_empty() { + return internal_err!("args should not be empty"); + } + + let mut ordering_fields = vec![]; + + debug_assert_eq!(sort_exprs.len(), ordering_req.len()); + if !ordering_req.is_empty() { + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>>()?; + + ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types); + } + + let input_exprs_types = args + .iter() + .map(|arg| arg.data_type(&schema)) + .collect::>>()?; + + check_arg_count( + fun.name(), + &input_exprs_types, + &fun.signature().type_signature, + )?; - let input_exprs_types = input_phy_exprs - .iter() - .map(|arg| arg.data_type(&schema)) - .collect::>>()?; - - check_arg_count( - fun.name(), - &input_exprs_types, - &fun.signature().type_signature, - )?; - - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(&schema)) - .collect::>>()?; - - let ordering_fields = ordering_fields(ordering_req, &ordering_types); - - Ok(Arc::new(AggregateFunctionExpr { - fun: fun.clone(), - args: input_phy_exprs.to_vec(), - logical_args: input_exprs.to_vec(), - data_type: fun.return_type(&input_exprs_types)?, - name: name.into(), - schema: schema.clone(), - dfschema: dfschema.clone(), - sort_exprs: sort_exprs.to_vec(), - ordering_req: ordering_req.to_vec(), - ignore_nulls, - ordering_fields, - is_distinct, - input_type: input_exprs_types[0].clone(), - is_reversed, - })) + let data_type = fun.return_type(&input_exprs_types)?; + + Ok(Arc::new(AggregateFunctionExpr { + fun: Arc::unwrap_or_clone(fun), + args, + logical_args, + data_type, + name, + schema: Arc::unwrap_or_clone(schema), + dfschema, + sort_exprs, + ordering_req, + ignore_nulls, + ordering_fields, + is_distinct, + input_type: input_exprs_types[0].clone(), + is_reversed, + })) + } + + pub fn name(mut self, name: impl Into) -> Self { + self.name = name.into(); + self + } + + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = schema; + self + } + + pub fn dfschema(mut self, dfschema: DFSchema) -> Self { + self.dfschema = dfschema; + self + } + + pub fn order_by(mut self, order_by: LexOrdering) -> Self { + self.ordering_req = order_by; + self + } + + pub fn reversed(mut self) -> Self { + self.is_reversed = true; + self + } + + pub fn with_reversed(mut self, is_reversed: bool) -> Self { + self.is_reversed = is_reversed; + self + } + + pub fn distinct(mut self) -> Self { + self.is_distinct = true; + self + } + + pub fn with_distinct(mut self, is_distinct: bool) -> Self { + self.is_distinct = is_distinct; + self + } + + pub fn ignore_nulls(mut self) -> Self { + self.ignore_nulls = true; + self + } + + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self + } + + /// This method will be deprecated in + pub fn sort_exprs(mut self, sort_exprs: Vec) -> Self { + self.sort_exprs = sort_exprs; + self + } + + /// This method will be deprecated in + pub fn logical_exprs(mut self, logical_args: Vec) -> Self { + self.logical_args = logical_args; + self + } } /// An aggregate expression that: diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e7cd5cb2725b..d1152038eb2a 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1211,7 +1211,7 @@ mod tests { use crate::common::collect; use datafusion_physical_expr_common::aggregate::{ - create_aggregate_expr, create_aggregate_expr_with_dfschema, + create_aggregate_expr_with_dfschema, AggregateExprBuilder, }; use datafusion_physical_expr_common::expressions::Literal; use futures::{FutureExt, Stream}; @@ -1351,18 +1351,11 @@ mod tests { ], }; - let aggregates = vec![create_aggregate_expr( - &count_udaf(), - &[lit(1i8)], - &[datafusion_expr::lit(1i8)], - &[], - &[], - &input_schema, - "COUNT(1)", - false, - false, - false, - )?]; + let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) + .schema(Arc::clone(&input_schema)) + .name("COUNT(1)") + .logical_exprs(vec![datafusion_expr::lit(1i8)]) + .build()?]; let task_ctx = if spill { new_spill_ctx(4, 1000) @@ -1501,18 +1494,13 @@ mod tests { groups: vec![vec![false]], }; - let aggregates: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &input_schema)?], - &[datafusion_expr::col("b")], - &[], - &[], - &input_schema, - "AVG(b)", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .name("AVG(b)") + .build()?, + ]; let task_ctx = if spill { // set to an appropriate value to trigger spill @@ -1803,21 +1791,11 @@ mod tests { } // Median(a) - fn test_median_agg_expr(schema: &Schema) -> Result> { - let args = vec![col("a", schema)?]; - let fun = median_udaf(); - datafusion_physical_expr_common::aggregate::create_aggregate_expr( - &fun, - &args, - &[], - &[], - &[], - schema, - "MEDIAN(a)", - false, - false, - false, - ) + fn test_median_agg_expr(schema: SchemaRef) -> Result> { + AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?]) + .schema(schema) + .name("MEDIAN(a)") + .build() } #[tokio::test] @@ -1840,21 +1818,16 @@ mod tests { // something that allocates within the aggregator let aggregates_v0: Vec> = - vec![test_median_agg_expr(&input_schema)?]; + vec![test_median_agg_expr(Arc::clone(&input_schema))?]; // use fast-path in `row_hash.rs`. - let aggregates_v2: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &input_schema)?], - &[datafusion_expr::col("b")], - &[], - &[], - &input_schema, - "AVG(b)", - false, - false, - false, - )?]; + let aggregates_v2: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .name("AVG(b)") + .build()?, + ]; for (version, groups, aggregates) in [ (0, groups_none, aggregates_v0), @@ -1908,18 +1881,13 @@ mod tests { let groups = PhysicalGroupBy::default(); - let aggregates: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("a", &schema)?], - &[datafusion_expr::col("a")], - &[], - &[], - &schema, - "AVG(a)", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .name("AVG(a)") + .build()?, + ]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -1953,18 +1921,13 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let aggregates: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &schema)?], - &[datafusion_expr::col("b")], - &[], - &[], - &schema, - "AVG(b)", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("AVG(b)") + .build()?, + ]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -2388,18 +2351,11 @@ mod tests { ], ); - let aggregates: Vec> = vec![create_aggregate_expr( - count_udaf().as_ref(), - &[lit(1)], - &[datafusion_expr::lit(1)], - &[], - &[], - schema.as_ref(), - "1", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&schema)) + .name("1") + .build()?]; let input_batches = (0..4) .map(|_| { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 959796489c19..ffe558e21583 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -26,16 +26,16 @@ use crate::{ cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, PhysicalSortExpr, RowNumber, }, - udaf, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, + ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, }; use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{exec_err, Column, DataFusionError, Result, ScalarValue}; -use datafusion_expr::Expr; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{col, Expr, SortExpr}; use datafusion_expr::{ - BuiltInWindowFunction, PartitionEvaluator, SortExpr, WindowFrame, - WindowFunctionDefinition, WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, + WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -44,6 +44,7 @@ use datafusion_physical_expr::{ AggregateExpr, ConstExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, }; +use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; use itertools::Itertools; mod bounded_window_agg_exec; @@ -95,7 +96,7 @@ pub fn create_window_expr( fun: &WindowFunctionDefinition, name: String, args: &[Arc], - logical_args: &[Expr], + _logical_args: &[Expr], partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -129,7 +130,6 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - // TODO: Ordering not supported for Window UDFs yet // Convert `Vec` into `Vec` let sort_exprs = order_by .iter() @@ -137,28 +137,20 @@ pub fn create_window_expr( let field_name = expr.to_string(); let field_name = field_name.split('@').next().unwrap_or(&field_name); Expr::Sort(SortExpr { - expr: Box::new(Expr::Column(Column::new( - None::, - field_name, - ))), + expr: Box::new(col(field_name)), asc: !options.descending, nulls_first: options.nulls_first, }) }) .collect::>(); - let aggregate = udaf::create_aggregate_expr( - fun.as_ref(), - args, - logical_args, - &sort_exprs, - order_by, - input_schema, - name, - ignore_nulls, - false, - false, - )?; + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .name(name) + .order_by(order_by.to_vec()) + .sort_exprs(sort_exprs) + .with_ignore_nulls(ignore_nulls) + .build()?; window_expr_from_aggregate_expr( partition_by, order_by, @@ -166,6 +158,7 @@ pub fn create_window_expr( aggregate, ) } + // TODO: Ordering not supported for Window UDFs yet WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name)?, partition_by, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 8c9e5bbd0e95..5c4d41f0eca6 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -18,6 +18,7 @@ use std::fmt::Debug; use std::sync::Arc; +use datafusion::physical_expr_common::aggregate::AggregateExprBuilder; use prost::bytes::BufMut; use prost::Message; @@ -58,7 +59,7 @@ use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMerge use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, + AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF}; @@ -501,13 +502,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { None => registry.udaf(udaf_name)? }; - // TODO: 'logical_exprs' is not supported for UDAF yet. - // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. - let logical_exprs = &[]; + // TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. // TODO: `order by` is not supported for UDAF yet - let sort_exprs = &[]; - let ordering_req = &[]; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, agg_node.distinct, false) + AggregateExprBuilder::new(agg_udf, input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 31ed0837d2f5..3ddc122e3de2 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -24,6 +24,7 @@ use std::vec; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; +use datafusion::physical_expr_common::aggregate::AggregateExprBuilder; use prost::Message; use datafusion::arrow::array::ArrayRef; @@ -64,7 +65,6 @@ use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::udaf::create_aggregate_expr; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, @@ -86,7 +86,7 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; -use datafusion_functions_aggregate::string_agg::StringAgg; +use datafusion_functions_aggregate::string_agg::string_agg_udaf; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; @@ -291,18 +291,13 @@ fn roundtrip_window() -> Result<()> { )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( - create_aggregate_expr( - &avg_udaf(), - &[cast(col("b", &schema)?, &schema, DataType::Float64)?], - &[], - &[], - &[], - &schema, - "avg(b)", - false, - false, - false, - )?, + AggregateExprBuilder::new( + avg_udaf(), + vec![cast(col("b", &schema)?, &schema, DataType::Float64)?], + ) + .schema(Arc::clone(&schema)) + .name("avg(b)") + .build()?, &[], &[], Arc::new(WindowFrame::new(None)), @@ -315,18 +310,10 @@ fn roundtrip_window() -> Result<()> { ); let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; - let sum_expr = create_aggregate_expr( - &sum_udaf(), - &args, - &[], - &[], - &[], - &schema, - "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", - false, - false, - false, - )?; + let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) + .schema(Arc::clone(&schema)) + .name("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") + .build()?; let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( sum_expr, @@ -357,49 +344,28 @@ fn rountrip_aggregate() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; + let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("AVG(b)") + .build()?; + let nth_expr = + AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?, lit(1u64)]) + .schema(Arc::clone(&schema)) + .name("NTH_VALUE(b, 1)") + .build()?; + let str_agg_expr = + AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?, lit(1u64)]) + .schema(Arc::clone(&schema)) + .name("NTH_VALUE(b, 1)") + .build()?; + let test_cases: Vec>> = vec![ // AVG - vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &schema)?], - &[], - &[], - &[], - &schema, - "AVG(b)", - false, - false, - false, - )?], + vec![avg_expr], // NTH_VALUE - vec![create_aggregate_expr( - &nth_value_udaf(), - &[col("b", &schema)?, lit(1u64)], - &[], - &[], - &[], - &schema, - "NTH_VALUE(b, 1)", - false, - false, - false, - )?], + vec![nth_expr], // STRING_AGG - vec![create_aggregate_expr( - &AggregateUDF::new_from_impl(StringAgg::new()), - &[ - cast(col("b", &schema)?, &schema, DataType::Utf8)?, - lit(ScalarValue::Utf8(Some(",".to_string()))), - ], - &[], - &[], - &[], - &schema, - "STRING_AGG(name, ',')", - false, - false, - false, - )?], + vec![str_agg_expr], ]; for aggregates in test_cases { @@ -426,18 +392,13 @@ fn rountrip_aggregate_with_limit() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &schema)?], - &[], - &[], - &[], - &schema, - "AVG(b)", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("AVG(b)") + .build()?, + ]; let agg = AggregateExec::try_new( AggregateMode::Final, @@ -498,18 +459,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![create_aggregate_expr( - &udaf, - &[col("b", &schema)?], - &[], - &[], - &[], - &schema, - "example_agg", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("example_agg") + .build()?, + ]; roundtrip_test_with_context( Arc::new(AggregateExec::try_new( @@ -994,21 +950,16 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { DataType::Int64, )); - let udaf = AggregateUDF::from(MyAggregateUDF::new("result".to_string())); - let aggr_args: [Arc; 1] = - [Arc::new(Literal::new(ScalarValue::from(42)))]; - let aggr_expr = create_aggregate_expr( - &udaf, - &aggr_args, - &[], - &[], - &[], - &schema, - "aggregate_udf", - false, - false, - false, - )?; + let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( + "result".to_string(), + ))); + let aggr_args: Vec> = + vec![Arc::new(Literal::new(ScalarValue::from(42)))]; + + let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) + .schema(Arc::clone(&schema)) + .name("aggregate_udf") + .build()?; let filter = Arc::new(FilterExec::try_new( Arc::new(BinaryExpr::new( @@ -1030,18 +981,12 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { vec![col("author", &schema)?], )?); - let aggr_expr = create_aggregate_expr( - &udaf, - &aggr_args, - &[], - &[], - &[], - &schema, - "aggregate_udf", - true, - true, - false, - )?; + let aggr_expr = AggregateExprBuilder::new(udaf, aggr_args.clone()) + .schema(Arc::clone(&schema)) + .name("aggregate_udf") + .distinct() + .ignore_nulls() + .build()?; let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, From 5901df58b21b8b4e36011744e7ddc17bcb6a37b3 Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Wed, 24 Jul 2024 12:21:13 -0700 Subject: [PATCH 08/17] feat: add bounds for unary math scalar functions (#11584) * feat: unary udf function bounds * feat: add bounds for more types * feat: remove eprint * fix: add missing bounds file * tests: add tests for unary udf bounds * tests: test f32 and f64 * build: remove unrelated changes * refactor: better unbounded func name * tests: fix tests * refactor: use data_type method * refactor: add more useful intervals to Interval * refactor: use typed bounds for (-inf, inf) * refactor: inf to unbounded * refactor: add lower/upper pi bounds * refactor: consts to consts module * fix: add missing file * fix: docstring typo * refactor: remove unused signum bounds --- datafusion/common/src/scalar/consts.rs | 44 +++ datafusion/common/src/scalar/mod.rs | 119 +++++++ datafusion/expr/src/interval_arithmetic.rs | 32 ++ datafusion/functions/src/macros.rs | 7 +- datafusion/functions/src/math/bounds.rs | 108 +++++++ datafusion/functions/src/math/mod.rs | 302 ++++++++++++++++-- datafusion/functions/src/math/monotonicity.rs | 17 +- 7 files changed, 595 insertions(+), 34 deletions(-) create mode 100644 datafusion/common/src/scalar/consts.rs create mode 100644 datafusion/functions/src/math/bounds.rs diff --git a/datafusion/common/src/scalar/consts.rs b/datafusion/common/src/scalar/consts.rs new file mode 100644 index 000000000000..efcde651841b --- /dev/null +++ b/datafusion/common/src/scalar/consts.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Constants defined for scalar construction. + +// PI ~ 3.1415927 in f32 +#[allow(clippy::approx_constant)] +pub(super) const PI_UPPER_F32: f32 = 3.141593_f32; + +// PI ~ 3.141592653589793 in f64 +pub(super) const PI_UPPER_F64: f64 = 3.141592653589794_f64; + +// -PI ~ -3.1415927 in f32 +#[allow(clippy::approx_constant)] +pub(super) const NEGATIVE_PI_LOWER_F32: f32 = -3.141593_f32; + +// -PI ~ -3.141592653589793 in f64 +pub(super) const NEGATIVE_PI_LOWER_F64: f64 = -3.141592653589794_f64; + +// PI / 2 ~ 1.5707964 in f32 +pub(super) const FRAC_PI_2_UPPER_F32: f32 = 1.5707965_f32; + +// PI / 2 ~ 1.5707963267948966 in f64 +pub(super) const FRAC_PI_2_UPPER_F64: f64 = 1.5707963267948967_f64; + +// -PI / 2 ~ -1.5707964 in f32 +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = -1.5707965_f32; + +// -PI / 2 ~ -1.5707963267948966 in f64 +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F64: f64 = -1.5707963267948967_f64; diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 92ed897e7185..286df339adcf 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -17,7 +17,9 @@ //! [`ScalarValue`]: stores single values +mod consts; mod struct_builder; + use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::{HashSet, VecDeque}; @@ -1007,6 +1009,123 @@ impl ScalarValue { } } + /// Returns a [`ScalarValue`] representing PI + pub fn new_pi(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::PI)), + DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::PI)), + _ => _internal_err!("PI is not supported for data type: {:?}", datatype), + } + } + + /// Returns a [`ScalarValue`] representing PI's upper bound + pub fn new_pi_upper(datatype: &DataType) -> Result { + // TODO: replace the constants with next_up/next_down when + // they are stabilized: https://doc.rust-lang.org/std/primitive.f64.html#method.next_up + match datatype { + DataType::Float32 => Ok(ScalarValue::from(consts::PI_UPPER_F32)), + DataType::Float64 => Ok(ScalarValue::from(consts::PI_UPPER_F64)), + _ => { + _internal_err!("PI_UPPER is not supported for data type: {:?}", datatype) + } + } + } + + /// Returns a [`ScalarValue`] representing -PI's lower bound + pub fn new_negative_pi_lower(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F32)), + DataType::Float64 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F64)), + _ => { + _internal_err!("-PI_LOWER is not supported for data type: {:?}", datatype) + } + } + } + + /// Returns a [`ScalarValue`] representing FRAC_PI_2's upper bound + pub fn new_frac_pi_2_upper(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F32)), + DataType::Float64 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F64)), + _ => { + _internal_err!( + "PI_UPPER/2 is not supported for data type: {:?}", + datatype + ) + } + } + } + + // Returns a [`ScalarValue`] representing FRAC_PI_2's lower bound + pub fn new_neg_frac_pi_2_lower(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => { + Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F32)) + } + DataType::Float64 => { + Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F64)) + } + _ => { + _internal_err!( + "-PI/2_LOWER is not supported for data type: {:?}", + datatype + ) + } + } + } + + /// Returns a [`ScalarValue`] representing -PI + pub fn new_negative_pi(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::PI)), + DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::PI)), + _ => _internal_err!("-PI is not supported for data type: {:?}", datatype), + } + } + + /// Returns a [`ScalarValue`] representing PI/2 + pub fn new_frac_pi_2(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::FRAC_PI_2)), + DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::FRAC_PI_2)), + _ => _internal_err!("PI/2 is not supported for data type: {:?}", datatype), + } + } + + /// Returns a [`ScalarValue`] representing -PI/2 + pub fn new_neg_frac_pi_2(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::FRAC_PI_2)), + DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::FRAC_PI_2)), + _ => _internal_err!("-PI/2 is not supported for data type: {:?}", datatype), + } + } + + /// Returns a [`ScalarValue`] representing infinity + pub fn new_infinity(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(f32::INFINITY)), + DataType::Float64 => Ok(ScalarValue::from(f64::INFINITY)), + _ => { + _internal_err!("Infinity is not supported for data type: {:?}", datatype) + } + } + } + + /// Returns a [`ScalarValue`] representing negative infinity + pub fn new_neg_infinity(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(f32::NEG_INFINITY)), + DataType::Float64 => Ok(ScalarValue::from(f64::NEG_INFINITY)), + _ => { + _internal_err!( + "Negative Infinity is not supported for data type: {:?}", + datatype + ) + } + } + } + /// Create a zero value in the given type. pub fn new_zero(datatype: &DataType) -> Result { Ok(match datatype { diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr/src/interval_arithmetic.rs index 18f92334ff14..d0dd418c78e7 100644 --- a/datafusion/expr/src/interval_arithmetic.rs +++ b/datafusion/expr/src/interval_arithmetic.rs @@ -332,6 +332,38 @@ impl Interval { Ok(Self::new(unbounded_endpoint.clone(), unbounded_endpoint)) } + /// Creates an interval between -1 to 1. + pub fn make_symmetric_unit_interval(data_type: &DataType) -> Result { + Self::try_new( + ScalarValue::new_negative_one(data_type)?, + ScalarValue::new_one(data_type)?, + ) + } + + /// Create an interval from -π to π. + pub fn make_symmetric_pi_interval(data_type: &DataType) -> Result { + Self::try_new( + ScalarValue::new_negative_pi_lower(data_type)?, + ScalarValue::new_pi_upper(data_type)?, + ) + } + + /// Create an interval from -π/2 to π/2. + pub fn make_symmetric_half_pi_interval(data_type: &DataType) -> Result { + Self::try_new( + ScalarValue::new_neg_frac_pi_2_lower(data_type)?, + ScalarValue::new_frac_pi_2_upper(data_type)?, + ) + } + + /// Create an interval from 0 to infinity. + pub fn make_non_negative_infinity_interval(data_type: &DataType) -> Result { + Self::try_new( + ScalarValue::new_zero(data_type)?, + ScalarValue::try_from(data_type)?, + ) + } + /// Returns a reference to the lower bound. pub fn lower(&self) -> &ScalarValue { &self.lower diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index cae689b3e0cb..e26c94e1bb79 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -162,7 +162,7 @@ macro_rules! downcast_arg { /// $UNARY_FUNC: the unary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function macro_rules! make_math_unary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr) => { + ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr) => { make_udf_function!($NAME::$UDF, $GNAME, $NAME); mod $NAME { @@ -172,6 +172,7 @@ macro_rules! make_math_unary_udf { use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; + use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -222,6 +223,10 @@ macro_rules! make_math_unary_udf { $OUTPUT_ORDERING(input) } + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + $EVALUATE_BOUNDS(inputs) + } + fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; diff --git a/datafusion/functions/src/math/bounds.rs b/datafusion/functions/src/math/bounds.rs new file mode 100644 index 000000000000..894d2bded5eb --- /dev/null +++ b/datafusion/functions/src/math/bounds.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::ScalarValue; +use datafusion_expr::interval_arithmetic::Interval; + +pub(super) fn unbounded_bounds(input: &[&Interval]) -> crate::Result { + let data_type = input[0].data_type(); + + Interval::make_unbounded(&data_type) +} + +pub(super) fn sin_bounds(input: &[&Interval]) -> crate::Result { + // sin(x) is bounded by [-1, 1] + let data_type = input[0].data_type(); + + Interval::make_symmetric_unit_interval(&data_type) +} + +pub(super) fn asin_bounds(input: &[&Interval]) -> crate::Result { + // asin(x) is bounded by [-π/2, π/2] + let data_type = input[0].data_type(); + + Interval::make_symmetric_half_pi_interval(&data_type) +} + +pub(super) fn atan_bounds(input: &[&Interval]) -> crate::Result { + // atan(x) is bounded by [-π/2, π/2] + let data_type = input[0].data_type(); + + Interval::make_symmetric_half_pi_interval(&data_type) +} + +pub(super) fn acos_bounds(input: &[&Interval]) -> crate::Result { + // acos(x) is bounded by [0, π] + let data_type = input[0].data_type(); + + Interval::try_new( + ScalarValue::new_zero(&data_type)?, + ScalarValue::new_pi_upper(&data_type)?, + ) +} + +pub(super) fn acosh_bounds(input: &[&Interval]) -> crate::Result { + // acosh(x) is bounded by [0, ∞) + let data_type = input[0].data_type(); + + Interval::make_non_negative_infinity_interval(&data_type) +} + +pub(super) fn cos_bounds(input: &[&Interval]) -> crate::Result { + // cos(x) is bounded by [-1, 1] + let data_type = input[0].data_type(); + + Interval::make_symmetric_unit_interval(&data_type) +} + +pub(super) fn cosh_bounds(input: &[&Interval]) -> crate::Result { + // cosh(x) is bounded by [1, ∞) + let data_type = input[0].data_type(); + + Interval::try_new( + ScalarValue::new_one(&data_type)?, + ScalarValue::try_from(&data_type)?, + ) +} + +pub(super) fn exp_bounds(input: &[&Interval]) -> crate::Result { + // exp(x) is bounded by [0, ∞) + let data_type = input[0].data_type(); + + Interval::make_non_negative_infinity_interval(&data_type) +} + +pub(super) fn radians_bounds(input: &[&Interval]) -> crate::Result { + // radians(x) is bounded by (-π, π) + let data_type = input[0].data_type(); + + Interval::make_symmetric_pi_interval(&data_type) +} + +pub(super) fn sqrt_bounds(input: &[&Interval]) -> crate::Result { + // sqrt(x) is bounded by [0, ∞) + let data_type = input[0].data_type(); + + Interval::make_non_negative_infinity_interval(&data_type) +} + +pub(super) fn tanh_bounds(input: &[&Interval]) -> crate::Result { + // tanh(x) is bounded by (-1, 1) + let data_type = input[0].data_type(); + + Interval::make_symmetric_unit_interval(&data_type) +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 3b32a158b884..1e41fff289a4 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -22,6 +22,7 @@ use datafusion_expr::ScalarUDF; use std::sync::Arc; pub mod abs; +pub mod bounds; pub mod cot; pub mod factorial; pub mod gcd; @@ -40,36 +41,142 @@ pub mod trunc; // Create UDFs make_udf_function!(abs::AbsFunc, ABS, abs); -make_math_unary_udf!(AcosFunc, ACOS, acos, acos, super::acos_order); -make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, super::acosh_order); -make_math_unary_udf!(AsinFunc, ASIN, asin, asin, super::asin_order); -make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, super::asinh_order); -make_math_unary_udf!(AtanFunc, ATAN, atan, atan, super::atan_order); -make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, super::atanh_order); +make_math_unary_udf!( + AcosFunc, + ACOS, + acos, + acos, + super::acos_order, + super::bounds::acos_bounds +); +make_math_unary_udf!( + AcoshFunc, + ACOSH, + acosh, + acosh, + super::acosh_order, + super::bounds::acosh_bounds +); +make_math_unary_udf!( + AsinFunc, + ASIN, + asin, + asin, + super::asin_order, + super::bounds::asin_bounds +); +make_math_unary_udf!( + AsinhFunc, + ASINH, + asinh, + asinh, + super::asinh_order, + super::bounds::unbounded_bounds +); +make_math_unary_udf!( + AtanFunc, + ATAN, + atan, + atan, + super::atan_order, + super::bounds::atan_bounds +); +make_math_unary_udf!( + AtanhFunc, + ATANH, + atanh, + atanh, + super::atanh_order, + super::bounds::unbounded_bounds +); make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, super::atan2_order); -make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, super::cbrt_order); -make_math_unary_udf!(CeilFunc, CEIL, ceil, ceil, super::ceil_order); -make_math_unary_udf!(CosFunc, COS, cos, cos, super::cos_order); -make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, super::cosh_order); +make_math_unary_udf!( + CbrtFunc, + CBRT, + cbrt, + cbrt, + super::cbrt_order, + super::bounds::unbounded_bounds +); +make_math_unary_udf!( + CeilFunc, + CEIL, + ceil, + ceil, + super::ceil_order, + super::bounds::unbounded_bounds +); +make_math_unary_udf!( + CosFunc, + COS, + cos, + cos, + super::cos_order, + super::bounds::cos_bounds +); +make_math_unary_udf!( + CoshFunc, + COSH, + cosh, + cosh, + super::cosh_order, + super::bounds::cosh_bounds +); make_udf_function!(cot::CotFunc, COT, cot); make_math_unary_udf!( DegreesFunc, DEGREES, degrees, to_degrees, - super::degrees_order + super::degrees_order, + super::bounds::unbounded_bounds +); +make_math_unary_udf!( + ExpFunc, + EXP, + exp, + exp, + super::exp_order, + super::bounds::exp_bounds ); -make_math_unary_udf!(ExpFunc, EXP, exp, exp, super::exp_order); make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial); -make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, super::floor_order); +make_math_unary_udf!( + FloorFunc, + FLOOR, + floor, + floor, + super::floor_order, + super::bounds::unbounded_bounds +); make_udf_function!(log::LogFunc, LOG, log); make_udf_function!(gcd::GcdFunc, GCD, gcd); make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero); make_udf_function!(lcm::LcmFunc, LCM, lcm); -make_math_unary_udf!(LnFunc, LN, ln, ln, super::ln_order); -make_math_unary_udf!(Log2Func, LOG2, log2, log2, super::log2_order); -make_math_unary_udf!(Log10Func, LOG10, log10, log10, super::log10_order); +make_math_unary_udf!( + LnFunc, + LN, + ln, + ln, + super::ln_order, + super::bounds::unbounded_bounds +); +make_math_unary_udf!( + Log2Func, + LOG2, + log2, + log2, + super::log2_order, + super::bounds::unbounded_bounds +); +make_math_unary_udf!( + Log10Func, + LOG10, + log10, + log10, + super::log10_order, + super::bounds::unbounded_bounds +); make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl); make_udf_function!(pi::PiFunc, PI, pi); make_udf_function!(power::PowerFunc, POWER, power); @@ -78,16 +185,52 @@ make_math_unary_udf!( RADIANS, radians, to_radians, - super::radians_order + super::radians_order, + super::bounds::radians_bounds ); make_udf_function!(random::RandomFunc, RANDOM, random); make_udf_function!(round::RoundFunc, ROUND, round); make_udf_function!(signum::SignumFunc, SIGNUM, signum); -make_math_unary_udf!(SinFunc, SIN, sin, sin, super::sin_order); -make_math_unary_udf!(SinhFunc, SINH, sinh, sinh, super::sinh_order); -make_math_unary_udf!(SqrtFunc, SQRT, sqrt, sqrt, super::sqrt_order); -make_math_unary_udf!(TanFunc, TAN, tan, tan, super::tan_order); -make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, super::tanh_order); +make_math_unary_udf!( + SinFunc, + SIN, + sin, + sin, + super::sin_order, + super::bounds::sin_bounds +); +make_math_unary_udf!( + SinhFunc, + SINH, + sinh, + sinh, + super::sinh_order, + super::bounds::unbounded_bounds +); +make_math_unary_udf!( + SqrtFunc, + SQRT, + sqrt, + sqrt, + super::sqrt_order, + super::bounds::sqrt_bounds +); +make_math_unary_udf!( + TanFunc, + TAN, + tan, + tan, + super::tan_order, + super::bounds::unbounded_bounds +); +make_math_unary_udf!( + TanhFunc, + TANH, + tanh, + tanh, + super::tanh_order, + super::bounds::tanh_bounds +); make_udf_function!(trunc::TruncFunc, TRUNC, trunc); pub mod expr_fn { @@ -175,3 +318,118 @@ pub fn functions() -> Vec> { trunc(), ] } + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; + use datafusion_expr::interval_arithmetic::Interval; + + fn unbounded_interval(data_type: &DataType) -> Interval { + Interval::make_unbounded(data_type).unwrap() + } + + fn one_to_inf_interval(data_type: &DataType) -> Interval { + Interval::try_new( + ScalarValue::new_one(data_type).unwrap(), + ScalarValue::try_from(data_type).unwrap(), + ) + .unwrap() + } + + fn zero_to_pi_interval(data_type: &DataType) -> Interval { + Interval::try_new( + ScalarValue::new_zero(data_type).unwrap(), + ScalarValue::new_pi_upper(data_type).unwrap(), + ) + .unwrap() + } + + fn assert_udf_evaluates_to_bounds( + udf: &datafusion_expr::ScalarUDF, + interval: Interval, + expected: Interval, + ) { + let input = vec![&interval]; + let result = udf.evaluate_bounds(&input).unwrap(); + assert_eq!( + result, + expected, + "Bounds check failed on UDF: {:?}", + udf.name() + ); + } + + #[test] + fn test_cases() -> crate::Result<()> { + let datatypes = [DataType::Float32, DataType::Float64]; + let cases = datatypes + .iter() + .flat_map(|data_type| { + vec![ + ( + super::acos(), + unbounded_interval(data_type), + zero_to_pi_interval(data_type), + ), + ( + super::acosh(), + unbounded_interval(data_type), + Interval::make_non_negative_infinity_interval(data_type).unwrap(), + ), + ( + super::asin(), + unbounded_interval(data_type), + Interval::make_symmetric_half_pi_interval(data_type).unwrap(), + ), + ( + super::atan(), + unbounded_interval(data_type), + Interval::make_symmetric_half_pi_interval(data_type).unwrap(), + ), + ( + super::cos(), + unbounded_interval(data_type), + Interval::make_symmetric_unit_interval(data_type).unwrap(), + ), + ( + super::cosh(), + unbounded_interval(data_type), + one_to_inf_interval(data_type), + ), + ( + super::sin(), + unbounded_interval(data_type), + Interval::make_symmetric_unit_interval(data_type).unwrap(), + ), + ( + super::exp(), + unbounded_interval(data_type), + Interval::make_non_negative_infinity_interval(data_type).unwrap(), + ), + ( + super::sqrt(), + unbounded_interval(data_type), + Interval::make_non_negative_infinity_interval(data_type).unwrap(), + ), + ( + super::radians(), + unbounded_interval(data_type), + Interval::make_symmetric_pi_interval(data_type).unwrap(), + ), + ( + super::sqrt(), + unbounded_interval(data_type), + Interval::make_non_negative_infinity_interval(data_type).unwrap(), + ), + ] + }) + .collect::>(); + + for (udf, interval, expected) in cases { + assert_udf_evaluates_to_bounds(&udf, interval, expected); + } + + Ok(()) + } +} diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index 33c061ee11d0..52f2ec517198 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -15,24 +15,17 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -fn symmetric_unit_interval(data_type: &DataType) -> Result { - Interval::try_new( - ScalarValue::new_negative_one(data_type)?, - ScalarValue::new_one(data_type)?, - ) -} - /// Non-increasing on the interval \[−1, 1\], undefined otherwise. pub fn acos_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; let range = &arg.range; - let valid_domain = symmetric_unit_interval(&range.lower().data_type())?; + let valid_domain = + Interval::make_symmetric_unit_interval(&range.lower().data_type())?; if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { Ok(-arg.sort_properties) @@ -63,7 +56,8 @@ pub fn asin_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; let range = &arg.range; - let valid_domain = symmetric_unit_interval(&range.lower().data_type())?; + let valid_domain = + Interval::make_symmetric_unit_interval(&range.lower().data_type())?; if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { Ok(arg.sort_properties) @@ -87,7 +81,8 @@ pub fn atanh_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; let range = &arg.range; - let valid_domain = symmetric_unit_interval(&range.lower().data_type())?; + let valid_domain = + Interval::make_symmetric_unit_interval(&range.lower().data_type())?; if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { Ok(arg.sort_properties) From bcf715c892f74d48bdbef54ac7165358be6fb741 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jul 2024 15:22:02 -0400 Subject: [PATCH 09/17] Minor: avoid copying order by exprs in planner (#11634) --- datafusion/sql/src/expr/function.rs | 6 +++--- datafusion/sql/src/expr/order_by.rs | 10 ++++------ datafusion/sql/src/query.rs | 2 +- datafusion/sql/src/select.rs | 2 +- datafusion/sql/src/statement.rs | 2 +- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 4804752d8389..0c4b125e76d0 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -274,7 +274,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; let mut order_by = self.order_by_to_sort_expr( - &window.order_by, + window.order_by, schema, planner_context, // Numeric literals in window function ORDER BY are treated as constants @@ -350,7 +350,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { let order_by = self.order_by_to_sort_expr( - &order_by, + order_by, schema, planner_context, true, @@ -375,7 +375,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, aggregate built-ins if let Ok(fun) = AggregateFunction::from_str(&name) { let order_by = self.order_by_to_sort_expr( - &order_by, + order_by, schema, planner_context, true, diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 4dd81517e958..6010da6fd325 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -37,7 +37,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, - exprs: &[OrderByExpr], + exprs: Vec, input_schema: &DFSchema, planner_context: &mut PlannerContext, literal_to_column: bool, @@ -87,11 +87,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input_schema.qualified_field(field_index - 1), )) } - e => self.sql_expr_to_logical_expr( - e.clone(), - order_by_schema, - planner_context, - )?, + e => { + self.sql_expr_to_logical_expr(e, order_by_schema, planner_context)? + } }; let asc = asc.unwrap_or(true); expr_vec.push(Expr::Sort(Sort::new( diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index cbbff19321d8..00560b5c9308 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -59,7 +59,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { other => { let plan = self.set_expr_to_plan(other, planner_context)?; let order_by_rex = self.order_by_to_sort_expr( - &query.order_by, + query.order_by, plan.schema(), planner_context, true, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index fc46c3a841b5..9b105117af15 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -101,7 +101,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Order-by expressions prioritize referencing columns from the select list, // then from the FROM clause. let order_by_rex = self.order_by_to_sort_expr( - &order_by, + order_by, projected_plan.schema().as_ref(), planner_context, true, diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 8eb4113f80a6..67107bae0202 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -967,7 +967,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { for expr in order_exprs { // Convert each OrderByExpr to a SortExpr: let expr_vec = - self.order_by_to_sort_expr(&expr, schema, planner_context, true, None)?; + self.order_by_to_sort_expr(expr, schema, planner_context, true, None)?; // Verify that columns of all SortExprs exist in the schema: for expr in expr_vec.iter() { for column in expr.column_refs().iter() { From 20b298e9d82e483e28087e595c409a8cc04872f3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 24 Jul 2024 15:34:31 -0600 Subject: [PATCH 10/17] perf: Optimize IsNotNullExpr (#11586) * add criterion benchmarks for IsNullExpr and IsNotNullExpr * Improve IsNotNull performance by avoiding calling is_null then not and just calling is_not_null kernel directly * fast path if input array is all nulls or no nulls * revert experimental change * remove unused import * simplify PR --- datafusion/physical-expr/Cargo.toml | 4 + datafusion/physical-expr/benches/is_null.rs | 95 +++++++++++++++++++ .../src/expressions/is_not_null.rs | 4 +- .../physical-expr/src/expressions/is_null.rs | 10 ++ 4 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 datafusion/physical-expr/benches/is_null.rs diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 067617a697a9..8436b5279bd7 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -81,3 +81,7 @@ name = "in_list" [[bench]] harness = false name = "case_when" + +[[bench]] +harness = false +name = "is_null" diff --git a/datafusion/physical-expr/benches/is_null.rs b/datafusion/physical-expr/benches/is_null.rs new file mode 100644 index 000000000000..3dad8e9b456a --- /dev/null +++ b/datafusion/physical-expr/benches/is_null.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_array::builder::Int32Builder; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr::expressions::{IsNotNullExpr, IsNullExpr}; +use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = Int32Builder::new(); + let mut c3 = Int32Builder::new(); + for i in 0..1000 { + // c1 is always null + c1.append_null(); + // c2 is never null + c2.append_value(i); + // c3 is a mix of values and nulls + if i % 7 == 0 { + c3.append_null(); + } else { + c3.append_value(i); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let c3 = Arc::new(c3.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, false), + Field::new("c3", DataType::Int32, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); + + c.bench_function("is_null: column is all nulls", |b| { + let expr = is_null("c1", 0); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_null: column is never null", |b| { + let expr = is_null("c2", 1); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_null: column is mix of values and nulls", |b| { + let expr = is_null("c3", 2); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_not_null: column is all nulls", |b| { + let expr = is_not_null("c1", 0); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_not_null: column is never null", |b| { + let expr = is_not_null("c2", 1); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_not_null: column is mix of values and nulls", |b| { + let expr = is_not_null("c3", 2); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); +} + +fn is_null(name: &str, index: usize) -> Arc { + Arc::new(IsNullExpr::new(Arc::new(Column::new(name, index)))) +} + +fn is_not_null(name: &str, index: usize) -> Arc { + Arc::new(IsNotNullExpr::new(Arc::new(Column::new(name, index)))) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 9f7438d13e05..58559352d44c 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -22,7 +22,6 @@ use std::{any::Any, sync::Arc}; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; -use arrow::compute; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -74,8 +73,7 @@ impl PhysicalExpr for IsNotNullExpr { let arg = self.arg.evaluate(batch)?; match arg { ColumnarValue::Array(array) => { - let is_null = super::is_null::compute_is_null(array)?; - let is_not_null = compute::not(&is_null)?; + let is_not_null = super::is_null::compute_is_not_null(array)?; Ok(ColumnarValue::Array(Arc::new(is_not_null))) } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index e2dc941e26bc..3cdb49bcab42 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -117,6 +117,16 @@ pub(crate) fn compute_is_null(array: ArrayRef) -> Result { } } +/// workaround , +/// this can be replaced with a direct call to `arrow::compute::is_not_null` once it's fixed. +pub(crate) fn compute_is_not_null(array: ArrayRef) -> Result { + if array.as_any().is::() { + compute::not(&compute_is_null(array)?).map_err(Into::into) + } else { + compute::is_not_null(array.as_ref()).map_err(Into::into) + } +} + fn dense_union_is_null( union_array: &UnionArray, offsets: &ScalarBuffer, From 76039fadd934a9f0798fee160877a4247c71c352 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Thu, 25 Jul 2024 01:30:55 +0200 Subject: [PATCH 11/17] Unify CI and pre-commit hook settings for clippy (#11640) pre-commit hook runs clippy and CI scripts run clippy too. The commands are and should be the same. Let's define them once. --- pre-commit.sh | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pre-commit.sh b/pre-commit.sh index 09cf431a1409..c81dd9f8e5e8 100755 --- a/pre-commit.sh +++ b/pre-commit.sh @@ -57,13 +57,7 @@ fi # 1. cargo clippy echo -e "$(GREEN INFO): cargo clippy ..." - -# Cargo clippy always return exit code 0, and `tee` doesn't work. -# So let's just run cargo clippy. -cargo clippy --all-targets --workspace --features avro,pyarrow -- -D warnings -pushd datafusion-cli -cargo clippy --all-targets --all-features -- -D warnings -popd +./ci/scripts/rust_clippy.sh echo -e "$(GREEN INFO): cargo clippy done" # 2. cargo fmt: format with nightly and stable. From 886e8accdaa85d7b3dca45340b955437786a9b6a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 24 Jul 2024 19:54:57 -0400 Subject: [PATCH 12/17] Consistent API to set parameters of aggregate and window functions (`AggregateExt` --> `ExprFunctionExt`) (#11550) * Moving over AggregateExt to ExprFunctionExt and adding in function settings for window functions * Switch WindowFrame to only need the window function definition and arguments. Other parameters will be set via the ExprFuncBuilder * Changing null_treatment to take an option, but this is mostly for code cleanliness and not strictly required * Moving functions in ExprFuncBuilder over to be explicitly implementing ExprFunctionExt trait so we can guarantee a consistent user experience no matter which they call on the Expr and which on the builder * Apply cargo fmt * Add deprecated trait AggregateExt so that users get a warning but still builds * Window helper functions should return Expr * Update documentation to show window function example * Add license info * Update comments that are no longer applicable * Remove first_value and last_value since these are already implemented in the aggregate functions * Update to use WindowFunction::new to set additional parameters for order_by using ExprFunctionExt * Apply cargo fmt * Fix up clippy * fix doc example * fmt * doc tweaks * more doc tweaks * fix up links * fix integration test * fix anothr doc example --------- Co-authored-by: Tim Saucer Co-authored-by: Andrew Lamb --- datafusion-examples/examples/advanced_udwf.rs | 12 +- datafusion-examples/examples/expr_api.rs | 4 +- datafusion-examples/examples/simple_udwf.rs | 12 +- datafusion/core/src/dataframe/mod.rs | 13 +- datafusion/core/tests/dataframe/mod.rs | 22 +- datafusion/core/tests/expr_api/mod.rs | 2 +- datafusion/expr/src/expr.rs | 85 ++++-- datafusion/expr/src/expr_fn.rs | 279 +++++++++++++++++- datafusion/expr/src/lib.rs | 3 +- datafusion/expr/src/tree_node.rs | 17 +- datafusion/expr/src/udaf.rs | 177 +---------- datafusion/expr/src/udwf.rs | 47 ++- datafusion/expr/src/utils.rs | 89 +++--- datafusion/expr/src/window_function.rs | 99 +++++++ .../functions-aggregate/src/first_last.rs | 4 +- .../src/analyzer/count_wildcard_rule.rs | 18 +- .../optimizer/src/analyzer/type_coercion.rs | 21 +- .../optimizer/src/optimize_projections/mod.rs | 17 +- .../src/replace_distinct_aggregate.rs | 2 +- .../simplify_expressions/expr_simplifier.rs | 24 +- .../src/single_distinct_to_groupby.rs | 2 +- .../proto/src/logical_plan/from_proto.rs | 46 +-- .../tests/cases/roundtrip_logical_plan.rs | 77 ++--- datafusion/sql/src/expr/function.rs | 25 +- datafusion/sql/src/unparser/expr.rs | 2 +- docs/source/user-guide/expressions.md | 2 +- 26 files changed, 657 insertions(+), 444 deletions(-) create mode 100644 datafusion/expr/src/window_function.rs diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 11fb6f6ccc48..ec0318a561b9 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -216,12 +216,12 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call( - vec![col("speed")], // smooth_it(speed) - vec![col("car")], // PARTITION BY car - vec![col("time").sort(true, true)], // ORDER BY time ASC - WindowFrame::new(None), - ); + let window_expr = smooth_it + .call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .window_frame(WindowFrame::new(None)) + .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; // print the results diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index a48171c625a8..0eb823302acf 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator}; +use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// @@ -95,7 +95,7 @@ fn expr_fn_demo() -> Result<()> { let agg = first_value.call(vec![col("price")]); assert_eq!(agg.to_string(), "first_value(price)"); - // You can use the AggregateExt trait to create more complex aggregates + // You can use the ExprFunctionExt trait to create more complex aggregates // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) let agg = first_value .call(vec![col("price")]) diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index 563f02cee6a6..22dfbbbf0c3a 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -118,12 +118,12 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call( - vec![col("speed")], // smooth_it(speed) - vec![col("car")], // PARTITION BY car - vec![col("time").sort(true, true)], // ORDER BY time ASC - WindowFrame::new(None), - ); + let window_expr = smooth_it + .call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .window_frame(WindowFrame::new(None)) + .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; // print the results diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index fb28b5c1ab47..ea437cc99a33 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,8 +1696,8 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, - Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, + ScalarFunctionImplementation, Volatility, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; @@ -1867,11 +1867,10 @@ mod tests { BuiltInWindowFunction::FirstValue, ), vec![col("aggregate_test_100.c1")], - vec![col("aggregate_test_100.c2")], - vec![], - WindowFrame::new(None), - None, - )); + )) + .partition_by(vec![col("aggregate_test_100.c2")]) + .build() + .unwrap(); let t2 = t.select(vec![col("c1"), first_row])?; let plan = t2.plan.clone(); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index bc01ada1e04b..d83a47ceb069 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -55,8 +55,8 @@ use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; @@ -183,15 +183,15 @@ async fn test_count_wildcard_on_window() -> Result<()> { .select(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], - WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap()])? .explain(false, false)? .collect() .await?; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 37d06355d2d3..051d65652633 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -21,7 +21,7 @@ use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field}; use datafusion::prelude::*; use datafusion_common::{assert_contains, DFSchema, ScalarValue}; -use datafusion_expr::AggregateExt; +use datafusion_expr::ExprFunctionExt; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 452c05be34f4..68d5504eea48 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,8 +28,8 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::{ - aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator, - Signature, + aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, + ExprSchemable, Operator, Signature, WindowFrame, WindowUDF, }; use crate::{window_frame, Volatility}; @@ -60,6 +60,10 @@ use sqlparser::ast::NullTreatment; /// use the fluent APIs in [`crate::expr_fn`] such as [`col`] and [`lit`], or /// methods such as [`Expr::alias`], [`Expr::cast_to`], and [`Expr::Like`]). /// +/// See also [`ExprFunctionExt`] for creating aggregate and window functions. +/// +/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt +/// /// # Schema Access /// /// See [`ExprSchemable::get_type`] to access the [`DataType`] and nullability @@ -283,15 +287,17 @@ pub enum Expr { /// This expression is guaranteed to have a fixed type. TryCast(TryCast), /// A sort expression, that can be used to sort values. + /// + /// See [Expr::sort] for more details Sort(Sort), /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), /// Calls an aggregate function with arguments, and optional /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. /// - /// See also [`AggregateExt`] to set these fields. + /// See also [`ExprFunctionExt`] to set these fields. /// - /// [`AggregateExt`]: crate::udaf::AggregateExt + /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), @@ -641,9 +647,9 @@ impl AggregateFunctionDefinition { /// Aggregate function /// -/// See also [`AggregateExt`] to set these fields on `Expr` +/// See also [`ExprFunctionExt`] to set these fields on `Expr` /// -/// [`AggregateExt`]: crate::udaf::AggregateExt +/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function @@ -769,7 +775,52 @@ impl fmt::Display for WindowFunctionDefinition { } } +impl From for WindowFunctionDefinition { + fn from(value: aggregate_function::AggregateFunction) -> Self { + Self::AggregateFunction(value) + } +} + +impl From for WindowFunctionDefinition { + fn from(value: BuiltInWindowFunction) -> Self { + Self::BuiltInWindowFunction(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::AggregateUDF(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::WindowUDF(value) + } +} + /// Window function +/// +/// Holds the actual actual function to call [`WindowFunction`] as well as its +/// arguments (`args`) and the contents of the `OVER` clause: +/// +/// 1. `PARTITION BY` +/// 2. `ORDER BY` +/// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`) +/// +/// # Example +/// ``` +/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt}; +/// # use datafusion_expr::expr::WindowFunction; +/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c) +/// let expr = Expr::WindowFunction( +/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")]) +/// ) +/// .partition_by(vec![col("b")]) +/// .order_by(vec![col("b").sort(true, true)]) +/// .build() +/// .unwrap(); +/// ``` #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function @@ -787,22 +838,16 @@ pub struct WindowFunction { } impl WindowFunction { - /// Create a new Window expression - pub fn new( - fun: WindowFunctionDefinition, - args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: window_frame::WindowFrame, - null_treatment: Option, - ) -> Self { + /// Create a new Window expression with the specified argument an + /// empty `OVER` clause + pub fn new(fun: impl Into, args: Vec) -> Self { Self { - fun, + fun: fun.into(), args, - partition_by, - order_by, - window_frame, - null_treatment, + partition_by: Vec::default(), + order_by: Vec::default(), + window_frame: WindowFrame::new(None), + null_treatment: None, } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9187e8352205..1f51cded2239 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,7 +19,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - Placeholder, TryCast, Unnest, + Placeholder, TryCast, Unnest, WindowFunction, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -30,12 +30,15 @@ use crate::{ AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; -use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; +use crate::{ + AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, +}; use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{Column, Result, ScalarValue}; +use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; use std::ops::Not; @@ -664,6 +667,276 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) } +/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] +/// +/// Adds methods to [`Expr`] that make it easy to set optional options +/// such as `ORDER BY`, `FILTER` and `DISTINCT` +/// +/// # Example +/// ```no_run +/// # use datafusion_common::Result; +/// # use datafusion_expr::test::function_stub::count; +/// # use sqlparser::ast::NullTreatment; +/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col}; +/// # use datafusion_expr::window_function::percent_rank; +/// # // first_value is an aggregate function in another crate +/// # fn first_value(_arg: Expr) -> Expr { +/// unimplemented!() } +/// # fn main() -> Result<()> { +/// // Create an aggregate count, filtering on column y > 5 +/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?; +/// +/// // Find the first value in an aggregate sorted by column y +/// // equivalent to: +/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)` +/// let sort_expr = col("y").sort(true, true); +/// let agg = first_value(col("x")) +/// .order_by(vec![sort_expr]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// +/// // Create a window expression for percent rank partitioned on column a +/// // equivalent to: +/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)` +/// let window = percent_rank() +/// .partition_by(vec![col("a")]) +/// .order_by(vec![col("b").sort(true, true)]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// # Ok(()) +/// # } +/// ``` +pub trait ExprFunctionExt { + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + fn order_by(self, order_by: Vec) -> ExprFuncBuilder; + /// Add `FILTER ` + fn filter(self, filter: Expr) -> ExprFuncBuilder; + /// Add `DISTINCT` + fn distinct(self) -> ExprFuncBuilder; + /// Add `RESPECT NULLS` or `IGNORE NULLS` + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder; + /// Add `PARTITION BY` + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; + /// Add appropriate window frame conditions + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder; +} + +#[derive(Debug, Clone)] +pub enum ExprFuncKind { + Aggregate(AggregateFunction), + Window(WindowFunction), +} + +/// Implementation of [`ExprFunctionExt`]. +/// +/// See [`ExprFunctionExt`] for usage and examples +#[derive(Debug, Clone)] +pub struct ExprFuncBuilder { + fun: Option, + order_by: Option>, + filter: Option, + distinct: bool, + null_treatment: Option, + partition_by: Option>, + window_frame: Option, +} + +impl ExprFuncBuilder { + /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`] + fn new(fun: Option) -> Self { + Self { + fun, + order_by: None, + filter: None, + distinct: false, + null_treatment: None, + partition_by: None, + window_frame: None, + } + } + + /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] + /// + /// # Errors: + /// + /// Returns an error if this builder [`ExprFunctionExt`] was used with an + /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] + pub fn build(self) -> Result { + let Self { + fun, + order_by, + filter, + distinct, + null_treatment, + partition_by, + window_frame, + } = self; + + let Some(fun) = fun else { + return plan_err!( + "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction" + ); + }; + + if let Some(order_by) = &order_by { + for expr in order_by.iter() { + if !matches!(expr, Expr::Sort(_)) { + return plan_err!( + "ORDER BY expressions must be Expr::Sort, found {expr:?}" + ); + } + } + } + + let fun_expr = match fun { + ExprFuncKind::Aggregate(mut udaf) => { + udaf.order_by = order_by; + udaf.filter = filter.map(Box::new); + udaf.distinct = distinct; + udaf.null_treatment = null_treatment; + Expr::AggregateFunction(udaf) + } + ExprFuncKind::Window(mut udwf) => { + let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); + udwf.order_by = order_by.unwrap_or_default(); + udwf.partition_by = partition_by.unwrap_or_default(); + udwf.window_frame = + window_frame.unwrap_or(WindowFrame::new(has_order_by)); + udwf.null_treatment = null_treatment; + Expr::WindowFunction(udwf) + } + }; + + Ok(fun_expr) + } +} + +impl ExprFunctionExt for ExprFuncBuilder { + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + self.order_by = Some(order_by); + self + } + + /// Add `FILTER ` + fn filter(mut self, filter: Expr) -> ExprFuncBuilder { + self.filter = Some(filter); + self + } + + /// Add `DISTINCT` + fn distinct(mut self) -> ExprFuncBuilder { + self.distinct = true; + self + } + + /// Add `RESPECT NULLS` or `IGNORE NULLS` + fn null_treatment( + mut self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { + self.null_treatment = null_treatment.into(); + self + } + + fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { + self.partition_by = Some(partition_by); + self + } + + fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { + self.window_frame = Some(window_frame); + self + } +} + +impl ExprFunctionExt for Expr { + fn order_by(self, order_by: Vec) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.order_by = Some(order_by); + } + builder + } + fn filter(self, filter: Expr) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.filter = Some(filter); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn distinct(self) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.distinct = true; + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.null_treatment = null_treatment.into(); + } + builder + } + + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.partition_by = Some(partition_by); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.window_frame = Some(window_frame); + builder + } + _ => ExprFuncBuilder::new(None), + } + } +} + #[cfg(test)] mod test { use super::*; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e1943c890e7c..0a5cf4653a22 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -60,6 +60,7 @@ pub mod type_coercion; pub mod utils; pub mod var_provider; pub mod window_frame; +pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; @@ -86,7 +87,7 @@ pub use signature::{ }; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF}; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f1df8609f903..a97b9f010f79 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -22,7 +22,7 @@ use crate::expr::{ Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; -use crate::Expr; +use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, @@ -294,14 +294,13 @@ impl TreeNode for Expr { transform_vec(order_by, &mut f) )? .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new( - fun, - new_args, - new_partition_by, - new_order_by, - window_frame, - null_treatment, - )) + Expr::WindowFunction(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() }), Expr::AggregateFunction(AggregateFunction { args, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 2851ca811e0c..8867a478f790 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -24,9 +24,8 @@ use std::sync::Arc; use std::vec; use arrow::datatypes::{DataType, Field}; -use sqlparser::ast::NullTreatment; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; use crate::expr::AggregateFunction; use crate::function::{ @@ -655,177 +654,3 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } - -/// Extensions for configuring [`Expr::AggregateFunction`] -/// -/// Adds methods to [`Expr`] that make it easy to set optional aggregate options -/// such as `ORDER BY`, `FILTER` and `DISTINCT` -/// -/// # Example -/// ```no_run -/// # use datafusion_common::Result; -/// # use datafusion_expr::{AggregateUDF, col, Expr, lit}; -/// # use sqlparser::ast::NullTreatment; -/// # fn count(arg: Expr) -> Expr { todo!{} } -/// # fn first_value(arg: Expr) -> Expr { todo!{} } -/// # fn main() -> Result<()> { -/// use datafusion_expr::AggregateExt; -/// -/// // Create COUNT(x FILTER y > 5) -/// let agg = count(col("x")) -/// .filter(col("y").gt(lit(5))) -/// .build()?; -/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) -/// let sort_expr = col("y").sort(true, true); -/// let agg = first_value(col("x")) -/// .order_by(vec![sort_expr]) -/// .null_treatment(NullTreatment::IgnoreNulls) -/// .build()?; -/// # Ok(()) -/// # } -/// ``` -pub trait AggregateExt { - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(self, order_by: Vec) -> AggregateBuilder; - /// Add `FILTER ` - fn filter(self, filter: Expr) -> AggregateBuilder; - /// Add `DISTINCT` - fn distinct(self) -> AggregateBuilder; - /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder; -} - -/// Implementation of [`AggregateExt`]. -/// -/// See [`AggregateExt`] for usage and examples -#[derive(Debug, Clone)] -pub struct AggregateBuilder { - udaf: Option, - order_by: Option>, - filter: Option, - distinct: bool, - null_treatment: Option, -} - -impl AggregateBuilder { - /// Create a new `AggregateBuilder`, see [`AggregateExt`] - - fn new(udaf: Option) -> Self { - Self { - udaf, - order_by: None, - filter: None, - distinct: false, - null_treatment: None, - } - } - - /// Updates and returns the in progress [`Expr::AggregateFunction`] - /// - /// # Errors: - /// - /// Returns an error of this builder [`AggregateExt`] was used with an - /// `Expr` variant other than [`Expr::AggregateFunction`] - pub fn build(self) -> Result { - let Self { - udaf, - order_by, - filter, - distinct, - null_treatment, - } = self; - - let Some(mut udaf) = udaf else { - return plan_err!( - "AggregateExt can only be used with Expr::AggregateFunction" - ); - }; - - if let Some(order_by) = &order_by { - for expr in order_by.iter() { - if !matches!(expr, Expr::Sort(_)) { - return plan_err!( - "ORDER BY expressions must be Expr::Sort, found {expr:?}" - ); - } - } - } - - udaf.order_by = order_by; - udaf.filter = filter.map(Box::new); - udaf.distinct = distinct; - udaf.null_treatment = null_treatment; - Ok(Expr::AggregateFunction(udaf)) - } - - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - pub fn order_by(mut self, order_by: Vec) -> AggregateBuilder { - self.order_by = Some(order_by); - self - } - - /// Add `FILTER ` - pub fn filter(mut self, filter: Expr) -> AggregateBuilder { - self.filter = Some(filter); - self - } - - /// Add `DISTINCT` - pub fn distinct(mut self) -> AggregateBuilder { - self.distinct = true; - self - } - - /// Add `RESPECT NULLS` or `IGNORE NULLS` - pub fn null_treatment(mut self, null_treatment: NullTreatment) -> AggregateBuilder { - self.null_treatment = Some(null_treatment); - self - } -} - -impl AggregateExt for Expr { - fn order_by(self, order_by: Vec) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.order_by = Some(order_by); - builder - } - _ => AggregateBuilder::new(None), - } - } - fn filter(self, filter: Expr) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.filter = Some(filter); - builder - } - _ => AggregateBuilder::new(None), - } - } - fn distinct(self) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.distinct = true; - builder - } - _ => AggregateBuilder::new(None), - } - } - fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.null_treatment = Some(null_treatment); - builder - } - _ => AggregateBuilder::new(None), - } - } -} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 1a6b21e3dd29..5abce013dfb6 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -28,9 +28,10 @@ use arrow::datatypes::DataType; use datafusion_common::Result; +use crate::expr::WindowFunction; use crate::{ function::WindowFunctionSimplification, Expr, PartitionEvaluator, - PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, + PartitionEvaluatorFactory, ReturnTypeFunction, Signature, }; /// Logical representation of a user-defined window function (UDWF) @@ -123,28 +124,19 @@ impl WindowUDF { Self::new_from_impl(AliasedWindowUDFImpl::new(Arc::clone(&self.inner), aliases)) } - /// creates a [`Expr`] that calls the window function given - /// the `partition_by`, `order_by`, and `window_frame` definition + /// creates a [`Expr`] that calls the window function with default + /// values for `order_by`, `partition_by`, `window_frame`. /// - /// This utility allows using the UDWF without requiring access to - /// the registry, such as with the DataFrame API. - pub fn call( - &self, - args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: WindowFrame, - ) -> Expr { + /// See [`ExprFunctionExt`] for details on setting these values. + /// + /// This utility allows using a user defined window function without + /// requiring access to the registry, such as with the DataFrame API. + /// + /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt + pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(crate::expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: None, - }) + Expr::WindowFunction(WindowFunction::new(fun, args)) } /// Returns this function's name @@ -210,7 +202,7 @@ where /// # use std::any::Any; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; /// #[derive(Debug, Clone)] /// struct SmoothIt { @@ -244,12 +236,13 @@ where /// let smooth_it = WindowUDF::from(SmoothIt::new()); /// /// // Call the function `add_one(col)` -/// let expr = smooth_it.call( -/// vec![col("speed")], // smooth_it(speed) -/// vec![col("car")], // PARTITION BY car -/// vec![col("time").sort(true, true)], // ORDER BY time ASC -/// WindowFrame::new(None), -/// ); +/// // smooth_it(speed) OVER (PARTITION BY car ORDER BY time ASC) +/// let expr = smooth_it.call(vec![col("speed")]) +/// .partition_by(vec![col("car")]) +/// .order_by(vec![col("time").sort(true, true)]) +/// .window_frame(WindowFrame::new(None)) +/// .build() +/// .unwrap(); /// ``` pub trait WindowUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 889aa0952e51..2ef1597abfd1 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1253,8 +1253,8 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::sum_udaf, AggregateFunction, Cast, WindowFrame, - WindowFunctionDefinition, + test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1270,34 +1270,18 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1317,35 +1301,32 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], - vec![], - vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + )) + .order_by(vec![ + name_desc.clone(), + age_asc.clone(), + created_at_desc.clone(), + ]) + .build() + .unwrap(); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1373,26 +1354,26 @@ mod tests { Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], - vec![], - vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), + )) + .order_by(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), + ]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(), Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], - vec![], - vec![ - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), + )) + .order_by(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), + Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), + ]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(), ]; let expected = vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs new file mode 100644 index 000000000000..5e81464d39c2 --- /dev/null +++ b/datafusion/expr/src/window_function.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::ScalarValue; + +use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; + +/// Create an expression to represent the `row_number` window function +pub fn row_number() -> Expr { + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::RowNumber, + vec![], + )) +} + +/// Create an expression to represent the `rank` window function +pub fn rank() -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Rank, vec![])) +} + +/// Create an expression to represent the `dense_rank` window function +pub fn dense_rank() -> Expr { + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::DenseRank, + vec![], + )) +} + +/// Create an expression to represent the `percent_rank` window function +pub fn percent_rank() -> Expr { + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::PercentRank, + vec![], + )) +} + +/// Create an expression to represent the `cume_dist` window function +pub fn cume_dist() -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![])) +} + +/// Create an expression to represent the `ntile` window function +pub fn ntile(arg: Expr) -> Expr { + Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg])) +} + +/// Create an expression to represent the `lag` window function +pub fn lag( + arg: Expr, + shift_offset: Option, + default_value: Option, +) -> Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::Lag, + vec![arg, shift_offset_lit, default_lit], + )) +} + +/// Create an expression to represent the `lead` window function +pub fn lead( + arg: Expr, + shift_offset: Option, + default_value: Option, +) -> Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::Lead, + vec![arg, shift_offset_lit, default_lit], + )) +} + +/// Create an expression to represent the `nth_value` window function +pub fn nth_value(arg: Expr, n: i64) -> Expr { + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::NthValue, + vec![arg, n.lit()], + )) +} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index ba11f7e91e07..8969937d377c 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,8 +31,8 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, - TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, + Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fa8aeb86ed31..338268e299da 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -101,6 +101,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, @@ -223,15 +224,14 @@ mod tests { .window(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], - WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build()?])? .project(vec![count(wildcard())])? .build()?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 50fb1b8193ce..75dbb4d1adcd 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -47,8 +47,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, AggregateUDF, Expr, ExprSchemable, LogicalPlan, - Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, + type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, + LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -466,14 +467,14 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => args, }; - Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - )))) + Ok(Transformed::yes( + Expr::WindowFunction(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build()?, + )) } Expr::Alias(_) | Expr::Column(_) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 58c1ae297b02..16abf93f3807 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -806,7 +806,7 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; - use datafusion_expr::AggregateExt; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, @@ -815,7 +815,7 @@ mod tests { lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, - Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFrame, + Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; @@ -1919,19 +1919,14 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.a")], - vec![col("test.b")], - vec![], - WindowFrame::new(None), - None, - )); + )) + .partition_by(vec![col("test.b")]) + .build() + .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.b")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let col1 = col(max1.display_name()?); let col2 = col(max2.display_name()?); diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index fcd33be618f7..430517121f2a 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -23,7 +23,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder}; +use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 56556f387d1b..38dfbb3ed551 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3855,15 +3855,9 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( - udwf, - vec![], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + let window_function_expr = Expr::WindowFunction( + datafusion_expr::expr::WindowFunction::new(udwf, vec![]), + ); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -3871,15 +3865,9 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( - udwf, - vec![], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + let window_function_expr = Expr::WindowFunction( + datafusion_expr::expr::WindowFunction::new(udwf, vec![]), + ); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index f2b4abdd6cbd..d776e6598cbe 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -354,7 +354,7 @@ mod tests { use super::*; use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; - use datafusion_expr::AggregateExt; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index aea8e454a31c..7b717add3311 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -25,6 +25,7 @@ use datafusion_common::{ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; +use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, @@ -299,7 +300,6 @@ pub fn parse_expr( ) })?; // TODO: support proto for null treatment - let null_treatment = None; regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { @@ -314,11 +314,12 @@ pub fn parse_expr( "expr", codec, )?], - partition_by, - order_by, - window_frame, - None, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) @@ -335,11 +336,12 @@ pub fn parse_expr( built_in_function, ), args, - partition_by, - order_by, - window_frame, - null_treatment, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { @@ -354,11 +356,12 @@ pub fn parse_expr( Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, - partition_by, - order_by, - window_frame, - None, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -373,11 +376,12 @@ pub fn parse_expr( Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, - partition_by, - order_by, - window_frame, - None, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 25223c3731be..7a4de4f61a38 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,7 +60,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateFunction, AggregateUDF, ColumnarValue, + Accumulator, AggregateFunction, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, @@ -2073,11 +2073,12 @@ fn roundtrip_window() { datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, false)]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( @@ -2085,11 +2086,12 @@ fn roundtrip_window() { datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(false, true)]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 3. with window_frame with row numbers let range_number_frame = WindowFrame::new_bounds( @@ -2103,11 +2105,12 @@ fn roundtrip_window() { datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], - vec![col("col1")], - vec![col("col2")], - range_number_frame, - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(false, false)]) + .window_frame(range_number_frame) + .build() + .unwrap(); // 4. test with AggregateFunction let row_number_frame = WindowFrame::new_bounds( @@ -2119,11 +2122,12 @@ fn roundtrip_window() { let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, true)]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); // 5. test with AggregateUDF #[derive(Debug)] @@ -2168,11 +2172,12 @@ fn roundtrip_window() { let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, true)]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); ctx.register_udaf(dummy_agg); // 6. test with WindowUDF @@ -2244,20 +2249,20 @@ fn roundtrip_window() { let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, true)]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], - vec![], - vec![], - row_number_frame.clone(), - None, - )); + )) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); ctx.register_udwf(dummy_window_udf); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 0c4b125e76d0..fd759c161381 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -24,7 +24,8 @@ use datafusion_common::{ use datafusion_expr::planner::PlannerResult; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, + expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + WindowFunctionDefinition, }; use datafusion_expr::{ expr::{ScalarFunction, Unnest}, @@ -329,20 +330,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(aggregate_fun), args, - partition_by, - order_by, - window_frame, - null_treatment, )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() } _ => Expr::WindowFunction(expr::WindowFunction::new( fun, self.function_args_to_expr(args, schema, planner_context)?, - partition_by, - order_by, - window_frame, - null_treatment, - )), + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap(), }; return Ok(expr); } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index f4ea44f37d78..3f7a85da276b 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1507,7 +1507,7 @@ mod tests { table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_expr::{interval_month_day_nano_lit, AggregateExt}; + use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 6e693a0e7087..60036e440ffb 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -308,7 +308,7 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Function Builder -You can also use the `AggregateExt` trait to more easily build Aggregate arguments `Expr`. +You can also use the `ExprFunctionExt` trait to more easily build Aggregate arguments `Expr`. See `datafusion-examples/examples/expr_api.rs` for example usage. From c9518245fa8138b19402ac7f124d9091adad2426 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Wed, 24 Jul 2024 19:42:47 -0700 Subject: [PATCH 13/17] Parsing SQL strings to Exprs with the qualified schema (#11562) * Parsing SQL strings to Exprs wtih the qualified schema * refactor code --- .../core/tests/expr_api/parse_sql_expr.rs | 16 ++++- .../optimizer/tests/optimizer_integration.rs | 2 +- datafusion/sql/src/expr/identifier.rs | 60 +++++++++---------- datafusion/sql/tests/sql_integration.rs | 2 +- .../sqllogictest/test_files/group_by.slt | 2 +- 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs index 991579b5a350..a3defceee247 100644 --- a/datafusion/core/tests/expr_api/parse_sql_expr.rs +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -17,10 +17,12 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::{CsvReadOptions, SessionContext}; +use datafusion_common::DFSchema; use datafusion_common::{DFSchemaRef, Result, ToDFSchema}; +use datafusion_expr::col; +use datafusion_expr::lit; use datafusion_expr::Expr; use datafusion_sql::unparser::Unparser; - /// A schema like: /// /// a: Int32 (possibly with nulls) @@ -85,6 +87,18 @@ async fn round_trip_dataframe(sql: &str) -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_qualified_schema() -> Result<()> { + let sql = "a < 5 OR a = 8"; + let expr = col("t.a").lt(lit(5_i64)).or(col("t.a").eq(lit(8_i64))); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from_qualified_schema("t", &schema).unwrap(); + let ctx = SessionContext::new(); + let parsed_expr = ctx.parse_sql_expr(sql, &df_schema)?; + assert_eq!(parsed_expr, expr); + Ok(()) +} + fn unparse_sql_expr(expr: &Expr) -> Result { let unparser = Unparser::default(); diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index c0863839dba1..3c77ffaa17f6 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -109,7 +109,7 @@ fn distribute_by() -> Result<()> { // regression test for https://github.com/apache/datafusion/issues/3234 let sql = "SELECT col_int32, col_utf8 FROM test DISTRIBUTE BY (col_utf8)"; let plan = test_sql(sql)?; - let expected = "Repartition: DistributeBy(col_utf8)\ + let expected = "Repartition: DistributeBy(test.col_utf8)\ \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index f8979bde3086..9b8356701a40 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -26,6 +26,7 @@ use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_expr::UNNAMED_TABLE; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_identifier_to_expr( @@ -50,40 +51,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // compound identifiers, but this is not a compound // identifier. (e.g. it is "foo.bar" not foo.bar) let normalize_ident = self.normalizer.normalize(id); - match schema.field_with_unqualified_name(normalize_ident.as_str()) { - Ok(_) => { - // found a match without a qualified name, this is a inner table column - Ok(Expr::Column(Column { - relation: None, - name: normalize_ident, - })) - } - Err(_) => { - // check the outer_query_schema and try to find a match - if let Some(outer) = planner_context.outer_query_schema() { - match outer.qualified_field_with_unqualified_name( - normalize_ident.as_str(), - ) { - Ok((qualifier, field)) => { - // found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - Ok(Expr::OuterReferenceColumn( - field.data_type().clone(), - Column::from((qualifier, field)), - )) - } - Err(_) => Ok(Expr::Column(Column { - relation: None, - name: normalize_ident, - })), - } - } else { - Ok(Expr::Column(Column { - relation: None, - name: normalize_ident, - })) - } + + // Check for qualified field with unqualified name + if let Ok((qualifier, _)) = + schema.qualified_field_with_unqualified_name(normalize_ident.as_str()) + { + return Ok(Expr::Column(Column { + relation: qualifier.filter(|q| q.table() != UNNAMED_TABLE).cloned(), + name: normalize_ident, + })); + } + + // Check the outer query schema + if let Some(outer) = planner_context.outer_query_schema() { + if let Ok((qualifier, field)) = + outer.qualified_field_with_unqualified_name(normalize_ident.as_str()) + { + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + return Ok(Expr::OuterReferenceColumn( + field.data_type().clone(), + Column::from((qualifier, field)), + )); } } + + // Default case + Ok(Expr::Column(Column { + relation: None, + name: normalize_ident, + })) } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 3291560383df..511f97c4750e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3274,7 +3274,7 @@ fn test_offset_before_limit() { #[test] fn test_distribute_by() { let sql = "select id from person distribute by state"; - let expected = "Repartition: DistributeBy(state)\ + let expected = "Repartition: DistributeBy(person.state)\ \n Projection: person.id\ \n TableScan: person"; quick_test(sql, expected); diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index a3cc10e1eeb8..b7d466d8bf82 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4077,7 +4077,7 @@ FROM (SELECT c, b, a, SUM(d) as sum1 DISTRIBUTE BY a ---- logical_plan -01)Repartition: DistributeBy(a) +01)Repartition: DistributeBy(multiple_ordered_table_with_pk.a) 02)--Projection: multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b, sum(multiple_ordered_table_with_pk.d) AS sum1 03)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 04)------TableScan: multiple_ordered_table_with_pk projection=[a, b, c, d] From f12b3db9c3a0507ed0bc7984ce0e290be0ca9e2d Mon Sep 17 00:00:00 2001 From: Michael J Ward Date: Thu, 25 Jul 2024 03:20:54 -0500 Subject: [PATCH 14/17] fix: expose the fluent API fn for approx_distinct instead of the module (#11644) * fix: expose the fluent API fn for approx_distinct instead of the module Fixes: https://github.com/apache/datafusion/issues/11643 * add approx_distinct to roundtrip_expr_api test * lint: cargo fmt --- datafusion/functions-aggregate/src/lib.rs | 2 +- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index b39b1955bb07..32ca05b8cdd9 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -90,7 +90,7 @@ use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { - pub use super::approx_distinct; + pub use super::approx_distinct::approx_distinct; pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 7a4de4f61a38..9c81c4852783 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -67,7 +67,7 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ - array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, + approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -717,6 +717,7 @@ async fn roundtrip_expr_api() -> Result<()> { var_pop(lit(2.2)), stddev(lit(2.2)), stddev_pop(lit(2.2)), + approx_distinct(lit(2)), approx_median(lit(2)), approx_percentile_cont(lit(2), lit(0.5)), approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), From 49d9d45f36989cd448ed6513af65948b6b0100ec Mon Sep 17 00:00:00 2001 From: kf zheng <100595273+Kev1n8@users.noreply.github.com> Date: Thu, 25 Jul 2024 16:53:05 +0800 Subject: [PATCH 15/17] Add some zero column tests covering LIMIT, GROUP BY, WHERE, JOIN, and WINDOW (#11624) * add zero column tests covering LIMIT, GROUP BY, WHERE, JOIN, and WINDOW * change from statement to query to be explicit about no rows * Revert "change from statement to query to be explicit about no rows" This reverts commit fd381fca5e9d80f62062c41c4326e4cbe50b2129. --------- Co-authored-by: Andrew Lamb --- datafusion/sqllogictest/test_files/select.slt | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 6884efc07e15..a5f31cb9b466 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1225,6 +1225,63 @@ statement ok SELECT * EXCEPT(a, b, c, d) FROM table1 +# try zero column with LIMIT, 1 row but empty +statement ok +SELECT * EXCEPT (a, b, c, d) +FROM table1 +LIMIT 1 + +# try zero column with GROUP BY, 2 row but empty +statement ok +SELECT * EXCEPT (a, b, c, d) +FROM table1 +GROUP BY a + +# try zero column with WHERE, 1 row but empty +statement ok +SELECT * EXCEPT (a, b, c, d) +FROM table1 +WHERE a = 1 + +# create table2 the same with table1 +statement ok +CREATE TABLE table2 ( + a int, + b int, + c int, + d int +) as values + (1, 10, 100, 1000), + (2, 20, 200, 2000); + +# try zero column with inner JOIN, 2 row but empty +statement ok +WITH t1 AS (SELECT a AS t1_a FROM table1), t2 AS (SELECT a AS t2_a FROM table2) +SELECT * EXCEPT (t1_a, t2_a) +FROM t1 +JOIN t2 ON (t1_a = t2_a) + +# try zero column with more JOIN, 2 row but empty +statement ok +SELECT * EXCEPT (b1, b2) +FROM ( + SELECT b AS b1 FROM table1 +) +JOIN ( + SELECT b AS b2 FROM table2 +) ON b1 = b2 + +# try zero column with Window, 2 row but empty +statement ok +SELECT * EXCEPT (a, b, row_num) +FROM ( + SELECT + a, + b, + ROW_NUMBER() OVER (ORDER BY b) AS row_num + FROM table1 +) + # EXCLUDE order shouldn't matter query II SELECT * EXCLUDE(b, a) From 7db4213b71ed9e914c5a4f16954abfa20b091ae3 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Thu, 25 Jul 2024 14:55:58 +0300 Subject: [PATCH 16/17] Refactor/simplify window frame utils (#11648) * Simplify window frame utils * Remove unwrap calls * Fix format * Incorporate review feedback --- .../core/tests/fuzz_cases/window_fuzz.rs | 64 ++++++++----- datafusion/expr/src/window_frame.rs | 89 +++++++++---------- .../proto/src/logical_plan/from_proto.rs | 38 ++++---- datafusion/sql/src/expr/function.rs | 19 ++-- 4 files changed, 105 insertions(+), 105 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 5bd19850cacc..c97621ec4d01 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, Int32Array}; +use arrow::array::{ArrayRef, Int32Array, StringArray}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -45,6 +45,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; use hashbrown::HashMap; +use rand::distributions::Alphanumeric; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -607,25 +608,6 @@ fn convert_bound_to_current_row_if_applicable( } } -/// This utility determines whether a given window frame can be executed with -/// multiple ORDER BY expressions. As an example, range frames with offset (such -/// as `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING`) cannot have ORDER BY clauses -/// of the form `\[ORDER BY a ASC, b ASC, ...]` -fn can_accept_multi_orderby(window_frame: &WindowFrame) -> bool { - match window_frame.units { - WindowFrameUnits::Rows => true, - WindowFrameUnits::Range => { - // Range can only accept multi ORDER BY clauses when bounds are - // CURRENT ROW or UNBOUNDED PRECEDING/FOLLOWING: - (window_frame.start_bound.is_unbounded() - || window_frame.start_bound == WindowFrameBound::CurrentRow) - && (window_frame.end_bound.is_unbounded() - || window_frame.end_bound == WindowFrameBound::CurrentRow) - } - WindowFrameUnits::Groups => true, - } -} - /// Perform batch and running window same input /// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal async fn run_window_test( @@ -649,7 +631,7 @@ async fn run_window_test( options: SortOptions::default(), }) } - if orderby_exprs.len() > 1 && !can_accept_multi_orderby(&window_frame) { + if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() { orderby_exprs = orderby_exprs[0..1].to_vec(); } let mut partitionby_exprs = vec![]; @@ -733,11 +715,30 @@ async fn run_window_test( )?) as _; let task_ctx = ctx.task_ctx(); let collected_usual = collect(usual_window_exec, task_ctx.clone()).await?; - let collected_running = collect(running_window_exec, task_ctx).await?; + let collected_running = collect(running_window_exec, task_ctx) + .await? + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(); // BoundedWindowAggExec should produce more chunk than the usual WindowAggExec. // Otherwise it means that we cannot generate result in running mode. - assert!(collected_running.len() > collected_usual.len()); + let err_msg = format!("Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}"); + // Below check makes sure that, streaming execution generates more chunks than the bulk execution. + // Since algorithms and operators works on sliding windows in the streaming execution. + // However, in the current test setup for some random generated window frame clauses: It is not guaranteed + // for streaming execution to generate more chunk than its non-streaming counter part in the Linear mode. + // As an example window frame `OVER(PARTITION BY d ORDER BY a RANGE BETWEEN CURRENT ROW AND 9 FOLLOWING)` + // needs to receive a=10 to generate result for the rows where a=0. If the input data generated is between the range [0, 9]. + // even in streaming mode, generated result will be single bulk as in the non-streaming version. + if search_mode != Linear { + assert!( + collected_running.len() > collected_usual.len(), + "{}", + err_msg + ); + } + // compare let usual_formatted = pretty_format_batches(&collected_usual)?.to_string(); let running_formatted = pretty_format_batches(&collected_running)?.to_string(); @@ -767,10 +768,17 @@ async fn run_window_test( Ok(()) } +fn generate_random_string(rng: &mut StdRng, length: usize) -> String { + rng.sample_iter(&Alphanumeric) + .take(length) + .map(char::from) + .collect() +} + /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x -fn make_staggered_batches( +pub(crate) fn make_staggered_batches( len: usize, n_distinct: usize, random_seed: u64, @@ -779,6 +787,7 @@ fn make_staggered_batches( let mut rng = StdRng::seed_from_u64(random_seed); let mut input123: Vec<(i32, i32, i32)> = vec![(0, 0, 0); len]; let mut input4: Vec = vec![0; len]; + let mut input5: Vec = vec!["".to_string(); len]; input123.iter_mut().for_each(|v| { *v = ( rng.gen_range(0..n_distinct) as i32, @@ -788,10 +797,15 @@ fn make_staggered_batches( }); input123.sort(); rng.fill(&mut input4[..]); + input5.iter_mut().for_each(|v| { + *v = generate_random_string(&mut rng, 1); + }); + input5.sort(); let input1 = Int32Array::from_iter_values(input123.iter().map(|k| k.0)); let input2 = Int32Array::from_iter_values(input123.iter().map(|k| k.1)); let input3 = Int32Array::from_iter_values(input123.iter().map(|k| k.2)); let input4 = Int32Array::from_iter_values(input4); + let input5 = StringArray::from_iter_values(input5); // split into several record batches let mut remainder = RecordBatch::try_from_iter(vec![ @@ -799,6 +813,7 @@ fn make_staggered_batches( ("b", Arc::new(input2) as ArrayRef), ("c", Arc::new(input3) as ArrayRef), ("x", Arc::new(input4) as ArrayRef), + ("string_field", Arc::new(input5) as ArrayRef), ]) .unwrap(); @@ -807,6 +822,7 @@ fn make_staggered_batches( while remainder.num_rows() > 0 { let batch_size = rng.gen_range(0..50); if remainder.num_rows() < batch_size { + batches.push(remainder); break; } batches.push(remainder.slice(0, batch_size)); diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index c0617eaf4ed4..5b2f8982a559 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -26,8 +26,7 @@ use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::expr::Sort; -use crate::Expr; +use crate::{lit, Expr}; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; @@ -246,59 +245,51 @@ impl WindowFrame { causal, } } -} -/// Regularizes ORDER BY clause for window definition for implicit corner cases. -pub fn regularize_window_order_by( - frame: &WindowFrame, - order_by: &mut Vec, -) -> Result<()> { - if frame.units == WindowFrameUnits::Range && order_by.len() != 1 { - // Normally, RANGE frames require an ORDER BY clause with exactly one - // column. However, an ORDER BY clause may be absent or present but with - // more than one column in two edge cases: - // 1. start bound is UNBOUNDED or CURRENT ROW - // 2. end bound is CURRENT ROW or UNBOUNDED. - // In these cases, we regularize the ORDER BY clause if the ORDER BY clause - // is absent. If an ORDER BY clause is present but has more than one column, - // the ORDER BY clause is unchanged. Note that this follows Postgres behavior. - if (frame.start_bound.is_unbounded() - || frame.start_bound == WindowFrameBound::CurrentRow) - && (frame.end_bound == WindowFrameBound::CurrentRow - || frame.end_bound.is_unbounded()) - { - // If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause - // with constant value as sort key. - // If an ORDER BY clause is present but has more than one column, it is - // unchanged. - if order_by.is_empty() { - order_by.push(Expr::Sort(Sort::new( - Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))), - true, - false, - ))); + /// Regularizes the ORDER BY clause of the window frame. + pub fn regularize_order_bys(&self, order_by: &mut Vec) -> Result<()> { + match self.units { + // Normally, RANGE frames require an ORDER BY clause with exactly + // one column. However, an ORDER BY clause may be absent or have + // more than one column when the start/end bounds are UNBOUNDED or + // CURRENT ROW. + WindowFrameUnits::Range if self.free_range() => { + // If an ORDER BY clause is absent, it is equivalent to an + // ORDER BY clause with constant value as sort key. If an + // ORDER BY clause is present but has more than one column, + // it is unchanged. Note that this follows PostgreSQL behavior. + if order_by.is_empty() { + order_by.push(lit(1u64).sort(true, false)); + } + } + WindowFrameUnits::Range if order_by.len() != 1 => { + return plan_err!("RANGE requires exactly one ORDER BY column"); } + WindowFrameUnits::Groups if order_by.is_empty() => { + return plan_err!("GROUPS requires an ORDER BY clause"); + } + _ => {} } + Ok(()) } - Ok(()) -} -/// Checks if given window frame is valid. In particular, if the frame is RANGE -/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column. -pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { - if frame.units == WindowFrameUnits::Range && order_bys != 1 { - // See `regularize_window_order_by`. - if !(frame.start_bound.is_unbounded() - || frame.start_bound == WindowFrameBound::CurrentRow) - || !(frame.end_bound == WindowFrameBound::CurrentRow - || frame.end_bound.is_unbounded()) - { - plan_err!("RANGE requires exactly one ORDER BY column")? + /// Returns whether the window frame can accept multiple ORDER BY expressons. + pub fn can_accept_multi_orderby(&self) -> bool { + match self.units { + WindowFrameUnits::Rows => true, + WindowFrameUnits::Range => self.free_range(), + WindowFrameUnits::Groups => true, } - } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { - plan_err!("GROUPS requires an ORDER BY clause")? - }; - Ok(()) + } + + /// Returns whether the window frame is "free range"; i.e. its start/end + /// bounds are UNBOUNDED or CURRENT ROW. + fn free_range(&self) -> bool { + (self.start_bound.is_unbounded() + || self.start_bound == WindowFrameBound::CurrentRow) + && (self.end_bound.is_unbounded() + || self.end_bound == WindowFrameBound::CurrentRow) + } } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 7b717add3311..5e9b9af49ae9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,18 +19,14 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - internal_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, + exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue, TableReference, UnnestOptions, }; -use datafusion_expr::expr::Unnest; -use datafusion_expr::expr::{Alias, Placeholder}; -use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; -use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - expr::{self, InList, Sort, WindowFunction}, + expr::{self, Alias, InList, Placeholder, Sort, Unnest, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, - GroupingSet, + ExprFunctionExt, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -289,24 +285,22 @@ pub fn parse_expr( .window_frame .as_ref() .map::, _>(|window_frame| { - let window_frame = window_frame.clone().try_into()?; - check_window_frame(&window_frame, order_by.len()) + let window_frame: WindowFrame = window_frame.clone().try_into()?; + window_frame + .regularize_order_bys(&mut order_by) .map(|_| window_frame) }) .transpose()? .ok_or_else(|| { - DataFusionError::Execution( - "missing window frame during deserialization".to_string(), - ) + exec_datafusion_err!("missing window frame during deserialization") })?; - // TODO: support proto for null treatment - regularize_window_order_by(&window_frame, &mut order_by)?; + // TODO: support proto for null treatment match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = parse_i32_to_aggregate_function(i)?; - Ok(Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateFunction(aggr_function), vec![parse_required_expr( expr.expr.as_deref(), @@ -319,7 +313,7 @@ pub fn parse_expr( .order_by(order_by) .window_frame(window_frame) .build() - .unwrap()) + .map_err(Error::DataFusionError) } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) @@ -331,7 +325,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); - Ok(Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), @@ -341,7 +335,7 @@ pub fn parse_expr( .order_by(order_by) .window_frame(window_frame) .build() - .unwrap()) + .map_err(Error::DataFusionError) } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { @@ -353,7 +347,7 @@ pub fn parse_expr( parse_optional_expr(expr.expr.as_deref(), registry, codec)? .map(|e| vec![e]) .unwrap_or_else(Vec::new); - Ok(Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -361,7 +355,7 @@ pub fn parse_expr( .order_by(order_by) .window_frame(window_frame) .build() - .unwrap()) + .map_err(Error::DataFusionError) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -373,7 +367,7 @@ pub fn parse_expr( parse_optional_expr(expr.expr.as_deref(), registry, codec)? .map(|e| vec![e]) .unwrap_or_else(Vec::new); - Ok(Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) @@ -381,7 +375,7 @@ pub fn parse_expr( .order_by(order_by) .window_frame(window_frame) .build() - .unwrap()) + .map_err(Error::DataFusionError) } } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index fd759c161381..2506ef740fde 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::str::FromStr; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; + use arrow_schema::DataType; use datafusion_common::{ internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; use datafusion_expr::planner::PlannerResult; -use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition, @@ -36,7 +38,7 @@ use sqlparser::ast::{ FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, NullTreatment, ObjectName, OrderByExpr, WindowType, }; -use std::str::FromStr; + use strum::IntoEnumIterator; /// Suggest a valid function based on an invalid input function name @@ -306,14 +308,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .window_frame .as_ref() .map(|window_frame| { - let window_frame = window_frame.clone().try_into()?; - check_window_frame(&window_frame, order_by.len()) + let window_frame: WindowFrame = window_frame.clone().try_into()?; + window_frame + .regularize_order_bys(&mut order_by) .map(|_| window_frame) }) .transpose()?; let window_frame = if let Some(window_frame) = window_frame { - regularize_window_order_by(&window_frame, &mut order_by)?; window_frame } else if let Some(is_ordering_strict) = is_ordering_strict { WindowFrame::new(Some(is_ordering_strict)) @@ -322,7 +324,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; if let Ok(fun) = self.find_window_func(&name) { - let expr = match fun { + return match fun { WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { let args = self.function_args_to_expr(args, schema, planner_context)?; @@ -336,7 +338,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .window_frame(window_frame) .null_treatment(null_treatment) .build() - .unwrap() } _ => Expr::WindowFunction(expr::WindowFunction::new( fun, @@ -346,10 +347,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .order_by(order_by) .window_frame(window_frame) .null_treatment(null_treatment) - .build() - .unwrap(), + .build(), }; - return Ok(expr); } } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function From 71903e1b2c62cda9a92808a71f8b63bcdd43762d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 25 Jul 2024 14:04:33 -0400 Subject: [PATCH 17/17] Minor: use `ready!` macro to simplify FilterExec poll loop (#11649) --- datafusion/physical-plan/src/filter.rs | 35 +++++++++++--------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index a9d78d059f5c..67de0989649e 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use super::{ ColumnStatistics, DisplayAs, ExecutionPlanProperties, PlanProperties, @@ -59,6 +59,7 @@ pub struct FilterExec { metrics: ExecutionPlanMetricsSet, /// Selectivity for statistics. 0 = no rows, 100 = all rows default_selectivity: u8, + /// Properties equivalence properties, partitioning, etc. cache: PlanProperties, } @@ -375,26 +376,20 @@ impl Stream for FilterExecStream { ) -> Poll> { let poll; loop { - match self.input.poll_next_unpin(cx) { - Poll::Ready(value) => match value { - Some(Ok(batch)) => { - let timer = self.baseline_metrics.elapsed_compute().timer(); - let filtered_batch = batch_filter(&batch, &self.predicate)?; - // skip entirely filtered batches - if filtered_batch.num_rows() == 0 { - continue; - } - timer.done(); - poll = Poll::Ready(Some(Ok(filtered_batch))); - break; + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let timer = self.baseline_metrics.elapsed_compute().timer(); + let filtered_batch = batch_filter(&batch, &self.predicate)?; + // skip entirely filtered batches + if filtered_batch.num_rows() == 0 { + continue; } - _ => { - poll = Poll::Ready(value); - break; - } - }, - Poll::Pending => { - poll = Poll::Pending; + timer.done(); + poll = Poll::Ready(Some(Ok(filtered_batch))); + break; + } + value => { + poll = Poll::Ready(value); break; } }