diff --git a/Cargo.toml b/Cargo.toml index e1e09d9893b1..c5a1aa9c8ef8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,7 +101,7 @@ object_store = { version = "0.9.1", default-features = false } parking_lot = "0.12" parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" -rstest = "0.18.0" +rstest = "0.19.0" serde_json = "1" sqlparser = { version = "0.44.0", features = ["visitor"] } tempfile = "3" diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 541448ebf149..dcebbb55fb66 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -59,7 +59,7 @@ pub fn main() -> Result<()> { // then run the optimizer with our custom rule let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(&analyzed_plan, &config, observe)?; + let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; println!( "Optimized Logical Plan:\n\n{}\n", optimized_plan.display_indent() diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index bb268e048d9a..dff22d495958 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -56,6 +56,9 @@ pub trait TreeNode: Sized { /// Visit the tree node using the given [`TreeNodeVisitor`], performing a /// depth-first walk of the node and its children. /// + /// See also: + /// * [`Self::rewrite`] to rewrite owned `TreeNode`s + /// /// Consider the following tree structure: /// ```text /// ParentNode @@ -93,6 +96,9 @@ pub trait TreeNode: Sized { /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for /// recursively transforming [`TreeNode`]s. /// + /// See also: + /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s + /// /// Consider the following tree structure: /// ```text /// ParentNode @@ -310,13 +316,15 @@ pub trait TreeNode: Sized { } /// Apply the closure `F` to the node's children. + /// + /// See `mutate_children` for rewriting in place fn apply_children Result>( &self, f: F, ) -> Result; - /// Apply transform `F` to the node's children. Note that the transform `F` - /// might have a direction (pre-order or post-order). + /// Apply transform `F` to potentially rewrite the node's children. Note + /// that the transform `F` might have a direction (pre-order or post-order). fn map_children Result>>( self, f: F, diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 9aac072ed4e2..100011952b3b 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -161,20 +161,83 @@ pub trait TableProvider: Sync + Send { /// Specify if DataFusion should provide filter expressions to the /// TableProvider to apply *during* the scan. /// - /// The return value must have one element for each filter expression passed - /// in. The value of each element indicates if the TableProvider can apply - /// that particular filter during the scan. - /// /// Some TableProviders can evaluate filters more efficiently than the /// `Filter` operator in DataFusion, for example by using an index. /// - /// By default, returns [`Unsupported`] for all filters, meaning no filters - /// will be provided to [`Self::scan`]. If the TableProvider can implement - /// filter pushdown, it should return either [`Exact`] or [`Inexact`]. + /// # Parameters and Return Value + /// + /// The return `Vec` must have one element for each element of the `filters` + /// argument. The value of each element indicates if the TableProvider can + /// apply the corresponding filter during the scan. The position in the return + /// value corresponds to the expression in the `filters` parameter. + /// + /// If the length of the resulting `Vec` does not match the `filters` input + /// an error will be thrown. + /// + /// Each element in the resulting `Vec` is one of the following: + /// * [`Exact`] or [`Inexact`]: The TableProvider can apply the filter + /// during scan + /// * [`Unsupported`]: The TableProvider cannot apply the filter during scan + /// + /// By default, this function returns [`Unsupported`] for all filters, + /// meaning no filters will be provided to [`Self::scan`]. /// /// [`Unsupported`]: TableProviderFilterPushDown::Unsupported /// [`Exact`]: TableProviderFilterPushDown::Exact /// [`Inexact`]: TableProviderFilterPushDown::Inexact + /// # Example + /// + /// ```rust + /// # use std::any::Any; + /// # use std::sync::Arc; + /// # use arrow_schema::SchemaRef; + /// # use async_trait::async_trait; + /// # use datafusion::datasource::TableProvider; + /// # use datafusion::error::{Result, DataFusionError}; + /// # use datafusion::execution::context::SessionState; + /// # use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; + /// # use datafusion_physical_plan::ExecutionPlan; + /// // Define a struct that implements the TableProvider trait + /// struct TestDataSource {} + /// + /// #[async_trait] + /// impl TableProvider for TestDataSource { + /// # fn as_any(&self) -> &dyn Any { todo!() } + /// # fn schema(&self) -> SchemaRef { todo!() } + /// # fn table_type(&self) -> TableType { todo!() } + /// # async fn scan(&self, s: &SessionState, p: Option<&Vec>, f: &[Expr], l: Option) -> Result> { + /// todo!() + /// # } + /// // Override the supports_filters_pushdown to evaluate which expressions + /// // to accept as pushdown predicates. + /// fn supports_filters_pushdown(&self, filters: &[&Expr]) -> Result> { + /// // Process each filter + /// let support: Vec<_> = filters.iter().map(|expr| { + /// match expr { + /// // This example only supports a between expr with a single column named "c1". + /// Expr::Between(between_expr) => { + /// between_expr.expr + /// .try_into_col() + /// .map(|column| { + /// if column.name == "c1" { + /// TableProviderFilterPushDown::Exact + /// } else { + /// TableProviderFilterPushDown::Unsupported + /// } + /// }) + /// // If there is no column in the expr set the filter to unsupported. + /// .unwrap_or(TableProviderFilterPushDown::Unsupported) + /// } + /// _ => { + /// // For all other cases return Unsupported. + /// TableProviderFilterPushDown::Unsupported + /// } + /// } + /// }).collect(); + /// Ok(support) + /// } + /// } + /// ``` fn supports_filters_pushdown( &self, filters: &[&Expr], diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 9e48c7b8a6f2..5cf8969aa46d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1881,7 +1881,7 @@ impl SessionState { // optimize the child plan, capturing the output of each optimizer let optimized_plan = self.optimizer.optimize( - &analyzed_plan, + analyzed_plan, self, |optimized_plan, optimizer| { let optimizer_name = optimizer.name().to_string(); @@ -1911,7 +1911,7 @@ impl SessionState { let analyzed_plan = self.analyzer .execute_and_check(plan, self.options(), |_, _| {})?; - self.optimizer.optimize(&analyzed_plan, self, |_, _| {}) + self.optimizer.optimize(analyzed_plan, self, |_, _| {}) } } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index c213f4554fb8..b0e2b6fa9c09 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -296,11 +296,15 @@ //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! -//! [`Expr`]s can be rewritten using the [`TreeNode`] API and simplified using -//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be found in the -//! [`expr_api`.rs] example +//! `LogicalPlan`s can be rewritten with [`TreeNode`] API, see the +//! [`tree_node module`] for more details. +//! +//! [`Expr`]s can also be rewritten with [`TreeNode`] API and simplified using +//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be +//! found in the [`expr_api`.rs] example //! //! [`TreeNode`]: datafusion_common::tree_node::TreeNode +//! [`tree_node module`]: datafusion_expr::logical_plan::tree_node //! [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier //! [`expr_api`.rs]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs //! diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index 60010bdddfb8..6e938361ddb4 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -110,7 +110,7 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, |_, _| {}) } #[derive(Default)] diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a1fe7a6f0a51..034440643e51 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -22,7 +22,7 @@ pub mod dml; mod extension; mod plan; mod statement; -mod tree_node; +pub mod tree_node; pub use builder::{ build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 02d65973a50b..7bad034a11ea 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -68,6 +68,11 @@ pub use datafusion_common::{JoinConstraint, JoinType}; /// an output relation (table) with a (potentially) different /// schema. A plan represents a dataflow tree where data flows /// from leaves up to the root to produce the query result. +/// +/// # See also: +/// * [`tree_node`]: visiting and rewriting API +/// +/// [`tree_node`]: crate::logical_plan::tree_node #[derive(Clone, PartialEq, Eq, Hash)] pub enum LogicalPlan { /// Evaluates an arbitrary list of expressions (essentially a @@ -238,7 +243,10 @@ impl LogicalPlan { } /// Returns all expressions (non-recursively) evaluated by the current - /// logical plan node. This does not include expressions in any children + /// logical plan node. This does not include expressions in any children. + /// + /// Note this method `clone`s all the expressions. When possible, the + /// [`tree_node`] API should be used instead of this API. /// /// The returned expressions do not necessarily represent or even /// contributed to the output schema of this node. For example, @@ -248,6 +256,8 @@ impl LogicalPlan { /// The expressions do contain all the columns that are used by this plan, /// so if there are columns not referenced by these expressions then /// DataFusion's optimizer attempts to optimize them away. + /// + /// [`tree_node`]: crate::logical_plan::tree_node pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; self.apply_expressions(|e| { @@ -773,10 +783,16 @@ impl LogicalPlan { /// Returns a new `LogicalPlan` based on `self` with inputs and /// expressions replaced. /// + /// Note this method creates an entirely new node, which requires a large + /// amount of clone'ing. When possible, the [`tree_node`] API should be used + /// instead of this API. + /// /// The exprs correspond to the same order of expressions returned /// by [`Self::expressions`]. This function is used by optimizers /// to rewrite plans using the following pattern: /// + /// [`tree_node`]: crate::logical_plan::tree_node + /// /// ```text /// let new_inputs = optimize_children(..., plan, props); /// @@ -1367,6 +1383,7 @@ macro_rules! handle_transform_recursion_up { } impl LogicalPlan { + /// Visits a plan similarly to [`Self::visit`], but including embedded subqueries. pub fn visit_with_subqueries>( &self, visitor: &mut V, @@ -1380,6 +1397,7 @@ impl LogicalPlan { .visit_parent(|| visitor.f_up(self)) } + /// Rewrites a plan similarly t [`Self::visit`], but including embedded subqueries. pub fn rewrite_with_subqueries>( self, rewriter: &mut R, diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index ce26cac7970b..415343f88685 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -15,17 +15,35 @@ // specific language governing permissions and limitations // under the License. -//! Tree node implementation for logical plan - +//! [`TreeNode`] based visiting and rewriting for [`LogicalPlan`]s +//! +//! Visiting (read only) APIs +//! * [`LogicalPlan::visit`]: recursively visit the node and all of its inputs +//! * [`LogicalPlan::visit_with_subqueries`]: recursively visit the node and all of its inputs, including subqueries +//! * [`LogicalPlan::apply_children`]: recursively visit all inputs of this node +//! * [`LogicalPlan::apply_expressions`]: (non recursively) visit all expressions of this node +//! * [`LogicalPlan::apply_subqueries`]: (non recursively) visit all subqueries of this node +//! * [`LogicalPlan::apply_with_subqueries`]: recursively visit all inputs and embedded subqueries. +//! +//! Rewriting (update) APIs: +//! * [`LogicalPlan::exists`]: search for an expression in a plan +//! * [`LogicalPlan::rewrite`]: recursively rewrite the node and all of its inputs +//! * [`LogicalPlan::map_children`]: recursively rewrite all inputs of this node +//! * [`LogicalPlan::map_expressions`]: (non recursively) visit all expressions of this node +//! * [`LogicalPlan::map_subqueries`]: (non recursively) rewrite all subqueries of this node +//! * [`LogicalPlan::rewrite_with_subqueries`]: recursively rewrite the node and all of its inputs, including subqueries +//! +//! (Re)creation APIs (these require substantial cloning and thus are slow): +//! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions +//! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions use crate::{ - Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, DdlStatement, Distinct, - DistinctOn, DmlStatement, Explain, Extension, Filter, Join, Limit, LogicalPlan, - Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias, - Union, Unnest, Window, + dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, + DdlStatement, Distinct, DistinctOn, DmlStatement, Explain, Extension, Filter, Join, + Limit, LogicalPlan, Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery, + SubqueryAlias, Union, Unnest, Window, }; use std::sync::Arc; -use crate::dml::CopyTo; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 019e7507b122..d9fc5a6ce261 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -338,7 +338,7 @@ mod tests { Operator, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), plan, @@ -378,7 +378,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ \n Projection: sq_2.c [c:UInt32]\ \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional AND filter @@ -404,7 +404,7 @@ mod tests { \n Projection: sq.c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional OR filter @@ -430,7 +430,7 @@ mod tests { \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -458,7 +458,7 @@ mod tests { \n Projection: sq2.c [c:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for nested IN subqueries @@ -487,7 +487,7 @@ mod tests { \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for filter input modification in case filter not supported @@ -519,7 +519,7 @@ mod tests { \n Projection: sq_inner.c [c:UInt32]\ \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test multiple correlated subqueries @@ -557,7 +557,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -607,7 +607,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -642,7 +642,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -675,7 +675,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -706,7 +706,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -739,7 +739,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -772,7 +772,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -806,7 +806,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); @@ -863,7 +863,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -896,7 +896,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -962,7 +962,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1000,7 +1000,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1030,7 +1030,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1054,7 +1054,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1078,7 +1078,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1107,7 +1107,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1142,7 +1142,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1178,7 +1178,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1224,7 +1224,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1255,7 +1255,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1289,7 +1289,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test recursive correlated subqueries @@ -1332,7 +1332,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional subquery filters @@ -1362,7 +1362,7 @@ mod tests { \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1387,7 +1387,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for exists subquery with both columns in schema @@ -1405,7 +1405,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for correlated exists subquery not equal @@ -1433,7 +1433,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery less than @@ -1461,7 +1461,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with subquery disjunction @@ -1490,7 +1490,7 @@ mod tests { \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists without projection @@ -1516,7 +1516,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists expressions @@ -1544,7 +1544,7 @@ mod tests { \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional filters @@ -1572,7 +1572,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with disjustions @@ -1599,7 +1599,7 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated EXISTS subquery filter @@ -1624,7 +1624,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for single exists subquery filter @@ -1636,7 +1636,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for single NOT exists subquery filter @@ -1648,7 +1648,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } #[test] @@ -1687,7 +1687,7 @@ mod tests { \n Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1713,7 +1713,7 @@ mod tests { \n Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1739,7 +1739,7 @@ mod tests { \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1767,7 +1767,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1795,7 +1795,7 @@ mod tests { \n Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1823,6 +1823,6 @@ mod tests { \n Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 349d4d8878e0..ee44a328f8b3 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -116,7 +116,7 @@ mod tests { use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(EliminateDuplicatedExpr::new()), plan, @@ -134,7 +134,7 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -153,6 +153,6 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 9411dc192beb..2bf5cfa30390 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -91,7 +91,7 @@ mod tests { use crate::test::*; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) } @@ -107,7 +107,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -122,7 +122,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -144,7 +144,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -159,7 +159,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -182,7 +182,7 @@ mod tests { \n TableScan: test\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -205,6 +205,6 @@ mod tests { // Filter is removed let expected = "Projection: test.a\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index e685229c61b2..caf45dda9896 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -83,7 +83,7 @@ mod tests { use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan}; use std::sync::Arc; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected) } @@ -98,7 +98,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -115,6 +115,6 @@ mod tests { CrossJoin:\ \n EmptyRelation\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index fb5d0d17b839..39231d784e00 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -94,24 +94,19 @@ mod tests { use crate::push_down_limit::PushDownLimit; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } fn assert_optimized_plan_eq_with_pushdown( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} @@ -125,7 +120,6 @@ mod tests { .expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } @@ -138,7 +132,7 @@ mod tests { .build()?; // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -158,7 +152,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -172,7 +166,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -192,7 +186,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -210,7 +204,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -228,7 +222,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -250,7 +244,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n TableScan: test\ \n TableScan: test1"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -263,6 +257,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 924a0853418c..da2a6a17214e 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -114,7 +114,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected) } @@ -131,7 +131,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -147,7 +147,7 @@ mod tests { \n Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -167,7 +167,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -188,7 +188,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -210,7 +210,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -230,7 +230,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // We don't need to use project_with_column_index in logical optimizer, @@ -261,7 +261,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -291,7 +291,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -337,7 +337,7 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -384,6 +384,6 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 63c3e789daa6..95a3370ab1b5 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -76,7 +76,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_with_rules( vec![Arc::new(EliminateOneUnion::new())], plan, @@ -97,7 +97,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -113,6 +113,6 @@ mod tests { }); let expected = "TableScan: table"; - assert_optimized_plan_equal(&single_union_plan, expected) + assert_optimized_plan_equal(single_union_plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index a004da2bff19..63b8b887bb32 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -306,7 +306,7 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) } @@ -330,7 +330,7 @@ mod tests { \n Left Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -353,7 +353,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -380,7 +380,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -407,7 +407,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -434,6 +434,6 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 4cfcd07b47d9..60b9ba3031a1 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -164,7 +164,7 @@ mod tests { col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; - fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(ExtractEquijoinPredicate {}), plan, @@ -186,7 +186,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -205,7 +205,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -228,7 +228,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -255,7 +255,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -281,7 +281,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -318,7 +318,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -351,7 +351,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -375,6 +375,6 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 16039b182bb2..fcf85327fdb0 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -116,7 +116,7 @@ mod tests { use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, logical_plan::JoinType, LogicalPlanBuilder}; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) } @@ -128,7 +128,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -139,7 +139,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -176,7 +176,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -197,7 +197,7 @@ mod tests { \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -218,7 +218,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -241,7 +241,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } fn build_plan( diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 69905c990a7f..6967b28f3037 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -941,7 +941,7 @@ mod tests { UserDefinedLogicalNodeCore, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -1090,7 +1090,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1104,7 +1104,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1117,7 +1117,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1130,7 +1130,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1152,7 +1152,7 @@ mod tests { \n Projection: \ \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ \n TableScan: ?table? projection=[]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1175,7 +1175,7 @@ mod tests { .build()?; let expected = "Projection: (?table?.s)[x]\ \n TableScan: ?table? projection=[s]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1187,7 +1187,7 @@ mod tests { let expected = "Projection: (- test.a)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1199,7 +1199,7 @@ mod tests { let expected = "Projection: test.a IS NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1211,7 +1211,7 @@ mod tests { let expected = "Projection: test.a IS NOT NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1223,7 +1223,7 @@ mod tests { let expected = "Projection: test.a IS TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1235,7 +1235,7 @@ mod tests { let expected = "Projection: test.a IS NOT TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1247,7 +1247,7 @@ mod tests { let expected = "Projection: test.a IS FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1259,7 +1259,7 @@ mod tests { let expected = "Projection: test.a IS NOT FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1271,7 +1271,7 @@ mod tests { let expected = "Projection: test.a IS UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1283,7 +1283,7 @@ mod tests { let expected = "Projection: test.a IS NOT UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1295,7 +1295,7 @@ mod tests { let expected = "Projection: NOT test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1307,7 +1307,7 @@ mod tests { let expected = "Projection: TRY_CAST(test.a AS Float64)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1323,7 +1323,7 @@ mod tests { let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1335,7 +1335,7 @@ mod tests { let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Test outer projection isn't discarded despite the same schema as inner @@ -1356,7 +1356,7 @@ mod tests { let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE d END AS d\ \n Projection: test.a, Int32(0) AS d\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Since only column `a` is referred at the output. Scan should only contain projection=[a]. @@ -1377,7 +1377,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Only column `a` is referred at the output. However, User defined node itself uses column `b` @@ -1404,7 +1404,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` @@ -1439,7 +1439,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Columns `l.a`, `l.c`, `r.a` is referred at the output. @@ -1464,6 +1464,6 @@ mod tests { \n UserDefinedCrossJoin\ \n TableScan: l projection=[a, c]\ \n TableScan: r projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 03ff402c3e3f..032f9c57321c 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -20,6 +20,16 @@ use std::collections::HashSet; use std::sync::Arc; +use chrono::{DateTime, Utc}; +use log::{debug, warn}; + +use datafusion_common::alias::AliasGenerator; +use datafusion_common::config::ConfigOptions; +use datafusion_common::instant::Instant; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::{DFSchema, DataFusionError, Result}; +use datafusion_expr::logical_plan::LogicalPlan; + use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; @@ -45,15 +55,6 @@ use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use crate::utils::log_plan; -use datafusion_common::alias::AliasGenerator; -use datafusion_common::config::ConfigOptions; -use datafusion_common::instant::Instant; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::logical_plan::LogicalPlan; - -use chrono::{DateTime, Utc}; -use log::{debug, warn}; - /// `OptimizerRule`s transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient /// way. If there are no suitable transformations for the input plan, @@ -184,41 +185,15 @@ pub struct Optimizer { pub rules: Vec>, } -/// If a rule is with `ApplyOrder`, it means the optimizer will derive to handle children instead of -/// recursively handling in rule. -/// We just need handle a subtree pattern itself. -/// -/// Notice: **sometime** result after optimize still can be optimized, we need apply again. +/// Specifies how recursion for an `OptimizerRule` should be handled. /// -/// Usage Example: Merge Limit (subtree pattern is: Limit-Limit) -/// ```rust -/// use datafusion_expr::{Limit, LogicalPlan, LogicalPlanBuilder}; -/// use datafusion_common::Result; -/// fn merge_limit(parent: &Limit, child: &Limit) -> LogicalPlan { -/// // just for run -/// return parent.input.as_ref().clone(); -/// } -/// fn try_optimize(plan: &LogicalPlan) -> Result> { -/// match plan { -/// LogicalPlan::Limit(limit) => match limit.input.as_ref() { -/// LogicalPlan::Limit(child_limit) => { -/// // merge limit ... -/// let optimized_plan = merge_limit(limit, child_limit); -/// // due to optimized_plan may be optimized again, -/// // for example: plan is Limit-Limit-Limit -/// Ok(Some( -/// try_optimize(&optimized_plan)? -/// .unwrap_or_else(|| optimized_plan.clone()), -/// )) -/// } -/// _ => Ok(None), -/// }, -/// _ => Ok(None), -/// } -/// } -/// ``` +/// * `Some(apply_order)`: The Optimizer will recursively apply the rule to the plan. +/// * `None`: the rule must handle any required recursion itself. +#[derive(Debug, Clone, Copy, PartialEq)] pub enum ApplyOrder { + /// Apply the rule to the node before its inputs TopDown, + /// Apply the rule to the node after its inputs BottomUp, } @@ -274,22 +249,78 @@ impl Optimizer { pub fn with_rules(rules: Vec>) -> Self { Self { rules } } +} + +/// Recursively rewrites LogicalPlans +struct Rewriter<'a> { + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, +} +impl<'a> Rewriter<'a> { + fn new( + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, + ) -> Self { + Self { + apply_order, + rule, + config, + } + } +} + +impl<'a> TreeNodeRewriter for Rewriter<'a> { + type Node = LogicalPlan; + + fn f_down(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::TopDown { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } + + fn f_up(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::BottomUp { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } +} + +/// Invokes the Optimizer rule to rewrite the LogicalPlan in place. +fn optimize_plan_node( + plan: LogicalPlan, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result> { + // TODO: add API to OptimizerRule to allow rewriting by ownership + rule.try_optimize(&plan, config) + .map(|maybe_plan| match maybe_plan { + Some(new_plan) => Transformed::yes(new_plan), + None => Transformed::no(plan), + }) +} + +impl Optimizer { /// Optimizes the logical plan by applying optimizer rules, and /// invoking observer function after each call pub fn optimize( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, mut observer: F, ) -> Result where F: FnMut(&LogicalPlan, &dyn OptimizerRule), { - let options = config.options(); - let mut new_plan = plan.clone(); - let start_time = Instant::now(); + let options = config.options(); + let mut new_plan = plan; let mut previous_plans = HashSet::with_capacity(16); previous_plans.insert(LogicalPlanSignature::new(&new_plan)); @@ -299,44 +330,71 @@ impl Optimizer { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); for rule in &self.rules { - let result = - self.optimize_recursively(rule, &new_plan, config) - .and_then(|plan| { - if let Some(plan) = &plan { - assert_schema_is_the_same(rule.name(), plan, &new_plan)?; - } - Ok(plan) - }); - match result { - Ok(Some(plan)) => { - new_plan = plan; - observer(&new_plan, rule.as_ref()); - log_plan(rule.name(), &new_plan); - } - Ok(None) => { + // If skipping failed rules, copy plan before attempting to rewrite + // as rewriting is destructive + let prev_plan = options + .optimizer + .skip_failed_rules + .then(|| new_plan.clone()); + + let starting_schema = new_plan.schema().clone(); + + let result = match rule.apply_order() { + // optimizer handles recursion + Some(apply_order) => new_plan.rewrite(&mut Rewriter::new( + apply_order, + rule.as_ref(), + config, + )), + // rule handles recursion itself + None => optimize_plan_node(new_plan, rule.as_ref(), config), + } + // verify the rule didn't change the schema + .and_then(|tnr| { + assert_schema_is_the_same(rule.name(), &starting_schema, &tnr.data)?; + Ok(tnr) + }); + + // Handle results + match (result, prev_plan) { + // OptimizerRule was successful + ( + Ok(Transformed { + data, transformed, .. + }), + _, + ) => { + new_plan = data; observer(&new_plan, rule.as_ref()); - debug!( - "Plan unchanged by optimizer rule '{}' (pass {})", - rule.name(), - i - ); + if transformed { + log_plan(rule.name(), &new_plan); + } else { + debug!( + "Plan unchanged by optimizer rule '{}' (pass {})", + rule.name(), + i + ); + } } - Err(e) => { - if options.optimizer.skip_failed_rules { - // Note to future readers: if you see this warning it signals a - // bug in the DataFusion optimizer. Please consider filing a ticket - // https://github.com/apache/arrow-datafusion - warn!( + // OptimizerRule was unsuccessful, but skipped failed rules is on + // so use the previous plan + (Err(e), Some(orig_plan)) => { + // Note to future readers: if you see this warning it signals a + // bug in the DataFusion optimizer. Please consider filing a ticket + // https://github.com/apache/arrow-datafusion + warn!( "Skipping optimizer rule '{}' due to unexpected error: {}", rule.name(), e ); - } else { - return Err(DataFusionError::Context( - format!("Optimizer rule '{}' failed", rule.name(),), - Box::new(e), - )); - } + new_plan = orig_plan; + } + // OptimizerRule was unsuccessful, but skipped failed rules is off, return error + (Err(e), None) => { + return Err(e.context(format!( + "Optimizer rule '{}' failed", + rule.name() + ))); } } } @@ -356,97 +414,22 @@ impl Optimizer { debug!("Optimizer took {} ms", start_time.elapsed().as_millis()); Ok(new_plan) } - - fn optimize_node( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - // TODO: future feature: We can do Batch optimize - rule.try_optimize(plan, config) - } - - fn optimize_inputs( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - let inputs = plan.inputs(); - let result = inputs - .iter() - .map(|sub_plan| self.optimize_recursively(rule, sub_plan, config)) - .collect::>>()?; - if result.is_empty() || result.iter().all(|o| o.is_none()) { - return Ok(None); - } - - let new_inputs = result - .into_iter() - .zip(inputs) - .map(|(new_plan, old_plan)| match new_plan { - Some(plan) => plan, - None => old_plan.clone(), - }) - .collect(); - - let exprs = plan.expressions(); - plan.with_new_exprs(exprs, new_inputs).map(Some) - } - - /// Use a rule to optimize the whole plan. - /// If the rule with `ApplyOrder`, we don't need to recursively handle children in rule. - pub fn optimize_recursively( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - match rule.apply_order() { - Some(order) => match order { - ApplyOrder::TopDown => { - let optimize_self_opt = self.optimize_node(rule, plan, config)?; - let optimize_inputs_opt = match &optimize_self_opt { - Some(optimized_plan) => { - self.optimize_inputs(rule, optimized_plan, config)? - } - _ => self.optimize_inputs(rule, plan, config)?, - }; - Ok(optimize_inputs_opt.or(optimize_self_opt)) - } - ApplyOrder::BottomUp => { - let optimize_inputs_opt = self.optimize_inputs(rule, plan, config)?; - let optimize_self_opt = match &optimize_inputs_opt { - Some(optimized_plan) => { - self.optimize_node(rule, optimized_plan, config)? - } - _ => self.optimize_node(rule, plan, config)?, - }; - Ok(optimize_self_opt.or(optimize_inputs_opt)) - } - }, - _ => rule.try_optimize(plan, config), - } - } } -/// Returns an error if plans have different schemas. +/// Returns an error if `new_plan`'s schema is different than `prev_schema` /// /// It ignores metadata and nullability. pub(crate) fn assert_schema_is_the_same( rule_name: &str, - prev_plan: &LogicalPlan, + prev_schema: &DFSchema, new_plan: &LogicalPlan, ) -> Result<()> { - let equivalent = new_plan - .schema() - .equivalent_names_and_types(prev_plan.schema()); + let equivalent = new_plan.schema().equivalent_names_and_types(prev_schema); if !equivalent { let e = DataFusionError::Internal(format!( "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", - prev_plan.schema(), + prev_schema, new_plan.schema() )); Err(DataFusionError::Context( @@ -462,14 +445,15 @@ pub(crate) fn assert_schema_is_the_same( mod tests { use std::sync::{Arc, Mutex}; - use super::ApplyOrder; + use datafusion_common::{plan_err, DFSchema, DFSchemaRef, Result}; + use datafusion_expr::logical_plan::EmptyRelation; + use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; + use crate::optimizer::Optimizer; use crate::test::test_table_scan; use crate::{OptimizerConfig, OptimizerContext, OptimizerRule}; - use datafusion_common::{plan_err, DFSchema, DFSchemaRef, Result}; - use datafusion_expr::logical_plan::EmptyRelation; - use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; + use super::ApplyOrder; #[test] fn skip_failing_rule() { @@ -479,7 +463,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -490,7 +474,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'bad rule' failed\ncaused by\n\ Error during planning: rule failed", @@ -506,21 +490,27 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( - "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ - Internal error: Failed due to a difference in schemas, original schema: \ - DFSchema { inner: Schema { fields: \ - [Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }, \ - field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" })], \ - functional_dependencies: FunctionalDependencies { deps: [] } }, \ + "Optimizer rule 'get table_scan rule' failed\n\ + caused by\nget table_scan rule\ncaused by\n\ + Internal error: Failed due to a difference in schemas, \ + original schema: DFSchema { inner: Schema { \ + fields: [], \ + metadata: {} }, \ + field_qualifiers: [], \ + functional_dependencies: FunctionalDependencies { deps: [] } \ + }, \ new schema: DFSchema { inner: Schema { \ - fields: [], metadata: {} }, \ - field_qualifiers: [], \ - functional_dependencies: FunctionalDependencies { deps: [] } }.\n\ - This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker", + fields: [\ + Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ + Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ + Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }\ + ], \ + metadata: {} }, \ + field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" })], \ + functional_dependencies: FunctionalDependencies { deps: [] } }.\n\ + This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker", err.strip_backtrace() ); } @@ -533,7 +523,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -554,7 +544,7 @@ mod tests { // optimizing should be ok, but the schema will have changed (no metadata) assert_ne!(plan.schema().as_ref(), input_schema.as_ref()); - let optimized_plan = opt.optimize(&plan, &config, &observe)?; + let optimized_plan = opt.optimize(plan, &config, &observe)?; // metadata was removed assert_eq!(optimized_plan.schema().as_ref(), input_schema.as_ref()); Ok(()) @@ -575,7 +565,7 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan.clone(), &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 3 plans assert_eq!(3, plans.len()); @@ -601,7 +591,7 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan, &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 4 plans assert_eq!(4, plans.len()); diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 2aca6f93254a..445109bbdf77 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -198,12 +198,12 @@ mod tests { use super::*; - fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) } fn assert_together_optimized_plan_eq( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { assert_optimized_plan_eq_with_rules( @@ -226,7 +226,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_eq(&plan, expected) + assert_eq(plan, expected) } #[test] @@ -249,7 +249,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -262,7 +262,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -287,7 +287,7 @@ mod tests { let expected = "Union\ \n TableScan: test1\ \n TableScan: test4"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -312,7 +312,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -339,7 +339,7 @@ mod tests { let expected = "Union\ \n TableScan: test2\ \n TableScan: test3"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -352,7 +352,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -367,7 +367,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -400,6 +400,6 @@ mod tests { let expected = "Projection: a, b, c\ \n TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f3ce8bbcde72..2b123e3559f5 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1028,11 +1028,11 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { #[cfg(test)] mod tests { + use super::*; use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::Arc; - use super::*; use crate::optimizer::Optimizer; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::test::*; @@ -1040,6 +1040,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{DFSchema, DFSchemaRef, ScalarValue}; + use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr, @@ -1049,9 +1050,9 @@ mod tests { }; use async_trait::async_trait; - use datafusion_expr::expr::ScalarFunction; + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(PushDownFilter::new()), plan, @@ -1060,29 +1061,17 @@ mod tests { } fn assert_optimized_plan_eq_with_rewrite_predicate( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![ Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(PushDownFilter::new()), ]); - let mut optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); - optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.get(1).unwrap(), - &optimized_plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(plan.schema(), optimized_plan.schema()); assert_eq!(expected, formatted_plan); Ok(()) } @@ -1098,7 +1087,7 @@ mod tests { let expected = "\ Projection: test.a, test.b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1115,7 +1104,7 @@ mod tests { \n Limit: skip=0, fetch=10\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1125,7 +1114,7 @@ mod tests { .filter(lit(0i64).eq(lit(1i64)))? .build()?; let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1141,7 +1130,7 @@ mod tests { Projection: test.c, test.b\ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1155,7 +1144,7 @@ mod tests { let expected = "\ Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1168,7 +1157,7 @@ mod tests { let expected = "Filter: test.b > Int64(10)\ \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1180,7 +1169,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1195,7 +1184,7 @@ mod tests { Filter: b > Int64(10)\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written @@ -1210,7 +1199,7 @@ mod tests { let expected = "\ Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } fn add(left: Expr, right: Expr) -> Expr { @@ -1254,7 +1243,7 @@ mod tests { let expected = "\ Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written @@ -1286,7 +1275,7 @@ mod tests { Projection: b * Int32(3) AS a, test.c\ \n Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[derive(Debug, PartialEq, Eq, Hash)] @@ -1349,7 +1338,7 @@ mod tests { let expected = "\ NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1366,7 +1355,7 @@ mod tests { Filter: test.c = Int64(2)\ \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1383,7 +1372,7 @@ mod tests { NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1401,7 +1390,7 @@ mod tests { \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed @@ -1434,7 +1423,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed @@ -1468,7 +1457,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two limits are in place, we jump neither @@ -1490,7 +1479,7 @@ mod tests { \n Limit: skip=0, fetch=20\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1505,7 +1494,7 @@ mod tests { let expected = "Union\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test2, full_filters=[test2.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1528,7 +1517,7 @@ mod tests { \n SubqueryAlias: test2\ \n Projection: test.a AS b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1559,7 +1548,7 @@ mod tests { \n Projection: test1.d, test1.e, test1.f\ \n TableScan: test1, full_filters=[test1.d > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1585,7 +1574,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ \n TableScan: test1, full_filters=[test1.a > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters with the same columns are correctly placed @@ -1619,7 +1608,7 @@ mod tests { \n Projection: test.a\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters to be placed on the same depth are ANDed @@ -1649,7 +1638,7 @@ mod tests { \n Limit: skip=0, fetch=1\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters on a plan with user nodes are not lost @@ -1675,7 +1664,7 @@ mod tests { TestUserDefined\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-on-join predicates on a column common to both sides is pushed to both sides @@ -1713,7 +1702,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-using-join predicates on a column common to both sides is pushed to both sides @@ -1750,7 +1739,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from both sides are converted to join filterss @@ -1792,7 +1781,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from one side of a join are pushed only to that side @@ -1834,7 +1823,7 @@ mod tests { \n TableScan: test, full_filters=[test.b <= Int64(1)]\ \n Projection: test2.a, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the right side of a left join are not duplicated @@ -1873,7 +1862,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the left side of a right join are not duplicated @@ -1911,7 +1900,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-left-join predicate on a column common to both sides is only pushed to the left side @@ -1949,7 +1938,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-right-join predicate on a column common to both sides is only pushed to the right side @@ -1987,7 +1976,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to both inputs @@ -2030,7 +2019,7 @@ mod tests { \n TableScan: test, full_filters=[test.c > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// join filter should be completely removed after pushdown @@ -2072,7 +2061,7 @@ mod tests { \n TableScan: test, full_filters=[test.b > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2112,7 +2101,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to right input @@ -2155,7 +2144,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to left input @@ -2198,7 +2187,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should not be pushed @@ -2236,7 +2225,7 @@ mod tests { ); let expected = &format!("{plan:?}"); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } struct PushDownProvider { @@ -2295,7 +2284,7 @@ mod tests { let expected = "\ TableScan: test, full_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2306,7 +2295,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test, partial_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2314,7 +2303,7 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; - let optimised_plan = PushDownFilter::new() + let optimized_plan = PushDownFilter::new() .try_optimize(&plan, &OptimizerContext::new()) .expect("failed to optimize plan") .unwrap(); @@ -2325,7 +2314,7 @@ mod tests { // Optimizing the same plan multiple times should produce the same plan // each time. - assert_optimized_plan_eq(&optimised_plan, expected) + assert_optimized_plan_eq(optimized_plan, expected) } #[test] @@ -2336,7 +2325,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2365,7 +2354,7 @@ mod tests { \n Filter: a = Int64(10) AND b > Int64(11)\ \n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2396,7 +2385,7 @@ Projection: a, b "# .trim(); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2424,7 +2413,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2456,7 +2445,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2481,7 +2470,7 @@ Projection: a, b Projection: test.a AS b, test.c AS d\ \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2521,7 +2510,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b AS d\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2550,7 +2539,7 @@ Projection: a, b Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2582,7 +2571,7 @@ Projection: a, b \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2618,7 +2607,7 @@ Projection: a, b \n Subquery:\ \n Projection: sq.c\ \n TableScan: sq"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2651,7 +2640,7 @@ Projection: a, b \n Projection: Int64(0) AS a\ \n Filter: Int64(0) = Int64(1)\ \n EmptyRelation"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2679,14 +2668,14 @@ Projection: a, b \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ \n Projection: test1.a AS d, test1.a AS e\ \n TableScan: test1"; - assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; + assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(), expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. let optimized_plan = PushDownFilter::new() .try_optimize(&plan, &OptimizerContext::new())? .expect("failed to optimize plan"); - assert_optimized_plan_eq(&optimized_plan, expected) + assert_optimized_plan_eq(optimized_plan, expected) } #[test] @@ -2727,7 +2716,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2768,7 +2757,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2814,7 +2803,7 @@ Projection: a, b \n TableScan: test1\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2859,7 +2848,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[derive(Debug)] @@ -2919,7 +2908,7 @@ Projection: a, b \n Projection: test1.a, SUM(test1.b), TestScalarUDF() + Int32(1) AS r\ \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2965,6 +2954,6 @@ Projection: a, b \n Inner Join: test1.a = test2.a\ \n TableScan: test1\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index cca6c3fd9bd1..6f1d7bf97cfe 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -285,7 +285,7 @@ mod test { max, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) } @@ -304,7 +304,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -322,7 +322,7 @@ mod test { let expected = "Limit: skip=0, fetch=10\ \n TableScan: test, fetch=10"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -339,7 +339,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -359,7 +359,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -376,7 +376,7 @@ mod test { \n Sort: test.a, fetch=10\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -393,7 +393,7 @@ mod test { \n Sort: test.a, fetch=15\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -412,7 +412,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -427,7 +427,7 @@ mod test { let expected = "Limit: skip=10, fetch=None\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -445,7 +445,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -462,7 +462,7 @@ mod test { \n Limit: skip=10, fetch=990\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -479,7 +479,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -495,7 +495,7 @@ mod test { let expected = "Limit: skip=10, fetch=10\ \n TableScan: test, fetch=20"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -512,7 +512,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -532,7 +532,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -556,7 +556,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -580,7 +580,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -609,7 +609,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -638,7 +638,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -664,7 +664,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -683,7 +683,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -702,7 +702,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -720,7 +720,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -738,7 +738,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -756,7 +756,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1) .join( @@ -774,7 +774,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -799,7 +799,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -824,7 +824,7 @@ mod test { \n TableScan: test, fetch=1010\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -849,7 +849,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -874,7 +874,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test2, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -894,7 +894,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -914,7 +914,7 @@ mod test { \n Limit: skip=0, fetch=2000\ \n TableScan: test2, fetch=2000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -929,7 +929,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -944,7 +944,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -961,6 +961,6 @@ mod test { \n Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index ae57ed9e5a34..2f578094b3bc 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -24,7 +24,7 @@ mod tests { use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::*; - use crate::OptimizerContext; + use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::builder::table_scan_with_filters; @@ -48,7 +48,7 @@ mod tests { let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -62,7 +62,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -78,7 +78,7 @@ mod tests { \n SubqueryAlias: a\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -95,7 +95,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -120,7 +120,7 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ \n TableScan: m4 projection=[tag.one]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -134,7 +134,7 @@ mod tests { let expected = "Projection: test.a, test.c, test.b\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -144,7 +144,7 @@ mod tests { let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; let expected = "TableScan: test projection=[b, a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -157,7 +157,7 @@ mod tests { let expected = "Projection: test.a, test.b\ \n TableScan: test projection=[b, a]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -170,7 +170,7 @@ mod tests { let expected = "Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -192,7 +192,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -212,7 +212,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -264,7 +264,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -314,7 +314,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[a]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -358,7 +358,7 @@ mod tests { let expected = "Projection: CAST(test.c AS Float64)\ \n TableScan: test projection=[c]"; - assert_optimized_plan_eq(&projection, expected) + assert_optimized_plan_eq(projection, expected) } #[test] @@ -374,7 +374,7 @@ mod tests { let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -395,7 +395,7 @@ mod tests { let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -415,7 +415,7 @@ mod tests { \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -424,7 +424,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan).build()?; // should expand projection to all columns without projection let expected = "TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -435,7 +435,7 @@ mod tests { .build()?; let expected = "Projection: Int64(1), Int64(2)\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes unused columns in projections @@ -454,14 +454,14 @@ mod tests { assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); - let plan = optimize(&plan).expect("failed to optimize plan"); + let plan = optimize(plan).expect("failed to optimize plan"); let expected = "\ Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.a)]]\ \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes un-needed projections @@ -483,7 +483,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -512,7 +512,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that optimizing twice yields same plan @@ -525,9 +525,9 @@ mod tests { .project(vec![lit(1).alias("a")])? .build()?; - let optimized_plan1 = optimize(&plan).expect("failed to optimize plan"); + let optimized_plan1 = optimize(plan).expect("failed to optimize plan"); let optimized_plan2 = - optimize(&optimized_plan1).expect("failed to optimize plan"); + optimize(optimized_plan1.clone()).expect("failed to optimize plan"); let formatted_plan1 = format!("{optimized_plan1:?}"); let formatted_plan2 = format!("{optimized_plan2:?}"); @@ -556,7 +556,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -582,7 +582,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -599,7 +599,7 @@ mod tests { \n Distinct:\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -638,25 +638,23 @@ mod tests { \n WindowAggr: windowExpr=[[MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { let optimized_plan = optimize(plan).expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); Ok(()) } - fn optimize(plan: &LogicalPlan) -> Result { + fn optimize(plan: LogicalPlan) -> Result { let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + Ok(optimized_plan) } + + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 752915be69c0..f464506057ff 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -172,7 +172,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } @@ -195,7 +195,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index a2c4eabcaae6..a8999f9c1d3c 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -429,7 +429,7 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -485,7 +485,7 @@ mod tests { \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -523,7 +523,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -559,7 +559,7 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -593,7 +593,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -732,7 +732,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -798,7 +798,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -837,7 +837,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -877,7 +877,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -910,7 +910,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -942,7 +942,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -973,7 +973,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -1030,7 +1030,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -1079,7 +1079,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 076bf4e24296..602994a9e3e2 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -313,7 +313,7 @@ mod tests { min, sum, AggregateFunction, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(SingleDistinctToGroupBy::new()), plan, @@ -335,7 +335,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]] [MAX(test.b):UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -352,7 +352,7 @@ mod tests { \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -373,7 +373,7 @@ mod tests { let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -391,7 +391,7 @@ mod tests { let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -410,7 +410,7 @@ mod tests { let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -426,7 +426,7 @@ mod tests { \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -443,7 +443,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -461,7 +461,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(DISTINCT test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -490,7 +490,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -508,7 +508,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -525,7 +525,7 @@ mod tests { \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -555,7 +555,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -574,7 +574,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -593,7 +593,7 @@ mod tests { \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -616,7 +616,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -639,7 +639,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -662,7 +662,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -685,7 +685,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -708,6 +708,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index e691fe9a5351..cafda8359aa3 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::{Analyzer, AnalyzerRule}; -use crate::optimizer::{assert_schema_is_the_same, Optimizer}; +use crate::optimizer::Optimizer; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -150,22 +150,19 @@ pub fn assert_analyzer_check_err( } } } + +fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + pub fn assert_optimized_plan_eq( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule.clone()]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + // Apply the rule once + let opt_context = OptimizerContext::new().with_max_passes(1); - // Ensure schemas always match after an optimization - assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; + let optimizer = Optimizer::with_rules(vec![rule.clone()]); + let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -174,7 +171,7 @@ pub fn assert_optimized_plan_eq( pub fn assert_optimized_plan_eq_with_rules( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} @@ -187,58 +184,44 @@ pub fn assert_optimized_plan_eq_with_rules( .expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } pub fn assert_optimized_plan_eq_display_indent( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - ) - .expect("failed to optimize plan") - .unwrap_or_else(|| plan.clone()); + .optimize(plan, &OptimizerContext::new(), observe) + .expect("failed to optimize plan"); let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); } pub fn assert_multi_rules_optimized_plan_eq_display_indent( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(rules); - let mut optimized_plan = plan.clone(); - for rule in &optimizer.rules { - optimized_plan = optimizer - .optimize_recursively(rule, &optimized_plan, &OptimizerContext::new()) - .expect("failed to optimize plan") - .unwrap_or_else(|| optimized_plan.clone()); - } + let optimized_plan = optimizer + .optimize(plan, &OptimizerContext::new(), observe) + .expect("failed to optimize plan"); let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); } pub fn assert_optimizer_err( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); - let res = optimizer.optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - ); + let res = optimizer.optimize(plan, &OptimizerContext::new(), observe); match res { - Ok(plan) => assert_eq!(format!("{}", plan.unwrap().display_indent()), "An error"), + Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"), Err(ref e) => { let actual = format!("{e}"); if expected.is_empty() || !actual.contains(expected) { @@ -250,16 +233,11 @@ pub fn assert_optimizer_err( pub fn assert_optimization_skipped( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![rule]); - let new_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let new_plan = optimizer.optimize(plan.clone(), &OptimizerContext::new(), observe)?; + assert_eq!( format!("{}", plan.display_indent()), format!("{}", new_plan.display_indent()) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 61d2535930b2..01db5e817c56 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -25,7 +25,7 @@ use datafusion_common::{plan_err, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; +use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; @@ -315,9 +315,11 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, observe) } +fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + #[derive(Default)] struct MyContextProvider { options: ConfigOptions, diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index c0dbf5164e19..2ed5da7ced20 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -36,6 +36,7 @@ use crate::repartition::distributor_channels::{ channels, partition_aware_channels, DistributionReceiver, DistributionSender, }; use crate::sorts::streaming_merge; +use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; use arrow::array::{ArrayRef, UInt64Builder}; @@ -48,7 +49,7 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; use futures::stream::Stream; -use futures::{FutureExt, StreamExt}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use hashbrown::HashMap; use log::trace; use parking_lot::Mutex; @@ -77,6 +78,102 @@ struct RepartitionExecState { abort_helper: Arc>>, } +impl RepartitionExecState { + fn new( + input: Arc, + partitioning: Partitioning, + metrics: ExecutionPlanMetricsSet, + preserve_order: bool, + name: String, + context: Arc, + ) -> Self { + let num_input_partitions = input.output_partitioning().partition_count(); + let num_output_partitions = partitioning.partition_count(); + + let (txs, rxs) = if preserve_order { + let (txs, rxs) = + partition_aware_channels(num_input_partitions, num_output_partitions); + // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition + let txs = transpose(txs); + let rxs = transpose(rxs); + (txs, rxs) + } else { + // create one channel per *output* partition + // note we use a custom channel that ensures there is always data for each receiver + // but limits the amount of buffering if required. + let (txs, rxs) = channels(num_output_partitions); + // Clone sender for each input partitions + let txs = txs + .into_iter() + .map(|item| vec![item; num_input_partitions]) + .collect::>(); + let rxs = rxs.into_iter().map(|item| vec![item]).collect::>(); + (txs, rxs) + }; + + let mut channels = HashMap::with_capacity(txs.len()); + for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { + let reservation = Arc::new(Mutex::new( + MemoryConsumer::new(format!("{}[{partition}]", name)) + .register(context.memory_pool()), + )); + channels.insert(partition, (tx, rx, reservation)); + } + + // launch one async task per *input* partition + let mut spawned_tasks = Vec::with_capacity(num_input_partitions); + for i in 0..num_input_partitions { + let txs: HashMap<_, _> = channels + .iter() + .map(|(partition, (tx, _rx, reservation))| { + (*partition, (tx[i].clone(), Arc::clone(reservation))) + }) + .collect(); + + // TODO: metric input-output mapping is broken + let r_metrics = RepartitionMetrics::new(i, 0, &metrics); + + let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( + input.clone(), + i, + txs.clone(), + partitioning.clone(), + r_metrics, + context.clone(), + )); + + // In a separate task, wait for each input to be done + // (and pass along any errors, including panic!s) + let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task( + input_task, + txs.into_iter() + .map(|(partition, (tx, _reservation))| (partition, tx)) + .collect(), + )); + spawned_tasks.push(wait_for_task); + } + + Self { + channels, + abort_helper: Arc::new(spawned_tasks), + } + } +} + +/// Lazily initialized state +/// +/// Note that the state is initialized ONCE for all partitions by a single task(thread). +/// This may take a short while. It is also like that multiple threads +/// call execute at the same time, because we have just started "target partitions" tasks +/// which is commonly set to the number of CPU cores and all call execute at the same time. +/// +/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles +/// in a futex lock but instead allow other threads to do something useful. +/// +/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration +/// (e.g. removing channels on completion) where the overhead of `await` is not warranted. +type LazyState = Arc>>; + /// A utility that can be used to partition batches based on [`Partitioning`] pub struct BatchPartitioner { state: BatchPartitionerState, @@ -298,7 +395,7 @@ pub struct RepartitionExec { /// Partitioning scheme to use partitioning: Partitioning, /// Inner state that is initialized when the first output stream is created. - state: Arc>, + state: LazyState, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Boolean flag to decide whether to preserve ordering. If true means @@ -453,134 +550,104 @@ impl ExecutionPlan for RepartitionExec { self.name(), partition ); - // lock mutexes - let mut state = self.state.lock(); - - let num_input_partitions = self.input.output_partitioning().partition_count(); - let num_output_partitions = self.partitioning.partition_count(); - - // if this is the first partition to be invoked then we need to set up initial state - if state.channels.is_empty() { - let (txs, rxs) = if self.preserve_order { - let (txs, rxs) = - partition_aware_channels(num_input_partitions, num_output_partitions); - // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition - let txs = transpose(txs); - let rxs = transpose(rxs); - (txs, rxs) - } else { - // create one channel per *output* partition - // note we use a custom channel that ensures there is always data for each receiver - // but limits the amount of buffering if required. - let (txs, rxs) = channels(num_output_partitions); - // Clone sender for each input partitions - let txs = txs - .into_iter() - .map(|item| vec![item; num_input_partitions]) - .collect::>(); - let rxs = rxs.into_iter().map(|item| vec![item]).collect::>(); - (txs, rxs) - }; - for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { - let reservation = Arc::new(Mutex::new( - MemoryConsumer::new(format!("{}[{partition}]", self.name())) - .register(context.memory_pool()), - )); - state.channels.insert(partition, (tx, rx, reservation)); - } - // launch one async task per *input* partition - let mut spawned_tasks = Vec::with_capacity(num_input_partitions); - for i in 0..num_input_partitions { - let txs: HashMap<_, _> = state - .channels - .iter() - .map(|(partition, (tx, _rx, reservation))| { - (*partition, (tx[i].clone(), Arc::clone(reservation))) - }) - .collect(); - - let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics); - - let input_task = SpawnedTask::spawn(Self::pull_from_input( - self.input.clone(), - i, - txs.clone(), - self.partitioning.clone(), - r_metrics, - context.clone(), - )); - - // In a separate task, wait for each input to be done - // (and pass along any errors, including panic!s) - let wait_for_task = SpawnedTask::spawn(Self::wait_for_task( - input_task, - txs.into_iter() - .map(|(partition, (tx, _reservation))| (partition, tx)) - .collect(), - )); - spawned_tasks.push(wait_for_task); - } + let lazy_state = Arc::clone(&self.state); + let input = Arc::clone(&self.input); + let partitioning = self.partitioning.clone(); + let metrics = self.metrics.clone(); + let preserve_order = self.preserve_order; + let name = self.name().to_owned(); + let schema = self.schema(); + let schema_captured = Arc::clone(&schema); + + // Get existing ordering to use for merging + let sort_exprs = self.sort_exprs().unwrap_or(&[]).to_owned(); + + let stream = futures::stream::once(async move { + let num_input_partitions = input.output_partitioning().partition_count(); + + let input_captured = Arc::clone(&input); + let metrics_captured = metrics.clone(); + let name_captured = name.clone(); + let context_captured = Arc::clone(&context); + let state = lazy_state + .get_or_init(|| async move { + Mutex::new(RepartitionExecState::new( + input_captured, + partitioning, + metrics_captured, + preserve_order, + name_captured, + context_captured, + )) + }) + .await; - state.abort_helper = Arc::new(spawned_tasks) - } + // lock scope + let (mut rx, reservation, abort_helper) = { + // lock mutexes + let mut state = state.lock(); - trace!( - "Before returning stream in {}::execute for partition: {}", - self.name(), - partition - ); + // now return stream for the specified *output* partition which will + // read from the channel + let (_tx, rx, reservation) = state + .channels + .remove(&partition) + .expect("partition not used yet"); - // now return stream for the specified *output* partition which will - // read from the channel - let (_tx, mut rx, reservation) = state - .channels - .remove(&partition) - .expect("partition not used yet"); + (rx, reservation, Arc::clone(&state.abort_helper)) + }; - if self.preserve_order { - // Store streams from all the input partitions: - let input_streams = rx - .into_iter() - .map(|receiver| { - Box::pin(PerPartitionStream { - schema: self.schema(), - receiver, - drop_helper: Arc::clone(&state.abort_helper), - reservation: reservation.clone(), - }) as SendableRecordBatchStream - }) - .collect::>(); - // Note that receiver size (`rx.len()`) and `num_input_partitions` are same. - - // Get existing ordering to use for merging - let sort_exprs = self.sort_exprs().unwrap_or(&[]); - - // Merge streams (while preserving ordering) coming from - // input partitions to this partition: - let fetch = None; - let merge_reservation = - MemoryConsumer::new(format!("{}[Merge {partition}]", self.name())) - .register(context.memory_pool()); - streaming_merge( - input_streams, - self.schema(), - sort_exprs, - BaselineMetrics::new(&self.metrics, partition), - context.session_config().batch_size(), - fetch, - merge_reservation, - ) - } else { - Ok(Box::pin(RepartitionStream { - num_input_partitions, - num_input_partitions_processed: 0, - schema: self.input.schema(), - input: rx.swap_remove(0), - drop_helper: Arc::clone(&state.abort_helper), - reservation, - })) - } + trace!( + "Before returning stream in {}::execute for partition: {}", + name, + partition + ); + + if preserve_order { + // Store streams from all the input partitions: + let input_streams = rx + .into_iter() + .map(|receiver| { + Box::pin(PerPartitionStream { + schema: Arc::clone(&schema_captured), + receiver, + drop_helper: Arc::clone(&abort_helper), + reservation: reservation.clone(), + }) as SendableRecordBatchStream + }) + .collect::>(); + // Note that receiver size (`rx.len()`) and `num_input_partitions` are same. + + // Merge streams (while preserving ordering) coming from + // input partitions to this partition: + let fetch = None; + let merge_reservation = + MemoryConsumer::new(format!("{}[Merge {partition}]", name)) + .register(context.memory_pool()); + streaming_merge( + input_streams, + schema_captured, + &sort_exprs, + BaselineMetrics::new(&metrics, partition), + context.session_config().batch_size(), + fetch, + merge_reservation, + ) + } else { + Ok(Box::pin(RepartitionStream { + num_input_partitions, + num_input_partitions_processed: 0, + schema: input.schema(), + input: rx.swap_remove(0), + drop_helper: abort_helper, + reservation, + }) as SendableRecordBatchStream) + } + }) + .try_flatten(); + let stream = RecordBatchStreamAdapter::new(schema, stream); + Ok(Box::pin(stream)) } fn metrics(&self) -> Option { @@ -606,10 +673,7 @@ impl RepartitionExec { Ok(RepartitionExec { input, partitioning, - state: Arc::new(Mutex::new(RepartitionExecState { - channels: HashMap::new(), - abort_helper: Arc::new(Vec::new()), - })), + state: Default::default(), metrics: ExecutionPlanMetricsSet::new(), preserve_order, cache, @@ -951,6 +1015,7 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use futures::FutureExt; + use tokio::task::JoinSet; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1240,7 +1305,10 @@ mod tests { std::mem::drop(output_stream0); // Now, start sending input - input.wait().await; + let mut background_task = JoinSet::new(); + background_task.spawn(async move { + input.wait().await; + }); // output stream 1 should *not* error and have one of the input batches let batches = crate::common::collect(output_stream1).await.unwrap(); @@ -1277,7 +1345,10 @@ mod tests { let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); - input.wait().await; + let mut background_task = JoinSet::new(); + background_task.spawn(async move { + input.wait().await; + }); let batches_without_drop = crate::common::collect(output_stream1).await.unwrap(); // run some checks on the result @@ -1299,7 +1370,10 @@ mod tests { // now, purposely drop output stream 0 // *before* any outputs are produced std::mem::drop(output_stream0); - input.wait().await; + let mut background_task = JoinSet::new(); + background_task.spawn(async move { + input.wait().await; + }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); assert_eq!(batches_without_drop, batches_with_drop); diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index da9b4168e7e0..135ab8075425 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -587,7 +587,7 @@ FROM t1 ---- 11 11 11 -# subsequent inner join +# subsequent inner join query III rowsort SELECT t1.t1_id, t2.t2_id, t3.t3_id FROM t1