From 75c399ce7d4d5360140c64089dd7b05ffd7c49ef Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Wed, 10 Apr 2024 10:59:40 +0200 Subject: [PATCH 1/5] fix: reduce lock contention in `RepartitionExec::execute` (#10009) * fix: lock contention in `RepartitionExec::execute` The state is initialized ONCE for all partitions. However this may take a short while (on a very busy system 1ms or more). It is quite likely that multiple threads call `execute` at the same time, because we have just fanned out to the number "target partitions" which is likely set to the number of CPU cores which now all try to start to execute the plan at the same time. The solution is to not waste CPU circles in some futex lock but to tell the async runtime (= tokio) that we are performing work and the other threads should rather do something useful. This mostly just moves code around, no functional change intended. * docs: explain design choice Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- .../physical-plan/src/repartition/mod.rs | 338 +++++++++++------- 1 file changed, 206 insertions(+), 132 deletions(-) diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index c0dbf5164e19..2ed5da7ced20 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -36,6 +36,7 @@ use crate::repartition::distributor_channels::{ channels, partition_aware_channels, DistributionReceiver, DistributionSender, }; use crate::sorts::streaming_merge; +use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; use arrow::array::{ArrayRef, UInt64Builder}; @@ -48,7 +49,7 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; use futures::stream::Stream; -use futures::{FutureExt, StreamExt}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use hashbrown::HashMap; use log::trace; use parking_lot::Mutex; @@ -77,6 +78,102 @@ struct RepartitionExecState { abort_helper: Arc>>, } +impl RepartitionExecState { + fn new( + input: Arc, + partitioning: Partitioning, + metrics: ExecutionPlanMetricsSet, + preserve_order: bool, + name: String, + context: Arc, + ) -> Self { + let num_input_partitions = input.output_partitioning().partition_count(); + let num_output_partitions = partitioning.partition_count(); + + let (txs, rxs) = if preserve_order { + let (txs, rxs) = + partition_aware_channels(num_input_partitions, num_output_partitions); + // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition + let txs = transpose(txs); + let rxs = transpose(rxs); + (txs, rxs) + } else { + // create one channel per *output* partition + // note we use a custom channel that ensures there is always data for each receiver + // but limits the amount of buffering if required. + let (txs, rxs) = channels(num_output_partitions); + // Clone sender for each input partitions + let txs = txs + .into_iter() + .map(|item| vec![item; num_input_partitions]) + .collect::>(); + let rxs = rxs.into_iter().map(|item| vec![item]).collect::>(); + (txs, rxs) + }; + + let mut channels = HashMap::with_capacity(txs.len()); + for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { + let reservation = Arc::new(Mutex::new( + MemoryConsumer::new(format!("{}[{partition}]", name)) + .register(context.memory_pool()), + )); + channels.insert(partition, (tx, rx, reservation)); + } + + // launch one async task per *input* partition + let mut spawned_tasks = Vec::with_capacity(num_input_partitions); + for i in 0..num_input_partitions { + let txs: HashMap<_, _> = channels + .iter() + .map(|(partition, (tx, _rx, reservation))| { + (*partition, (tx[i].clone(), Arc::clone(reservation))) + }) + .collect(); + + // TODO: metric input-output mapping is broken + let r_metrics = RepartitionMetrics::new(i, 0, &metrics); + + let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( + input.clone(), + i, + txs.clone(), + partitioning.clone(), + r_metrics, + context.clone(), + )); + + // In a separate task, wait for each input to be done + // (and pass along any errors, including panic!s) + let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task( + input_task, + txs.into_iter() + .map(|(partition, (tx, _reservation))| (partition, tx)) + .collect(), + )); + spawned_tasks.push(wait_for_task); + } + + Self { + channels, + abort_helper: Arc::new(spawned_tasks), + } + } +} + +/// Lazily initialized state +/// +/// Note that the state is initialized ONCE for all partitions by a single task(thread). +/// This may take a short while. It is also like that multiple threads +/// call execute at the same time, because we have just started "target partitions" tasks +/// which is commonly set to the number of CPU cores and all call execute at the same time. +/// +/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles +/// in a futex lock but instead allow other threads to do something useful. +/// +/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration +/// (e.g. removing channels on completion) where the overhead of `await` is not warranted. +type LazyState = Arc>>; + /// A utility that can be used to partition batches based on [`Partitioning`] pub struct BatchPartitioner { state: BatchPartitionerState, @@ -298,7 +395,7 @@ pub struct RepartitionExec { /// Partitioning scheme to use partitioning: Partitioning, /// Inner state that is initialized when the first output stream is created. - state: Arc>, + state: LazyState, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Boolean flag to decide whether to preserve ordering. If true means @@ -453,134 +550,104 @@ impl ExecutionPlan for RepartitionExec { self.name(), partition ); - // lock mutexes - let mut state = self.state.lock(); - - let num_input_partitions = self.input.output_partitioning().partition_count(); - let num_output_partitions = self.partitioning.partition_count(); - - // if this is the first partition to be invoked then we need to set up initial state - if state.channels.is_empty() { - let (txs, rxs) = if self.preserve_order { - let (txs, rxs) = - partition_aware_channels(num_input_partitions, num_output_partitions); - // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition - let txs = transpose(txs); - let rxs = transpose(rxs); - (txs, rxs) - } else { - // create one channel per *output* partition - // note we use a custom channel that ensures there is always data for each receiver - // but limits the amount of buffering if required. - let (txs, rxs) = channels(num_output_partitions); - // Clone sender for each input partitions - let txs = txs - .into_iter() - .map(|item| vec![item; num_input_partitions]) - .collect::>(); - let rxs = rxs.into_iter().map(|item| vec![item]).collect::>(); - (txs, rxs) - }; - for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { - let reservation = Arc::new(Mutex::new( - MemoryConsumer::new(format!("{}[{partition}]", self.name())) - .register(context.memory_pool()), - )); - state.channels.insert(partition, (tx, rx, reservation)); - } - // launch one async task per *input* partition - let mut spawned_tasks = Vec::with_capacity(num_input_partitions); - for i in 0..num_input_partitions { - let txs: HashMap<_, _> = state - .channels - .iter() - .map(|(partition, (tx, _rx, reservation))| { - (*partition, (tx[i].clone(), Arc::clone(reservation))) - }) - .collect(); - - let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics); - - let input_task = SpawnedTask::spawn(Self::pull_from_input( - self.input.clone(), - i, - txs.clone(), - self.partitioning.clone(), - r_metrics, - context.clone(), - )); - - // In a separate task, wait for each input to be done - // (and pass along any errors, including panic!s) - let wait_for_task = SpawnedTask::spawn(Self::wait_for_task( - input_task, - txs.into_iter() - .map(|(partition, (tx, _reservation))| (partition, tx)) - .collect(), - )); - spawned_tasks.push(wait_for_task); - } + let lazy_state = Arc::clone(&self.state); + let input = Arc::clone(&self.input); + let partitioning = self.partitioning.clone(); + let metrics = self.metrics.clone(); + let preserve_order = self.preserve_order; + let name = self.name().to_owned(); + let schema = self.schema(); + let schema_captured = Arc::clone(&schema); + + // Get existing ordering to use for merging + let sort_exprs = self.sort_exprs().unwrap_or(&[]).to_owned(); + + let stream = futures::stream::once(async move { + let num_input_partitions = input.output_partitioning().partition_count(); + + let input_captured = Arc::clone(&input); + let metrics_captured = metrics.clone(); + let name_captured = name.clone(); + let context_captured = Arc::clone(&context); + let state = lazy_state + .get_or_init(|| async move { + Mutex::new(RepartitionExecState::new( + input_captured, + partitioning, + metrics_captured, + preserve_order, + name_captured, + context_captured, + )) + }) + .await; - state.abort_helper = Arc::new(spawned_tasks) - } + // lock scope + let (mut rx, reservation, abort_helper) = { + // lock mutexes + let mut state = state.lock(); - trace!( - "Before returning stream in {}::execute for partition: {}", - self.name(), - partition - ); + // now return stream for the specified *output* partition which will + // read from the channel + let (_tx, rx, reservation) = state + .channels + .remove(&partition) + .expect("partition not used yet"); - // now return stream for the specified *output* partition which will - // read from the channel - let (_tx, mut rx, reservation) = state - .channels - .remove(&partition) - .expect("partition not used yet"); + (rx, reservation, Arc::clone(&state.abort_helper)) + }; - if self.preserve_order { - // Store streams from all the input partitions: - let input_streams = rx - .into_iter() - .map(|receiver| { - Box::pin(PerPartitionStream { - schema: self.schema(), - receiver, - drop_helper: Arc::clone(&state.abort_helper), - reservation: reservation.clone(), - }) as SendableRecordBatchStream - }) - .collect::>(); - // Note that receiver size (`rx.len()`) and `num_input_partitions` are same. - - // Get existing ordering to use for merging - let sort_exprs = self.sort_exprs().unwrap_or(&[]); - - // Merge streams (while preserving ordering) coming from - // input partitions to this partition: - let fetch = None; - let merge_reservation = - MemoryConsumer::new(format!("{}[Merge {partition}]", self.name())) - .register(context.memory_pool()); - streaming_merge( - input_streams, - self.schema(), - sort_exprs, - BaselineMetrics::new(&self.metrics, partition), - context.session_config().batch_size(), - fetch, - merge_reservation, - ) - } else { - Ok(Box::pin(RepartitionStream { - num_input_partitions, - num_input_partitions_processed: 0, - schema: self.input.schema(), - input: rx.swap_remove(0), - drop_helper: Arc::clone(&state.abort_helper), - reservation, - })) - } + trace!( + "Before returning stream in {}::execute for partition: {}", + name, + partition + ); + + if preserve_order { + // Store streams from all the input partitions: + let input_streams = rx + .into_iter() + .map(|receiver| { + Box::pin(PerPartitionStream { + schema: Arc::clone(&schema_captured), + receiver, + drop_helper: Arc::clone(&abort_helper), + reservation: reservation.clone(), + }) as SendableRecordBatchStream + }) + .collect::>(); + // Note that receiver size (`rx.len()`) and `num_input_partitions` are same. + + // Merge streams (while preserving ordering) coming from + // input partitions to this partition: + let fetch = None; + let merge_reservation = + MemoryConsumer::new(format!("{}[Merge {partition}]", name)) + .register(context.memory_pool()); + streaming_merge( + input_streams, + schema_captured, + &sort_exprs, + BaselineMetrics::new(&metrics, partition), + context.session_config().batch_size(), + fetch, + merge_reservation, + ) + } else { + Ok(Box::pin(RepartitionStream { + num_input_partitions, + num_input_partitions_processed: 0, + schema: input.schema(), + input: rx.swap_remove(0), + drop_helper: abort_helper, + reservation, + }) as SendableRecordBatchStream) + } + }) + .try_flatten(); + let stream = RecordBatchStreamAdapter::new(schema, stream); + Ok(Box::pin(stream)) } fn metrics(&self) -> Option { @@ -606,10 +673,7 @@ impl RepartitionExec { Ok(RepartitionExec { input, partitioning, - state: Arc::new(Mutex::new(RepartitionExecState { - channels: HashMap::new(), - abort_helper: Arc::new(Vec::new()), - })), + state: Default::default(), metrics: ExecutionPlanMetricsSet::new(), preserve_order, cache, @@ -951,6 +1015,7 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use futures::FutureExt; + use tokio::task::JoinSet; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1240,7 +1305,10 @@ mod tests { std::mem::drop(output_stream0); // Now, start sending input - input.wait().await; + let mut background_task = JoinSet::new(); + background_task.spawn(async move { + input.wait().await; + }); // output stream 1 should *not* error and have one of the input batches let batches = crate::common::collect(output_stream1).await.unwrap(); @@ -1277,7 +1345,10 @@ mod tests { let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); - input.wait().await; + let mut background_task = JoinSet::new(); + background_task.spawn(async move { + input.wait().await; + }); let batches_without_drop = crate::common::collect(output_stream1).await.unwrap(); // run some checks on the result @@ -1299,7 +1370,10 @@ mod tests { // now, purposely drop output stream 0 // *before* any outputs are produced std::mem::drop(output_stream0); - input.wait().await; + let mut background_task = JoinSet::new(); + background_task.spawn(async move { + input.wait().await; + }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); assert_eq!(batches_without_drop, batches_with_drop); From 843caea55af8991e9f31e2d17c3a3debbd3965ee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 10 Apr 2024 12:14:31 +0200 Subject: [PATCH 2/5] chore(deps): update rstest requirement from 0.18.0 to 0.19.0 (#10021) Updates the requirements on [rstest](https://github.com/la10736/rstest) to permit the latest version. - [Release notes](https://github.com/la10736/rstest/releases) - [Changelog](https://github.com/la10736/rstest/blob/master/CHANGELOG.md) - [Commits](https://github.com/la10736/rstest/compare/v0.18.0...v0.18.2) --- updated-dependencies: - dependency-name: rstest dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e1e09d9893b1..c5a1aa9c8ef8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,7 +101,7 @@ object_store = { version = "0.9.1", default-features = false } parking_lot = "0.12" parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" -rstest = "0.18.0" +rstest = "0.19.0" serde_json = "1" sqlparser = { version = "0.44.0", features = ["visitor"] } tempfile = "3" From 582050728914650c6d4340ca803a0e9af087d8ec Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 10 Apr 2024 06:26:44 -0400 Subject: [PATCH 3/5] Minor: Document LogicalPlan tree node transformations (#10010) * Document LogicalPlan tree node transformations * Add exists * touchups, add apply_subqueries, map_subqueries --- datafusion/core/src/lib.rs | 10 ++++-- datafusion/expr/src/logical_plan/mod.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 20 +++++++++++- datafusion/expr/src/logical_plan/tree_node.rs | 32 +++++++++++++++---- 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index c213f4554fb8..b0e2b6fa9c09 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -296,11 +296,15 @@ //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! -//! [`Expr`]s can be rewritten using the [`TreeNode`] API and simplified using -//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be found in the -//! [`expr_api`.rs] example +//! `LogicalPlan`s can be rewritten with [`TreeNode`] API, see the +//! [`tree_node module`] for more details. +//! +//! [`Expr`]s can also be rewritten with [`TreeNode`] API and simplified using +//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be +//! found in the [`expr_api`.rs] example //! //! [`TreeNode`]: datafusion_common::tree_node::TreeNode +//! [`tree_node module`]: datafusion_expr::logical_plan::tree_node //! [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier //! [`expr_api`.rs]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs //! diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a1fe7a6f0a51..034440643e51 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -22,7 +22,7 @@ pub mod dml; mod extension; mod plan; mod statement; -mod tree_node; +pub mod tree_node; pub use builder::{ build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 02d65973a50b..7bad034a11ea 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -68,6 +68,11 @@ pub use datafusion_common::{JoinConstraint, JoinType}; /// an output relation (table) with a (potentially) different /// schema. A plan represents a dataflow tree where data flows /// from leaves up to the root to produce the query result. +/// +/// # See also: +/// * [`tree_node`]: visiting and rewriting API +/// +/// [`tree_node`]: crate::logical_plan::tree_node #[derive(Clone, PartialEq, Eq, Hash)] pub enum LogicalPlan { /// Evaluates an arbitrary list of expressions (essentially a @@ -238,7 +243,10 @@ impl LogicalPlan { } /// Returns all expressions (non-recursively) evaluated by the current - /// logical plan node. This does not include expressions in any children + /// logical plan node. This does not include expressions in any children. + /// + /// Note this method `clone`s all the expressions. When possible, the + /// [`tree_node`] API should be used instead of this API. /// /// The returned expressions do not necessarily represent or even /// contributed to the output schema of this node. For example, @@ -248,6 +256,8 @@ impl LogicalPlan { /// The expressions do contain all the columns that are used by this plan, /// so if there are columns not referenced by these expressions then /// DataFusion's optimizer attempts to optimize them away. + /// + /// [`tree_node`]: crate::logical_plan::tree_node pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; self.apply_expressions(|e| { @@ -773,10 +783,16 @@ impl LogicalPlan { /// Returns a new `LogicalPlan` based on `self` with inputs and /// expressions replaced. /// + /// Note this method creates an entirely new node, which requires a large + /// amount of clone'ing. When possible, the [`tree_node`] API should be used + /// instead of this API. + /// /// The exprs correspond to the same order of expressions returned /// by [`Self::expressions`]. This function is used by optimizers /// to rewrite plans using the following pattern: /// + /// [`tree_node`]: crate::logical_plan::tree_node + /// /// ```text /// let new_inputs = optimize_children(..., plan, props); /// @@ -1367,6 +1383,7 @@ macro_rules! handle_transform_recursion_up { } impl LogicalPlan { + /// Visits a plan similarly to [`Self::visit`], but including embedded subqueries. pub fn visit_with_subqueries>( &self, visitor: &mut V, @@ -1380,6 +1397,7 @@ impl LogicalPlan { .visit_parent(|| visitor.f_up(self)) } + /// Rewrites a plan similarly t [`Self::visit`], but including embedded subqueries. pub fn rewrite_with_subqueries>( self, rewriter: &mut R, diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index ce26cac7970b..415343f88685 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -15,17 +15,35 @@ // specific language governing permissions and limitations // under the License. -//! Tree node implementation for logical plan - +//! [`TreeNode`] based visiting and rewriting for [`LogicalPlan`]s +//! +//! Visiting (read only) APIs +//! * [`LogicalPlan::visit`]: recursively visit the node and all of its inputs +//! * [`LogicalPlan::visit_with_subqueries`]: recursively visit the node and all of its inputs, including subqueries +//! * [`LogicalPlan::apply_children`]: recursively visit all inputs of this node +//! * [`LogicalPlan::apply_expressions`]: (non recursively) visit all expressions of this node +//! * [`LogicalPlan::apply_subqueries`]: (non recursively) visit all subqueries of this node +//! * [`LogicalPlan::apply_with_subqueries`]: recursively visit all inputs and embedded subqueries. +//! +//! Rewriting (update) APIs: +//! * [`LogicalPlan::exists`]: search for an expression in a plan +//! * [`LogicalPlan::rewrite`]: recursively rewrite the node and all of its inputs +//! * [`LogicalPlan::map_children`]: recursively rewrite all inputs of this node +//! * [`LogicalPlan::map_expressions`]: (non recursively) visit all expressions of this node +//! * [`LogicalPlan::map_subqueries`]: (non recursively) rewrite all subqueries of this node +//! * [`LogicalPlan::rewrite_with_subqueries`]: recursively rewrite the node and all of its inputs, including subqueries +//! +//! (Re)creation APIs (these require substantial cloning and thus are slow): +//! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions +//! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions use crate::{ - Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, DdlStatement, Distinct, - DistinctOn, DmlStatement, Explain, Extension, Filter, Join, Limit, LogicalPlan, - Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias, - Union, Unnest, Window, + dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, + DdlStatement, Distinct, DistinctOn, DmlStatement, Explain, Extension, Filter, Join, + Limit, LogicalPlan, Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery, + SubqueryAlias, Union, Unnest, Window, }; use std::sync::Arc; -use crate::dml::CopyTo; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; From 03d8ba1f0d94bac6bb8bb33e95f00f9f6fb5275a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 10 Apr 2024 09:27:35 -0400 Subject: [PATCH 4/5] Refactor `Optimizer` to use owned plans and `TreeNode` API (10% faster planning) (#9948) * Rewrite Optimizer to use TreeNode API * fmt --- datafusion-examples/examples/rewrite_expr.rs | 2 +- datafusion/common/src/tree_node.rs | 12 +- datafusion/core/src/execution/context/mod.rs | 4 +- .../core/tests/optimizer_integration.rs | 2 +- .../src/decorrelate_predicate_subquery.rs | 96 ++--- .../src/eliminate_duplicated_expr.rs | 6 +- datafusion/optimizer/src/eliminate_filter.rs | 14 +- datafusion/optimizer/src/eliminate_join.rs | 6 +- datafusion/optimizer/src/eliminate_limit.rs | 32 +- .../optimizer/src/eliminate_nested_union.rs | 22 +- .../optimizer/src/eliminate_one_union.rs | 6 +- .../optimizer/src/eliminate_outer_join.rs | 12 +- .../src/extract_equijoin_predicate.rs | 18 +- .../optimizer/src/filter_null_join_keys.rs | 14 +- .../optimizer/src/optimize_projections.rs | 50 +-- datafusion/optimizer/src/optimizer.rs | 348 +++++++++--------- .../optimizer/src/propagate_empty_relation.rs | 22 +- datafusion/optimizer/src/push_down_filter.rs | 151 ++++---- datafusion/optimizer/src/push_down_limit.rs | 70 ++-- .../optimizer/src/push_down_projection.rs | 76 ++-- .../src/replace_distinct_aggregate.rs | 4 +- .../optimizer/src/scalar_subquery_to_join.rs | 28 +- .../src/single_distinct_to_groupby.rs | 40 +- datafusion/optimizer/src/test/mod.rs | 68 ++-- .../optimizer/tests/optimizer_integration.rs | 6 +- datafusion/sqllogictest/test_files/join.slt | 2 +- 26 files changed, 535 insertions(+), 576 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 541448ebf149..dcebbb55fb66 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -59,7 +59,7 @@ pub fn main() -> Result<()> { // then run the optimizer with our custom rule let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(&analyzed_plan, &config, observe)?; + let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; println!( "Optimized Logical Plan:\n\n{}\n", optimized_plan.display_indent() diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index bb268e048d9a..dff22d495958 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -56,6 +56,9 @@ pub trait TreeNode: Sized { /// Visit the tree node using the given [`TreeNodeVisitor`], performing a /// depth-first walk of the node and its children. /// + /// See also: + /// * [`Self::rewrite`] to rewrite owned `TreeNode`s + /// /// Consider the following tree structure: /// ```text /// ParentNode @@ -93,6 +96,9 @@ pub trait TreeNode: Sized { /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for /// recursively transforming [`TreeNode`]s. /// + /// See also: + /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s + /// /// Consider the following tree structure: /// ```text /// ParentNode @@ -310,13 +316,15 @@ pub trait TreeNode: Sized { } /// Apply the closure `F` to the node's children. + /// + /// See `mutate_children` for rewriting in place fn apply_children Result>( &self, f: F, ) -> Result; - /// Apply transform `F` to the node's children. Note that the transform `F` - /// might have a direction (pre-order or post-order). + /// Apply transform `F` to potentially rewrite the node's children. Note + /// that the transform `F` might have a direction (pre-order or post-order). fn map_children Result>>( self, f: F, diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9e48c7b8a6f2..5cf8969aa46d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1881,7 +1881,7 @@ impl SessionState { // optimize the child plan, capturing the output of each optimizer let optimized_plan = self.optimizer.optimize( - &analyzed_plan, + analyzed_plan, self, |optimized_plan, optimizer| { let optimizer_name = optimizer.name().to_string(); @@ -1911,7 +1911,7 @@ impl SessionState { let analyzed_plan = self.analyzer .execute_and_check(plan, self.options(), |_, _| {})?; - self.optimizer.optimize(&analyzed_plan, self, |_, _| {}) + self.optimizer.optimize(analyzed_plan, self, |_, _| {}) } } diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index 60010bdddfb8..6e938361ddb4 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -110,7 +110,7 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, |_, _| {}) } #[derive(Default)] diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 019e7507b122..d9fc5a6ce261 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -338,7 +338,7 @@ mod tests { Operator, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), plan, @@ -378,7 +378,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ \n Projection: sq_2.c [c:UInt32]\ \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional AND filter @@ -404,7 +404,7 @@ mod tests { \n Projection: sq.c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional OR filter @@ -430,7 +430,7 @@ mod tests { \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -458,7 +458,7 @@ mod tests { \n Projection: sq2.c [c:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for nested IN subqueries @@ -487,7 +487,7 @@ mod tests { \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for filter input modification in case filter not supported @@ -519,7 +519,7 @@ mod tests { \n Projection: sq_inner.c [c:UInt32]\ \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test multiple correlated subqueries @@ -557,7 +557,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -607,7 +607,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -642,7 +642,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -675,7 +675,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -706,7 +706,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -739,7 +739,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -772,7 +772,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -806,7 +806,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); @@ -863,7 +863,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -896,7 +896,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -962,7 +962,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1000,7 +1000,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1030,7 +1030,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1054,7 +1054,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1078,7 +1078,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1107,7 +1107,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1142,7 +1142,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1178,7 +1178,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1224,7 +1224,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1255,7 +1255,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1289,7 +1289,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test recursive correlated subqueries @@ -1332,7 +1332,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional subquery filters @@ -1362,7 +1362,7 @@ mod tests { \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1387,7 +1387,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for exists subquery with both columns in schema @@ -1405,7 +1405,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for correlated exists subquery not equal @@ -1433,7 +1433,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery less than @@ -1461,7 +1461,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with subquery disjunction @@ -1490,7 +1490,7 @@ mod tests { \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists without projection @@ -1516,7 +1516,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists expressions @@ -1544,7 +1544,7 @@ mod tests { \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional filters @@ -1572,7 +1572,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with disjustions @@ -1599,7 +1599,7 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated EXISTS subquery filter @@ -1624,7 +1624,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for single exists subquery filter @@ -1636,7 +1636,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for single NOT exists subquery filter @@ -1648,7 +1648,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } #[test] @@ -1687,7 +1687,7 @@ mod tests { \n Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1713,7 +1713,7 @@ mod tests { \n Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1739,7 +1739,7 @@ mod tests { \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1767,7 +1767,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1795,7 +1795,7 @@ mod tests { \n Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1823,6 +1823,6 @@ mod tests { \n Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 349d4d8878e0..ee44a328f8b3 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -116,7 +116,7 @@ mod tests { use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(EliminateDuplicatedExpr::new()), plan, @@ -134,7 +134,7 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -153,6 +153,6 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 9411dc192beb..2bf5cfa30390 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -91,7 +91,7 @@ mod tests { use crate::test::*; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) } @@ -107,7 +107,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -122,7 +122,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -144,7 +144,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -159,7 +159,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -182,7 +182,7 @@ mod tests { \n TableScan: test\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -205,6 +205,6 @@ mod tests { // Filter is removed let expected = "Projection: test.a\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index e685229c61b2..caf45dda9896 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -83,7 +83,7 @@ mod tests { use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan}; use std::sync::Arc; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected) } @@ -98,7 +98,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -115,6 +115,6 @@ mod tests { CrossJoin:\ \n EmptyRelation\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index fb5d0d17b839..39231d784e00 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -94,24 +94,19 @@ mod tests { use crate::push_down_limit::PushDownLimit; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } fn assert_optimized_plan_eq_with_pushdown( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} @@ -125,7 +120,6 @@ mod tests { .expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } @@ -138,7 +132,7 @@ mod tests { .build()?; // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -158,7 +152,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -172,7 +166,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -192,7 +186,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -210,7 +204,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -228,7 +222,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -250,7 +244,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n TableScan: test\ \n TableScan: test1"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -263,6 +257,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 924a0853418c..da2a6a17214e 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -114,7 +114,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected) } @@ -131,7 +131,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -147,7 +147,7 @@ mod tests { \n Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -167,7 +167,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -188,7 +188,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -210,7 +210,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -230,7 +230,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // We don't need to use project_with_column_index in logical optimizer, @@ -261,7 +261,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -291,7 +291,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -337,7 +337,7 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -384,6 +384,6 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 63c3e789daa6..95a3370ab1b5 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -76,7 +76,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_with_rules( vec![Arc::new(EliminateOneUnion::new())], plan, @@ -97,7 +97,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -113,6 +113,6 @@ mod tests { }); let expected = "TableScan: table"; - assert_optimized_plan_equal(&single_union_plan, expected) + assert_optimized_plan_equal(single_union_plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index a004da2bff19..63b8b887bb32 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -306,7 +306,7 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) } @@ -330,7 +330,7 @@ mod tests { \n Left Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -353,7 +353,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -380,7 +380,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -407,7 +407,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -434,6 +434,6 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 4cfcd07b47d9..60b9ba3031a1 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -164,7 +164,7 @@ mod tests { col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; - fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(ExtractEquijoinPredicate {}), plan, @@ -186,7 +186,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -205,7 +205,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -228,7 +228,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -255,7 +255,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -281,7 +281,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -318,7 +318,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -351,7 +351,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -375,6 +375,6 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 16039b182bb2..fcf85327fdb0 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -116,7 +116,7 @@ mod tests { use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, logical_plan::JoinType, LogicalPlanBuilder}; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) } @@ -128,7 +128,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -139,7 +139,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -176,7 +176,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -197,7 +197,7 @@ mod tests { \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -218,7 +218,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -241,7 +241,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } fn build_plan( diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 69905c990a7f..6967b28f3037 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -941,7 +941,7 @@ mod tests { UserDefinedLogicalNodeCore, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -1090,7 +1090,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1104,7 +1104,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1117,7 +1117,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1130,7 +1130,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1152,7 +1152,7 @@ mod tests { \n Projection: \ \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ \n TableScan: ?table? projection=[]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1175,7 +1175,7 @@ mod tests { .build()?; let expected = "Projection: (?table?.s)[x]\ \n TableScan: ?table? projection=[s]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1187,7 +1187,7 @@ mod tests { let expected = "Projection: (- test.a)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1199,7 +1199,7 @@ mod tests { let expected = "Projection: test.a IS NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1211,7 +1211,7 @@ mod tests { let expected = "Projection: test.a IS NOT NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1223,7 +1223,7 @@ mod tests { let expected = "Projection: test.a IS TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1235,7 +1235,7 @@ mod tests { let expected = "Projection: test.a IS NOT TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1247,7 +1247,7 @@ mod tests { let expected = "Projection: test.a IS FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1259,7 +1259,7 @@ mod tests { let expected = "Projection: test.a IS NOT FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1271,7 +1271,7 @@ mod tests { let expected = "Projection: test.a IS UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1283,7 +1283,7 @@ mod tests { let expected = "Projection: test.a IS NOT UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1295,7 +1295,7 @@ mod tests { let expected = "Projection: NOT test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1307,7 +1307,7 @@ mod tests { let expected = "Projection: TRY_CAST(test.a AS Float64)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1323,7 +1323,7 @@ mod tests { let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1335,7 +1335,7 @@ mod tests { let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Test outer projection isn't discarded despite the same schema as inner @@ -1356,7 +1356,7 @@ mod tests { let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE d END AS d\ \n Projection: test.a, Int32(0) AS d\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Since only column `a` is referred at the output. Scan should only contain projection=[a]. @@ -1377,7 +1377,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Only column `a` is referred at the output. However, User defined node itself uses column `b` @@ -1404,7 +1404,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` @@ -1439,7 +1439,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Columns `l.a`, `l.c`, `r.a` is referred at the output. @@ -1464,6 +1464,6 @@ mod tests { \n UserDefinedCrossJoin\ \n TableScan: l projection=[a, c]\ \n TableScan: r projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 03ff402c3e3f..032f9c57321c 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -20,6 +20,16 @@ use std::collections::HashSet; use std::sync::Arc; +use chrono::{DateTime, Utc}; +use log::{debug, warn}; + +use datafusion_common::alias::AliasGenerator; +use datafusion_common::config::ConfigOptions; +use datafusion_common::instant::Instant; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::{DFSchema, DataFusionError, Result}; +use datafusion_expr::logical_plan::LogicalPlan; + use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; @@ -45,15 +55,6 @@ use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use crate::utils::log_plan; -use datafusion_common::alias::AliasGenerator; -use datafusion_common::config::ConfigOptions; -use datafusion_common::instant::Instant; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::logical_plan::LogicalPlan; - -use chrono::{DateTime, Utc}; -use log::{debug, warn}; - /// `OptimizerRule`s transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient /// way. If there are no suitable transformations for the input plan, @@ -184,41 +185,15 @@ pub struct Optimizer { pub rules: Vec>, } -/// If a rule is with `ApplyOrder`, it means the optimizer will derive to handle children instead of -/// recursively handling in rule. -/// We just need handle a subtree pattern itself. -/// -/// Notice: **sometime** result after optimize still can be optimized, we need apply again. +/// Specifies how recursion for an `OptimizerRule` should be handled. /// -/// Usage Example: Merge Limit (subtree pattern is: Limit-Limit) -/// ```rust -/// use datafusion_expr::{Limit, LogicalPlan, LogicalPlanBuilder}; -/// use datafusion_common::Result; -/// fn merge_limit(parent: &Limit, child: &Limit) -> LogicalPlan { -/// // just for run -/// return parent.input.as_ref().clone(); -/// } -/// fn try_optimize(plan: &LogicalPlan) -> Result> { -/// match plan { -/// LogicalPlan::Limit(limit) => match limit.input.as_ref() { -/// LogicalPlan::Limit(child_limit) => { -/// // merge limit ... -/// let optimized_plan = merge_limit(limit, child_limit); -/// // due to optimized_plan may be optimized again, -/// // for example: plan is Limit-Limit-Limit -/// Ok(Some( -/// try_optimize(&optimized_plan)? -/// .unwrap_or_else(|| optimized_plan.clone()), -/// )) -/// } -/// _ => Ok(None), -/// }, -/// _ => Ok(None), -/// } -/// } -/// ``` +/// * `Some(apply_order)`: The Optimizer will recursively apply the rule to the plan. +/// * `None`: the rule must handle any required recursion itself. +#[derive(Debug, Clone, Copy, PartialEq)] pub enum ApplyOrder { + /// Apply the rule to the node before its inputs TopDown, + /// Apply the rule to the node after its inputs BottomUp, } @@ -274,22 +249,78 @@ impl Optimizer { pub fn with_rules(rules: Vec>) -> Self { Self { rules } } +} + +/// Recursively rewrites LogicalPlans +struct Rewriter<'a> { + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, +} +impl<'a> Rewriter<'a> { + fn new( + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, + ) -> Self { + Self { + apply_order, + rule, + config, + } + } +} + +impl<'a> TreeNodeRewriter for Rewriter<'a> { + type Node = LogicalPlan; + + fn f_down(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::TopDown { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } + + fn f_up(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::BottomUp { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } +} + +/// Invokes the Optimizer rule to rewrite the LogicalPlan in place. +fn optimize_plan_node( + plan: LogicalPlan, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result> { + // TODO: add API to OptimizerRule to allow rewriting by ownership + rule.try_optimize(&plan, config) + .map(|maybe_plan| match maybe_plan { + Some(new_plan) => Transformed::yes(new_plan), + None => Transformed::no(plan), + }) +} + +impl Optimizer { /// Optimizes the logical plan by applying optimizer rules, and /// invoking observer function after each call pub fn optimize( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, mut observer: F, ) -> Result where F: FnMut(&LogicalPlan, &dyn OptimizerRule), { - let options = config.options(); - let mut new_plan = plan.clone(); - let start_time = Instant::now(); + let options = config.options(); + let mut new_plan = plan; let mut previous_plans = HashSet::with_capacity(16); previous_plans.insert(LogicalPlanSignature::new(&new_plan)); @@ -299,44 +330,71 @@ impl Optimizer { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); for rule in &self.rules { - let result = - self.optimize_recursively(rule, &new_plan, config) - .and_then(|plan| { - if let Some(plan) = &plan { - assert_schema_is_the_same(rule.name(), plan, &new_plan)?; - } - Ok(plan) - }); - match result { - Ok(Some(plan)) => { - new_plan = plan; - observer(&new_plan, rule.as_ref()); - log_plan(rule.name(), &new_plan); - } - Ok(None) => { + // If skipping failed rules, copy plan before attempting to rewrite + // as rewriting is destructive + let prev_plan = options + .optimizer + .skip_failed_rules + .then(|| new_plan.clone()); + + let starting_schema = new_plan.schema().clone(); + + let result = match rule.apply_order() { + // optimizer handles recursion + Some(apply_order) => new_plan.rewrite(&mut Rewriter::new( + apply_order, + rule.as_ref(), + config, + )), + // rule handles recursion itself + None => optimize_plan_node(new_plan, rule.as_ref(), config), + } + // verify the rule didn't change the schema + .and_then(|tnr| { + assert_schema_is_the_same(rule.name(), &starting_schema, &tnr.data)?; + Ok(tnr) + }); + + // Handle results + match (result, prev_plan) { + // OptimizerRule was successful + ( + Ok(Transformed { + data, transformed, .. + }), + _, + ) => { + new_plan = data; observer(&new_plan, rule.as_ref()); - debug!( - "Plan unchanged by optimizer rule '{}' (pass {})", - rule.name(), - i - ); + if transformed { + log_plan(rule.name(), &new_plan); + } else { + debug!( + "Plan unchanged by optimizer rule '{}' (pass {})", + rule.name(), + i + ); + } } - Err(e) => { - if options.optimizer.skip_failed_rules { - // Note to future readers: if you see this warning it signals a - // bug in the DataFusion optimizer. Please consider filing a ticket - // https://github.com/apache/arrow-datafusion - warn!( + // OptimizerRule was unsuccessful, but skipped failed rules is on + // so use the previous plan + (Err(e), Some(orig_plan)) => { + // Note to future readers: if you see this warning it signals a + // bug in the DataFusion optimizer. Please consider filing a ticket + // https://github.com/apache/arrow-datafusion + warn!( "Skipping optimizer rule '{}' due to unexpected error: {}", rule.name(), e ); - } else { - return Err(DataFusionError::Context( - format!("Optimizer rule '{}' failed", rule.name(),), - Box::new(e), - )); - } + new_plan = orig_plan; + } + // OptimizerRule was unsuccessful, but skipped failed rules is off, return error + (Err(e), None) => { + return Err(e.context(format!( + "Optimizer rule '{}' failed", + rule.name() + ))); } } } @@ -356,97 +414,22 @@ impl Optimizer { debug!("Optimizer took {} ms", start_time.elapsed().as_millis()); Ok(new_plan) } - - fn optimize_node( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - // TODO: future feature: We can do Batch optimize - rule.try_optimize(plan, config) - } - - fn optimize_inputs( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - let inputs = plan.inputs(); - let result = inputs - .iter() - .map(|sub_plan| self.optimize_recursively(rule, sub_plan, config)) - .collect::>>()?; - if result.is_empty() || result.iter().all(|o| o.is_none()) { - return Ok(None); - } - - let new_inputs = result - .into_iter() - .zip(inputs) - .map(|(new_plan, old_plan)| match new_plan { - Some(plan) => plan, - None => old_plan.clone(), - }) - .collect(); - - let exprs = plan.expressions(); - plan.with_new_exprs(exprs, new_inputs).map(Some) - } - - /// Use a rule to optimize the whole plan. - /// If the rule with `ApplyOrder`, we don't need to recursively handle children in rule. - pub fn optimize_recursively( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - match rule.apply_order() { - Some(order) => match order { - ApplyOrder::TopDown => { - let optimize_self_opt = self.optimize_node(rule, plan, config)?; - let optimize_inputs_opt = match &optimize_self_opt { - Some(optimized_plan) => { - self.optimize_inputs(rule, optimized_plan, config)? - } - _ => self.optimize_inputs(rule, plan, config)?, - }; - Ok(optimize_inputs_opt.or(optimize_self_opt)) - } - ApplyOrder::BottomUp => { - let optimize_inputs_opt = self.optimize_inputs(rule, plan, config)?; - let optimize_self_opt = match &optimize_inputs_opt { - Some(optimized_plan) => { - self.optimize_node(rule, optimized_plan, config)? - } - _ => self.optimize_node(rule, plan, config)?, - }; - Ok(optimize_self_opt.or(optimize_inputs_opt)) - } - }, - _ => rule.try_optimize(plan, config), - } - } } -/// Returns an error if plans have different schemas. +/// Returns an error if `new_plan`'s schema is different than `prev_schema` /// /// It ignores metadata and nullability. pub(crate) fn assert_schema_is_the_same( rule_name: &str, - prev_plan: &LogicalPlan, + prev_schema: &DFSchema, new_plan: &LogicalPlan, ) -> Result<()> { - let equivalent = new_plan - .schema() - .equivalent_names_and_types(prev_plan.schema()); + let equivalent = new_plan.schema().equivalent_names_and_types(prev_schema); if !equivalent { let e = DataFusionError::Internal(format!( "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", - prev_plan.schema(), + prev_schema, new_plan.schema() )); Err(DataFusionError::Context( @@ -462,14 +445,15 @@ pub(crate) fn assert_schema_is_the_same( mod tests { use std::sync::{Arc, Mutex}; - use super::ApplyOrder; + use datafusion_common::{plan_err, DFSchema, DFSchemaRef, Result}; + use datafusion_expr::logical_plan::EmptyRelation; + use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; + use crate::optimizer::Optimizer; use crate::test::test_table_scan; use crate::{OptimizerConfig, OptimizerContext, OptimizerRule}; - use datafusion_common::{plan_err, DFSchema, DFSchemaRef, Result}; - use datafusion_expr::logical_plan::EmptyRelation; - use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; + use super::ApplyOrder; #[test] fn skip_failing_rule() { @@ -479,7 +463,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -490,7 +474,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'bad rule' failed\ncaused by\n\ Error during planning: rule failed", @@ -506,21 +490,27 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( - "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ - Internal error: Failed due to a difference in schemas, original schema: \ - DFSchema { inner: Schema { fields: \ - [Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }, \ - field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" })], \ - functional_dependencies: FunctionalDependencies { deps: [] } }, \ + "Optimizer rule 'get table_scan rule' failed\n\ + caused by\nget table_scan rule\ncaused by\n\ + Internal error: Failed due to a difference in schemas, \ + original schema: DFSchema { inner: Schema { \ + fields: [], \ + metadata: {} }, \ + field_qualifiers: [], \ + functional_dependencies: FunctionalDependencies { deps: [] } \ + }, \ new schema: DFSchema { inner: Schema { \ - fields: [], metadata: {} }, \ - field_qualifiers: [], \ - functional_dependencies: FunctionalDependencies { deps: [] } }.\n\ - This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker", + fields: [\ + Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ + Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ + Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }\ + ], \ + metadata: {} }, \ + field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" })], \ + functional_dependencies: FunctionalDependencies { deps: [] } }.\n\ + This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker", err.strip_backtrace() ); } @@ -533,7 +523,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -554,7 +544,7 @@ mod tests { // optimizing should be ok, but the schema will have changed (no metadata) assert_ne!(plan.schema().as_ref(), input_schema.as_ref()); - let optimized_plan = opt.optimize(&plan, &config, &observe)?; + let optimized_plan = opt.optimize(plan, &config, &observe)?; // metadata was removed assert_eq!(optimized_plan.schema().as_ref(), input_schema.as_ref()); Ok(()) @@ -575,7 +565,7 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan.clone(), &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 3 plans assert_eq!(3, plans.len()); @@ -601,7 +591,7 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan, &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 4 plans assert_eq!(4, plans.len()); diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 2aca6f93254a..445109bbdf77 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -198,12 +198,12 @@ mod tests { use super::*; - fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) } fn assert_together_optimized_plan_eq( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { assert_optimized_plan_eq_with_rules( @@ -226,7 +226,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_eq(&plan, expected) + assert_eq(plan, expected) } #[test] @@ -249,7 +249,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -262,7 +262,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -287,7 +287,7 @@ mod tests { let expected = "Union\ \n TableScan: test1\ \n TableScan: test4"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -312,7 +312,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -339,7 +339,7 @@ mod tests { let expected = "Union\ \n TableScan: test2\ \n TableScan: test3"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -352,7 +352,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -367,7 +367,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -400,6 +400,6 @@ mod tests { let expected = "Projection: a, b, c\ \n TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f3ce8bbcde72..2b123e3559f5 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1028,11 +1028,11 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { #[cfg(test)] mod tests { + use super::*; use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::Arc; - use super::*; use crate::optimizer::Optimizer; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::test::*; @@ -1040,6 +1040,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{DFSchema, DFSchemaRef, ScalarValue}; + use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr, @@ -1049,9 +1050,9 @@ mod tests { }; use async_trait::async_trait; - use datafusion_expr::expr::ScalarFunction; + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(PushDownFilter::new()), plan, @@ -1060,29 +1061,17 @@ mod tests { } fn assert_optimized_plan_eq_with_rewrite_predicate( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![ Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(PushDownFilter::new()), ]); - let mut optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); - optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.get(1).unwrap(), - &optimized_plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(plan.schema(), optimized_plan.schema()); assert_eq!(expected, formatted_plan); Ok(()) } @@ -1098,7 +1087,7 @@ mod tests { let expected = "\ Projection: test.a, test.b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1115,7 +1104,7 @@ mod tests { \n Limit: skip=0, fetch=10\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1125,7 +1114,7 @@ mod tests { .filter(lit(0i64).eq(lit(1i64)))? .build()?; let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1141,7 +1130,7 @@ mod tests { Projection: test.c, test.b\ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1155,7 +1144,7 @@ mod tests { let expected = "\ Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1168,7 +1157,7 @@ mod tests { let expected = "Filter: test.b > Int64(10)\ \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1180,7 +1169,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1195,7 +1184,7 @@ mod tests { Filter: b > Int64(10)\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written @@ -1210,7 +1199,7 @@ mod tests { let expected = "\ Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } fn add(left: Expr, right: Expr) -> Expr { @@ -1254,7 +1243,7 @@ mod tests { let expected = "\ Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written @@ -1286,7 +1275,7 @@ mod tests { Projection: b * Int32(3) AS a, test.c\ \n Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[derive(Debug, PartialEq, Eq, Hash)] @@ -1349,7 +1338,7 @@ mod tests { let expected = "\ NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1366,7 +1355,7 @@ mod tests { Filter: test.c = Int64(2)\ \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1383,7 +1372,7 @@ mod tests { NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1401,7 +1390,7 @@ mod tests { \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed @@ -1434,7 +1423,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed @@ -1468,7 +1457,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two limits are in place, we jump neither @@ -1490,7 +1479,7 @@ mod tests { \n Limit: skip=0, fetch=20\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1505,7 +1494,7 @@ mod tests { let expected = "Union\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test2, full_filters=[test2.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1528,7 +1517,7 @@ mod tests { \n SubqueryAlias: test2\ \n Projection: test.a AS b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1559,7 +1548,7 @@ mod tests { \n Projection: test1.d, test1.e, test1.f\ \n TableScan: test1, full_filters=[test1.d > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1585,7 +1574,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ \n TableScan: test1, full_filters=[test1.a > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters with the same columns are correctly placed @@ -1619,7 +1608,7 @@ mod tests { \n Projection: test.a\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters to be placed on the same depth are ANDed @@ -1649,7 +1638,7 @@ mod tests { \n Limit: skip=0, fetch=1\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters on a plan with user nodes are not lost @@ -1675,7 +1664,7 @@ mod tests { TestUserDefined\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-on-join predicates on a column common to both sides is pushed to both sides @@ -1713,7 +1702,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-using-join predicates on a column common to both sides is pushed to both sides @@ -1750,7 +1739,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from both sides are converted to join filterss @@ -1792,7 +1781,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from one side of a join are pushed only to that side @@ -1834,7 +1823,7 @@ mod tests { \n TableScan: test, full_filters=[test.b <= Int64(1)]\ \n Projection: test2.a, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the right side of a left join are not duplicated @@ -1873,7 +1862,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the left side of a right join are not duplicated @@ -1911,7 +1900,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-left-join predicate on a column common to both sides is only pushed to the left side @@ -1949,7 +1938,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-right-join predicate on a column common to both sides is only pushed to the right side @@ -1987,7 +1976,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to both inputs @@ -2030,7 +2019,7 @@ mod tests { \n TableScan: test, full_filters=[test.c > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// join filter should be completely removed after pushdown @@ -2072,7 +2061,7 @@ mod tests { \n TableScan: test, full_filters=[test.b > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2112,7 +2101,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to right input @@ -2155,7 +2144,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to left input @@ -2198,7 +2187,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should not be pushed @@ -2236,7 +2225,7 @@ mod tests { ); let expected = &format!("{plan:?}"); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } struct PushDownProvider { @@ -2295,7 +2284,7 @@ mod tests { let expected = "\ TableScan: test, full_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2306,7 +2295,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test, partial_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2314,7 +2303,7 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; - let optimised_plan = PushDownFilter::new() + let optimized_plan = PushDownFilter::new() .try_optimize(&plan, &OptimizerContext::new()) .expect("failed to optimize plan") .unwrap(); @@ -2325,7 +2314,7 @@ mod tests { // Optimizing the same plan multiple times should produce the same plan // each time. - assert_optimized_plan_eq(&optimised_plan, expected) + assert_optimized_plan_eq(optimized_plan, expected) } #[test] @@ -2336,7 +2325,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2365,7 +2354,7 @@ mod tests { \n Filter: a = Int64(10) AND b > Int64(11)\ \n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2396,7 +2385,7 @@ Projection: a, b "# .trim(); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2424,7 +2413,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2456,7 +2445,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2481,7 +2470,7 @@ Projection: a, b Projection: test.a AS b, test.c AS d\ \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2521,7 +2510,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b AS d\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2550,7 +2539,7 @@ Projection: a, b Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2582,7 +2571,7 @@ Projection: a, b \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2618,7 +2607,7 @@ Projection: a, b \n Subquery:\ \n Projection: sq.c\ \n TableScan: sq"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2651,7 +2640,7 @@ Projection: a, b \n Projection: Int64(0) AS a\ \n Filter: Int64(0) = Int64(1)\ \n EmptyRelation"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2679,14 +2668,14 @@ Projection: a, b \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ \n Projection: test1.a AS d, test1.a AS e\ \n TableScan: test1"; - assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; + assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(), expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. let optimized_plan = PushDownFilter::new() .try_optimize(&plan, &OptimizerContext::new())? .expect("failed to optimize plan"); - assert_optimized_plan_eq(&optimized_plan, expected) + assert_optimized_plan_eq(optimized_plan, expected) } #[test] @@ -2727,7 +2716,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2768,7 +2757,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2814,7 +2803,7 @@ Projection: a, b \n TableScan: test1\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2859,7 +2848,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[derive(Debug)] @@ -2919,7 +2908,7 @@ Projection: a, b \n Projection: test1.a, SUM(test1.b), TestScalarUDF() + Int32(1) AS r\ \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2965,6 +2954,6 @@ Projection: a, b \n Inner Join: test1.a = test2.a\ \n TableScan: test1\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index cca6c3fd9bd1..6f1d7bf97cfe 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -285,7 +285,7 @@ mod test { max, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) } @@ -304,7 +304,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -322,7 +322,7 @@ mod test { let expected = "Limit: skip=0, fetch=10\ \n TableScan: test, fetch=10"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -339,7 +339,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -359,7 +359,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -376,7 +376,7 @@ mod test { \n Sort: test.a, fetch=10\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -393,7 +393,7 @@ mod test { \n Sort: test.a, fetch=15\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -412,7 +412,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -427,7 +427,7 @@ mod test { let expected = "Limit: skip=10, fetch=None\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -445,7 +445,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -462,7 +462,7 @@ mod test { \n Limit: skip=10, fetch=990\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -479,7 +479,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -495,7 +495,7 @@ mod test { let expected = "Limit: skip=10, fetch=10\ \n TableScan: test, fetch=20"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -512,7 +512,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -532,7 +532,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -556,7 +556,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -580,7 +580,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -609,7 +609,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -638,7 +638,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -664,7 +664,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -683,7 +683,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -702,7 +702,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -720,7 +720,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -738,7 +738,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -756,7 +756,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1) .join( @@ -774,7 +774,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -799,7 +799,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -824,7 +824,7 @@ mod test { \n TableScan: test, fetch=1010\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -849,7 +849,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -874,7 +874,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test2, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -894,7 +894,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -914,7 +914,7 @@ mod test { \n Limit: skip=0, fetch=2000\ \n TableScan: test2, fetch=2000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -929,7 +929,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -944,7 +944,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -961,6 +961,6 @@ mod test { \n Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index ae57ed9e5a34..2f578094b3bc 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -24,7 +24,7 @@ mod tests { use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::*; - use crate::OptimizerContext; + use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::builder::table_scan_with_filters; @@ -48,7 +48,7 @@ mod tests { let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -62,7 +62,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -78,7 +78,7 @@ mod tests { \n SubqueryAlias: a\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -95,7 +95,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -120,7 +120,7 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ \n TableScan: m4 projection=[tag.one]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -134,7 +134,7 @@ mod tests { let expected = "Projection: test.a, test.c, test.b\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -144,7 +144,7 @@ mod tests { let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; let expected = "TableScan: test projection=[b, a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -157,7 +157,7 @@ mod tests { let expected = "Projection: test.a, test.b\ \n TableScan: test projection=[b, a]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -170,7 +170,7 @@ mod tests { let expected = "Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -192,7 +192,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -212,7 +212,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -264,7 +264,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -314,7 +314,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[a]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -358,7 +358,7 @@ mod tests { let expected = "Projection: CAST(test.c AS Float64)\ \n TableScan: test projection=[c]"; - assert_optimized_plan_eq(&projection, expected) + assert_optimized_plan_eq(projection, expected) } #[test] @@ -374,7 +374,7 @@ mod tests { let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -395,7 +395,7 @@ mod tests { let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -415,7 +415,7 @@ mod tests { \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -424,7 +424,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan).build()?; // should expand projection to all columns without projection let expected = "TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -435,7 +435,7 @@ mod tests { .build()?; let expected = "Projection: Int64(1), Int64(2)\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes unused columns in projections @@ -454,14 +454,14 @@ mod tests { assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); - let plan = optimize(&plan).expect("failed to optimize plan"); + let plan = optimize(plan).expect("failed to optimize plan"); let expected = "\ Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.a)]]\ \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes un-needed projections @@ -483,7 +483,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -512,7 +512,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that optimizing twice yields same plan @@ -525,9 +525,9 @@ mod tests { .project(vec![lit(1).alias("a")])? .build()?; - let optimized_plan1 = optimize(&plan).expect("failed to optimize plan"); + let optimized_plan1 = optimize(plan).expect("failed to optimize plan"); let optimized_plan2 = - optimize(&optimized_plan1).expect("failed to optimize plan"); + optimize(optimized_plan1.clone()).expect("failed to optimize plan"); let formatted_plan1 = format!("{optimized_plan1:?}"); let formatted_plan2 = format!("{optimized_plan2:?}"); @@ -556,7 +556,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -582,7 +582,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -599,7 +599,7 @@ mod tests { \n Distinct:\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -638,25 +638,23 @@ mod tests { \n WindowAggr: windowExpr=[[MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { let optimized_plan = optimize(plan).expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); Ok(()) } - fn optimize(plan: &LogicalPlan) -> Result { + fn optimize(plan: LogicalPlan) -> Result { let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + Ok(optimized_plan) } + + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 752915be69c0..f464506057ff 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -172,7 +172,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } @@ -195,7 +195,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index a2c4eabcaae6..a8999f9c1d3c 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -429,7 +429,7 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -485,7 +485,7 @@ mod tests { \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -523,7 +523,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -559,7 +559,7 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -593,7 +593,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -732,7 +732,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -798,7 +798,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -837,7 +837,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -877,7 +877,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -910,7 +910,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -942,7 +942,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -973,7 +973,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -1030,7 +1030,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -1079,7 +1079,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 076bf4e24296..602994a9e3e2 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -313,7 +313,7 @@ mod tests { min, sum, AggregateFunction, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(SingleDistinctToGroupBy::new()), plan, @@ -335,7 +335,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]] [MAX(test.b):UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -352,7 +352,7 @@ mod tests { \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -373,7 +373,7 @@ mod tests { let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -391,7 +391,7 @@ mod tests { let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -410,7 +410,7 @@ mod tests { let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -426,7 +426,7 @@ mod tests { \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -443,7 +443,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -461,7 +461,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(DISTINCT test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -490,7 +490,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -508,7 +508,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -525,7 +525,7 @@ mod tests { \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -555,7 +555,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -574,7 +574,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -593,7 +593,7 @@ mod tests { \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -616,7 +616,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -639,7 +639,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -662,7 +662,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -685,7 +685,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -708,6 +708,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index e691fe9a5351..cafda8359aa3 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::{Analyzer, AnalyzerRule}; -use crate::optimizer::{assert_schema_is_the_same, Optimizer}; +use crate::optimizer::Optimizer; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -150,22 +150,19 @@ pub fn assert_analyzer_check_err( } } } + +fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + pub fn assert_optimized_plan_eq( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule.clone()]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + // Apply the rule once + let opt_context = OptimizerContext::new().with_max_passes(1); - // Ensure schemas always match after an optimization - assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; + let optimizer = Optimizer::with_rules(vec![rule.clone()]); + let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -174,7 +171,7 @@ pub fn assert_optimized_plan_eq( pub fn assert_optimized_plan_eq_with_rules( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} @@ -187,58 +184,44 @@ pub fn assert_optimized_plan_eq_with_rules( .expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } pub fn assert_optimized_plan_eq_display_indent( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - ) - .expect("failed to optimize plan") - .unwrap_or_else(|| plan.clone()); + .optimize(plan, &OptimizerContext::new(), observe) + .expect("failed to optimize plan"); let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); } pub fn assert_multi_rules_optimized_plan_eq_display_indent( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(rules); - let mut optimized_plan = plan.clone(); - for rule in &optimizer.rules { - optimized_plan = optimizer - .optimize_recursively(rule, &optimized_plan, &OptimizerContext::new()) - .expect("failed to optimize plan") - .unwrap_or_else(|| optimized_plan.clone()); - } + let optimized_plan = optimizer + .optimize(plan, &OptimizerContext::new(), observe) + .expect("failed to optimize plan"); let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); } pub fn assert_optimizer_err( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); - let res = optimizer.optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - ); + let res = optimizer.optimize(plan, &OptimizerContext::new(), observe); match res { - Ok(plan) => assert_eq!(format!("{}", plan.unwrap().display_indent()), "An error"), + Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"), Err(ref e) => { let actual = format!("{e}"); if expected.is_empty() || !actual.contains(expected) { @@ -250,16 +233,11 @@ pub fn assert_optimizer_err( pub fn assert_optimization_skipped( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![rule]); - let new_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let new_plan = optimizer.optimize(plan.clone(), &OptimizerContext::new(), observe)?; + assert_eq!( format!("{}", plan.display_indent()), format!("{}", new_plan.display_indent()) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 61d2535930b2..01db5e817c56 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -25,7 +25,7 @@ use datafusion_common::{plan_err, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; +use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; @@ -315,9 +315,11 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, observe) } +fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + #[derive(Default)] struct MyContextProvider { options: ConfigOptions, diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index da9b4168e7e0..135ab8075425 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -587,7 +587,7 @@ FROM t1 ---- 11 11 11 -# subsequent inner join +# subsequent inner join query III rowsort SELECT t1.t1_id, t2.t2_id, t3.t3_id FROM t1 From 1ec65a4a4a697d382d64ac2382b8486709dcf680 Mon Sep 17 00:00:00 2001 From: "D.B. Schwartz" <1689014+cisaacson@users.noreply.github.com> Date: Wed, 10 Apr 2024 10:00:47 -0600 Subject: [PATCH 5/5] Further clarification of the supports_filters_pushdown documentation (#9988) * Further refinement of the comment * Add code example example of how to support a filter * Update supports_filters_pushdown example so that it compiles * Add comments to example code in supports_filters_pushdown doc * Change example to use functional style * Fixed several issues with the supports_filters_pushdown doc; still need all required TableProvider impl fns * cargo fmt * Update example so it compiles and add headings * clean * remove to_string() --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/datasource/provider.rs | 77 ++++++++++++++++++++-- 1 file changed, 70 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 9aac072ed4e2..100011952b3b 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -161,20 +161,83 @@ pub trait TableProvider: Sync + Send { /// Specify if DataFusion should provide filter expressions to the /// TableProvider to apply *during* the scan. /// - /// The return value must have one element for each filter expression passed - /// in. The value of each element indicates if the TableProvider can apply - /// that particular filter during the scan. - /// /// Some TableProviders can evaluate filters more efficiently than the /// `Filter` operator in DataFusion, for example by using an index. /// - /// By default, returns [`Unsupported`] for all filters, meaning no filters - /// will be provided to [`Self::scan`]. If the TableProvider can implement - /// filter pushdown, it should return either [`Exact`] or [`Inexact`]. + /// # Parameters and Return Value + /// + /// The return `Vec` must have one element for each element of the `filters` + /// argument. The value of each element indicates if the TableProvider can + /// apply the corresponding filter during the scan. The position in the return + /// value corresponds to the expression in the `filters` parameter. + /// + /// If the length of the resulting `Vec` does not match the `filters` input + /// an error will be thrown. + /// + /// Each element in the resulting `Vec` is one of the following: + /// * [`Exact`] or [`Inexact`]: The TableProvider can apply the filter + /// during scan + /// * [`Unsupported`]: The TableProvider cannot apply the filter during scan + /// + /// By default, this function returns [`Unsupported`] for all filters, + /// meaning no filters will be provided to [`Self::scan`]. /// /// [`Unsupported`]: TableProviderFilterPushDown::Unsupported /// [`Exact`]: TableProviderFilterPushDown::Exact /// [`Inexact`]: TableProviderFilterPushDown::Inexact + /// # Example + /// + /// ```rust + /// # use std::any::Any; + /// # use std::sync::Arc; + /// # use arrow_schema::SchemaRef; + /// # use async_trait::async_trait; + /// # use datafusion::datasource::TableProvider; + /// # use datafusion::error::{Result, DataFusionError}; + /// # use datafusion::execution::context::SessionState; + /// # use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; + /// # use datafusion_physical_plan::ExecutionPlan; + /// // Define a struct that implements the TableProvider trait + /// struct TestDataSource {} + /// + /// #[async_trait] + /// impl TableProvider for TestDataSource { + /// # fn as_any(&self) -> &dyn Any { todo!() } + /// # fn schema(&self) -> SchemaRef { todo!() } + /// # fn table_type(&self) -> TableType { todo!() } + /// # async fn scan(&self, s: &SessionState, p: Option<&Vec>, f: &[Expr], l: Option) -> Result> { + /// todo!() + /// # } + /// // Override the supports_filters_pushdown to evaluate which expressions + /// // to accept as pushdown predicates. + /// fn supports_filters_pushdown(&self, filters: &[&Expr]) -> Result> { + /// // Process each filter + /// let support: Vec<_> = filters.iter().map(|expr| { + /// match expr { + /// // This example only supports a between expr with a single column named "c1". + /// Expr::Between(between_expr) => { + /// between_expr.expr + /// .try_into_col() + /// .map(|column| { + /// if column.name == "c1" { + /// TableProviderFilterPushDown::Exact + /// } else { + /// TableProviderFilterPushDown::Unsupported + /// } + /// }) + /// // If there is no column in the expr set the filter to unsupported. + /// .unwrap_or(TableProviderFilterPushDown::Unsupported) + /// } + /// _ => { + /// // For all other cases return Unsupported. + /// TableProviderFilterPushDown::Unsupported + /// } + /// } + /// }).collect(); + /// Ok(support) + /// } + /// } + /// ``` fn supports_filters_pushdown( &self, filters: &[&Expr],