From 1a002bccd420ff91ec149ee1ba9c42061510f906 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 7 Apr 2024 03:39:49 -0400 Subject: [PATCH 01/15] Minor: Improve documentation about optimizer (#9967) * Minor: Improve documentation about optimizer * fix unused commit --- datafusion/optimizer/src/analyzer/mod.rs | 1 + datafusion/optimizer/src/decorrelate.rs | 13 ++++++++++--- .../src/decorrelate_predicate_subquery.rs | 1 + datafusion/optimizer/src/eliminate_cross_join.rs | 2 +- .../optimizer/src/eliminate_duplicated_expr.rs | 2 ++ datafusion/optimizer/src/eliminate_filter.rs | 11 +++++++---- datafusion/optimizer/src/eliminate_join.rs | 3 ++- datafusion/optimizer/src/eliminate_limit.rs | 15 ++++++++------- .../optimizer/src/eliminate_nested_union.rs | 2 +- datafusion/optimizer/src/eliminate_one_union.rs | 2 +- datafusion/optimizer/src/eliminate_outer_join.rs | 2 +- .../optimizer/src/extract_equijoin_predicate.rs | 2 +- .../optimizer/src/filter_null_join_keys.rs | 5 +---- datafusion/optimizer/src/lib.rs | 16 +++++++++++++++- datafusion/optimizer/src/optimize_projections.rs | 16 ++++++++-------- datafusion/optimizer/src/optimizer.rs | 4 ++-- .../optimizer/src/propagate_empty_relation.rs | 1 + datafusion/optimizer/src/push_down_filter.rs | 3 +-- datafusion/optimizer/src/push_down_limit.rs | 7 ++++--- datafusion/optimizer/src/push_down_projection.rs | 3 --- .../optimizer/src/replace_distinct_aggregate.rs | 1 + .../src/rewrite_disjunctive_predicate.rs | 2 ++ .../optimizer/src/scalar_subquery_to_join.rs | 2 ++ .../optimizer/src/simplify_expressions/mod.rs | 3 +++ .../optimizer/src/single_distinct_to_groupby.rs | 2 +- .../optimizer/src/unwrap_cast_in_comparison.rs | 4 +--- datafusion/optimizer/src/utils.rs | 2 +- 27 files changed, 79 insertions(+), 48 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index c7eb6e895d57..b446fe2f320e 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`Analyzer`] and [`AnalyzerRule`] use std::sync::Arc; use log::debug; diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 12e84a63ea15..dbcf02b26ba6 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`PullUpCorrelatedExpr`] converts correlated subqueries to `Joins` + use std::collections::{BTreeSet, HashMap}; use std::ops::Deref; @@ -31,8 +33,11 @@ use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; -/// This struct rewrite the sub query plan by pull up the correlated expressions(contains outer reference columns) from the inner subquery's 'Filter'. -/// It adds the inner reference columns to the 'Projection' or 'Aggregate' of the subquery if they are missing, so that they can be evaluated by the parent operator as the join condition. +/// This struct rewrite the sub query plan by pull up the correlated +/// expressions(contains outer reference columns) from the inner subquery's +/// 'Filter'. It adds the inner reference columns to the 'Projection' or +/// 'Aggregate' of the subquery if they are missing, so that they can be +/// evaluated by the parent operator as the join condition. pub struct PullUpCorrelatedExpr { pub join_filters: Vec, // mapping from the plan to its holding correlated columns @@ -54,7 +59,9 @@ pub struct PullUpCorrelatedExpr { /// This is used to handle the Count bug pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; -/// Mapping from expr display name to its evaluation result on empty record batch (for example: 'count(*)' is 'ScalarValue(0)', 'count(*) + 2' is 'ScalarValue(2)') +/// Mapping from expr display name to its evaluation result on empty record +/// batch (for example: 'count(*)' is 'ScalarValue(0)', 'count(*) + 2' is +/// 'ScalarValue(2)') pub type ExprResultMap = HashMap; impl TreeNodeRewriter for PullUpCorrelatedExpr { diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index b94cf37c5c12..019e7507b122 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins use std::collections::BTreeSet; use std::ops::Deref; use std::sync::Arc; diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 7f65690a4a7c..18a9c05b9dc6 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to eliminate cross join to inner join if join predicates are available in filters. +//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. use std::collections::HashSet; use std::sync::Arc; diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index de05717a72e2..349d4d8878e0 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`EliminateDuplicatedExpr`] Removes redundant expressions + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index fea14342ca77..9411dc192beb 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace `where false or null` on a plan with an empty relation. -//! This saves time in planning and executing the query. -//! Note that this rule should be applied after simplify expressions optimizer rule. +//! [`EliminateFilter`] replaces `where false` or `where null` with an empty relation. + use crate::optimizer::ApplyOrder; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ @@ -27,7 +26,11 @@ use datafusion_expr::{ use crate::{OptimizerConfig, OptimizerRule}; -/// Optimization rule that eliminate the scalar value (true/false/null) filter with an [LogicalPlan::EmptyRelation] +/// Optimization rule that eliminate the scalar value (true/false/null) filter +/// with an [LogicalPlan::EmptyRelation] +/// +/// This saves time in planning and executing the query. +/// Note that this rule should be applied after simplify expressions optimizer rule. #[derive(Default)] pub struct EliminateFilter; diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 0dbebcc8a051..e685229c61b2 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`EliminateJoin`] rewrites `INNER JOIN` with `true`/`null` use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Result, ScalarValue}; @@ -24,7 +25,7 @@ use datafusion_expr::{ CrossJoin, Expr, }; -/// Eliminates joins when inner join condition is false. +/// Eliminates joins when join condition is false. /// Replaces joins when inner join condition is true with a cross join. #[derive(Default)] pub struct EliminateJoin; diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 4386253740aa..fb5d0d17b839 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -15,18 +15,19 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace `LIMIT 0` or -//! `LIMIT whose ancestor LIMIT's skip is greater than or equal to current's fetch` -//! on a plan with an empty relation. -//! This rule also removes OFFSET 0 from the [LogicalPlan] -//! This saves time in planning and executing the query. +//! [`EliminateLimit`] eliminates `LIMIT` when possible use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan}; -/// Optimization rule that eliminate LIMIT 0 or useless LIMIT(skip:0, fetch:None). -/// It can cooperate with `propagate_empty_relation` and `limit_push_down`. +/// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is +/// greater than or equal to current's fetch +/// +/// It can cooperate with `propagate_empty_relation` and `limit_push_down`. on a +/// plan with an empty relation. +/// +/// This rule also removes OFFSET 0 from the [LogicalPlan] #[derive(Default)] pub struct EliminateLimit; diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 5771ea2e19a2..924a0853418c 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace nested unions to single union. +//! [`EliminateNestedUnion`]: flattens nested `Union` to a single `Union` use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 70ee490346ff..63c3e789daa6 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to eliminate one union. +//! [`EliminateOneUnion`] eliminates single element `Union` use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::logical_plan::{LogicalPlan, Union}; diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 56a4a76987f7..a004da2bff19 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to eliminate left/right/full join to inner join if possible. +//! [`EliminateOuterJoin`] converts `LEFT/RIGHT/FULL` joins to `INNER` joins use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 24664d57c38d..4cfcd07b47d9 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ExtractEquijoinPredicate`] rule that extracts equijoin predicates +//! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::DFSchema; diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 95cd8a9fd36c..16039b182bb2 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -15,10 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! The FilterNullJoinKeys rule will identify inner joins with equi-join conditions -//! where the join key is nullable on one side and non-nullable on the other side -//! and then insert an `IsNotNull` filter on the nullable side since null values -//! can never match. +//! [`FilterNullJoinKeys`] adds filters to join inputs when input isn't nullable use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index b54facc5d682..f1f49727c39c 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -15,6 +15,19 @@ // specific language governing permissions and limitations // under the License. +//! # DataFusion Optimizer +//! +//! Contains rules for rewriting [`LogicalPlan`]s +//! +//! 1. [`Analyzer`] applies [`AnalyzerRule`]s to transform `LogicalPlan`s +//! to make the plan valid prior to the rest of the DataFusion optimization +//! process (for example, [`TypeCoercion`]). +//! +//! 2. [`Optimizer`] applies [`OptimizerRule`]s to transform `LogicalPlan`s +//! into equivalent, but more efficient plans. +//! +//! [`LogicalPlan`]: datafusion_expr::LogicalPlan +//! [`TypeCoercion`]: analyzer::type_coercion::TypeCoercion pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate; @@ -46,7 +59,8 @@ pub mod utils; #[cfg(test)] pub mod test; -pub use optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; +pub use analyzer::{Analyzer, AnalyzerRule}; +pub use optimizer::{Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; pub use utils::optimize_children; mod plan_signature; diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index c40a9bb704eb..147702cc0441 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -15,13 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to prune unnecessary columns from intermediate schemas -//! inside the [`LogicalPlan`]. This rule: -//! - Removes unnecessary columns that do not appear at the output and/or are -//! not used during any computation step. -//! - Adds projections to decrease table column size before operators that -//! benefit from a smaller memory footprint at its input. -//! - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. +//! [`OptimizeProjections`] identifies and eliminates unused columns use std::collections::HashSet; use std::sync::Arc; @@ -44,7 +38,13 @@ use datafusion_expr::utils::inspect_expr_pre; use hashbrown::HashMap; use itertools::{izip, Itertools}; -/// A rule for optimizing logical plans by removing unused columns/fields. +/// Optimizer rule to prune unnecessary columns from intermediate schemas +/// inside the [`LogicalPlan`]. This rule: +/// - Removes unnecessary columns that do not appear at the output and/or are +/// not used during any computation step. +/// - Adds projections to decrease table column size before operators that +/// benefit from a smaller memory footprint at its input. +/// - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. /// /// `OptimizeProjections` is an optimizer rule that identifies and eliminates /// columns from a logical plan that are not used by downstream operations. diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 3153f72d7ee7..03ff402c3e3f 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Query optimizer traits +//! [`Optimizer`] and [`OptimizerRule`] use std::collections::HashSet; use std::sync::Arc; @@ -54,7 +54,7 @@ use datafusion_expr::logical_plan::LogicalPlan; use chrono::{DateTime, Utc}; use log::{debug, warn}; -/// `OptimizerRule` transforms one [`LogicalPlan`] into another which +/// `OptimizerRule`s transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient /// way. If there are no suitable transformations for the input plan, /// the optimizer should simply return it unmodified. diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 55fb982d2a87..2aca6f93254a 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`PropagateEmptyRelation`] eliminates nodes fed by `EmptyRelation` use datafusion_common::{plan_err, Result}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{EmptyRelation, JoinType, Projection, Union}; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 83db4b0640a4..ff24df259adf 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -12,8 +12,7 @@ // specific language governing permissions and limitations // under the License. -//! [`PushDownFilter`] Moves filters so they are applied as early as possible in -//! the plan. +//! [`PushDownFilter`] applies filters as early as possible use std::collections::{HashMap, HashSet}; use std::sync::Arc; diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 33d02d5c5628..cca6c3fd9bd1 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to push down LIMIT in the query plan -//! It will push down through projection, limits (taking the smaller limit) +//! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan use std::sync::Arc; @@ -29,7 +28,9 @@ use datafusion_expr::logical_plan::{ }; use datafusion_expr::CrossJoin; -/// Optimization rule that tries to push down LIMIT. +/// Optimization rule that tries to push down `LIMIT`. +/// +//. It will push down through projection, limits (taking the smaller limit) #[derive(Default)] pub struct PushDownLimit {} diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index ccdcf2f65bc8..ae57ed9e5a34 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -15,9 +15,6 @@ // specific language governing permissions and limitations // under the License. -//! Projection Push Down optimizer rule ensures that only referenced columns are -//! loaded into memory - #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 0055e329c29d..752915be69c0 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...` use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 90c96b4b8b8c..059b1452ff3d 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`RewriteDisjunctivePredicate`] rewrites predicates to reduce redundancy + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 8acc36e479ca..a2c4eabcaae6 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s + use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 5244f9a5af88..d0399fef07e6 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +//! [`SimplifyExpressions`] simplifies expressions in the logical plan, +//! [`ExprSimplifier`] simplifies individual `Expr`s. + pub mod expr_simplifier; mod guarantees; mod inlist_simplifier; diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 5b47abb308d0..076bf4e24296 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! single distinct to group by optimizer rule +//! [`SingleDistinctToGroupBy`] replaces `AGG(DISTINCT ..)` with `AGG(..) GROUP BY ..` use std::sync::Arc; diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index f573ac69377b..fda390f37961 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type -//! of expr can be added if needed. -//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. +//! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)` use std::cmp::Ordering; use std::sync::Arc; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 0df79550f143..560c63b18882 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Collection of utility functions that are leveraged by the query optimizer rules +//! Utility functions leveraged by the query optimizer rules use std::collections::{BTreeSet, HashMap}; From 7acc8f16cf0776a4112a5e62214a44ad20c4c673 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 7 Apr 2024 18:03:27 +0200 Subject: [PATCH 02/15] use `Expr::apply()` instead of `inspect_expr_pre()` (#9984) --- datafusion/expr/src/utils.rs | 13 +++++++------ datafusion/optimizer/src/optimize_projections.rs | 6 +++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a93282574e8a..0d99d0b5028e 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -264,7 +264,7 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - inspect_expr_pre(expr, |expr| { + expr.apply(&mut |expr| { match expr { Expr::Column(qc) => { accum.insert(qc.clone()); @@ -307,8 +307,9 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} } - Ok(()) + Ok(TreeNodeRecursion::Continue) }) + .map(|_| ()) } /// Find excluded columns in the schema, if any @@ -838,11 +839,11 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { let mut exprs = vec![]; - inspect_expr_pre(e, |expr| { + e.apply(&mut |expr| { if let Expr::Column(c) = expr { exprs.push(c.clone()) } - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // As the closure always returns Ok, this "can't" error .expect("Unexpected error"); @@ -867,7 +868,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( schema: &DFSchemaRef, ) -> Vec { let mut indexes = vec![]; - inspect_expr_pre(e, |expr| { + e.apply(&mut |expr| { match expr { Expr::Column(qc) => { if let Ok(idx) = schema.index_of_column(qc) { @@ -879,7 +880,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( } _ => {} } - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) .unwrap(); indexes diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 147702cc0441..69905c990a7f 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -34,7 +34,7 @@ use datafusion_expr::{ Expr, Projection, TableScan, Window, }; -use datafusion_expr::utils::inspect_expr_pre; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use hashbrown::HashMap; use itertools::{izip, Itertools}; @@ -613,7 +613,7 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { /// columns are collected. fn outer_columns(expr: &Expr, columns: &mut HashSet) { // inspect_expr_pre doesn't handle subquery references, so find them explicitly - inspect_expr_pre(expr, |expr| { + expr.apply(&mut |expr| { match expr { Expr::OuterReferenceColumn(_, col) => { columns.insert(col.clone()); @@ -632,7 +632,7 @@ fn outer_columns(expr: &Expr, columns: &mut HashSet) { } _ => {} }; - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // unwrap: closure above never returns Err, so can not be Err here .unwrap(); From 85b4e40df9e9a5a71c08760452c2059a271313d1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 7 Apr 2024 13:39:44 -0400 Subject: [PATCH 03/15] Update documentation for COPY command (#9931) * Update documentation for COPY command * Fix example * prettier --- docs/source/user-guide/sql/dml.md | 39 +++++++++++++++--- docs/source/user-guide/sql/write_options.md | 45 ++++++++++----------- 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 79c36092fd3d..666e86b46002 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -35,8 +35,22 @@ TO 'file_name' [ OPTIONS( option [, ... ] ) ] +`STORED AS` specifies the file format the `COPY` command will write. If this +clause is not specified, it will be inferred from the file extension if possible. + +`PARTITIONED BY` specifies the columns to use for partitioning the output files into +separate hive-style directories. + +The output format is determined by the first match of the following rules: + +1. Value of `STORED AS` +2. Value of the `OPTION (FORMAT ..)` +3. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) + For a detailed list of valid OPTIONS, see [Write Options](write_options). +### Examples + Copy the contents of `source_table` to `file_name.json` in JSON format: ```sql @@ -72,6 +86,23 @@ of hive-style partitioned parquet files: +-------+ ``` +If the the data contains values of `x` and `y` in column1 and only `a` in +column2, output files will appear in the following directory structure: + +``` +dir_name/ + column1=x/ + column2=a/ + .parquet + .parquet + ... + column1=y/ + column2=a/ + .parquet + .parquet + ... +``` + Run the query `SELECT * from source ORDER BY time` and write the results (maintaining the order) to a parquet file named `output.parquet` with a maximum parquet row group size of 10MB: @@ -85,14 +116,10 @@ results (maintaining the order) to a parquet file named +-------+ ``` -The output format is determined by the first match of the following rules: - -1. Value of `STORED AS` -2. Value of the `OPTION (FORMAT ..)` -3. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) - ## INSERT +### Examples + Insert values into a table.
diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md
index ac0a41a97f07..5c204d8fc0e6 100644
--- a/docs/source/user-guide/sql/write_options.md
+++ b/docs/source/user-guide/sql/write_options.md
@@ -35,44 +35,41 @@ If inserting to an external table, table specific write options can be specified
 
 ```sql
 CREATE EXTERNAL TABLE
-my_table(a bigint, b bigint)
-STORED AS csv
-COMPRESSION TYPE gzip
-WITH HEADER ROW
-DELIMITER ';'
-LOCATION '/test/location/my_csv_table/'
-OPTIONS(
-NULL_VALUE 'NAN'
-);
+  my_table(a bigint, b bigint)
+  STORED AS csv
+  COMPRESSION TYPE gzip
+  WITH HEADER ROW
+  DELIMITER ';'
+  LOCATION '/test/location/my_csv_table/'
+  OPTIONS(
+    NULL_VALUE 'NAN'
+  )
 ```
 
 When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). There will be a single output file if the output path doesn't have folder format, i.e. ending with a `\`. Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file.
 
 Finally, options can be passed when running a `COPY` command.
 
+
+
 ```sql
 COPY source_table
-TO 'test/table_with_options'
-(format parquet,
-compression snappy,
-'compression::col1' 'zstd(5)',
-partition_by 'column3, column4'
-)
+  TO 'test/table_with_options'
+  PARTITIONED BY (column3, column4)
+  OPTIONS (
+    format parquet,
+    compression snappy,
+    'compression::column1' 'zstd(5)',
+  )
 ```
 
 In this example, we write the entirety of `source_table` out to a folder of parquet files. One parquet file will be written in parallel to the folder for each partition in the query. The next option `compression` set to `snappy` indicates that unless otherwise specified all columns should use the snappy compression codec. The option `compression::col1` sets an override, so that the column `col1` in the parquet file will use `ZSTD` compression codec with compression level `5`. In general, parquet options which support column specific settings can be specified with the syntax `OPTION::COLUMN.NESTED.PATH`.
 
 ## Available Options
 
-### COPY Specific Options
-
-The following special options are specific to the `COPY` command.
-
-| Option       | Description                                                                                                                                                                         | Default Value |
-| ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- |
-| FORMAT       | Specifies the file format COPY query will write out. If there're more than one output file or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A           |
-| PARTITION_BY | Specifies the columns that the output files should be partitioned by into separate hive-style directories. Value should be a comma separated string literal, e.g. 'col1,col2'       | N/A           |
-
 ### JSON Format Specific Options
 
 The following options are available when writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail.

From 9ce21e11ea8574ea2b650d80bf09327db343887f Mon Sep 17 00:00:00 2001
From: Andrew Lamb 
Date: Sun, 7 Apr 2024 21:53:41 -0400
Subject: [PATCH 04/15] Minor: fix bug in pruning predicate doc (#9986)

* Minor: fix bug in pruning predicate doc

* formatting
---
 datafusion/core/src/physical_optimizer/pruning.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs
index 80bb5ad42e81..19e71a92a706 100644
--- a/datafusion/core/src/physical_optimizer/pruning.rs
+++ b/datafusion/core/src/physical_optimizer/pruning.rs
@@ -330,7 +330,7 @@ pub trait PruningStatistics {
 /// `x = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END`
 /// `x < 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_max < 5 END`
 /// `x = 5 AND y = 10` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END AND CASE WHEN y_null_count = y_row_count THEN false ELSE y_min <= 10 AND 10 <= y_max END`
-/// `x IS NULL`  | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_null_count > 0 END`
+/// `x IS NULL`  | `x_null_count > 0`
 /// `CAST(x as int) = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int) END`
 ///
 /// ## Predicate Evaluation

From 215f30f74a12e91fd7dca0d30e37014c8c3493ed Mon Sep 17 00:00:00 2001
From: Jonah Gao 
Date: Mon, 8 Apr 2024 11:08:06 +0800
Subject: [PATCH 05/15] fix: improve `unnest_generic_list` handling of null
 list (#9975)

* fix: improve `unnest_generic_list` handling of null list

* fix clippy

* fix comment
---
 datafusion/physical-plan/src/unnest.rs | 139 +++++++++++++++++++++----
 1 file changed, 117 insertions(+), 22 deletions(-)

diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs
index 324e2ea2d773..6ea1b3c40c83 100644
--- a/datafusion/physical-plan/src/unnest.rs
+++ b/datafusion/physical-plan/src/unnest.rs
@@ -364,32 +364,31 @@ fn unnest_generic_list>(
     options: &UnnestOptions,
 ) -> Result> {
     let values = list_array.values();
-    if list_array.null_count() == 0 || !options.preserve_nulls {
-        Ok(values.clone())
-    } else {
-        let mut take_indicies_builder =
-            PrimitiveArray::

::builder(values.len() + list_array.null_count()); - let mut take_offset = 0; + if list_array.null_count() == 0 { + return Ok(values.clone()); + } - list_array.iter().for_each(|elem| match elem { - Some(array) => { - for i in 0..array.len() { - // take_offset + i is always positive - let take_index = P::Native::from_usize(take_offset + i).unwrap(); - take_indicies_builder.append_value(take_index); - } - take_offset += array.len(); - } - None => { + let mut take_indicies_builder = + PrimitiveArray::

::builder(values.len() + list_array.null_count()); + let offsets = list_array.value_offsets(); + for row in 0..list_array.len() { + if list_array.is_null(row) { + if options.preserve_nulls { take_indicies_builder.append_null(); } - }); - Ok(kernels::take::take( - &values, - &take_indicies_builder.finish(), - None, - )?) + } else { + let start = offsets[row].as_usize(); + let end = offsets[row + 1].as_usize(); + for idx in start..end { + take_indicies_builder.append_value(P::Native::from_usize(idx).unwrap()); + } + } } + Ok(kernels::take::take( + &values, + &take_indicies_builder.finish(), + None, + )?) } fn build_batch_fixedsize_list( @@ -596,3 +595,99 @@ where Ok(RecordBatch::try_new(schema.clone(), arrays.to_vec())?) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::AsArray, + datatypes::{DataType, Field}, + }; + use arrow_array::StringArray; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; + + // Create a ListArray with the following list values: + // [A, B, C], [], NULL, [D], NULL, [NULL, F] + fn make_test_array() -> ListArray { + let mut values = vec![]; + let mut offsets = vec![0]; + let mut valid = BooleanBufferBuilder::new(2); + + // [A, B, C] + values.extend_from_slice(&[Some("A"), Some("B"), Some("C")]); + offsets.push(values.len() as i32); + valid.append(true); + + // [] + offsets.push(values.len() as i32); + valid.append(true); + + // NULL with non-zero value length + // Issue https://github.com/apache/arrow-datafusion/issues/9932 + values.push(Some("?")); + offsets.push(values.len() as i32); + valid.append(false); + + // [D] + values.push(Some("D")); + offsets.push(values.len() as i32); + valid.append(true); + + // Another NULL with zero value length + offsets.push(values.len() as i32); + valid.append(false); + + // [NULL, F] + values.extend_from_slice(&[None, Some("F")]); + offsets.push(values.len() as i32); + valid.append(true); + + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + ListArray::new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(StringArray::from(values)), + Some(NullBuffer::new(valid.finish())), + ) + } + + #[test] + fn test_unnest_generic_list() -> datafusion_common::Result<()> { + let list_array = make_test_array(); + + // Test with preserve_nulls = false + let options = UnnestOptions { + preserve_nulls: false, + }; + let unnested_array = + unnest_generic_list::(&list_array, &options)?; + let strs = unnested_array.as_string::().iter().collect::>(); + assert_eq!( + strs, + vec![Some("A"), Some("B"), Some("C"), Some("D"), None, Some("F")] + ); + + // Test with preserve_nulls = true + let options = UnnestOptions { + preserve_nulls: true, + }; + let unnested_array = + unnest_generic_list::(&list_array, &options)?; + let strs = unnested_array.as_string::().iter().collect::>(); + assert_eq!( + strs, + vec![ + Some("A"), + Some("B"), + Some("C"), + None, + Some("D"), + None, + None, + Some("F") + ] + ); + + Ok(()) + } +} From 0a4d9a6c788c1e4ad340943492abb823bd31c4f9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 8 Apr 2024 11:28:59 +0200 Subject: [PATCH 06/15] Consistent LogicalPlan subquery handling in TreeNode::apply and TreeNode::visit (#9913) * fix * clippy * remove accidental extra apply * minor fixes * fix `LogicalPlan::apply_expressions()` and `LogicalPlan::map_subqueries()` * fix `LogicalPlan::visit_with_subqueries()` * Add deprecated LogicalPlan::inspect_expressions --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/tree_node.rs | 3 +- datafusion/core/src/execution/context/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 558 ++++++++++++++---- datafusion/expr/src/tree_node/expr.rs | 2 +- datafusion/expr/src/tree_node/plan.rs | 53 +- datafusion/optimizer/src/analyzer/mod.rs | 15 +- datafusion/optimizer/src/analyzer/subquery.rs | 2 +- datafusion/optimizer/src/plan_signature.rs | 4 +- 8 files changed, 475 insertions(+), 166 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 8e088e7a0b56..42514537e28d 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -25,10 +25,9 @@ use crate::Result; /// These macros are used to determine continuation during transforming traversals. macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ - #[allow(clippy::redundant_closure_call)] $F_DOWN? .transform_children(|n| n.map_children($F_CHILD))? - .transform_parent(|n| $F_UP(n)) + .transform_parent($F_UP) }}; } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index f15c1c218db6..9e48c7b8a6f2 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -67,7 +67,7 @@ use datafusion_common::{ alias::AliasGenerator, config::{ConfigExtension, TableOptions}, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, + tree_node::{TreeNodeRecursion, TreeNodeVisitor}, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; @@ -2298,7 +2298,7 @@ impl SQLOptions { /// Return an error if the [`LogicalPlan`] has any nodes that are /// incompatible with this [`SQLOptions`]. pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> { - plan.visit(&mut BadPlanVisitor::new(self))?; + plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?; Ok(()) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 3d40dcae0e4b..4f55bbfe3f4d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -34,8 +34,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, - split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, @@ -45,16 +44,19 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TransformedResult, TreeNode, TreeNodeIterator, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ - aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, + aggregate_functional_dependencies, internal_err, map_until_stop_and_collect, + plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, + FunctionalDependence, FunctionalDependencies, ParamValues, Result, TableReference, + UnnestOptions, }; // backwards compatibility use crate::display::PgJsonVisitor; +use crate::tree_node::expr::transform_option_vec; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -248,9 +250,9 @@ impl LogicalPlan { /// DataFusion's optimizer attempts to optimize them away. pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(|e| { exprs.push(e.clone()); - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -261,13 +263,13 @@ impl LogicalPlan { /// logical plan nodes and all its descendant nodes. pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(|e| { find_out_reference_exprs(e).into_iter().for_each(|e| { if !exprs.contains(&e) { exprs.push(e) } }); - Ok(()) as Result<(), DataFusionError> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -282,60 +284,81 @@ impl LogicalPlan { exprs } - /// Calls `f` on all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children. + #[deprecated(since = "37.0.0", note = "Use `apply_expressions` instead")] pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, { + let mut err = Ok(()); + self.apply_expressions(|e| { + if let Err(e) = f(e) { + // save the error for later (it may not be a DataFusionError + err = Err(e); + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + // The closure always returns OK, so this will always too + .expect("no way to return error during recursion"); + + err + } + + /// Calls `f` on all expressions (non-recursively) in the current + /// logical plan node. This does not include expressions in any + /// children. + pub fn apply_expressions Result>( + &self, + mut f: F, + ) -> Result { match self { LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().try_for_each(f) - } - LogicalPlan::Values(Values { values, .. }) => { - values.iter().flatten().try_for_each(f) + expr.iter().apply_until_stop(f) } + LogicalPlan::Values(Values { values, .. }) => values + .iter() + .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { - Partitioning::Hash(expr, _) => expr.iter().try_for_each(f), - Partitioning::DistributeBy(expr) => expr.iter().try_for_each(f), - Partitioning::RoundRobinBatch(_) => Ok(()), + Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { + expr.iter().apply_until_stop(f) + } + Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().try_for_each(f) + window_expr.iter().apply_until_stop(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr.iter().chain(aggr_expr.iter()).try_for_each(f), + }) => group_expr + .iter() + .chain(aggr_expr.iter()) + .apply_until_stop(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). LogicalPlan::Join(Join { on, filter, .. }) => { on.iter() + // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... // it not ideal to create an expr here to analyze them, but could cache it on the Join itself .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .try_for_each(|e| f(&e))?; - - if let Some(filter) = filter.as_ref() { - f(filter) - } else { - Ok(()) - } + .apply_until_stop(|e| f(&e))? + .visit_sibling(|| filter.iter().apply_until_stop(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().try_for_each(f), + LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().try_for_each(f) + extension.node.expressions().iter().apply_until_stop(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().try_for_each(f) + filters.iter().apply_until_stop(f) } LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) @@ -348,8 +371,8 @@ impl LogicalPlan { })) => on_expr .iter() .chain(select_expr.iter()) - .chain(sort_expr.clone().unwrap_or(vec![]).iter()) - .try_for_each(f), + .chain(sort_expr.iter().flatten()) + .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) @@ -366,10 +389,225 @@ impl LogicalPlan { | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Ok(()), + | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), } } + pub fn map_expressions Result>>( + self, + mut f: F, + ) -> Result> { + Ok(match self { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), + LogicalPlan::Values(Values { schema, values }) => values + .into_iter() + .map_until_stop_and_collect(|value| { + value.into_iter().map_until_stop_and_collect(&mut f) + })? + .update_data(|values| LogicalPlan::Values(Values { schema, values })), + LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? + .update_data(|predicate| { + LogicalPlan::Filter(Filter { predicate, input }) + }), + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) => match partitioning_scheme { + Partitioning::Hash(expr, usize) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| Partitioning::Hash(expr, usize)), + Partitioning::DistributeBy(expr) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(Partitioning::DistributeBy), + Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), + } + .update_data(|partitioning_scheme| { + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) + }), + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) => window_expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|window_expr| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) => map_until_stop_and_collect!( + group_expr.into_iter().map_until_stop_and_collect(&mut f), + aggr_expr, + aggr_expr.into_iter().map_until_stop_and_collect(&mut f) + )? + .update_data(|(group_expr, aggr_expr)| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }), + + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) => map_until_stop_and_collect!( + on.into_iter().map_until_stop_and_collect( + |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) + ), + filter, + filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }) + )? + .update_data(|(on, filter)| { + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) + }), + LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Extension(Extension { node }) => { + // would be nice to avoid this copy -- maybe can + // update extension to just observer Exprs + node.expressions() + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|exprs| { + LogicalPlan::Extension(Extension { + node: UserDefinedLogicalNode::from_template( + node.as_ref(), + exprs.as_slice(), + node.inputs() + .into_iter() + .cloned() + .collect::>() + .as_slice(), + ), + }) + }) + } + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) => filters + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|filters| { + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) + }), + LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + }) => f(Expr::Column(column))?.map_data(|column| match column { + Expr::Column(column) => Ok(LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + })), + _ => internal_err!("Transformation should return Column"), + })?, + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) => map_until_stop_and_collect!( + on_expr.into_iter().map_until_stop_and_collect(&mut f), + select_expr, + select_expr.into_iter().map_until_stop_and_collect(&mut f), + sort_expr, + transform_option_vec(sort_expr, &mut f) + )? + .update_data(|(on_expr, select_expr, sort_expr)| { + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) + }), + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => Transformed::no(self), + }) + } + /// returns all inputs of this `LogicalPlan` node. Does not /// include inputs to inputs, or subqueries. pub fn inputs(&self) -> Vec<&LogicalPlan> { @@ -417,7 +655,7 @@ impl LogicalPlan { pub fn using_columns(&self) -> Result>, DataFusionError> { let mut using_columns: Vec> = vec![]; - self.apply(&mut |plan| { + self.apply_with_subqueries(&mut |plan| { if let LogicalPlan::Join(Join { join_constraint: JoinConstraint::Using, on, @@ -1079,57 +1317,178 @@ impl LogicalPlan { } } +/// This macro is used to determine continuation during combined transforming +/// traversals. +macro_rules! handle_transform_recursion { + ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent($F_UP) + }}; +} + +macro_rules! handle_transform_recursion_down { + ($F_DOWN:expr, $F_CHILD:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD)) + }}; +} + +macro_rules! handle_transform_recursion_up { + ($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $SELF + .map_subqueries($F_CHILD)? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent(|n| $F_UP(n)) + }}; +} + impl LogicalPlan { - /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> Result<()> - where - F: FnMut(&Self) -> Result, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.apply(op)?; - } - _ => {} + pub fn visit_with_subqueries>( + &self, + visitor: &mut V, + ) -> Result { + visitor + .f_down(self)? + .visit_children(|| { + self.apply_subqueries(|c| c.visit_with_subqueries(visitor)) + })? + .visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))? + .visit_parent(|| visitor.f_up(self)) + } + + pub fn rewrite_with_subqueries>( + self, + rewriter: &mut R, + ) -> Result> { + handle_transform_recursion!( + rewriter.f_down(self), + |c| c.rewrite_with_subqueries(rewriter), + |n| rewriter.f_up(n) + ) + } + + pub fn apply_with_subqueries Result>( + &self, + f: &mut F, + ) -> Result { + f(self)? + .visit_children(|| self.apply_subqueries(|c| c.apply_with_subqueries(f)))? + .visit_sibling(|| self.apply_children(|c| c.apply_with_subqueries(f))) + } + + pub fn transform_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + self.transform_up_with_subqueries(f) + } + + pub fn transform_down_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c.transform_down_with_subqueries(f)) + } + + pub fn transform_down_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c + .transform_down_mut_with_subqueries(f)) + } + + pub fn transform_up_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_with_subqueries(f), f) + } + + pub fn transform_up_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_mut_with_subqueries(f), f) + } + + pub fn transform_down_up_with_subqueries< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( + self, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> { + handle_transform_recursion!( + f_down(self), + |c| c.transform_down_up_with_subqueries(f_down, f_up), + f_up + ) + } + + fn apply_subqueries Result>( + &self, + mut f: F, + ) -> Result { + self.apply_expressions(|expr| { + expr.apply(&mut |expr| match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + f(&LogicalPlan::Subquery(subquery.clone())) } - Ok::<(), DataFusionError>(()) + _ => Ok(TreeNodeRecursion::Continue), }) - })?; - Ok(()) + }) } - /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> Result<()> - where - V: TreeNodeVisitor, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the visitor sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.visit(v)?; - } - _ => {} + fn map_subqueries Result>>( + self, + mut f: F, + ) -> Result> { + self.map_expressions(|expr| { + expr.transform_down_mut(&mut |expr| match expr { + Expr::Exists(Exists { subquery, negated }) => { + f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::Exists(Exists { subquery, negated })) + } + _ => internal_err!("Transformation should return Subquery"), + }) } - Ok::<(), DataFusionError>(()) + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => Ok(Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + })), + _ => internal_err!("Transformation should return Subquery"), + }), + Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? + .map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::ScalarSubquery(subquery)) + } + _ => internal_err!("Transformation should return Subquery"), + }), + _ => Ok(Transformed::no(expr)), }) - })?; - Ok(()) + }) } /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, @@ -1165,8 +1524,8 @@ impl LogicalPlan { ) -> Result>, DataFusionError> { let mut param_types: HashMap> = HashMap::new(); - self.apply(&mut |plan| { - plan.inspect_expressions(|expr| { + self.apply_with_subqueries(&mut |plan| { + plan.apply_expressions(|expr| { expr.apply(&mut |expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { let prev = param_types.get(id); @@ -1183,13 +1542,10 @@ impl LogicalPlan { } } Ok(TreeNodeRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(param_types) + }) + }) + }) + .map(|_| param_types) } /// Return an Expr with all placeholders replaced with their @@ -1257,7 +1613,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = false; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1300,7 +1656,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = true; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1320,7 +1676,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = PgJsonVisitor::new(f); visitor.with_schema(true); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1369,12 +1725,16 @@ impl LogicalPlan { visitor.start_graph()?; visitor.pre_visit_plan("LogicalPlan")?; - self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; + self.0 + .visit_with_subqueries(&mut visitor) + .map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; visitor.set_with_schema(true); visitor.pre_visit_plan("Detailed LogicalPlan")?; - self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; + self.0 + .visit_with_subqueries(&mut visitor) + .map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; visitor.end_graph()?; @@ -2908,7 +3268,7 @@ digraph { fn visit_order() { let mut visitor = OkVisitor::default(); let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -2984,7 +3344,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -3000,7 +3360,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -3051,7 +3411,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor).unwrap_err(); + let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); assert_eq!( "This feature is not implemented: Error in pre_visit", res.strip_backtrace() @@ -3069,7 +3429,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor).unwrap_err(); + let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); assert_eq!( "This feature is not implemented: Error in post_visit", res.strip_backtrace() diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 97331720ce7d..85097f6249e1 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -412,7 +412,7 @@ where } /// &mut transform a Option<`Vec` of `Expr`s> -fn transform_option_vec( +pub fn transform_option_vec( ove: Option>, f: &mut F, ) -> Result>>> diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 7a6b1005fede..482fc96b519b 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -20,58 +20,11 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; use datafusion_common::Result; impl TreeNode for LogicalPlan { - fn apply Result>( - &self, - f: &mut F, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::apply_subqueries`] before visiting its children - f(self)?.visit_children(|| { - self.apply_subqueries(f)?; - self.apply_children(|n| n.apply(f)) - }) - } - - /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke - /// [`LogicalPlan::visit`]. - /// - /// For example, for a logical plan like: - /// - /// ```text - /// Projection: id - /// Filter: state Eq Utf8(\"CO\")\ - /// CsvScan: employee.csv projection=Some([0, 3])"; - /// ``` - /// - /// The sequence of visit operations would be: - /// ```text - /// visitor.pre_visit(Projection) - /// visitor.pre_visit(Filter) - /// visitor.pre_visit(CsvScan) - /// visitor.post_visit(CsvScan) - /// visitor.post_visit(Filter) - /// visitor.post_visit(Projection) - /// ``` - fn visit>( - &self, - visitor: &mut V, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::visit_subqueries`] before visiting its children - visitor - .f_down(self)? - .visit_children(|| { - self.visit_subqueries(visitor)?; - self.apply_children(|n| n.visit(visitor)) - })? - .visit_parent(|| visitor.f_up(self)) - } - fn apply_children Result>( &self, f: F, @@ -85,8 +38,8 @@ impl TreeNode for LogicalPlan { ) -> Result> { let new_children = self .inputs() - .iter() - .map(|&c| c.clone()) + .into_iter() + .cloned() .map_until_stop_and_collect(f)?; // Propagate up `new_children.transformed` and `new_children.tnr` // along with the node containing transformed children. diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index b446fe2f320e..d0b83d24299b 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -155,8 +155,8 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.apply(&mut |plan: &LogicalPlan| { - plan.inspect_expressions(|expr| { + plan.apply_with_subqueries(&mut |plan: &LogicalPlan| { + plan.apply_expressions(|expr| { // recursively look for subqueries expr.apply(&mut |expr| { match expr { @@ -168,11 +168,8 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { _ => {} }; Ok(TreeNodeRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(()) + }) + }) + }) + .map(|_| ()) } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 038361c3ee8c..79375e52da1f 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -283,7 +283,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; - inner_plan.apply(&mut |plan| { + inner_plan.apply_with_subqueries(&mut |plan| { if let LogicalPlan::Filter(Filter { predicate, .. }) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 4143d52a053e..a8e323ff429f 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -21,7 +21,7 @@ use std::{ num::NonZeroUsize, }; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_expr::LogicalPlan; /// Non-unique identifier of a [`LogicalPlan`]. @@ -73,7 +73,7 @@ impl LogicalPlanSignature { /// Get total number of [`LogicalPlan`]s in the plan. fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; - plan.apply(&mut |_plan| { + plan.apply_with_subqueries(&mut |_plan| { node_number += 1; Ok(TreeNodeRecursion::Continue) }) From fc29c3e67d43e82e4d1d49b44d150ff710ad7004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 8 Apr 2024 18:29:32 +0800 Subject: [PATCH 07/15] Remove unnecessary result (#9990) --- datafusion/common/src/dfschema.rs | 29 +++++++++++++---------------- datafusion/expr/src/utils.rs | 2 +- datafusion/sql/src/statement.rs | 2 +- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 9f167fd1f6d9..83e53b3cc6ff 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -319,7 +319,7 @@ impl DFSchema { &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result> { + ) -> Option { let mut matches = self .iter() .enumerate() @@ -345,19 +345,19 @@ impl DFSchema { (None, Some(_)) | (None, None) => f.name() == name, }) .map(|(idx, _)| idx); - Ok(matches.next()) + matches.next() } /// Find the index of the column with the given qualifier and name pub fn index_of_column(&self, col: &Column) -> Result { - self.index_of_column_by_name(col.relation.as_ref(), &col.name)? + self.index_of_column_by_name(col.relation.as_ref(), &col.name) .ok_or_else(|| field_not_found(col.relation.clone(), &col.name, self)) } /// Check if the column is in the current schema - pub fn is_column_from_schema(&self, col: &Column) -> Result { + pub fn is_column_from_schema(&self, col: &Column) -> bool { self.index_of_column_by_name(col.relation.as_ref(), &col.name) - .map(|idx| idx.is_some()) + .is_some() } /// Find the field with the given name @@ -381,7 +381,7 @@ impl DFSchema { ) -> Result<(Option<&TableReference>, &Field)> { if let Some(qualifier) = qualifier { let idx = self - .index_of_column_by_name(Some(qualifier), name)? + .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; Ok((self.field_qualifiers[idx].as_ref(), self.field(idx))) } else { @@ -519,7 +519,7 @@ impl DFSchema { name: &str, ) -> Result<&Field> { let idx = self - .index_of_column_by_name(Some(qualifier), name)? + .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; Ok(self.field(idx)) @@ -1190,11 +1190,8 @@ mod tests { .to_string(), expected_help ); - assert!(schema.index_of_column_by_name(None, "y").unwrap().is_none()); - assert!(schema - .index_of_column_by_name(None, "t1.c0") - .unwrap() - .is_none()); + assert!(schema.index_of_column_by_name(None, "y").is_none()); + assert!(schema.index_of_column_by_name(None, "t1.c0").is_none()); Ok(()) } @@ -1284,28 +1281,28 @@ mod tests { { let col = Column::from_qualified_name("t1.c0"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(schema.is_column_from_schema(&col)?); + assert!(schema.is_column_from_schema(&col)); } // qualified not exists { let col = Column::from_qualified_name("t1.c2"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(!schema.is_column_from_schema(&col)?); + assert!(!schema.is_column_from_schema(&col)); } // unqualified exists { let col = Column::from_name("c0"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(schema.is_column_from_schema(&col)?); + assert!(schema.is_column_from_schema(&col)); } // unqualified not exists { let col = Column::from_name("c2"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(!schema.is_column_from_schema(&col)?); + assert!(!schema.is_column_from_schema(&col)); } Ok(()) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 0d99d0b5028e..8c6b98f17933 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -933,7 +933,7 @@ pub fn check_all_columns_from_schema( schema: DFSchemaRef, ) -> Result { for col in columns.iter() { - let exist = schema.is_column_from_schema(col)?; + let exist = schema.is_column_from_schema(col); if !exist { return Ok(false); } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index b8c9172621c3..6b89f89aaccf 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1350,7 +1350,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .enumerate() .map(|(i, c)| { let column_index = table_schema - .index_of_column_by_name(None, &c)? + .index_of_column_by_name(None, &c) .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; if value_indices[column_index].is_some() { return schema_err!(SchemaError::DuplicateUnqualifiedField { From 820843ff597161c9cdacd0e79cecf20d05755081 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 8 Apr 2024 06:31:10 -0400 Subject: [PATCH 08/15] Removes Bloom filter for Int8/Int16/Uint8/Uint16 (#9969) * Removing broken tests * Simplifying tests / removing support for failed tests * Revert "Simplifying tests / removing support for failed tests" This reverts commit 6e50a8064436943d9f42d313cef2c2b017d196f1. * Fixing tests for real * Apply suggestions from code review Thanks @alamb ! Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- .../physical_plan/parquet/row_groups.rs | 4 -- .../core/tests/parquet/row_group_pruning.rs | 54 +++++++++---------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 8df4925fc566..6600dd07d7fd 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -232,12 +232,8 @@ impl PruningStatistics for BloomFilterStatistics { ScalarValue::Float32(Some(v)) => sbbf.check(v), ScalarValue::Int64(Some(v)) => sbbf.check(v), ScalarValue::Int32(Some(v)) => sbbf.check(v), - ScalarValue::Int16(Some(v)) => sbbf.check(v), - ScalarValue::Int8(Some(v)) => sbbf.check(v), ScalarValue::UInt64(Some(v)) => sbbf.check(v), ScalarValue::UInt32(Some(v)) => sbbf.check(v), - ScalarValue::UInt16(Some(v)) => sbbf.check(v), - ScalarValue::UInt8(Some(v)) => sbbf.check(v), ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { Type::INT32 => { //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index b7b434d1c3d3..8fc7936552af 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -290,7 +290,7 @@ async fn prune_disabled() { // https://github.com/apache/arrow-datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on Int8 and Int16 columns are still buggy. macro_rules! int_tests { - ($bits:expr, correct_bloom_filters: $correct_bloom_filters:expr) => { + ($bits:expr) => { paste::item! { #[tokio::test] async fn []() { @@ -329,9 +329,9 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -343,9 +343,9 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -404,9 +404,9 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -447,17 +447,16 @@ macro_rules! int_tests { }; } -int_tests!(8, correct_bloom_filters: false); -int_tests!(16, correct_bloom_filters: false); -int_tests!(32, correct_bloom_filters: true); -int_tests!(64, correct_bloom_filters: true); +// int8/int16 are incorrect: https://github.com/apache/arrow-datafusion/issues/9779 +int_tests!(32); +int_tests!(64); // $bits: number of bits of the integer to test (8, 16, 32, 64) // $correct_bloom_filters: if false, replicates the // https://github.com/apache/arrow-datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on UInt8 and UInt16 columns are still buggy. macro_rules! uint_tests { - ($bits:expr, correct_bloom_filters: $correct_bloom_filters:expr) => { + ($bits:expr) => { paste::item! { #[tokio::test] async fn []() { @@ -482,9 +481,9 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -496,9 +495,9 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -542,9 +541,9 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -585,10 +584,9 @@ macro_rules! uint_tests { }; } -uint_tests!(8, correct_bloom_filters: false); -uint_tests!(16, correct_bloom_filters: false); -uint_tests!(32, correct_bloom_filters: true); -uint_tests!(64, correct_bloom_filters: true); +// uint8/uint16 are incorrect: https://github.com/apache/arrow-datafusion/issues/9779 +uint_tests!(32); +uint_tests!(64); #[tokio::test] async fn prune_int32_eq_large_in_list() { From 86ad8a580863218f9fb123f09e1d058094ec3ef8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 8 Apr 2024 09:46:04 -0400 Subject: [PATCH 09/15] Move LogicalPlan tree_node modul (#9995) --- datafusion/expr/src/logical_plan/mod.rs | 1 + datafusion/expr/src/logical_plan/plan.rs | 2 +- .../plan.rs => logical_plan/tree_node.rs} | 0 .../src/{tree_node/expr.rs => tree_node.rs} | 0 datafusion/expr/src/tree_node/mod.rs | 21 ------------------- 5 files changed, 2 insertions(+), 22 deletions(-) rename datafusion/expr/src/{tree_node/plan.rs => logical_plan/tree_node.rs} (100%) rename datafusion/expr/src/{tree_node/expr.rs => tree_node.rs} (100%) delete mode 100644 datafusion/expr/src/tree_node/mod.rs diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 84781cb2e9ec..a1fe7a6f0a51 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -22,6 +22,7 @@ pub mod dml; mod extension; mod plan; mod statement; +mod tree_node; pub use builder::{ build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4f55bbfe3f4d..860fd7daafbf 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -56,7 +56,7 @@ use datafusion_common::{ // backwards compatibility use crate::display::PgJsonVisitor; -use crate::tree_node::expr::transform_option_vec; +use crate::tree_node::transform_option_vec; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/logical_plan/tree_node.rs similarity index 100% rename from datafusion/expr/src/tree_node/plan.rs rename to datafusion/expr/src/logical_plan/tree_node.rs diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node.rs similarity index 100% rename from datafusion/expr/src/tree_node/expr.rs rename to datafusion/expr/src/tree_node.rs diff --git a/datafusion/expr/src/tree_node/mod.rs b/datafusion/expr/src/tree_node/mod.rs deleted file mode 100644 index 3f8bb6d3257e..000000000000 --- a/datafusion/expr/src/tree_node/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Tree node implementation for logical expr and logical plan - -pub mod expr; -pub mod plan; From 8c9e5678228557aff370b137e9029462230df68a Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja <69668484+kevinmingtarja@users.noreply.github.com> Date: Tue, 9 Apr 2024 00:01:26 +0800 Subject: [PATCH 10/15] Optimize performance of substr_index and add tests (#9973) * Optimize performance of substr_index --- .../functions/src/unicode/substrindex.rs | 153 +++++++++++++++--- .../sqllogictest/test_files/functions.slt | 11 +- 2 files changed, 143 insertions(+), 21 deletions(-) diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index d00108a68fc9..da4ff55828e9 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder}; use arrow::datatypes::DataType; use datafusion_common::cast::{as_generic_string_array, as_int64_array}; @@ -101,38 +101,151 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { let delimiter_array = as_generic_string_array::(&args[1])?; let count_array = as_int64_array(&args[2])?; - let result = string_array + let mut builder = StringBuilder::new(); + string_array .iter() .zip(delimiter_array.iter()) .zip(count_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { + .for_each(|((string, delimiter), n)| match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { // In MySQL, these cases will return an empty string. if n == 0 || string.is_empty() || delimiter.is_empty() { - return Some(String::new()); + builder.append_value(""); + return; } - let splitted: Box> = if n > 0 { - Box::new(string.split(delimiter)) + let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); + let length = if n > 0 { + let splitted = string.split(delimiter); + splitted + .take(occurrences) + .map(|s| s.len() + delimiter.len()) + .sum::() + - delimiter.len() } else { - Box::new(string.rsplit(delimiter)) + let splitted = string.rsplit(delimiter); + splitted + .take(occurrences) + .map(|s| s.len() + delimiter.len()) + .sum::() + - delimiter.len() }; - let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - // The length of the substring covered by substr_index. - let length = splitted - .take(occurrences) // at least 1 element, since n != 0 - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len(); if n > 0 { - Some(string[..length].to_owned()) + match string.get(..length) { + Some(substring) => builder.append_value(substring), + None => builder.append_null(), + } } else { - Some(string[string.len() - length..].to_owned()) + match string.get(string.len().saturating_sub(length)..) { + Some(substring) => builder.append_value(substring), + None => builder.append_null(), + } } } - _ => None, - }) - .collect::>(); + _ => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} - Ok(Arc::new(result) as ArrayRef) +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::substrindex::SubstrIndexFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("www")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("www.apache")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ], + Ok(Some("apache.org")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ], + Ok(Some("org")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from("")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 21433ba16810..38ebedf5654a 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -940,7 +940,8 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM (VALUES ROW('arrow.apache.org'), ROW('.'), - ROW('...') + ROW('...'), + ROW(NULL) ) AS strings(str), (VALUES ROW(1), @@ -954,6 +955,14 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM ) AS occurrences(n) ORDER BY str DESC, n; ---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL arrow.apache.org -100 arrow.apache.org arrow.apache.org -3 arrow.apache.org arrow.apache.org -2 apache.org From bece785174c199f4fde4343a27c2213fae11bfb8 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Mon, 8 Apr 2024 12:03:00 -0400 Subject: [PATCH 11/15] move Floor, Gcd, Lcm, Pi to datafusion-functions (#9976) * move Floor, Gcd, Lcm, Pi to datafusion-functions --- datafusion/expr/src/built_in_function.rs | 29 +--- datafusion/expr/src/expr_fn.rs | 17 +- datafusion/functions/src/math/gcd.rs | 145 ++++++++++++++++++ datafusion/functions/src/math/lcm.rs | 126 +++++++++++++++ datafusion/functions/src/math/mod.rs | 14 +- datafusion/functions/src/math/pi.rs | 76 +++++++++ .../optimizer/src/analyzer/type_coercion.rs | 30 ++-- .../physical-expr/src/equivalence/ordering.rs | 44 ++++-- .../src/equivalence/projection.rs | 41 +++-- .../src/equivalence/properties.rs | 41 ++--- datafusion/physical-expr/src/functions.rs | 10 +- .../physical-expr/src/math_expressions.rs | 118 -------------- datafusion/physical-expr/src/udf.rs | 55 ++----- datafusion/physical-expr/src/utils/mod.rs | 99 +++++++++++- datafusion/proto/proto/datafusion.proto | 8 +- datafusion/proto/src/generated/pbjson.rs | 12 -- datafusion/proto/src/generated/prost.rs | 16 +- .../proto/src/logical_plan/from_proto.rs | 20 +-- datafusion/proto/src/logical_plan/to_proto.rs | 4 - datafusion/sql/src/expr/function.rs | 20 ++- datafusion/sql/src/expr/mod.rs | 7 +- 21 files changed, 588 insertions(+), 344 deletions(-) create mode 100644 datafusion/functions/src/math/gcd.rs create mode 100644 datafusion/functions/src/math/lcm.rs create mode 100644 datafusion/functions/src/math/pi.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index dc1fc98a5c02..7426ccd938e7 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -45,20 +45,12 @@ pub enum BuiltinScalarFunction { Exp, /// factorial Factorial, - /// floor - Floor, - /// gcd, Greatest common divisor - Gcd, - /// lcm, Least common multiple - Lcm, /// iszero Iszero, /// log, same as log10 Log, /// nanvl Nanvl, - /// pi - Pi, /// power Power, /// round @@ -135,13 +127,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Exp => Volatility::Immutable, BuiltinScalarFunction::Factorial => Volatility::Immutable, - BuiltinScalarFunction::Floor => Volatility::Immutable, - BuiltinScalarFunction::Gcd => Volatility::Immutable, BuiltinScalarFunction::Iszero => Volatility::Immutable, - BuiltinScalarFunction::Lcm => Volatility::Immutable, BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Nanvl => Volatility::Immutable, - BuiltinScalarFunction::Pi => Volatility::Immutable, BuiltinScalarFunction::Power => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, @@ -183,13 +171,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } - BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::EndsWith => Ok(Boolean), - BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Gcd - | BuiltinScalarFunction::Lcm => Ok(Int64), + BuiltinScalarFunction::Factorial => Ok(Int64), BuiltinScalarFunction::Power => match &input_expr_types[0] { Int64 => Ok(Int64), @@ -210,7 +195,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor | BuiltinScalarFunction::Round | BuiltinScalarFunction::Trunc | BuiltinScalarFunction::Cot => match input_expr_types[0] { @@ -248,7 +232,6 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Power => Signature::one_of( vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], @@ -289,12 +272,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Factorial => { Signature::uniform(1, vec![Int64], self.volatility()) } - BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { - Signature::uniform(2, vec![Int64], self.volatility()) - } BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor | BuiltinScalarFunction::Cot => { // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we @@ -319,10 +298,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Floor | BuiltinScalarFunction::Round | BuiltinScalarFunction::Trunc - | BuiltinScalarFunction::Pi ) { Some(vec![Some(true)]) } else if *self == BuiltinScalarFunction::Log { @@ -339,13 +316,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cot => &["cot"], BuiltinScalarFunction::Exp => &["exp"], BuiltinScalarFunction::Factorial => &["factorial"], - BuiltinScalarFunction::Floor => &["floor"], - BuiltinScalarFunction::Gcd => &["gcd"], BuiltinScalarFunction::Iszero => &["iszero"], - BuiltinScalarFunction::Lcm => &["lcm"], BuiltinScalarFunction::Log => &["log"], BuiltinScalarFunction::Nanvl => &["nanvl"], - BuiltinScalarFunction::Pi => &["pi"], BuiltinScalarFunction::Power => &["power", "pow"], BuiltinScalarFunction::Random => &["random"], BuiltinScalarFunction::Round => &["round"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index f68685a87f13..6c811ff06418 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -297,11 +297,6 @@ pub fn concat_ws(sep: Expr, values: Vec) -> Expr { )) } -/// Returns an approximate value of π -pub fn pi() -> Expr { - Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Pi, vec![])) -} - /// Returns a random value in the range 0.0 <= x < 1.0 pub fn random() -> Expr { Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Random, vec![])) @@ -537,12 +532,6 @@ macro_rules! nary_scalar_expr { // math functions scalar_expr!(Cot, cot, num, "cotangent of a number"); scalar_expr!(Factorial, factorial, num, "factorial"); -scalar_expr!( - Floor, - floor, - num, - "nearest integer less than or equal to argument" -); scalar_expr!( Ceil, ceil, @@ -556,8 +545,7 @@ nary_scalar_expr!( "truncate toward zero, with optional precision" ); scalar_expr!(Exp, exp, num, "exponential"); -scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor"); -scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple"); + scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); @@ -974,7 +962,6 @@ mod test { fn scalar_function_definitions() { test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Factorial, factorial); - test_unary_scalar_expr!(Floor, floor); test_unary_scalar_expr!(Ceil, ceil); test_nary_scalar_expr!(Round, round, input); test_nary_scalar_expr!(Round, round, input, decimal_places); @@ -984,8 +971,6 @@ mod test { test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); - test_scalar_expr!(Gcd, gcd, arg_1, arg_2); - test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); } diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs new file mode 100644 index 000000000000..41c9e4e23314 --- /dev/null +++ b/datafusion/functions/src/math/gcd.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array}; +use std::any::Any; +use std::mem::swap; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Int64; + +use crate::utils::make_scalar_function; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub struct GcdFunc { + signature: Signature, +} + +impl Default for GcdFunc { + fn default() -> Self { + Self::new() + } +} + +impl GcdFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform(2, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for GcdFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "gcd" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(gcd, vec![])(args) + } +} + +/// Gcd SQL function +fn gcd(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + Int64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Int64Array, + Int64Array, + { compute_gcd } + )) as ArrayRef), + other => exec_err!("Unsupported data type {other:?} for function gcd"), + } +} + +/// Computes greatest common divisor using Binary GCD algorithm. +pub fn compute_gcd(x: i64, y: i64) -> i64 { + let mut a = x.wrapping_abs(); + let mut b = y.wrapping_abs(); + + if a == 0 { + return b; + } + if b == 0 { + return a; + } + + let shift = (a | b).trailing_zeros(); + a >>= shift; + b >>= shift; + a >>= a.trailing_zeros(); + + loop { + b >>= b.trailing_zeros(); + if a > b { + swap(&mut a, &mut b); + } + + b -= a; + + if b == 0 { + return a << shift; + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, Int64Array}; + + use crate::math::gcd::gcd; + use datafusion_common::cast::as_int64_array; + + #[test] + fn test_gcd_i64() { + let args: Vec = vec![ + Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x + Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y + ]; + + let result = gcd(&args).expect("failed to initialize function gcd"); + let ints = as_int64_array(&result).expect("failed to initialize function gcd"); + + assert_eq!(ints.len(), 4); + assert_eq!(ints.value(0), 0); + assert_eq!(ints.value(1), 1); + assert_eq!(ints.value(2), 5); + assert_eq!(ints.value(3), 8); + } +} diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs new file mode 100644 index 000000000000..3674f7371de2 --- /dev/null +++ b/datafusion/functions/src/math/lcm.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Int64; + +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::math::gcd::compute_gcd; +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct LcmFunc { + signature: Signature, +} + +impl Default for LcmFunc { + fn default() -> Self { + LcmFunc::new() + } +} + +impl LcmFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform(2, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for LcmFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "lcm" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(lcm, vec![])(args) + } +} + +/// Lcm SQL function +fn lcm(args: &[ArrayRef]) -> Result { + let compute_lcm = |x: i64, y: i64| { + let a = x.wrapping_abs(); + let b = y.wrapping_abs(); + + if a == 0 || b == 0 { + return 0; + } + a / compute_gcd(a, b) * b + }; + + match args[0].data_type() { + Int64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Int64Array, + Int64Array, + { compute_lcm } + )) as ArrayRef), + other => exec_err!("Unsupported data type {other:?} for function lcm"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, Int64Array}; + + use datafusion_common::cast::as_int64_array; + + use crate::math::lcm::lcm; + + #[test] + fn test_lcm_i64() { + let args: Vec = vec![ + Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x + Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y + ]; + + let result = lcm(&args).expect("failed to initialize function lcm"); + let ints = as_int64_array(&result).expect("failed to initialize function lcm"); + + assert_eq!(ints.len(), 4); + assert_eq!(ints.value(0), 0); + assert_eq!(ints.value(1), 6); + assert_eq!(ints.value(2), 75); + assert_eq!(ints.value(3), 16); + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index f241c8b3250b..3a1f7cc13bb7 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -18,11 +18,17 @@ //! "math" DataFusion functions pub mod abs; +pub mod gcd; +pub mod lcm; pub mod nans; +pub mod pi; // Create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(gcd::GcdFunc, GCD, gcd); +make_udf_function!(lcm::LcmFunc, LCM, lcm); +make_udf_function!(pi::PiFunc, PI, pi); make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); @@ -50,6 +56,8 @@ make_math_unary_udf!(CosFunc, COS, cos, cos, None); make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, None); make_math_unary_udf!(DegreesFunc, DEGREES, degrees, to_degrees, None); +make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, Some(vec![Some(true)])); + // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( ( @@ -86,5 +94,9 @@ export_functions!( (cbrt, num, "cube root of a number"), (cos, num, "cosine"), (cosh, num, "hyperbolic cosine"), - (degrees, num, "converts radians to degrees") + (degrees, num, "converts radians to degrees"), + (gcd, x y, "greatest common divisor"), + (lcm, x y, "least common multiple"), + (floor, num, "nearest integer less than or equal to argument"), + (pi, , "Returns an approximate value of π") ); diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs new file mode 100644 index 000000000000..0801e797511b --- /dev/null +++ b/datafusion/functions/src/math/pi.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::Float64Array; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Float64; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, FuncMonotonicity, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +#[derive(Debug)] +pub struct PiFunc { + signature: Signature, +} + +impl Default for PiFunc { + fn default() -> Self { + PiFunc::new() + } +} + +impl PiFunc { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for PiFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "pi" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Float64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if !matches!(&args[0], ColumnarValue::Array(_)) { + return exec_err!("Expect pi function to take no param"); + } + let array = Float64Array::from_value(std::f64::consts::PI, 1); + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 04de243fba07..1ea8b9534e80 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -19,9 +19,8 @@ use std::sync::Arc; -use crate::analyzer::AnalyzerRule; - use arrow::datatypes::{DataType, IntervalUnit}; + use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{ @@ -51,6 +50,8 @@ use datafusion_expr::{ WindowFrameUnits, }; +use crate::analyzer::AnalyzerRule; + #[derive(Default)] pub struct TypeCoercion {} @@ -758,25 +759,25 @@ mod test { use std::any::Any; use std::sync::{Arc, OnceLock}; - use crate::analyzer::type_coercion::{ - coerce_case_expression, TypeCoercion, TypeCoercionRewriter, - }; - use crate::test::assert_analyzed_plan_eq; - use arrow::datatypes::{DataType, Field, TimeUnit}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection}; use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, lit, - AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, - BuiltinScalarFunction, Case, ColumnarValue, Expr, ExprSchemable, Filter, - LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, - Subquery, Volatility, + AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, Case, + ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan, Operator, ScalarUDF, + ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; + use crate::analyzer::type_coercion::{ + coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + }; + use crate::test::assert_analyzed_plan_eq; + fn empty() -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -875,14 +876,15 @@ mod test { // test that automatic argument type coercion for scalar functions work let empty = empty(); let lit_expr = lit(10i64); - let fun: BuiltinScalarFunction = BuiltinScalarFunction::Floor; + let fun = ScalarUDF::new_from_impl(TestScalarUDF {}); let scalar_function_expr = - Expr::ScalarFunction(ScalarFunction::new(fun, vec![lit_expr])); + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr])); let plan = LogicalPlan::Projection(Projection::try_new( vec![scalar_function_expr], empty, )?); - let expected = "Projection: floor(CAST(Int64(10) AS Float64))\n EmptyRelation"; + let expected = + "Projection: TestScalarUDF(CAST(Int64(10) AS Float32))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 1364d3a8c028..688cdf798bdd 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::SortOptions; use std::hash::Hash; use std::sync::Arc; +use arrow_schema::SortOptions; + use crate::equivalence::add_offset_to_expr; use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; @@ -220,6 +221,16 @@ fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> #[cfg(test)] mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SortOptions; + use itertools::Itertools; + + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::{BuiltinScalarFunction, Operator, ScalarUDF}; + use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, create_random_schema, create_test_params, generate_table_for_eq_properties, is_table_same_after_sort, @@ -231,14 +242,8 @@ mod tests { use crate::expressions::Column; use crate::expressions::{col, BinaryExpr}; use crate::functions::create_physical_expr; + use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExpr, PhysicalSortExpr}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SortOptions; - use datafusion_common::Result; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{BuiltinScalarFunction, Operator}; - use itertools::Itertools; - use std::sync::Arc; #[test] fn test_ordering_satisfy() -> Result<()> { @@ -281,17 +286,20 @@ mod tests { let col_d = &col("d", &test_schema)?; let col_e = &col("e", &test_schema)?; let col_f = &col("f", &test_schema)?; - let floor_a = &create_physical_expr( - &BuiltinScalarFunction::Floor, + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = &crate::udf::create_physical_expr( + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; - let floor_f = &create_physical_expr( - &BuiltinScalarFunction::Floor, + let floor_f = &crate::udf::create_physical_expr( + &test_fun, &[col("f", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; let exp_a = &create_physical_expr( &BuiltinScalarFunction::Exp, @@ -804,11 +812,13 @@ mod tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = crate::udf::create_physical_expr( + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index b8231a74c271..5efcf5942c39 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -17,13 +17,14 @@ use std::sync::Arc; -use crate::expressions::Column; -use crate::PhysicalExpr; - use arrow::datatypes::SchemaRef; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; +use crate::expressions::Column; +use crate::PhysicalExpr; + /// Stores the mapping between source expressions and target expressions for a /// projection. #[derive(Debug, Clone)] @@ -111,7 +112,14 @@ impl ProjectionMapping { mod tests { use std::sync::Arc; - use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{SortOptions, TimeUnit}; + use itertools::Itertools; + + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::{BuiltinScalarFunction, Operator, ScalarUDF}; + use crate::equivalence::tests::{ apply_projection, convert_to_orderings, convert_to_orderings_owned, create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, @@ -119,16 +127,11 @@ mod tests { }; use crate::equivalence::EquivalenceProperties; use crate::expressions::{col, BinaryExpr}; - use crate::functions::create_physical_expr; + use crate::udf::create_physical_expr; + use crate::utils::tests::TestScalarUDF; use crate::PhysicalSortExpr; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{SortOptions, TimeUnit}; - use datafusion_common::Result; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{BuiltinScalarFunction, Operator}; - - use itertools::Itertools; + use super::*; #[test] fn project_orderings() -> Result<()> { @@ -646,7 +649,7 @@ mod tests { col_b.clone(), )) as Arc; - let round_c = &create_physical_expr( + let round_c = &crate::functions::create_physical_expr( &BuiltinScalarFunction::Round, &[col_c.clone()], &schema, @@ -973,11 +976,13 @@ mod tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; // a + b let a_plus_b = Arc::new(BinaryExpr::new( @@ -1049,11 +1054,13 @@ mod tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; // a + b let a_plus_b = Arc::new(BinaryExpr::new( diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 7ce540b267b2..c14c88d6c69b 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -18,7 +18,13 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use super::ordering::collapse_lex_ordering; +use arrow_schema::{SchemaRef, SortOptions}; +use indexmap::{IndexMap, IndexSet}; +use itertools::Itertools; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{JoinSide, JoinType, Result}; + use crate::equivalence::{ collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; @@ -30,12 +36,7 @@ use crate::{ PhysicalSortRequirement, }; -use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinSide, JoinType, Result}; - -use indexmap::{IndexMap, IndexSet}; -use itertools::Itertools; +use super::ordering::collapse_lex_ordering; /// A `EquivalenceProperties` object stores useful information related to a schema. /// Currently, it keeps track of: @@ -1296,7 +1297,13 @@ mod tests { use std::ops::Not; use std::sync::Arc; - use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{Fields, SortOptions, TimeUnit}; + use itertools::Itertools; + + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{Operator, ScalarUDF}; + use crate::equivalence::add_offset_to_expr; use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, @@ -1304,16 +1311,10 @@ mod tests { generate_table_for_eq_properties, is_table_same_after_sort, output_schema, }; use crate::expressions::{col, BinaryExpr, Column}; - use crate::functions::create_physical_expr; + use crate::utils::tests::TestScalarUDF; use crate::PhysicalSortExpr; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{Fields, SortOptions, TimeUnit}; - use datafusion_common::Result; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{BuiltinScalarFunction, Operator}; - - use itertools::Itertools; + use super::*; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -1792,11 +1793,13 @@ mod tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = crate::udf::create_physical_expr( + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 770d9184325a..79d69b273d2c 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -184,16 +184,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Factorial => { Arc::new(|args| make_scalar_function_inner(math_expressions::factorial)(args)) } - BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), - BuiltinScalarFunction::Gcd => { - Arc::new(|args| make_scalar_function_inner(math_expressions::gcd)(args)) - } BuiltinScalarFunction::Iszero => { Arc::new(|args| make_scalar_function_inner(math_expressions::iszero)(args)) } - BuiltinScalarFunction::Lcm => { - Arc::new(|args| make_scalar_function_inner(math_expressions::lcm)(args)) - } BuiltinScalarFunction::Nanvl => { Arc::new(|args| make_scalar_function_inner(math_expressions::nanvl)(args)) } @@ -204,7 +197,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Trunc => { Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) } - BuiltinScalarFunction::Pi => Arc::new(math_expressions::pi), BuiltinScalarFunction::Power => { Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args)) } @@ -573,7 +565,7 @@ mod tests { let execution_props = ExecutionProps::new(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let funs = [BuiltinScalarFunction::Pi, BuiltinScalarFunction::Random]; + let funs = [BuiltinScalarFunction::Random]; for fun in funs.iter() { create_physical_expr_with_type_coercion(fun, &[], &schema, &execution_props)?; diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index f8244ad9525f..384f8d87eb96 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -19,7 +19,6 @@ use std::any::type_name; use std::iter; -use std::mem::swap; use std::sync::Arc; use arrow::array::ArrayRef; @@ -161,7 +160,6 @@ math_unary_function!("atan", atan); math_unary_function!("asinh", asinh); math_unary_function!("acosh", acosh); math_unary_function!("atanh", atanh); -math_unary_function!("floor", floor); math_unary_function!("ceil", ceil); math_unary_function!("exp", exp); math_unary_function!("ln", ln); @@ -181,79 +179,6 @@ pub fn factorial(args: &[ArrayRef]) -> Result { } } -/// Computes greatest common divisor using Binary GCD algorithm. -fn compute_gcd(x: i64, y: i64) -> i64 { - let mut a = x.wrapping_abs(); - let mut b = y.wrapping_abs(); - - if a == 0 { - return b; - } - if b == 0 { - return a; - } - - let shift = (a | b).trailing_zeros(); - a >>= shift; - b >>= shift; - a >>= a.trailing_zeros(); - - loop { - b >>= b.trailing_zeros(); - if a > b { - swap(&mut a, &mut b); - } - - b -= a; - - if b == 0 { - return a << shift; - } - } -} - -/// Gcd SQL function -pub fn gcd(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Int64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Int64Array, - Int64Array, - { compute_gcd } - )) as ArrayRef), - other => exec_err!("Unsupported data type {other:?} for function gcd"), - } -} - -/// Lcm SQL function -pub fn lcm(args: &[ArrayRef]) -> Result { - let compute_lcm = |x: i64, y: i64| { - let a = x.wrapping_abs(); - let b = y.wrapping_abs(); - - if a == 0 || b == 0 { - return 0; - } - a / compute_gcd(a, b) * b - }; - - match args[0].data_type() { - DataType::Int64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Int64Array, - Int64Array, - { compute_lcm } - )) as ArrayRef), - other => exec_err!("Unsupported data type {other:?} for function lcm"), - } -} - /// Nanvl SQL function pub fn nanvl(args: &[ArrayRef]) -> Result { match args[0].data_type() { @@ -345,15 +270,6 @@ pub fn iszero(args: &[ArrayRef]) -> Result { } } -/// Pi SQL function -pub fn pi(args: &[ColumnarValue]) -> Result { - if !matches!(&args[0], ColumnarValue::Array(_)) { - return exec_err!("Expect pi function to take no param"); - } - let array = Float64Array::from_value(std::f64::consts::PI, 1); - Ok(ColumnarValue::Array(Arc::new(array))) -} - /// Random SQL function pub fn random(args: &[ColumnarValue]) -> Result { let len: usize = match &args[0] { @@ -808,40 +724,6 @@ mod tests { assert_eq!(ints, &expected); } - #[test] - fn test_gcd_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x - Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y - ]; - - let result = gcd(&args).expect("failed to initialize function gcd"); - let ints = as_int64_array(&result).expect("failed to initialize function gcd"); - - assert_eq!(ints.len(), 4); - assert_eq!(ints.value(0), 0); - assert_eq!(ints.value(1), 1); - assert_eq!(ints.value(2), 5); - assert_eq!(ints.value(3), 8); - } - - #[test] - fn test_lcm_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x - Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y - ]; - - let result = lcm(&args).expect("failed to initialize function lcm"); - let ints = as_int64_array(&result).expect("failed to initialize function lcm"); - - assert_eq!(ints.len(), 4); - assert_eq!(ints.value(0), 0); - assert_eq!(ints.value(1), 6); - assert_eq!(ints.value(2), 75); - assert_eq!(ints.value(3), 16); - } - #[test] fn test_cot_f32() { let args: Vec = diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index 4fc94bfa15ec..368dfdf92f45 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -16,14 +16,17 @@ // under the License. //! UDF support -use crate::{PhysicalExpr, ScalarFunctionExpr}; +use std::sync::Arc; + use arrow_schema::Schema; + use datafusion_common::{DFSchema, Result}; pub use datafusion_expr::ScalarUDF; use datafusion_expr::{ type_coercion::functions::data_types, Expr, ScalarFunctionDefinition, }; -use std::sync::Arc; + +use crate::{PhysicalExpr, ScalarFunctionExpr}; /// Create a physical expression of the UDF. /// @@ -60,58 +63,18 @@ pub fn create_physical_expr( #[cfg(test)] mod tests { - use arrow_schema::{DataType, Schema}; + use arrow_schema::Schema; + use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature, Volatility, - }; + use datafusion_expr::ScalarUDF; + use crate::utils::tests::TestScalarUDF; use crate::ScalarFunctionExpr; use super::create_physical_expr; #[test] fn test_functions() -> Result<()> { - #[derive(Debug, Clone)] - struct TestScalarUDF { - signature: Signature, - } - - impl TestScalarUDF { - fn new() -> Self { - let signature = - Signature::exact(vec![DataType::Float64], Volatility::Immutable); - - Self { signature } - } - } - - impl ScalarUDFImpl for TestScalarUDF { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "my_fn" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) - } - - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!("my_fn is not implemented") - } - - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) - } - } - // create and register the udf let udf = ScalarUDF::from(TestScalarUDF::new()); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index e55bc3d15665..d7bebbff891c 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -256,7 +256,9 @@ pub fn merge_vectors( } #[cfg(test)] -mod tests { +pub(crate) mod tests { + use arrow_array::{ArrayRef, Float32Array, Float64Array}; + use std::any::Any; use std::fmt::{Display, Formatter}; use std::sync::Arc; @@ -265,10 +267,103 @@ mod tests { use crate::PhysicalSortExpr; use arrow_schema::{DataType, Field, Schema}; - use datafusion_common::{Result, ScalarValue}; + use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; + use datafusion_expr::{ + ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, + }; use petgraph::visit::Bfs; + #[derive(Debug, Clone)] + pub struct TestScalarUDF { + signature: Signature, + } + + impl TestScalarUDF { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } + } + + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "test-scalar-udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new({ + let arg = &args[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + })?; + + arg.iter() + .map(|a| a.map(f64::floor)) + .collect::() + }), + DataType::Float32 => Arc::new({ + let arg = &args[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + })?; + + arg.iter() + .map(|a| a.map(f32::floor)) + .collect::() + }), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + Ok(ColumnarValue::Array(arr)) + } + } + #[derive(Clone)] struct DummyProperty { expr_type: String, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7f967657f573..b656bededc07 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -550,7 +550,7 @@ enum ScalarFunction { // 6 was Cos // 7 was Digest Exp = 8; - Floor = 9; + // 9 was Floor // 10 was Ln Log = 11; // 12 was Log10 @@ -621,12 +621,12 @@ enum ScalarFunction { // 77 was Sinh // 78 was Cosh // Tanh = 79 - Pi = 80; + // 80 was Pi // 81 was Degrees // 82 was Radians Factorial = 83; - Lcm = 84; - Gcd = 85; + // 84 was Lcm + // 85 was Gcd // 86 was ArrayAppend // 87 was ArrayConcat // 88 was ArrayDims diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 966d7f7f7487..c13ae045bdb5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22794,7 +22794,6 @@ impl serde::Serialize for ScalarFunction { Self::Unknown => "unknown", Self::Ceil => "Ceil", Self::Exp => "Exp", - Self::Floor => "Floor", Self::Log => "Log", Self::Round => "Round", Self::Trunc => "Trunc", @@ -22804,10 +22803,7 @@ impl serde::Serialize for ScalarFunction { Self::Random => "Random", Self::Coalesce => "Coalesce", Self::Power => "Power", - Self::Pi => "Pi", Self::Factorial => "Factorial", - Self::Lcm => "Lcm", - Self::Gcd => "Gcd", Self::Cot => "Cot", Self::Nanvl => "Nanvl", Self::Iszero => "Iszero", @@ -22826,7 +22822,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown", "Ceil", "Exp", - "Floor", "Log", "Round", "Trunc", @@ -22836,10 +22831,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Random", "Coalesce", "Power", - "Pi", "Factorial", - "Lcm", - "Gcd", "Cot", "Nanvl", "Iszero", @@ -22887,7 +22879,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown" => Ok(ScalarFunction::Unknown), "Ceil" => Ok(ScalarFunction::Ceil), "Exp" => Ok(ScalarFunction::Exp), - "Floor" => Ok(ScalarFunction::Floor), "Log" => Ok(ScalarFunction::Log), "Round" => Ok(ScalarFunction::Round), "Trunc" => Ok(ScalarFunction::Trunc), @@ -22897,10 +22888,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Random" => Ok(ScalarFunction::Random), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), - "Pi" => Ok(ScalarFunction::Pi), "Factorial" => Ok(ScalarFunction::Factorial), - "Lcm" => Ok(ScalarFunction::Lcm), - "Gcd" => Ok(ScalarFunction::Gcd), "Cot" => Ok(ScalarFunction::Cot), "Nanvl" => Ok(ScalarFunction::Nanvl), "Iszero" => Ok(ScalarFunction::Iszero), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c94aa1f4ed93..092d5c59d081 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2849,7 +2849,7 @@ pub enum ScalarFunction { /// 6 was Cos /// 7 was Digest Exp = 8, - Floor = 9, + /// 9 was Floor /// 10 was Ln Log = 11, /// 12 was Log10 @@ -2920,12 +2920,12 @@ pub enum ScalarFunction { /// 77 was Sinh /// 78 was Cosh /// Tanh = 79 - Pi = 80, + /// 80 was Pi /// 81 was Degrees /// 82 was Radians Factorial = 83, - Lcm = 84, - Gcd = 85, + /// 84 was Lcm + /// 85 was Gcd /// 86 was ArrayAppend /// 87 was ArrayConcat /// 88 was ArrayDims @@ -2989,7 +2989,6 @@ impl ScalarFunction { ScalarFunction::Unknown => "unknown", ScalarFunction::Ceil => "Ceil", ScalarFunction::Exp => "Exp", - ScalarFunction::Floor => "Floor", ScalarFunction::Log => "Log", ScalarFunction::Round => "Round", ScalarFunction::Trunc => "Trunc", @@ -2999,10 +2998,7 @@ impl ScalarFunction { ScalarFunction::Random => "Random", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", - ScalarFunction::Pi => "Pi", ScalarFunction::Factorial => "Factorial", - ScalarFunction::Lcm => "Lcm", - ScalarFunction::Gcd => "Gcd", ScalarFunction::Cot => "Cot", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Iszero => "Iszero", @@ -3015,7 +3011,6 @@ impl ScalarFunction { "unknown" => Some(Self::Unknown), "Ceil" => Some(Self::Ceil), "Exp" => Some(Self::Exp), - "Floor" => Some(Self::Floor), "Log" => Some(Self::Log), "Round" => Some(Self::Round), "Trunc" => Some(Self::Trunc), @@ -3025,10 +3020,7 @@ impl ScalarFunction { "Random" => Some(Self::Random), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), - "Pi" => Some(Self::Pi), "Factorial" => Some(Self::Factorial), - "Lcm" => Some(Self::Lcm), - "Gcd" => Some(Self::Gcd), "Cot" => Some(Self::Cot), "Nanvl" => Some(Self::Nanvl), "Iszero" => Some(Self::Iszero), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 96b3b5942ec3..9c24a3941895 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -39,9 +39,9 @@ use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_ use datafusion_expr::{ ceil, coalesce, concat_expr, concat_ws_expr, cot, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, floor, gcd, initcap, iszero, lcm, log, + factorial, initcap, iszero, log, logical_plan::{PlanType, StringifiedPlan}, - nanvl, pi, power, random, round, trunc, AggregateFunction, Between, BinaryExpr, + nanvl, power, random, round, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -423,9 +423,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Exp => Self::Exp, ScalarFunction::Log => Self::Log, ScalarFunction::Factorial => Self::Factorial, - ScalarFunction::Gcd => Self::Gcd, - ScalarFunction::Lcm => Self::Lcm, - ScalarFunction::Floor => Self::Floor, ScalarFunction::Ceil => Self::Ceil, ScalarFunction::Round => Self::Round, ScalarFunction::Trunc => Self::Trunc, @@ -435,7 +432,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::InitCap => Self::InitCap, ScalarFunction::Random => Self::Random, ScalarFunction::Coalesce => Self::Coalesce, - ScalarFunction::Pi => Self::Pi, ScalarFunction::Power => Self::Power, ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::Iszero => Self::Iszero, @@ -1301,9 +1297,6 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Floor => { - Ok(floor(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Factorial => { Ok(factorial(parse_expr(&args[0], registry, codec)?)) } @@ -1313,14 +1306,6 @@ pub fn parse_expr( ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Gcd => Ok(gcd( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Lcm => Ok(lcm( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Random => Ok(random()), ScalarFunction::Concat => { Ok(concat_expr(parse_exprs(args, registry, codec)?)) @@ -1335,7 +1320,6 @@ pub fn parse_expr( ScalarFunction::Coalesce => { Ok(coalesce(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Pi => Ok(pi()), ScalarFunction::Power => Ok(power( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index a10edb393241..bd964b43d418 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1410,10 +1410,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Exp => Self::Exp, BuiltinScalarFunction::Factorial => Self::Factorial, - BuiltinScalarFunction::Gcd => Self::Gcd, - BuiltinScalarFunction::Lcm => Self::Lcm, BuiltinScalarFunction::Log => Self::Log, - BuiltinScalarFunction::Floor => Self::Floor, BuiltinScalarFunction::Ceil => Self::Ceil, BuiltinScalarFunction::Round => Self::Round, BuiltinScalarFunction::Trunc => Self::Trunc, @@ -1423,7 +1420,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::InitCap => Self::InitCap, BuiltinScalarFunction::Random => Self::Random, BuiltinScalarFunction::Coalesce => Self::Coalesce, - BuiltinScalarFunction::Pi => Self::Pi, BuiltinScalarFunction::Power => Self::Power, BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::Iszero => Self::Iszero, diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index e97eb1a32b12..4bf0906685ca 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -18,7 +18,8 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, + internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, + Dependency, Result, }; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ @@ -264,6 +265,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") } + pub(super) fn sql_fn_name_to_expr( + &self, + expr: SQLExpr, + fn_name: &str, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let fun = self + .context_provider + .get_function_meta(fn_name) + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected '{fn_name}' function") + })?; + let args = vec![self.sql_expr_to_logical_expr(expr, schema, planner_context)?]; + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + } + pub(super) fn sql_named_function_to_expr( &self, expr: SQLExpr, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index c2f72720afcb..7763fa2d8dab 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -518,12 +518,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Floor { expr, field: _field, - } => self.sql_named_function_to_expr( - *expr, - BuiltinScalarFunction::Floor, - schema, - planner_context, - ), + } => self.sql_fn_name_to_expr(*expr, "floor", schema, planner_context), SQLExpr::Ceil { expr, field: _field, From 1c4c00230afe3058b45ed2df812daa16f7276ce5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 8 Apr 2024 14:41:03 -0400 Subject: [PATCH 12/15] Minor: Improve documentation on `LogicalPlan::apply*` and `LogicalPlan::map*` (#9996) --- datafusion/expr/src/logical_plan/plan.rs | 26 +++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 860fd7daafbf..ca8d718ec090 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -305,9 +305,9 @@ impl LogicalPlan { err } - /// Calls `f` on all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children. + /// Calls `f` on all expressions in the current `LogicalPlan` node. + /// + /// Note this does not include expressions in child `LogicalPlan` nodes. pub fn apply_expressions Result>( &self, mut f: F, @@ -393,6 +393,11 @@ impl LogicalPlan { } } + /// Rewrites all expressions in the current `LogicalPlan` node using `f`. + /// + /// Returns the current node. + /// + /// Note this does not include expressions in child `LogicalPlan` nodes. pub fn map_expressions Result>>( self, mut f: F, @@ -608,8 +613,9 @@ impl LogicalPlan { }) } - /// returns all inputs of this `LogicalPlan` node. Does not - /// include inputs to inputs, or subqueries. + /// Returns all inputs / children of this `LogicalPlan` node. + /// + /// Note does not include inputs to inputs, or subqueries. pub fn inputs(&self) -> Vec<&LogicalPlan> { match self { LogicalPlan::Projection(Projection { input, .. }) => vec![input], @@ -1370,6 +1376,10 @@ impl LogicalPlan { ) } + /// Calls `f` recursively on all children of the `LogicalPlan` node. + /// + /// Unlike [`Self::apply`], this method *does* includes `LogicalPlan`s that + /// are referenced in `Expr`s pub fn apply_with_subqueries Result>( &self, f: &mut F, @@ -1434,6 +1444,8 @@ impl LogicalPlan { ) } + /// Calls `f` on all subqueries referenced in expressions of the current + /// `LogicalPlan` node. fn apply_subqueries Result>( &self, mut f: F, @@ -1453,6 +1465,10 @@ impl LogicalPlan { }) } + /// Rewrites all subquery `LogicalPlan` in the current `LogicalPlan` node + /// using `f`. + /// + /// Returns the current node. fn map_subqueries Result>>( self, mut f: F, From ad0abe91752b8da327e97b09d8ca883c782f027d Mon Sep 17 00:00:00 2001 From: Kunal Kundu Date: Tue, 9 Apr 2024 00:39:27 +0530 Subject: [PATCH 13/15] move the Log, Power functions to datafusion-functions (#9983) * move the Log, Power functions to datafusion-functions * match type instead of name * fix formatting errors --- datafusion/core/tests/simplification.rs | 136 +++++++++ datafusion/expr/src/built_in_function.rs | 34 --- datafusion/expr/src/expr_fn.rs | 3 - datafusion/functions/src/macros.rs | 13 + datafusion/functions/src/math/log.rs | 259 ++++++++++++++++++ datafusion/functions/src/math/mod.rs | 6 + datafusion/functions/src/math/power.rs | 218 +++++++++++++++ .../simplify_expressions/expr_simplifier.rs | 74 ----- .../simplify_expressions/simplify_exprs.rs | 38 +-- .../src/simplify_expressions/utils.rs | 74 +---- datafusion/physical-expr/src/functions.rs | 6 - .../physical-expr/src/math_expressions.rs | 153 +---------- datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 - datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 14 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 - 17 files changed, 641 insertions(+), 407 deletions(-) create mode 100644 datafusion/functions/src/math/log.rs create mode 100644 datafusion/functions/src/math/power.rs diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index 25f994d320c1..5a2f040c09d8 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -25,6 +25,7 @@ use datafusion_common::cast::as_int32_array; use datafusion_common::ScalarValue; use datafusion_common::{DFSchemaRef, ToDFSchema}; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable, @@ -294,6 +295,45 @@ fn select_date_plus_interval() -> Result<()> { Ok(()) } +#[test] +fn simplify_project_scalar_fn() -> Result<()> { + // Issue https://github.com/apache/arrow-datafusion/issues/5996 + let schema = Schema::new(vec![Field::new("f", DataType::Float64, false)]); + let plan = table_scan(Some("test"), &schema, None)? + .project(vec![power(col("f"), lit(1.0))])? + .build()?; + + // before simplify: power(t.f, 1.0) + // after simplify: t.f as "power(t.f, 1.0)" + let expected = "Projection: test.f AS power(test.f,Float64(1))\ + \n TableScan: test"; + let actual = get_optimized_plan_formatted(&plan, &Utc::now()); + assert_eq!(expected, actual); + Ok(()) +} + +#[test] +fn simplify_scan_predicate() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("f", DataType::Float64, false), + Field::new("g", DataType::Float64, false), + ]); + let plan = table_scan_with_filters( + Some("test"), + &schema, + None, + vec![col("g").eq(power(col("f"), lit(1.0)))], + )? + .build()?; + + // before simplify: t.g = power(t.f, 1.0) + // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)" + let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]"; + let actual = get_optimized_plan_formatted(&plan, &Utc::now()); + assert_eq!(expected, actual); + Ok(()) +} + #[test] fn test_const_evaluator() { // true --> true @@ -431,3 +471,99 @@ fn multiple_now() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +// ------------------------------ +// --- Simplifier tests ----- +// ------------------------------ + +fn expr_test_schema() -> DFSchemaRef { + Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Boolean, true), + Field::new("c3", DataType::Int64, true), + Field::new("c4", DataType::UInt32, true), + Field::new("c1_non_null", DataType::Utf8, false), + Field::new("c2_non_null", DataType::Boolean, false), + Field::new("c3_non_null", DataType::Int64, false), + Field::new("c4_non_null", DataType::UInt32, false), + ]) + .to_dfschema_ref() + .unwrap() +} + +fn test_simplify(input_expr: Expr, expected_expr: Expr) { + let info: MyInfo = MyInfo { + schema: expr_test_schema(), + execution_props: ExecutionProps::new(), + }; + let simplifier = ExprSimplifier::new(info); + let simplified_expr = simplifier + .simplify(input_expr.clone()) + .expect("successfully evaluated"); + + assert_eq!( + simplified_expr, expected_expr, + "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" + ); +} + +#[test] +fn test_simplify_log() { + // Log(c3, 1) ===> 0 + { + let expr = log(col("c3_non_null"), lit(1)); + test_simplify(expr, lit(0i64)); + } + // Log(c3, c3) ===> 1 + { + let expr = log(col("c3_non_null"), col("c3_non_null")); + let expected = lit(1i64); + test_simplify(expr, expected); + } + // Log(c3, Power(c3, c4)) ===> c4 + { + let expr = log( + col("c3_non_null"), + power(col("c3_non_null"), col("c4_non_null")), + ); + let expected = col("c4_non_null"); + test_simplify(expr, expected); + } + // Log(c3, c4) ===> Log(c3, c4) + { + let expr = log(col("c3_non_null"), col("c4_non_null")); + let expected = log(col("c3_non_null"), col("c4_non_null")); + test_simplify(expr, expected); + } +} + +#[test] +fn test_simplify_power() { + // Power(c3, 0) ===> 1 + { + let expr = power(col("c3_non_null"), lit(0)); + let expected = lit(1i64); + test_simplify(expr, expected) + } + // Power(c3, 1) ===> c3 + { + let expr = power(col("c3_non_null"), lit(1)); + let expected = col("c3_non_null"); + test_simplify(expr, expected) + } + // Power(c3, Log(c3, c4)) ===> c4 + { + let expr = power( + col("c3_non_null"), + log(col("c3_non_null"), col("c4_non_null")), + ); + let expected = col("c4_non_null"); + test_simplify(expr, expected) + } + // Power(c3, c4) ===> Power(c3, c4) + { + let expr = power(col("c3_non_null"), col("c4_non_null")); + let expected = power(col("c3_non_null"), col("c4_non_null")); + test_simplify(expr, expected) + } +} diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 7426ccd938e7..d98d7d0abfe2 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -47,12 +47,8 @@ pub enum BuiltinScalarFunction { Factorial, /// iszero Iszero, - /// log, same as log10 - Log, /// nanvl Nanvl, - /// power - Power, /// round Round, /// trunc @@ -128,9 +124,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Exp => Volatility::Immutable, BuiltinScalarFunction::Factorial => Volatility::Immutable, BuiltinScalarFunction::Iszero => Volatility::Immutable, - BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Nanvl => Volatility::Immutable, - BuiltinScalarFunction::Power => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, @@ -176,16 +170,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Factorial => Ok(Int64), - BuiltinScalarFunction::Power => match &input_expr_types[0] { - Int64 => Ok(Int64), - _ => Ok(Float64), - }, - - BuiltinScalarFunction::Log => match &input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - BuiltinScalarFunction::Nanvl => match &input_expr_types[0] { Float32 => Ok(Float32), _ => Ok(Float64), @@ -233,10 +217,6 @@ impl BuiltinScalarFunction { self.volatility(), ), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), - BuiltinScalarFunction::Power => Signature::one_of( - vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], - self.volatility(), - ), BuiltinScalarFunction::Round => Signature::one_of( vec![ Exact(vec![Float64, Int64]), @@ -255,16 +235,6 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - - BuiltinScalarFunction::Log => Signature::one_of( - vec![ - Exact(vec![Float32]), - Exact(vec![Float64]), - Exact(vec![Float32, Float32]), - Exact(vec![Float64, Float64]), - ], - self.volatility(), - ), BuiltinScalarFunction::Nanvl => Signature::one_of( vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], self.volatility(), @@ -302,8 +272,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Trunc ) { Some(vec![Some(true)]) - } else if *self == BuiltinScalarFunction::Log { - Some(vec![Some(true), Some(false)]) } else { None } @@ -317,9 +285,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Exp => &["exp"], BuiltinScalarFunction::Factorial => &["factorial"], BuiltinScalarFunction::Iszero => &["iszero"], - BuiltinScalarFunction::Log => &["log"], BuiltinScalarFunction::Nanvl => &["nanvl"], - BuiltinScalarFunction::Power => &["power", "pow"], BuiltinScalarFunction::Random => &["random"], BuiltinScalarFunction::Round => &["round"], BuiltinScalarFunction::Trunc => &["trunc"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 6c811ff06418..b554d87bade1 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -546,9 +546,6 @@ nary_scalar_expr!( ); scalar_expr!(Exp, exp, num, "exponential"); -scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); -scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); - scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index c92cb27ef5bb..5ee47bd3e8eb 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -357,6 +357,19 @@ macro_rules! make_math_binary_udf { }; } +macro_rules! make_function_scalar_inputs { + ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ + let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); + + arg.iter() + .map(|a| match a { + Some(a) => Some($FUNC(a)), + _ => None, + }) + .collect::<$ARRAY_TYPE>() + }}; +} + macro_rules! make_function_inputs2 { ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs new file mode 100644 index 000000000000..2131b6aa6705 --- /dev/null +++ b/datafusion/functions/src/math/log.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `log()`. + +use arrow::datatypes::DataType; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::{ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition}; + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +use super::power::PowerFunc; + +#[derive(Debug)] +pub struct LogFunc { + signature: Signature, +} + +impl Default for LogFunc { + fn default() -> Self { + Self::new() + } +} + +impl LogFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Float32]), + Exact(vec![Float64]), + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LogFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "log" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true), Some(false)])) + } + + // Support overloaded log(base, x) and log(x) which defaults to log(10, x) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let mut base = ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))); + + let mut x = &args[0]; + if args.len() == 2 { + x = &args[1]; + base = ColumnarValue::Array(args[0].clone()); + } + // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => match base { + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { + Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { + |value: f64| f64::log(value, base as f64) + })) + } + ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( + x, + base, + "x", + "base", + Float64Array, + { f64::log } + )), + _ => { + return exec_err!("log function requires a scalar or array for base") + } + }, + + DataType::Float32 => match base { + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { + Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { + |value: f32| f32::log(value, base) + })) + } + ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( + x, + base, + "x", + "base", + Float32Array, + { f32::log } + )), + _ => { + return exec_err!("log function requires a scalar or array for base") + } + }, + other => { + return exec_err!("Unsupported data type {other:?} for function log") + } + }; + + Ok(ColumnarValue::Array(arr)) + } + + /// Simplify the `log` function by the relevant rules: + /// 1. Log(a, 1) ===> 0 + /// 2. Log(a, Power(a, b)) ===> b + /// 3. Log(a, a) ===> 1 + fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + let mut number = &args[0]; + let mut base = + &Expr::Literal(ScalarValue::new_ten(&info.get_data_type(number)?)?); + if args.len() == 2 { + base = &args[0]; + number = &args[1]; + } + + match number { + Expr::Literal(value) + if value == &ScalarValue::new_one(&info.get_data_type(number)?)? => + { + Ok(ExprSimplifyResult::Simplified(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(base)?)?, + ))) + } + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(fun), + args, + }) if base == &args[0] + && fun + .as_ref() + .inner() + .as_any() + .downcast_ref::() + .is_some() => + { + Ok(ExprSimplifyResult::Simplified(args[1].clone())) + } + _ => { + if number == base { + Ok(ExprSimplifyResult::Simplified(Expr::Literal( + ScalarValue::new_one(&info.get_data_type(number)?)?, + ))) + } else { + Ok(ExprSimplifyResult::Original(args)) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use datafusion_common::cast::{as_float32_array, as_float64_array}; + + use super::*; + + #[test] + fn test_log_f64() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 8.0, 4.0, 81.0, 625.0, + ]))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 3.0); + assert_eq!(floats.value(1), 2.0); + assert_eq!(floats.value(2), 4.0); + assert_eq!(floats.value(3), 4.0); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_f32() { + let args = [ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 8.0, 4.0, 81.0, 625.0, + ]))), // num + ]; + + let result = LogFunc::new() + .invoke(&args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 3.0); + assert_eq!(floats.value(1), 2.0); + assert_eq!(floats.value(2), 4.0); + assert_eq!(floats.value(3), 4.0); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 3a1f7cc13bb7..2655edfe76dc 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -20,12 +20,16 @@ pub mod abs; pub mod gcd; pub mod lcm; +pub mod log; pub mod nans; pub mod pi; +pub mod power; // Create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(log::LogFunc, LOG, log); +make_udf_function!(power::PowerFunc, POWER, power); make_udf_function!(gcd::GcdFunc, GCD, gcd); make_udf_function!(lcm::LcmFunc, LCM, lcm); make_udf_function!(pi::PiFunc, PI, pi); @@ -66,6 +70,8 @@ export_functions!( "returns true if a given number is +NaN or -NaN otherwise returns false" ), (abs, num, "returns the absolute value of a given number"), + (power, base exponent, "`base` raised to the power of `exponent`"), + (log, base num, "logarithm of a number for a particular `base`"), (log2, num, "base 2 logarithm of a number"), (log10, num, "base 10 logarithm of a number"), (ln, num, "natural logarithm (base e) of a number"), diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs new file mode 100644 index 000000000000..8e3b2cf02405 --- /dev/null +++ b/datafusion/functions/src/math/power.rs @@ -0,0 +1,218 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `power()`. + +use arrow::datatypes::DataType; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionDefinition}; + +use arrow::array::{ArrayRef, Float64Array, Int64Array}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +use super::log::LogFunc; + +#[derive(Debug)] +pub struct PowerFunc { + signature: Signature, + aliases: Vec, +} + +impl Default for PowerFunc { + fn default() -> Self { + Self::new() + } +} + +impl PowerFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], + Volatility::Immutable, + ), + aliases: vec![String::from("pow")], + } + } +} + +impl ScalarUDFImpl for PowerFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "power" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + DataType::Int64 => Ok(DataType::Int64), + _ => Ok(DataType::Float64), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "base", + "exponent", + Float64Array, + { f64::powf } + )), + + DataType::Int64 => Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "base", + "exponent", + Int64Array, + { i64::pow } + )), + + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + + Ok(ColumnarValue::Array(arr)) + } + + /// Simplify the `power` function by the relevant rules: + /// 1. Power(a, 0) ===> 0 + /// 2. Power(a, 1) ===> a + /// 3. Power(a, Log(a, b)) ===> b + fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + let base = &args[0]; + let exponent = &args[1]; + + match exponent { + Expr::Literal(value) + if value == &ScalarValue::new_zero(&info.get_data_type(exponent)?)? => + { + Ok(ExprSimplifyResult::Simplified(Expr::Literal( + ScalarValue::new_one(&info.get_data_type(base)?)?, + ))) + } + Expr::Literal(value) + if value == &ScalarValue::new_one(&info.get_data_type(exponent)?)? => + { + Ok(ExprSimplifyResult::Simplified(base.clone())) + } + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(fun), + args, + }) if base == &args[0] + && fun + .as_ref() + .inner() + .as_any() + .downcast_ref::() + .is_some() => + { + Ok(ExprSimplifyResult::Simplified(args[1].clone())) + } + _ => Ok(ExprSimplifyResult::Original(args)), + } + } +} + +#[cfg(test)] +mod tests { + use datafusion_common::cast::{as_float64_array, as_int64_array}; + + use super::*; + + #[test] + fn test_power_f64() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base + ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent + ]; + + let result = PowerFunc::new() + .invoke(&args) + .expect("failed to initialize function power"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 8.0); + assert_eq!(floats.value(1), 4.0); + assert_eq!(floats.value(2), 81.0); + assert_eq!(floats.value(3), 625.0); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_power_i64() { + let args = [ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base + ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent + ]; + + let result = PowerFunc::new() + .invoke(&args) + .expect("failed to initialize function power"); + + match result { + ColumnarValue::Array(arr) => { + let ints = as_int64_array(&arr) + .expect("failed to convert result to a Int64Array"); + + assert_eq!(ints.len(), 4); + assert_eq!(ints.value(0), 8); + assert_eq!(ints.value(1), 4); + assert_eq!(ints.value(2), 81); + assert_eq!(ints.value(3), 625); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 8b70f76617dd..3198807b04cf 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1318,18 +1318,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), }, - // log - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), - args, - }) => Transformed::yes(simpl_log(args, info)?), - - // power - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), - args, - }) => Transformed::yes(simpl_power(args, info)?), - // concat Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), @@ -2665,68 +2653,6 @@ mod tests { assert_eq!(simplify(expr_eq), lit(true)); } - #[test] - fn test_simplify_log() { - // Log(c3, 1) ===> 0 - { - let expr = log(col("c3_non_null"), lit(1)); - let expected = lit(0i64); - assert_eq!(simplify(expr), expected); - } - // Log(c3, c3) ===> 1 - { - let expr = log(col("c3_non_null"), col("c3_non_null")); - let expected = lit(1i64); - assert_eq!(simplify(expr), expected); - } - // Log(c3, Power(c3, c4)) ===> c4 - { - let expr = log( - col("c3_non_null"), - power(col("c3_non_null"), col("c4_non_null")), - ); - let expected = col("c4_non_null"); - assert_eq!(simplify(expr), expected); - } - // Log(c3, c4) ===> Log(c3, c4) - { - let expr = log(col("c3_non_null"), col("c4_non_null")); - let expected = log(col("c3_non_null"), col("c4_non_null")); - assert_eq!(simplify(expr), expected); - } - } - - #[test] - fn test_simplify_power() { - // Power(c3, 0) ===> 1 - { - let expr = power(col("c3_non_null"), lit(0)); - let expected = lit(1i64); - assert_eq!(simplify(expr), expected); - } - // Power(c3, 1) ===> c3 - { - let expr = power(col("c3_non_null"), lit(1)); - let expected = col("c3_non_null"); - assert_eq!(simplify(expr), expected); - } - // Power(c3, Log(c3, c4)) ===> c4 - { - let expr = power( - col("c3_non_null"), - log(col("c3_non_null"), col("c4_non_null")), - ); - let expected = col("c4_non_null"); - assert_eq!(simplify(expr), expected); - } - // Power(c3, c4) ===> Power(c3, c4) - { - let expr = power(col("c3_non_null"), col("c4_non_null")); - let expected = power(col("c3_non_null"), col("c4_non_null")); - assert_eq!(simplify(expr), expected); - } - } - #[test] fn test_simplify_concat_ws() { let null = lit(ScalarValue::Utf8(None)); diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 8213af76989f..4e06730133d9 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -144,7 +144,7 @@ mod tests { and, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, ExprSchemable, JoinType, }; - use datafusion_expr::{call_fn, or, BinaryExpr, Cast, Operator}; + use datafusion_expr::{or, BinaryExpr, Cast, Operator}; use crate::test::{assert_fields_eq, test_table_scan_with_name}; use crate::OptimizerContext; @@ -712,42 +712,6 @@ mod tests { assert_optimized_plan_eq(&plan, expected) } - #[test] - fn simplify_project_scalar_fn() -> Result<()> { - // Issue https://github.com/apache/arrow-datafusion/issues/5996 - let schema = Schema::new(vec![Field::new("f", DataType::Float64, false)]); - let plan = table_scan(Some("test"), &schema, None)? - .project(vec![call_fn("power", vec![col("f"), lit(1.0)])?])? - .build()?; - - // before simplify: power(t.f, 1.0) - // after simplify: t.f as "power(t.f, 1.0)" - let expected = "Projection: test.f AS power(test.f,Float64(1))\ - \n TableScan: test"; - - assert_optimized_plan_eq(&plan, expected) - } - - #[test] - fn simplify_scan_predicate() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("f", DataType::Float64, false), - Field::new("g", DataType::Float64, false), - ]); - let plan = table_scan_with_filters( - Some("test"), - &schema, - None, - vec![col("g").eq(call_fn("power", vec![col("f"), lit(1.0)])?)], - )? - .build()?; - - // before simplify: t.g = power(t.f, 1.0) - // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)" - let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]"; - assert_optimized_plan_eq(&plan, expected) - } - #[test] fn simplify_is_not_null() -> Result<()> { let table_scan = test_table_scan(); diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 1dd3a6162894..f0ad4738633f 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -18,11 +18,10 @@ //! Utility functions for expression simplification use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ expr::{Between, BinaryExpr, InList, ScalarFunction}, expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or}, - lit, BuiltinScalarFunction, Expr, Like, Operator, ScalarFunctionDefinition, + lit, BuiltinScalarFunction, Expr, Like, Operator, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -343,77 +342,6 @@ pub fn distribute_negation(expr: Expr) -> Expr { } } -/// Simplify the `log` function by the relevant rules: -/// 1. Log(a, 1) ===> 0 -/// 2. Log(a, a) ===> 1 -/// 3. Log(a, Power(a, b)) ===> b -pub fn simpl_log(current_args: Vec, info: &dyn SimplifyInfo) -> Result { - let mut number = ¤t_args[0]; - let mut base = &Expr::Literal(ScalarValue::new_ten(&info.get_data_type(number)?)?); - if current_args.len() == 2 { - base = ¤t_args[0]; - number = ¤t_args[1]; - } - - match number { - Expr::Literal(value) - if value == &ScalarValue::new_one(&info.get_data_type(number)?)? => - { - Ok(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(base)?, - )?)) - } - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), - args, - }) if base == &args[0] => Ok(args[1].clone()), - _ => { - if number == base { - Ok(Expr::Literal(ScalarValue::new_one( - &info.get_data_type(number)?, - )?)) - } else { - Ok(Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::Log, - vec![base.clone(), number.clone()], - ))) - } - } - } -} - -/// Simplify the `power` function by the relevant rules: -/// 1. Power(a, 0) ===> 0 -/// 2. Power(a, 1) ===> a -/// 3. Power(a, Log(a, b)) ===> b -pub fn simpl_power(current_args: Vec, info: &dyn SimplifyInfo) -> Result { - let base = ¤t_args[0]; - let exponent = ¤t_args[1]; - - match exponent { - Expr::Literal(value) - if value == &ScalarValue::new_zero(&info.get_data_type(exponent)?)? => - { - Ok(Expr::Literal(ScalarValue::new_one( - &info.get_data_type(base)?, - )?)) - } - Expr::Literal(value) - if value == &ScalarValue::new_one(&info.get_data_type(exponent)?)? => - { - Ok(base.clone()) - } - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), - args, - }) if base == &args[0] => Ok(args[1].clone()), - _ => Ok(Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::Power, - current_args, - ))), - } -} - /// Simplify the `concat` function by /// 1. filtering out all `null` literals /// 2. concatenating contiguous literal arguments diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 79d69b273d2c..124acdc7ac78 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -197,12 +197,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Trunc => { Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) } - BuiltinScalarFunction::Power => { - Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args)) - } - BuiltinScalarFunction::Log => { - Arc::new(|args| make_scalar_function_inner(math_expressions::log)(args)) - } BuiltinScalarFunction::Cot => { Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 384f8d87eb96..b29230de1f76 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -27,7 +27,7 @@ use arrow::datatypes::DataType; use arrow_array::Array; use rand::{thread_rng, Rng}; -use datafusion_common::ScalarValue::{Float32, Int64}; +use datafusion_common::ScalarValue::Int64; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; @@ -374,85 +374,6 @@ pub fn round(args: &[ArrayRef]) -> Result { } } -/// Power SQL function -pub fn power(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Float64Array, - { f64::powf } - )) as ArrayRef), - - DataType::Int64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Int64Array, - { i64::pow } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function power"), - } -} - -/// Log SQL function -pub fn log(args: &[ArrayRef]) -> Result { - // Support overloaded log(base, x) and log(x) which defaults to log(10, x) - // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) - let mut base = ColumnarValue::Scalar(Float32(Some(10.0))); - - let mut x = &args[0]; - if args.len() == 2 { - x = &args[1]; - base = ColumnarValue::Array(args[0].clone()); - } - match args[0].data_type() { - DataType::Float64 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - let base = base as f64; - Ok( - Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { - |value: f64| f64::log(value, base) - })) as ArrayRef, - ) - } - ColumnarValue::Array(base) => Ok(Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float64Array, - { f64::log } - )) as ArrayRef), - _ => exec_err!("log function requires a scalar or array for base"), - }, - - DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Ok(Arc::new( - make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - }), - ) - as ArrayRef), - ColumnarValue::Array(base) => Ok(Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float32Array, - { f32::log } - )) as ArrayRef), - _ => exec_err!("log function requires a scalar or array for base"), - }, - - other => exec_err!("Unsupported data type {other:?} for function log"), - } -} - ///cot SQL function pub fn cot(args: &[ArrayRef]) -> Result { match args[0].data_type() { @@ -571,78 +492,6 @@ mod tests { assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); } - #[test] - fn test_power_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0])), // base - Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0])), // exponent - ]; - - let result = power(&args).expect("failed to initialize function power"); - let floats = - as_float64_array(&result).expect("failed to initialize function power"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 8.0); - assert_eq!(floats.value(1), 4.0); - assert_eq!(floats.value(2), 81.0); - assert_eq!(floats.value(3), 625.0); - } - - #[test] - fn test_power_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![2, 2, 3, 5])), // base - Arc::new(Int64Array::from(vec![3, 2, 4, 4])), // exponent - ]; - - let result = power(&args).expect("failed to initialize function power"); - let floats = - as_int64_array(&result).expect("failed to initialize function power"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 8); - assert_eq!(floats.value(1), 4); - assert_eq!(floats.value(2), 81); - assert_eq!(floats.value(3), 625); - } - - #[test] - fn test_log_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0])), // base - Arc::new(Float64Array::from(vec![8.0, 4.0, 81.0, 625.0])), // x - ]; - - let result = log(&args).expect("failed to initialize function log"); - let floats = - as_float64_array(&result).expect("failed to initialize function log"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 3.0); - assert_eq!(floats.value(1), 2.0); - assert_eq!(floats.value(2), 4.0); - assert_eq!(floats.value(3), 4.0); - } - - #[test] - fn test_log_f32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![2.0, 2.0, 3.0, 5.0])), // base - Arc::new(Float32Array::from(vec![8.0, 4.0, 81.0, 625.0])), // x - ]; - - let result = log(&args).expect("failed to initialize function log"); - let floats = - as_float32_array(&result).expect("failed to initialize function log"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 3.0); - assert_eq!(floats.value(1), 2.0); - assert_eq!(floats.value(2), 4.0); - assert_eq!(floats.value(3), 4.0); - } - #[test] fn test_round_f32() { let args: Vec = vec![ diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index b656bededc07..0f245673f6cd 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -552,7 +552,7 @@ enum ScalarFunction { Exp = 8; // 9 was Floor // 10 was Ln - Log = 11; + // 11 was Log // 12 was Log10 // 13 was Log2 Round = 14; @@ -605,7 +605,7 @@ enum ScalarFunction { // Trim = 61; // Upper = 62; Coalesce = 63; - Power = 64; + // 64 was Power // 65 was StructFun // 66 was FromUnixtime // 67 Atan2 diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c13ae045bdb5..0922fccc7917 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22794,7 +22794,6 @@ impl serde::Serialize for ScalarFunction { Self::Unknown => "unknown", Self::Ceil => "Ceil", Self::Exp => "Exp", - Self::Log => "Log", Self::Round => "Round", Self::Trunc => "Trunc", Self::Concat => "Concat", @@ -22802,7 +22801,6 @@ impl serde::Serialize for ScalarFunction { Self::InitCap => "InitCap", Self::Random => "Random", Self::Coalesce => "Coalesce", - Self::Power => "Power", Self::Factorial => "Factorial", Self::Cot => "Cot", Self::Nanvl => "Nanvl", @@ -22822,7 +22820,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown", "Ceil", "Exp", - "Log", "Round", "Trunc", "Concat", @@ -22830,7 +22827,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "InitCap", "Random", "Coalesce", - "Power", "Factorial", "Cot", "Nanvl", @@ -22879,7 +22875,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown" => Ok(ScalarFunction::Unknown), "Ceil" => Ok(ScalarFunction::Ceil), "Exp" => Ok(ScalarFunction::Exp), - "Log" => Ok(ScalarFunction::Log), "Round" => Ok(ScalarFunction::Round), "Trunc" => Ok(ScalarFunction::Trunc), "Concat" => Ok(ScalarFunction::Concat), @@ -22887,7 +22882,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "InitCap" => Ok(ScalarFunction::InitCap), "Random" => Ok(ScalarFunction::Random), "Coalesce" => Ok(ScalarFunction::Coalesce), - "Power" => Ok(ScalarFunction::Power), "Factorial" => Ok(ScalarFunction::Factorial), "Cot" => Ok(ScalarFunction::Cot), "Nanvl" => Ok(ScalarFunction::Nanvl), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 092d5c59d081..db7614144983 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2851,7 +2851,7 @@ pub enum ScalarFunction { Exp = 8, /// 9 was Floor /// 10 was Ln - Log = 11, + /// 11 was Log /// 12 was Log10 /// 13 was Log2 Round = 14, @@ -2904,7 +2904,7 @@ pub enum ScalarFunction { /// Trim = 61; /// Upper = 62; Coalesce = 63, - Power = 64, + /// 64 was Power /// 65 was StructFun /// 66 was FromUnixtime /// 67 Atan2 @@ -2989,7 +2989,6 @@ impl ScalarFunction { ScalarFunction::Unknown => "unknown", ScalarFunction::Ceil => "Ceil", ScalarFunction::Exp => "Exp", - ScalarFunction::Log => "Log", ScalarFunction::Round => "Round", ScalarFunction::Trunc => "Trunc", ScalarFunction::Concat => "Concat", @@ -2997,7 +2996,6 @@ impl ScalarFunction { ScalarFunction::InitCap => "InitCap", ScalarFunction::Random => "Random", ScalarFunction::Coalesce => "Coalesce", - ScalarFunction::Power => "Power", ScalarFunction::Factorial => "Factorial", ScalarFunction::Cot => "Cot", ScalarFunction::Nanvl => "Nanvl", @@ -3011,7 +3009,6 @@ impl ScalarFunction { "unknown" => Some(Self::Unknown), "Ceil" => Some(Self::Ceil), "Exp" => Some(Self::Exp), - "Log" => Some(Self::Log), "Round" => Some(Self::Round), "Trunc" => Some(Self::Trunc), "Concat" => Some(Self::Concat), @@ -3019,7 +3016,6 @@ impl ScalarFunction { "InitCap" => Some(Self::InitCap), "Random" => Some(Self::Random), "Coalesce" => Some(Self::Coalesce), - "Power" => Some(Self::Power), "Factorial" => Some(Self::Factorial), "Cot" => Some(Self::Cot), "Nanvl" => Some(Self::Nanvl), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 9c24a3941895..6a2e89fe00a3 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -39,9 +39,9 @@ use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_ use datafusion_expr::{ ceil, coalesce, concat_expr, concat_ws_expr, cot, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, initcap, iszero, log, + factorial, initcap, iszero, logical_plan::{PlanType, StringifiedPlan}, - nanvl, power, random, round, trunc, AggregateFunction, Between, BinaryExpr, + nanvl, random, round, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -421,7 +421,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Unknown => todo!(), ScalarFunction::Cot => Self::Cot, ScalarFunction::Exp => Self::Exp, - ScalarFunction::Log => Self::Log, ScalarFunction::Factorial => Self::Factorial, ScalarFunction::Ceil => Self::Ceil, ScalarFunction::Round => Self::Round, @@ -432,7 +431,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::InitCap => Self::InitCap, ScalarFunction::Random => Self::Random, ScalarFunction::Coalesce => Self::Coalesce, - ScalarFunction::Power => Self::Power, ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::Iszero => Self::Iszero, } @@ -1320,14 +1318,6 @@ pub fn parse_expr( ScalarFunction::Coalesce => { Ok(coalesce(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Power => Ok(power( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Log => Ok(log( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Nanvl => Ok(nanvl( parse_expr(&args[0], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index bd964b43d418..db9653e32346 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1410,7 +1410,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Exp => Self::Exp, BuiltinScalarFunction::Factorial => Self::Factorial, - BuiltinScalarFunction::Log => Self::Log, BuiltinScalarFunction::Ceil => Self::Ceil, BuiltinScalarFunction::Round => Self::Round, BuiltinScalarFunction::Trunc => Self::Trunc, @@ -1420,7 +1419,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::InitCap => Self::InitCap, BuiltinScalarFunction::Random => Self::Random, BuiltinScalarFunction::Coalesce => Self::Coalesce, - BuiltinScalarFunction::Power => Self::Power, BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::Iszero => Self::Iszero, }; From 0088c28254ea7ab1fae66b5f1bfcc66e0c9aa7a7 Mon Sep 17 00:00:00 2001 From: Kunal Kundu Date: Tue, 9 Apr 2024 02:01:41 +0530 Subject: [PATCH 14/15] Remove FORMAT <..> backwards compatibility options from COPY (#9985) * Revert "Add test for reading back file created with FORMAT options (#9753)" This reverts commit b50f3aad043da9de613f422f20f7aa916ce55776. * Revert "support format in options of COPY command (#9744)" This reverts commit 40fb1b859be4dd399922c498d49b9b847874af2b. * update docs and example to remove old syntax --- datafusion/sql/src/parser.rs | 4 +- datafusion/sql/src/statement.rs | 12 ++--- datafusion/sql/tests/sql_integration.rs | 12 ----- datafusion/sqllogictest/test_files/copy.slt | 59 --------------------- docs/source/user-guide/sql/dml.md | 3 +- 5 files changed, 7 insertions(+), 83 deletions(-) diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 67fa1325eea7..5a999ab21d30 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -87,11 +87,11 @@ impl fmt::Display for ExplainStatement { /// /// ```sql /// COPY lineitem TO 'lineitem' -/// (format parquet, +/// STORED AS PARQUET ( /// partitions 16, /// row_group_limit_rows 100000, /// row_group_limit_bytes 200000 -/// ) +/// ) /// /// COPY (SELECT l_orderkey from lineitem) to 'lineitem.parquet'; /// ``` diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 6b89f89aaccf..1bb024733c34 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -850,7 +850,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return plan_err!("Unsupported Value in COPY statement {}", value); } }; - if !(key.contains('.') || key == "format") { + if !(&key.contains('.')) { // If config does not belong to any namespace, assume it is // a format option and apply the format prefix for backwards // compatibility. @@ -866,16 +866,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FileType::from_str(&file_type).map_err(|_| { DataFusionError::Configuration(format!("Unknown FileType {}", file_type)) })? - } else if let Some(format) = options.remove("format") { - // try to infer file format from the "format" key in options - FileType::from_str(&format) - .map_err(|e| DataFusionError::Configuration(format!("{}", e)))? } else { let e = || { DataFusionError::Configuration( - "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." - .to_string(), - ) + "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." + .to_string(), + ) }; // try to infer file format from file extension let extension: &str = &Path::new(&statement.target) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index f2f188105faf..e923a15372d0 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -444,18 +444,6 @@ CopyTo: format=csv output_url=output.csv options: () quick_test(sql, plan); } -#[test] -fn plan_copy_stored_as_priority() { - let sql = "COPY (select * from (values (1))) to 'output/' STORED AS CSV OPTIONS (format json)"; - let plan = r#" -CopyTo: format=csv output_url=output/ options: (format json) - Projection: column1 - Values: (Int64(1)) - "# - .trim(); - quick_test(sql, plan); -} - #[test] fn plan_insert() { let sql = diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 95b6d29db407..75f1ccb07aac 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -514,65 +514,6 @@ OPTIONS ( ); -# Format Options Support with format in OPTIONS -# -# i.e. COPY { table_name | query } TO 'file_name' OPTIONS (format , ...) - -# Ensure that the format is set in the OPTIONS, not extension -query I -COPY (select * from (values (1))) to 'test_files/scratch/copy/foo.dat' -OPTIONS (format parquet); ----- -1 - -statement ok -CREATE EXTERNAL TABLE foo_dat STORED AS PARQUET LOCATION 'test_files/scratch/copy/foo.dat'; - -query I -select * from foo_dat; ----- -1 - -statement ok -DROP TABLE foo_dat; - - -query I -COPY (select * from (values (1))) to 'test_files/scratch/copy' -OPTIONS (format parquet); ----- -1 - -query I -COPY (select * from (values (1))) to 'test_files/scratch/copy/' -OPTIONS (format parquet, compression 'zstd(10)'); ----- -1 - -query I -COPY (select * from (values (1))) to 'test_files/scratch/copy/' -OPTIONS (format json, compression gzip); ----- -1 - -query I -COPY (select * from (values (1))) to 'test_files/scratch/copy/' -OPTIONS ( - format csv, - has_header false, - compression xz, - datetime_format '%FT%H:%M:%S.%9f', - delimiter ';', - null_value 'NULLVAL' -); ----- -1 - -query error DataFusion error: Invalid or Unsupported Configuration: This feature is not implemented: Unknown FileType: NOTVALIDFORMAT -COPY (select * from (values (1))) to 'test_files/scratch/copy/' -OPTIONS (format notvalidformat, compression 'zstd(5)'); - - # Error cases: # Copy from table with options diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 666e86b46002..42e0c8054c9b 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -44,8 +44,7 @@ separate hive-style directories. The output format is determined by the first match of the following rules: 1. Value of `STORED AS` -2. Value of the `OPTION (FORMAT ..)` -3. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) +2. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) For a detailed list of valid OPTIONS, see [Write Options](write_options). From 78f8ef16b89ffa38f80e8c7e3f21948602901787 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Mon, 8 Apr 2024 19:08:13 -0400 Subject: [PATCH 15/15] move Trunc, Cot, Round, iszero functions to datafusion-functions (#10000) * move Floor, Gcd, Lcm, Pi to datafusion-functions * remove floor fn * move Trunc, Cot, Round, iszero functions to datafusion-functions * Make mod iszero public, minor ordering change to keep the alphabetical ordering theme. --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/built_in_function.rs | 61 +-- datafusion/expr/src/expr_fn.rs | 46 +- datafusion/functions/src/math/cot.rs | 166 +++++++ datafusion/functions/src/math/iszero.rs | 141 ++++++ datafusion/functions/src/math/mod.rs | 300 +++++++++--- datafusion/functions/src/math/round.rs | 252 ++++++++++ datafusion/functions/src/math/trunc.rs | 235 ++++++++++ .../src/equivalence/projection.rs | 11 +- datafusion/physical-expr/src/functions.rs | 12 - .../physical-expr/src/math_expressions.rs | 435 ------------------ datafusion/proto/proto/datafusion.proto | 8 +- datafusion/proto/src/generated/pbjson.rs | 12 - datafusion/proto/src/generated/prost.rs | 16 +- .../proto/src/logical_plan/from_proto.rs | 20 +- datafusion/proto/src/logical_plan/to_proto.rs | 4 - .../tests/cases/roundtrip_physical_plan.rs | 29 +- datafusion/sql/tests/sql_integration.rs | 5 + 17 files changed, 1061 insertions(+), 692 deletions(-) create mode 100644 datafusion/functions/src/math/cot.rs create mode 100644 datafusion/functions/src/math/iszero.rs create mode 100644 datafusion/functions/src/math/round.rs create mode 100644 datafusion/functions/src/math/trunc.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index d98d7d0abfe2..a6795e99d751 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -45,17 +45,8 @@ pub enum BuiltinScalarFunction { Exp, /// factorial Factorial, - /// iszero - Iszero, /// nanvl Nanvl, - /// round - Round, - /// trunc - Trunc, - /// cot - Cot, - // string functions /// concat Concat, @@ -123,11 +114,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Exp => Volatility::Immutable, BuiltinScalarFunction::Factorial => Volatility::Immutable, - BuiltinScalarFunction::Iszero => Volatility::Immutable, BuiltinScalarFunction::Nanvl => Volatility::Immutable, - BuiltinScalarFunction::Round => Volatility::Immutable, - BuiltinScalarFunction::Cot => Volatility::Immutable, - BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, @@ -175,16 +162,12 @@ impl BuiltinScalarFunction { _ => Ok(Float64), }, - BuiltinScalarFunction::Iszero => Ok(Boolean), - - BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Round - | BuiltinScalarFunction::Trunc - | BuiltinScalarFunction::Cot => match input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => { + match input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } } } @@ -217,24 +200,6 @@ impl BuiltinScalarFunction { self.volatility(), ), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), - BuiltinScalarFunction::Round => Signature::one_of( - vec![ - Exact(vec![Float64, Int64]), - Exact(vec![Float32, Int64]), - Exact(vec![Float64]), - Exact(vec![Float32]), - ], - self.volatility(), - ), - BuiltinScalarFunction::Trunc => Signature::one_of( - vec![ - Exact(vec![Float32, Int64]), - Exact(vec![Float64, Int64]), - Exact(vec![Float64]), - Exact(vec![Float32]), - ], - self.volatility(), - ), BuiltinScalarFunction::Nanvl => Signature::one_of( vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], self.volatility(), @@ -242,9 +207,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Factorial => { Signature::uniform(1, vec![Int64], self.volatility()) } - BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Cot => { + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => { // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). @@ -252,10 +215,6 @@ impl BuiltinScalarFunction { // will be as good as the number of digits in the number Signature::uniform(1, vec![Float64, Float32], self.volatility()) } - BuiltinScalarFunction::Iszero => Signature::one_of( - vec![Exact(vec![Float32]), Exact(vec![Float64])], - self.volatility(), - ), } } @@ -268,8 +227,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Round - | BuiltinScalarFunction::Trunc ) { Some(vec![Some(true)]) } else { @@ -281,14 +238,10 @@ impl BuiltinScalarFunction { pub fn aliases(&self) -> &'static [&'static str] { match self { BuiltinScalarFunction::Ceil => &["ceil"], - BuiltinScalarFunction::Cot => &["cot"], BuiltinScalarFunction::Exp => &["exp"], BuiltinScalarFunction::Factorial => &["factorial"], - BuiltinScalarFunction::Iszero => &["iszero"], BuiltinScalarFunction::Nanvl => &["nanvl"], BuiltinScalarFunction::Random => &["random"], - BuiltinScalarFunction::Round => &["round"], - BuiltinScalarFunction::Trunc => &["trunc"], // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index b554d87bade1..1e28e27af1e0 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -530,7 +530,6 @@ macro_rules! nary_scalar_expr { // generate methods for creating the supported unary/binary expressions // math functions -scalar_expr!(Cot, cot, num, "cotangent of a number"); scalar_expr!(Factorial, factorial, num, "factorial"); scalar_expr!( Ceil, @@ -538,12 +537,7 @@ scalar_expr!( num, "nearest integer greater than or equal to argument" ); -nary_scalar_expr!(Round, round, "round to nearest integer"); -nary_scalar_expr!( - Trunc, - trunc, - "truncate toward zero, with optional precision" -); + scalar_expr!(Exp, exp, num, "exponential"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); @@ -557,12 +551,6 @@ nary_scalar_expr!( ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); -scalar_expr!( - Iszero, - iszero, - num, - "returns true if a given number is +0.0 or -0.0 otherwise returns false" -); /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { @@ -872,12 +860,6 @@ impl WindowUDFImpl for SimpleWindowUDF { } /// Calls a named built in function -/// ``` -/// use datafusion_expr::{col, lit, call_fn}; -/// -/// // create the expression trunc(x) < 0.2 -/// let expr = call_fn("trunc", vec![col("x")]).unwrap().lt(lit(0.2)); -/// ``` pub fn call_fn(name: impl AsRef, args: Vec) -> Result { match name.as_ref().parse::() { Ok(fun) => Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))), @@ -935,38 +917,12 @@ mod test { }; } - macro_rules! test_nary_scalar_expr { - ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { - let expected = [$(stringify!($arg)),*]; - let result = $FUNC( - vec![ - $( - col(stringify!($arg.to_string())) - ),* - ] - ); - if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(expected.len(), args.len()); - } else { - assert!(false, "unexpected: {:?}", result); - } - }; -} - #[test] fn scalar_function_definitions() { - test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Factorial, factorial); test_unary_scalar_expr!(Ceil, ceil); - test_nary_scalar_expr!(Round, round, input); - test_nary_scalar_expr!(Round, round, input, decimal_places); - test_nary_scalar_expr!(Trunc, trunc, num); - test_nary_scalar_expr!(Trunc, trunc, num, precision); test_unary_scalar_expr!(Exp, exp); test_scalar_expr!(Nanvl, nanvl, x, y); - test_scalar_expr!(Iszero, iszero, input); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs new file mode 100644 index 000000000000..66219960d9a2 --- /dev/null +++ b/datafusion/functions/src/math/cot.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Float32, Float64}; + +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct CotFunc { + signature: Signature, +} + +impl Default for CotFunc { + fn default() -> Self { + CotFunc::new() + } +} + +impl CotFunc { + pub fn new() -> Self { + use DataType::*; + Self { + // math expressions expect 1 argument of type f64 or f32 + // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we + // return the best approximation for it (in f64). + // We accept f32 because in this case it is clear that the best approximation + // will be as good as the number of digits in the number + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for CotFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "cot" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(cot, vec![])(args) + } +} + +///cot SQL function +fn cot(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + Float64 => Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "x", + Float64Array, + { compute_cot64 } + )) as ArrayRef), + Float32 => Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "x", + Float32Array, + { compute_cot32 } + )) as ArrayRef), + other => exec_err!("Unsupported data type {other:?} for function cot"), + } +} + +fn compute_cot32(x: f32) -> f32 { + let a = f32::tan(x); + 1.0 / a +} + +fn compute_cot64(x: f64) -> f64 { + let a = f64::tan(x); + 1.0 / a +} + +#[cfg(test)] +mod test { + use crate::math::cot::cot; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use std::sync::Arc; + + #[test] + fn test_cot_f32() { + let args: Vec = + vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; + let result = cot(&args).expect("failed to initialize function cot"); + let floats = + as_float32_array(&result).expect("failed to initialize function cot"); + + let expected = Float32Array::from(vec![ + -1.986_460_4, + -0.156_119_96, + -0.501_202_8, + 0.156_119_96, + ]); + + let eps = 1e-6; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + + #[test] + fn test_cot_f64() { + let args: Vec = + vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; + let result = cot(&args).expect("failed to initialize function cot"); + let floats = + as_float64_array(&result).expect("failed to initialize function cot"); + + let expected = Float64Array::from(vec![ + -1.986_458_685_881_4, + -0.156_119_952_161_6, + -0.501_202_783_380_1, + 0.156_119_952_161_6, + ]); + + let eps = 1e-12; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } +} diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs new file mode 100644 index 000000000000..e6a728053359 --- /dev/null +++ b/datafusion/functions/src/math/iszero.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Boolean, Float32, Float64}; + +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct IsZeroFunc { + signature: Signature, +} + +impl Default for IsZeroFunc { + fn default() -> Self { + IsZeroFunc::new() + } +} + +impl IsZeroFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Float32]), Exact(vec![Float64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for IsZeroFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "iszero" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(iszero, vec![])(args) + } +} + +/// Iszero SQL function +pub fn iszero(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float64Array, + BooleanArray, + { |x: f64| { x == 0_f64 } } + )) as ArrayRef), + + Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float32Array, + BooleanArray, + { |x: f32| { x == 0_f32 } } + )) as ArrayRef), + + other => exec_err!("Unsupported data type {other:?} for function iszero"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + + use datafusion_common::cast::as_boolean_array; + + use crate::math::iszero::iszero; + + #[test] + fn test_iszero_f64() { + let args: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; + + let result = iszero(&args).expect("failed to initialize function iszero"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function iszero"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } + + #[test] + fn test_iszero_f32() { + let args: Vec = + vec![Arc::new(Float32Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; + + let result = iszero(&args).expect("failed to initialize function iszero"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function iszero"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 2655edfe76dc..544de04e4a98 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -17,92 +17,260 @@ //! "math" DataFusion functions +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + pub mod abs; +pub mod cot; pub mod gcd; +pub mod iszero; pub mod lcm; pub mod log; pub mod nans; pub mod pi; pub mod power; +pub mod round; +pub mod trunc; // Create UDFs -make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); -make_udf_function!(log::LogFunc, LOG, log); -make_udf_function!(power::PowerFunc, POWER, power); -make_udf_function!(gcd::GcdFunc, GCD, gcd); -make_udf_function!(lcm::LcmFunc, LCM, lcm); -make_udf_function!(pi::PiFunc, PI, pi); - -make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); -make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); -make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)])); - -make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, None); make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None); +make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None); -make_math_unary_udf!(TanFunc, TAN, tan, tan, None); - -make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)])); -make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); +make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, Some(vec![Some(true)])); - +make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, None); +make_math_unary_udf!(CosFunc, COS, cos, cos, None); +make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, None); +make_udf_function!(cot::CotFunc, COT, cot); +make_math_unary_udf!(DegreesFunc, DEGREES, degrees, to_degrees, None); +make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, Some(vec![Some(true)])); +make_udf_function!(log::LogFunc, LOG, log); +make_udf_function!(gcd::GcdFunc, GCD, gcd); +make_udf_function!(nans::IsNanFunc, ISNAN, isnan); +make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero); +make_udf_function!(lcm::LcmFunc, LCM, lcm); +make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)])); +make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); +make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); +make_udf_function!(pi::PiFunc, PI, pi); +make_udf_function!(power::PowerFunc, POWER, power); make_math_unary_udf!(RadiansFunc, RADIANS, radians, to_radians, None); +make_udf_function!(round::RoundFunc, ROUND, round); make_math_unary_udf!(SignumFunc, SIGNUM, signum, signum, None); make_math_unary_udf!(SinFunc, SIN, sin, sin, None); make_math_unary_udf!(SinhFunc, SINH, sinh, sinh, None); make_math_unary_udf!(SqrtFunc, SQRT, sqrt, sqrt, None); +make_math_unary_udf!(TanFunc, TAN, tan, tan, None); +make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, None); +make_udf_function!(trunc::TruncFunc, TRUNC, trunc); -make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, None); -make_math_unary_udf!(CosFunc, COS, cos, cos, None); -make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, None); -make_math_unary_udf!(DegreesFunc, DEGREES, degrees, to_degrees, None); +pub mod expr_fn { + use datafusion_expr::Expr; -make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, Some(vec![Some(true)])); + #[doc = "returns the absolute value of a given number"] + pub fn abs(num: Expr) -> Expr { + super::abs().call(vec![num]) + } + + #[doc = "returns the arc cosine or inverse cosine of a number"] + pub fn acos(num: Expr) -> Expr { + super::acos().call(vec![num]) + } + + #[doc = "returns inverse hyperbolic cosine"] + pub fn acosh(num: Expr) -> Expr { + super::acosh().call(vec![num]) + } + + #[doc = "returns the arc sine or inverse sine of a number"] + pub fn asin(num: Expr) -> Expr { + super::asin().call(vec![num]) + } + + #[doc = "returns inverse hyperbolic sine"] + pub fn asinh(num: Expr) -> Expr { + super::asinh().call(vec![num]) + } + + #[doc = "returns inverse tangent"] + pub fn atan(num: Expr) -> Expr { + super::atan().call(vec![num]) + } + + #[doc = "returns inverse tangent of a division given in the argument"] + pub fn atan2(y: Expr, x: Expr) -> Expr { + super::atan2().call(vec![y, x]) + } + + #[doc = "returns inverse hyperbolic tangent"] + pub fn atanh(num: Expr) -> Expr { + super::atanh().call(vec![num]) + } + + #[doc = "cube root of a number"] + pub fn cbrt(num: Expr) -> Expr { + super::cbrt().call(vec![num]) + } + + #[doc = "cosine"] + pub fn cos(num: Expr) -> Expr { + super::cos().call(vec![num]) + } + + #[doc = "hyperbolic cosine"] + pub fn cosh(num: Expr) -> Expr { + super::cosh().call(vec![num]) + } + + #[doc = "cotangent of a number"] + pub fn cot(num: Expr) -> Expr { + super::cot().call(vec![num]) + } + + #[doc = "converts radians to degrees"] + pub fn degrees(num: Expr) -> Expr { + super::degrees().call(vec![num]) + } + + #[doc = "nearest integer less than or equal to argument"] + pub fn floor(num: Expr) -> Expr { + super::floor().call(vec![num]) + } + + #[doc = "greatest common divisor"] + pub fn gcd(x: Expr, y: Expr) -> Expr { + super::gcd().call(vec![x, y]) + } + + #[doc = "returns true if a given number is +NaN or -NaN otherwise returns false"] + pub fn isnan(num: Expr) -> Expr { + super::isnan().call(vec![num]) + } + + #[doc = "returns true if a given number is +0.0 or -0.0 otherwise returns false"] + pub fn iszero(num: Expr) -> Expr { + super::iszero().call(vec![num]) + } + + #[doc = "least common multiple"] + pub fn lcm(x: Expr, y: Expr) -> Expr { + super::lcm().call(vec![x, y]) + } + + #[doc = "natural logarithm (base e) of a number"] + pub fn ln(num: Expr) -> Expr { + super::ln().call(vec![num]) + } + + #[doc = "logarithm of a number for a particular `base`"] + pub fn log(base: Expr, num: Expr) -> Expr { + super::log().call(vec![base, num]) + } + + #[doc = "base 2 logarithm of a number"] + pub fn log2(num: Expr) -> Expr { + super::log2().call(vec![num]) + } + + #[doc = "base 10 logarithm of a number"] + pub fn log10(num: Expr) -> Expr { + super::log10().call(vec![num]) + } + + #[doc = "Returns an approximate value of π"] + pub fn pi() -> Expr { + super::pi().call(vec![]) + } + + #[doc = "`base` raised to the power of `exponent`"] + pub fn power(base: Expr, exponent: Expr) -> Expr { + super::power().call(vec![base, exponent]) + } + + #[doc = "converts degrees to radians"] + pub fn radians(num: Expr) -> Expr { + super::radians().call(vec![num]) + } + + #[doc = "round to nearest integer"] + pub fn round(args: Vec) -> Expr { + super::round().call(args) + } + + #[doc = "sign of the argument (-1, 0, +1)"] + pub fn signum(num: Expr) -> Expr { + super::signum().call(vec![num]) + } + + #[doc = "sine"] + pub fn sin(num: Expr) -> Expr { + super::sin().call(vec![num]) + } + + #[doc = "hyperbolic sine"] + pub fn sinh(num: Expr) -> Expr { + super::sinh().call(vec![num]) + } + + #[doc = "square root of a number"] + pub fn sqrt(num: Expr) -> Expr { + super::sqrt().call(vec![num]) + } + + #[doc = "returns the tangent of a number"] + pub fn tan(num: Expr) -> Expr { + super::tan().call(vec![num]) + } + + #[doc = "returns the hyperbolic tangent of a number"] + pub fn tanh(num: Expr) -> Expr { + super::tanh().call(vec![num]) + } + + #[doc = "truncate toward zero, with optional precision"] + pub fn trunc(args: Vec) -> Expr { + super::trunc().call(args) + } +} -// Export the functions out of this package, both as expr_fn as well as a list of functions -export_functions!( - ( - isnan, - num, - "returns true if a given number is +NaN or -NaN otherwise returns false" - ), - (abs, num, "returns the absolute value of a given number"), - (power, base exponent, "`base` raised to the power of `exponent`"), - (log, base num, "logarithm of a number for a particular `base`"), - (log2, num, "base 2 logarithm of a number"), - (log10, num, "base 10 logarithm of a number"), - (ln, num, "natural logarithm (base e) of a number"), - ( - acos, - num, - "returns the arc cosine or inverse cosine of a number" - ), - ( - asin, - num, - "returns the arc sine or inverse sine of a number" - ), - (tan, num, "returns the tangent of a number"), - (tanh, num, "returns the hyperbolic tangent of a number"), - (atanh, num, "returns inverse hyperbolic tangent"), - (asinh, num, "returns inverse hyperbolic sine"), - (acosh, num, "returns inverse hyperbolic cosine"), - (atan, num, "returns inverse tangent"), - (atan2, y x, "returns inverse tangent of a division given in the argument"), - (radians, num, "converts degrees to radians"), - (signum, num, "sign of the argument (-1, 0, +1)"), - (sin, num, "sine"), - (sinh, num, "hyperbolic sine"), - (sqrt, num, "square root of a number"), - (cbrt, num, "cube root of a number"), - (cos, num, "cosine"), - (cosh, num, "hyperbolic cosine"), - (degrees, num, "converts radians to degrees"), - (gcd, x y, "greatest common divisor"), - (lcm, x y, "least common multiple"), - (floor, num, "nearest integer less than or equal to argument"), - (pi, , "Returns an approximate value of π") -); +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![ + abs(), + acos(), + acosh(), + asin(), + asinh(), + atan(), + atan2(), + atanh(), + cbrt(), + cos(), + cosh(), + cot(), + degrees(), + floor(), + gcd(), + isnan(), + iszero(), + lcm(), + ln(), + log(), + log2(), + log10(), + pi(), + power(), + radians(), + round(), + signum(), + sin(), + sinh(), + sqrt(), + tan(), + tanh(), + trunc(), + ] +} diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs new file mode 100644 index 000000000000..f4a163137a35 --- /dev/null +++ b/datafusion/functions/src/math/round.rs @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Float32, Float64}; + +use crate::utils::make_scalar_function; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, FuncMonotonicity}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub struct RoundFunc { + signature: Signature, +} + +impl Default for RoundFunc { + fn default() -> Self { + RoundFunc::new() + } +} + +impl RoundFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Float64, Int64]), + Exact(vec![Float32, Int64]), + Exact(vec![Float64]), + Exact(vec![Float32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RoundFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "round" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(round, vec![])(args) + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } +} + +/// Round SQL function +pub fn round(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!( + "round function requires one or two arguments, got {}", + args.len() + ); + } + + let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); + + if args.len() == 2 { + decimal_places = ColumnarValue::Array(args[1].clone()); + } + + match args[0].data_type() { + DataType::Float64 => match decimal_places { + ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { + let decimal_places = decimal_places.try_into().unwrap(); + + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float64Array, + { + |value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + } + } + )) as ArrayRef) + } + ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float64Array, + Int64Array, + { + |value: f64, decimal_places: i64| { + (value * 10.0_f64.powi(decimal_places.try_into().unwrap())) + .round() + / 10.0_f64.powi(decimal_places.try_into().unwrap()) + } + } + )) as ArrayRef), + _ => { + exec_err!("round function requires a scalar or array for decimal_places") + } + }, + + DataType::Float32 => match decimal_places { + ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { + let decimal_places = decimal_places.try_into().unwrap(); + + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float32Array, + { + |value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + } + } + )) as ArrayRef) + } + ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float32Array, + Int64Array, + { + |value: f32, decimal_places: i64| { + (value * 10.0_f32.powi(decimal_places.try_into().unwrap())) + .round() + / 10.0_f32.powi(decimal_places.try_into().unwrap()) + } + } + )) as ArrayRef), + _ => { + exec_err!("round function requires a scalar or array for decimal_places") + } + }, + + other => exec_err!("Unsupported data type {other:?} for function round"), + } +} + +#[cfg(test)] +mod test { + use crate::math::round::round; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use std::sync::Arc; + + #[test] + fn test_round_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![125.2345; 10])), // input + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float32_array(&result).expect("failed to initialize function round"); + + let expected = Float32Array::from(vec![ + 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, + ]); + + assert_eq!(floats, &expected); + } + + #[test] + fn test_round_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345; 10])), // input + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float64_array(&result).expect("failed to initialize function round"); + + let expected = Float64Array::from(vec![ + 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, + ]); + + assert_eq!(floats, &expected); + } + + #[test] + fn test_round_f32_one_input() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float32_array(&result).expect("failed to initialize function round"); + + let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]); + + assert_eq!(floats, &expected); + } + + #[test] + fn test_round_f64_one_input() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float64_array(&result).expect("failed to initialize function round"); + + let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]); + + assert_eq!(floats, &expected); + } +} diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs new file mode 100644 index 000000000000..6f88099889cc --- /dev/null +++ b/datafusion/functions/src/math/trunc.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Float32, Float64}; + +use crate::utils::make_scalar_function; +use datafusion_common::ScalarValue::Int64; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, FuncMonotonicity}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub struct TruncFunc { + signature: Signature, +} + +impl Default for TruncFunc { + fn default() -> Self { + TruncFunc::new() + } +} + +impl TruncFunc { + pub fn new() -> Self { + use DataType::*; + Self { + // math expressions expect 1 argument of type f64 or f32 + // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we + // return the best approximation for it (in f64). + // We accept f32 because in this case it is clear that the best approximation + // will be as good as the number of digits in the number + signature: Signature::one_of( + vec![ + Exact(vec![Float32, Int64]), + Exact(vec![Float64, Int64]), + Exact(vec![Float64]), + Exact(vec![Float32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TruncFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(trunc, vec![])(args) + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } +} + +/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function +fn trunc(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!( + "truncate function requires one or two arguments, got {}", + args.len() + ); + } + + //if only one arg then invoke toolchain trunc(num) and precision = 0 by default + //or then invoke the compute_truncate method to process precision + let num = &args[0]; + let precision = if args.len() == 1 { + ColumnarValue::Scalar(Int64(Some(0))) + } else { + ColumnarValue::Array(args[1].clone()) + }; + + match args[0].data_type() { + Float64 => match precision { + ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( + make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }), + ) as ArrayRef), + ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( + num, + precision, + "x", + "y", + Float64Array, + Int64Array, + { compute_truncate64 } + )) as ArrayRef), + _ => exec_err!("trunc function requires a scalar or array for precision"), + }, + Float32 => match precision { + ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( + make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }), + ) as ArrayRef), + ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( + num, + precision, + "x", + "y", + Float32Array, + Int64Array, + { compute_truncate32 } + )) as ArrayRef), + _ => exec_err!("trunc function requires a scalar or array for precision"), + }, + other => exec_err!("Unsupported data type {other:?} for function trunc"), + } +} + +fn compute_truncate32(x: f32, y: i64) -> f32 { + let factor = 10.0_f32.powi(y as i32); + (x * factor).round() / factor +} + +fn compute_truncate64(x: f64, y: i64) -> f64 { + let factor = 10.0_f64.powi(y as i32); + (x * factor).round() / factor +} + +#[cfg(test)] +mod test { + use crate::math::trunc::trunc; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use std::sync::Arc; + + #[test] + fn test_truncate_32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![ + 15.0, + 1_234.267_8, + 1_233.123_4, + 3.312_979_2, + -21.123_4, + ])), + Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), + ]; + + let result = trunc(&args).expect("failed to initialize function truncate"); + let floats = + as_float32_array(&result).expect("failed to initialize function truncate"); + + assert_eq!(floats.len(), 5); + assert_eq!(floats.value(0), 15.0); + assert_eq!(floats.value(1), 1_234.268); + assert_eq!(floats.value(2), 1_233.12); + assert_eq!(floats.value(3), 3.312_98); + assert_eq!(floats.value(4), -21.123_4); + } + + #[test] + fn test_truncate_64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![ + 5.0, + 234.267_812_176, + 123.123_456_789, + 123.312_979_313_2, + -321.123_1, + ])), + Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), + ]; + + let result = trunc(&args).expect("failed to initialize function truncate"); + let floats = + as_float64_array(&result).expect("failed to initialize function truncate"); + + assert_eq!(floats.len(), 5); + assert_eq!(floats.value(0), 5.0); + assert_eq!(floats.value(1), 234.268); + assert_eq!(floats.value(2), 123.12); + assert_eq!(floats.value(3), 123.312_98); + assert_eq!(floats.value(4), -321.123_1); + } + + #[test] + fn test_truncate_64_one_arg() { + let args: Vec = vec![Arc::new(Float64Array::from(vec![ + 5.0, + 234.267_812, + 123.123_45, + 123.312_979_313_2, + -321.123, + ]))]; + + let result = trunc(&args).expect("failed to initialize function truncate"); + let floats = + as_float64_array(&result).expect("failed to initialize function truncate"); + + assert_eq!(floats.len(), 5); + assert_eq!(floats.value(0), 5.0); + assert_eq!(floats.value(1), 234.0); + assert_eq!(floats.value(2), 123.0); + assert_eq!(floats.value(3), 123.0); + assert_eq!(floats.value(4), -321.0); + } +} diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 5efcf5942c39..92772e4623be 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -117,8 +117,7 @@ mod tests { use itertools::Itertools; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{BuiltinScalarFunction, Operator, ScalarUDF}; + use datafusion_expr::{Operator, ScalarUDF}; use crate::equivalence::tests::{ apply_projection, convert_to_orderings, convert_to_orderings_owned, @@ -649,11 +648,13 @@ mod tests { col_b.clone(), )) as Arc; - let round_c = &crate::functions::create_physical_expr( - &BuiltinScalarFunction::Round, + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let round_c = &create_physical_expr( + &test_fun, &[col_c.clone()], &schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; let option_asc = SortOptions { diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 124acdc7ac78..2be85a69d7da 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -184,22 +184,10 @@ pub fn create_physical_fun( BuiltinScalarFunction::Factorial => { Arc::new(|args| make_scalar_function_inner(math_expressions::factorial)(args)) } - BuiltinScalarFunction::Iszero => { - Arc::new(|args| make_scalar_function_inner(math_expressions::iszero)(args)) - } BuiltinScalarFunction::Nanvl => { Arc::new(|args| make_scalar_function_inner(math_expressions::nanvl)(args)) } BuiltinScalarFunction::Random => Arc::new(math_expressions::random), - BuiltinScalarFunction::Round => { - Arc::new(|args| make_scalar_function_inner(math_expressions::round)(args)) - } - BuiltinScalarFunction::Trunc => { - Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) - } - BuiltinScalarFunction::Cot => { - Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) - } // string functions BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index b29230de1f76..55fb54563787 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -27,7 +27,6 @@ use arrow::datatypes::DataType; use arrow_array::Array; use rand::{thread_rng, Rng}; -use datafusion_common::ScalarValue::Int64; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; @@ -154,17 +153,8 @@ macro_rules! make_function_scalar_inputs_return_type { }}; } -math_unary_function!("asin", asin); -math_unary_function!("acos", acos); -math_unary_function!("atan", atan); -math_unary_function!("asinh", asinh); -math_unary_function!("acosh", acosh); -math_unary_function!("atanh", atanh); math_unary_function!("ceil", ceil); math_unary_function!("exp", exp); -math_unary_function!("ln", ln); -math_unary_function!("log2", log2); -math_unary_function!("log10", log10); /// Factorial SQL function pub fn factorial(args: &[ArrayRef]) -> Result { @@ -247,29 +237,6 @@ pub fn isnan(args: &[ArrayRef]) -> Result { } } -/// Iszero SQL function -pub fn iszero(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { |x: f64| { x == 0_f64 } } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { |x: f32| { x == 0_f32 } } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function iszero"), - } -} - /// Random SQL function pub fn random(args: &[ColumnarValue]) -> Result { let len: usize = match &args[0] { @@ -282,192 +249,6 @@ pub fn random(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } -/// Round SQL function -pub fn round(args: &[ArrayRef]) -> Result { - if args.len() != 1 && args.len() != 2 { - return exec_err!( - "round function requires one or two arguments, got {}", - args.len() - ); - } - - let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); - - if args.len() == 2 { - decimal_places = ColumnarValue::Array(args[1].clone()); - } - - match args[0].data_type() { - DataType::Float64 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); - - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) - } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float64Array, - Int64Array, - { - |value: f64, decimal_places: i64| { - (value * 10.0_f64.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f64.powi(decimal_places.try_into().unwrap()) - } - } - )) as ArrayRef), - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } - }, - - DataType::Float32 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); - - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) - } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float32Array, - Int64Array, - { - |value: f32, decimal_places: i64| { - (value * 10.0_f32.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f32.powi(decimal_places.try_into().unwrap()) - } - } - )) as ArrayRef), - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } - }, - - other => exec_err!("Unsupported data type {other:?} for function round"), - } -} - -///cot SQL function -pub fn cot(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float64Array, - { compute_cot64 } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float32Array, - { compute_cot32 } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function cot"), - } -} - -fn compute_cot32(x: f32) -> f32 { - let a = f32::tan(x); - 1.0 / a -} - -fn compute_cot64(x: f64) -> f64 { - let a = f64::tan(x); - 1.0 / a -} - -/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function -pub fn trunc(args: &[ArrayRef]) -> Result { - if args.len() != 1 && args.len() != 2 { - return exec_err!( - "truncate function requires one or two arguments, got {}", - args.len() - ); - } - - //if only one arg then invoke toolchain trunc(num) and precision = 0 by default - //or then invoke the compute_truncate method to process precision - let num = &args[0]; - let precision = if args.len() == 1 { - ColumnarValue::Scalar(Int64(Some(0))) - } else { - ColumnarValue::Array(args[1].clone()) - }; - - match args[0].data_type() { - DataType::Float64 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }), - ) as ArrayRef), - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float64Array, - Int64Array, - { compute_truncate64 } - )) as ArrayRef), - _ => exec_err!("trunc function requires a scalar or array for precision"), - }, - DataType::Float32 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }), - ) as ArrayRef), - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float32Array, - Int64Array, - { compute_truncate32 } - )) as ArrayRef), - _ => exec_err!("trunc function requires a scalar or array for precision"), - }, - other => exec_err!("Unsupported data type {other:?} for function trunc"), - } -} - -fn compute_truncate32(x: f32, y: i64) -> f32 { - let factor = 10.0_f32.powi(y as i32); - (x * factor).round() / factor -} - -fn compute_truncate64(x: f64, y: i64) -> f64 { - let factor = 10.0_f64.powi(y as i32); - (x * factor).round() / factor -} - #[cfg(test)] mod tests { use arrow::array::{Float64Array, NullArray}; @@ -492,72 +273,6 @@ mod tests { assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); } - #[test] - fn test_round_f32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![125.2345; 10])), // input - Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places - ]; - - let result = round(&args).expect("failed to initialize function round"); - let floats = - as_float32_array(&result).expect("failed to initialize function round"); - - let expected = Float32Array::from(vec![ - 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, - ]); - - assert_eq!(floats, &expected); - } - - #[test] - fn test_round_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![125.2345; 10])), // input - Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places - ]; - - let result = round(&args).expect("failed to initialize function round"); - let floats = - as_float64_array(&result).expect("failed to initialize function round"); - - let expected = Float64Array::from(vec![ - 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, - ]); - - assert_eq!(floats, &expected); - } - - #[test] - fn test_round_f32_one_input() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input - ]; - - let result = round(&args).expect("failed to initialize function round"); - let floats = - as_float32_array(&result).expect("failed to initialize function round"); - - let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]); - - assert_eq!(floats, &expected); - } - - #[test] - fn test_round_f64_one_input() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input - ]; - - let result = round(&args).expect("failed to initialize function round"); - let floats = - as_float64_array(&result).expect("failed to initialize function round"); - - let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]); - - assert_eq!(floats, &expected); - } - #[test] fn test_factorial_i64() { let args: Vec = vec![ @@ -573,124 +288,6 @@ mod tests { assert_eq!(ints, &expected); } - #[test] - fn test_cot_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float32_array(&result).expect("failed to initialize function cot"); - - let expected = Float32Array::from(vec![ - -1.986_460_4, - -0.156_119_96, - -0.501_202_8, - 0.156_119_96, - ]); - - let eps = 1e-6; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); - } - - #[test] - fn test_cot_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float64_array(&result).expect("failed to initialize function cot"); - - let expected = Float64Array::from(vec![ - -1.986_458_685_881_4, - -0.156_119_952_161_6, - -0.501_202_783_380_1, - 0.156_119_952_161_6, - ]); - - let eps = 1e-12; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); - } - - #[test] - fn test_truncate_32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![ - 15.0, - 1_234.267_8, - 1_233.123_4, - 3.312_979_2, - -21.123_4, - ])), - Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), - ]; - - let result = trunc(&args).expect("failed to initialize function truncate"); - let floats = - as_float32_array(&result).expect("failed to initialize function truncate"); - - assert_eq!(floats.len(), 5); - assert_eq!(floats.value(0), 15.0); - assert_eq!(floats.value(1), 1_234.268); - assert_eq!(floats.value(2), 1_233.12); - assert_eq!(floats.value(3), 3.312_98); - assert_eq!(floats.value(4), -21.123_4); - } - - #[test] - fn test_truncate_64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![ - 5.0, - 234.267_812_176, - 123.123_456_789, - 123.312_979_313_2, - -321.123_1, - ])), - Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])), - ]; - - let result = trunc(&args).expect("failed to initialize function truncate"); - let floats = - as_float64_array(&result).expect("failed to initialize function truncate"); - - assert_eq!(floats.len(), 5); - assert_eq!(floats.value(0), 5.0); - assert_eq!(floats.value(1), 234.268); - assert_eq!(floats.value(2), 123.12); - assert_eq!(floats.value(3), 123.312_98); - assert_eq!(floats.value(4), -321.123_1); - } - - #[test] - fn test_truncate_64_one_arg() { - let args: Vec = vec![Arc::new(Float64Array::from(vec![ - 5.0, - 234.267_812, - 123.123_45, - 123.312_979_313_2, - -321.123, - ]))]; - - let result = trunc(&args).expect("failed to initialize function truncate"); - let floats = - as_float64_array(&result).expect("failed to initialize function truncate"); - - assert_eq!(floats.len(), 5); - assert_eq!(floats.value(0), 5.0); - assert_eq!(floats.value(1), 234.0); - assert_eq!(floats.value(2), 123.0); - assert_eq!(floats.value(3), 123.0); - assert_eq!(floats.value(4), -321.0); - } - #[test] fn test_nanvl_f64() { let args: Vec = vec![ @@ -766,36 +363,4 @@ mod tests { assert!(!booleans.value(2)); assert!(booleans.value(3)); } - - #[test] - fn test_iszero_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } - - #[test] - fn test_iszero_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 0f245673f6cd..c7c0d9b5a656 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -555,12 +555,12 @@ enum ScalarFunction { // 11 was Log // 12 was Log10 // 13 was Log2 - Round = 14; + // 14 was Round // 15 was Signum // 16 was Sin // 17 was Sqrt // Tan = 18; - Trunc = 19; + // 19 was Trunc // 20 was Array // RegexpMatch = 21; // 22 was BitLength @@ -642,7 +642,7 @@ enum ScalarFunction { // 98 was Cardinality // 99 was ArrayElement // 100 was ArraySlice - Cot = 103; + // 103 was Cot // 104 was ArrayHas // 105 was ArrayHasAny // 106 was ArrayHasAll @@ -653,7 +653,7 @@ enum ScalarFunction { Nanvl = 111; // 112 was Flatten // 113 was IsNan - Iszero = 114; + // 114 was Iszero // 115 was ArrayEmpty // 116 was ArrayPopBack // 117 was StringToArray diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0922fccc7917..c8a1fba40765 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22794,17 +22794,13 @@ impl serde::Serialize for ScalarFunction { Self::Unknown => "unknown", Self::Ceil => "Ceil", Self::Exp => "Exp", - Self::Round => "Round", - Self::Trunc => "Trunc", Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", Self::Random => "Random", Self::Coalesce => "Coalesce", Self::Factorial => "Factorial", - Self::Cot => "Cot", Self::Nanvl => "Nanvl", - Self::Iszero => "Iszero", Self::EndsWith => "EndsWith", }; serializer.serialize_str(variant) @@ -22820,17 +22816,13 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown", "Ceil", "Exp", - "Round", - "Trunc", "Concat", "ConcatWithSeparator", "InitCap", "Random", "Coalesce", "Factorial", - "Cot", "Nanvl", - "Iszero", "EndsWith", ]; @@ -22875,17 +22867,13 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown" => Ok(ScalarFunction::Unknown), "Ceil" => Ok(ScalarFunction::Ceil), "Exp" => Ok(ScalarFunction::Exp), - "Round" => Ok(ScalarFunction::Round), - "Trunc" => Ok(ScalarFunction::Trunc), "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), "Random" => Ok(ScalarFunction::Random), "Coalesce" => Ok(ScalarFunction::Coalesce), "Factorial" => Ok(ScalarFunction::Factorial), - "Cot" => Ok(ScalarFunction::Cot), "Nanvl" => Ok(ScalarFunction::Nanvl), - "Iszero" => Ok(ScalarFunction::Iszero), "EndsWith" => Ok(ScalarFunction::EndsWith), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index db7614144983..facf24219810 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2854,12 +2854,12 @@ pub enum ScalarFunction { /// 11 was Log /// 12 was Log10 /// 13 was Log2 - Round = 14, + /// 14 was Round /// 15 was Signum /// 16 was Sin /// 17 was Sqrt /// Tan = 18; - Trunc = 19, + /// 19 was Trunc /// 20 was Array /// RegexpMatch = 21; /// 22 was BitLength @@ -2941,7 +2941,7 @@ pub enum ScalarFunction { /// 98 was Cardinality /// 99 was ArrayElement /// 100 was ArraySlice - Cot = 103, + /// 103 was Cot /// 104 was ArrayHas /// 105 was ArrayHasAny /// 106 was ArrayHasAll @@ -2952,7 +2952,7 @@ pub enum ScalarFunction { Nanvl = 111, /// 112 was Flatten /// 113 was IsNan - Iszero = 114, + /// 114 was Iszero /// 115 was ArrayEmpty /// 116 was ArrayPopBack /// 117 was StringToArray @@ -2989,17 +2989,13 @@ impl ScalarFunction { ScalarFunction::Unknown => "unknown", ScalarFunction::Ceil => "Ceil", ScalarFunction::Exp => "Exp", - ScalarFunction::Round => "Round", - ScalarFunction::Trunc => "Trunc", ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", ScalarFunction::Random => "Random", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Factorial => "Factorial", - ScalarFunction::Cot => "Cot", ScalarFunction::Nanvl => "Nanvl", - ScalarFunction::Iszero => "Iszero", ScalarFunction::EndsWith => "EndsWith", } } @@ -3009,17 +3005,13 @@ impl ScalarFunction { "unknown" => Some(Self::Unknown), "Ceil" => Some(Self::Ceil), "Exp" => Some(Self::Exp), - "Round" => Some(Self::Round), - "Trunc" => Some(Self::Trunc), "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), "Random" => Some(Self::Random), "Coalesce" => Some(Self::Coalesce), "Factorial" => Some(Self::Factorial), - "Cot" => Some(Self::Cot), "Nanvl" => Some(Self::Nanvl), - "Iszero" => Some(Self::Iszero), "EndsWith" => Some(Self::EndsWith), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 6a2e89fe00a3..e9eb53e45199 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -37,13 +37,13 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - ceil, coalesce, concat_expr, concat_ws_expr, cot, ends_with, exp, + ceil, coalesce, concat_expr, concat_ws_expr, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, initcap, iszero, + factorial, initcap, logical_plan::{PlanType, StringifiedPlan}, - nanvl, random, round, trunc, AggregateFunction, Between, BinaryExpr, - BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, - GetIndexedField, GroupingSet, + nanvl, random, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, + BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, + GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -419,12 +419,9 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { use protobuf::ScalarFunction; match f { ScalarFunction::Unknown => todo!(), - ScalarFunction::Cot => Self::Cot, ScalarFunction::Exp => Self::Exp, ScalarFunction::Factorial => Self::Factorial, ScalarFunction::Ceil => Self::Ceil, - ScalarFunction::Round => Self::Round, - ScalarFunction::Trunc => Self::Trunc, ScalarFunction::Concat => Self::Concat, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, @@ -432,7 +429,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Random => Self::Random, ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Nanvl => Self::Nanvl, - ScalarFunction::Iszero => Self::Iszero, } } } @@ -1299,8 +1295,6 @@ pub fn parse_expr( Ok(factorial(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Round => Ok(round(parse_exprs(args, registry, codec)?)), - ScalarFunction::Trunc => Ok(trunc(parse_exprs(args, registry, codec)?)), ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } @@ -1318,14 +1312,10 @@ pub fn parse_expr( ScalarFunction::Coalesce => { Ok(coalesce(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Nanvl => Ok(nanvl( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Iszero => { - Ok(iszero(parse_expr(&args[0], registry, codec)?)) - } } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index db9653e32346..ed5e7a302b20 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1407,12 +1407,9 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { fn try_from(scalar: &BuiltinScalarFunction) -> Result { let scalar_function = match scalar { - BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Exp => Self::Exp, BuiltinScalarFunction::Factorial => Self::Factorial, BuiltinScalarFunction::Ceil => Self::Ceil, - BuiltinScalarFunction::Round => Self::Round, - BuiltinScalarFunction::Trunc => Self::Trunc, BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, @@ -1420,7 +1417,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Random => Self::Random, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Nanvl => Self::Nanvl, - BuiltinScalarFunction::Iszero => Self::Iszero, }; Ok(scalar_function) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 5dacf692e904..a74b1a38935b 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -34,9 +34,7 @@ use datafusion::datasource::physical_plan::{ FileSinkConfig, ParquetExec, }; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::{ - create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, -}; +use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::expressions::NthValueAgg; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; @@ -603,31 +601,6 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { ))) } -#[test] -fn roundtrip_builtin_scalar_function() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let input = Arc::new(EmptyExec::new(schema.clone())); - - let fun_def = ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Trunc); - - let expr = ScalarFunctionExpr::new( - "trunc", - fun_def, - vec![col("a", &schema)?], - DataType::Float64, - Some(vec![Some(true)]), - false, - ); - - let project = - ProjectionExec::try_new(vec![(Arc::new(expr), "a".to_string())], input)?; - - roundtrip_test(Arc::new(project)) -} - #[test] fn roundtrip_scalar_udf() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index e923a15372d0..19288123558a 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2683,6 +2683,11 @@ fn logical_plan_with_dialect_and_options( vec![DataType::Int32, DataType::Int32], DataType::Int32, )) + .with_udf(make_udf( + "round", + vec![DataType::Float64, DataType::Int64], + DataType::Float32, + )) .with_udf(make_udf( "arrow_cast", vec![DataType::Int64, DataType::Utf8],