From 0a4d9a6c788c1e4ad340943492abb823bd31c4f9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 8 Apr 2024 11:28:59 +0200 Subject: [PATCH 1/5] Consistent LogicalPlan subquery handling in TreeNode::apply and TreeNode::visit (#9913) * fix * clippy * remove accidental extra apply * minor fixes * fix `LogicalPlan::apply_expressions()` and `LogicalPlan::map_subqueries()` * fix `LogicalPlan::visit_with_subqueries()` * Add deprecated LogicalPlan::inspect_expressions --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/tree_node.rs | 3 +- datafusion/core/src/execution/context/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 558 ++++++++++++++---- datafusion/expr/src/tree_node/expr.rs | 2 +- datafusion/expr/src/tree_node/plan.rs | 53 +- datafusion/optimizer/src/analyzer/mod.rs | 15 +- datafusion/optimizer/src/analyzer/subquery.rs | 2 +- datafusion/optimizer/src/plan_signature.rs | 4 +- 8 files changed, 475 insertions(+), 166 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 8e088e7a0b56..42514537e28d 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -25,10 +25,9 @@ use crate::Result; /// These macros are used to determine continuation during transforming traversals. macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ - #[allow(clippy::redundant_closure_call)] $F_DOWN? .transform_children(|n| n.map_children($F_CHILD))? - .transform_parent(|n| $F_UP(n)) + .transform_parent($F_UP) }}; } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index f15c1c218db6..9e48c7b8a6f2 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -67,7 +67,7 @@ use datafusion_common::{ alias::AliasGenerator, config::{ConfigExtension, TableOptions}, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, + tree_node::{TreeNodeRecursion, TreeNodeVisitor}, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; @@ -2298,7 +2298,7 @@ impl SQLOptions { /// Return an error if the [`LogicalPlan`] has any nodes that are /// incompatible with this [`SQLOptions`]. pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> { - plan.visit(&mut BadPlanVisitor::new(self))?; + plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?; Ok(()) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 3d40dcae0e4b..4f55bbfe3f4d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -34,8 +34,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, - split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, @@ -45,16 +44,19 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TransformedResult, TreeNode, TreeNodeIterator, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ - aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, + aggregate_functional_dependencies, internal_err, map_until_stop_and_collect, + plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, + FunctionalDependence, FunctionalDependencies, ParamValues, Result, TableReference, + UnnestOptions, }; // backwards compatibility use crate::display::PgJsonVisitor; +use crate::tree_node::expr::transform_option_vec; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -248,9 +250,9 @@ impl LogicalPlan { /// DataFusion's optimizer attempts to optimize them away. pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(|e| { exprs.push(e.clone()); - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -261,13 +263,13 @@ impl LogicalPlan { /// logical plan nodes and all its descendant nodes. pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(|e| { find_out_reference_exprs(e).into_iter().for_each(|e| { if !exprs.contains(&e) { exprs.push(e) } }); - Ok(()) as Result<(), DataFusionError> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -282,60 +284,81 @@ impl LogicalPlan { exprs } - /// Calls `f` on all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children. + #[deprecated(since = "37.0.0", note = "Use `apply_expressions` instead")] pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, { + let mut err = Ok(()); + self.apply_expressions(|e| { + if let Err(e) = f(e) { + // save the error for later (it may not be a DataFusionError + err = Err(e); + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + // The closure always returns OK, so this will always too + .expect("no way to return error during recursion"); + + err + } + + /// Calls `f` on all expressions (non-recursively) in the current + /// logical plan node. This does not include expressions in any + /// children. + pub fn apply_expressions Result>( + &self, + mut f: F, + ) -> Result { match self { LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().try_for_each(f) - } - LogicalPlan::Values(Values { values, .. }) => { - values.iter().flatten().try_for_each(f) + expr.iter().apply_until_stop(f) } + LogicalPlan::Values(Values { values, .. }) => values + .iter() + .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { - Partitioning::Hash(expr, _) => expr.iter().try_for_each(f), - Partitioning::DistributeBy(expr) => expr.iter().try_for_each(f), - Partitioning::RoundRobinBatch(_) => Ok(()), + Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { + expr.iter().apply_until_stop(f) + } + Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().try_for_each(f) + window_expr.iter().apply_until_stop(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr.iter().chain(aggr_expr.iter()).try_for_each(f), + }) => group_expr + .iter() + .chain(aggr_expr.iter()) + .apply_until_stop(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). LogicalPlan::Join(Join { on, filter, .. }) => { on.iter() + // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... // it not ideal to create an expr here to analyze them, but could cache it on the Join itself .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .try_for_each(|e| f(&e))?; - - if let Some(filter) = filter.as_ref() { - f(filter) - } else { - Ok(()) - } + .apply_until_stop(|e| f(&e))? + .visit_sibling(|| filter.iter().apply_until_stop(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().try_for_each(f), + LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().try_for_each(f) + extension.node.expressions().iter().apply_until_stop(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().try_for_each(f) + filters.iter().apply_until_stop(f) } LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) @@ -348,8 +371,8 @@ impl LogicalPlan { })) => on_expr .iter() .chain(select_expr.iter()) - .chain(sort_expr.clone().unwrap_or(vec![]).iter()) - .try_for_each(f), + .chain(sort_expr.iter().flatten()) + .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) @@ -366,10 +389,225 @@ impl LogicalPlan { | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Ok(()), + | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), } } + pub fn map_expressions Result>>( + self, + mut f: F, + ) -> Result> { + Ok(match self { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), + LogicalPlan::Values(Values { schema, values }) => values + .into_iter() + .map_until_stop_and_collect(|value| { + value.into_iter().map_until_stop_and_collect(&mut f) + })? + .update_data(|values| LogicalPlan::Values(Values { schema, values })), + LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? + .update_data(|predicate| { + LogicalPlan::Filter(Filter { predicate, input }) + }), + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) => match partitioning_scheme { + Partitioning::Hash(expr, usize) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| Partitioning::Hash(expr, usize)), + Partitioning::DistributeBy(expr) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(Partitioning::DistributeBy), + Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), + } + .update_data(|partitioning_scheme| { + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) + }), + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) => window_expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|window_expr| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) => map_until_stop_and_collect!( + group_expr.into_iter().map_until_stop_and_collect(&mut f), + aggr_expr, + aggr_expr.into_iter().map_until_stop_and_collect(&mut f) + )? + .update_data(|(group_expr, aggr_expr)| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }), + + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) => map_until_stop_and_collect!( + on.into_iter().map_until_stop_and_collect( + |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) + ), + filter, + filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }) + )? + .update_data(|(on, filter)| { + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) + }), + LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Extension(Extension { node }) => { + // would be nice to avoid this copy -- maybe can + // update extension to just observer Exprs + node.expressions() + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|exprs| { + LogicalPlan::Extension(Extension { + node: UserDefinedLogicalNode::from_template( + node.as_ref(), + exprs.as_slice(), + node.inputs() + .into_iter() + .cloned() + .collect::>() + .as_slice(), + ), + }) + }) + } + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) => filters + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|filters| { + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) + }), + LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + }) => f(Expr::Column(column))?.map_data(|column| match column { + Expr::Column(column) => Ok(LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + })), + _ => internal_err!("Transformation should return Column"), + })?, + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) => map_until_stop_and_collect!( + on_expr.into_iter().map_until_stop_and_collect(&mut f), + select_expr, + select_expr.into_iter().map_until_stop_and_collect(&mut f), + sort_expr, + transform_option_vec(sort_expr, &mut f) + )? + .update_data(|(on_expr, select_expr, sort_expr)| { + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) + }), + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => Transformed::no(self), + }) + } + /// returns all inputs of this `LogicalPlan` node. Does not /// include inputs to inputs, or subqueries. pub fn inputs(&self) -> Vec<&LogicalPlan> { @@ -417,7 +655,7 @@ impl LogicalPlan { pub fn using_columns(&self) -> Result>, DataFusionError> { let mut using_columns: Vec> = vec![]; - self.apply(&mut |plan| { + self.apply_with_subqueries(&mut |plan| { if let LogicalPlan::Join(Join { join_constraint: JoinConstraint::Using, on, @@ -1079,57 +1317,178 @@ impl LogicalPlan { } } +/// This macro is used to determine continuation during combined transforming +/// traversals. +macro_rules! handle_transform_recursion { + ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent($F_UP) + }}; +} + +macro_rules! handle_transform_recursion_down { + ($F_DOWN:expr, $F_CHILD:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD)) + }}; +} + +macro_rules! handle_transform_recursion_up { + ($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $SELF + .map_subqueries($F_CHILD)? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent(|n| $F_UP(n)) + }}; +} + impl LogicalPlan { - /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> Result<()> - where - F: FnMut(&Self) -> Result, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.apply(op)?; - } - _ => {} + pub fn visit_with_subqueries>( + &self, + visitor: &mut V, + ) -> Result { + visitor + .f_down(self)? + .visit_children(|| { + self.apply_subqueries(|c| c.visit_with_subqueries(visitor)) + })? + .visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))? + .visit_parent(|| visitor.f_up(self)) + } + + pub fn rewrite_with_subqueries>( + self, + rewriter: &mut R, + ) -> Result> { + handle_transform_recursion!( + rewriter.f_down(self), + |c| c.rewrite_with_subqueries(rewriter), + |n| rewriter.f_up(n) + ) + } + + pub fn apply_with_subqueries Result>( + &self, + f: &mut F, + ) -> Result { + f(self)? + .visit_children(|| self.apply_subqueries(|c| c.apply_with_subqueries(f)))? + .visit_sibling(|| self.apply_children(|c| c.apply_with_subqueries(f))) + } + + pub fn transform_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + self.transform_up_with_subqueries(f) + } + + pub fn transform_down_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c.transform_down_with_subqueries(f)) + } + + pub fn transform_down_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c + .transform_down_mut_with_subqueries(f)) + } + + pub fn transform_up_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_with_subqueries(f), f) + } + + pub fn transform_up_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_mut_with_subqueries(f), f) + } + + pub fn transform_down_up_with_subqueries< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( + self, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> { + handle_transform_recursion!( + f_down(self), + |c| c.transform_down_up_with_subqueries(f_down, f_up), + f_up + ) + } + + fn apply_subqueries Result>( + &self, + mut f: F, + ) -> Result { + self.apply_expressions(|expr| { + expr.apply(&mut |expr| match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + f(&LogicalPlan::Subquery(subquery.clone())) } - Ok::<(), DataFusionError>(()) + _ => Ok(TreeNodeRecursion::Continue), }) - })?; - Ok(()) + }) } - /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> Result<()> - where - V: TreeNodeVisitor, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the visitor sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.visit(v)?; - } - _ => {} + fn map_subqueries Result>>( + self, + mut f: F, + ) -> Result> { + self.map_expressions(|expr| { + expr.transform_down_mut(&mut |expr| match expr { + Expr::Exists(Exists { subquery, negated }) => { + f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::Exists(Exists { subquery, negated })) + } + _ => internal_err!("Transformation should return Subquery"), + }) } - Ok::<(), DataFusionError>(()) + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => Ok(Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + })), + _ => internal_err!("Transformation should return Subquery"), + }), + Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? + .map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::ScalarSubquery(subquery)) + } + _ => internal_err!("Transformation should return Subquery"), + }), + _ => Ok(Transformed::no(expr)), }) - })?; - Ok(()) + }) } /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, @@ -1165,8 +1524,8 @@ impl LogicalPlan { ) -> Result>, DataFusionError> { let mut param_types: HashMap> = HashMap::new(); - self.apply(&mut |plan| { - plan.inspect_expressions(|expr| { + self.apply_with_subqueries(&mut |plan| { + plan.apply_expressions(|expr| { expr.apply(&mut |expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { let prev = param_types.get(id); @@ -1183,13 +1542,10 @@ impl LogicalPlan { } } Ok(TreeNodeRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(param_types) + }) + }) + }) + .map(|_| param_types) } /// Return an Expr with all placeholders replaced with their @@ -1257,7 +1613,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = false; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1300,7 +1656,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = true; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1320,7 +1676,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = PgJsonVisitor::new(f); visitor.with_schema(true); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1369,12 +1725,16 @@ impl LogicalPlan { visitor.start_graph()?; visitor.pre_visit_plan("LogicalPlan")?; - self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; + self.0 + .visit_with_subqueries(&mut visitor) + .map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; visitor.set_with_schema(true); visitor.pre_visit_plan("Detailed LogicalPlan")?; - self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; + self.0 + .visit_with_subqueries(&mut visitor) + .map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; visitor.end_graph()?; @@ -2908,7 +3268,7 @@ digraph { fn visit_order() { let mut visitor = OkVisitor::default(); let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -2984,7 +3344,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -3000,7 +3360,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -3051,7 +3411,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor).unwrap_err(); + let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); assert_eq!( "This feature is not implemented: Error in pre_visit", res.strip_backtrace() @@ -3069,7 +3429,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor).unwrap_err(); + let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); assert_eq!( "This feature is not implemented: Error in post_visit", res.strip_backtrace() diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 97331720ce7d..85097f6249e1 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -412,7 +412,7 @@ where } /// &mut transform a Option<`Vec` of `Expr`s> -fn transform_option_vec( +pub fn transform_option_vec( ove: Option>, f: &mut F, ) -> Result>>> diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 7a6b1005fede..482fc96b519b 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -20,58 +20,11 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; use datafusion_common::Result; impl TreeNode for LogicalPlan { - fn apply Result>( - &self, - f: &mut F, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::apply_subqueries`] before visiting its children - f(self)?.visit_children(|| { - self.apply_subqueries(f)?; - self.apply_children(|n| n.apply(f)) - }) - } - - /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke - /// [`LogicalPlan::visit`]. - /// - /// For example, for a logical plan like: - /// - /// ```text - /// Projection: id - /// Filter: state Eq Utf8(\"CO\")\ - /// CsvScan: employee.csv projection=Some([0, 3])"; - /// ``` - /// - /// The sequence of visit operations would be: - /// ```text - /// visitor.pre_visit(Projection) - /// visitor.pre_visit(Filter) - /// visitor.pre_visit(CsvScan) - /// visitor.post_visit(CsvScan) - /// visitor.post_visit(Filter) - /// visitor.post_visit(Projection) - /// ``` - fn visit>( - &self, - visitor: &mut V, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::visit_subqueries`] before visiting its children - visitor - .f_down(self)? - .visit_children(|| { - self.visit_subqueries(visitor)?; - self.apply_children(|n| n.visit(visitor)) - })? - .visit_parent(|| visitor.f_up(self)) - } - fn apply_children Result>( &self, f: F, @@ -85,8 +38,8 @@ impl TreeNode for LogicalPlan { ) -> Result> { let new_children = self .inputs() - .iter() - .map(|&c| c.clone()) + .into_iter() + .cloned() .map_until_stop_and_collect(f)?; // Propagate up `new_children.transformed` and `new_children.tnr` // along with the node containing transformed children. diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index b446fe2f320e..d0b83d24299b 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -155,8 +155,8 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.apply(&mut |plan: &LogicalPlan| { - plan.inspect_expressions(|expr| { + plan.apply_with_subqueries(&mut |plan: &LogicalPlan| { + plan.apply_expressions(|expr| { // recursively look for subqueries expr.apply(&mut |expr| { match expr { @@ -168,11 +168,8 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { _ => {} }; Ok(TreeNodeRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(()) + }) + }) + }) + .map(|_| ()) } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 038361c3ee8c..79375e52da1f 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -283,7 +283,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; - inner_plan.apply(&mut |plan| { + inner_plan.apply_with_subqueries(&mut |plan| { if let LogicalPlan::Filter(Filter { predicate, .. }) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 4143d52a053e..a8e323ff429f 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -21,7 +21,7 @@ use std::{ num::NonZeroUsize, }; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_expr::LogicalPlan; /// Non-unique identifier of a [`LogicalPlan`]. @@ -73,7 +73,7 @@ impl LogicalPlanSignature { /// Get total number of [`LogicalPlan`]s in the plan. fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; - plan.apply(&mut |_plan| { + plan.apply_with_subqueries(&mut |_plan| { node_number += 1; Ok(TreeNodeRecursion::Continue) }) From fc29c3e67d43e82e4d1d49b44d150ff710ad7004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 8 Apr 2024 18:29:32 +0800 Subject: [PATCH 2/5] Remove unnecessary result (#9990) --- datafusion/common/src/dfschema.rs | 29 +++++++++++++---------------- datafusion/expr/src/utils.rs | 2 +- datafusion/sql/src/statement.rs | 2 +- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 9f167fd1f6d9..83e53b3cc6ff 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -319,7 +319,7 @@ impl DFSchema { &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result> { + ) -> Option { let mut matches = self .iter() .enumerate() @@ -345,19 +345,19 @@ impl DFSchema { (None, Some(_)) | (None, None) => f.name() == name, }) .map(|(idx, _)| idx); - Ok(matches.next()) + matches.next() } /// Find the index of the column with the given qualifier and name pub fn index_of_column(&self, col: &Column) -> Result { - self.index_of_column_by_name(col.relation.as_ref(), &col.name)? + self.index_of_column_by_name(col.relation.as_ref(), &col.name) .ok_or_else(|| field_not_found(col.relation.clone(), &col.name, self)) } /// Check if the column is in the current schema - pub fn is_column_from_schema(&self, col: &Column) -> Result { + pub fn is_column_from_schema(&self, col: &Column) -> bool { self.index_of_column_by_name(col.relation.as_ref(), &col.name) - .map(|idx| idx.is_some()) + .is_some() } /// Find the field with the given name @@ -381,7 +381,7 @@ impl DFSchema { ) -> Result<(Option<&TableReference>, &Field)> { if let Some(qualifier) = qualifier { let idx = self - .index_of_column_by_name(Some(qualifier), name)? + .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; Ok((self.field_qualifiers[idx].as_ref(), self.field(idx))) } else { @@ -519,7 +519,7 @@ impl DFSchema { name: &str, ) -> Result<&Field> { let idx = self - .index_of_column_by_name(Some(qualifier), name)? + .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; Ok(self.field(idx)) @@ -1190,11 +1190,8 @@ mod tests { .to_string(), expected_help ); - assert!(schema.index_of_column_by_name(None, "y").unwrap().is_none()); - assert!(schema - .index_of_column_by_name(None, "t1.c0") - .unwrap() - .is_none()); + assert!(schema.index_of_column_by_name(None, "y").is_none()); + assert!(schema.index_of_column_by_name(None, "t1.c0").is_none()); Ok(()) } @@ -1284,28 +1281,28 @@ mod tests { { let col = Column::from_qualified_name("t1.c0"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(schema.is_column_from_schema(&col)?); + assert!(schema.is_column_from_schema(&col)); } // qualified not exists { let col = Column::from_qualified_name("t1.c2"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(!schema.is_column_from_schema(&col)?); + assert!(!schema.is_column_from_schema(&col)); } // unqualified exists { let col = Column::from_name("c0"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(schema.is_column_from_schema(&col)?); + assert!(schema.is_column_from_schema(&col)); } // unqualified not exists { let col = Column::from_name("c2"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(!schema.is_column_from_schema(&col)?); + assert!(!schema.is_column_from_schema(&col)); } Ok(()) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 0d99d0b5028e..8c6b98f17933 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -933,7 +933,7 @@ pub fn check_all_columns_from_schema( schema: DFSchemaRef, ) -> Result { for col in columns.iter() { - let exist = schema.is_column_from_schema(col)?; + let exist = schema.is_column_from_schema(col); if !exist { return Ok(false); } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index b8c9172621c3..6b89f89aaccf 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1350,7 +1350,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .enumerate() .map(|(i, c)| { let column_index = table_schema - .index_of_column_by_name(None, &c)? + .index_of_column_by_name(None, &c) .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; if value_indices[column_index].is_some() { return schema_err!(SchemaError::DuplicateUnqualifiedField { From 820843ff597161c9cdacd0e79cecf20d05755081 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 8 Apr 2024 06:31:10 -0400 Subject: [PATCH 3/5] Removes Bloom filter for Int8/Int16/Uint8/Uint16 (#9969) * Removing broken tests * Simplifying tests / removing support for failed tests * Revert "Simplifying tests / removing support for failed tests" This reverts commit 6e50a8064436943d9f42d313cef2c2b017d196f1. * Fixing tests for real * Apply suggestions from code review Thanks @alamb ! Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- .../physical_plan/parquet/row_groups.rs | 4 -- .../core/tests/parquet/row_group_pruning.rs | 54 +++++++++---------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 8df4925fc566..6600dd07d7fd 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -232,12 +232,8 @@ impl PruningStatistics for BloomFilterStatistics { ScalarValue::Float32(Some(v)) => sbbf.check(v), ScalarValue::Int64(Some(v)) => sbbf.check(v), ScalarValue::Int32(Some(v)) => sbbf.check(v), - ScalarValue::Int16(Some(v)) => sbbf.check(v), - ScalarValue::Int8(Some(v)) => sbbf.check(v), ScalarValue::UInt64(Some(v)) => sbbf.check(v), ScalarValue::UInt32(Some(v)) => sbbf.check(v), - ScalarValue::UInt16(Some(v)) => sbbf.check(v), - ScalarValue::UInt8(Some(v)) => sbbf.check(v), ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { Type::INT32 => { //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index b7b434d1c3d3..8fc7936552af 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -290,7 +290,7 @@ async fn prune_disabled() { // https://github.com/apache/arrow-datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on Int8 and Int16 columns are still buggy. macro_rules! int_tests { - ($bits:expr, correct_bloom_filters: $correct_bloom_filters:expr) => { + ($bits:expr) => { paste::item! { #[tokio::test] async fn []() { @@ -329,9 +329,9 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -343,9 +343,9 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -404,9 +404,9 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -447,17 +447,16 @@ macro_rules! int_tests { }; } -int_tests!(8, correct_bloom_filters: false); -int_tests!(16, correct_bloom_filters: false); -int_tests!(32, correct_bloom_filters: true); -int_tests!(64, correct_bloom_filters: true); +// int8/int16 are incorrect: https://github.com/apache/arrow-datafusion/issues/9779 +int_tests!(32); +int_tests!(64); // $bits: number of bits of the integer to test (8, 16, 32, 64) // $correct_bloom_filters: if false, replicates the // https://github.com/apache/arrow-datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on UInt8 and UInt16 columns are still buggy. macro_rules! uint_tests { - ($bits:expr, correct_bloom_filters: $correct_bloom_filters:expr) => { + ($bits:expr) => { paste::item! { #[tokio::test] async fn []() { @@ -482,9 +481,9 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -496,9 +495,9 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -542,9 +541,9 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -585,10 +584,9 @@ macro_rules! uint_tests { }; } -uint_tests!(8, correct_bloom_filters: false); -uint_tests!(16, correct_bloom_filters: false); -uint_tests!(32, correct_bloom_filters: true); -uint_tests!(64, correct_bloom_filters: true); +// uint8/uint16 are incorrect: https://github.com/apache/arrow-datafusion/issues/9779 +uint_tests!(32); +uint_tests!(64); #[tokio::test] async fn prune_int32_eq_large_in_list() { From 86ad8a580863218f9fb123f09e1d058094ec3ef8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 8 Apr 2024 09:46:04 -0400 Subject: [PATCH 4/5] Move LogicalPlan tree_node modul (#9995) --- datafusion/expr/src/logical_plan/mod.rs | 1 + datafusion/expr/src/logical_plan/plan.rs | 2 +- .../plan.rs => logical_plan/tree_node.rs} | 0 .../src/{tree_node/expr.rs => tree_node.rs} | 0 datafusion/expr/src/tree_node/mod.rs | 21 ------------------- 5 files changed, 2 insertions(+), 22 deletions(-) rename datafusion/expr/src/{tree_node/plan.rs => logical_plan/tree_node.rs} (100%) rename datafusion/expr/src/{tree_node/expr.rs => tree_node.rs} (100%) delete mode 100644 datafusion/expr/src/tree_node/mod.rs diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 84781cb2e9ec..a1fe7a6f0a51 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -22,6 +22,7 @@ pub mod dml; mod extension; mod plan; mod statement; +mod tree_node; pub use builder::{ build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4f55bbfe3f4d..860fd7daafbf 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -56,7 +56,7 @@ use datafusion_common::{ // backwards compatibility use crate::display::PgJsonVisitor; -use crate::tree_node::expr::transform_option_vec; +use crate::tree_node::transform_option_vec; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/logical_plan/tree_node.rs similarity index 100% rename from datafusion/expr/src/tree_node/plan.rs rename to datafusion/expr/src/logical_plan/tree_node.rs diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node.rs similarity index 100% rename from datafusion/expr/src/tree_node/expr.rs rename to datafusion/expr/src/tree_node.rs diff --git a/datafusion/expr/src/tree_node/mod.rs b/datafusion/expr/src/tree_node/mod.rs deleted file mode 100644 index 3f8bb6d3257e..000000000000 --- a/datafusion/expr/src/tree_node/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Tree node implementation for logical expr and logical plan - -pub mod expr; -pub mod plan; From 8c9e5678228557aff370b137e9029462230df68a Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja <69668484+kevinmingtarja@users.noreply.github.com> Date: Tue, 9 Apr 2024 00:01:26 +0800 Subject: [PATCH 5/5] Optimize performance of substr_index and add tests (#9973) * Optimize performance of substr_index --- .../functions/src/unicode/substrindex.rs | 153 +++++++++++++++--- .../sqllogictest/test_files/functions.slt | 11 +- 2 files changed, 143 insertions(+), 21 deletions(-) diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index d00108a68fc9..da4ff55828e9 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder}; use arrow::datatypes::DataType; use datafusion_common::cast::{as_generic_string_array, as_int64_array}; @@ -101,38 +101,151 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { let delimiter_array = as_generic_string_array::(&args[1])?; let count_array = as_int64_array(&args[2])?; - let result = string_array + let mut builder = StringBuilder::new(); + string_array .iter() .zip(delimiter_array.iter()) .zip(count_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { + .for_each(|((string, delimiter), n)| match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { // In MySQL, these cases will return an empty string. if n == 0 || string.is_empty() || delimiter.is_empty() { - return Some(String::new()); + builder.append_value(""); + return; } - let splitted: Box> = if n > 0 { - Box::new(string.split(delimiter)) + let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); + let length = if n > 0 { + let splitted = string.split(delimiter); + splitted + .take(occurrences) + .map(|s| s.len() + delimiter.len()) + .sum::() + - delimiter.len() } else { - Box::new(string.rsplit(delimiter)) + let splitted = string.rsplit(delimiter); + splitted + .take(occurrences) + .map(|s| s.len() + delimiter.len()) + .sum::() + - delimiter.len() }; - let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - // The length of the substring covered by substr_index. - let length = splitted - .take(occurrences) // at least 1 element, since n != 0 - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len(); if n > 0 { - Some(string[..length].to_owned()) + match string.get(..length) { + Some(substring) => builder.append_value(substring), + None => builder.append_null(), + } } else { - Some(string[string.len() - length..].to_owned()) + match string.get(string.len().saturating_sub(length)..) { + Some(substring) => builder.append_value(substring), + None => builder.append_null(), + } } } - _ => None, - }) - .collect::>(); + _ => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} - Ok(Arc::new(result) as ArrayRef) +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::substrindex::SubstrIndexFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("www")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("www.apache")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ], + Ok(Some("apache.org")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ], + Ok(Some("org")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from("")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 21433ba16810..38ebedf5654a 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -940,7 +940,8 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM (VALUES ROW('arrow.apache.org'), ROW('.'), - ROW('...') + ROW('...'), + ROW(NULL) ) AS strings(str), (VALUES ROW(1), @@ -954,6 +955,14 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM ) AS occurrences(n) ORDER BY str DESC, n; ---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL arrow.apache.org -100 arrow.apache.org arrow.apache.org -3 arrow.apache.org arrow.apache.org -2 apache.org