From 1c7209b280ba0d8a7faa6a4a63ff5bc52a6fd9bc Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Wed, 21 Aug 2024 18:19:03 +0300 Subject: [PATCH 01/10] Use `schema_name` to create the `physical_name` (#11977) More consistency and less opportunity for column name mismatch. --- datafusion/core/src/physical_planner.rs | 13 +- datafusion/expr/src/expr.rs | 272 +----------------- .../src/aggregate.rs | 4 +- .../physical-plan/src/aggregates/mod.rs | 1 + 4 files changed, 17 insertions(+), 273 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6536f9a01439f..8d6c5089fa34d 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -73,8 +73,7 @@ use datafusion_common::{ }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ - self, create_function_physical_name, physical_name, AggregateFunction, Alias, - GroupingSet, WindowFunction, + self, physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -1569,12 +1568,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let name = if let Some(name) = name { name } else { - create_function_physical_name( - func.name(), - *distinct, - args, - order_by.as_ref(), - )? + physical_name(e)? }; let physical_args = @@ -1588,8 +1582,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( None => None, }; - let ignore_nulls = null_treatment - .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) + let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 88939ccf41b8c..85ba80396c8e8 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -38,8 +38,7 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - internal_err, not_impl_err, plan_err, Column, DFSchema, Result, ScalarValue, - TableReference, + plan_err, Column, DFSchema, Result, ScalarValue, TableReference, }; use sqlparser::ast::{ display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, @@ -1082,7 +1081,7 @@ impl Expr { /// For example, for a projection (e.g. `SELECT `) the resulting arrow /// [`Schema`] will have a field with this name. /// - /// Note that the resulting string is subtlety different than the `Display` + /// Note that the resulting string is subtlety different from the `Display` /// representation for certain `Expr`. Some differences: /// /// 1. [`Expr::Alias`], which shows only the alias itself @@ -1104,6 +1103,7 @@ impl Expr { } /// Returns a full and complete string representation of this expression. + #[deprecated(note = "use format! instead")] pub fn canonical_name(&self) -> String { format!("{self}") } @@ -2386,263 +2386,13 @@ fn fmt_function( write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) } -pub fn create_function_physical_name( - fun: &str, - distinct: bool, - args: &[Expr], - order_by: Option<&Vec>, -) -> Result { - let names: Vec = args - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>()?; - - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - - let phys_name = format!("{}({}{})", fun, distinct_str, names.join(",")); - - Ok(order_by - .map(|order_by| format!("{} ORDER BY [{}]", phys_name, expr_vec_fmt!(order_by))) - .unwrap_or(phys_name)) -} - -pub fn physical_name(e: &Expr) -> Result { - create_physical_name(e, true) -} - -fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { - match e { - Expr::Unnest(_) => { - internal_err!( - "Expr::Unnest should have been converted to LogicalPlan::Unnest" - ) - } - Expr::Column(c) => { - if is_first_expr { - Ok(c.name.clone()) - } else { - Ok(c.flat_name()) - } - } - Expr::Alias(Alias { name, .. }) => Ok(name.clone()), - Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")), - Expr::Literal(value) => Ok(format!("{value:?}")), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let left = create_physical_name(left, false)?; - let right = create_physical_name(right, false)?; - Ok(format!("{left} {op} {right}")) - } - Expr::Case(case) => { - let mut name = "CASE ".to_string(); - if let Some(e) = &case.expr { - let _ = write!(name, "{} ", create_physical_name(e, false)?); - } - for (w, t) in &case.when_then_expr { - let _ = write!( - name, - "WHEN {} THEN {} ", - create_physical_name(w, false)?, - create_physical_name(t, false)? - ); - } - if let Some(e) = &case.else_expr { - let _ = write!(name, "ELSE {} ", create_physical_name(e, false)?); - } - name += "END"; - Ok(name) - } - Expr::Cast(Cast { expr, .. }) => { - // CAST does not change the expression name - create_physical_name(expr, false) - } - Expr::TryCast(TryCast { expr, .. }) => { - // CAST does not change the expression name - create_physical_name(expr, false) - } - Expr::Not(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("NOT {expr}")) - } - Expr::Negative(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("(- {expr})")) - } - Expr::IsNull(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NULL")) - } - Expr::IsNotNull(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NOT NULL")) - } - Expr::IsTrue(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS TRUE")) - } - Expr::IsFalse(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS FALSE")) - } - Expr::IsUnknown(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS UNKNOWN")) - } - Expr::IsNotTrue(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NOT TRUE")) - } - Expr::IsNotFalse(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NOT FALSE")) - } - Expr::IsNotUnknown(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NOT UNKNOWN")) - } - Expr::ScalarFunction(fun) => fun.func.schema_name(&fun.args), - Expr::WindowFunction(WindowFunction { - fun, - args, - order_by, - .. - }) => { - create_function_physical_name(&fun.to_string(), false, args, Some(order_by)) - } - Expr::AggregateFunction(AggregateFunction { - func, - distinct, - args, - filter: _, - order_by, - null_treatment: _, - }) => { - create_function_physical_name(func.name(), *distinct, args, order_by.as_ref()) - } - Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => Ok(format!( - "ROLLUP ({})", - exprs - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>>()? - .join(", ") - )), - GroupingSet::Cube(exprs) => Ok(format!( - "CUBE ({})", - exprs - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>>()? - .join(", ") - )), - GroupingSet::GroupingSets(lists_of_exprs) => { - let mut strings = vec![]; - for exprs in lists_of_exprs { - let exprs_str = exprs - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>>()? - .join(", "); - strings.push(format!("({exprs_str})")); - } - Ok(format!("GROUPING SETS ({})", strings.join(", "))) - } - }, - - Expr::InList(InList { - expr, - list, - negated, - }) => { - let expr = create_physical_name(expr, false)?; - let list = list.iter().map(|expr| create_physical_name(expr, false)); - if *negated { - Ok(format!("{expr} NOT IN ({list:?})")) - } else { - Ok(format!("{expr} IN ({list:?})")) - } - } - Expr::Exists { .. } => { - not_impl_err!("EXISTS is not yet supported in the physical plan") - } - Expr::InSubquery(_) => { - not_impl_err!("IN subquery is not yet supported in the physical plan") - } - Expr::ScalarSubquery(_) => { - not_impl_err!("Scalar subqueries are not yet supported in the physical plan") - } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - let expr = create_physical_name(expr, false)?; - let low = create_physical_name(low, false)?; - let high = create_physical_name(high, false)?; - if *negated { - Ok(format!("{expr} NOT BETWEEN {low} AND {high}")) - } else { - Ok(format!("{expr} BETWEEN {low} AND {high}")) - } - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { - let expr = create_physical_name(expr, false)?; - let pattern = create_physical_name(pattern, false)?; - let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" }; - let escape = if let Some(char) = escape_char { - format!("CHAR '{char}'") - } else { - "".to_string() - }; - if *negated { - Ok(format!("{expr} NOT {op_name} {pattern}{escape}")) - } else { - Ok(format!("{expr} {op_name} {pattern}{escape}")) - } - } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { - let expr = create_physical_name(expr, false)?; - let pattern = create_physical_name(pattern, false)?; - let escape = if let Some(char) = escape_char { - format!("CHAR '{char}'") - } else { - "".to_string() - }; - if *negated { - Ok(format!("{expr} NOT SIMILAR TO {pattern}{escape}")) - } else { - Ok(format!("{expr} SIMILAR TO {pattern}{escape}")) - } - } - Expr::Sort { .. } => { - internal_err!("Create physical name does not support sort expression") - } - Expr::Wildcard { qualifier, options } => match qualifier { - Some(qualifier) => Ok(format!("{}.*{}", qualifier, options)), - None => Ok(format!("*{}", options)), - }, - Expr::Placeholder(_) => { - internal_err!("Create physical name does not support placeholder") - } - Expr::OuterReferenceColumn(_, _) => { - internal_err!("Create physical name does not support OuterReferenceColumn") - } +/// The name of the column (field) that this `Expr` will produce in the physical plan. +/// The difference from [Expr::schema_name] is that top-level columns are unqualified. +pub fn physical_name(expr: &Expr) -> Result { + if let Expr::Column(col) = expr { + Ok(col.name.clone()) + } else { + Ok(expr.schema_name().to_string()) } } @@ -2658,6 +2408,7 @@ mod test { use std::any::Any; #[test] + #[allow(deprecated)] fn format_case_when() -> Result<()> { let expr = case(col("a")) .when(lit(1), lit(true)) @@ -2670,6 +2421,7 @@ mod test { } #[test] + #[allow(deprecated)] fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), diff --git a/datafusion/physical-expr-functions-aggregate/src/aggregate.rs b/datafusion/physical-expr-functions-aggregate/src/aggregate.rs index aa1d1999a3395..fd986e00a7ef3 100644 --- a/datafusion/physical-expr-functions-aggregate/src/aggregate.rs +++ b/datafusion/physical-expr-functions-aggregate/src/aggregate.rs @@ -18,7 +18,6 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::ScalarValue; use datafusion_common::{internal_err, not_impl_err, Result}; -use datafusion_expr::expr::create_function_physical_name; use datafusion_expr::AggregateUDF; use datafusion_expr::ReversedUDAF; use datafusion_expr_common::accumulator::Accumulator; @@ -112,8 +111,7 @@ impl AggregateExprBuilder { let data_type = fun.return_type(&input_exprs_types)?; let is_nullable = fun.is_nullable(); let name = match alias { - // TODO: Ideally, we should build the name from physical expressions - None => create_function_physical_name(fun.name(), is_distinct, &[], None)?, + None => return internal_err!("alias should be provided"), Some(alias) => alias, }; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 89d4c452cca65..5aa255e7c341a 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -2179,6 +2179,7 @@ mod tests { .map(|order_by_expr| { let ordering_req = order_by_expr.unwrap_or_default(); AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)]) + .alias("a") .order_by(ordering_req.to_vec()) .schema(Arc::clone(&test_schema)) .build() From 78f58c80476ef2d2b10f4551230db4a610a9a32d Mon Sep 17 00:00:00 2001 From: JC <1950050+jc4x4@users.noreply.github.com> Date: Wed, 21 Aug 2024 23:21:52 +0800 Subject: [PATCH 02/10] Add new user doc to translate logical plan to physical plan (#12026) * Add new user doc to translate logical plan to physical plan https://github.com/apache/datafusion/issues/7306 * prettier * Run doc examples as part of cargo --doc * Update first example to run * Fix next example * fix last example * prettier * clarify table source * prettier * Revert changes --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/lib.rs | 6 + datafusion/expr/src/logical_plan/mod.rs | 2 +- .../building-logical-plans.md | 187 ++++++++++++------ 3 files changed, 139 insertions(+), 56 deletions(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index daeb21db9d05c..735a381586ad1 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -678,6 +678,12 @@ doc_comment::doctest!( library_user_guide_sql_api ); +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/building-logical-plans.md", + library_user_guide_logical_plans +); + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/library-user-guide/using-the-dataframe-api.md", diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index b58208591920b..5b5a842fa4cf8 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -26,7 +26,7 @@ pub mod tree_node; pub use builder::{ build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, - LogicalPlanBuilder, UNNAMED_TABLE, + LogicalPlanBuilder, LogicalTableSource, UNNAMED_TABLE, }; pub use ddl::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, diff --git a/docs/source/library-user-guide/building-logical-plans.md b/docs/source/library-user-guide/building-logical-plans.md index fe922d8eaeb11..556deb02e9800 100644 --- a/docs/source/library-user-guide/building-logical-plans.md +++ b/docs/source/library-user-guide/building-logical-plans.md @@ -31,44 +31,52 @@ explained in more detail in the [Query Planning and Execution Overview] section DataFusion's [LogicalPlan] is an enum containing variants representing all the supported operators, and also contains an `Extension` variant that allows projects building on DataFusion to add custom logical operators. -It is possible to create logical plans by directly creating instances of the [LogicalPlan] enum as follows, but is is +It is possible to create logical plans by directly creating instances of the [LogicalPlan] enum as shown, but it is much easier to use the [LogicalPlanBuilder], which is described in the next section. Here is an example of building a logical plan directly: - - ```rust -// create a logical table source -let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("name", DataType::Utf8, true), -]); -let table_source = LogicalTableSource::new(SchemaRef::new(schema)); - -// create a TableScan plan -let projection = None; // optional projection -let filters = vec![]; // optional filters to push down -let fetch = None; // optional LIMIT -let table_scan = LogicalPlan::TableScan(TableScan::try_new( - "person", - Arc::new(table_source), - projection, - filters, - fetch, -)?); - -// create a Filter plan that evaluates `id > 500` that wraps the TableScan -let filter_expr = col("id").gt(lit(500)); -let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan))?); - -// print the plan -println!("{}", plan.display_indent_schema()); +use datafusion::common::DataFusionError; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::logical_expr::{Filter, LogicalPlan, TableScan, LogicalTableSource}; +use datafusion::prelude::*; +use std::sync::Arc; + +fn main() -> Result<(), DataFusionError> { + // create a logical table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + + // create a TableScan plan + let projection = None; // optional projection + let filters = vec![]; // optional filters to push down + let fetch = None; // optional LIMIT + let table_scan = LogicalPlan::TableScan(TableScan::try_new( + "person", + Arc::new(table_source), + projection, + filters, + fetch, + )? + ); + + // create a Filter plan that evaluates `id > 500` that wraps the TableScan + let filter_expr = col("id").gt(lit(500)); + let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan)) ? ); + + // print the plan + println!("{}", plan.display_indent_schema()); + Ok(()) +} ``` This example produces the following plan: -``` +```text Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] TableScan: person [id:Int32;N, name:Utf8;N] ``` @@ -78,7 +86,7 @@ Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] DataFusion logical plans can be created using the [LogicalPlanBuilder] struct. There is also a [DataFrame] API which is a higher-level API that delegates to [LogicalPlanBuilder]. -The following associated functions can be used to create a new builder: +There are several functions that can can be used to create a new builder, such as - `empty` - create an empty plan with no fields - `values` - create a plan from a set of literal values @@ -102,41 +110,107 @@ The following example demonstrates building the same simple query plan as the pr ```rust -// create a logical table source -let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("name", DataType::Utf8, true), -]); -let table_source = LogicalTableSource::new(SchemaRef::new(schema)); - -// optional projection -let projection = None; - -// create a LogicalPlanBuilder for a table scan -let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; - -// perform a filter operation and build the plan -let plan = builder - .filter(col("id").gt(lit(500)))? // WHERE id > 500 - .build()?; - -// print the plan -println!("{}", plan.display_indent_schema()); +use datafusion::common::DataFusionError; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::logical_expr::{LogicalPlanBuilder, LogicalTableSource}; +use datafusion::prelude::*; +use std::sync::Arc; + +fn main() -> Result<(), DataFusionError> { + // create a logical table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + + // optional projection + let projection = None; + + // create a LogicalPlanBuilder for a table scan + let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; + + // perform a filter operation and build the plan + let plan = builder + .filter(col("id").gt(lit(500)))? // WHERE id > 500 + .build()?; + + // print the plan + println!("{}", plan.display_indent_schema()); + Ok(()) +} ``` This example produces the following plan: -``` +```text Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] TableScan: person [id:Int32;N, name:Utf8;N] ``` +## Translating Logical Plan to Physical Plan + +Logical plans can not be directly executed. They must be "compiled" into an +[`ExecutionPlan`], which is often referred to as a "physical plan". + +Compared to `LogicalPlan`s `ExecutionPlans` have many more details such as +specific algorithms and detailed optimizations compared to. Given a +`LogicalPlan` the easiest way to create an `ExecutionPlan` is using +[`SessionState::create_physical_plan`] as shown below + +```rust +use datafusion::datasource::{provider_as_source, MemTable}; +use datafusion::common::DataFusionError; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::logical_expr::{LogicalPlanBuilder, LogicalTableSource}; +use datafusion::prelude::*; +use std::sync::Arc; + +// Creating physical plans may access remote catalogs and data sources +// thus it must be run with an async runtime. +#[tokio::main] +async fn main() -> Result<(), DataFusionError> { + + // create a default table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + // To create an ExecutionPlan we must provide an actual + // TableProvider. For this example, we don't provide any data + // but in production code, this would have `RecordBatch`es with + // in memory data + let table_provider = Arc::new(MemTable::try_new(Arc::new(schema), vec![])?); + // Use the provider_as_source function to convert the TableProvider to a table source + let table_source = provider_as_source(table_provider); + + // create a LogicalPlanBuilder for a table scan without projection or filters + let logical_plan = LogicalPlanBuilder::scan("person", table_source, None)?.build()?; + + // Now create the physical plan by calling `create_physical_plan` + let ctx = SessionContext::new(); + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + + // print the plan + println!("{}", DisplayableExecutionPlan::new(physical_plan.as_ref()).indent(true)); + Ok(()) +} +``` + +This example produces the following physical plan: + +```text +MemoryExec: partitions=0, partition_sizes=[] +``` + ## Table Sources -The previous example used a [LogicalTableSource], which is used for tests and documentation in DataFusion, and is also -suitable if you are using DataFusion to build logical plans but do not use DataFusion's physical planner. However, if you -want to use a [TableSource] that can be executed in DataFusion then you will need to use [DefaultTableSource], which is a -wrapper for a [TableProvider]. +The previous examples use a [LogicalTableSource], which is used for tests and documentation in DataFusion, and is also +suitable if you are using DataFusion to build logical plans but do not use DataFusion's physical planner. + +However, it is more common to use a [TableProvider]. To get a [TableSource] from a +[TableProvider], use [provider_as_source] or [DefaultTableSource]. [query planning and execution overview]: https://docs.rs/datafusion/latest/datafusion/index.html#query-planning-and-execution-overview [architecture guide]: https://docs.rs/datafusion/latest/datafusion/index.html#architecture @@ -145,5 +219,8 @@ wrapper for a [TableProvider]. [dataframe]: using-the-dataframe-api.md [logicaltablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/builder/struct.LogicalTableSource.html [defaulttablesource]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html +[provider_as_source]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/fn.provider_as_source.html [tableprovider]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html [tablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/trait.TableSource.html +[`executionplan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html +[`sessionstate::create_physical_plan`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.create_physical_plan From 7eeac2f5c25c8bf606e463172916004b9d645da7 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Wed, 21 Aug 2024 08:23:28 -0700 Subject: [PATCH 03/10] Improve rpad udf by using a GenericStringBuilder (#12070) * Improve rpad udf by using a GenericStringBuilder * fix format * refine code --- datafusion/functions/benches/pad.rs | 11 +- datafusion/functions/src/unicode/rpad.rs | 333 ++++++++++++----------- 2 files changed, 180 insertions(+), 164 deletions(-) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index 5ff1e2fb860d4..0c496bc633477 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -127,11 +127,12 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) }); - // - // let args = create_args::(size, 32, true); - // group.bench_function(BenchmarkId::new("stringview type", size), |b| { - // b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) - // }); + + // rpad for stringview type + let args = create_args::(size, 32, true); + group.bench_function(BenchmarkId::new("stringview type", size), |b| { + b.iter(|| criterion::black_box(rpad().invoke(&args).unwrap())) + }); group.finish(); } diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 4bcf102c8793d..c1d6f327928f2 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -15,20 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; -use arrow::datatypes::DataType; -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, -}; -use unicode_segmentation::UnicodeSegmentation; - +use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; +use arrow::array::{ + ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringViewArray, +}; +use arrow::datatypes::DataType; +use datafusion_common::cast::as_int64_array; +use datafusion_common::DataFusionError; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::fmt::Write; +use std::sync::Arc; +use unicode_segmentation::UnicodeSegmentation; +use DataType::{LargeUtf8, Utf8, Utf8View}; #[derive(Debug)] pub struct RPadFunc { @@ -84,170 +87,182 @@ impl ScalarUDFImpl for RPadFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args.len() { - 2 => match args[0].data_type() { - DataType::Utf8 | DataType::Utf8View => { - make_scalar_function(rpad::, vec![])(args) - } - DataType::LargeUtf8 => { - make_scalar_function(rpad::, vec![])(args) - } - other => exec_err!("Unsupported data type {other:?} for function rpad"), - }, - 3 => match (args[0].data_type(), args[2].data_type()) { - ( - DataType::Utf8 | DataType::Utf8View, - DataType::Utf8 | DataType::Utf8View, - ) => make_scalar_function(rpad::, vec![])(args), - (DataType::LargeUtf8, DataType::LargeUtf8) => { - make_scalar_function(rpad::, vec![])(args) - } - (DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => { - make_scalar_function(rpad::, vec![])(args) - } - (DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => { - make_scalar_function(rpad::, vec![])(args) - } - (first_type, last_type) => { - exec_err!("unsupported arguments type for rpad, first argument type is {}, last argument type is {}", first_type, last_type) - } - }, - number => { - exec_err!("unsupported arguments number {} for rpad", number) + match ( + args.len(), + args[0].data_type(), + args.get(2).map(|arg| arg.data_type()), + ) { + (2, Utf8 | Utf8View, _) => { + make_scalar_function(rpad::, vec![])(args) + } + (2, LargeUtf8, _) => make_scalar_function(rpad::, vec![])(args), + (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => { + make_scalar_function(rpad::, vec![])(args) + } + (3, LargeUtf8, Some(LargeUtf8)) => { + make_scalar_function(rpad::, vec![])(args) + } + (3, Utf8 | Utf8View, Some(LargeUtf8)) => { + make_scalar_function(rpad::, vec![])(args) + } + (3, LargeUtf8, Some(Utf8 | Utf8View)) => { + make_scalar_function(rpad::, vec![])(args) + } + (_, _, _) => { + exec_err!("Unsupported combination of data types for function rpad") } } } } -macro_rules! process_rpad { - // For the two-argument case - ($string_array:expr, $length_array:expr) => {{ - $string_array - .iter() - .zip($length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!("rpad requested length {} too large", length); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s = string.to_string(); - s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>() - }}; - - // For the three-argument case - ($string_array:expr, $length_array:expr, $fill_array:expr) => {{ - $string_array - .iter() - .zip($length_array.iter()) - .zip($fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!("rpad requested length {} too large", length); - } - - let length = if length < 0 { 0 } else { length as usize }; - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); +pub fn rpad( + args: &[ArrayRef], +) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!( + "rpad was called with {} arguments. It requires 2 or 3 arguments.", + args.len() + ); + } - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let char_vector: Vec = (0..length - graphemes.len()) - .map(|l| fill_chars[l % fill_chars.len()]) - .collect(); - s.push_str(&char_vector.iter().collect::()); - Ok(Some(s)) - } - } - _ => Ok(None), - }) - .collect::>>() - }}; + let length_array = as_int64_array(&args[1])?; + match ( + args.len(), + args[0].data_type(), + args.get(2).map(|arg| arg.data_type()), + ) { + (2, Utf8View, _) => { + rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( + args[0].as_string_view(), + length_array, + None, + ) + } + (3, Utf8View, Some(Utf8View)) => { + rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( + args[0].as_string_view(), + length_array, + Some(args[2].as_string_view()), + ) + } + (3, Utf8View, Some(Utf8 | LargeUtf8)) => { + rpad_impl::<&StringViewArray, &GenericStringArray, StringArrayLen>( + args[0].as_string_view(), + length_array, + Some(args[2].as_string::()), + ) + } + (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::< + &GenericStringArray, + &StringViewArray, + StringArrayLen, + >( + args[0].as_string::(), + length_array, + Some(args[2].as_string_view()), + ), + (_, _, _) => rpad_impl::< + &GenericStringArray, + &GenericStringArray, + StringArrayLen, + >( + args[0].as_string::(), + length_array, + args.get(2).map(|arg| arg.as_string::()), + ), + } } /// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. /// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad( - args: &[ArrayRef], -) -> Result { - match (args.len(), args[0].data_type()) { - (2, DataType::Utf8View) => { - let string_array = as_string_view_array(&args[0])?; - let length_array = as_int64_array(&args[1])?; +pub fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>( + string_array: StringArrType, + length_array: &Int64Array, + fill_array: Option, +) -> Result +where + StringArrType: StringArrayType<'a>, + FillArrType: StringArrayType<'a>, + StringArrayLen: OffsetSizeTrait, +{ + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let result = process_rpad!(string_array, length_array)?; - Ok(Arc::new(result) as ArrayRef) + match fill_array { + None => { + string_array.iter().zip(length_array.iter()).try_for_each( + |(string, length)| -> Result<(), DataFusionError> { + match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {} too large", + length + ); + } + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + builder.append_value(""); + } else { + let graphemes = + string.graphemes(true).collect::>(); + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else { + builder.write_str(string)?; + builder.write_str( + &" ".repeat(length - graphemes.len()), + )?; + builder.append_value(""); + } + } + } + _ => builder.append_null(), + } + Ok(()) + }, + )?; } - (2, _) => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; + Some(fill_array) => { + string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .try_for_each( + |((string, length), fill)| -> Result<(), DataFusionError> { + match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {} too large", + length + ); + } + let length = if length < 0 { 0 } else { length as usize }; + let graphemes = + string.graphemes(true).collect::>(); - let result = process_rpad!(string_array, length_array)?; - Ok(Arc::new(result) as ArrayRef) - } - (3, DataType::Utf8View) => { - let string_array = as_string_view_array(&args[0])?; - let length_array = as_int64_array(&args[1])?; - match args[2].data_type() { - DataType::Utf8View => { - let fill_array = as_string_view_array(&args[2])?; - let result = process_rpad!(string_array, length_array, fill_array)?; - Ok(Arc::new(result) as ArrayRef) - } - DataType::Utf8 | DataType::LargeUtf8 => { - let fill_array = as_generic_string_array::(&args[2])?; - let result = process_rpad!(string_array, length_array, fill_array)?; - Ok(Arc::new(result) as ArrayRef) - } - other_type => { - exec_err!("unsupported type for rpad's third operator: {}", other_type) - } - } - } - (3, _) => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - match args[2].data_type() { - DataType::Utf8View => { - let fill_array = as_string_view_array(&args[2])?; - let result = process_rpad!(string_array, length_array, fill_array)?; - Ok(Arc::new(result) as ArrayRef) - } - DataType::Utf8 | DataType::LargeUtf8 => { - let fill_array = as_generic_string_array::(&args[2])?; - let result = process_rpad!(string_array, length_array, fill_array)?; - Ok(Arc::new(result) as ArrayRef) - } - other_type => { - exec_err!("unsupported type for rpad's third operator: {}", other_type) - } - } + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else if fill.is_empty() { + builder.append_value(string); + } else { + builder.write_str(string)?; + fill.chars() + .cycle() + .take(length - graphemes.len()) + .for_each(|ch| builder.write_char(ch).unwrap()); + builder.append_value(""); + } + } + _ => builder.append_null(), + } + Ok(()) + }, + )?; } - (other, other_type) => exec_err!( - "rpad requires 2 or 3 arguments with corresponding types, but got {}. number of arguments with {}", - other, other_type - ), } + + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] From 9d076bde4fb099329abf33f57e319fd47c523561 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 21 Aug 2024 23:23:49 +0800 Subject: [PATCH 04/10] fix: Panic non-integer for the second argument of `nth_value` function (#12076) * fix: Panic non-integer for nth_value function * chore: Display actual value * Update datafusion/physical-plan/src/windows/mod.rs Co-authored-by: Marco Neumann * chore --------- Co-authored-by: Marco Neumann --- datafusion/physical-plan/src/windows/mod.rs | 8 ++++++-- datafusion/sqllogictest/test_files/window.slt | 13 +++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index f938f4410a992..63f4ffcfaacc2 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -30,7 +30,9 @@ use crate::{ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::{ BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, WindowUDF, @@ -284,7 +286,9 @@ fn create_built_in_window_expr( args[1] .as_any() .downcast_ref::() - .unwrap() + .ok_or_else(|| { + exec_datafusion_err!("Expected a signed integer literal for the second argument of nth_value, got {}", args[1]) + })? .value() .clone(), )?; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 78055f8c1c11b..5bf5cf83284f6 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4861,3 +4861,16 @@ select a, row_number(a) over (order by b) as rn from t; statement ok drop table t; + +statement ok +DROP TABLE t1; + +# https://github.com/apache/datafusion/issues/12073 +statement ok +CREATE TABLE t1(v1 BIGINT); + +query error DataFusion error: Execution error: Expected a signed integer literal for the second argument of nth_value, got v1@0 +SELECT NTH_VALUE('+Inf'::Double, v1) OVER (PARTITION BY v1) FROM t1; + +statement ok +DROP TABLE t1; \ No newline at end of file From eca71c4cb144795d1c073cbf918e94bd9d0e4102 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 21 Aug 2024 11:24:03 -0400 Subject: [PATCH 05/10] Remove vestigal `datafusion-docs` module compilation (#12081) * Remove vestigal `datafusion-docs` module compilation * fix build --- Cargo.toml | 1 - docs/Cargo.toml | 35 -------------- docs/src/lib.rs | 19 -------- docs/src/library_logical_plan.rs | 78 -------------------------------- 4 files changed, 133 deletions(-) delete mode 100644 docs/Cargo.toml delete mode 100644 docs/src/lib.rs delete mode 100644 docs/src/library_logical_plan.rs diff --git a/Cargo.toml b/Cargo.toml index ae344a46a1bd3..d82443f5d1c8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,6 @@ members = [ "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", - "docs", "test-utils", "benchmarks", ] diff --git a/docs/Cargo.toml b/docs/Cargo.toml deleted file mode 100644 index 14398c8415791..0000000000000 --- a/docs/Cargo.toml +++ /dev/null @@ -1,35 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "datafusion-docs-tests" -description = "DataFusion Documentation Tests" -publish = false -version = { workspace = true } -edition = { workspace = true } -readme = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } -license = { workspace = true } -authors = { workspace = true } -rust-version = { workspace = true } - -[lints] -workspace = true - -[dependencies] -datafusion = { workspace = true } diff --git a/docs/src/lib.rs b/docs/src/lib.rs deleted file mode 100644 index f73132468ec9e..0000000000000 --- a/docs/src/lib.rs +++ /dev/null @@ -1,19 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#[cfg(test)] -mod library_logical_plan; diff --git a/docs/src/library_logical_plan.rs b/docs/src/library_logical_plan.rs deleted file mode 100644 index 3550039415706..0000000000000 --- a/docs/src/library_logical_plan.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::error::Result; -use datafusion::logical_expr::builder::LogicalTableSource; -use datafusion::logical_expr::{Filter, LogicalPlan, LogicalPlanBuilder, TableScan}; -use datafusion::prelude::*; -use std::sync::Arc; - -#[test] -fn plan_1() -> Result<()> { - // create a logical table source - let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("name", DataType::Utf8, true), - ]); - let table_source = LogicalTableSource::new(SchemaRef::new(schema)); - - // create a TableScan plan - let projection = None; // optional projection - let filters = vec![]; // optional filters to push down - let fetch = None; // optional LIMIT - let table_scan = LogicalPlan::TableScan(TableScan::try_new( - "person", - Arc::new(table_source), - projection, - filters, - fetch, - )?); - - // create a Filter plan that evaluates `id > 500` and wraps the TableScan - let filter_expr = col("id").gt(lit(500)); - let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan))?); - - // print the plan - println!("{}", plan.display_indent_schema()); - - Ok(()) -} - -#[test] -fn plan_builder_1() -> Result<()> { - // create a logical table source - let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("name", DataType::Utf8, true), - ]); - let table_source = LogicalTableSource::new(SchemaRef::new(schema)); - - // optional projection - let projection = None; - - // create a LogicalPlanBuilder for a table scan - let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; - - // perform a filter that evaluates `id > 500`, and build the plan - let plan = builder.filter(col("id").gt(lit(500)))?.build()?; - - // print the plan - println!("{}", plan.display_indent_schema()); - - Ok(()) -} From ad583a8dfa1cdf269c7a52eb4cb030d65d370a4c Mon Sep 17 00:00:00 2001 From: HuSen Date: Thu, 22 Aug 2024 01:03:09 +0800 Subject: [PATCH 06/10] Add test to verify count aggregate function should not be nullable (#12100) --- datafusion/sqllogictest/test_files/aggregate.slt | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index b8b93b28aff61..d39bf6538ecbc 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5387,6 +5387,18 @@ physical_plan statement ok DROP TABLE empty; +# verify count aggregate function should not be nullable +statement ok +create table empty; + +query I +select distinct count() from empty; +---- +0 + +statement ok +DROP TABLE empty; + statement ok CREATE TABLE t(col0 INTEGER) as VALUES(2); From 121f330a6ccca008da4bf6ffc4efa4ffbf961fd7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 21 Aug 2024 13:12:37 -0400 Subject: [PATCH 07/10] Minor: Extract `BatchCoalescer` to its own module (#12047) --- datafusion/physical-plan/src/coalesce/mod.rs | 588 ++++++++++++++++++ .../physical-plan/src/coalesce_batches.rs | 546 +--------------- datafusion/physical-plan/src/lib.rs | 1 + 3 files changed, 593 insertions(+), 542 deletions(-) create mode 100644 datafusion/physical-plan/src/coalesce/mod.rs diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs new file mode 100644 index 0000000000000..5befa5ecda99b --- /dev/null +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -0,0 +1,588 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::concat_batches; +use arrow_array::builder::StringViewBuilder; +use arrow_array::cast::AsArray; +use arrow_array::{Array, ArrayRef, RecordBatch}; +use arrow_schema::SchemaRef; +use std::sync::Arc; + +/// Concatenate multiple [`RecordBatch`]es +/// +/// `BatchCoalescer` concatenates multiple small [`RecordBatch`]es, produced by +/// operations such as `FilterExec` and `RepartitionExec`, into larger ones for +/// more efficient processing by subsequent operations. +/// +/// # Background +/// +/// Generally speaking, larger [`RecordBatch`]es are more efficient to process +/// than smaller record batches (until the CPU cache is exceeded) because there +/// is fixed processing overhead per batch. DataFusion tries to operate on +/// batches of `target_batch_size` rows to amortize this overhead +/// +/// ```text +/// ┌────────────────────┐ +/// │ RecordBatch │ +/// │ num_rows = 23 │ +/// └────────────────────┘ ┌────────────────────┐ +/// │ │ +/// ┌────────────────────┐ Coalesce │ │ +/// │ │ Batches │ │ +/// │ RecordBatch │ │ │ +/// │ num_rows = 50 │ ─ ─ ─ ─ ─ ─ ▶ │ │ +/// │ │ │ RecordBatch │ +/// │ │ │ num_rows = 106 │ +/// └────────────────────┘ │ │ +/// │ │ +/// ┌────────────────────┐ │ │ +/// │ │ │ │ +/// │ RecordBatch │ │ │ +/// │ num_rows = 33 │ └────────────────────┘ +/// │ │ +/// └────────────────────┘ +/// ``` +/// +/// # Notes: +/// +/// 1. Output rows are produced in the same order as the input rows +/// +/// 2. The output is a sequence of batches, with all but the last being at least +/// `target_batch_size` rows. +/// +/// 3. Eventually this may also be able to handle other optimizations such as a +/// combined filter/coalesce operation. +/// +#[derive(Debug)] +pub struct BatchCoalescer { + /// The input schema + schema: SchemaRef, + /// Minimum number of rows for coalesces batches + target_batch_size: usize, + /// Total number of rows returned so far + total_rows: usize, + /// Buffered batches + buffer: Vec, + /// Buffered row count + buffered_rows: usize, + /// Limit: maximum number of rows to fetch, `None` means fetch all rows + fetch: Option, +} + +impl BatchCoalescer { + /// Create a new `BatchCoalescer` + /// + /// # Arguments + /// - `schema` - the schema of the output batches + /// - `target_batch_size` - the minimum number of rows for each + /// output batch (until limit reached) + /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows + pub fn new( + schema: SchemaRef, + target_batch_size: usize, + fetch: Option, + ) -> Self { + Self { + schema, + target_batch_size, + total_rows: 0, + buffer: vec![], + buffered_rows: 0, + fetch, + } + } + + /// Return the schema of the output batches + pub fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + /// Push next batch, and returns [`CoalescerState`] indicating the current + /// state of the buffer. + pub fn push_batch(&mut self, batch: RecordBatch) -> CoalescerState { + let batch = gc_string_view_batch(&batch); + if self.limit_reached(&batch) { + CoalescerState::LimitReached + } else if self.target_reached(batch) { + CoalescerState::TargetReached + } else { + CoalescerState::Continue + } + } + + /// Return true if the there is no data buffered + pub fn is_empty(&self) -> bool { + self.buffer.is_empty() + } + + /// Checks if the buffer will reach the specified limit after getting + /// `batch`. + /// + /// If fetch would be exceeded, slices the received batch, updates the + /// buffer with it, and returns `true`. + /// + /// Otherwise: does nothing and returns `false`. + fn limit_reached(&mut self, batch: &RecordBatch) -> bool { + match self.fetch { + Some(fetch) if self.total_rows + batch.num_rows() >= fetch => { + // Limit is reached + let remaining_rows = fetch - self.total_rows; + debug_assert!(remaining_rows > 0); + + let batch = batch.slice(0, remaining_rows); + self.buffered_rows += batch.num_rows(); + self.total_rows = fetch; + self.buffer.push(batch); + true + } + _ => false, + } + } + + /// Updates the buffer with the given batch. + /// + /// If the target batch size is reached, returns `true`. Otherwise, returns + /// `false`. + fn target_reached(&mut self, batch: RecordBatch) -> bool { + if batch.num_rows() == 0 { + false + } else { + self.total_rows += batch.num_rows(); + self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); + self.buffered_rows >= self.target_batch_size + } + } + + /// Concatenates and returns all buffered batches, and clears the buffer. + pub fn finish_batch(&mut self) -> datafusion_common::Result { + let batch = concat_batches(&self.schema, &self.buffer)?; + self.buffer.clear(); + self.buffered_rows = 0; + Ok(batch) + } +} + +/// Indicates the state of the [`BatchCoalescer`] buffer after the +/// [`BatchCoalescer::push_batch()`] operation. +/// +/// The caller should take diferent actions, depending on the variant returned. +pub enum CoalescerState { + /// Neither the limit nor the target batch size is reached. + /// + /// Action: continue pushing batches. + Continue, + /// The limit has been reached. + /// + /// Action: call [`BatchCoalescer::finish_batch()`] to get the final + /// buffered results as a batch and finish the query. + LimitReached, + /// The specified minimum number of rows a batch should have is reached. + /// + /// Action: call [`BatchCoalescer::finish_batch()`] to get the current + /// buffered results as a batch and then continue pushing batches. + TargetReached, +} + +/// Heuristically compact `StringViewArray`s to reduce memory usage, if needed +/// +/// Decides when to consolidate the StringView into a new buffer to reduce +/// memory usage and improve string locality for better performance. +/// +/// This differs from `StringViewArray::gc` because: +/// 1. It may not compact the array depending on a heuristic. +/// 2. It uses a precise block size to reduce the number of buffers to track. +/// +/// # Heuristic +/// +/// If the average size of each view is larger than 32 bytes, we compact the array. +/// +/// `StringViewArray` include pointers to buffer that hold the underlying data. +/// One of the great benefits of `StringViewArray` is that many operations +/// (e.g., `filter`) can be done without copying the underlying data. +/// +/// However, after a while (e.g., after `FilterExec` or `HashJoinExec`) the +/// `StringViewArray` may only refer to a small portion of the buffer, +/// significantly increasing memory usage. +fn gc_string_view_batch(batch: &RecordBatch) -> RecordBatch { + let new_columns: Vec = batch + .columns() + .iter() + .map(|c| { + // Try to re-create the `StringViewArray` to prevent holding the underlying buffer too long. + let Some(s) = c.as_string_view_opt() else { + return Arc::clone(c); + }; + let ideal_buffer_size: usize = s + .views() + .iter() + .map(|v| { + let len = (*v as u32) as usize; + if len > 12 { + len + } else { + 0 + } + }) + .sum(); + let actual_buffer_size = s.get_buffer_memory_size(); + + // Re-creating the array copies data and can be time consuming. + // We only do it if the array is sparse + if actual_buffer_size > (ideal_buffer_size * 2) { + // We set the block size to `ideal_buffer_size` so that the new StringViewArray only has one buffer, which accelerate later concat_batches. + // See https://github.com/apache/arrow-rs/issues/6094 for more details. + let mut builder = StringViewBuilder::with_capacity(s.len()); + if ideal_buffer_size > 0 { + builder = builder.with_block_size(ideal_buffer_size as u32); + } + + for v in s.iter() { + builder.append_option(v); + } + + let gc_string = builder.finish(); + + debug_assert!(gc_string.data_buffers().len() <= 1); // buffer count can be 0 if the `ideal_buffer_size` is 0 + + Arc::new(gc_string) + } else { + Arc::clone(c) + } + }) + .collect(); + RecordBatch::try_new(batch.schema(), new_columns) + .expect("Failed to re-create the gc'ed record batch") +} + +#[cfg(test)] +mod tests { + use std::ops::Range; + + use super::*; + + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::builder::ArrayBuilder; + use arrow_array::{StringViewArray, UInt32Array}; + + #[test] + fn test_coalesce() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // expected output is batches of at least 20 rows (except for the final batch) + .with_target_batch_size(21) + .with_expected_output_sizes(vec![24, 24, 24, 8]) + .run() + } + + #[test] + fn test_coalesce_with_fetch_larger_than_input_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 100 + // expected to behave the same as `test_concat_batches` + .with_target_batch_size(21) + .with_fetch(Some(100)) + .with_expected_output_sizes(vec![24, 24, 24, 8]) + .run(); + } + + #[test] + fn test_coalesce_with_fetch_less_than_input_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 50 + .with_target_batch_size(21) + .with_fetch(Some(50)) + .with_expected_output_sizes(vec![24, 24, 2]) + .run(); + } + + #[test] + fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 48 + .with_target_batch_size(21) + .with_fetch(Some(48)) + .with_expected_output_sizes(vec![24, 24]) + .run(); + } + + #[test] + fn test_coalesce_with_fetch_less_target_batch_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 10 + .with_target_batch_size(21) + .with_fetch(Some(10)) + .with_expected_output_sizes(vec![10]) + .run(); + } + + #[test] + fn test_coalesce_single_large_batch_over_fetch() { + let large_batch = uint32_batch(0..100); + Test::new() + .with_batch(large_batch) + .with_target_batch_size(20) + .with_fetch(Some(7)) + .with_expected_output_sizes(vec![7]) + .run() + } + + /// Test for [`BatchCoalescer`] + /// + /// Pushes the input batches to the coalescer and verifies that the resulting + /// batches have the expected number of rows and contents. + #[derive(Debug, Clone, Default)] + struct Test { + /// Batches to feed to the coalescer. Tests must have at least one + /// schema + input_batches: Vec, + /// Expected output sizes of the resulting batches + expected_output_sizes: Vec, + /// target batch size + target_batch_size: usize, + /// Fetch (limit) + fetch: Option, + } + + impl Test { + fn new() -> Self { + Self::default() + } + + /// Set the target batch size + fn with_target_batch_size(mut self, target_batch_size: usize) -> Self { + self.target_batch_size = target_batch_size; + self + } + + /// Set the fetch (limit) + fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + /// Extend the input batches with `batch` + fn with_batch(mut self, batch: RecordBatch) -> Self { + self.input_batches.push(batch); + self + } + + /// Extends the input batches with `batches` + fn with_batches( + mut self, + batches: impl IntoIterator, + ) -> Self { + self.input_batches.extend(batches); + self + } + + /// Extends `sizes` to expected output sizes + fn with_expected_output_sizes( + mut self, + sizes: impl IntoIterator, + ) -> Self { + self.expected_output_sizes.extend(sizes); + self + } + + /// Runs the test -- see documentation on [`Test`] for details + fn run(self) { + let Self { + input_batches, + target_batch_size, + fetch, + expected_output_sizes, + } = self; + + let schema = input_batches[0].schema(); + + // create a single large input batch for output comparison + let single_input_batch = concat_batches(&schema, &input_batches).unwrap(); + + let mut coalescer = + BatchCoalescer::new(Arc::clone(&schema), target_batch_size, fetch); + + let mut output_batches = vec![]; + for batch in input_batches { + match coalescer.push_batch(batch) { + CoalescerState::Continue => {} + CoalescerState::LimitReached => { + output_batches.push(coalescer.finish_batch().unwrap()); + break; + } + CoalescerState::TargetReached => { + coalescer.buffered_rows = 0; + output_batches.push(coalescer.finish_batch().unwrap()); + } + } + } + if coalescer.buffered_rows != 0 { + output_batches.extend(coalescer.buffer); + } + + // make sure we got the expected number of output batches and content + let mut starting_idx = 0; + assert_eq!(expected_output_sizes.len(), output_batches.len()); + for (i, (expected_size, batch)) in + expected_output_sizes.iter().zip(output_batches).enumerate() + { + assert_eq!( + *expected_size, + batch.num_rows(), + "Unexpected number of rows in Batch {i}" + ); + + // compare the contents of the batch (using `==` compares the + // underlying memory layout too) + let expected_batch = + single_input_batch.slice(starting_idx, *expected_size); + let batch_strings = batch_to_pretty_strings(&batch); + let expected_batch_strings = batch_to_pretty_strings(&expected_batch); + let batch_strings = batch_strings.lines().collect::>(); + let expected_batch_strings = + expected_batch_strings.lines().collect::>(); + assert_eq!( + expected_batch_strings, batch_strings, + "Unexpected content in Batch {i}:\ + \n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}" + ); + starting_idx += *expected_size; + } + } + } + + /// Return a batch of UInt32 with the specified range + fn uint32_batch(range: Range) -> RecordBatch { + let schema = + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from_iter_values(range))], + ) + .unwrap() + } + + #[test] + fn test_gc_string_view_batch_small_no_compact() { + // view with only short strings (no buffers) --> no need to compact + let array = StringViewTest { + rows: 1000, + strings: vec![Some("a"), Some("b"), Some("c")], + } + .build(); + + let gc_array = do_gc(array.clone()); + compare_string_array_values(&array, &gc_array); + assert_eq!(array.data_buffers().len(), 0); + assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction + } + + #[test] + fn test_gc_string_view_batch_large_no_compact() { + // view with large strings (has buffers) but full --> no need to compact + let array = StringViewTest { + rows: 1000, + strings: vec![Some("This string is longer than 12 bytes")], + } + .build(); + + let gc_array = do_gc(array.clone()); + compare_string_array_values(&array, &gc_array); + assert_eq!(array.data_buffers().len(), 5); + assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction + } + + #[test] + fn test_gc_string_view_batch_large_slice_compact() { + // view with large strings (has buffers) and only partially used --> no need to compact + let array = StringViewTest { + rows: 1000, + strings: vec![Some("this string is longer than 12 bytes")], + } + .build(); + + // slice only 11 rows, so most of the buffer is not used + let array = array.slice(11, 22); + + let gc_array = do_gc(array.clone()); + compare_string_array_values(&array, &gc_array); + assert_eq!(array.data_buffers().len(), 5); + assert_eq!(gc_array.data_buffers().len(), 1); // compacted into a single buffer + } + + /// Compares the values of two string view arrays + fn compare_string_array_values(arr1: &StringViewArray, arr2: &StringViewArray) { + assert_eq!(arr1.len(), arr2.len()); + for (s1, s2) in arr1.iter().zip(arr2.iter()) { + assert_eq!(s1, s2); + } + } + + /// runs garbage collection on string view array + /// and ensures the number of rows are the same + fn do_gc(array: StringViewArray) -> StringViewArray { + let batch = + RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)]).unwrap(); + let gc_batch = gc_string_view_batch(&batch); + assert_eq!(batch.num_rows(), gc_batch.num_rows()); + assert_eq!(batch.schema(), gc_batch.schema()); + gc_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .clone() + } + + /// Describes parameters for creating a `StringViewArray` + struct StringViewTest { + /// The number of rows in the array + rows: usize, + /// The strings to use in the array (repeated over and over + strings: Vec>, + } + + impl StringViewTest { + /// Create a `StringViewArray` with the parameters specified in this struct + fn build(self) -> StringViewArray { + let mut builder = StringViewBuilder::with_capacity(100).with_block_size(8192); + loop { + for &v in self.strings.iter() { + builder.append_option(v); + if builder.len() >= self.rows { + return builder.finish(); + } + } + } + } + } + fn batch_to_pretty_strings(batch: &RecordBatch) -> String { + arrow::util::pretty::pretty_format_batches(&[batch.clone()]) + .unwrap() + .to_string() + } +} diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 5589027694fe4..7caf5b8ab65a3 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -28,19 +28,17 @@ use crate::{ DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, }; -use arrow::array::{AsArray, StringViewBuilder}; -use arrow::compute::concat_batches; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, ArrayRef}; use datafusion_common::Result; use datafusion_execution::TaskContext; +use crate::coalesce::{BatchCoalescer, CoalescerState}; use futures::ready; use futures::stream::{Stream, StreamExt}; /// `CoalesceBatchesExec` combines small batches into larger batches for more -/// efficient use of vectorized processing by later operators. +/// efficient vectorized processing by later operators. /// /// The operator buffers batches until it collects `target_batch_size` rows and /// then emits a single concatenated batch. When only a limited number of rows @@ -48,35 +46,7 @@ use futures::stream::{Stream, StreamExt}; /// buffering and returns the final batch once the number of collected rows /// reaches the `fetch` value. /// -/// # Background -/// -/// Generally speaking, larger RecordBatches are more efficient to process than -/// smaller record batches (until the CPU cache is exceeded) because there is -/// fixed processing overhead per batch. This code concatenates multiple small -/// record batches into larger ones to amortize this overhead. -/// -/// ```text -/// ┌────────────────────┐ -/// │ RecordBatch │ -/// │ num_rows = 23 │ -/// └────────────────────┘ ┌────────────────────┐ -/// │ │ -/// ┌────────────────────┐ Coalesce │ │ -/// │ │ Batches │ │ -/// │ RecordBatch │ │ │ -/// │ num_rows = 50 │ ─ ─ ─ ─ ─ ─ ▶ │ │ -/// │ │ │ RecordBatch │ -/// │ │ │ num_rows = 106 │ -/// └────────────────────┘ │ │ -/// │ │ -/// ┌────────────────────┐ │ │ -/// │ │ │ │ -/// │ RecordBatch │ │ │ -/// │ num_rows = 33 │ └────────────────────┘ -/// │ │ -/// └────────────────────┘ -/// ``` - +/// See [`BatchCoalescer`] for more information #[derive(Debug)] pub struct CoalesceBatchesExec { /// The input plan @@ -346,7 +316,7 @@ impl CoalesceBatchesStream { } CoalesceBatchesStreamState::Exhausted => { // Handle the end of the input stream. - return if self.coalescer.buffer.is_empty() { + return if self.coalescer.is_empty() { // If buffer is empty, return None indicating the stream is fully consumed. Poll::Ready(None) } else { @@ -365,511 +335,3 @@ impl RecordBatchStream for CoalesceBatchesStream { self.coalescer.schema() } } - -/// Concatenate multiple record batches into larger batches -/// -/// See [`CoalesceBatchesExec`] for more details. -/// -/// Notes: -/// -/// 1. The output rows is the same order as the input rows -/// -/// 2. The output is a sequence of batches, with all but the last being at least -/// `target_batch_size` rows. -/// -/// 3. Eventually this may also be able to handle other optimizations such as a -/// combined filter/coalesce operation. -#[derive(Debug)] -struct BatchCoalescer { - /// The input schema - schema: SchemaRef, - /// Minimum number of rows for coalesces batches - target_batch_size: usize, - /// Total number of rows returned so far - total_rows: usize, - /// Buffered batches - buffer: Vec, - /// Buffered row count - buffered_rows: usize, - /// Maximum number of rows to fetch, `None` means fetching all rows - fetch: Option, -} - -impl BatchCoalescer { - /// Create a new `BatchCoalescer` - /// - /// # Arguments - /// - `schema` - the schema of the output batches - /// - `target_batch_size` - the minimum number of rows for each - /// output batch (until limit reached) - /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows - fn new(schema: SchemaRef, target_batch_size: usize, fetch: Option) -> Self { - Self { - schema, - target_batch_size, - total_rows: 0, - buffer: vec![], - buffered_rows: 0, - fetch, - } - } - - /// Return the schema of the output batches - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } - - /// Given a batch, it updates the buffer of [`BatchCoalescer`]. It returns - /// a variant of [`CoalescerState`] indicating the final state of the buffer. - fn push_batch(&mut self, batch: RecordBatch) -> CoalescerState { - let batch = gc_string_view_batch(&batch); - if self.limit_reached(&batch) { - CoalescerState::LimitReached - } else if self.target_reached(batch) { - CoalescerState::TargetReached - } else { - CoalescerState::Continue - } - } - - /// The function checks if the buffer can reach the specified limit after getting `batch`. - /// If it does, it slices the received batch as needed, updates the buffer with it, and - /// finally returns `true`. Otherwise; the function does nothing and returns `false`. - fn limit_reached(&mut self, batch: &RecordBatch) -> bool { - match self.fetch { - Some(fetch) if self.total_rows + batch.num_rows() >= fetch => { - // Limit is reached - let remaining_rows = fetch - self.total_rows; - debug_assert!(remaining_rows > 0); - - let batch = batch.slice(0, remaining_rows); - self.buffered_rows += batch.num_rows(); - self.total_rows = fetch; - self.buffer.push(batch); - true - } - _ => false, - } - } - - /// Updates the buffer with the given batch. If the target batch size is reached, - /// the function returns `true`. Otherwise, it returns `false`. - fn target_reached(&mut self, batch: RecordBatch) -> bool { - if batch.num_rows() == 0 { - false - } else { - self.total_rows += batch.num_rows(); - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - self.buffered_rows >= self.target_batch_size - } - } - - /// Concatenates and returns all buffered batches, and clears the buffer. - fn finish_batch(&mut self) -> Result { - let batch = concat_batches(&self.schema, &self.buffer)?; - self.buffer.clear(); - self.buffered_rows = 0; - Ok(batch) - } -} - -/// This enumeration acts as a status indicator for the [`BatchCoalescer`] after a -/// [`BatchCoalescer::push_batch()`] operation. -enum CoalescerState { - /// Neither the limit nor the target batch size is reached. - Continue, - /// The sufficient row count to produce a complete query result is reached. - LimitReached, - /// The specified minimum number of rows a batch should have is reached. - TargetReached, -} - -/// Heuristically compact `StringViewArray`s to reduce memory usage, if needed -/// -/// This function decides when to consolidate the StringView into a new buffer -/// to reduce memory usage and improve string locality for better performance. -/// -/// This differs from `StringViewArray::gc` because: -/// 1. It may not compact the array depending on a heuristic. -/// 2. It uses a precise block size to reduce the number of buffers to track. -/// -/// # Heuristic -/// -/// If the average size of each view is larger than 32 bytes, we compact the array. -/// -/// `StringViewArray` include pointers to buffer that hold the underlying data. -/// One of the great benefits of `StringViewArray` is that many operations -/// (e.g., `filter`) can be done without copying the underlying data. -/// -/// However, after a while (e.g., after `FilterExec` or `HashJoinExec`) the -/// `StringViewArray` may only refer to a small portion of the buffer, -/// significantly increasing memory usage. -fn gc_string_view_batch(batch: &RecordBatch) -> RecordBatch { - let new_columns: Vec = batch - .columns() - .iter() - .map(|c| { - // Try to re-create the `StringViewArray` to prevent holding the underlying buffer too long. - let Some(s) = c.as_string_view_opt() else { - return Arc::clone(c); - }; - let ideal_buffer_size: usize = s - .views() - .iter() - .map(|v| { - let len = (*v as u32) as usize; - if len > 12 { - len - } else { - 0 - } - }) - .sum(); - let actual_buffer_size = s.get_buffer_memory_size(); - - // Re-creating the array copies data and can be time consuming. - // We only do it if the array is sparse - if actual_buffer_size > (ideal_buffer_size * 2) { - // We set the block size to `ideal_buffer_size` so that the new StringViewArray only has one buffer, which accelerate later concat_batches. - // See https://github.com/apache/arrow-rs/issues/6094 for more details. - let mut builder = StringViewBuilder::with_capacity(s.len()); - if ideal_buffer_size > 0 { - builder = builder.with_block_size(ideal_buffer_size as u32); - } - - for v in s.iter() { - builder.append_option(v); - } - - let gc_string = builder.finish(); - - debug_assert!(gc_string.data_buffers().len() <= 1); // buffer count can be 0 if the `ideal_buffer_size` is 0 - - Arc::new(gc_string) - } else { - Arc::clone(c) - } - }) - .collect(); - RecordBatch::try_new(batch.schema(), new_columns) - .expect("Failed to re-create the gc'ed record batch") -} - -#[cfg(test)] -mod tests { - use std::ops::Range; - - use super::*; - - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::builder::ArrayBuilder; - use arrow_array::{StringViewArray, UInt32Array}; - - #[test] - fn test_coalesce() { - let batch = uint32_batch(0..8); - Test::new() - .with_batches(std::iter::repeat(batch).take(10)) - // expected output is batches of at least 20 rows (except for the final batch) - .with_target_batch_size(21) - .with_expected_output_sizes(vec![24, 24, 24, 8]) - .run() - } - - #[test] - fn test_coalesce_with_fetch_larger_than_input_size() { - let batch = uint32_batch(0..8); - Test::new() - .with_batches(std::iter::repeat(batch).take(10)) - // input is 10 batches x 8 rows (80 rows) with fetch limit of 100 - // expected to behave the same as `test_concat_batches` - .with_target_batch_size(21) - .with_fetch(Some(100)) - .with_expected_output_sizes(vec![24, 24, 24, 8]) - .run(); - } - - #[test] - fn test_coalesce_with_fetch_less_than_input_size() { - let batch = uint32_batch(0..8); - Test::new() - .with_batches(std::iter::repeat(batch).take(10)) - // input is 10 batches x 8 rows (80 rows) with fetch limit of 50 - .with_target_batch_size(21) - .with_fetch(Some(50)) - .with_expected_output_sizes(vec![24, 24, 2]) - .run(); - } - - #[test] - fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() { - let batch = uint32_batch(0..8); - Test::new() - .with_batches(std::iter::repeat(batch).take(10)) - // input is 10 batches x 8 rows (80 rows) with fetch limit of 48 - .with_target_batch_size(21) - .with_fetch(Some(48)) - .with_expected_output_sizes(vec![24, 24]) - .run(); - } - - #[test] - fn test_coalesce_with_fetch_less_target_batch_size() { - let batch = uint32_batch(0..8); - Test::new() - .with_batches(std::iter::repeat(batch).take(10)) - // input is 10 batches x 8 rows (80 rows) with fetch limit of 10 - .with_target_batch_size(21) - .with_fetch(Some(10)) - .with_expected_output_sizes(vec![10]) - .run(); - } - - #[test] - fn test_coalesce_single_large_batch_over_fetch() { - let large_batch = uint32_batch(0..100); - Test::new() - .with_batch(large_batch) - .with_target_batch_size(20) - .with_fetch(Some(7)) - .with_expected_output_sizes(vec![7]) - .run() - } - - /// Test for [`BatchCoalescer`] - /// - /// Pushes the input batches to the coalescer and verifies that the resulting - /// batches have the expected number of rows and contents. - #[derive(Debug, Clone, Default)] - struct Test { - /// Batches to feed to the coalescer. Tests must have at least one - /// schema - input_batches: Vec, - /// Expected output sizes of the resulting batches - expected_output_sizes: Vec, - /// target batch size - target_batch_size: usize, - /// Fetch (limit) - fetch: Option, - } - - impl Test { - fn new() -> Self { - Self::default() - } - - /// Set the target batch size - fn with_target_batch_size(mut self, target_batch_size: usize) -> Self { - self.target_batch_size = target_batch_size; - self - } - - /// Set the fetch (limit) - fn with_fetch(mut self, fetch: Option) -> Self { - self.fetch = fetch; - self - } - - /// Extend the input batches with `batch` - fn with_batch(mut self, batch: RecordBatch) -> Self { - self.input_batches.push(batch); - self - } - - /// Extends the input batches with `batches` - fn with_batches( - mut self, - batches: impl IntoIterator, - ) -> Self { - self.input_batches.extend(batches); - self - } - - /// Extends `sizes` to expected output sizes - fn with_expected_output_sizes( - mut self, - sizes: impl IntoIterator, - ) -> Self { - self.expected_output_sizes.extend(sizes); - self - } - - /// Runs the test -- see documentation on [`Test`] for details - fn run(self) { - let Self { - input_batches, - target_batch_size, - fetch, - expected_output_sizes, - } = self; - - let schema = input_batches[0].schema(); - - // create a single large input batch for output comparison - let single_input_batch = concat_batches(&schema, &input_batches).unwrap(); - - let mut coalescer = - BatchCoalescer::new(Arc::clone(&schema), target_batch_size, fetch); - - let mut output_batches = vec![]; - for batch in input_batches { - match coalescer.push_batch(batch) { - CoalescerState::Continue => {} - CoalescerState::LimitReached => { - output_batches.push(coalescer.finish_batch().unwrap()); - break; - } - CoalescerState::TargetReached => { - coalescer.buffered_rows = 0; - output_batches.push(coalescer.finish_batch().unwrap()); - } - } - } - if coalescer.buffered_rows != 0 { - output_batches.extend(coalescer.buffer); - } - - // make sure we got the expected number of output batches and content - let mut starting_idx = 0; - assert_eq!(expected_output_sizes.len(), output_batches.len()); - for (i, (expected_size, batch)) in - expected_output_sizes.iter().zip(output_batches).enumerate() - { - assert_eq!( - *expected_size, - batch.num_rows(), - "Unexpected number of rows in Batch {i}" - ); - - // compare the contents of the batch (using `==` compares the - // underlying memory layout too) - let expected_batch = - single_input_batch.slice(starting_idx, *expected_size); - let batch_strings = batch_to_pretty_strings(&batch); - let expected_batch_strings = batch_to_pretty_strings(&expected_batch); - let batch_strings = batch_strings.lines().collect::>(); - let expected_batch_strings = - expected_batch_strings.lines().collect::>(); - assert_eq!( - expected_batch_strings, batch_strings, - "Unexpected content in Batch {i}:\ - \n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}" - ); - starting_idx += *expected_size; - } - } - } - - /// Return a batch of UInt32 with the specified range - fn uint32_batch(range: Range) -> RecordBatch { - let schema = - Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])); - - RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(UInt32Array::from_iter_values(range))], - ) - .unwrap() - } - - #[test] - fn test_gc_string_view_batch_small_no_compact() { - // view with only short strings (no buffers) --> no need to compact - let array = StringViewTest { - rows: 1000, - strings: vec![Some("a"), Some("b"), Some("c")], - } - .build(); - - let gc_array = do_gc(array.clone()); - compare_string_array_values(&array, &gc_array); - assert_eq!(array.data_buffers().len(), 0); - assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction - } - - #[test] - fn test_gc_string_view_batch_large_no_compact() { - // view with large strings (has buffers) but full --> no need to compact - let array = StringViewTest { - rows: 1000, - strings: vec![Some("This string is longer than 12 bytes")], - } - .build(); - - let gc_array = do_gc(array.clone()); - compare_string_array_values(&array, &gc_array); - assert_eq!(array.data_buffers().len(), 5); - assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction - } - - #[test] - fn test_gc_string_view_batch_large_slice_compact() { - // view with large strings (has buffers) and only partially used --> no need to compact - let array = StringViewTest { - rows: 1000, - strings: vec![Some("this string is longer than 12 bytes")], - } - .build(); - - // slice only 11 rows, so most of the buffer is not used - let array = array.slice(11, 22); - - let gc_array = do_gc(array.clone()); - compare_string_array_values(&array, &gc_array); - assert_eq!(array.data_buffers().len(), 5); - assert_eq!(gc_array.data_buffers().len(), 1); // compacted into a single buffer - } - - /// Compares the values of two string view arrays - fn compare_string_array_values(arr1: &StringViewArray, arr2: &StringViewArray) { - assert_eq!(arr1.len(), arr2.len()); - for (s1, s2) in arr1.iter().zip(arr2.iter()) { - assert_eq!(s1, s2); - } - } - - /// runs garbage collection on string view array - /// and ensures the number of rows are the same - fn do_gc(array: StringViewArray) -> StringViewArray { - let batch = - RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)]).unwrap(); - let gc_batch = gc_string_view_batch(&batch); - assert_eq!(batch.num_rows(), gc_batch.num_rows()); - assert_eq!(batch.schema(), gc_batch.schema()); - gc_batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - /// Describes parameters for creating a `StringViewArray` - struct StringViewTest { - /// The number of rows in the array - rows: usize, - /// The strings to use in the array (repeated over and over - strings: Vec>, - } - - impl StringViewTest { - /// Create a `StringViewArray` with the parameters specified in this struct - fn build(self) -> StringViewArray { - let mut builder = StringViewBuilder::with_capacity(100).with_block_size(8192); - loop { - for &v in self.strings.iter() { - builder.append_option(v); - if builder.len() >= self.rows { - return builder.finish(); - } - } - } - } - } - fn batch_to_pretty_strings(batch: &RecordBatch) -> String { - arrow::util::pretty::pretty_format_batches(&[batch.clone()]) - .unwrap() - .to_string() - } -} diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 59c5da6b6fb20..fb86a008e2cd6 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -85,5 +85,6 @@ pub mod udaf { pub use datafusion_physical_expr_functions_aggregate::aggregate::AggregateFunctionExpr; } +pub mod coalesce; #[cfg(test)] pub mod test; From c6be00d5012e89ea29611f02b9edf15c806db6c5 Mon Sep 17 00:00:00 2001 From: Dmitry Bugakov Date: Wed, 21 Aug 2024 21:05:03 +0200 Subject: [PATCH 08/10] Add Utf8View support to STRPOS function (#12087) * Add Utf8View support to STRPOS function * fix type inconsistency * fix type inconsistency * refactor tests --- datafusion/functions/src/unicode/strpos.rs | 175 ++++++++++++------ .../sqllogictest/test_files/string_view.slt | 5 +- 2 files changed, 121 insertions(+), 59 deletions(-) diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 702baf6e8fa77..cf10b18ae3383 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -19,11 +19,10 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use datafusion_common::cast::as_generic_string_array; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -52,6 +51,9 @@ impl StrposFunc { Exact(vec![Utf8, LargeUtf8]), Exact(vec![LargeUtf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8]), + Exact(vec![Utf8View, Utf8View]), + Exact(vec![Utf8View, Utf8]), + Exact(vec![Utf8View, LargeUtf8]), ], Volatility::Immutable, ), @@ -78,21 +80,7 @@ impl ScalarUDFImpl for StrposFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match (args[0].data_type(), args[1].data_type()) { - (DataType::Utf8, DataType::Utf8) => { - make_scalar_function(strpos::, vec![])(args) - } - (DataType::Utf8, DataType::LargeUtf8) => { - make_scalar_function(strpos::, vec![])(args) - } - (DataType::LargeUtf8, DataType::Utf8) => { - make_scalar_function(strpos::, vec![])(args) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - make_scalar_function(strpos::, vec![])(args) - } - other => exec_err!("Unsupported data type {other:?} for function strpos"), - } + make_scalar_function(strpos, vec![])(args) } fn aliases(&self) -> &[String] { @@ -100,30 +88,71 @@ impl ScalarUDFImpl for StrposFunc { } } +fn strpos(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8, DataType::Utf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8, DataType::LargeUtf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::LargeUtf8, DataType::Utf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int64Type>(string_array, substring_array) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int64Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::Utf8View) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string_view(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::Utf8) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::LargeUtf8) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + + other => { + exec_err!("Unsupported data type combination {other:?} for function strpos") + } + } +} + /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 /// The implementation uses UTF-8 code points as characters -fn strpos( - args: &[ArrayRef], +fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>( + string_array: V1, + substring_array: V2, ) -> Result where - T0::Native: OffsetSizeTrait, - T1::Native: OffsetSizeTrait, + V1: ArrayAccessor, + V2: ArrayAccessor, { - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let substring_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; + let string_iter = ArrayIter::new(string_array); + let substring_iter = ArrayIter::new(substring_array); - let result = string_array - .iter() - .zip(substring_array.iter()) + let result = string_iter + .zip(substring_iter) .map(|(string, substring)| match (string, substring) { (Some(string), Some(substring)) => { - // the find method returns the byte index of the substring - // Next, we count the number of the chars until that byte - T0::Native::from_usize( + // The `find` method returns the byte index of the substring. + // We count the number of chars up to that byte index. + T::Native::from_usize( string .find(substring) .map(|x| string[..x].chars().count() + 1) @@ -132,20 +161,21 @@ where } _ => None, }) - .collect::>(); + .collect::>(); Ok(Arc::new(result) as ArrayRef) } #[cfg(test)] -mod test { - use super::*; +mod tests { + use arrow::array::{Array, Int32Array, Int64Array}; + use arrow::datatypes::DataType::{Int32, Int64}; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::strpos::StrposFunc; use crate::utils::test::test_function; - use arrow::{ - array::{Array as _, Int32Array, Int64Array}, - datatypes::DataType::{Int32, Int64}, - }; - use datafusion_common::ScalarValue; macro_rules! test_strpos { ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => { @@ -164,21 +194,54 @@ mod test { } #[test] - fn strpos() { - test_strpos!("foo", "bar" -> 0; Utf8 Utf8 i32 Int32 Int32Array); - test_strpos!("foobar", "foo" -> 1; Utf8 Utf8 i32 Int32 Int32Array); - test_strpos!("foobar", "bar" -> 4; Utf8 Utf8 i32 Int32 Int32Array); - - test_strpos!("foo", "bar" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("foobar", "foo" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - test_strpos!("foobar", "bar" -> 4; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); - - test_strpos!("foo", "bar" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("foobar", "foo" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); - test_strpos!("foobar", "bar" -> 4; Utf8 LargeUtf8 i32 Int32 Int32Array); - - test_strpos!("foo", "bar" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("foobar", "foo" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); - test_strpos!("foobar", "bar" -> 4; LargeUtf8 Utf8 i64 Int64 Int64Array); + fn test_strpos_functions() { + // Utf8 and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array); + + // LargeUtf8 and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + + // Utf8 and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); + + // LargeUtf8 and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); + + // Utf8View and Utf8View combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array); + + // Utf8View and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array); + + // Utf8View and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); } } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 0b441bcbeb8fe..4b4eba0522e40 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -1066,9 +1066,8 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: strpos(__common_expr_1, Utf8("f")) AS c, strpos(__common_expr_1, CAST(test.column2_utf8view AS Utf8)) AS c2 -02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view -03)----TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: strpos(test.column1_utf8view, Utf8("f")) AS c, strpos(test.column1_utf8view, test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for SUBSTR ## TODO file ticket From b2ac83ff821b8434cfcc9b3391d7039120c7b259 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 22:36:29 +0300 Subject: [PATCH 09/10] Update itertools requirement from 0.12 to 0.13 (#10556) * Update itertools requirement from 0.12 to 0.13 Updates the requirements on [itertools](https://github.com/rust-itertools/itertools) to permit the latest version. - [Changelog](https://github.com/rust-itertools/itertools/blob/master/CHANGELOG.md) - [Commits](https://github.com/rust-itertools/itertools/compare/v0.12.0...v0.13.0) --- updated-dependencies: - dependency-name: itertools dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Update Cargo.lock * Avoid deprecated API * nested-functions: workspace version of itertools --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb Co-authored-by: Eduard Karacharov --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 25 ++++++------------- .../datasource/physical_plan/file_groups.rs | 2 +- datafusion/functions-nested/Cargo.toml | 2 +- 4 files changed, 11 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d82443f5d1c8d..124747999041f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -120,7 +120,7 @@ futures = "0.3" half = { version = "2.2.1", default-features = false } hashbrown = { version = "0.14.5", features = ["raw"] } indexmap = "2.0.0" -itertools = "0.12" +itertools = "0.13" log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.10.2", default-features = false } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index a164b74c55a5e..e35eb3906b9a2 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1216,7 +1216,7 @@ dependencies = [ "half", "hashbrown", "indexmap", - "itertools 0.12.1", + "itertools", "log", "num-traits", "num_cpus", @@ -1367,7 +1367,7 @@ dependencies = [ "datafusion-expr", "hashbrown", "hex", - "itertools 0.12.1", + "itertools", "log", "md-5", "rand", @@ -1422,7 +1422,7 @@ dependencies = [ "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "itertools 0.12.1", + "itertools", "log", "paste", "rand", @@ -1450,7 +1450,7 @@ dependencies = [ "datafusion-physical-expr", "hashbrown", "indexmap", - "itertools 0.12.1", + "itertools", "log", "paste", "regex-syntax", @@ -1479,7 +1479,7 @@ dependencies = [ "hashbrown", "hex", "indexmap", - "itertools 0.12.1", + "itertools", "log", "paste", "petgraph", @@ -1520,7 +1520,7 @@ dependencies = [ "datafusion-execution", "datafusion-physical-expr", "datafusion-physical-plan", - "itertools 0.12.1", + "itertools", ] [[package]] @@ -1548,7 +1548,7 @@ dependencies = [ "half", "hashbrown", "indexmap", - "itertools 0.12.1", + "itertools", "log", "once_cell", "parking_lot", @@ -2245,15 +2245,6 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -2645,7 +2636,7 @@ dependencies = [ "futures", "humantime", "hyper 1.4.1", - "itertools 0.13.0", + "itertools", "md-5", "parking_lot", "percent-encoding", diff --git a/datafusion/core/src/datasource/physical_plan/file_groups.rs b/datafusion/core/src/datasource/physical_plan/file_groups.rs index 6456bd5c72766..fb2cd4ad06ec9 100644 --- a/datafusion/core/src/datasource/physical_plan/file_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/file_groups.rs @@ -256,7 +256,7 @@ impl FileGroupPartitioner { }, ) .flatten() - .group_by(|(partition_idx, _)| *partition_idx) + .chunk_by(|(partition_idx, _)| *partition_idx) .into_iter() .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) .collect_vec(); diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 6a1973ecfed17..5e1a15233cb52 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -50,7 +50,7 @@ datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } -itertools = { version = "0.12", features = ["use_std"] } +itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "1.0.14" rand = "0.8.5" From a50aeefcbfc84d491495887d57fa8ebc0db57ff2 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 22 Aug 2024 13:06:39 +0800 Subject: [PATCH 10/10] Fix wildcard expansion for `HAVING` clause (#12046) * fix the wildcard expand for filter plan * expand the wildcard for the error message * add the tests * fix recompute_schema * fix clippy * cargo fmt * change the check for having clause * rename the function and moving the tests * fix check * expand the schema for aggregate plan * reduce the time to expand wildcard * clean the testing table after tested * fmt and address review * stop expand wildcard and add more check for group-by and selects * simplify the having check --- datafusion/expr/src/logical_plan/builder.rs | 8 ++ datafusion/expr/src/logical_plan/plan.rs | 26 +++++- datafusion/expr/src/logical_plan/tree_node.rs | 28 ++++-- datafusion/expr/src/utils.rs | 9 ++ .../src/analyzer/expand_wildcard_rule.rs | 5 +- datafusion/sql/src/select.rs | 2 +- datafusion/sql/src/utils.rs | 3 +- .../sqllogictest/test_files/aggregate.slt | 91 +++++++++++++++++++ 8 files changed, 157 insertions(+), 15 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index a96caa03d6110..d8e5d8bbdc0c9 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -385,6 +385,14 @@ impl LogicalPlanBuilder { .map(Self::from) } + /// Apply a filter which is used for a having clause + pub fn having(self, expr: impl Into) -> Result { + let expr = normalize_col(expr.into(), &self.plan)?; + Filter::try_new_with_having(expr, Arc::new(self.plan)) + .map(LogicalPlan::Filter) + .map(Self::from) + } + /// Make a builder for a prepare logical plan from the builder's plan pub fn prepare(self, name: String, data_types: Vec) -> Result { Ok(Self::from(LogicalPlan::Prepare(Prepare { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index f93b7c0fedd09..ca7d04b9b03ec 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -643,9 +643,12 @@ impl LogicalPlan { // todo it isn't clear why the schema is not recomputed here Ok(LogicalPlan::Values(Values { schema, values })) } - LogicalPlan::Filter(Filter { predicate, input }) => { - Filter::try_new(predicate, input).map(LogicalPlan::Filter) - } + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) => Filter::try_new_internal(predicate, input, having) + .map(LogicalPlan::Filter), LogicalPlan::Repartition(_) => Ok(self), LogicalPlan::Window(Window { input, @@ -2080,6 +2083,8 @@ pub struct Filter { pub predicate: Expr, /// The incoming logical plan pub input: Arc, + /// The flag to indicate if the filter is a having clause + pub having: bool, } impl Filter { @@ -2088,6 +2093,20 @@ impl Filter { /// Notes: as Aliases have no effect on the output of a filter operator, /// they are removed from the predicate expression. pub fn try_new(predicate: Expr, input: Arc) -> Result { + Self::try_new_internal(predicate, input, false) + } + + /// Create a new filter operator for a having clause. + /// This is similar to a filter, but its having flag is set to true. + pub fn try_new_with_having(predicate: Expr, input: Arc) -> Result { + Self::try_new_internal(predicate, input, true) + } + + fn try_new_internal( + predicate: Expr, + input: Arc, + having: bool, + ) -> Result { // Filter predicates must return a boolean value so we try and validate that here. // Note that it is not always possible to resolve the predicate expression during plan // construction (such as with correlated subqueries) so we make a best effort here and @@ -2104,6 +2123,7 @@ impl Filter { Ok(Self { predicate: predicate.unalias_nested().data, input, + having, }) } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index dbe43128fd384..539cb1cf5fb22 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -87,8 +87,17 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Filter(Filter { predicate, input }) => rewrite_arc(input, f)? - .update_data(|input| LogicalPlan::Filter(Filter { predicate, input })), + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) + }), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -561,10 +570,17 @@ impl LogicalPlan { value.into_iter().map_until_stop_and_collect(&mut f) })? .update_data(|values| LogicalPlan::Values(Values { schema, values })), - LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? - .update_data(|predicate| { - LogicalPlan::Filter(Filter { predicate, input }) - }), + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) => f(predicate)?.update_data(|predicate| { + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) + }), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 5f5c468fa2f59..11a244a944f81 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -804,6 +804,15 @@ pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan { match input { LogicalPlan::Window(window) => find_base_plan(&window.input), LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input), + LogicalPlan::Filter(filter) => { + if filter.having { + // If a filter is used for a having clause, its input plan is an aggregation. + // We should expand the wildcard expression based on the aggregation's input plan. + find_base_plan(&filter.input) + } else { + input + } + } _ => input, } } diff --git a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs index 53ba3042f522e..dd422f7aab954 100644 --- a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs @@ -160,14 +160,13 @@ fn replace_columns( mod tests { use arrow::datatypes::{DataType, Field, Schema}; + use crate::test::{assert_analyzed_plan_eq_display_indent, test_table_scan}; + use crate::Analyzer; use datafusion_common::{JoinType, TableReference}; use datafusion_expr::{ col, in_subquery, qualified_wildcard, table_scan, wildcard, LogicalPlanBuilder, }; - use crate::test::{assert_analyzed_plan_eq_display_indent, test_table_scan}; - use crate::Analyzer; - use super::*; fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 4e0ce33f1334d..45fda094557b0 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -215,7 +215,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr { LogicalPlanBuilder::from(plan) - .filter(having_expr_post_aggr)? + .having(having_expr_post_aggr)? .build()? } else { plan diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index af161bba45c14..c32acecaae5fd 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -17,8 +17,6 @@ //! SQL Utility Functions -use std::collections::HashMap; - use arrow_schema::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; @@ -33,6 +31,7 @@ use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{expr_vec_fmt, Expr, ExprSchemable, LogicalPlan}; use sqlparser::ast::{Ident, Value}; +use std::collections::HashMap; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index d39bf6538ecbc..09fc397bf915c 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5655,6 +5655,97 @@ select count(null), min(null), max(null), bit_and(NULL), bit_or(NULL), bit_xor(N ---- 0 NULL NULL NULL NULL NULL NULL NULL +statement ok +create table having_test(v1 int, v2 int) + +statement ok +create table join_table(v1 int, v2 int) + +statement ok +insert into having_test values (1, 2), (2, 3), (3, 4) + +statement ok +insert into join_table values (1, 2), (2, 3), (3, 4) + + +query II +select * from having_test group by v1, v2 having max(v1) = 3 +---- +3 4 + +query TT +EXPLAIN select * from having_test group by v1, v2 having max(v1) = 3 +---- +logical_plan +01)Projection: having_test.v1, having_test.v2 +02)--Filter: max(having_test.v1) = Int32(3) +03)----Aggregate: groupBy=[[having_test.v1, having_test.v2]], aggr=[[max(having_test.v1)]] +04)------TableScan: having_test projection=[v1, v2] +physical_plan +01)ProjectionExec: expr=[v1@0 as v1, v2@1 as v2] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: max(having_test.v1)@2 = 3 +04)------AggregateExec: mode=FinalPartitioned, gby=[v1@0 as v1, v2@1 as v2], aggr=[max(having_test.v1)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([v1@0, v2@1], 4), input_partitions=4 +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------AggregateExec: mode=Partial, gby=[v1@0 as v1, v2@1 as v2], aggr=[max(having_test.v1)] +09)----------------MemoryExec: partitions=1, partition_sizes=[1] + + +query error +select * from having_test having max(v1) = 3 + +query I +select max(v1) from having_test having max(v1) = 3 +---- +3 + +query I +select max(v1), * exclude (v1, v2) from having_test having max(v1) = 3 +---- +3 + +# because v1, v2 is not in the group by clause, the sql is invalid +query III +select max(v1), * replace ('v1' as v3) from having_test group by v1, v2 having max(v1) = 3 +---- +3 3 4 + +query III +select max(v1), t.* from having_test t group by v1, v2 having max(v1) = 3 +---- +3 3 4 + +# j.* should also be included in the group-by clause +query error +select max(t.v1), j.* from having_test t join join_table j on t.v1 = j.v1 group by t.v1, t.v2 having max(t.v1) = 3 + +query III +select max(t.v1), j.* from having_test t join join_table j on t.v1 = j.v1 group by j.v1, j.v2 having max(t.v1) = 3 +---- +3 3 4 + +# If the select items only contain scalar expressions, the having clause is valid. +query P +select now() from having_test having max(v1) = 4 +---- + +# If the select items only contain scalar expressions, the having clause is valid. +query I +select 0 from having_test having max(v1) = 4 +---- + +# v2 should also be included in group-by clause +query error +select * from having_test group by v1 having max(v1) = 3 + +statement ok +drop table having_test + +statement ok +drop table join_table + # test min/max Float16 without group expression query RRTT WITH data AS (