diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 5bd19850cacc..c97621ec4d01 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, Int32Array}; +use arrow::array::{ArrayRef, Int32Array, StringArray}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -45,6 +45,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; use hashbrown::HashMap; +use rand::distributions::Alphanumeric; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -607,25 +608,6 @@ fn convert_bound_to_current_row_if_applicable( } } -/// This utility determines whether a given window frame can be executed with -/// multiple ORDER BY expressions. As an example, range frames with offset (such -/// as `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING`) cannot have ORDER BY clauses -/// of the form `\[ORDER BY a ASC, b ASC, ...]` -fn can_accept_multi_orderby(window_frame: &WindowFrame) -> bool { - match window_frame.units { - WindowFrameUnits::Rows => true, - WindowFrameUnits::Range => { - // Range can only accept multi ORDER BY clauses when bounds are - // CURRENT ROW or UNBOUNDED PRECEDING/FOLLOWING: - (window_frame.start_bound.is_unbounded() - || window_frame.start_bound == WindowFrameBound::CurrentRow) - && (window_frame.end_bound.is_unbounded() - || window_frame.end_bound == WindowFrameBound::CurrentRow) - } - WindowFrameUnits::Groups => true, - } -} - /// Perform batch and running window same input /// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal async fn run_window_test( @@ -649,7 +631,7 @@ async fn run_window_test( options: SortOptions::default(), }) } - if orderby_exprs.len() > 1 && !can_accept_multi_orderby(&window_frame) { + if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() { orderby_exprs = orderby_exprs[0..1].to_vec(); } let mut partitionby_exprs = vec![]; @@ -733,11 +715,30 @@ async fn run_window_test( )?) as _; let task_ctx = ctx.task_ctx(); let collected_usual = collect(usual_window_exec, task_ctx.clone()).await?; - let collected_running = collect(running_window_exec, task_ctx).await?; + let collected_running = collect(running_window_exec, task_ctx) + .await? + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(); // BoundedWindowAggExec should produce more chunk than the usual WindowAggExec. // Otherwise it means that we cannot generate result in running mode. - assert!(collected_running.len() > collected_usual.len()); + let err_msg = format!("Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}"); + // Below check makes sure that, streaming execution generates more chunks than the bulk execution. + // Since algorithms and operators works on sliding windows in the streaming execution. + // However, in the current test setup for some random generated window frame clauses: It is not guaranteed + // for streaming execution to generate more chunk than its non-streaming counter part in the Linear mode. + // As an example window frame `OVER(PARTITION BY d ORDER BY a RANGE BETWEEN CURRENT ROW AND 9 FOLLOWING)` + // needs to receive a=10 to generate result for the rows where a=0. If the input data generated is between the range [0, 9]. + // even in streaming mode, generated result will be single bulk as in the non-streaming version. + if search_mode != Linear { + assert!( + collected_running.len() > collected_usual.len(), + "{}", + err_msg + ); + } + // compare let usual_formatted = pretty_format_batches(&collected_usual)?.to_string(); let running_formatted = pretty_format_batches(&collected_running)?.to_string(); @@ -767,10 +768,17 @@ async fn run_window_test( Ok(()) } +fn generate_random_string(rng: &mut StdRng, length: usize) -> String { + rng.sample_iter(&Alphanumeric) + .take(length) + .map(char::from) + .collect() +} + /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x -fn make_staggered_batches( +pub(crate) fn make_staggered_batches( len: usize, n_distinct: usize, random_seed: u64, @@ -779,6 +787,7 @@ fn make_staggered_batches( let mut rng = StdRng::seed_from_u64(random_seed); let mut input123: Vec<(i32, i32, i32)> = vec![(0, 0, 0); len]; let mut input4: Vec = vec![0; len]; + let mut input5: Vec = vec!["".to_string(); len]; input123.iter_mut().for_each(|v| { *v = ( rng.gen_range(0..n_distinct) as i32, @@ -788,10 +797,15 @@ fn make_staggered_batches( }); input123.sort(); rng.fill(&mut input4[..]); + input5.iter_mut().for_each(|v| { + *v = generate_random_string(&mut rng, 1); + }); + input5.sort(); let input1 = Int32Array::from_iter_values(input123.iter().map(|k| k.0)); let input2 = Int32Array::from_iter_values(input123.iter().map(|k| k.1)); let input3 = Int32Array::from_iter_values(input123.iter().map(|k| k.2)); let input4 = Int32Array::from_iter_values(input4); + let input5 = StringArray::from_iter_values(input5); // split into several record batches let mut remainder = RecordBatch::try_from_iter(vec![ @@ -799,6 +813,7 @@ fn make_staggered_batches( ("b", Arc::new(input2) as ArrayRef), ("c", Arc::new(input3) as ArrayRef), ("x", Arc::new(input4) as ArrayRef), + ("string_field", Arc::new(input5) as ArrayRef), ]) .unwrap(); @@ -807,6 +822,7 @@ fn make_staggered_batches( while remainder.num_rows() > 0 { let batch_size = rng.gen_range(0..50); if remainder.num_rows() < batch_size { + batches.push(remainder); break; } batches.push(remainder.slice(0, batch_size)); diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index c0617eaf4ed4..5b2f8982a559 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -26,8 +26,7 @@ use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::expr::Sort; -use crate::Expr; +use crate::{lit, Expr}; use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; @@ -246,59 +245,51 @@ impl WindowFrame { causal, } } -} -/// Regularizes ORDER BY clause for window definition for implicit corner cases. -pub fn regularize_window_order_by( - frame: &WindowFrame, - order_by: &mut Vec, -) -> Result<()> { - if frame.units == WindowFrameUnits::Range && order_by.len() != 1 { - // Normally, RANGE frames require an ORDER BY clause with exactly one - // column. However, an ORDER BY clause may be absent or present but with - // more than one column in two edge cases: - // 1. start bound is UNBOUNDED or CURRENT ROW - // 2. end bound is CURRENT ROW or UNBOUNDED. - // In these cases, we regularize the ORDER BY clause if the ORDER BY clause - // is absent. If an ORDER BY clause is present but has more than one column, - // the ORDER BY clause is unchanged. Note that this follows Postgres behavior. - if (frame.start_bound.is_unbounded() - || frame.start_bound == WindowFrameBound::CurrentRow) - && (frame.end_bound == WindowFrameBound::CurrentRow - || frame.end_bound.is_unbounded()) - { - // If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause - // with constant value as sort key. - // If an ORDER BY clause is present but has more than one column, it is - // unchanged. - if order_by.is_empty() { - order_by.push(Expr::Sort(Sort::new( - Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))), - true, - false, - ))); + /// Regularizes the ORDER BY clause of the window frame. + pub fn regularize_order_bys(&self, order_by: &mut Vec) -> Result<()> { + match self.units { + // Normally, RANGE frames require an ORDER BY clause with exactly + // one column. However, an ORDER BY clause may be absent or have + // more than one column when the start/end bounds are UNBOUNDED or + // CURRENT ROW. + WindowFrameUnits::Range if self.free_range() => { + // If an ORDER BY clause is absent, it is equivalent to an + // ORDER BY clause with constant value as sort key. If an + // ORDER BY clause is present but has more than one column, + // it is unchanged. Note that this follows PostgreSQL behavior. + if order_by.is_empty() { + order_by.push(lit(1u64).sort(true, false)); + } + } + WindowFrameUnits::Range if order_by.len() != 1 => { + return plan_err!("RANGE requires exactly one ORDER BY column"); } + WindowFrameUnits::Groups if order_by.is_empty() => { + return plan_err!("GROUPS requires an ORDER BY clause"); + } + _ => {} } + Ok(()) } - Ok(()) -} -/// Checks if given window frame is valid. In particular, if the frame is RANGE -/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column. -pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { - if frame.units == WindowFrameUnits::Range && order_bys != 1 { - // See `regularize_window_order_by`. - if !(frame.start_bound.is_unbounded() - || frame.start_bound == WindowFrameBound::CurrentRow) - || !(frame.end_bound == WindowFrameBound::CurrentRow - || frame.end_bound.is_unbounded()) - { - plan_err!("RANGE requires exactly one ORDER BY column")? + /// Returns whether the window frame can accept multiple ORDER BY expressons. + pub fn can_accept_multi_orderby(&self) -> bool { + match self.units { + WindowFrameUnits::Rows => true, + WindowFrameUnits::Range => self.free_range(), + WindowFrameUnits::Groups => true, } - } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { - plan_err!("GROUPS requires an ORDER BY clause")? - }; - Ok(()) + } + + /// Returns whether the window frame is "free range"; i.e. its start/end + /// bounds are UNBOUNDED or CURRENT ROW. + fn free_range(&self) -> bool { + (self.start_bound.is_unbounded() + || self.start_bound == WindowFrameBound::CurrentRow) + && (self.end_bound.is_unbounded() + || self.end_bound == WindowFrameBound::CurrentRow) + } } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 7b717add3311..5e9b9af49ae9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,18 +19,14 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - internal_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, + exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue, TableReference, UnnestOptions, }; -use datafusion_expr::expr::Unnest; -use datafusion_expr::expr::{Alias, Placeholder}; -use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; -use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - expr::{self, InList, Sort, WindowFunction}, + expr::{self, Alias, InList, Placeholder, Sort, Unnest, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, - GroupingSet, + ExprFunctionExt, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -289,24 +285,22 @@ pub fn parse_expr( .window_frame .as_ref() .map::, _>(|window_frame| { - let window_frame = window_frame.clone().try_into()?; - check_window_frame(&window_frame, order_by.len()) + let window_frame: WindowFrame = window_frame.clone().try_into()?; + window_frame + .regularize_order_bys(&mut order_by) .map(|_| window_frame) }) .transpose()? .ok_or_else(|| { - DataFusionError::Execution( - "missing window frame during deserialization".to_string(), - ) + exec_datafusion_err!("missing window frame during deserialization") })?; - // TODO: support proto for null treatment - regularize_window_order_by(&window_frame, &mut order_by)?; + // TODO: support proto for null treatment match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = parse_i32_to_aggregate_function(i)?; - Ok(Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateFunction(aggr_function), vec![parse_required_expr( expr.expr.as_deref(), @@ -319,7 +313,7 @@ pub fn parse_expr( .order_by(order_by) .window_frame(window_frame) .build() - .unwrap()) + .map_err(Error::DataFusionError) } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) @@ -331,7 +325,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); - Ok(Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), @@ -341,7 +335,7 @@ pub fn parse_expr( .order_by(order_by) .window_frame(window_frame) .build() - .unwrap()) + .map_err(Error::DataFusionError) } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { @@ -353,7 +347,7 @@ pub fn parse_expr( parse_optional_expr(expr.expr.as_deref(), registry, codec)? .map(|e| vec![e]) .unwrap_or_else(Vec::new); - Ok(Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -361,7 +355,7 @@ pub fn parse_expr( .order_by(order_by) .window_frame(window_frame) .build() - .unwrap()) + .map_err(Error::DataFusionError) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -373,7 +367,7 @@ pub fn parse_expr( parse_optional_expr(expr.expr.as_deref(), registry, codec)? .map(|e| vec![e]) .unwrap_or_else(Vec::new); - Ok(Expr::WindowFunction(WindowFunction::new( + Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) @@ -381,7 +375,7 @@ pub fn parse_expr( .order_by(order_by) .window_frame(window_frame) .build() - .unwrap()) + .map_err(Error::DataFusionError) } } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index fd759c161381..2506ef740fde 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::str::FromStr; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; + use arrow_schema::DataType; use datafusion_common::{ internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; use datafusion_expr::planner::PlannerResult; -use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition, @@ -36,7 +38,7 @@ use sqlparser::ast::{ FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, NullTreatment, ObjectName, OrderByExpr, WindowType, }; -use std::str::FromStr; + use strum::IntoEnumIterator; /// Suggest a valid function based on an invalid input function name @@ -306,14 +308,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .window_frame .as_ref() .map(|window_frame| { - let window_frame = window_frame.clone().try_into()?; - check_window_frame(&window_frame, order_by.len()) + let window_frame: WindowFrame = window_frame.clone().try_into()?; + window_frame + .regularize_order_bys(&mut order_by) .map(|_| window_frame) }) .transpose()?; let window_frame = if let Some(window_frame) = window_frame { - regularize_window_order_by(&window_frame, &mut order_by)?; window_frame } else if let Some(is_ordering_strict) = is_ordering_strict { WindowFrame::new(Some(is_ordering_strict)) @@ -322,7 +324,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; if let Ok(fun) = self.find_window_func(&name) { - let expr = match fun { + return match fun { WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { let args = self.function_args_to_expr(args, schema, planner_context)?; @@ -336,7 +338,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .window_frame(window_frame) .null_treatment(null_treatment) .build() - .unwrap() } _ => Expr::WindowFunction(expr::WindowFunction::new( fun, @@ -346,10 +347,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .order_by(order_by) .window_frame(window_frame) .null_treatment(null_treatment) - .build() - .unwrap(), + .build(), }; - return Ok(expr); } } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function