From ed374678d1beac56d39e673eb0edb78f34458f68 Mon Sep 17 00:00:00 2001 From: Yang Jiang Date: Thu, 11 Apr 2024 00:48:49 +0800 Subject: [PATCH 1/5] Prune columns are all null in ParquetExec by row_counts , handle IS NOT NULL (#9989) * Prune columns are all null in ParquetExec by row_counts in pruning statistics * fix clippy * Update datafusion/core/tests/parquet/row_group_pruning.rs Co-authored-by: Ruihang Xia * fix comment and support isNotNUll * add test * fix conflict --------- Co-authored-by: Ruihang Xia --- .../physical_plan/parquet/row_groups.rs | 10 +++- .../core/src/physical_optimizer/pruning.rs | 38 +++++++++--- datafusion/core/tests/parquet/mod.rs | 30 ++++++++++ .../core/tests/parquet/row_group_pruning.rs | 60 +++++++++++++++++++ 4 files changed, 128 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 6600dd07d7fd..2b9665954842 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -338,8 +338,10 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { scalar.to_array().ok() } - fn row_counts(&self, _column: &Column) -> Option { - None + fn row_counts(&self, column: &Column) -> Option { + let (c, _) = self.column(&column.name)?; + let scalar = ScalarValue::UInt64(Some(c.num_values() as u64)); + scalar.to_array().ok() } fn contained( @@ -1022,15 +1024,17 @@ mod tests { column_statistics: Vec, ) -> RowGroupMetaData { let mut columns = vec![]; + let number_row = 1000; for (i, s) in column_statistics.iter().enumerate() { let column = ColumnChunkMetaData::builder(schema_descr.column(i)) .set_statistics(s.clone()) + .set_num_values(number_row) .build() .unwrap(); columns.push(column); } RowGroupMetaData::builder(schema_descr.clone()) - .set_num_rows(1000) + .set_num_rows(number_row) .set_total_byte_size(2000) .set_column_metadata(columns) .build() diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index dc7e0529decb..ebb811408fb3 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -335,6 +335,7 @@ pub trait PruningStatistics { /// `x < 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_max < 5 END` /// `x = 5 AND y = 10` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END AND CASE WHEN y_null_count = y_row_count THEN false ELSE y_min <= 10 AND 10 <= y_max END` /// `x IS NULL` | `x_null_count > 0` +/// `x IS NOT NULL` | `x_null_count = 0` /// `CAST(x as int) = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int) END` /// /// ## Predicate Evaluation @@ -1239,10 +1240,15 @@ fn build_single_column_expr( /// returns a pruning expression in terms of IsNull that will evaluate to true /// if the column may contain null, and false if definitely does not /// contain null. +/// If set `with_not` to true: which means is not null +/// Given an expression reference to `expr`, if `expr` is a column expression, +/// returns a pruning expression in terms of IsNotNull that will evaluate to true +/// if the column not contain any null, and false if definitely contain null. fn build_is_null_column_expr( expr: &Arc, schema: &Schema, required_columns: &mut RequiredColumns, + with_not: bool, ) -> Option> { if let Some(col) = expr.as_any().downcast_ref::() { let field = schema.field_with_name(col.name()).ok()?; @@ -1251,12 +1257,21 @@ fn build_is_null_column_expr( required_columns .null_count_column_expr(col, expr, null_count_field) .map(|null_count_column_expr| { - // IsNull(column) => null_count > 0 - Arc::new(phys_expr::BinaryExpr::new( - null_count_column_expr, - Operator::Gt, - Arc::new(phys_expr::Literal::new(ScalarValue::UInt64(Some(0)))), - )) as _ + if with_not { + // IsNotNull(column) => null_count = 0 + Arc::new(phys_expr::BinaryExpr::new( + null_count_column_expr, + Operator::Eq, + Arc::new(phys_expr::Literal::new(ScalarValue::UInt64(Some(0)))), + )) as _ + } else { + // IsNull(column) => null_count > 0 + Arc::new(phys_expr::BinaryExpr::new( + null_count_column_expr, + Operator::Gt, + Arc::new(phys_expr::Literal::new(ScalarValue::UInt64(Some(0)))), + )) as _ + } }) .ok() } else { @@ -1287,9 +1302,18 @@ fn build_predicate_expression( // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { - return build_is_null_column_expr(is_null.arg(), schema, required_columns) + return build_is_null_column_expr(is_null.arg(), schema, required_columns, false) .unwrap_or(unhandled); } + if let Some(is_not_null) = expr_any.downcast_ref::() { + return build_is_null_column_expr( + is_not_null.arg(), + schema, + required_columns, + true, + ) + .unwrap_or(unhandled); + } if let Some(col) = expr_any.downcast_ref::() { return build_single_column_expr(col, schema, required_columns, false) .unwrap_or(unhandled); diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index b4415d638ada..f36afe1976b1 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -28,6 +28,7 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; +use arrow_array::new_null_array; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{physical_plan::ParquetExec, provider_as_source, TableProvider}, @@ -75,6 +76,7 @@ enum Scenario { DecimalLargePrecisionBloomFilter, ByteArray, PeriodsInColumnNames, + WithNullValues, } enum Unit { @@ -630,6 +632,27 @@ fn make_names_batch(name: &str, service_name_values: Vec<&str>) -> RecordBatch { RecordBatch::try_new(schema, vec![Arc::new(name), Arc::new(service_name)]).unwrap() } +/// Return record batch with i8, i16, i32, and i64 sequences with all Null values +fn make_all_null_values() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("i8", DataType::Int8, true), + Field::new("i16", DataType::Int16, true), + Field::new("i32", DataType::Int32, true), + Field::new("i64", DataType::Int64, true), + ])); + + RecordBatch::try_new( + schema, + vec![ + new_null_array(&DataType::Int8, 5), + new_null_array(&DataType::Int16, 5), + new_null_array(&DataType::Int32, 5), + new_null_array(&DataType::Int64, 5), + ], + ) + .unwrap() +} + fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Timestamps => { @@ -799,6 +822,13 @@ fn create_data_batch(scenario: Scenario) -> Vec { ), ] } + Scenario::WithNullValues => { + vec![ + make_all_null_values(), + make_int_batches(1, 6), + make_all_null_values(), + ] + } } } diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 8fc7936552af..29bf1ef0a8d4 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -1262,3 +1262,63 @@ async fn prune_periods_in_column_names() { .test_row_group_prune() .await; } + +#[tokio::test] +async fn test_row_group_with_null_values() { + // Three row groups: + // 1. all Null values + // 2. values from 1 to 5 + // 3. all Null values + + // After pruning, only row group 2 should be selected + RowGroupPruningTest::new() + .with_scenario(Scenario::WithNullValues) + .with_query("SELECT * FROM t WHERE \"i8\" <= 5") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_expected_rows(5) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + // After pruning, only row group 1,3 should be selected + RowGroupPruningTest::new() + .with_scenario(Scenario::WithNullValues) + .with_query("SELECT * FROM t WHERE \"i8\" is Null") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(10) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + // After pruning, only row group 2should be selected + RowGroupPruningTest::new() + .with_scenario(Scenario::WithNullValues) + .with_query("SELECT * FROM t WHERE \"i16\" is Not Null") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_expected_rows(5) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + // All row groups will be pruned + RowGroupPruningTest::new() + .with_scenario(Scenario::WithNullValues) + .with_query("SELECT * FROM t WHERE \"i32\" > 7") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(3)) + .with_expected_rows(0) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} From fdb2d5761c64273ac7326b4a86b052b9bb9c08c7 Mon Sep 17 00:00:00 2001 From: JasonLi Date: Thu, 11 Apr 2024 00:56:09 +0800 Subject: [PATCH 2/5] Improve the performance of ltrim/rtrim/btrim (#10006) * optimize trim function * fix: the second arg is NULL --- datafusion/functions/Cargo.toml | 5 +++ datafusion/functions/benches/ltrim.rs | 50 +++++++++++++++++++++++ datafusion/functions/src/string/btrim.rs | 11 ++++- datafusion/functions/src/string/common.rs | 17 +++++++- datafusion/functions/src/string/ltrim.rs | 11 ++++- datafusion/functions/src/string/rtrim.rs | 11 ++++- 6 files changed, 98 insertions(+), 7 deletions(-) create mode 100644 datafusion/functions/benches/ltrim.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index a6847f3327c0..66f8b3010fd2 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -113,3 +113,8 @@ required-features = ["datetime_expressions"] harness = false name = "substr_index" required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "ltrim" +required-features = ["string_expressions"] diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs new file mode 100644 index 000000000000..01acb9de3381 --- /dev/null +++ b/datafusion/functions/benches/ltrim.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{ArrayRef, StringArray}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_functions::string; +use std::sync::Arc; + +fn create_args(size: usize, characters: &str) -> Vec { + let iter = + std::iter::repeat(format!("{}datafusion{}", characters, characters)).take(size); + let array = Arc::new(StringArray::from_iter_values(iter)) as ArrayRef; + vec![ + ColumnarValue::Array(array), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(characters.to_string()))), + ] +} + +fn criterion_benchmark(c: &mut Criterion) { + let ltrim = string::ltrim(); + for char in ["\"", "Header:"] { + for size in [1024, 4096, 8192] { + let args = create_args(size, char); + c.bench_function(&format!("ltrim {}: {}", char, size), |b| { + b.iter(|| black_box(ltrim.invoke(&args))) + }); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index b0a85eab6d83..971f7bbd4d92 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -24,6 +24,7 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_physical_expr::functions::Hint; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; @@ -72,8 +73,14 @@ impl ScalarUDFImpl for BTrimFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(btrim::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(btrim::, vec![])(args), + DataType::Utf8 => make_scalar_function( + btrim::, + vec![Hint::Pad, Hint::AcceptsSingular], + )(args), + DataType::LargeUtf8 => make_scalar_function( + btrim::, + vec![Hint::Pad, Hint::AcceptsSingular], + )(args), other => exec_err!("Unsupported data type {other:?} for function btrim"), } } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 276aad121df2..2b554db3979f 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -18,7 +18,9 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + new_null_array, Array, ArrayRef, GenericStringArray, OffsetSizeTrait, +}; use arrow::datatypes::DataType; use datafusion_common::cast::as_generic_string_array; @@ -78,6 +80,19 @@ pub(crate) fn general_trim( 2 => { let characters_array = as_generic_string_array::(&args[1])?; + if characters_array.len() == 1 { + if characters_array.is_null(0) { + return Ok(new_null_array(args[0].data_type(), args[0].len())); + } + + let characters = characters_array.value(0); + let result = string_array + .iter() + .map(|item| item.map(|string| func(string, characters))) + .collect::>(); + return Ok(Arc::new(result) as ArrayRef); + } + let result = string_array .iter() .zip(characters_array.iter()) diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index ad86259d0d7e..1a6a9d497f66 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -24,6 +24,7 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_physical_expr::functions::Hint; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; @@ -70,8 +71,14 @@ impl ScalarUDFImpl for LtrimFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(ltrim::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(ltrim::, vec![])(args), + DataType::Utf8 => make_scalar_function( + ltrim::, + vec![Hint::Pad, Hint::AcceptsSingular], + )(args), + DataType::LargeUtf8 => make_scalar_function( + ltrim::, + vec![Hint::Pad, Hint::AcceptsSingular], + )(args), other => exec_err!("Unsupported data type {other:?} for function ltrim"), } } diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 607e647b2615..e6e93e38c966 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -24,6 +24,7 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_physical_expr::functions::Hint; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; @@ -70,8 +71,14 @@ impl ScalarUDFImpl for RtrimFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(rtrim::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(rtrim::, vec![])(args), + DataType::Utf8 => make_scalar_function( + rtrim::, + vec![Hint::Pad, Hint::AcceptsSingular], + )(args), + DataType::LargeUtf8 => make_scalar_function( + rtrim::, + vec![Hint::Pad, Hint::AcceptsSingular], + )(args), other => exec_err!("Unsupported data type {other:?} for function rtrim"), } } From a13c37d1d0e3cd0a1383d1685e1efdc015bb4bc8 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Wed, 10 Apr 2024 19:34:53 +0200 Subject: [PATCH 3/5] fix: `RepartitionExec` metrics (#10025) `RepartitionExec` is somewhat special. While most execs operate on "input partition = output partition", `RepartitionExec` drives all of its work using input-bound tasks. The metrics "fetch time" and "repartition time" therefore have to be accounted for the input partition, not for the output partition. The only metric that has an input & output partition label is the "send time". Fixes #10015. --- .../physical-plan/src/repartition/mod.rs | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 2ed5da7ced20..59c71dbf89b4 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -130,8 +130,7 @@ impl RepartitionExecState { }) .collect(); - // TODO: metric input-output mapping is broken - let r_metrics = RepartitionMetrics::new(i, 0, &metrics); + let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics); let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( input.clone(), @@ -411,32 +410,36 @@ struct RepartitionMetrics { fetch_time: metrics::Time, /// Time in nanos to perform repartitioning repartition_time: metrics::Time, - /// Time in nanos for sending resulting batches to channels - send_time: metrics::Time, + /// Time in nanos for sending resulting batches to channels. + /// + /// One metric per output partition. + send_time: Vec, } impl RepartitionMetrics { pub fn new( - output_partition: usize, input_partition: usize, + num_output_partitions: usize, metrics: &ExecutionPlanMetricsSet, ) -> Self { - let label = metrics::Label::new("inputPartition", input_partition.to_string()); - // Time in nanos to execute child operator and fetch batches - let fetch_time = MetricBuilder::new(metrics) - .with_label(label.clone()) - .subset_time("fetch_time", output_partition); + let fetch_time = + MetricBuilder::new(metrics).subset_time("fetch_time", input_partition); // Time in nanos to perform repartitioning - let repart_time = MetricBuilder::new(metrics) - .with_label(label.clone()) - .subset_time("repart_time", output_partition); + let repart_time = + MetricBuilder::new(metrics).subset_time("repart_time", input_partition); // Time in nanos for sending resulting batches to channels - let send_time = MetricBuilder::new(metrics) - .with_label(label) - .subset_time("send_time", output_partition); + let send_time = (0..num_output_partitions) + .map(|output_partition| { + let label = + metrics::Label::new("outputPartition", output_partition.to_string()); + MetricBuilder::new(metrics) + .with_label(label) + .subset_time("send_time", input_partition) + }) + .collect(); Self { fetch_time, @@ -786,7 +789,7 @@ impl RepartitionExec { let (partition, batch) = res?; let size = batch.get_array_memory_size(); - let timer = metrics.send_time.timer(); + let timer = metrics.send_time[partition].timer(); // if there is still a receiver, send to it if let Some((tx, reservation)) = output_channels.get_mut(&partition) { reservation.lock().try_grow(size)?; From 69595a48458715aadffc56974665ebbafee35bd7 Mon Sep 17 00:00:00 2001 From: JasonLi Date: Thu, 11 Apr 2024 04:10:27 +0800 Subject: [PATCH 4/5] modify emit() of TopK (#10030) --- datafusion/physical-plan/src/topk/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 9120566273d3..6a77bfaf3ccd 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -208,7 +208,7 @@ impl TopK { // break into record batches as needed let mut batches = vec![]; loop { - if batch.num_rows() < batch_size { + if batch.num_rows() <= batch_size { batches.push(Ok(batch)); break; } else { From b9759b9810a05b7993c0357a5346197395cfd4cc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 10 Apr 2024 17:31:41 -0400 Subject: [PATCH 5/5] Consolidate LogicalPlan tree node walking/rewriting code into one module (#10034) --- datafusion/expr/src/logical_plan/plan.rs | 515 +----------------- datafusion/expr/src/logical_plan/tree_node.rs | 511 ++++++++++++++++- 2 files changed, 515 insertions(+), 511 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7bad034a11ea..d16dfb140353 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,9 +25,7 @@ use std::sync::Arc; use super::dml::CopyTo; use super::DdlStatement; use crate::builder::change_redundant_column; -use crate::expr::{ - Alias, Exists, InSubquery, Placeholder, Sort as SortExpr, WindowFunction, -}; +use crate::expr::{Alias, Placeholder, Sort as SortExpr, WindowFunction}; use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; @@ -44,19 +42,16 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeIterator, TreeNodeRecursion, - TreeNodeRewriter, TreeNodeVisitor, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - aggregate_functional_dependencies, internal_err, map_until_stop_and_collect, - plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, - FunctionalDependence, FunctionalDependencies, ParamValues, Result, TableReference, - UnnestOptions, + aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, + DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, + FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, }; // backwards compatibility use crate::display::PgJsonVisitor; -use crate::tree_node::transform_option_vec; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -315,314 +310,6 @@ impl LogicalPlan { err } - /// Calls `f` on all expressions in the current `LogicalPlan` node. - /// - /// Note this does not include expressions in child `LogicalPlan` nodes. - pub fn apply_expressions Result>( - &self, - mut f: F, - ) -> Result { - match self { - LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().apply_until_stop(f) - } - LogicalPlan::Values(Values { values, .. }) => values - .iter() - .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), - LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), - LogicalPlan::Repartition(Repartition { - partitioning_scheme, - .. - }) => match partitioning_scheme { - Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { - expr.iter().apply_until_stop(f) - } - Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), - }, - LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().apply_until_stop(f) - } - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - .. - }) => group_expr - .iter() - .chain(aggr_expr.iter()) - .apply_until_stop(f), - // There are two part of expression for join, equijoin(on) and non-equijoin(filter). - // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. - // 2. the second part is non-equijoin(filter). - LogicalPlan::Join(Join { on, filter, .. }) => { - on.iter() - // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... - // it not ideal to create an expr here to analyze them, but could cache it on the Join itself - .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .apply_until_stop(|e| f(&e))? - .visit_sibling(|| filter.iter().apply_until_stop(f)) - } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), - LogicalPlan::Extension(extension) => { - // would be nice to avoid this copy -- maybe can - // update extension to just observer Exprs - extension.node.expressions().iter().apply_until_stop(f) - } - LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().apply_until_stop(f) - } - LogicalPlan::Unnest(Unnest { column, .. }) => { - f(&Expr::Column(column.clone())) - } - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - .. - })) => on_expr - .iter() - .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten()) - .apply_until_stop(f), - // plans without expressions - LogicalPlan::EmptyRelation(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Union(_) - | LogicalPlan::Distinct(Distinct::All(_)) - | LogicalPlan::Dml(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), - } - } - - /// Rewrites all expressions in the current `LogicalPlan` node using `f`. - /// - /// Returns the current node. - /// - /// Note this does not include expressions in child `LogicalPlan` nodes. - pub fn map_expressions Result>>( - self, - mut f: F, - ) -> Result> { - Ok(match self { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) - }), - LogicalPlan::Values(Values { schema, values }) => values - .into_iter() - .map_until_stop_and_collect(|value| { - value.into_iter().map_until_stop_and_collect(&mut f) - })? - .update_data(|values| LogicalPlan::Values(Values { schema, values })), - LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? - .update_data(|predicate| { - LogicalPlan::Filter(Filter { predicate, input }) - }), - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => match partitioning_scheme { - Partitioning::Hash(expr, usize) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| Partitioning::Hash(expr, usize)), - Partitioning::DistributeBy(expr) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(Partitioning::DistributeBy), - Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), - } - .update_data(|partitioning_scheme| { - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) - }), - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) => window_expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|window_expr| { - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) - }), - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) => map_until_stop_and_collect!( - group_expr.into_iter().map_until_stop_and_collect(&mut f), - aggr_expr, - aggr_expr.into_iter().map_until_stop_and_collect(&mut f) - )? - .update_data(|(group_expr, aggr_expr)| { - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) - }), - - // There are two part of expression for join, equijoin(on) and non-equijoin(filter). - // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. - // 2. the second part is non-equijoin(filter). - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - join_constraint, - schema, - null_equals_null, - }) => map_until_stop_and_collect!( - on.into_iter().map_until_stop_and_collect( - |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) - ), - filter, - filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }) - )? - .update_data(|(on, filter)| { - LogicalPlan::Join(Join { - left, - right, - on, - filter, - join_type, - join_constraint, - schema, - null_equals_null, - }) - }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), - LogicalPlan::Extension(Extension { node }) => { - // would be nice to avoid this copy -- maybe can - // update extension to just observer Exprs - node.expressions() - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|exprs| { - LogicalPlan::Extension(Extension { - node: UserDefinedLogicalNode::from_template( - node.as_ref(), - exprs.as_slice(), - node.inputs() - .into_iter() - .cloned() - .collect::>() - .as_slice(), - ), - }) - }) - } - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - projected_schema, - filters, - fetch, - }) => filters - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|filters| { - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - projected_schema, - filters, - fetch, - }) - }), - LogicalPlan::Unnest(Unnest { - input, - column, - schema, - options, - }) => f(Expr::Column(column))?.map_data(|column| match column { - Expr::Column(column) => Ok(LogicalPlan::Unnest(Unnest { - input, - column, - schema, - options, - })), - _ => internal_err!("Transformation should return Column"), - })?, - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - schema, - })) => map_until_stop_and_collect!( - on_expr.into_iter().map_until_stop_and_collect(&mut f), - select_expr, - select_expr.into_iter().map_until_stop_and_collect(&mut f), - sort_expr, - transform_option_vec(sort_expr, &mut f) - )? - .update_data(|(on_expr, select_expr, sort_expr)| { - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - schema, - })) - }), - // plans without expressions - LogicalPlan::EmptyRelation(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Union(_) - | LogicalPlan::Distinct(Distinct::All(_)) - | LogicalPlan::Dml(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Copy(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Transformed::no(self), - }) - } - /// Returns all inputs / children of this `LogicalPlan` node. /// /// Note does not include inputs to inputs, or subqueries. @@ -1354,192 +1041,7 @@ impl LogicalPlan { } } -/// This macro is used to determine continuation during combined transforming -/// traversals. -macro_rules! handle_transform_recursion { - ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ - $F_DOWN? - .transform_children(|n| n.map_subqueries($F_CHILD))? - .transform_sibling(|n| n.map_children($F_CHILD))? - .transform_parent($F_UP) - }}; -} - -macro_rules! handle_transform_recursion_down { - ($F_DOWN:expr, $F_CHILD:expr) => {{ - $F_DOWN? - .transform_children(|n| n.map_subqueries($F_CHILD))? - .transform_sibling(|n| n.map_children($F_CHILD)) - }}; -} - -macro_rules! handle_transform_recursion_up { - ($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{ - $SELF - .map_subqueries($F_CHILD)? - .transform_sibling(|n| n.map_children($F_CHILD))? - .transform_parent(|n| $F_UP(n)) - }}; -} - impl LogicalPlan { - /// Visits a plan similarly to [`Self::visit`], but including embedded subqueries. - pub fn visit_with_subqueries>( - &self, - visitor: &mut V, - ) -> Result { - visitor - .f_down(self)? - .visit_children(|| { - self.apply_subqueries(|c| c.visit_with_subqueries(visitor)) - })? - .visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))? - .visit_parent(|| visitor.f_up(self)) - } - - /// Rewrites a plan similarly t [`Self::visit`], but including embedded subqueries. - pub fn rewrite_with_subqueries>( - self, - rewriter: &mut R, - ) -> Result> { - handle_transform_recursion!( - rewriter.f_down(self), - |c| c.rewrite_with_subqueries(rewriter), - |n| rewriter.f_up(n) - ) - } - - /// Calls `f` recursively on all children of the `LogicalPlan` node. - /// - /// Unlike [`Self::apply`], this method *does* includes `LogicalPlan`s that - /// are referenced in `Expr`s - pub fn apply_with_subqueries Result>( - &self, - f: &mut F, - ) -> Result { - f(self)? - .visit_children(|| self.apply_subqueries(|c| c.apply_with_subqueries(f)))? - .visit_sibling(|| self.apply_children(|c| c.apply_with_subqueries(f))) - } - - pub fn transform_with_subqueries Result>>( - self, - f: &F, - ) -> Result> { - self.transform_up_with_subqueries(f) - } - - pub fn transform_down_with_subqueries Result>>( - self, - f: &F, - ) -> Result> { - handle_transform_recursion_down!(f(self), |c| c.transform_down_with_subqueries(f)) - } - - pub fn transform_down_mut_with_subqueries< - F: FnMut(Self) -> Result>, - >( - self, - f: &mut F, - ) -> Result> { - handle_transform_recursion_down!(f(self), |c| c - .transform_down_mut_with_subqueries(f)) - } - - pub fn transform_up_with_subqueries Result>>( - self, - f: &F, - ) -> Result> { - handle_transform_recursion_up!(self, |c| c.transform_up_with_subqueries(f), f) - } - - pub fn transform_up_mut_with_subqueries< - F: FnMut(Self) -> Result>, - >( - self, - f: &mut F, - ) -> Result> { - handle_transform_recursion_up!(self, |c| c.transform_up_mut_with_subqueries(f), f) - } - - pub fn transform_down_up_with_subqueries< - FD: FnMut(Self) -> Result>, - FU: FnMut(Self) -> Result>, - >( - self, - f_down: &mut FD, - f_up: &mut FU, - ) -> Result> { - handle_transform_recursion!( - f_down(self), - |c| c.transform_down_up_with_subqueries(f_down, f_up), - f_up - ) - } - - /// Calls `f` on all subqueries referenced in expressions of the current - /// `LogicalPlan` node. - pub fn apply_subqueries Result>( - &self, - mut f: F, - ) -> Result { - self.apply_expressions(|expr| { - expr.apply(&mut |expr| match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - f(&LogicalPlan::Subquery(subquery.clone())) - } - _ => Ok(TreeNodeRecursion::Continue), - }) - }) - } - - /// Rewrites all subquery `LogicalPlan` in the current `LogicalPlan` node - /// using `f`. - /// - /// Returns the current node. - pub fn map_subqueries Result>>( - self, - mut f: F, - ) -> Result> { - self.map_expressions(|expr| { - expr.transform_down_mut(&mut |expr| match expr { - Expr::Exists(Exists { subquery, negated }) => { - f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { - LogicalPlan::Subquery(subquery) => { - Ok(Expr::Exists(Exists { subquery, negated })) - } - _ => internal_err!("Transformation should return Subquery"), - }) - } - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { - LogicalPlan::Subquery(subquery) => Ok(Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - })), - _ => internal_err!("Transformation should return Subquery"), - }), - Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? - .map_data(|s| match s { - LogicalPlan::Subquery(subquery) => { - Ok(Expr::ScalarSubquery(subquery)) - } - _ => internal_err!("Transformation should return Subquery"), - }), - _ => Ok(Transformed::no(expr)), - }) - }) - } - /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, /// ...) replaced with corresponding values provided in /// `params_values` @@ -1623,10 +1125,11 @@ impl LogicalPlan { }) .data() } -} -// Various implementations for printing out LogicalPlans -impl LogicalPlan { + // ------------ + // Various implementations for printing out LogicalPlans + // ------------ + /// Return a `format`able structure that produces a single line /// per node. /// diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 415343f88685..1eb9d50277dd 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -38,16 +38,22 @@ //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, - DdlStatement, Distinct, DistinctOn, DmlStatement, Explain, Extension, Filter, Join, - Limit, LogicalPlan, Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery, - SubqueryAlias, Union, Unnest, Window, + DdlStatement, Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, + Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, + Repartition, Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest, + UserDefinedLogicalNode, Values, Window, }; use std::sync::Arc; +use crate::expr::{Exists, InSubquery}; +use crate::tree_node::transform_option_vec; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, +}; +use datafusion_common::{ + internal_err, map_until_stop_and_collect, DataFusionError, Result, }; -use datafusion_common::{map_until_stop_and_collect, Result}; impl TreeNode for LogicalPlan { fn apply_children Result>( @@ -413,3 +419,498 @@ where }) }) } + +/// This macro is used to determine continuation during combined transforming +/// traversals. +macro_rules! handle_transform_recursion { + ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent($F_UP) + }}; +} + +macro_rules! handle_transform_recursion_down { + ($F_DOWN:expr, $F_CHILD:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD)) + }}; +} + +macro_rules! handle_transform_recursion_up { + ($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $SELF + .map_subqueries($F_CHILD)? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent(|n| $F_UP(n)) + }}; +} + +impl LogicalPlan { + /// Calls `f` on all expressions in the current `LogicalPlan` node. + /// + /// Note this does not include expressions in child `LogicalPlan` nodes. + pub fn apply_expressions Result>( + &self, + mut f: F, + ) -> Result { + match self { + LogicalPlan::Projection(Projection { expr, .. }) => { + expr.iter().apply_until_stop(f) + } + LogicalPlan::Values(Values { values, .. }) => values + .iter() + .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), + LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), + LogicalPlan::Repartition(Repartition { + partitioning_scheme, + .. + }) => match partitioning_scheme { + Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { + expr.iter().apply_until_stop(f) + } + Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), + }, + LogicalPlan::Window(Window { window_expr, .. }) => { + window_expr.iter().apply_until_stop(f) + } + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => group_expr + .iter() + .chain(aggr_expr.iter()) + .apply_until_stop(f), + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { on, filter, .. }) => { + on.iter() + // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... + // it not ideal to create an expr here to analyze them, but could cache it on the Join itself + .map(|(l, r)| Expr::eq(l.clone(), r.clone())) + .apply_until_stop(|e| f(&e))? + .visit_sibling(|| filter.iter().apply_until_stop(f)) + } + LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), + LogicalPlan::Extension(extension) => { + // would be nice to avoid this copy -- maybe can + // update extension to just observer Exprs + extension.node.expressions().iter().apply_until_stop(f) + } + LogicalPlan::TableScan(TableScan { filters, .. }) => { + filters.iter().apply_until_stop(f) + } + LogicalPlan::Unnest(Unnest { column, .. }) => { + f(&Expr::Column(column.clone())) + } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => on_expr + .iter() + .chain(select_expr.iter()) + .chain(sort_expr.iter().flatten()) + .apply_until_stop(f), + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), + } + } + + /// Rewrites all expressions in the current `LogicalPlan` node using `f`. + /// + /// Returns the current node. + /// + /// Note this does not include expressions in child `LogicalPlan` nodes. + pub fn map_expressions Result>>( + self, + mut f: F, + ) -> Result> { + Ok(match self { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), + LogicalPlan::Values(Values { schema, values }) => values + .into_iter() + .map_until_stop_and_collect(|value| { + value.into_iter().map_until_stop_and_collect(&mut f) + })? + .update_data(|values| LogicalPlan::Values(Values { schema, values })), + LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? + .update_data(|predicate| { + LogicalPlan::Filter(Filter { predicate, input }) + }), + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) => match partitioning_scheme { + Partitioning::Hash(expr, usize) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| Partitioning::Hash(expr, usize)), + Partitioning::DistributeBy(expr) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(Partitioning::DistributeBy), + Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), + } + .update_data(|partitioning_scheme| { + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) + }), + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) => window_expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|window_expr| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) => map_until_stop_and_collect!( + group_expr.into_iter().map_until_stop_and_collect(&mut f), + aggr_expr, + aggr_expr.into_iter().map_until_stop_and_collect(&mut f) + )? + .update_data(|(group_expr, aggr_expr)| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }), + + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) => map_until_stop_and_collect!( + on.into_iter().map_until_stop_and_collect( + |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) + ), + filter, + filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }) + )? + .update_data(|(on, filter)| { + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) + }), + LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Extension(Extension { node }) => { + // would be nice to avoid this copy -- maybe can + // update extension to just observer Exprs + node.expressions() + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|exprs| { + LogicalPlan::Extension(Extension { + node: UserDefinedLogicalNode::from_template( + node.as_ref(), + exprs.as_slice(), + node.inputs() + .into_iter() + .cloned() + .collect::>() + .as_slice(), + ), + }) + }) + } + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) => filters + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|filters| { + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) + }), + LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + }) => f(Expr::Column(column))?.map_data(|column| match column { + Expr::Column(column) => Ok(LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + })), + _ => internal_err!("Transformation should return Column"), + })?, + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) => map_until_stop_and_collect!( + on_expr.into_iter().map_until_stop_and_collect(&mut f), + select_expr, + select_expr.into_iter().map_until_stop_and_collect(&mut f), + sort_expr, + transform_option_vec(sort_expr, &mut f) + )? + .update_data(|(on_expr, select_expr, sort_expr)| { + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) + }), + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => Transformed::no(self), + }) + } + + /// Visits a plan similarly to [`Self::visit`], but including embedded subqueries. + pub fn visit_with_subqueries>( + &self, + visitor: &mut V, + ) -> Result { + visitor + .f_down(self)? + .visit_children(|| { + self.apply_subqueries(|c| c.visit_with_subqueries(visitor)) + })? + .visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))? + .visit_parent(|| visitor.f_up(self)) + } + + /// Rewrites a plan similarly t [`Self::visit`], but including embedded subqueries. + pub fn rewrite_with_subqueries>( + self, + rewriter: &mut R, + ) -> Result> { + handle_transform_recursion!( + rewriter.f_down(self), + |c| c.rewrite_with_subqueries(rewriter), + |n| rewriter.f_up(n) + ) + } + + /// Calls `f` recursively on all children of the `LogicalPlan` node. + /// + /// Unlike [`Self::apply`], this method *does* includes `LogicalPlan`s that + /// are referenced in `Expr`s + pub fn apply_with_subqueries Result>( + &self, + f: &mut F, + ) -> Result { + f(self)? + .visit_children(|| self.apply_subqueries(|c| c.apply_with_subqueries(f)))? + .visit_sibling(|| self.apply_children(|c| c.apply_with_subqueries(f))) + } + + pub fn transform_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + self.transform_up_with_subqueries(f) + } + + pub fn transform_down_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c.transform_down_with_subqueries(f)) + } + + pub fn transform_down_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c + .transform_down_mut_with_subqueries(f)) + } + + pub fn transform_up_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_with_subqueries(f), f) + } + + pub fn transform_up_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_mut_with_subqueries(f), f) + } + + pub fn transform_down_up_with_subqueries< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( + self, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> { + handle_transform_recursion!( + f_down(self), + |c| c.transform_down_up_with_subqueries(f_down, f_up), + f_up + ) + } + + /// Calls `f` on all subqueries referenced in expressions of the current + /// `LogicalPlan` node. + pub fn apply_subqueries Result>( + &self, + mut f: F, + ) -> Result { + self.apply_expressions(|expr| { + expr.apply(&mut |expr| match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + f(&LogicalPlan::Subquery(subquery.clone())) + } + _ => Ok(TreeNodeRecursion::Continue), + }) + }) + } + + /// Rewrites all subquery `LogicalPlan` in the current `LogicalPlan` node + /// using `f`. + /// + /// Returns the current node. + pub fn map_subqueries Result>>( + self, + mut f: F, + ) -> Result> { + self.map_expressions(|expr| { + expr.transform_down_mut(&mut |expr| match expr { + Expr::Exists(Exists { subquery, negated }) => { + f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::Exists(Exists { subquery, negated })) + } + _ => internal_err!("Transformation should return Subquery"), + }) + } + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => Ok(Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + })), + _ => internal_err!("Transformation should return Subquery"), + }), + Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? + .map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::ScalarSubquery(subquery)) + } + _ => internal_err!("Transformation should return Subquery"), + }), + _ => Ok(Transformed::no(expr)), + }) + }) + } +}