Skip to content

Commit

Permalink
Refactor/simplify window frame utils (apache#11648)
Browse files Browse the repository at this point in the history
* Simplify window frame utils

* Remove unwrap calls

* Fix format

* Incorporate review feedback
  • Loading branch information
ozankabak authored Jul 25, 2024
1 parent 49d9d45 commit 7db4213
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 105 deletions.
64 changes: 40 additions & 24 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::sync::Arc;

use arrow::array::{ArrayRef, Int32Array};
use arrow::array::{ArrayRef, Int32Array, StringArray};
use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -45,6 +45,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;

use hashbrown::HashMap;
use rand::distributions::Alphanumeric;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};

Expand Down Expand Up @@ -607,25 +608,6 @@ fn convert_bound_to_current_row_if_applicable(
}
}

/// This utility determines whether a given window frame can be executed with
/// multiple ORDER BY expressions. As an example, range frames with offset (such
/// as `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING`) cannot have ORDER BY clauses
/// of the form `\[ORDER BY a ASC, b ASC, ...]`
fn can_accept_multi_orderby(window_frame: &WindowFrame) -> bool {
match window_frame.units {
WindowFrameUnits::Rows => true,
WindowFrameUnits::Range => {
// Range can only accept multi ORDER BY clauses when bounds are
// CURRENT ROW or UNBOUNDED PRECEDING/FOLLOWING:
(window_frame.start_bound.is_unbounded()
|| window_frame.start_bound == WindowFrameBound::CurrentRow)
&& (window_frame.end_bound.is_unbounded()
|| window_frame.end_bound == WindowFrameBound::CurrentRow)
}
WindowFrameUnits::Groups => true,
}
}

/// Perform batch and running window same input
/// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal
async fn run_window_test(
Expand All @@ -649,7 +631,7 @@ async fn run_window_test(
options: SortOptions::default(),
})
}
if orderby_exprs.len() > 1 && !can_accept_multi_orderby(&window_frame) {
if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() {
orderby_exprs = orderby_exprs[0..1].to_vec();
}
let mut partitionby_exprs = vec![];
Expand Down Expand Up @@ -733,11 +715,30 @@ async fn run_window_test(
)?) as _;
let task_ctx = ctx.task_ctx();
let collected_usual = collect(usual_window_exec, task_ctx.clone()).await?;
let collected_running = collect(running_window_exec, task_ctx).await?;
let collected_running = collect(running_window_exec, task_ctx)
.await?
.into_iter()
.filter(|b| b.num_rows() > 0)
.collect::<Vec<_>>();

// BoundedWindowAggExec should produce more chunk than the usual WindowAggExec.
// Otherwise it means that we cannot generate result in running mode.
assert!(collected_running.len() > collected_usual.len());
let err_msg = format!("Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}");
// Below check makes sure that, streaming execution generates more chunks than the bulk execution.
// Since algorithms and operators works on sliding windows in the streaming execution.
// However, in the current test setup for some random generated window frame clauses: It is not guaranteed
// for streaming execution to generate more chunk than its non-streaming counter part in the Linear mode.
// As an example window frame `OVER(PARTITION BY d ORDER BY a RANGE BETWEEN CURRENT ROW AND 9 FOLLOWING)`
// needs to receive a=10 to generate result for the rows where a=0. If the input data generated is between the range [0, 9].
// even in streaming mode, generated result will be single bulk as in the non-streaming version.
if search_mode != Linear {
assert!(
collected_running.len() > collected_usual.len(),
"{}",
err_msg
);
}

// compare
let usual_formatted = pretty_format_batches(&collected_usual)?.to_string();
let running_formatted = pretty_format_batches(&collected_running)?.to_string();
Expand Down Expand Up @@ -767,10 +768,17 @@ async fn run_window_test(
Ok(())
}

fn generate_random_string(rng: &mut StdRng, length: usize) -> String {
rng.sample_iter(&Alphanumeric)
.take(length)
.map(char::from)
.collect()
}

/// Return randomly sized record batches with:
/// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns
/// one random int32 column x
fn make_staggered_batches<const STREAM: bool>(
pub(crate) fn make_staggered_batches<const STREAM: bool>(
len: usize,
n_distinct: usize,
random_seed: u64,
Expand All @@ -779,6 +787,7 @@ fn make_staggered_batches<const STREAM: bool>(
let mut rng = StdRng::seed_from_u64(random_seed);
let mut input123: Vec<(i32, i32, i32)> = vec![(0, 0, 0); len];
let mut input4: Vec<i32> = vec![0; len];
let mut input5: Vec<String> = vec!["".to_string(); len];
input123.iter_mut().for_each(|v| {
*v = (
rng.gen_range(0..n_distinct) as i32,
Expand All @@ -788,17 +797,23 @@ fn make_staggered_batches<const STREAM: bool>(
});
input123.sort();
rng.fill(&mut input4[..]);
input5.iter_mut().for_each(|v| {
*v = generate_random_string(&mut rng, 1);
});
input5.sort();
let input1 = Int32Array::from_iter_values(input123.iter().map(|k| k.0));
let input2 = Int32Array::from_iter_values(input123.iter().map(|k| k.1));
let input3 = Int32Array::from_iter_values(input123.iter().map(|k| k.2));
let input4 = Int32Array::from_iter_values(input4);
let input5 = StringArray::from_iter_values(input5);

// split into several record batches
let mut remainder = RecordBatch::try_from_iter(vec![
("a", Arc::new(input1) as ArrayRef),
("b", Arc::new(input2) as ArrayRef),
("c", Arc::new(input3) as ArrayRef),
("x", Arc::new(input4) as ArrayRef),
("string_field", Arc::new(input5) as ArrayRef),
])
.unwrap();

Expand All @@ -807,6 +822,7 @@ fn make_staggered_batches<const STREAM: bool>(
while remainder.num_rows() > 0 {
let batch_size = rng.gen_range(0..50);
if remainder.num_rows() < batch_size {
batches.push(remainder);
break;
}
batches.push(remainder.slice(0, batch_size));
Expand Down
89 changes: 40 additions & 49 deletions datafusion/expr/src/window_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
use std::fmt::{self, Formatter};
use std::hash::Hash;

use crate::expr::Sort;
use crate::Expr;
use crate::{lit, Expr};

use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue};
use sqlparser::ast;
Expand Down Expand Up @@ -246,59 +245,51 @@ impl WindowFrame {
causal,
}
}
}

/// Regularizes ORDER BY clause for window definition for implicit corner cases.
pub fn regularize_window_order_by(
frame: &WindowFrame,
order_by: &mut Vec<Expr>,
) -> Result<()> {
if frame.units == WindowFrameUnits::Range && order_by.len() != 1 {
// Normally, RANGE frames require an ORDER BY clause with exactly one
// column. However, an ORDER BY clause may be absent or present but with
// more than one column in two edge cases:
// 1. start bound is UNBOUNDED or CURRENT ROW
// 2. end bound is CURRENT ROW or UNBOUNDED.
// In these cases, we regularize the ORDER BY clause if the ORDER BY clause
// is absent. If an ORDER BY clause is present but has more than one column,
// the ORDER BY clause is unchanged. Note that this follows Postgres behavior.
if (frame.start_bound.is_unbounded()
|| frame.start_bound == WindowFrameBound::CurrentRow)
&& (frame.end_bound == WindowFrameBound::CurrentRow
|| frame.end_bound.is_unbounded())
{
// If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause
// with constant value as sort key.
// If an ORDER BY clause is present but has more than one column, it is
// unchanged.
if order_by.is_empty() {
order_by.push(Expr::Sort(Sort::new(
Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))),
true,
false,
)));
/// Regularizes the ORDER BY clause of the window frame.
pub fn regularize_order_bys(&self, order_by: &mut Vec<Expr>) -> Result<()> {
match self.units {
// Normally, RANGE frames require an ORDER BY clause with exactly
// one column. However, an ORDER BY clause may be absent or have
// more than one column when the start/end bounds are UNBOUNDED or
// CURRENT ROW.
WindowFrameUnits::Range if self.free_range() => {
// If an ORDER BY clause is absent, it is equivalent to an
// ORDER BY clause with constant value as sort key. If an
// ORDER BY clause is present but has more than one column,
// it is unchanged. Note that this follows PostgreSQL behavior.
if order_by.is_empty() {
order_by.push(lit(1u64).sort(true, false));
}
}
WindowFrameUnits::Range if order_by.len() != 1 => {
return plan_err!("RANGE requires exactly one ORDER BY column");
}
WindowFrameUnits::Groups if order_by.is_empty() => {
return plan_err!("GROUPS requires an ORDER BY clause");
}
_ => {}
}
Ok(())
}
Ok(())
}

/// Checks if given window frame is valid. In particular, if the frame is RANGE
/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column.
pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> {
if frame.units == WindowFrameUnits::Range && order_bys != 1 {
// See `regularize_window_order_by`.
if !(frame.start_bound.is_unbounded()
|| frame.start_bound == WindowFrameBound::CurrentRow)
|| !(frame.end_bound == WindowFrameBound::CurrentRow
|| frame.end_bound.is_unbounded())
{
plan_err!("RANGE requires exactly one ORDER BY column")?
/// Returns whether the window frame can accept multiple ORDER BY expressons.
pub fn can_accept_multi_orderby(&self) -> bool {
match self.units {
WindowFrameUnits::Rows => true,
WindowFrameUnits::Range => self.free_range(),
WindowFrameUnits::Groups => true,
}
} else if frame.units == WindowFrameUnits::Groups && order_bys == 0 {
plan_err!("GROUPS requires an ORDER BY clause")?
};
Ok(())
}

/// Returns whether the window frame is "free range"; i.e. its start/end
/// bounds are UNBOUNDED or CURRENT ROW.
fn free_range(&self) -> bool {
(self.start_bound.is_unbounded()
|| self.start_bound == WindowFrameBound::CurrentRow)
&& (self.end_bound.is_unbounded()
|| self.end_bound == WindowFrameBound::CurrentRow)
}
}

/// There are five ways to describe starting and ending frame boundaries:
Expand Down
38 changes: 16 additions & 22 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@ use std::sync::Arc;

use datafusion::execution::registry::FunctionRegistry;
use datafusion_common::{
internal_err, plan_datafusion_err, DataFusionError, Result, ScalarValue,
exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue,
TableReference, UnnestOptions,
};
use datafusion_expr::expr::Unnest;
use datafusion_expr::expr::{Alias, Placeholder};
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
use datafusion_expr::ExprFunctionExt;
use datafusion_expr::{
expr::{self, InList, Sort, WindowFunction},
expr::{self, Alias, InList, Placeholder, Sort, Unnest, WindowFunction},
logical_plan::{PlanType, StringifiedPlan},
AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr,
GroupingSet,
ExprFunctionExt, GroupingSet,
GroupingSet::GroupingSets,
JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound,
WindowFrameUnits,
Expand Down Expand Up @@ -289,24 +285,22 @@ pub fn parse_expr(
.window_frame
.as_ref()
.map::<Result<WindowFrame, _>, _>(|window_frame| {
let window_frame = window_frame.clone().try_into()?;
check_window_frame(&window_frame, order_by.len())
let window_frame: WindowFrame = window_frame.clone().try_into()?;
window_frame
.regularize_order_bys(&mut order_by)
.map(|_| window_frame)
})
.transpose()?
.ok_or_else(|| {
DataFusionError::Execution(
"missing window frame during deserialization".to_string(),
)
exec_datafusion_err!("missing window frame during deserialization")
})?;
// TODO: support proto for null treatment
regularize_window_order_by(&window_frame, &mut order_by)?;

// TODO: support proto for null treatment
match window_function {
window_expr_node::WindowFunction::AggrFunction(i) => {
let aggr_function = parse_i32_to_aggregate_function(i)?;

Ok(Expr::WindowFunction(WindowFunction::new(
Expr::WindowFunction(WindowFunction::new(
expr::WindowFunctionDefinition::AggregateFunction(aggr_function),
vec![parse_required_expr(
expr.expr.as_deref(),
Expand All @@ -319,7 +313,7 @@ pub fn parse_expr(
.order_by(order_by)
.window_frame(window_frame)
.build()
.unwrap())
.map_err(Error::DataFusionError)
}
window_expr_node::WindowFunction::BuiltInFunction(i) => {
let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i)
Expand All @@ -331,7 +325,7 @@ pub fn parse_expr(
.map(|e| vec![e])
.unwrap_or_else(Vec::new);

Ok(Expr::WindowFunction(WindowFunction::new(
Expr::WindowFunction(WindowFunction::new(
expr::WindowFunctionDefinition::BuiltInWindowFunction(
built_in_function,
),
Expand All @@ -341,7 +335,7 @@ pub fn parse_expr(
.order_by(order_by)
.window_frame(window_frame)
.build()
.unwrap())
.map_err(Error::DataFusionError)
}
window_expr_node::WindowFunction::Udaf(udaf_name) => {
let udaf_function = match &expr.fun_definition {
Expand All @@ -353,15 +347,15 @@ pub fn parse_expr(
parse_optional_expr(expr.expr.as_deref(), registry, codec)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
Expr::WindowFunction(WindowFunction::new(
expr::WindowFunctionDefinition::AggregateUDF(udaf_function),
args,
))
.partition_by(partition_by)
.order_by(order_by)
.window_frame(window_frame)
.build()
.unwrap())
.map_err(Error::DataFusionError)
}
window_expr_node::WindowFunction::Udwf(udwf_name) => {
let udwf_function = match &expr.fun_definition {
Expand All @@ -373,15 +367,15 @@ pub fn parse_expr(
parse_optional_expr(expr.expr.as_deref(), registry, codec)?
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
Expr::WindowFunction(WindowFunction::new(
expr::WindowFunctionDefinition::WindowUDF(udwf_function),
args,
))
.partition_by(partition_by)
.order_by(order_by)
.window_frame(window_frame)
.build()
.unwrap())
.map_err(Error::DataFusionError)
}
}
}
Expand Down
Loading

0 comments on commit 7db4213

Please sign in to comment.