From 6e637488188c6620ecd113bf47987bd98f6d7871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Thu, 4 Jul 2024 14:47:14 +0800 Subject: [PATCH 01/26] Enable clone_on_ref_ptr clippy lint on expr crate (#11238) --- datafusion/expr/src/expr_rewriter/mod.rs | 2 +- datafusion/expr/src/expr_schema.rs | 2 +- datafusion/expr/src/lib.rs | 2 ++ datafusion/expr/src/logical_plan/builder.rs | 23 +++++++++----- datafusion/expr/src/logical_plan/plan.rs | 30 +++++++++---------- datafusion/expr/src/type_coercion/binary.rs | 8 ++--- .../expr/src/type_coercion/functions.rs | 12 ++++---- datafusion/expr/src/udaf.rs | 9 ++++-- datafusion/expr/src/udf.rs | 8 ++--- datafusion/expr/src/udwf.rs | 6 ++-- datafusion/expr/src/utils.rs | 2 +- 11 files changed, 58 insertions(+), 46 deletions(-) diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 024e4a0ceae5..91bec501f4a0 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -215,7 +215,7 @@ pub fn coerce_plan_expr_for_schema( LogicalPlan::Projection(Projection { expr, input, .. }) => { let new_exprs = coerce_exprs_for_schema(expr.clone(), input.schema(), schema)?; - let projection = Projection::try_new(new_exprs, input.clone())?; + let projection = Projection::try_new(new_exprs, Arc::clone(input))?; Ok(LogicalPlan::Projection(projection)) } _ => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 8bb655eda575..a84931398f5b 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -520,7 +520,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 5f1d3c9d5c6b..e1943c890e7c 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! [DataFusion](https://github.com/apache/datafusion) //! is an extensible query execution framework that uses diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cc4348d58c33..4ad3bd5018a4 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1223,17 +1223,17 @@ pub fn build_join_schema( JoinType::Inner => { // left then right let left_fields = left_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); let right_fields = right_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); left_fields.into_iter().chain(right_fields).collect() } JoinType::Left => { // left then right, right set to nullable in case of not matched scenario let left_fields = left_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); left_fields .into_iter() @@ -1243,7 +1243,7 @@ pub fn build_join_schema( JoinType::Right => { // left then right, left set to nullable in case of not matched scenario let right_fields = right_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); nullify_fields(left_fields) .into_iter() @@ -1259,11 +1259,15 @@ pub fn build_join_schema( } JoinType::LeftSemi | JoinType::LeftAnti => { // Only use the left side for the schema - left_fields.map(|(q, f)| (q.cloned(), f.clone())).collect() + left_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .collect() } JoinType::RightSemi | JoinType::RightAnti => { // Only use the right side for the schema - right_fields.map(|(q, f)| (q.cloned(), f.clone())).collect() + right_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .collect() } }; let func_dependencies = left.functional_dependencies().join( @@ -1577,7 +1581,7 @@ impl TableSource for LogicalTableSource { } fn schema(&self) -> SchemaRef { - self.table_schema.clone() + Arc::clone(&self.table_schema) } fn supports_filters_pushdown( @@ -1691,7 +1695,10 @@ pub fn unnest_with_options( } None => { dependency_indices.push(index); - Ok(vec![(original_qualifier.cloned(), original_field.clone())]) + Ok(vec![( + original_qualifier.cloned(), + Arc::clone(original_field), + )]) } } }) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 2921541934f8..8fd5982a0f2e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -762,9 +762,9 @@ impl LogicalPlan { // If inputs are not pruned do not change schema // TODO this seems wrong (shouldn't we always use the schema of the input?) let schema = if schema.fields().len() == input_schema.fields().len() { - schema.clone() + Arc::clone(&schema) } else { - input_schema.clone() + Arc::clone(input_schema) }; Ok(LogicalPlan::Union(Union { inputs, schema })) } @@ -850,7 +850,7 @@ impl LogicalPlan { .. }) => Ok(LogicalPlan::Dml(DmlStatement::new( table_name.clone(), - table_schema.clone(), + Arc::clone(table_schema), op.clone(), Arc::new(inputs.swap_remove(0)), ))), @@ -863,13 +863,13 @@ impl LogicalPlan { }) => Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(inputs.swap_remove(0)), output_url: output_url.clone(), - file_type: file_type.clone(), + file_type: Arc::clone(file_type), options: options.clone(), partition_by: partition_by.clone(), })), LogicalPlan::Values(Values { schema, .. }) => { Ok(LogicalPlan::Values(Values { - schema: schema.clone(), + schema: Arc::clone(schema), values: expr .chunks_exact(schema.fields().len()) .map(|s| s.to_vec()) @@ -1027,9 +1027,9 @@ impl LogicalPlan { let input_schema = inputs[0].schema(); // If inputs are not pruned do not change schema. let schema = if schema.fields().len() == input_schema.fields().len() { - schema.clone() + Arc::clone(schema) } else { - input_schema.clone() + Arc::clone(input_schema) }; Ok(LogicalPlan::Union(Union { inputs: inputs.into_iter().map(Arc::new).collect(), @@ -1073,7 +1073,7 @@ impl LogicalPlan { assert_eq!(inputs.len(), 1); Ok(LogicalPlan::Analyze(Analyze { verbose: a.verbose, - schema: a.schema.clone(), + schema: Arc::clone(&a.schema), input: Arc::new(inputs.swap_remove(0)), })) } @@ -1087,7 +1087,7 @@ impl LogicalPlan { verbose: e.verbose, plan: Arc::new(inputs.swap_remove(0)), stringified_plans: e.stringified_plans.clone(), - schema: e.schema.clone(), + schema: Arc::clone(&e.schema), logical_optimization_succeeded: e.logical_optimization_succeeded, })) } @@ -1369,7 +1369,7 @@ impl LogicalPlan { param_values: &ParamValues, ) -> Result { self.transform_up_with_subqueries(|plan| { - let schema = plan.schema().clone(); + let schema = Arc::clone(plan.schema()); plan.map_expressions(|e| { e.infer_placeholder_types(&schema)?.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { @@ -2227,7 +2227,7 @@ impl Window { let fields: Vec<(Option, Arc)> = input .schema() .iter() - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect(); let input_len = fields.len(); let mut window_fields = fields; @@ -3352,7 +3352,7 @@ digraph { vec![col("a")], Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: empty_schema.clone(), + schema: Arc::clone(&empty_schema), })), empty_schema, ); @@ -3467,9 +3467,9 @@ digraph { ); let scan = Arc::new(LogicalPlan::TableScan(TableScan { table_name: TableReference::bare("tab"), - source: source.clone(), + source: Arc::clone(&source) as Arc, projection: None, - projected_schema: schema.clone(), + projected_schema: Arc::clone(&schema), filters: vec![], fetch: None, })); @@ -3499,7 +3499,7 @@ digraph { table_name: TableReference::bare("tab"), source, projection: None, - projected_schema: unique_schema.clone(), + projected_schema: Arc::clone(&unique_schema), filters: vec![], fetch: None, })); diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 5645a2a4dede..442a33bebc99 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -1048,16 +1048,16 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(lhs_tz.clone()), - (lhs, rhs) if lhs == rhs => Some(lhs_tz.clone()), + ("UTC", "+00:00") | ("+00:00", "UTC") => Some(Arc::clone(lhs_tz)), + (lhs, rhs) if lhs == rhs => Some(Arc::clone(lhs_tz)), // can't cast across timezones _ => { return None; } } } - (Some(lhs_tz), None) => Some(lhs_tz.clone()), - (None, Some(rhs_tz)) => Some(rhs_tz.clone()), + (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)), + (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)), (None, None) => None, }; diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 5f060a4a4f16..f9f467098ee4 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -598,7 +598,7 @@ fn coerced_from<'a>( Arc::new(f_into.as_ref().clone().with_data_type(data_type)); Some(FixedSizeList(new_field, *size_from)) } - Some(_) => Some(FixedSizeList(f_into.clone(), *size_from)), + Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)), _ => None, } } @@ -607,7 +607,7 @@ fn coerced_from<'a>( (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { - Some(Timestamp(unit.clone(), Some(from_tz.clone()))) + Some(Timestamp(unit.clone(), Some(Arc::clone(from_tz)))) } Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => { // In the absence of any other information assume the time zone is "+00" (UTC). @@ -715,12 +715,12 @@ mod tests { fn test_fixed_list_wildcard_coerce() -> Result<()> { let inner = Arc::new(Field::new("item", DataType::Int32, false)); let current_types = vec![ - DataType::FixedSizeList(inner.clone(), 2), // able to coerce for any size + DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size ]; let signature = Signature::exact( vec![DataType::FixedSizeList( - inner.clone(), + Arc::clone(&inner), FIXED_SIZE_LIST_WILDCARD, )], Volatility::Stable, @@ -731,7 +731,7 @@ mod tests { // make sure it can't coerce to a different size let signature = Signature::exact( - vec![DataType::FixedSizeList(inner.clone(), 3)], + vec![DataType::FixedSizeList(Arc::clone(&inner), 3)], Volatility::Stable, ); let coerced_data_types = data_types(¤t_types, &signature); @@ -739,7 +739,7 @@ mod tests { // make sure it works with the same type. let signature = Signature::exact( - vec![DataType::FixedSizeList(inner.clone(), 2)], + vec![DataType::FixedSizeList(Arc::clone(&inner), 2)], Volatility::Stable, ); let coerced_data_types = data_types(¤t_types, &signature).unwrap(); diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index c8362691452b..7a054abea75b 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -106,8 +106,8 @@ impl AggregateUDF { Self::new_from_impl(AggregateUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), - return_type: return_type.clone(), - accumulator: accumulator.clone(), + return_type: Arc::clone(return_type), + accumulator: Arc::clone(accumulator), }) } @@ -133,7 +133,10 @@ impl AggregateUDF { /// /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedAggregateUDFImpl::new(self.inner.clone(), aliases)) + Self::new_from_impl(AliasedAggregateUDFImpl::new( + Arc::clone(&self.inner), + aliases, + )) } /// creates an [`Expr`] that calls the aggregate function. diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 03650b1d4748..68d3af6ace3c 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -87,8 +87,8 @@ impl ScalarUDF { Self::new_from_impl(ScalarUdfLegacyWrapper { name: name.to_owned(), signature: signature.clone(), - return_type: return_type.clone(), - fun: fun.clone(), + return_type: Arc::clone(return_type), + fun: Arc::clone(fun), }) } @@ -114,7 +114,7 @@ impl ScalarUDF { /// /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedScalarUDFImpl::new(self.inner.clone(), aliases)) + Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases)) } /// Returns a [`Expr`] logical expression to call this UDF with specified @@ -199,7 +199,7 @@ impl ScalarUDF { /// Returns a `ScalarFunctionImplementation` that can invoke the function /// during execution pub fn fun(&self) -> ScalarFunctionImplementation { - let captured = self.inner.clone(); + let captured = Arc::clone(&self.inner); Arc::new(move |args| captured.invoke(args)) } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index a17bb0ade8e3..70b44e5e307a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -90,8 +90,8 @@ impl WindowUDF { Self::new_from_impl(WindowUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), - return_type: return_type.clone(), - partition_evaluator_factory: partition_evaluator_factory.clone(), + return_type: Arc::clone(return_type), + partition_evaluator_factory: Arc::clone(partition_evaluator_factory), }) } @@ -117,7 +117,7 @@ impl WindowUDF { /// /// If you implement [`WindowUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedWindowUDFImpl::new(self.inner.clone(), aliases)) + Self::new_from_impl(AliasedWindowUDFImpl::new(Arc::clone(&self.inner), aliases)) } /// creates a [`Expr`] that calls the window function given diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 286f05309ea7..e3b8db676c98 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1199,7 +1199,7 @@ pub fn only_or_err(slice: &[T]) -> Result<&T> { /// merge inputs schema into a single schema. pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { if inputs.len() == 1 { - inputs[0].schema().clone().as_ref().clone() + inputs[0].schema().as_ref().clone() } else { inputs.iter().map(|input| input.schema()).fold( DFSchema::empty(), From b46d5b7fc65a647d51339f3e6524879aee9810fa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 5 Jul 2024 07:18:05 -0400 Subject: [PATCH 02/26] Optimize PushDownFilter to avoid recreating schema columns (#11211) --- datafusion/optimizer/src/push_down_filter.rs | 80 +++++++++++++------- 1 file changed, 53 insertions(+), 27 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 8e310d1f4e8a..1c3186b762b7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -194,9 +194,50 @@ fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { } } -/// Return true if a predicate only references columns in the specified schema -fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result { - let schema_columns = schema +/// Evaluates the columns referenced in the given expression to see if they refer +/// only to the left or right columns +#[derive(Debug)] +struct ColumnChecker<'a> { + /// schema of left join input + left_schema: &'a DFSchema, + /// columns in left_schema, computed on demand + left_columns: Option>, + /// schema of right join input + right_schema: &'a DFSchema, + /// columns in left_schema, computed on demand + right_columns: Option>, +} + +impl<'a> ColumnChecker<'a> { + fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self { + Self { + left_schema, + left_columns: None, + right_schema, + right_columns: None, + } + } + + /// Return true if the expression references only columns from the left side of the join + fn is_left_only(&mut self, predicate: &Expr) -> bool { + if self.left_columns.is_none() { + self.left_columns = Some(schema_columns(self.left_schema)); + } + has_all_column_refs(predicate, self.left_columns.as_ref().unwrap()) + } + + /// Return true if the expression references only columns from the right side of the join + fn is_right_only(&mut self, predicate: &Expr) -> bool { + if self.right_columns.is_none() { + self.right_columns = Some(schema_columns(self.right_schema)); + } + has_all_column_refs(predicate, self.right_columns.as_ref().unwrap()) + } +} + +/// Returns all columns in the schema +fn schema_columns(schema: &DFSchema) -> HashSet { + schema .iter() .flat_map(|(qualifier, field)| { [ @@ -205,8 +246,7 @@ fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result>(); - Ok(has_all_column_refs(predicate, &schema_columns)) + .collect::>() } /// Determine whether the predicate can evaluate as the join conditions @@ -291,16 +331,7 @@ fn extract_or_clauses_for_join<'a>( filters: &'a [Expr], schema: &'a DFSchema, ) -> impl Iterator + 'a { - let schema_columns = schema - .iter() - .flat_map(|(qualifier, field)| { - [ - Column::new(qualifier.cloned(), field.name()), - // we need to push down filter using unqualified column as well - Column::new_unqualified(field.name()), - ] - }) - .collect::>(); + let schema_columns = schema_columns(schema); // new formed OR clauses and their column references filters.iter().filter_map(move |expr| { @@ -403,12 +434,11 @@ fn push_down_all_join( let mut right_push = vec![]; let mut keep_predicates = vec![]; let mut join_conditions = vec![]; + let mut checker = ColumnChecker::new(left_schema, right_schema); for predicate in predicates { - if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { + if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved - && can_pushdown_join_predicate(&predicate, right_schema)? - { + } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate @@ -421,11 +451,9 @@ fn push_down_all_join( // For infer predicates, if they can not push through join, just drop them for predicate in inferred_join_predicates { - if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { + if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved - && can_pushdown_join_predicate(&predicate, right_schema)? - { + } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } } @@ -435,11 +463,9 @@ fn push_down_all_join( if !on_filter.is_empty() { for on in on_filter { - if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? { + if on_left_preserved && checker.is_left_only(&on) { left_push.push(on) - } else if on_right_preserved - && can_pushdown_join_predicate(&on, right_schema)? - { + } else if on_right_preserved && checker.is_right_only(&on) { right_push.push(on) } else { on_filter_join_conditions.push(on) From 351e5f95646035a8f742811af0a61d00480fefc2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 5 Jul 2024 07:49:06 -0400 Subject: [PATCH 03/26] Remove outdated `rewrite_expr.rs` example (#11085) * Remove outdated rewrite_expr.rs example * Update docs --- datafusion-examples/README.md | 1 - datafusion-examples/examples/rewrite_expr.rs | 251 ------------------ datafusion/optimizer/README.md | 6 +- .../library-user-guide/working-with-exprs.md | 6 +- 4 files changed, 9 insertions(+), 255 deletions(-) delete mode 100644 datafusion-examples/examples/rewrite_expr.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index a208eee13587..dc92019035dd 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -72,7 +72,6 @@ cargo run --example dataframe - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`regexp.rs`](examples/regexp.rs): Examples of using regular expression functions -- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs deleted file mode 100644 index 06286d5d66ed..000000000000 --- a/datafusion-examples/examples/rewrite_expr.rs +++ /dev/null @@ -1,251 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{ - AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, -}; -use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; -use datafusion_optimizer::optimizer::{ApplyOrder, Optimizer}; -use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; -use datafusion_sql::planner::{ContextProvider, SqlToRel}; -use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; -use datafusion_sql::sqlparser::parser::Parser; -use datafusion_sql::TableReference; -use std::any::Any; -use std::sync::Arc; - -pub fn main() -> Result<()> { - // produce a logical plan using the datafusion-sql crate - let dialect = PostgreSqlDialect {}; - let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; - let statements = Parser::parse_sql(&dialect, sql)?; - - // produce a logical plan using the datafusion-sql crate - let context_provider = MyContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context_provider); - let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; - println!( - "Unoptimized Logical Plan:\n\n{}\n", - logical_plan.display_indent() - ); - - // run the analyzer with our custom rule - let config = OptimizerContext::default().with_skip_failing_rules(false); - let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); - let analyzed_plan = - analyzer.execute_and_check(logical_plan, config.options(), |_, _| {})?; - println!( - "Analyzed Logical Plan:\n\n{}\n", - analyzed_plan.display_indent() - ); - - // then run the optimizer with our custom rule - let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; - println!( - "Optimized Logical Plan:\n\n{}\n", - optimized_plan.display_indent() - ); - - Ok(()) -} - -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { - println!( - "After applying rule '{}':\n{}\n", - rule.name(), - plan.display_indent() - ) -} - -/// An example analyzer rule that changes Int64 literals to UInt64 -struct MyAnalyzerRule {} - -impl AnalyzerRule for MyAnalyzerRule { - fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - Self::analyze_plan(plan) - } - - fn name(&self) -> &str { - "my_analyzer_rule" - } -} - -impl MyAnalyzerRule { - fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(|plan| { - Ok(match plan { - LogicalPlan::Filter(filter) => { - let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?)) - } - _ => Transformed::no(plan), - }) - }) - .data() - } - - fn analyze_expr(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Literal(ScalarValue::Int64(i)) => { - // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) - } - _ => Transformed::no(expr), - }) - }) - .data() - } -} - -/// An example optimizer rule that rewrite BETWEEN expression to binary compare expressions -struct MyOptimizerRule {} - -impl OptimizerRule for MyOptimizerRule { - fn name(&self) -> &str { - "my_optimizer_rule" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - - fn supports_rewrite(&self) -> bool { - true - } - - fn rewrite( - &self, - plan: LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result, DataFusionError> { - match plan { - LogicalPlan::Filter(filter) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input.clone(), - )?))) - } - _ => Ok(Transformed::no(plan)), - } - } -} - -/// use rewrite_expr to modify the expression tree. -fn my_rewrite(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - // unbox - let expr: Expr = *expr; - let low: Expr = *low; - let high: Expr = *high; - if negated { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } else { - Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) - } - } - _ => Transformed::no(expr), - }) - }) - .data() -} - -#[derive(Default)] -struct MyContextProvider { - options: ConfigOptions, -} - -impl ContextProvider for MyContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - if name.table() == "person" { - Ok(Arc::new(MyTableSource { - schema: Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("age", DataType::UInt8, false), - ])), - })) - } else { - plan_err!("table not found") - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn udf_names(&self) -> Vec { - Vec::new() - } - - fn udaf_names(&self) -> Vec { - Vec::new() - } - - fn udwf_names(&self) -> Vec { - Vec::new() - } -} - -struct MyTableSource { - schema: SchemaRef, -} - -impl TableSource for MyTableSource { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index 2f1f85e3a57a..5aacfaf59cb1 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -67,8 +67,10 @@ let optimizer = Optimizer::with_rules(vec![ ## Writing Optimization Rules -Please refer to the [rewrite_expr example](../../datafusion-examples/examples/rewrite_expr.rs) to learn more about -the general approach to writing optimizer rules and then move onto studying the existing rules. +Please refer to the +[optimizer_rule.rs](../../datafusion-examples/examples/optimizer_rule.rs) +example to learn more about the general approach to writing optimizer rules and +then move onto studying the existing rules. All rules must implement the `OptimizerRule` trait. diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index e0c9e69eb6ed..e0b6f434a032 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -80,7 +80,11 @@ If you'd like to learn more about `Expr`s, before we get into the details of cre ## Rewriting `Expr`s -[rewrite_expr.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) contains example code for rewriting `Expr`s. +There are several examples of rewriting and working with `Exprs`: + +- [expr_api.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs) +- [analyzer_rule.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/analyzer_rule.rs) +- [optimizer_rule.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/optimizer_rule.rs) Rewriting Expressions is the process of taking an `Expr` and transforming it into another `Expr`. This is useful for a number of reasons, including: From 0d2525e6eaebea8cc3ee94c249a61ab5c2d55a81 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Fri, 5 Jul 2024 04:49:24 -0700 Subject: [PATCH 04/26] Implement TPCH substrait integration teset, support tpch_2 (#11234) * integrate tpch query 2 avoid cloning optimize code optimize code * optimize code * refactor code * format --- .../substrait/src/logical_plan/consumer.rs | 142 +- .../tests/cases/consumer_integration.rs | 93 +- .../substrait/tests/testdata/tpch/nation.csv | 2 + .../substrait/tests/testdata/tpch/part.csv | 2 + .../tests/testdata/tpch/partsupp.csv | 2 + .../substrait/tests/testdata/tpch/region.csv | 2 + .../tests/testdata/tpch/supplier.csv | 2 + .../tpch_substrait_plans/query_2.json | 1582 +++++++++++++++++ 8 files changed, 1769 insertions(+), 58 deletions(-) create mode 100644 datafusion/substrait/tests/testdata/tpch/nation.csv create mode 100644 datafusion/substrait/tests/testdata/tpch/part.csv create mode 100644 datafusion/substrait/tests/testdata/tpch/partsupp.csv create mode 100644 datafusion/substrait/tests/testdata/tpch/region.csv create mode 100644 datafusion/substrait/tests/testdata/tpch/supplier.csv create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_2.json diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index af8dd60f6566..cc10ea0619c1 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -15,25 +15,38 @@ // specific language governing permissions and limitations // under the License. +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use async_recursion::async_recursion; +use datafusion::arrow::array::GenericListArray; use datafusion::arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; +use datafusion::common::plan_err; use datafusion::common::{ - not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, - substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, + not_impl_datafusion_err, not_impl_err, plan_datafusion_err, substrait_datafusion_err, + substrait_err, DFSchema, DFSchemaRef, }; -use substrait::proto::expression::literal::IntervalDayToSecond; -use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; -use url::Url; - -use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::expr::{InSubquery, Sort}; + use datafusion::logical_expr::{ aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values, }; +use url::Url; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, + INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, + UNSIGNED_INTEGER_TYPE_VARIATION_REF, +}; +use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, @@ -46,10 +59,15 @@ use datafusion::{ prelude::{Column, SessionContext}, scalar::ScalarValue, }; +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; +use substrait::proto::expression::literal::IntervalDayToSecond; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; +use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -70,24 +88,6 @@ use substrait::proto::{ }; use substrait::proto::{FunctionArgument, SortField}; -use datafusion::arrow::array::GenericListArray; -use datafusion::common::scalar::ScalarStructBuilder; -use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; -use std::collections::HashMap; -use std::str::FromStr; -use std::sync::Arc; - -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, - TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, - TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, -}; - pub fn name_to_op(name: &str) -> Result { match name { "equal" => Ok(Operator::Eq), @@ -1125,17 +1125,32 @@ pub async fn from_substrait_rex( expr::ScalarFunction::new_udf(func.to_owned(), args), ))) } else if let Ok(op) = name_to_op(fn_name) { - if args.len() != 2 { + if f.arguments.len() < 2 { return not_impl_err!( - "Expect two arguments for binary operator {op:?}" + "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", + f.arguments.len() ); } + // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. + // In those cases we iterate through all the arguments, applying the binary expression against them all + let combined_expr = args + .into_iter() + .fold(None, |combined_expr: Option>, arg: Expr| { + Some(match combined_expr { + Some(expr) => Arc::new(Expr::BinaryExpr(BinaryExpr { + left: Box::new( + Arc::try_unwrap(expr) + .unwrap_or_else(|arc: Arc| (*arc).clone()), + ), // Avoid cloning if possible + op: op.clone(), + right: Box::new(arg), + })), + None => Arc::new(arg), + }) + }) + .unwrap(); - Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { - left: Box::new(args[0].to_owned()), - op, - right: Box::new(args[1].to_owned()), - }))) + Ok(combined_expr) } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { builder.build(ctx, f, input_schema, extensions).await } else { @@ -1269,7 +1284,22 @@ pub async fn from_substrait_rex( } } } - _ => substrait_err!("Subquery type not implemented"), + SubqueryType::Scalar(query) => { + let plan = from_substrait_rel( + ctx, + &(query.input.clone()).unwrap_or_default(), + extensions, + ) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Arc::new(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + }))) + } + other_type => { + substrait_err!("Subquery type {:?} not implemented", other_type) + } }, None => { substrait_err!("Subquery experssion without SubqueryType is not allowed") @@ -1699,6 +1729,7 @@ fn from_substrait_literal( })) => { ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000)) } + Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { match user_defined.type_reference { INTERVAL_YEAR_MONTH_TYPE_REF => { @@ -1988,8 +2019,8 @@ impl BuiltinExprBuilder { extensions: &HashMap, ) -> Result> { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 3 { - return substrait_err!("Expect three arguments for `{fn_name}` expr"); + if f.arguments.len() != 2 && f.arguments.len() != 3 { + return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { @@ -2007,25 +2038,40 @@ impl BuiltinExprBuilder { .await? .as_ref() .clone(); - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let escape_char_expr = - from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { - return substrait_err!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" - ); + + // Default case: escape character is Literal(Utf8(None)) + let escape_char = if f.arguments.len() == 3 { + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type + else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + + let escape_char_expr = + from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + + match escape_char_expr { + Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { + // Convert Option to Option + escape_char_string.and_then(|s| s.chars().next()) + } + _ => { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ) + } + } + } else { + None }; Ok(Arc::new(Expr::Like(Like { negated: false, expr: Box::new(expr), pattern: Box::new(pattern), - escape_char: escape_char.map(|c| c.chars().next().unwrap()), + escape_char, case_insensitive, }))) } diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 8ea3a69cab61..58f2fc900937 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -32,9 +32,51 @@ mod tests { use std::io::BufReader; use substrait::proto::Plan; + async fn register_csv( + ctx: &SessionContext, + table_name: &str, + file_path: &str, + ) -> Result<()> { + ctx.register_csv(table_name, file_path, CsvReadOptions::default()) + .await + } + + async fn create_context_tpch2() -> Result { + let ctx = SessionContext::new(); + + let registrations = vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"), + ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"), + ]; + + for (table_name, file_path) in registrations { + register_csv(&ctx, table_name, file_path).await?; + } + + Ok(ctx) + } + + async fn create_context_tpch1() -> Result { + let ctx = SessionContext::new(); + register_csv( + &ctx, + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + ) + .await?; + Ok(ctx) + } + #[tokio::test] async fn tpch_test_1() -> Result<()> { - let ctx = create_context().await?; + let ctx = create_context_tpch1().await?; let path = "tests/testdata/tpch_substrait_plans/query_1.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -56,14 +98,45 @@ mod tests { Ok(()) } - async fn create_context() -> datafusion::common::Result { - let ctx = SessionContext::new(); - ctx.register_csv( - "FILENAME_PLACEHOLDER_0", - "tests/testdata/tpch/lineitem.csv", - CsvReadOptions::default(), - ) - .await?; - Ok(ctx) + #[tokio::test] + async fn tpch_test_2() -> Result<()> { + let ctx = create_context_tpch2().await?; + let path = "tests/testdata/tpch_substrait_plans/query_2.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!( + plan_str, + "Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\ + \n Limit: skip=0, fetch=100\ + \n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.n_name ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\ + \n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\ + \n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_regionkey = FILENAME_PLACEHOLDER_4.r_regionkey AND FILENAME_PLACEHOLDER_4.r_name = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = ()\ + \n Subquery:\ + \n Aggregate: groupBy=[[]], aggr=[[MIN(FILENAME_PLACEHOLDER_5.ps_supplycost)]]\ + \n Projection: FILENAME_PLACEHOLDER_5.ps_supplycost\ + \n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.n_nationkey AND FILENAME_PLACEHOLDER_7.n_regionkey = FILENAME_PLACEHOLDER_8.r_regionkey AND FILENAME_PLACEHOLDER_8.r_name = CAST(Utf8(\"EUROPE\") AS Utf8)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_5 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_6 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_7 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_8 projection=[r_regionkey, r_name, r_comment]\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_4 projection=[r_regionkey, r_name, r_comment]" + ); + Ok(()) } } diff --git a/datafusion/substrait/tests/testdata/tpch/nation.csv b/datafusion/substrait/tests/testdata/tpch/nation.csv new file mode 100644 index 000000000000..fdf7421467d3 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/nation.csv @@ -0,0 +1,2 @@ +n_nationkey,n_name,n_regionkey,n_comment +0,ALGERIA,0, haggle. carefully final deposits detect slyly agai \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/part.csv b/datafusion/substrait/tests/testdata/tpch/part.csv new file mode 100644 index 000000000000..ef6d04271117 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/part.csv @@ -0,0 +1,2 @@ +p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,p_retailprice,p_comment +1,pink powder puff,Manufacturer#1,Brand#13,SMALL PLATED COPPER,7,JUMBO PKG,901.00,ly final dependencies: slyly bold \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/partsupp.csv b/datafusion/substrait/tests/testdata/tpch/partsupp.csv new file mode 100644 index 000000000000..5c585abc7733 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/partsupp.csv @@ -0,0 +1,2 @@ +ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment +1,1,1000,50.00,slyly final packages boost against the slyly regular \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/region.csv b/datafusion/substrait/tests/testdata/tpch/region.csv new file mode 100644 index 000000000000..6c3fb4524355 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/region.csv @@ -0,0 +1,2 @@ +r_regionkey,r_name,r_comment +0,AFRICA,lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/supplier.csv b/datafusion/substrait/tests/testdata/tpch/supplier.csv new file mode 100644 index 000000000000..f73d2cbeaf91 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/supplier.csv @@ -0,0 +1,2 @@ +s_suppkey,s_name,s_address,s_nationkey,s_phone,s_acctbal,s_comment +1,Supplier#1,123 Main St,0,555-1234,1000.00,No comments \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_2.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_2.json new file mode 100644 index 000000000000..dd570ca06d45 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_2.json @@ -0,0 +1,1582 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, + { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "like:vchar_vchar" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 3, + "name": "min:decimal" + } + } + ], + "relations": [ + { + "root": { + "input": { + "fetch": { + "common": { + "direct": {} + }, + "input": { + "sort": { + "common": { + "direct": {} + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "P_PARTKEY", + "P_NAME", + "P_MFGR", + "P_BRAND", + "P_TYPE", + "P_SIZE", + "P_CONTAINER", + "P_RETAILPRICE", + "P_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 55, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 23, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "S_SUPPKEY", + "S_NAME", + "S_ADDRESS", + "S_NATIONKEY", + "S_PHONE", + "S_ACCTBAL", + "S_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 101, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "PS_PARTKEY", + "PS_SUPPKEY", + "PS_AVAILQTY", + "PS_SUPPLYCOST", + "PS_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 199, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "N_NATIONKEY", + "N_NAME", + "N_REGIONKEY", + "N_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_3", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "R_REGIONKEY", + "R_NAME", + "R_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_4", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "i32": 15, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "%BRASS", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 25 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 26 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "EUROPE", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "subquery": { + "scalar": { + "input": { + "aggregate": { + "common": { + "direct": {} + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 19 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "PS_PARTKEY", + "PS_SUPPKEY", + "PS_AVAILQTY", + "PS_SUPPLYCOST", + "PS_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 199, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_5", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "S_SUPPKEY", + "S_NAME", + "S_ADDRESS", + "S_NATIONKEY", + "S_PHONE", + "S_ACCTBAL", + "S_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 101, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_6", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "N_NATIONKEY", + "N_NAME", + "N_REGIONKEY", + "N_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_7", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "R_REGIONKEY", + "R_NAME", + "R_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_8", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "EUROPE", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [] + } + ], + "measures": [ + { + "measure": { + "functionReference": 3, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + } + } + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": {} + } + } + ] + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ] + } + }, + "offset": "0", + "count": "100" + } + }, + "names": [ + "S_ACCTBAL", + "S_NAME", + "N_NAME", + "P_PARTKEY", + "P_MFGR", + "S_ADDRESS", + "S_PHONE", + "S_COMMENT" + ] + } + } + ], + "expectedTypeUrls": [] +} \ No newline at end of file From 9355f4a5c7cb99df318ea98316cafc9b0494316f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Fri, 5 Jul 2024 20:43:22 +0800 Subject: [PATCH 05/26] Enable `clone_on_ref_ptr` clippy lint on physical-expr crate (#11240) * Enable clone_on_ref_ptr clippy lint on physical-expr crate * Fixup clippy --------- Co-authored-by: Andrew Lamb --- .../physical-expr/src/aggregate/array_agg.rs | 4 +- .../src/aggregate/array_agg_distinct.rs | 2 +- .../src/aggregate/array_agg_ordered.rs | 4 +- .../physical-expr/src/aggregate/build_in.rs | 12 +- .../physical-expr/src/aggregate/min_max.rs | 16 +- .../physical-expr/src/aggregate/nth_value.rs | 4 +- datafusion/physical-expr/src/analysis.rs | 2 +- .../physical-expr/src/equivalence/class.rs | 35 ++-- .../physical-expr/src/equivalence/mod.rs | 28 +-- .../physical-expr/src/equivalence/ordering.rs | 26 +-- .../src/equivalence/projection.rs | 59 +++--- .../src/equivalence/properties.rs | 175 ++++++++++-------- .../physical-expr/src/expressions/binary.rs | 65 ++++--- .../physical-expr/src/expressions/case.rs | 34 ++-- .../physical-expr/src/expressions/in_list.rs | 120 ++++++------ .../src/expressions/is_not_null.rs | 2 +- .../physical-expr/src/expressions/is_null.rs | 2 +- .../physical-expr/src/expressions/like.rs | 4 +- .../physical-expr/src/expressions/negative.rs | 2 +- .../physical-expr/src/expressions/not.rs | 2 +- .../physical-expr/src/expressions/try_cast.rs | 4 +- .../physical-expr/src/intervals/cp_solver.rs | 30 +-- .../physical-expr/src/intervals/test_utils.rs | 18 +- datafusion/physical-expr/src/lib.rs | 2 + datafusion/physical-expr/src/partitioning.rs | 4 +- datafusion/physical-expr/src/physical_expr.rs | 20 +- datafusion/physical-expr/src/planner.rs | 18 +- .../physical-expr/src/scalar_function.rs | 2 +- .../physical-expr/src/utils/guarantee.rs | 14 +- datafusion/physical-expr/src/utils/mod.rs | 6 +- .../physical-expr/src/window/built_in.rs | 4 +- .../physical-expr/src/window/lead_lag.rs | 6 +- .../physical-expr/src/window/nth_value.rs | 4 +- .../src/window/sliding_aggregate.rs | 2 +- .../physical-expr/src/window/window_expr.rs | 6 +- 35 files changed, 397 insertions(+), 341 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index c5a0662a2283..634a0a017903 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -91,7 +91,7 @@ impl AggregateExpr for ArrayAgg { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { @@ -137,7 +137,7 @@ impl Accumulator for ArrayAggAccumulator { return Ok(()); } assert!(values.len() == 1, "array_agg can only take 1 param!"); - let val = values[0].clone(); + let val = Arc::clone(&values[0]); self.values.push(val); Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index fc838196de20..a59d85e84a20 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -95,7 +95,7 @@ impl AggregateExpr for DistinctArrayAgg { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 1234ab40c188..3b122fe9f82b 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -127,7 +127,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { @@ -146,7 +146,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { Some(Arc::new(Self { name: self.name.to_string(), input_data_type: self.input_data_type.clone(), - expr: self.expr.clone(), + expr: Arc::clone(&self.expr), nullable: self.nullable, order_by_data_types: self.order_by_data_types.clone(), // Reverse requirement: diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index adbbbd3e631e..1eadf7247f7c 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -61,7 +61,7 @@ pub fn create_aggregate_expr( let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { (AggregateFunction::ArrayAgg, false) => { - let expr = input_phy_exprs[0].clone(); + let expr = Arc::clone(&input_phy_exprs[0]); let nullable = expr.nullable(input_schema)?; if ordering_req.is_empty() { @@ -83,7 +83,7 @@ pub fn create_aggregate_expr( "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" ); } - let expr = input_phy_exprs[0].clone(); + let expr = Arc::clone(&input_phy_exprs[0]); let is_expr_nullable = expr.nullable(input_schema)?; Arc::new(expressions::DistinctArrayAgg::new( expr, @@ -93,12 +93,12 @@ pub fn create_aggregate_expr( )) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( - input_phy_exprs[0].clone(), + Arc::clone(&input_phy_exprs[0]), name, data_type, )), (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( - input_phy_exprs[0].clone(), + Arc::clone(&input_phy_exprs[0]), name, data_type, )), @@ -113,7 +113,7 @@ pub fn create_aggregate_expr( }; let nullable = expr.nullable(input_schema)?; Arc::new(expressions::NthValueAgg::new( - expr.clone(), + Arc::clone(expr), n.clone().try_into()?, name, input_phy_types[0].clone(), @@ -320,7 +320,7 @@ mod tests { input_exprs .iter() .zip(coerced_types) - .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) + .map(|(expr, coerced_type)| try_cast(Arc::clone(expr), schema, coerced_type)) .collect::>>() } } diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index d142f68e417a..65bb9e478c3d 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -162,7 +162,7 @@ impl AggregateExpr for Max { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn create_accumulator(&self) -> Result> { @@ -927,7 +927,7 @@ impl AggregateExpr for Min { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { @@ -1169,7 +1169,7 @@ mod tests { let mut min = MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) .unwrap(); - min.update_batch(&[b.clone()]).unwrap(); + min.update_batch(&[Arc::clone(&b)]).unwrap(); let min_res = min.evaluate().unwrap(); assert_eq!( min_res, @@ -1181,7 +1181,7 @@ mod tests { let mut max = MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) .unwrap(); - max.update_batch(&[b.clone()]).unwrap(); + max.update_batch(&[Arc::clone(&b)]).unwrap(); let max_res = max.evaluate().unwrap(); assert_eq!( max_res, @@ -1202,7 +1202,7 @@ mod tests { let mut min = MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); - min.update_batch(&[b.clone()]).unwrap(); + min.update_batch(&[Arc::clone(&b)]).unwrap(); let min_res = min.evaluate().unwrap(); assert_eq!( min_res, @@ -1211,7 +1211,7 @@ mod tests { let mut max = MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); - max.update_batch(&[b.clone()]).unwrap(); + max.update_batch(&[Arc::clone(&b)]).unwrap(); let max_res = max.evaluate().unwrap(); assert_eq!( max_res, @@ -1231,7 +1231,7 @@ mod tests { let mut min = MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) .unwrap(); - min.update_batch(&[b.clone()]).unwrap(); + min.update_batch(&[Arc::clone(&b)]).unwrap(); let min_res = min.evaluate().unwrap(); assert_eq!( min_res, @@ -1243,7 +1243,7 @@ mod tests { let mut max = MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) .unwrap(); - max.update_batch(&[b.clone()]).unwrap(); + max.update_batch(&[Arc::clone(&b)]).unwrap(); let max_res = max.evaluate().unwrap(); assert_eq!( max_res, diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/physical-expr/src/aggregate/nth_value.rs index f6d25348f222..b75ecd1066ca 100644 --- a/datafusion/physical-expr/src/aggregate/nth_value.rs +++ b/datafusion/physical-expr/src/aggregate/nth_value.rs @@ -119,7 +119,7 @@ impl AggregateExpr for NthValueAgg { fn expressions(&self) -> Vec> { let n = Arc::new(Literal::new(ScalarValue::Int64(Some(self.n)))) as _; - vec![self.expr.clone(), n] + vec![Arc::clone(&self.expr), n] } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { @@ -138,7 +138,7 @@ impl AggregateExpr for NthValueAgg { Some(Arc::new(Self { name: self.name.to_string(), input_data_type: self.input_data_type.clone(), - expr: self.expr.clone(), + expr: Arc::clone(&self.expr), // index should be from the opposite side n: -self.n, nullable: self.nullable, diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index e7b199af3743..bcf1c8e510b1 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -163,7 +163,7 @@ pub fn analyze( ) -> Result { let target_boundaries = context.boundaries; - let mut graph = ExprIntervalGraph::try_new(expr.clone(), schema)?; + let mut graph = ExprIntervalGraph::try_new(Arc::clone(expr), schema)?; let columns = collect_columns(expr) .into_iter() diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 6c12acb934be..4c0edd2a5d9a 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -267,17 +267,19 @@ impl EquivalenceGroup { } (Some(group_idx), None) => { // Right side is new, extend left side's class: - self.classes[group_idx].push(right.clone()); + self.classes[group_idx].push(Arc::clone(right)); } (None, Some(group_idx)) => { // Left side is new, extend right side's class: - self.classes[group_idx].push(left.clone()); + self.classes[group_idx].push(Arc::clone(left)); } (None, None) => { // None of the expressions is among existing classes. // Create a new equivalence class and extend the group. - self.classes - .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); + self.classes.push(EquivalenceClass::new(vec![ + Arc::clone(left), + Arc::clone(right), + ])); } } } @@ -328,7 +330,7 @@ impl EquivalenceGroup { /// The expression is replaced with the first expression in the equivalence /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { - expr.clone() + Arc::clone(&expr) .transform(|expr| { for cls in self.iter() { if cls.contains(&expr) { @@ -429,7 +431,7 @@ impl EquivalenceGroup { .get_equivalence_class(source) .map_or(false, |group| group.contains(expr)) { - return Some(target.clone()); + return Some(Arc::clone(target)); } } } @@ -443,7 +445,7 @@ impl EquivalenceGroup { .into_iter() .map(|child| self.project_expr(mapping, child)) .collect::>>() - .map(|children| expr.clone().with_new_children(children).unwrap()) + .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) } /// Projects this equivalence group according to the given projection mapping. @@ -461,13 +463,13 @@ impl EquivalenceGroup { let mut new_classes = vec![]; for (source, target) in mapping.iter() { if new_classes.is_empty() { - new_classes.push((source, vec![target.clone()])); + new_classes.push((source, vec![Arc::clone(target)])); } if let Some((_, values)) = new_classes.iter_mut().find(|(key, _)| key.eq(source)) { if !physical_exprs_contains(values, target) { - values.push(target.clone()); + values.push(Arc::clone(target)); } } } @@ -515,10 +517,9 @@ impl EquivalenceGroup { // are equal in the resulting table. if join_type == &JoinType::Inner { for (lhs, rhs) in on.iter() { - let new_lhs = lhs.clone() as _; + let new_lhs = Arc::clone(lhs) as _; // Rewrite rhs to point to the right side of the join: - let new_rhs = rhs - .clone() + let new_rhs = Arc::clone(rhs) .transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() @@ -649,7 +650,7 @@ mod tests { let eq_group = eq_properties.eq_group(); for (expr, expected_eq) in expressions { assert!( - expected_eq.eq(&eq_group.normalize_expr(expr.clone())), + expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))), "error in test: expr: {expr:?}" ); } @@ -669,9 +670,11 @@ mod tests { Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); - let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); - let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + let cls1 = + EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]); + let cls2 = + EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]); + let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]); // lit_true is common assert!(cls1.contains_any(&cls2)); diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 5eb8a19e3d67..1ed9a4ac217f 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -145,7 +145,7 @@ mod tests { let col_e = &col("e", &test_schema)?; let col_f = &col("f", &test_schema)?; let col_g = &col("g", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); eq_properties.add_equal_conditions(col_a, col_c)?; let option_asc = SortOptions { @@ -201,11 +201,11 @@ mod tests { let col_f = &col("f", &test_schema)?; let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Define a and f are aliases eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. - eq_properties = eq_properties.add_constants([ConstExpr::new(col_e.clone())]); + eq_properties = eq_properties.add_constants([ConstExpr::new(Arc::clone(col_e))]); // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -223,7 +223,7 @@ mod tests { let ordering = remaining_exprs .drain(0..n_sort_expr) .map(|expr| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options: options_asc, }) .collect(); @@ -241,7 +241,7 @@ mod tests { in_data .iter() .map(|(expr, options)| { - PhysicalSortRequirement::new((*expr).clone(), *options) + PhysicalSortRequirement::new(Arc::clone(*expr), *options) }) .collect() } @@ -253,7 +253,7 @@ mod tests { in_data .iter() .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), + expr: Arc::clone(*expr), options: *options, }) .collect() @@ -276,7 +276,7 @@ mod tests { in_data .iter() .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), + expr: Arc::clone(expr), options: *options, }) .collect() @@ -309,9 +309,9 @@ mod tests { .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) .collect::>>()?; let projected_batch = if projected_values.is_empty() { - RecordBatch::new_empty(output_schema.clone()) + RecordBatch::new_empty(Arc::clone(&output_schema)) } else { - RecordBatch::try_new(output_schema.clone(), projected_values)? + RecordBatch::try_new(Arc::clone(&output_schema), projected_values)? }; let projected_eq = @@ -399,7 +399,7 @@ mod tests { let vals: Vec = (0..n_row).collect::>(); let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; - columns.push(unique_col.clone()); + columns.push(Arc::clone(&unique_col)); // Create a new schema with the added unique column let unique_col_name = "unique"; @@ -414,7 +414,7 @@ mod tests { let schema = Arc::new(Schema::new(fields)); // Create a new batch with the added column - let new_batch = RecordBatch::try_new(schema.clone(), columns)?; + let new_batch = RecordBatch::try_new(Arc::clone(&schema), columns)?; // Add the unique column to the required ordering to ensure deterministic results required_ordering.push(PhysicalSortExpr { @@ -454,7 +454,7 @@ mod tests { let col = expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); if let Some(res) = &existing_vec[idx] { - return Some(res.clone()); + return Some(Arc::clone(res)); } } None @@ -516,13 +516,13 @@ mod tests { // Fill columns based on equivalence groups for eq_group in eq_properties.eq_group.iter() { let representative_array = - get_representative_arr(eq_group, &schema_vec, schema.clone()) + get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); for expr in eq_group.iter() { let col = expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - schema_vec[idx] = Some(representative_array.clone()); + schema_vec[idx] = Some(Arc::clone(&representative_array)); } } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index ac9d64e486ac..d71075dc77e1 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -174,7 +174,7 @@ impl OrderingEquivalenceClass { pub fn add_offset(&mut self, offset: usize) { for ordering in self.orderings.iter_mut() { for sort_expr in ordering { - sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); + sort_expr.expr = add_offset_to_expr(Arc::clone(&sort_expr.expr), offset); } } } @@ -264,12 +264,14 @@ mod tests { }, ]; // finer ordering satisfies, crude ordering should return true - let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); + let mut eq_properties_finer = + EquivalenceProperties::new(Arc::clone(&input_schema)); eq_properties_finer.oeq_class.push(finer.clone()); assert!(eq_properties_finer.ordering_satisfy(&crude)); // Crude ordering doesn't satisfy finer ordering. should return false - let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); + let mut eq_properties_crude = + EquivalenceProperties::new(Arc::clone(&input_schema)); eq_properties_crude.oeq_class.push(crude.clone()); assert!(!eq_properties_crude.ordering_satisfy(&finer)); Ok(()) @@ -307,9 +309,9 @@ mod tests { &DFSchema::empty(), )?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let options = SortOptions { descending: false, @@ -541,7 +543,7 @@ mod tests { for (orderings, eq_group, constants, reqs, expected) in test_cases { let err_msg = format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); let orderings = convert_to_orderings(&orderings); eq_properties.add_new_orderings(orderings); let eq_group = eq_group @@ -554,9 +556,9 @@ mod tests { let eq_group = EquivalenceGroup::new(eq_group); eq_properties.add_equivalence_group(eq_group); - let constants = constants - .into_iter() - .map(|expr| ConstExpr::new(expr.clone()).with_across_partitions(true)); + let constants = constants.into_iter().map(|expr| { + ConstExpr::new(Arc::clone(expr)).with_across_partitions(true) + }); eq_properties = eq_properties.add_constants(constants); let reqs = convert_to_sort_exprs(&reqs); @@ -717,7 +719,7 @@ mod tests { let required = cols .into_iter() .map(|(expr, options)| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options, }) .collect::>(); @@ -769,7 +771,7 @@ mod tests { let requirement = exprs .into_iter() .map(|expr| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options: SORT_OPTIONS, }) .collect::>(); @@ -842,7 +844,7 @@ mod tests { let requirement = exprs .into_iter() .map(|expr| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options: SORT_OPTIONS, }) .collect::>(); diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index b5ac149d8b71..f1ce3f04489e 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -56,8 +56,7 @@ impl ProjectionMapping { .enumerate() .map(|(expr_idx, (expression, name))| { let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - expression - .clone() + Arc::clone(expression) .transform_down(|e| match e.as_any().downcast_ref::() { Some(col) => { // Sometimes, an expression and its name in the input_schema @@ -107,7 +106,7 @@ impl ProjectionMapping { self.map .iter() .find(|(source, _)| source.eq(expr)) - .map(|(_, target)| target.clone()) + .map(|(_, target)| Arc::clone(target)) } } @@ -149,24 +148,24 @@ mod tests { let col_e = &col("e", &schema)?; let col_ts = &col("ts", &schema)?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), + Arc::clone(col_b), Operator::Plus, - col_d.clone(), + Arc::clone(col_d), )) as Arc; let b_plus_e = Arc::new(BinaryExpr::new( - col_b.clone(), + Arc::clone(col_b), Operator::Plus, - col_e.clone(), + Arc::clone(col_e), )) as Arc; let c_plus_d = Arc::new(BinaryExpr::new( - col_c.clone(), + Arc::clone(col_c), Operator::Plus, - col_d.clone(), + Arc::clone(col_d), )) as Arc; let option_asc = SortOptions { @@ -587,14 +586,14 @@ mod tests { for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(&orderings); eq_properties.add_new_orderings(orderings); let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name)) + .map(|(expr, name)| (Arc::clone(expr), name)) .collect::>(); let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; @@ -643,15 +642,15 @@ mod tests { let col_c = &col("c", &schema)?; let col_ts = &col("ts", &schema)?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); let round_c = &create_physical_expr( &test_fun, - &[col_c.clone()], + &[Arc::clone(col_c)], &schema, &[], &DFSchema::empty(), @@ -670,7 +669,7 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name)) + .map(|(expr, name)| (Arc::clone(expr), name)) .collect::>(); let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; @@ -680,9 +679,9 @@ mod tests { let col_c_new = &col("c_new", &output_schema)?; let col_round_c_res = &col("round_c_res", &output_schema)?; let a_new_plus_b_new = Arc::new(BinaryExpr::new( - col_a_new.clone(), + Arc::clone(col_a_new), Operator::Plus, - col_b_new.clone(), + Arc::clone(col_b_new), )) as Arc; let test_cases = vec![ @@ -793,7 +792,7 @@ mod tests { ]; for (idx, (orderings, expected)) in test_cases.iter().enumerate() { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(orderings); eq_properties.add_new_orderings(orderings); @@ -801,7 +800,7 @@ mod tests { let expected = convert_to_orderings(expected); let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); + eq_properties.project(&projection_mapping, Arc::clone(&output_schema)); let orderings = projected_eq.oeq_class(); let err_msg = format!( @@ -834,9 +833,9 @@ mod tests { let col_e = &col("e", &schema)?; let col_f = &col("f", &schema)?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let option_asc = SortOptions { @@ -851,7 +850,7 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name)) + .map(|(expr, name)| (Arc::clone(expr), name)) .collect::>(); let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; @@ -936,7 +935,7 @@ mod tests { ), ]; for (orderings, equal_columns, expected) in test_cases { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); for (lhs, rhs) in equal_columns { eq_properties.add_equal_conditions(lhs, rhs)?; } @@ -947,7 +946,7 @@ mod tests { let expected = convert_to_orderings(&expected); let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); + eq_properties.project(&projection_mapping, Arc::clone(&output_schema)); let orderings = projected_eq.oeq_class(); let err_msg = format!( @@ -1006,7 +1005,7 @@ mod tests { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) .collect::>(); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), @@ -1084,7 +1083,7 @@ mod tests { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) .collect::>(); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), @@ -1097,7 +1096,7 @@ mod tests { let projected_exprs = projection_mapping .iter() - .map(|(_source, target)| target.clone()) + .map(|(_source, target)| Arc::clone(target)) .collect::>(); for n_req in 0..=projected_exprs.len() { @@ -1105,7 +1104,7 @@ mod tests { let requirement = exprs .into_iter() .map(|expr| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options: SORT_OPTIONS, }) .collect::>(); diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index e3a2d1c753ca..9a6a17f58c1f 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -213,13 +213,13 @@ impl EquivalenceProperties { // Left expression is constant, add right as constant if !const_exprs_contains(&self.constants, right) { self.constants - .push(ConstExpr::new(right.clone()).with_across_partitions(true)); + .push(ConstExpr::new(Arc::clone(right)).with_across_partitions(true)); } } else if self.is_expr_constant(right) { // Right expression is constant, add left as constant if !const_exprs_contains(&self.constants, left) { self.constants - .push(ConstExpr::new(left.clone()).with_across_partitions(true)); + .push(ConstExpr::new(Arc::clone(left)).with_across_partitions(true)); } } @@ -357,7 +357,7 @@ impl EquivalenceProperties { constant_exprs.extend( self.constants .iter() - .map(|const_expr| const_expr.expr().clone()), + .map(|const_expr| Arc::clone(const_expr.expr())), ); let constants_normalized = self.eq_group.normalize_exprs(constant_exprs); // Prune redundant sections in the requirement: @@ -424,11 +424,11 @@ impl EquivalenceProperties { fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { let ExprProperties { sort_properties, .. - } = self.get_expr_properties(req.expr.clone()); + } = self.get_expr_properties(Arc::clone(&req.expr)); match sort_properties { SortProperties::Ordered(options) => { let sort_expr = PhysicalSortExpr { - expr: req.expr.clone(), + expr: Arc::clone(&req.expr), options, }; sort_expr.satisfy(req, self.schema()) @@ -572,7 +572,7 @@ impl EquivalenceProperties { && cast_expr.is_bigger_cast(expr_type) { res.push(PhysicalSortExpr { - expr: r_expr.clone(), + expr: Arc::clone(&r_expr), options: sort_expr.options, }); } @@ -715,8 +715,9 @@ impl EquivalenceProperties { map: mapping .iter() .map(|(source, target)| { - let normalized_source = self.eq_group.normalize_expr(source.clone()); - (normalized_source, target.clone()) + let normalized_source = + self.eq_group.normalize_expr(Arc::clone(source)); + (normalized_source, Arc::clone(target)) }) .collect(), } @@ -758,7 +759,7 @@ impl EquivalenceProperties { }) .flat_map(|(options, relevant_deps)| { let sort_expr = PhysicalSortExpr { - expr: target.clone(), + expr: Arc::clone(target), options, }; // Generate dependent orderings (i.e. prefixes for `sort_expr`): @@ -831,8 +832,9 @@ impl EquivalenceProperties { && !const_exprs_contains(&projected_constants, target) { // Expression evaluates to single value - projected_constants - .push(ConstExpr::new(target.clone()).with_across_partitions(true)); + projected_constants.push( + ConstExpr::new(Arc::clone(target)).with_across_partitions(true), + ); } } projected_constants @@ -889,11 +891,11 @@ impl EquivalenceProperties { .flat_map(|&idx| { let ExprProperties { sort_properties, .. - } = eq_properties.get_expr_properties(exprs[idx].clone()); + } = eq_properties.get_expr_properties(Arc::clone(&exprs[idx])); match sort_properties { SortProperties::Ordered(options) => Some(( PhysicalSortExpr { - expr: exprs[idx].clone(), + expr: Arc::clone(&exprs[idx]), options, }, idx, @@ -903,7 +905,7 @@ impl EquivalenceProperties { let options = SortOptions::default(); Some(( PhysicalSortExpr { - expr: exprs[idx].clone(), + expr: Arc::clone(&exprs[idx]), options, }, idx, @@ -926,7 +928,7 @@ impl EquivalenceProperties { // an implementation strategy confined to this function. for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { eq_properties = eq_properties - .add_constants(std::iter::once(ConstExpr::new(expr.clone()))); + .add_constants(std::iter::once(ConstExpr::new(Arc::clone(expr)))); search_indices.shift_remove(idx); } // Add new ordered section to the state. @@ -954,9 +956,9 @@ impl EquivalenceProperties { let const_exprs = self .constants .iter() - .map(|const_expr| const_expr.expr().clone()); + .map(|const_expr| Arc::clone(const_expr.expr())); let normalized_constants = self.eq_group.normalize_exprs(const_exprs); - let normalized_expr = self.eq_group.normalize_expr(expr.clone()); + let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); is_constant_recurse(&normalized_constants, &normalized_expr) } @@ -1022,7 +1024,9 @@ fn update_properties( Interval::make_unbounded(&node.expr.data_type(eq_properties.schema())?)? } // Now, check what we know about orderings: - let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); + let normalized_expr = eq_properties + .eq_group + .normalize_expr(Arc::clone(&node.expr)); if eq_properties.is_expr_constant(&normalized_expr) { node.data.sort_properties = SortProperties::Singleton; } else if let Some(options) = eq_properties @@ -1108,7 +1112,7 @@ fn referred_dependencies( .keys() .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) { - let key = ExprWrapper(sort_expr.expr.clone()); + let key = ExprWrapper(Arc::clone(&sort_expr.expr)); expr_to_sort_exprs .entry(key) .or_default() @@ -1484,25 +1488,25 @@ mod tests { Field::new("c", DataType::Int64, true), ])); - let input_properties = EquivalenceProperties::new(input_schema.clone()); + let input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); let col_a = col("a", &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), + (Arc::clone(&col_a), "a1".to_string()), + (Arc::clone(&col_a), "a2".to_string()), + (Arc::clone(&col_a), "a3".to_string()), + (Arc::clone(&col_a), "a4".to_string()), ]; let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; let out_schema = output_schema(&projection_mapping, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), + (Arc::clone(&col_a), "a1".to_string()), + (Arc::clone(&col_a), "a2".to_string()), + (Arc::clone(&col_a), "a3".to_string()), + (Arc::clone(&col_a), "a4".to_string()), ]; let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; @@ -1532,8 +1536,8 @@ mod tests { let col_b = &col("b", &schema)?; let col_c = &col("c", &schema)?; let offset = schema.fields.len(); - let col_a2 = &add_offset_to_expr(col_a.clone(), offset); - let col_b2 = &add_offset_to_expr(col_b.clone(), offset); + let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset); + let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset); let option_asc = SortOptions { descending: false, nulls_first: false, @@ -1577,8 +1581,8 @@ mod tests { ), ]; for (left_orderings, right_orderings, expected) in test_cases { - let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); - let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); + let mut left_eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + let mut right_eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let left_orderings = convert_to_orderings(&left_orderings); let right_orderings = convert_to_orderings(&right_orderings); let expected = convert_to_orderings(&expected); @@ -1626,17 +1630,17 @@ mod tests { let col_b = col("b", &schema)?; let col_d = col("d", &schema)?; let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), + Arc::clone(&col_b), Operator::Plus, - col_d.clone(), + Arc::clone(&col_d), )) as Arc; - let constants = vec![col_a.clone(), col_b.clone()]; - let expr = b_plus_d.clone(); + let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b)]; + let expr = Arc::clone(&b_plus_d); assert!(!is_constant_recurse(&constants, &expr)); - let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; - let expr = b_plus_d.clone(); + let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b), Arc::clone(&col_d)]; + let expr = Arc::clone(&b_plus_d); assert!(is_constant_recurse(&constants, &expr)); Ok(()) } @@ -1726,11 +1730,11 @@ mod tests { eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr)?; let others = vec![ vec![PhysicalSortExpr { - expr: col_b_expr.clone(), + expr: Arc::clone(&col_b_expr), options: sort_options, }], vec![PhysicalSortExpr { - expr: col_c_expr.clone(), + expr: Arc::clone(&col_c_expr), options: sort_options, }], ]; @@ -1739,11 +1743,11 @@ mod tests { let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); expected_eqs.add_new_orderings([ vec![PhysicalSortExpr { - expr: col_b_expr.clone(), + expr: Arc::clone(&col_b_expr), options: sort_options, }], vec![PhysicalSortExpr { - expr: col_c_expr.clone(), + expr: Arc::clone(&col_c_expr), options: sort_options, }], ]); @@ -1766,7 +1770,7 @@ mod tests { ]); let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; + let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); eq_properties.add_new_orderings([vec![ PhysicalSortExpr { @@ -1784,11 +1788,11 @@ mod tests { result, vec![ PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: sort_options_not }, PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: sort_options } ] @@ -1801,7 +1805,7 @@ mod tests { ]); let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; + let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); eq_properties.add_new_orderings([ vec![PhysicalSortExpr { @@ -1825,11 +1829,11 @@ mod tests { result, vec![ PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: sort_options_not }, PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: sort_options } ] @@ -1890,11 +1894,11 @@ mod tests { // [b ASC], [d ASC] eq_properties.add_new_orderings(vec![ vec![PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: option_asc, }], vec![PhysicalSortExpr { - expr: col_d.clone(), + expr: Arc::clone(col_d), options: option_asc, }], ]); @@ -1903,22 +1907,22 @@ mod tests { // d + b ( Arc::new(BinaryExpr::new( - col_d.clone(), + Arc::clone(col_d), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc, SortProperties::Ordered(option_asc), ), // b - (col_b.clone(), SortProperties::Ordered(option_asc)), + (Arc::clone(col_b), SortProperties::Ordered(option_asc)), // a - (col_a.clone(), SortProperties::Ordered(option_asc)), + (Arc::clone(col_a), SortProperties::Ordered(option_asc)), // a + c ( Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_c.clone(), + Arc::clone(col_c), )), SortProperties::Unordered, ), @@ -1929,7 +1933,7 @@ mod tests { .iter() .flat_map(|ordering| ordering.first().cloned()) .collect::>(); - let expr_props = eq_properties.get_expr_properties(expr.clone()); + let expr_props = eq_properties.get_expr_properties(Arc::clone(&expr)); let err_msg = format!( "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", expr, expected, expr_props.sort_properties @@ -1987,7 +1991,7 @@ mod tests { .iter() .zip(ordering.iter()) .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: exprs[idx].clone(), + expr: Arc::clone(&exprs[idx]), options: sort_expr.options, }) .collect::>(); @@ -2034,9 +2038,9 @@ mod tests { let col_h = &col("h", &test_schema)?; // a + d let a_plus_d = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_d.clone(), + Arc::clone(col_d), )) as Arc; let option_asc = SortOptions { @@ -2050,11 +2054,11 @@ mod tests { // [d ASC, h DESC] also satisfies schema. eq_properties.add_new_orderings([vec![ PhysicalSortExpr { - expr: col_d.clone(), + expr: Arc::clone(col_d), options: option_asc, }, PhysicalSortExpr { - expr: col_h.clone(), + expr: Arc::clone(col_h), options: option_desc, }, ]]); @@ -2143,7 +2147,8 @@ mod tests { let col_h = &col("h", &test_schema)?; // Add column h as constant - eq_properties = eq_properties.add_constants(vec![ConstExpr::new(col_h.clone())]); + eq_properties = + eq_properties.add_constants(vec![ConstExpr::new(Arc::clone(col_h))]); let test_cases = vec![ // TEST CASE 1 @@ -2382,20 +2387,21 @@ mod tests { Field::new("b", DataType::Utf8, true), Field::new("c", DataType::Timestamp(TimeUnit::Nanosecond, None), true), ])); - let base_properties = EquivalenceProperties::new(schema.clone()).with_reorder( - ["a", "b", "c"] - .into_iter() - .map(|c| { - col(c, schema.as_ref()).map(|expr| PhysicalSortExpr { - expr, - options: SortOptions { - descending: false, - nulls_first: true, - }, + let base_properties = EquivalenceProperties::new(Arc::clone(&schema)) + .with_reorder( + ["a", "b", "c"] + .into_iter() + .map(|c| { + col(c, schema.as_ref()).map(|expr| PhysicalSortExpr { + expr, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }) }) - }) - .collect::>>()?, - ); + .collect::>>()?, + ); struct TestCase { name: &'static str, @@ -2414,8 +2420,11 @@ mod tests { TestCase { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order - constants: vec![col_b.clone()], - equal_conditions: vec![[cast_c.clone(), col_a.clone()]], + constants: vec![Arc::clone(&col_b)], + equal_conditions: vec![[ + Arc::clone(&cast_c) as Arc, + Arc::clone(&col_a), + ]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -2425,7 +2434,10 @@ mod tests { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order constants: vec![col_b], - equal_conditions: vec![[col_a.clone(), cast_c.clone()]], + equal_conditions: vec![[ + Arc::clone(&col_a), + Arc::clone(&cast_c) as Arc, + ]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -2434,7 +2446,10 @@ mod tests { // b is not constant anymore constants: vec![], // a and c are still compatible, but this is irrelevant since the original ordering is (a, b, c) - equal_conditions: vec![[cast_c.clone(), col_a.clone()]], + equal_conditions: vec![[ + Arc::clone(&cast_c) as Arc, + Arc::clone(&col_a), + ]], sort_columns: &["c"], should_satisfy_ordering: false, }, diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index d19279c20d10..f1e40575bc64 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -328,9 +328,9 @@ impl PhysicalExpr for BinaryExpr { children: Vec>, ) -> Result> { Ok(Arc::new(BinaryExpr::new( - children[0].clone(), + Arc::clone(&children[0]), self.op.clone(), - children[1].clone(), + Arc::clone(&children[1]), ))) } @@ -1493,8 +1493,11 @@ mod tests { let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); apply_arithmetic::( - schema.clone(), - vec![a.clone(), b.clone()], + Arc::clone(&schema), + vec![ + Arc::clone(&a) as Arc, + Arc::clone(&b) as Arc, + ], Operator::Minus, Int32Array::from(vec![0, 0, 1, 4, 11]), )?; @@ -2376,8 +2379,8 @@ mod tests { expected: BooleanArray, ) -> Result<()> { let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; - let data: Vec = vec![left.clone(), right.clone()]; - let batch = RecordBatch::try_new(schema.clone(), data)?; + let data: Vec = vec![Arc::clone(left), Arc::clone(right)]; + let batch = RecordBatch::try_new(Arc::clone(schema), data)?; let result = op .evaluate(&batch)? .into_array(batch.num_rows()) @@ -3471,8 +3474,8 @@ mod tests { expected: ArrayRef, ) -> Result<()> { let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; - let data: Vec = vec![left.clone(), right.clone()]; - let batch = RecordBatch::try_new(schema.clone(), data)?; + let data: Vec = vec![Arc::clone(left), Arc::clone(right)]; + let batch = RecordBatch::try_new(Arc::clone(schema), data)?; let result = arithmetic_op .evaluate(&batch)? .into_array(batch.num_rows()) @@ -3767,15 +3770,15 @@ mod tests { let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; let right = Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef; - let mut result = bitwise_and_dyn(left.clone(), right.clone())?; + let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = Int32Array::from(vec![Some(0), None, Some(3)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_or_dyn(left.clone(), right.clone())?; + result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = Int32Array::from(vec![Some(13), None, Some(15)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_xor_dyn(left.clone(), right.clone())?; + result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = Int32Array::from(vec![Some(13), None, Some(12)]); assert_eq!(result.as_ref(), &expected); @@ -3783,15 +3786,15 @@ mod tests { Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; let right = Arc::new(UInt32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef; - let mut result = bitwise_and_dyn(left.clone(), right.clone())?; + let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = UInt32Array::from(vec![Some(0), None, Some(3)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_or_dyn(left.clone(), right.clone())?; + result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = UInt32Array::from(vec![Some(13), None, Some(15)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_xor_dyn(left.clone(), right.clone())?; + result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = UInt32Array::from(vec![Some(13), None, Some(12)]); assert_eq!(result.as_ref(), &expected); @@ -3803,24 +3806,26 @@ mod tests { let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef; let modules = Arc::new(Int32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef; - let mut result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let mut result = + bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = Int32Array::from(vec![Some(8), None, Some(2560)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_shift_right_dyn(result.clone(), modules.clone())?; + result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?; assert_eq!(result.as_ref(), &input); let input = Arc::new(UInt32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef; let modules = Arc::new(UInt32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef; - let mut result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let mut result = + bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = UInt32Array::from(vec![Some(8), None, Some(2560)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_shift_right_dyn(result.clone(), modules.clone())?; + result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?; assert_eq!(result.as_ref(), &input); Ok(()) } @@ -3829,14 +3834,14 @@ mod tests { fn bitwise_shift_array_overflow_test() -> Result<()> { let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef; let modules = Arc::new(Int32Array::from(vec![Some(100)])) as ArrayRef; - let result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = Int32Array::from(vec![Some(32)]); assert_eq!(result.as_ref(), &expected); let input = Arc::new(UInt32Array::from(vec![Some(2)])) as ArrayRef; let modules = Arc::new(UInt32Array::from(vec![Some(100)])) as ArrayRef; - let result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = UInt32Array::from(vec![Some(32)]); assert_eq!(result.as_ref(), &expected); @@ -3973,9 +3978,12 @@ mod tests { Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef; // Casting Dictionary to Int32 - let casted = - to_result_type_array(&Operator::Plus, dictionary.clone(), &DataType::Int32) - .unwrap(); + let casted = to_result_type_array( + &Operator::Plus, + Arc::clone(&dictionary), + &DataType::Int32, + ) + .unwrap(); assert_eq!( &casted, &(Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)])) @@ -3985,16 +3993,19 @@ mod tests { // Array has same datatype as result type, no casting let casted = to_result_type_array( &Operator::Plus, - dictionary.clone(), + Arc::clone(&dictionary), dictionary.data_type(), ) .unwrap(); assert_eq!(&casted, &dictionary); // Not numerical operator, no casting - let casted = - to_result_type_array(&Operator::Eq, dictionary.clone(), &DataType::Int32) - .unwrap(); + let casted = to_result_type_array( + &Operator::Eq, + Arc::clone(&dictionary), + &DataType::Int32, + ) + .unwrap(); assert_eq!(&casted, &dictionary); } } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 08d8cd441334..b1707d3abfa1 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -173,8 +173,8 @@ impl CaseExpr { if let Some(e) = &self.else_expr { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(e.clone(), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| e.clone()); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); // null and unmatched tuples should be assigned else value remainder = or(&base_nulls, &remainder)?; let else_ = expr @@ -246,8 +246,8 @@ impl CaseExpr { if let Some(e) = &self.else_expr { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(e.clone(), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| e.clone()); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); let else_ = expr .evaluate_selection(batch, &remainder)? .into_array(batch.num_rows())?; @@ -887,26 +887,26 @@ mod tests { let expr1 = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![ - (when1.clone(), then1.clone()), - (when2.clone(), then2.clone()), + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), ], - Some(else_value.clone()), + Some(Arc::clone(&else_value)), &schema, )?; let expr2 = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![ - (when1.clone(), then1.clone()), - (when2.clone(), then2.clone()), + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), ], - Some(else_value.clone()), + Some(Arc::clone(&else_value)), &schema, )?; let expr3 = generate_case_when_with_type_coercion( Some(col("a", &schema)?), - vec![(when1.clone(), then1.clone()), (when2, then2)], + vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)], None, &schema, )?; @@ -943,15 +943,14 @@ mod tests { let expr = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![ - (when1.clone(), then1.clone()), - (when2.clone(), then2.clone()), + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), ], - Some(else_value.clone()), + Some(Arc::clone(&else_value)), &schema, )?; - let expr2 = expr - .clone() + let expr2 = Arc::clone(&expr) .transform(|e| { let transformed = match e.as_any().downcast_ref::() { @@ -972,8 +971,7 @@ mod tests { .data() .unwrap(); - let expr3 = expr - .clone() + let expr3 = Arc::clone(&expr) .transform_down(|e| { let transformed = match e.as_any().downcast_ref::() { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 53c790ff6b54..8a3885030b9d 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -389,7 +389,7 @@ impl PhysicalExpr for InListExpr { ) -> Result> { // assume the static_filter will not change during the rewrite process Ok(Arc::new(InListExpr::new( - children[0].clone(), + Arc::clone(&children[0]), children[1..].to_vec(), self.negated, self.static_filter.clone(), @@ -540,7 +540,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -551,7 +551,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -562,7 +562,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -573,7 +573,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -598,7 +598,7 @@ mod tests { list.clone(), &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -608,7 +608,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -623,7 +623,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -633,7 +633,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -654,7 +654,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -665,7 +665,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -676,7 +676,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -687,7 +687,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -714,7 +714,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None, Some(false), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -725,7 +725,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None, Some(true), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -736,7 +736,7 @@ mod tests { list, &false, vec![Some(true), None, None, None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -747,7 +747,7 @@ mod tests { list, &true, vec![Some(false), None, None, None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -758,7 +758,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None, Some(true), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -769,7 +769,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None, Some(false), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -780,7 +780,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None, Some(false), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -791,7 +791,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None, Some(true), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -812,7 +812,7 @@ mod tests { list, &false, vec![Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -823,7 +823,7 @@ mod tests { list, &true, vec![Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -834,7 +834,7 @@ mod tests { list, &false, vec![Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -845,7 +845,7 @@ mod tests { list, &true, vec![Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -869,7 +869,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -883,7 +883,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -898,7 +898,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -913,7 +913,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -937,7 +937,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -951,7 +951,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -966,7 +966,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -981,7 +981,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1008,7 +1008,7 @@ mod tests { list, &false, vec![Some(true), None, Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); // expression: "a not in (100,200) @@ -1018,7 +1018,7 @@ mod tests { list, &true, vec![Some(false), None, Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1029,7 +1029,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); // expression: "a not in (200,NULL), the data type of list is INT32 AND NULL @@ -1038,7 +1038,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1049,7 +1049,7 @@ mod tests { list, &false, vec![Some(true), None, Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1060,7 +1060,7 @@ mod tests { list, &true, vec![Some(true), None, Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1073,7 +1073,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1082,7 +1082,7 @@ mod tests { list, &true, vec![Some(false), None, Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1168,7 +1168,7 @@ mod tests { list.clone(), &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1177,7 +1177,7 @@ mod tests { list.clone(), &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); Ok(()) @@ -1219,13 +1219,13 @@ mod tests { vec![Arc::new(a), Arc::new(b), Arc::new(c)], )?; - let list = vec![col_b.clone(), col_c.clone()]; + let list = vec![Arc::clone(&col_b), Arc::clone(&col_c)]; in_list!( batch, list.clone(), &false, vec![Some(false), Some(true), None, Some(true), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1234,7 +1234,7 @@ mod tests { list, &true, vec![Some(true), Some(false), None, Some(false), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1262,22 +1262,22 @@ mod tests { // static_filter has no nulls let list = vec![lit(1_i64), lit(2_i64)]; - test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false); + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); // static_filter has nulls let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)]; - test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); - let list = vec![c1_nullable.clone()]; - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true); + let list = vec![Arc::clone(&c1_nullable)]; + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); - let list = vec![c2_non_nullable.clone()]; - test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); + let list = vec![Arc::clone(&c2_non_nullable)]; + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); - let list = vec![c2_non_nullable.clone(), c2_non_nullable.clone()]; - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false); + let list = vec![Arc::clone(&c2_non_nullable), Arc::clone(&c2_non_nullable)]; + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); Ok(()) } @@ -1370,7 +1370,7 @@ mod tests { list.clone(), &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } @@ -1382,7 +1382,7 @@ mod tests { list.clone(), &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } @@ -1402,7 +1402,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } @@ -1414,7 +1414,7 @@ mod tests { list.clone(), &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 1918f0891fff..d8fa77585b5d 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -90,7 +90,7 @@ impl PhysicalExpr for IsNotNullExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(IsNotNullExpr::new(children[0].clone()))) + Ok(Arc::new(IsNotNullExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 3430efcd7635..41becafde6de 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -91,7 +91,7 @@ impl PhysicalExpr for IsNullExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(IsNullExpr::new(children[0].clone()))) + Ok(Arc::new(IsNullExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index e0c02b0a90e9..b84ba82b642d 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -123,8 +123,8 @@ impl PhysicalExpr for LikeExpr { Ok(Arc::new(LikeExpr::new( self.negated, self.case_insensitive, - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), ))) } diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index aed2675e0447..b5ebc250cb89 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -97,7 +97,7 @@ impl PhysicalExpr for NegativeExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(NegativeExpr::new(children[0].clone()))) + Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 9aaab0658d39..b69954e00bba 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -97,7 +97,7 @@ impl PhysicalExpr for NotExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(NotExpr::new(children[0].clone()))) + Ok(Arc::new(NotExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index d31306e239bd..3549a3df83bb 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -106,7 +106,7 @@ impl PhysicalExpr for TryCastExpr { children: Vec>, ) -> Result> { Ok(Arc::new(TryCastExpr::new( - children[0].clone(), + Arc::clone(&children[0]), self.cast_type.clone(), ))) } @@ -137,7 +137,7 @@ pub fn try_cast( ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { - Ok(expr.clone()) + Ok(Arc::clone(&expr)) } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(TryCastExpr::new(expr, cast_type))) } else { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 6fbcd461af66..ef9dd36cfb50 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -180,7 +180,7 @@ impl ExprIntervalGraphNode { /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { - let expr = node.expr.clone(); + let expr = Arc::clone(&node.expr); if let Some(literal) = expr.as_any().downcast_ref::() { let value = literal.value(); Interval::try_new(value.clone(), value.clone()) @@ -422,7 +422,7 @@ impl ExprIntervalGraph { let mut removals = vec![]; let mut expr_node_indices = exprs .iter() - .map(|e| (e.clone(), usize::MAX)) + .map(|e| (Arc::clone(e), usize::MAX)) .collect::>(); while let Some(node) = bfs.next(graph) { // Get the plan corresponding to this node: @@ -744,16 +744,17 @@ mod tests { schema: &Schema, ) -> Result<()> { let col_stats = vec![ - (exprs_with_interval.0.clone(), left_interval), - (exprs_with_interval.1.clone(), right_interval), + (Arc::clone(&exprs_with_interval.0), left_interval), + (Arc::clone(&exprs_with_interval.1), right_interval), ]; let expected = vec![ - (exprs_with_interval.0.clone(), left_expected), - (exprs_with_interval.1.clone(), right_expected), + (Arc::clone(&exprs_with_interval.0), left_expected), + (Arc::clone(&exprs_with_interval.1), right_expected), ]; let mut graph = ExprIntervalGraph::try_new(expr, schema)?; - let expr_indexes = graph - .gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec()); + let expr_indexes = graph.gather_node_indices( + &col_stats.iter().map(|(e, _)| Arc::clone(e)).collect_vec(), + ); let mut col_stat_nodes = col_stats .iter() @@ -870,14 +871,21 @@ mod tests { // left_watermark > right_watermark + 5 let left_and_1 = Arc::new(BinaryExpr::new( - left_col.clone(), + Arc::clone(&left_col) as Arc, Operator::Plus, Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), )); - let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone())); + let expr = Arc::new(BinaryExpr::new( + left_and_1, + Operator::Gt, + Arc::clone(&right_col) as Arc, + )); experiment( expr, - (left_col.clone(), right_col.clone()), + ( + Arc::clone(&left_col) as Arc, + Arc::clone(&right_col) as Arc, + ), Interval::make(Some(10_i32), Some(20_i32))?, Interval::make(Some(100), None)?, Interval::make(Some(10), Some(20))?, diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index 075b8240353d..cedf55bccbf2 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -41,12 +41,12 @@ pub fn gen_conjunctive_numerical_expr( ) -> Arc { let (op_1, op_2, op_3, op_4) = op; let left_and_1 = Arc::new(BinaryExpr::new( - left_col.clone(), + Arc::clone(&left_col), op_1, Arc::new(Literal::new(a)), )); let left_and_2 = Arc::new(BinaryExpr::new( - right_col.clone(), + Arc::clone(&right_col), op_2, Arc::new(Literal::new(b)), )); @@ -78,8 +78,18 @@ pub fn gen_conjunctive_temporal_expr( d: ScalarValue, schema: &Schema, ) -> Result, DataFusionError> { - let left_and_1 = binary(left_col.clone(), op_1, Arc::new(Literal::new(a)), schema)?; - let left_and_2 = binary(right_col.clone(), op_2, Arc::new(Literal::new(b)), schema)?; + let left_and_1 = binary( + Arc::clone(&left_col), + op_1, + Arc::new(Literal::new(a)), + schema, + )?; + let left_and_2 = binary( + Arc::clone(&right_col), + op_2, + Arc::new(Literal::new(b)), + schema, + )?; let right_and_1 = binary(left_col, op_3, Arc::new(Literal::new(c)), schema)?; let right_and_2 = binary(right_col, op_4, Arc::new(Literal::new(d)), schema)?; let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 06c73636773e..4f83ae01959b 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] pub mod aggregate; pub mod analysis; diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index 273c77fb1d5e..821b2c9fe17a 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -169,11 +169,11 @@ impl Partitioning { if !eq_groups.is_empty() { let normalized_required_exprs = required_exprs .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect::>(); let normalized_partition_exprs = partition_exprs .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect::>(); return physical_exprs_equal( &normalized_required_exprs, diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 127194f681a5..c60a772b9ce2 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -117,12 +117,12 @@ mod tests { // lit(true), lit(false), lit(4), lit(2), Col(a), Col(b) let physical_exprs: Vec> = vec![ - lit_true.clone(), - lit_false.clone(), - lit4.clone(), - lit2.clone(), - col_a_expr.clone(), - col_b_expr.clone(), + Arc::clone(&lit_true), + Arc::clone(&lit_false), + Arc::clone(&lit4), + Arc::clone(&lit2), + Arc::clone(&col_a_expr), + Arc::clone(&col_b_expr), ]; // below expressions are inside physical_exprs assert!(physical_exprs_contains(&physical_exprs, &lit_true)); @@ -146,10 +146,10 @@ mod tests { Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let vec1 = vec![lit_true.clone(), lit_false.clone()]; - let vec2 = vec![lit_true.clone(), col_b_expr.clone()]; - let vec3 = vec![lit2.clone(), lit1.clone()]; - let vec4 = vec![lit_true.clone(), lit_false.clone()]; + let vec1 = vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]; + let vec2 = vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]; + let vec3 = vec![Arc::clone(&lit2), Arc::clone(&lit1)]; + let vec4 = vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]; // these vectors are same assert!(physical_exprs_equal(&vec1, &vec1)); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8fe99cdca591..dbebf4c18b79 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -242,7 +242,7 @@ pub fn create_physical_expr( when_expr .iter() .zip(then_expr.iter()) - .map(|(w, t)| (w.clone(), t.clone())) + .map(|(w, t)| (Arc::clone(w), Arc::clone(t))) .collect(); let else_expr: Option> = if let Some(e) = &case.else_expr { @@ -288,7 +288,7 @@ pub fn create_physical_expr( create_physical_exprs(args, input_dfschema, execution_props)?; scalar_function::create_physical_expr( - func.clone().as_ref(), + Arc::clone(func).as_ref(), &physical_args, input_schema, args, @@ -307,9 +307,19 @@ pub fn create_physical_expr( // rewrite the between into the two binary operators let binary_expr = binary( - binary(value_expr.clone(), Operator::GtEq, low_expr, input_schema)?, + binary( + Arc::clone(&value_expr), + Operator::GtEq, + low_expr, + input_schema, + )?, Operator::And, - binary(value_expr.clone(), Operator::LtEq, high_expr, input_schema)?, + binary( + Arc::clone(&value_expr), + Operator::LtEq, + high_expr, + input_schema, + )?, input_schema, ); diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 10e29b41031d..83272fc9b269 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -153,7 +153,7 @@ impl PhysicalExpr for ScalarFunctionExpr { ) -> Result> { Ok(Arc::new(ScalarFunctionExpr::new( &self.name, - self.fun.clone(), + Arc::clone(&self.fun), children, self.return_type().clone(), ))) diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 070034116fb4..42e5e6fcf3ac 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -870,14 +870,12 @@ mod test { // Schema for testing fn schema() -> SchemaRef { - SCHEMA - .get_or_init(|| { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Int32, false), - ])) - }) - .clone() + Arc::clone(SCHEMA.get_or_init(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])) + })) } static SCHEMA: OnceLock = OnceLock::new(); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 492cb02941df..a33f65f92a61 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -111,7 +111,7 @@ pub fn convert_to_expr>( ) -> Vec> { sequence .into_iter() - .map(|elem| elem.borrow().expr.clone()) + .map(|elem| Arc::clone(&elem.borrow().expr)) .collect() } @@ -166,7 +166,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> for expr_node in node.children.iter() { self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } - self.visited_plans.push((expr.clone(), node_idx)); + self.visited_plans.push((Arc::clone(expr), node_idx)); node_idx } }; @@ -379,7 +379,7 @@ pub(crate) mod tests { } fn make_dummy_node(node: &ExprTreeNode) -> Result { - let expr = node.expr.clone(); + let expr = Arc::clone(&node.expr); let dummy_property = if expr.as_any().is::() { "Binary" } else if expr.as_any().is::() { diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 065260a73e0b..04d359903eae 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -137,7 +137,7 @@ impl WindowExpr for BuiltInWindowExpr { let order_bys_ref = &values[n_args..]; let mut window_frame_ctx = - WindowFrameContext::new(self.window_frame.clone(), sort_options); + WindowFrameContext::new(Arc::clone(&self.window_frame), sort_options); let mut last_range = Range { start: 0, end: 0 }; // We iterate on each row to calculate window frame range and and window function result for idx in 0..num_rows { @@ -217,7 +217,7 @@ impl WindowExpr for BuiltInWindowExpr { .window_frame_ctx .get_or_insert_with(|| { WindowFrameContext::new( - self.window_frame.clone(), + Arc::clone(&self.window_frame), sort_options.clone(), ) }) diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 9a7c89dca56c..1656b7c3033a 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -104,7 +104,7 @@ impl BuiltInWindowFunctionExpr for WindowShift { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { @@ -125,7 +125,7 @@ impl BuiltInWindowFunctionExpr for WindowShift { name: self.name.clone(), data_type: self.data_type.clone(), shift_offset: -self.shift_offset, - expr: self.expr.clone(), + expr: Arc::clone(&self.expr), default_value: self.default_value.clone(), ignore_nulls: self.ignore_nulls, })) @@ -209,7 +209,7 @@ fn shift_with_default_value( let value_len = array.len() as i64; if offset == 0 { - Ok(array.clone()) + Ok(Arc::clone(array)) } else if offset == i64::MIN || offset.abs() >= value_len { default_value.to_array_of_size(value_len as usize) } else { diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 4bd40066ff34..87c74579c639 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -116,7 +116,7 @@ impl BuiltInWindowFunctionExpr for NthValue { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { @@ -142,7 +142,7 @@ impl BuiltInWindowFunctionExpr for NthValue { }; Some(Arc::new(Self { name: self.name.clone(), - expr: self.expr.clone(), + expr: Arc::clone(&self.expr), data_type: self.data_type.clone(), kind: reversed_kind, ignore_nulls: self.ignore_nulls, diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 961f0884dd87..50e9632b2196 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -163,7 +163,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { aggregate: self.aggregate.with_new_expressions(args, vec![])?, partition_by: partition_bys, order_by: new_order_by, - window_frame: self.window_frame.clone(), + window_frame: Arc::clone(&self.window_frame), })) } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 3cf68379d72b..7020f7f5cf83 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -137,7 +137,7 @@ pub trait WindowExpr: Send + Sync + Debug { let order_by_exprs = self .order_by() .iter() - .map(|sort_expr| sort_expr.expr.clone()) + .map(|sort_expr| Arc::clone(&sort_expr.expr)) .collect::>(); WindowPhysicalExpressions { args, @@ -193,7 +193,7 @@ pub trait AggregateWindowExpr: WindowExpr { let sort_options: Vec = self.order_by().iter().map(|o| o.options).collect(); let mut window_frame_ctx = - WindowFrameContext::new(self.get_window_frame().clone(), sort_options); + WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options); self.get_result_column( &mut accumulator, batch, @@ -241,7 +241,7 @@ pub trait AggregateWindowExpr: WindowExpr { let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { let sort_options: Vec = self.order_by().iter().map(|o| o.options).collect(); - WindowFrameContext::new(self.get_window_frame().clone(), sort_options) + WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options) }); let out_col = self.get_result_column( accumulator, From dce77db316beb5bf5dc2326c350b2d642edbfed8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 5 Jul 2024 08:45:39 -0400 Subject: [PATCH 06/26] Add standalone `AnalyzerRule` example that implements row level access control (#11089) * Add standlone example AnalyzerRule * Apply suggestions from code review Co-authored-by: Jax Liu * update for api change * Apply suggestions from code review Co-authored-by: Jonah Gao --------- Co-authored-by: Jax Liu Co-authored-by: Jonah Gao --- datafusion-examples/README.md | 1 + datafusion-examples/examples/analyzer_rule.rs | 200 ++++++++++++++++++ 2 files changed, 201 insertions(+) create mode 100644 datafusion-examples/examples/analyzer_rule.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index dc92019035dd..f868a5310cbe 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -47,6 +47,7 @@ cargo run --example dataframe - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) - [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a detailed secondary index that covers the contents of several parquet files +- [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule to change a query's semantics (row level access control) - [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog - [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization - [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs new file mode 100644 index 000000000000..bd067be97b8b --- /dev/null +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use datafusion::prelude::SessionContext; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_optimizer::analyzer::AnalyzerRule; +use std::sync::{Arc, Mutex}; + +/// This example demonstrates how to add your own [`AnalyzerRule`] to +/// DataFusion. +/// +/// [`AnalyzerRule`]s transform [`LogicalPlan`]s prior to the DataFusion +/// optimization process, and can be used to change the plan's semantics (e.g. +/// output types). +/// +/// This example shows an `AnalyzerRule` which implements a simplistic of row +/// level access control scheme by introducing a filter to the query. +/// +/// See [optimizer_rule.rs] for an example of a optimizer rule +#[tokio::main] +pub async fn main() -> Result<()> { + // AnalyzerRules run before OptimizerRules. + // + // DataFusion includes several built in AnalyzerRules for tasks such as type + // coercion which change the types of expressions in the plan. Add our new + // rule to the context to run it during the analysis phase. + let rule = Arc::new(RowLevelAccessControl::new()); + let ctx = SessionContext::new(); + ctx.add_analyzer_rule(Arc::clone(&rule) as _); + + ctx.register_batch("employee", employee_batch())?; + + // Now, planning any SQL statement also invokes the AnalyzerRule + let plan = ctx + .sql("SELECT * FROM employee") + .await? + .into_optimized_plan()?; + + // Printing the query plan shows a filter has been added + // + // Filter: employee.position = Utf8("Engineer") + // TableScan: employee projection=[name, age, position] + println!("Logical Plan:\n\n{}\n", plan.display_indent()); + + // Execute the query, and indeed no Manager's are returned + // + // +-----------+-----+----------+ + // | name | age | position | + // +-----------+-----+----------+ + // | Andy | 11 | Engineer | + // | Oleks | 33 | Engineer | + // | Xiangpeng | 55 | Engineer | + // +-----------+-----+----------+ + ctx.sql("SELECT * FROM employee").await?.show().await?; + + // We can now change the access level to "Manager" and see the results + // + // +----------+-----+----------+ + // | name | age | position | + // +----------+-----+----------+ + // | Andrew | 22 | Manager | + // | Chunchun | 44 | Manager | + // +----------+-----+----------+ + rule.set_show_position("Manager"); + ctx.sql("SELECT * FROM employee").await?.show().await?; + + // The filters introduced by our AnalyzerRule are treated the same as any + // other filter by the DataFusion optimizer, including predicate push down + // (including into scans), simplifications, and similar optimizations. + // + // For example adding another predicate to the query + let plan = ctx + .sql("SELECT * FROM employee WHERE age > 30") + .await? + .into_optimized_plan()?; + + // We can see the DataFusion Optimizer has combined the filters together + // when we print out the plan + // + // Filter: employee.age > Int32(30) AND employee.position = Utf8("Manager") + // TableScan: employee projection=[name, age, position] + println!("Logical Plan:\n\n{}\n", plan.display_indent()); + + Ok(()) +} + +/// Example AnalyzerRule that implements a very basic "row level access +/// control" +/// +/// In this case, it adds a filter to the plan that removes all managers from +/// the result set. +#[derive(Debug)] +struct RowLevelAccessControl { + /// Models the current access level of the session + /// + /// This is value of the position column which should be included in the + /// result set. It is wrapped in a `Mutex` so we can change it during query + show_position: Mutex, +} + +impl RowLevelAccessControl { + fn new() -> Self { + Self { + show_position: Mutex::new("Engineer".to_string()), + } + } + + /// return the current position to show, as an expression + fn show_position(&self) -> Expr { + lit(self.show_position.lock().unwrap().clone()) + } + + /// specifies a different position to show in the result set + fn set_show_position(&self, access_level: impl Into) { + *self.show_position.lock().unwrap() = access_level.into(); + } +} + +impl AnalyzerRule for RowLevelAccessControl { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + // use the TreeNode API to recursively walk the LogicalPlan tree + // and all of its children (inputs) + let transfomed_plan = plan.transform(|plan| { + // This closure is called for each LogicalPlan node + // if it is a Scan node, add a filter to remove all managers + if is_employee_table_scan(&plan) { + // Use the LogicalPlanBuilder to add a filter to the plan + let filter = LogicalPlanBuilder::from(plan) + // Filter Expression: position = + .filter(col("position").eq(self.show_position()))? + .build()?; + + // `Transformed::yes` signals the plan was changed + Ok(Transformed::yes(filter)) + } else { + // `Transformed::no` + // signals the plan was not changed + Ok(Transformed::no(plan)) + } + })?; + + // the result of calling transform is a `Transformed` structure which + // contains + // + // 1. a flag signaling if any rewrite took place + // 2. a flag if the recursion stopped early + // 3. The actual transformed data (a LogicalPlan in this case) + // + // This example does not need the value of either flag, so simply + // extract the LogicalPlan "data" + Ok(transfomed_plan.data) + } + + fn name(&self) -> &str { + "table_access" + } +} + +fn is_employee_table_scan(plan: &LogicalPlan) -> bool { + if let LogicalPlan::TableScan(scan) = plan { + scan.table_name.table() == "employee" + } else { + false + } +} + +/// Return a RecordBatch with made up data about fictional employees +fn employee_batch() -> RecordBatch { + let name: ArrayRef = Arc::new(StringArray::from_iter_values([ + "Andy", + "Andrew", + "Oleks", + "Chunchun", + "Xiangpeng", + ])); + let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22, 33, 44, 55])); + let position = Arc::new(StringArray::from_iter_values([ + "Engineer", "Manager", "Engineer", "Manager", "Engineer", + ])); + RecordBatch::try_from_iter(vec![("name", name), ("age", age), ("position", position)]) + .unwrap() +} From df999d6675238c063ab29a4d11bffc353b271195 Mon Sep 17 00:00:00 2001 From: Nishi <46855953+Nishi46@users.noreply.github.com> Date: Fri, 5 Jul 2024 07:55:27 -0700 Subject: [PATCH 07/26] Replace println! with assert! if possible in DataFusion examples (#11237) * Replace println! with assert! if possible in DataFusion examples * Replace println! with assert! if possible in DataFusion examples * Update rewrite_expr.rs * port changes to other examples --------- Co-authored-by: Andrew Lamb --- .../examples/optimizer_rule.rs | 57 +++++++++++-------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 057852946341..b4663b345f64 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -19,7 +19,7 @@ use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow_schema::DataType; use datafusion::prelude::SessionContext; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{assert_batches_eq, Result, ScalarValue}; use datafusion_expr::{ BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility, @@ -54,39 +54,46 @@ pub async fn main() -> Result<()> { // We can see the effect of our rewrite on the output plan that the filter // has been rewritten to `my_eq` - // - // Filter: my_eq(person.age, Int32(22)) - // TableScan: person projection=[name, age] - println!("Logical Plan:\n\n{}\n", plan.display_indent()); + assert_eq!( + plan.display_indent().to_string(), + "Filter: my_eq(person.age, Int32(22))\ + \n TableScan: person projection=[name, age]" + ); // The query below doesn't respect a filter `where age = 22` because // the plan has been rewritten using UDF which returns always true // // And the output verifies the predicates have been changed (as the my_eq // function always returns true) - // - // +--------+-----+ - // | name | age | - // +--------+-----+ - // | Andy | 11 | - // | Andrew | 22 | - // | Oleks | 33 | - // +--------+-----+ - ctx.sql(sql).await?.show().await?; + assert_batches_eq!( + [ + "+--------+-----+", + "| name | age |", + "+--------+-----+", + "| Andy | 11 |", + "| Andrew | 22 |", + "| Oleks | 33 |", + "+--------+-----+", + ], + &ctx.sql(sql).await?.collect().await? + ); // however we can see the rule doesn't trigger for queries with predicates // other than `=` - // - // +-------+-----+ - // | name | age | - // +-------+-----+ - // | Andy | 11 | - // | Oleks | 33 | - // +-------+-----+ - ctx.sql("SELECT * FROM person WHERE age <> 22") - .await? - .show() - .await?; + assert_batches_eq!( + [ + "+-------+-----+", + "| name | age |", + "+-------+-----+", + "| Andy | 11 |", + "| Oleks | 33 |", + "+-------+-----+", + ], + &ctx.sql("SELECT * FROM person WHERE age <> 22") + .await? + .collect() + .await? + ); Ok(()) } From a0dd0a14426a05bd85a4fac56fcec1fea20b81d6 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Fri, 5 Jul 2024 22:56:23 +0800 Subject: [PATCH 08/26] minor: format `Expr::get_type()` (#11267) --- datafusion/expr/src/expr_schema.rs | 64 ++++++++++++++++-------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index a84931398f5b..45ade5c5993b 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -118,18 +118,18 @@ impl ExprSchemable for Expr { Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list - match arg_data_type{ - DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) =>{ - Ok(field.data_type().clone()) - } - DataType::Struct(_) => { - Ok(arg_data_type) - } + match arg_data_type { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()), + DataType::Struct(_) => Ok(arg_data_type), DataType::Null => { not_impl_err!("unnest() does not support null yet") } _ => { - plan_err!("unnest() can only be applied to array, struct and null") + plan_err!( + "unnest() can only be applied to array, struct and null" + ) } } } @@ -138,22 +138,22 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { - plan_datafusion_err!( - "{} {}", - err, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_data_types, - ) - ) - })?; - - // perform additional function arguments validation (due to limited - // expressiveness of `TypeSignature`), then infer return type - Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?) + // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { + plan_datafusion_err!( + "{} {}", + err, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_data_types, + ) + ) + })?; + + // perform additional function arguments validation (due to limited + // expressiveness of `TypeSignature`), then infer return type + Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args @@ -166,7 +166,8 @@ impl ExprSchemable for Expr { .collect::>>()?; match fun { WindowFunctionDefinition::AggregateUDF(udf) => { - let new_types = data_types_with_aggregate_udf(&data_types, udf).map_err(|err| { + let new_types = data_types_with_aggregate_udf(&data_types, udf) + .map_err(|err| { plan_datafusion_err!( "{} {}", err, @@ -179,9 +180,7 @@ impl ExprSchemable for Expr { })?; Ok(fun.return_type(&new_types, &nullability)?) } - _ => { - fun.return_type(&data_types, &nullability) - } + _ => fun.return_type(&data_types, &nullability), } } Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { @@ -198,7 +197,8 @@ impl ExprSchemable for Expr { fun.return_type(&data_types, &nullability) } AggregateFunctionDefinition::UDF(fun) => { - let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| { + let new_types = data_types_with_aggregate_udf(&data_types, fun) + .map_err(|err| { plan_datafusion_err!( "{} {}", err, @@ -237,7 +237,11 @@ impl ExprSchemable for Expr { Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), Expr::Placeholder(Placeholder { data_type, .. }) => { data_type.clone().ok_or_else(|| { - plan_datafusion_err!("Placeholder type could not be resolved. Make sure that the placeholder is bound to a concrete type, e.g. by providing parameter values.") + plan_datafusion_err!( + "Placeholder type could not be resolved. Make sure that the \ + placeholder is bound to a concrete type, e.g. by providing \ + parameter values." + ) }) } Expr::Wildcard { qualifier } => { From 7df000a333a7d4018e1446ef900f652288b1f104 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 5 Jul 2024 19:04:10 +0200 Subject: [PATCH 09/26] Fix hash join for nested types (#11232) * Fixes to 10749 and generalization * Add e2e tests for joins on struct * PR comments * Add Struct to can_hash method * Add explain query as well * Use EXCEPT to trigger failure * Update datafusion/sqllogictest/test_files/joins.slt Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: Liang-Chi Hsieh --- datafusion/expr/src/utils.rs | 1 + .../physical-plan/src/joins/hash_join.rs | 111 +++++++++++++++++- datafusion/sqllogictest/test_files/joins.slt | 52 ++++++++ 3 files changed, 161 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index e3b8db676c98..34e007207427 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -864,6 +864,7 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::List(_) => true, DataType::LargeList(_) => true, DataType::FixedSizeList(_, _) => true, + DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())), _ => false, } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index b2f9ef560745..c6ef9936b9c5 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1212,11 +1212,16 @@ fn eq_dyn_null( right: &dyn Array, null_equals_null: bool, ) -> Result { - // Nested datatypes cannot use the underlying not_distinct function and must use a special + // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special // implementation // - if left.data_type().is_nested() && null_equals_null { - return Ok(compare_op_for_nested(&Operator::Eq, &left, &right)?); + if left.data_type().is_nested() { + let op = if null_equals_null { + Operator::IsNotDistinctFrom + } else { + Operator::Eq + }; + return Ok(compare_op_for_nested(&op, &left, &right)?); } match (left.data_type(), right.data_type()) { _ if null_equals_null => not_distinct(&left, &right), @@ -1546,6 +1551,8 @@ mod tests { use arrow::array::{Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field}; + use arrow_array::StructArray; + use arrow_buffer::NullBuffer; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, ScalarValue, @@ -3844,6 +3851,104 @@ mod tests { Ok(()) } + fn build_table_struct( + struct_name: &str, + field_name_and_values: (&str, &Vec>), + nulls: Option, + ) -> Arc { + let (field_name, values) = field_name_and_values; + let inner_fields = vec![Field::new(field_name, DataType::Int32, true)]; + let schema = Schema::new(vec![Field::new( + struct_name, + DataType::Struct(inner_fields.clone().into()), + nulls.is_some(), + )]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(StructArray::new( + inner_fields.into(), + vec![Arc::new(Int32Array::from(values.clone()))], + nulls, + ))], + ) + .unwrap(); + let schema_ref = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema_ref, None).unwrap()) + } + + #[tokio::test] + async fn join_on_struct() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = + build_table_struct("n1", ("a", &vec![None, Some(1), Some(2), Some(3)]), None); + let right = + build_table_struct("n2", ("a", &vec![None, Some(1), Some(2), Some(4)]), None); + let on = vec![( + Arc::new(Column::new_with_schema("n1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["n1", "n2"]); + + let expected = [ + "+--------+--------+", + "| n1 | n2 |", + "+--------+--------+", + "| {a: } | {a: } |", + "| {a: 1} | {a: 1} |", + "| {a: 2} | {a: 2} |", + "+--------+--------+", + ]; + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_on_struct_with_nulls() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = + build_table_struct("n1", ("a", &vec![None]), Some(NullBuffer::new_null(1))); + let right = + build_table_struct("n2", ("a", &vec![None]), Some(NullBuffer::new_null(1))); + let on = vec![( + Arc::new(Column::new_with_schema("n1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, + )]; + + let (_, batches_null_eq) = join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Inner, + true, + task_ctx.clone(), + ) + .await?; + + let expected_null_eq = [ + "+----+----+", + "| n1 | n2 |", + "+----+----+", + "| | |", + "+----+----+", + ]; + assert_batches_eq!(expected_null_eq, &batches_null_eq); + + let (_, batches_null_neq) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + let expected_null_neq = + ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; + assert_batches_eq!(expected_null_neq, &batches_null_neq); + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 3cbeea0f9222..593de07f7d26 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -53,6 +53,20 @@ AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE join_t3(s3 struct) + AS VALUES + (NULL), + (struct(1)), + (struct(2)); + +statement ok +CREATE TABLE join_t4(s4 struct) + AS VALUES + (NULL), + (struct(2)), + (struct(3)); + # Left semi anti join statement ok @@ -1336,6 +1350,44 @@ physical_plan 10)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 11)------------MemoryExec: partitions=1, partition_sizes=[1] +# Join on struct +query TT +explain select join_t3.s3, join_t4.s4 +from join_t3 +inner join join_t4 on join_t3.s3 = join_t4.s4 +---- +logical_plan +01)Inner Join: join_t3.s3 = join_t4.s4 +02)--TableScan: join_t3 projection=[s3] +03)--TableScan: join_t4 projection=[s4] +physical_plan +01)CoalesceBatchesExec: target_batch_size=2 +02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s3@0, s4@0)] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([s3@0], 2), input_partitions=2 +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------MemoryExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=2 +08)------RepartitionExec: partitioning=Hash([s4@0], 2), input_partitions=2 +09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +10)----------MemoryExec: partitions=1, partition_sizes=[1] + +query ?? +select join_t3.s3, join_t4.s4 +from join_t3 +inner join join_t4 on join_t3.s3 = join_t4.s4 +---- +{id: 2} {id: 2} + +# join with struct key and nulls +# Note that intersect or except applies `null_equals_null` as true for Join. +query ? +SELECT * FROM join_t3 +EXCEPT +SELECT * FROM join_t4 +---- +{id: 1} + query TT EXPLAIN select count(*) From 13cb65e44136711befb87dd75fb8b41f814af16f Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sat, 6 Jul 2024 06:05:38 +0200 Subject: [PATCH 10/26] Infer count() aggregation is not null (#11256) * Fix test function name typo * Improve formatting * Infer count() aggregation is not null `count([DISTINCT [expr]])` aggregate function never returns null. Infer non-nullness of such aggregate expression. This allows elimination of the HAVING filter for a query such as SELECT ... count(*) AS c FROM ... GROUP BY ... HAVING c IS NOT NULL --- datafusion/core/tests/dataframe/mod.rs | 1 + datafusion/expr/src/expr_schema.rs | 3 ++ .../src/analyzer/count_wildcard_rule.rs | 32 +++++------ .../src/single_distinct_to_groupby.rs | 54 +++++++++---------- .../optimizer/tests/optimizer_integration.rs | 15 ++++++ 5 files changed, 62 insertions(+), 43 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c3bc2fcca2b5..e46a92e92818 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -234,6 +234,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { Ok(()) } + #[tokio::test] async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { let ctx = create_join_context()?; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 45ade5c5993b..1df5d6c4d736 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -330,6 +330,9 @@ impl ExprSchemable for Expr { match func_def { AggregateFunctionDefinition::BuiltIn(fun) => fun.nullable(), // TODO: UDF should be able to customize nullability + AggregateFunctionDefinition::UDF(udf) if udf.name() == "count" => { + Ok(false) + } AggregateFunctionDefinition::UDF(_) => Ok(true), } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 34f9802b1fd9..959ffdaaa212 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -127,9 +127,9 @@ mod tests { .project(vec![count(wildcard())])? .sort(vec![count(wildcard()).sort(true, false)])? .build()?; - let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64;N]\ - \n Projection: count(*) [count(*):Int64;N]\ - \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64;N]\ + let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) } @@ -152,9 +152,9 @@ mod tests { .build()?; let expected = "Filter: t1.a IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [count(*):Int64;N]\ - \n Projection: count(*) [count(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\ + \n Subquery: [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) @@ -175,9 +175,9 @@ mod tests { .build()?; let expected = "Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [count(*):Int64;N]\ - \n Projection: count(*) [count(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\ + \n Subquery: [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) @@ -207,9 +207,9 @@ mod tests { let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\ \n Filter: () > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [count(Int64(1)):Int64;N]\ - \n Projection: count(Int64(1)) [count(Int64(1)):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64;N]\ + \n Subquery: [count(Int64(1)):Int64]\ + \n Projection: count(Int64(1)) [count(Int64(1)):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64]\ \n Filter: outer_ref(t1.a) = t2.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; @@ -235,7 +235,7 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64;N]\ + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ \n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) @@ -249,8 +249,8 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: count(*) [count(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\ + let expected = "Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) } @@ -272,7 +272,7 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64;N]\ + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ \n Aggregate: groupBy=[[]], aggr=[[MAX(count(Int64(1))) AS MAX(count(*))]] [MAX(count(*)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index b3562b7065e1..7c66d659cbaf 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -396,8 +396,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\ + let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -419,7 +419,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -437,7 +437,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -456,7 +456,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -470,8 +470,8 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\ + let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -487,8 +487,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64;N]\ + let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -507,7 +507,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -534,8 +534,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), MAX(alias1)]] [a:UInt32, count(alias1):Int64;N, MAX(alias1):UInt32;N]\ + let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), MAX(alias1)]] [a:UInt32, count(alias1):Int64, MAX(alias1):UInt32;N]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -554,7 +554,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -569,8 +569,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ + let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64]\ + \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -599,8 +599,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64;N, MAX(alias1):UInt32;N]\ + let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, MAX(alias1):UInt32;N]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -608,7 +608,7 @@ mod tests { } #[test] - fn one_distinctand_and_two_common() -> Result<()> { + fn one_distinct_and_two_common() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -618,8 +618,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, count(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, count(alias1):Int64;N]\ + let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -637,8 +637,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, count(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), count(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, count(alias1):Int64;N]\ + let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), count(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -662,7 +662,7 @@ mod tests { .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -682,7 +682,7 @@ mod tests { .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -705,7 +705,7 @@ mod tests { .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -725,7 +725,7 @@ mod tests { .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -746,7 +746,7 @@ mod tests { .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index c501d5aaa4bf..c0863839dba1 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -294,6 +294,21 @@ fn eliminate_nested_filters() { assert_eq!(expected, format!("{plan:?}")); } +#[test] +fn eliminate_redundant_null_check_on_count() { + let sql = "\ + SELECT col_int32, count(*) c + FROM test + GROUP BY col_int32 + HAVING c IS NOT NULL"; + let plan = test_sql(sql).unwrap(); + let expected = "\ + Projection: test.col_int32, count(*) AS c\ + \n Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1)) AS count(*)]]\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{plan:?}")); +} + #[test] fn test_propagate_empty_relation_inner_join_and_unions() { let sql = "\ From 1b3a7af673e5b590da94cea74892ee2c0c097848 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sat, 6 Jul 2024 13:08:58 +0200 Subject: [PATCH 11/26] Fix count() docs around including null values (#11293) The count aggregate was documented to count null values, but it does not do that. The implemented behavior is correct, so let's fix docs. --- docs/source/user-guide/sql/aggregate_functions.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 427a7bf130a7..edb0e1d0c9f0 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -123,11 +123,9 @@ bool_or(expression) ### `count` -Returns the number of rows in the specified column. +Returns the number of non-null values in the specified column. -Count includes _null_ values in the total count. -To exclude _null_ values from the total count, include ` IS NOT NULL` -in the `WHERE` clause. +To include _null_ values in the total count, use `count(*)`. ``` count(expression) From 5657886121e5cec4d53d39d6771427c1dd3f910f Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sat, 6 Jul 2024 13:10:03 +0200 Subject: [PATCH 12/26] Remove unnecessary qualified names (#11292) Leverage existing imports. --- datafusion-examples/examples/advanced_udaf.rs | 2 +- .../src/approx_distinct.rs | 2 +- .../src/approx_percentile_cont.rs | 2 +- datafusion/functions-aggregate/src/average.rs | 2 +- .../functions-aggregate/src/bit_and_or_xor.rs | 8 +- datafusion/functions-aggregate/src/count.rs | 20 ++--- .../physical-expr-common/src/binary_map.rs | 4 +- .../src/expressions/cast.rs | 89 +++++++++---------- .../physical-expr/src/expressions/case.rs | 2 +- 9 files changed, 62 insertions(+), 69 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 2c672a18a738..48da09a51236 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -92,7 +92,7 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new("prod", args.return_type.clone(), true), Field::new("n", DataType::UInt32, true), diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 6f1e97a16380..7c6aef9944f6 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -211,7 +211,7 @@ where impl Accumulator for NumericHLLAccumulator where - T: ArrowPrimitiveType + std::fmt::Debug, + T: ArrowPrimitiveType + Debug, T::Native: Hash, { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 5ae5684d9cab..bbe7d21e2486 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -191,7 +191,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { } #[allow(rustdoc::private_intra_doc_links)] - /// See [`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`] for a description of the serialised + /// See [`TDigest::to_scalar_state()`] for a description of the serialised /// state. fn state_fields( &self, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 1dc1f10afce6..18642fb84329 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -489,7 +489,7 @@ where .into_iter() .zip(counts.into_iter()) .map(|(sum, count)| (self.avg_fn)(sum, count)) - .collect::>>()?; + .collect::>>()?; PrimitiveArray::new(averages.into(), Some(nulls)) // no copy .with_data_type(self.return_data_type.clone()) }; diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index ba9964270443..9224b06e407a 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -245,7 +245,7 @@ struct BitAndAccumulator { } impl std::fmt::Debug for BitAndAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "BitAndAccumulator({})", T::DATA_TYPE) } } @@ -290,7 +290,7 @@ struct BitOrAccumulator { } impl std::fmt::Debug for BitOrAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "BitOrAccumulator({})", T::DATA_TYPE) } } @@ -335,7 +335,7 @@ struct BitXorAccumulator { } impl std::fmt::Debug for BitXorAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "BitXorAccumulator({})", T::DATA_TYPE) } } @@ -380,7 +380,7 @@ struct DistinctBitXorAccumulator { } impl std::fmt::Debug for DistinctBitXorAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE) } } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 0fc8e32d7240..bd0155df0271 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -62,17 +62,15 @@ make_udaf_expr_and_func!( count_udaf ); -pub fn count_distinct(expr: Expr) -> datafusion_expr::Expr { - datafusion_expr::Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new_udf( - count_udaf(), - vec![expr], - true, - None, - None, - None, - ), - ) +pub fn count_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + count_udaf(), + vec![expr], + true, + None, + None, + None, + )) } pub struct Count { diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index 6d5ba737a1df..bff571f5b5be 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -255,7 +255,7 @@ where /// the same output type pub fn take(&mut self) -> Self { let mut new_self = Self::new(self.output_type); - std::mem::swap(self, &mut new_self); + mem::swap(self, &mut new_self); new_self } @@ -538,7 +538,7 @@ where /// this set, not including `self` pub fn size(&self) -> usize { self.map_size - + self.buffer.capacity() * std::mem::size_of::() + + self.buffer.capacity() * mem::size_of::() + self.offsets.allocated_size() + self.hashes_buffer.allocated_size() } diff --git a/datafusion/physical-expr-common/src/expressions/cast.rs b/datafusion/physical-expr-common/src/expressions/cast.rs index 31b96889fd62..8aba33932c56 100644 --- a/datafusion/physical-expr-common/src/expressions/cast.rs +++ b/datafusion/physical-expr-common/src/expressions/cast.rs @@ -366,9 +366,9 @@ mod tests { generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 3), + Decimal128(10, 3), Decimal128Array, - DataType::Decimal128(20, 6), + Decimal128(20, 6), [ Some(1_234_000), Some(2_222_000), @@ -387,9 +387,9 @@ mod tests { generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 3), + Decimal128(10, 3), Decimal128Array, - DataType::Decimal128(10, 2), + Decimal128(10, 2), [Some(123), Some(222), Some(0), Some(400), Some(500), None], None ); @@ -408,9 +408,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int8Array, - DataType::Int8, + Int8, [ Some(1_i8), Some(2_i8), @@ -430,9 +430,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int16Array, - DataType::Int16, + Int16, [ Some(1_i16), Some(2_i16), @@ -452,9 +452,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int32Array, - DataType::Int32, + Int32, [ Some(1_i32), Some(2_i32), @@ -473,9 +473,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int64Array, - DataType::Int64, + Int64, [ Some(1_i64), Some(2_i64), @@ -503,9 +503,9 @@ mod tests { .with_precision_and_scale(10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 3), + Decimal128(10, 3), Float32Array, - DataType::Float32, + Float32, [ Some(1.234_f32), Some(2.222_f32), @@ -524,9 +524,9 @@ mod tests { .with_precision_and_scale(20, 6)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(20, 6), + Decimal128(20, 6), Float64Array, - DataType::Float64, + Float64, [ Some(0.001234_f64), Some(0.002222_f64), @@ -545,10 +545,10 @@ mod tests { // int8 generic_test_cast!( Int8Array, - DataType::Int8, + Int8, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(3, 0), + Decimal128(3, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -556,10 +556,10 @@ mod tests { // int16 generic_test_cast!( Int16Array, - DataType::Int16, + Int16, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(5, 0), + Decimal128(5, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -567,10 +567,10 @@ mod tests { // int32 generic_test_cast!( Int32Array, - DataType::Int32, + Int32, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -578,10 +578,10 @@ mod tests { // int64 generic_test_cast!( Int64Array, - DataType::Int64, + Int64, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(20, 0), + Decimal128(20, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -589,10 +589,10 @@ mod tests { // int64 to different scale generic_test_cast!( Int64Array, - DataType::Int64, + Int64, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(20, 2), + Decimal128(20, 2), [Some(100), Some(200), Some(300), Some(400), Some(500)], None ); @@ -600,10 +600,10 @@ mod tests { // float32 generic_test_cast!( Float32Array, - DataType::Float32, + Float32, vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, - DataType::Decimal128(10, 2), + Decimal128(10, 2), [Some(150), Some(250), Some(300), Some(112), Some(550)], None ); @@ -611,10 +611,10 @@ mod tests { // float64 generic_test_cast!( Float64Array, - DataType::Float64, + Float64, vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, - DataType::Decimal128(20, 4), + Decimal128(20, 4), [ Some(15000), Some(25000), @@ -631,10 +631,10 @@ mod tests { fn test_cast_i32_u32() -> Result<()> { generic_test_cast!( Int32Array, - DataType::Int32, + Int32, vec![1, 2, 3, 4, 5], UInt32Array, - DataType::UInt32, + UInt32, [ Some(1_u32), Some(2_u32), @@ -651,10 +651,10 @@ mod tests { fn test_cast_i32_utf8() -> Result<()> { generic_test_cast!( Int32Array, - DataType::Int32, + Int32, vec![1, 2, 3, 4, 5], StringArray, - DataType::Utf8, + Utf8, [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")], None ); @@ -670,10 +670,10 @@ mod tests { .collect(); generic_test_cast!( Int64Array, - DataType::Int64, + Int64, original, TimestampNanosecondArray, - DataType::Timestamp(TimeUnit::Nanosecond, None), + Timestamp(TimeUnit::Nanosecond, None), expected, None ); @@ -683,7 +683,7 @@ mod tests { #[test] fn invalid_cast() { // Ensure a useful error happens at plan time if invalid casts are used - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema = Schema::new(vec![Field::new("a", Int32, false)]); let result = cast( col("a", &schema).unwrap(), @@ -696,11 +696,10 @@ mod tests { #[test] fn invalid_cast_with_options_error() -> Result<()> { // Ensure a useful error happens at plan time if invalid casts are used - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let schema = Schema::new(vec![Field::new("a", Utf8, false)]); let a = StringArray::from(vec!["9.1"]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let expression = - cast_with_options(col("a", &schema)?, &schema, DataType::Int32, None)?; + let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?; let result = expression.evaluate(&batch); match result { @@ -717,15 +716,11 @@ mod tests { #[test] #[ignore] // TODO: https://github.com/apache/datafusion/issues/5396 fn test_cast_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + let schema = Schema::new(vec![Field::new("a", Int64, false)]); let a = Int64Array::from(vec![100]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let expression = cast_with_options( - col("a", &schema)?, - &schema, - DataType::Decimal128(38, 38), - None, - )?; + let expression = + cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?; expression.evaluate(&batch)?; Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index b1707d3abfa1..cd73c5cb579c 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -870,7 +870,7 @@ mod tests { ); assert!(expr.is_ok()); let result_type = expr.unwrap().data_type(schema.as_ref())?; - assert_eq!(DataType::Float64, result_type); + assert_eq!(Float64, result_type); Ok(()) } From 682fc054524e1aec4874bf6e76970b680e79d5aa Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sat, 6 Jul 2024 13:26:11 +0200 Subject: [PATCH 13/26] Fix running examples readme (#11225) * Fix running examples readme Some examples are runnable from any place (e.g. `csv_sql`), but some expect a specific working directory (e.g. `regexp`). Running from `datafusion-examples/examples` is tested on CI so guaranteed to work, let's put this path in the README. As a follow-up, we should look what it would take to make examples runnable directly from an IDE such as RustRover. * Remove doubled fmt check from CI Examples' format is checked in `check-fmt` job, so can be skipped in `rust_example.sh`. --- ci/scripts/rust_example.sh | 1 - datafusion-examples/README.md | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ci/scripts/rust_example.sh b/ci/scripts/rust_example.sh index 0415090665d2..1bb97c88106f 100755 --- a/ci/scripts/rust_example.sh +++ b/ci/scripts/rust_example.sh @@ -19,7 +19,6 @@ set -ex cd datafusion-examples/examples/ -cargo fmt --all -- --check cargo check --examples files=$(ls .) diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index f868a5310cbe..90469e6715a6 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -36,6 +36,9 @@ cd datafusion # Download test data git submodule update --init +# Change to the examples directory +cd datafusion-examples/examples + # Run the `dataframe` example: # ... use the equivalent for other examples cargo run --example dataframe From 6f86bfad2fa12478c29eaa14355d0801b4ebf489 Mon Sep 17 00:00:00 2001 From: Arttu Date: Sat, 6 Jul 2024 13:27:18 +0200 Subject: [PATCH 14/26] feat: enable "substring" as a UDF in addition to "substr" (#11277) * feat: enable "substring" as a UDF in addition to "substr" Substrait uses the name "substring", and it already exists in DF SQL The setup here is a bit weird; I'd have added substring as an alias for substr, but then we have here this "substring" version being created as udf already and exported through the export_functions, with slightly different args than substr (even though in reality the underlying function for both is the same substr impl). I think this PR should work, but if you have suggestions on how to make the situation here cleaner, I'd be happy to! * okay redo everything: add an alias instead, and add renaming in the substrait producer * add alias into scalar_functions.md --- datafusion/functions/src/unicode/substr.rs | 6 ++ .../substrait/src/logical_plan/consumer.rs | 60 +++++++++---------- .../substrait/src/logical_plan/producer.rs | 32 ++++++---- .../tests/cases/roundtrip_logical_plan.rs | 2 +- .../source/user-guide/sql/scalar_functions.md | 8 +++ 5 files changed, 65 insertions(+), 43 deletions(-) diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index c297182057fe..9d15920bb655 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -32,6 +32,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub struct SubstrFunc { signature: Signature, + aliases: Vec, } impl Default for SubstrFunc { @@ -53,6 +54,7 @@ impl SubstrFunc { ], Volatility::Immutable, ), + aliases: vec![String::from("substring")], } } } @@ -81,6 +83,10 @@ impl ScalarUDFImpl for SubstrFunc { other => exec_err!("Unsupported data type {other:?} for function substr"), } } + + fn aliases(&self) -> &[String] { + &self.aliases + } } /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index cc10ea0619c1..c65943643e8c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -88,36 +88,36 @@ use substrait::proto::{ }; use substrait::proto::{FunctionArgument, SortField}; -pub fn name_to_op(name: &str) -> Result { +pub fn name_to_op(name: &str) -> Option { match name { - "equal" => Ok(Operator::Eq), - "not_equal" => Ok(Operator::NotEq), - "lt" => Ok(Operator::Lt), - "lte" => Ok(Operator::LtEq), - "gt" => Ok(Operator::Gt), - "gte" => Ok(Operator::GtEq), - "add" => Ok(Operator::Plus), - "subtract" => Ok(Operator::Minus), - "multiply" => Ok(Operator::Multiply), - "divide" => Ok(Operator::Divide), - "mod" => Ok(Operator::Modulo), - "and" => Ok(Operator::And), - "or" => Ok(Operator::Or), - "is_distinct_from" => Ok(Operator::IsDistinctFrom), - "is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom), - "regex_match" => Ok(Operator::RegexMatch), - "regex_imatch" => Ok(Operator::RegexIMatch), - "regex_not_match" => Ok(Operator::RegexNotMatch), - "regex_not_imatch" => Ok(Operator::RegexNotIMatch), - "bitwise_and" => Ok(Operator::BitwiseAnd), - "bitwise_or" => Ok(Operator::BitwiseOr), - "str_concat" => Ok(Operator::StringConcat), - "at_arrow" => Ok(Operator::AtArrow), - "arrow_at" => Ok(Operator::ArrowAt), - "bitwise_xor" => Ok(Operator::BitwiseXor), - "bitwise_shift_right" => Ok(Operator::BitwiseShiftRight), - "bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft), - _ => not_impl_err!("Unsupported function name: {name:?}"), + "equal" => Some(Operator::Eq), + "not_equal" => Some(Operator::NotEq), + "lt" => Some(Operator::Lt), + "lte" => Some(Operator::LtEq), + "gt" => Some(Operator::Gt), + "gte" => Some(Operator::GtEq), + "add" => Some(Operator::Plus), + "subtract" => Some(Operator::Minus), + "multiply" => Some(Operator::Multiply), + "divide" => Some(Operator::Divide), + "mod" => Some(Operator::Modulo), + "and" => Some(Operator::And), + "or" => Some(Operator::Or), + "is_distinct_from" => Some(Operator::IsDistinctFrom), + "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), + "regex_match" => Some(Operator::RegexMatch), + "regex_imatch" => Some(Operator::RegexIMatch), + "regex_not_match" => Some(Operator::RegexNotMatch), + "regex_not_imatch" => Some(Operator::RegexNotIMatch), + "bitwise_and" => Some(Operator::BitwiseAnd), + "bitwise_or" => Some(Operator::BitwiseOr), + "str_concat" => Some(Operator::StringConcat), + "at_arrow" => Some(Operator::AtArrow), + "arrow_at" => Some(Operator::ArrowAt), + "bitwise_xor" => Some(Operator::BitwiseXor), + "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), + "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), + _ => None, } } @@ -1124,7 +1124,7 @@ pub async fn from_substrait_rex( Ok(Arc::new(Expr::ScalarFunction( expr::ScalarFunction::new_udf(func.to_owned(), args), ))) - } else if let Ok(op) = name_to_op(fn_name) { + } else if let Some(op) = name_to_op(fn_name) { if f.arguments.len() < 2 { return not_impl_err!( "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c3bef1689d14..899fec21f8bb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -818,7 +818,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } - let function_anchor = _register_function(fun.to_string(), extension_info); + let function_anchor = register_function(fun.to_string(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -849,7 +849,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } - let function_anchor = _register_function(fun.name().to_string(), extension_info); + let function_anchor = register_function(fun.name().to_string(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -917,7 +917,7 @@ fn to_substrait_sort_field( } } -fn _register_function( +fn register_function( function_name: String, extension_info: &mut ( Vec, @@ -926,6 +926,14 @@ fn _register_function( ) -> u32 { let (function_extensions, function_set) = extension_info; let function_name = function_name.to_lowercase(); + + // Some functions are named differently in Substrait default extensions than in DF + // Rename those to match the Substrait extensions for interoperability + let function_name = match function_name.as_str() { + "substr" => "substring".to_string(), + _ => function_name, + }; + // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, // a plan-relative identifier starting from 0 is used as the function_anchor. // The consumer is responsible for correctly registering @@ -969,7 +977,7 @@ pub fn make_binary_op_scalar_func( ), ) -> Expression { let function_anchor = - _register_function(operator_to_name(op).to_string(), extension_info); + register_function(operator_to_name(op).to_string(), extension_info); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1044,7 +1052,7 @@ pub fn to_substrait_rex( if *negated { let function_anchor = - _register_function("not".to_string(), extension_info); + register_function("not".to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1076,7 +1084,7 @@ pub fn to_substrait_rex( } let function_anchor = - _register_function(fun.name().to_string(), extension_info); + register_function(fun.name().to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1252,7 +1260,7 @@ pub fn to_substrait_rex( null_treatment: _, }) => { // function reference - let function_anchor = _register_function(fun.to_string(), extension_info); + let function_anchor = register_function(fun.to_string(), extension_info); // arguments let mut arguments: Vec = vec![]; for arg in args { @@ -1330,7 +1338,7 @@ pub fn to_substrait_rex( }; if *negated { let function_anchor = - _register_function("not".to_string(), extension_info); + register_function("not".to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1727,9 +1735,9 @@ fn make_substrait_like_expr( ), ) -> Result { let function_anchor = if ignore_case { - _register_function("ilike".to_string(), extension_info) + register_function("ilike".to_string(), extension_info) } else { - _register_function("like".to_string(), extension_info) + register_function("like".to_string(), extension_info) }; let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; @@ -1759,7 +1767,7 @@ fn make_substrait_like_expr( }; if negated { - let function_anchor = _register_function("not".to_string(), extension_info); + let function_anchor = register_function("not".to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2128,7 +2136,7 @@ fn to_substrait_unary_scalar_fn( HashMap, ), ) -> Result { - let function_anchor = _register_function(fn_name.to_string(), extension_info); + let function_anchor = register_function(fn_name.to_string(), extension_info); let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 7ed376f62ba0..dbc2e404bf56 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -327,7 +327,7 @@ async fn simple_scalar_function_pow() -> Result<()> { #[tokio::test] async fn simple_scalar_function_substr() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await + roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await } #[tokio::test] diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index ec34dbf9ba6c..d636726b45fe 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1132,6 +1132,14 @@ substr(str, start_pos[, length]) - **length**: Number of characters to extract. If not specified, returns the rest of the string after the start position. +#### Aliases + +- substring + +### `substring` + +_Alias of [substr](#substr)._ + ### `translate` Translates characters in a string to specified translation characters. From 00bbb42c96e041fe679631dae9f36bf51d4d998b Mon Sep 17 00:00:00 2001 From: Arttu Date: Sat, 6 Jul 2024 13:27:35 +0200 Subject: [PATCH 15/26] fix: correctly handle Substrait windows with rows bounds (and validate executability of test plans) (#11278) * add a test for window with rows specified that should fail * make roundtrip tests actually execute the plan previously we could create plans that seem valid but cannot actually be executed by DF E.g. the newly added test fails now with: Error: Internal("Rows should be Uint") * fix consuming bounds --- .../substrait/src/logical_plan/consumer.rs | 22 +++++--- .../substrait/src/logical_plan/producer.rs | 13 +++++ .../tests/cases/roundtrip_logical_plan.rs | 52 +++++++++---------- datafusion/substrait/tests/testdata/data.csv | 4 +- 4 files changed, 55 insertions(+), 36 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index c65943643e8c..03692399e1b3 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1544,12 +1544,22 @@ fn from_substrait_bound( BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { Ok(WindowFrameBound::CurrentRow) } - BoundKind::Preceding(SubstraitBound::Preceding { offset }) => Ok( - WindowFrameBound::Preceding(ScalarValue::Int64(Some(*offset))), - ), - BoundKind::Following(SubstraitBound::Following { offset }) => Ok( - WindowFrameBound::Following(ScalarValue::Int64(Some(*offset))), - ), + BoundKind::Preceding(SubstraitBound::Preceding { offset }) => { + if *offset <= 0 { + return plan_err!("Preceding bound must be positive"); + } + Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } + BoundKind::Following(SubstraitBound::Following { offset }) => { + if *offset <= 0 { + return plan_err!("Following bound must be positive"); + } + Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { if is_lower { Ok(WindowFrameBound::Preceding(ScalarValue::Null)) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 899fec21f8bb..959542080161 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2230,6 +2230,7 @@ mod test { use crate::logical_plan::consumer::{ from_substrait_literal_without_names, from_substrait_type_without_names, }; + use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::GenericListArray; use datafusion::arrow::datatypes::Field; use datafusion::common::scalar::ScalarStructBuilder; @@ -2309,6 +2310,14 @@ mod test { )?; round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?; + round_trip_literal(ScalarValue::IntervalYearMonth(Some(17)))?; + round_trip_literal(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano::new(17, 25, 1234567890), + )))?; + round_trip_literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( + 57, 123456, + ))))?; + Ok(()) } @@ -2376,6 +2385,10 @@ mod test { .into(), ))?; + round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; + round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; + round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; + Ok(()) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index dbc2e404bf56..52cfa50683a0 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -523,24 +523,6 @@ async fn roundtrip_arithmetic_ops() -> Result<()> { Ok(()) } -#[tokio::test] -async fn roundtrip_interval_literal() -> Result<()> { - roundtrip( - "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(YearMonth)')", - ) - .await?; - roundtrip( - "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(DayTime)')", - ) - .await?; - roundtrip( - "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(MonthDayNano)')", - ) - .await?; - - Ok(()) -} - #[tokio::test] async fn roundtrip_like() -> Result<()> { roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await @@ -650,6 +632,17 @@ async fn simple_window_function() -> Result<()> { roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, sum(b) OVER (PARTITION BY a) FROM data;").await } +#[tokio::test] +async fn window_with_rows() -> Result<()> { + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 FOLLOWING AND 4 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 4 PRECEDING AND 2 PRECEDING) FROM data;").await +} + #[tokio::test] async fn qualified_schema_table_reference() -> Result<()> { roundtrip("SELECT * FROM public.data;").await @@ -810,23 +803,20 @@ async fn roundtrip_aggregate_udf() -> Result<()> { struct Dummy {} impl Accumulator for Dummy { - fn state(&mut self) -> datafusion::error::Result> { - Ok(vec![]) + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Float64(None), ScalarValue::UInt32(None)]) } - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { Ok(()) } - fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { Ok(()) } - fn evaluate(&mut self) -> datafusion::error::Result { - Ok(ScalarValue::Float64(None)) + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(None)) } fn size(&self) -> usize { @@ -1060,6 +1050,8 @@ async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { assert_eq!(plan1str, plan2str); assert_eq!(plan.schema(), plan2.schema()); + + DataFrame::new(ctx.state(), plan2).show().await?; Ok(()) } @@ -1132,7 +1124,6 @@ async fn create_context() -> Result { Field::new("d", DataType::Boolean, true), Field::new("e", DataType::UInt32, true), Field::new("f", DataType::Utf8, true), - Field::new("g", DataType::Interval(IntervalUnit::DayTime), true), ]; let schema = Schema::new(fields); explicit_options.schema = Some(&schema); @@ -1195,6 +1186,11 @@ async fn create_all_type_context() -> Result { ), Field::new("decimal_128_col", DataType::Decimal128(10, 2), true), Field::new("decimal_256_col", DataType::Decimal256(10, 2), true), + Field::new( + "interval_day_time_col", + DataType::Interval(IntervalUnit::DayTime), + true, + ), ]); explicit_options.schema = Some(&schema); explicit_options.has_header = false; diff --git a/datafusion/substrait/tests/testdata/data.csv b/datafusion/substrait/tests/testdata/data.csv index 1b85b166b1df..ef2766d29565 100644 --- a/datafusion/substrait/tests/testdata/data.csv +++ b/datafusion/substrait/tests/testdata/data.csv @@ -1,3 +1,3 @@ a,b,c,d,e,f -1,2.0,2020-01-01,false,4294967296,'a' -3,4.5,2020-01-01,true,2147483648,'b' \ No newline at end of file +1,2.0,2020-01-01,false,4294967295,'a' +3,4.5,2020-01-01,true,2147483648,'b' From b9fdc53ac80c68f819191c2f2f9872ae64c5d3d8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 6 Jul 2024 07:28:19 -0400 Subject: [PATCH 16/26] Minor: Add `ConstExpr`::from, use in places (#11283) --- .../sort_preserving_repartition_fuzz.rs | 2 +- .../physical-expr/src/equivalence/class.rs | 28 +++++++++++++++++++ .../physical-expr/src/equivalence/mod.rs | 2 +- .../physical-expr/src/equivalence/ordering.rs | 6 ++-- .../src/equivalence/properties.rs | 22 +++++++-------- datafusion/physical-plan/src/filter.rs | 6 ++-- datafusion/physical-plan/src/union.rs | 2 +- datafusion/physical-plan/src/windows/mod.rs | 4 +-- 8 files changed, 47 insertions(+), 25 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index f00d17a06ffc..ceae13a469f0 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -80,7 +80,7 @@ mod sp_repartition_fuzz_tests { // Define a and f are aliases eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. - eq_properties = eq_properties.add_constants([ConstExpr::new(col_e.clone())]); + eq_properties = eq_properties.add_constants([ConstExpr::from(col_e)]); // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 4c0edd2a5d9a..e483f935b75c 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -42,12 +42,28 @@ use datafusion_common::JoinType; /// - `across_partitions`: A boolean flag indicating whether the constant expression is /// valid across partitions. If set to `true`, the constant expression has same value for all partitions. /// If set to `false`, the constant expression may have different values for different partitions. +/// +/// # Example +/// +/// ```rust +/// # use datafusion_physical_expr::ConstExpr; +/// # use datafusion_physical_expr_common::expressions::lit; +/// let col = lit(5); +/// // Create a constant expression from a physical expression ref +/// let const_expr = ConstExpr::from(&col); +/// // create a constant expression from a physical expression +/// let const_expr = ConstExpr::from(col); +/// ``` pub struct ConstExpr { expr: Arc, across_partitions: bool, } impl ConstExpr { + /// Create a new constant expression from a physical expression. + /// + /// Note you can also use `ConstExpr::from` to create a constant expression + /// from a reference as well pub fn new(expr: Arc) -> Self { Self { expr, @@ -85,6 +101,18 @@ impl ConstExpr { } } +impl From> for ConstExpr { + fn from(expr: Arc) -> Self { + Self::new(expr) + } +} + +impl From<&Arc> for ConstExpr { + fn from(expr: &Arc) -> Self { + Self::new(Arc::clone(expr)) + } +} + /// Checks whether `expr` is among in the `const_exprs`. pub fn const_exprs_contains( const_exprs: &[ConstExpr], diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 1ed9a4ac217f..83f94057f740 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -205,7 +205,7 @@ mod tests { // Define a and f are aliases eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. - eq_properties = eq_properties.add_constants([ConstExpr::new(Arc::clone(col_e))]); + eq_properties = eq_properties.add_constants([ConstExpr::from(col_e)]); // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index d71075dc77e1..c4b8a5c46563 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -556,9 +556,9 @@ mod tests { let eq_group = EquivalenceGroup::new(eq_group); eq_properties.add_equivalence_group(eq_group); - let constants = constants.into_iter().map(|expr| { - ConstExpr::new(Arc::clone(expr)).with_across_partitions(true) - }); + let constants = constants + .into_iter() + .map(|expr| ConstExpr::from(expr).with_across_partitions(true)); eq_properties = eq_properties.add_constants(constants); let reqs = convert_to_sort_exprs(&reqs); diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 9a6a17f58c1f..d9d19c0bcf47 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -213,13 +213,13 @@ impl EquivalenceProperties { // Left expression is constant, add right as constant if !const_exprs_contains(&self.constants, right) { self.constants - .push(ConstExpr::new(Arc::clone(right)).with_across_partitions(true)); + .push(ConstExpr::from(right).with_across_partitions(true)); } } else if self.is_expr_constant(right) { // Right expression is constant, add left as constant if !const_exprs_contains(&self.constants, left) { self.constants - .push(ConstExpr::new(Arc::clone(left)).with_across_partitions(true)); + .push(ConstExpr::from(left).with_across_partitions(true)); } } @@ -300,7 +300,7 @@ impl EquivalenceProperties { { if !const_exprs_contains(&self.constants, &expr) { let const_expr = - ConstExpr::new(expr).with_across_partitions(across_partitions); + ConstExpr::from(expr).with_across_partitions(across_partitions); self.constants.push(const_expr); } } @@ -404,7 +404,7 @@ impl EquivalenceProperties { // we add column `a` as constant to the algorithm state. This enables us // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. eq_properties = eq_properties - .add_constants(std::iter::once(ConstExpr::new(normalized_req.expr))); + .add_constants(std::iter::once(ConstExpr::from(normalized_req.expr))); } true } @@ -832,9 +832,8 @@ impl EquivalenceProperties { && !const_exprs_contains(&projected_constants, target) { // Expression evaluates to single value - projected_constants.push( - ConstExpr::new(Arc::clone(target)).with_across_partitions(true), - ); + projected_constants + .push(ConstExpr::from(target).with_across_partitions(true)); } } projected_constants @@ -927,8 +926,8 @@ impl EquivalenceProperties { // Note that these expressions are not properly "constants". This is just // an implementation strategy confined to this function. for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { - eq_properties = eq_properties - .add_constants(std::iter::once(ConstExpr::new(Arc::clone(expr)))); + eq_properties = + eq_properties.add_constants(std::iter::once(ConstExpr::from(expr))); search_indices.shift_remove(idx); } // Add new ordered section to the state. @@ -2147,8 +2146,7 @@ mod tests { let col_h = &col("h", &test_schema)?; // Add column h as constant - eq_properties = - eq_properties.add_constants(vec![ConstExpr::new(Arc::clone(col_h))]); + eq_properties = eq_properties.add_constants(vec![ConstExpr::from(col_h)]); let test_cases = vec![ // TEST CASE 1 @@ -2458,7 +2456,7 @@ mod tests { for case in cases { let mut properties = base_properties .clone() - .add_constants(case.constants.into_iter().map(ConstExpr::new)); + .add_constants(case.constants.into_iter().map(ConstExpr::from)); for [left, right] in &case.equal_conditions { properties.add_equal_conditions(left, right)? } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index c141958c1171..ab7a63e44550 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -173,13 +173,11 @@ impl FilterExec { // Filter evaluates to single value for all partitions if input_eqs.is_expr_constant(binary.left()) { res_constants.push( - ConstExpr::new(binary.right().clone()) - .with_across_partitions(true), + ConstExpr::from(binary.right()).with_across_partitions(true), ) } else if input_eqs.is_expr_constant(binary.right()) { res_constants.push( - ConstExpr::new(binary.left().clone()) - .with_across_partitions(true), + ConstExpr::from(binary.left()).with_across_partitions(true), ) } } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 3f88eb4c3732..867cddeb7b41 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -188,7 +188,7 @@ fn calculate_union_eq_properties( // TODO: Check whether constant expressions evaluates the same value or not for each partition let across_partitions = false; return Some( - ConstExpr::new(meet_constant.owned_expr()) + ConstExpr::from(meet_constant.owned_expr()) .with_across_partitions(across_partitions), ); } diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d798eaacc787..0622aad74cad 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -598,9 +598,7 @@ pub fn get_window_mode( options: None, })); // Treat partition by exprs as constant. During analysis of requirements are satisfied. - let const_exprs = partitionby_exprs - .iter() - .map(|expr| ConstExpr::new(expr.clone())); + let const_exprs = partitionby_exprs.iter().map(ConstExpr::from); let partition_by_eqs = input_eqs.add_constants(const_exprs); let order_by_reqs = PhysicalSortRequirement::from_sort_exprs(orderby_keys); let reverse_order_by_reqs = From a3e1c3d055a2f4b6d63824ba021ea532d9c83558 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Sat, 6 Jul 2024 04:37:09 -0700 Subject: [PATCH 17/26] Implement TPCH substrait integration teset, support tpch_3 (#11298) --- .../tests/cases/consumer_integration.rs | 55 +- .../tests/testdata/tpch/customer.csv | 2 + .../substrait/tests/testdata/tpch/orders.csv | 2 + .../tpch_substrait_plans/query_3.json | 851 ++++++++++++++++++ 4 files changed, 903 insertions(+), 7 deletions(-) create mode 100644 datafusion/substrait/tests/testdata/tpch/customer.csv create mode 100644 datafusion/substrait/tests/testdata/tpch/orders.csv create mode 100644 datafusion/substrait/tests/testdata/tpch_substrait_plans/query_3.json diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 58f2fc900937..a8bbeb444a23 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -41,6 +41,17 @@ mod tests { .await } + async fn create_context_tpch1() -> Result { + let ctx = SessionContext::new(); + register_csv( + &ctx, + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + ) + .await?; + Ok(ctx) + } + async fn create_context_tpch2() -> Result { let ctx = SessionContext::new(); @@ -63,14 +74,19 @@ mod tests { Ok(ctx) } - async fn create_context_tpch1() -> Result { + async fn create_context_tpch3() -> Result { let ctx = SessionContext::new(); - register_csv( - &ctx, - "FILENAME_PLACEHOLDER_0", - "tests/testdata/tpch/lineitem.csv", - ) - .await?; + + let registrations = vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ]; + + for (table_name, file_path) in registrations { + register_csv(&ctx, table_name, file_path).await?; + } + Ok(ctx) } @@ -139,4 +155,29 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn tpch_test_3() -> Result<()> { + let ctx = create_context_tpch3().await?; + let path = "tests/testdata/tpch_substrait_plans/query_3.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_2.l_orderkey AS L_ORDERKEY, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE, FILENAME_PLACEHOLDER_1.o_orderdate AS O_ORDERDATE, FILENAME_PLACEHOLDER_1.o_shippriority AS O_SHIPPRIORITY\ + \n Limit: skip=0, fetch=10\ + \n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST, FILENAME_PLACEHOLDER_1.o_orderdate ASC NULLS LAST\ + \n Projection: FILENAME_PLACEHOLDER_2.l_orderkey, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount), FILENAME_PLACEHOLDER_1.o_orderdate, FILENAME_PLACEHOLDER_1.o_shippriority\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_2.l_orderkey, FILENAME_PLACEHOLDER_1.o_orderdate, FILENAME_PLACEHOLDER_1.o_shippriority]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ + \n Projection: FILENAME_PLACEHOLDER_2.l_orderkey, FILENAME_PLACEHOLDER_1.o_orderdate, FILENAME_PLACEHOLDER_1.o_shippriority, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.c_mktsegment = CAST(Utf8(\"HOUSEHOLD\") AS Utf8) AND FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_1.o_orderdate < Date32(\"1995-03-25\") AND FILENAME_PLACEHOLDER_2.l_shipdate > Date32(\"1995-03-25\")\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"); + Ok(()) + } } diff --git a/datafusion/substrait/tests/testdata/tpch/customer.csv b/datafusion/substrait/tests/testdata/tpch/customer.csv new file mode 100644 index 000000000000..ed15da17d47d --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/customer.csv @@ -0,0 +1,2 @@ +c_custkey,c_name,c_address,c_nationkey,c_phone,c_acctbal,c_mktsegment,c_comment +1,Customer#000000001,Address1,1,123-456-7890,5000.00,BUILDING,No comment \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/orders.csv b/datafusion/substrait/tests/testdata/tpch/orders.csv new file mode 100644 index 000000000000..b9abea3cbb5b --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/orders.csv @@ -0,0 +1,2 @@ +o_orderkey,o_custkey,o_orderstatus,o_totalprice,o_orderdate,o_orderpriority,o_clerk,o_shippriority,o_comment +1,1,O,1000.00,2023-01-01,5-LOW,Clerk#000000001,0,No comment \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_3.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_3.json new file mode 100644 index 000000000000..4ca074d2e8cf --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_3.json @@ -0,0 +1,851 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any1_any1" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "gt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "multiply:opt_decimal_decimal" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "subtract:opt_decimal_decimal" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "sum:opt_decimal" + } + }], + "relations": [{ + "root": { + "input": { + "fetch": { + "common": { + "direct": { + } + }, + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [4, 5, 6, 7] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [33, 34, 35, 36] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 117, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 79, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "HOUSEHOLD", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "date": 9214, + "nullable": false, + "typeVariationReference": 0 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 27 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "date": 9214, + "nullable": false, + "typeVariationReference": 0 + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 6, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "offset": "0", + "count": "10" + } + }, + "names": ["L_ORDERKEY", "REVENUE", "O_ORDERDATE", "O_SHIPPRIORITY"] + } + }], + "expectedTypeUrls": [] +} From 2af3d3a55ba67726ff915743680380fb02a3e2c2 Mon Sep 17 00:00:00 2001 From: Xin Li <33629085+xinlifoobar@users.noreply.github.com> Date: Sat, 6 Jul 2024 04:55:55 -0700 Subject: [PATCH 18/26] Implement user defined planner for position (#11243) * Implement user defined planner for position * Fix format * Move planner to session_state * Extract function --- .../core/src/execution/session_state.rs | 2 + datafusion/expr/src/planner.rs | 6 +++ datafusion/functions/src/unicode/mod.rs | 1 + datafusion/functions/src/unicode/planner.rs | 36 +++++++++++++++ datafusion/sql/src/expr/mod.rs | 44 +++++++++++-------- 5 files changed, 70 insertions(+), 19 deletions(-) create mode 100644 datafusion/functions/src/unicode/planner.rs diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index a831f92def50..ffaaa2df5e7e 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -240,6 +240,8 @@ impl SessionState { Arc::new(functions_array::planner::FieldAccessPlanner), #[cfg(feature = "datetime_expressions")] Arc::new(functions::datetime::planner::ExtractPlanner), + #[cfg(feature = "unicode_expressions")] + Arc::new(functions::unicode::planner::PositionPlanner), ]; let mut new_self = SessionState { diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index bba0228ae0aa..bcbf5eb203ac 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -116,6 +116,12 @@ pub trait UserDefinedSQLPlanner: Send + Sync { Ok(PlannerResult::Original(exprs)) } + // Plan the POSITION expression, e.g., POSITION( in ) + // returns origin expression arguments if not possible + fn plan_position(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + /// Plan the dictionary literal `{ key: value, ...}` /// /// Returns origin expression arguments if not possible diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 9e8c07cd36ed..a391b8ba11dc 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -25,6 +25,7 @@ pub mod character_length; pub mod find_in_set; pub mod left; pub mod lpad; +pub mod planner; pub mod reverse; pub mod right; pub mod rpad; diff --git a/datafusion/functions/src/unicode/planner.rs b/datafusion/functions/src/unicode/planner.rs new file mode 100644 index 000000000000..4d6f73321b4a --- /dev/null +++ b/datafusion/functions/src/unicode/planner.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL planning extensions like [`PositionPlanner`] + +use datafusion_common::Result; +use datafusion_expr::{ + expr::ScalarFunction, + planner::{PlannerResult, UserDefinedSQLPlanner}, + Expr, +}; + +#[derive(Default)] +pub struct PositionPlanner; + +impl UserDefinedSQLPlanner for PositionPlanner { + fn plan_position(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::unicode::strpos(), args), + ))) + } +} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 2ddd2d22c022..6295821fa944 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -629,6 +629,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } + fn sql_position_to_expr( + &self, + substr_expr: SQLExpr, + str_expr: SQLExpr, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let substr = + self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; + let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; + let mut extract_args = vec![fullstr, substr]; + for planner in self.planners.iter() { + match planner.plan_position(extract_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => { + extract_args = args; + } + } + } + + not_impl_err!( + "Position not supported by UserDefinedExtensionPlanners: {extract_args:?}" + ) + } + fn try_plan_dictionary_literal( &self, fields: Vec, @@ -924,25 +949,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } - fn sql_position_to_expr( - &self, - substr_expr: SQLExpr, - str_expr: SQLExpr, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - let fun = self - .context_provider - .get_function_meta("strpos") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'strpos' function") - })?; - let substr = - self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; - let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; - let args = vec![fullstr, substr]; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) - } } #[cfg(test)] From 08c5345e932f1c5c948751e0d06b1fd99e174efa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 6 Jul 2024 07:57:54 -0400 Subject: [PATCH 19/26] Upgrade to arrow 52.1.0 (and fix clippy issues on main) (#11302) --- Cargo.toml | 18 +- datafusion-cli/Cargo.lock | 188 +++++++++--------- datafusion-cli/Cargo.toml | 4 +- datafusion/expr/src/type_coercion/binary.rs | 2 +- .../expr/src/type_coercion/functions.rs | 4 +- datafusion/functions/src/datetime/date_bin.rs | 14 +- 6 files changed, 115 insertions(+), 115 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6a6928e25bdd..968a74e37f10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,22 +64,22 @@ version = "39.0.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -arrow = { version = "52.0.0", features = [ +arrow = { version = "52.1.0", features = [ "prettyprint", ] } -arrow-array = { version = "52.0.0", default-features = false, features = [ +arrow-array = { version = "52.1.0", default-features = false, features = [ "chrono-tz", ] } -arrow-buffer = { version = "52.0.0", default-features = false } -arrow-flight = { version = "52.0.0", features = [ +arrow-buffer = { version = "52.1.0", default-features = false } +arrow-flight = { version = "52.1.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "52.0.0", default-features = false, features = [ +arrow-ipc = { version = "52.1.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "52.0.0", default-features = false } -arrow-schema = { version = "52.0.0", default-features = false } -arrow-string = { version = "52.0.0", default-features = false } +arrow-ord = { version = "52.1.0", default-features = false } +arrow-schema = { version = "52.1.0", default-features = false } +arrow-string = { version = "52.1.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" @@ -114,7 +114,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.10.1", default-features = false } parking_lot = "0.12" -parquet = { version = "52.0.0", default-features = false, features = [ +parquet = { version = "52.1.0", default-features = false, features = [ "arrow", "async", "object_store", diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5fc8dbcfdfb3..4fce2ec500e4 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ae9728f104939be6d8d9b368a354b4929b0569160ea1641f0721b55a861ce38" +checksum = "6127ea5e585a12ec9f742232442828ebaf264dfa5eefdd71282376c599562b77" dependencies = [ "arrow-arith", "arrow-array", @@ -151,9 +151,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7029a5b3efbeafbf4a12d12dc16b8f9e9bff20a410b8c25c5d28acc089e1043" +checksum = "7add7f39210b7d726e2a8efc0083e7bf06e8f2d15bdb4896b564dce4410fbf5d" dependencies = [ "arrow-array", "arrow-buffer", @@ -166,9 +166,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d33238427c60271710695f17742f45b1a5dc5bcfc5c15331c25ddfe7abf70d97" +checksum = "81c16ec702d3898c2f5cfdc148443c6cd7dbe5bac28399859eb0a3d38f072827" dependencies = [ "ahash", "arrow-buffer", @@ -183,9 +183,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe9b95e825ae838efaf77e366c00d3fc8cca78134c9db497d6bda425f2e7b7c1" +checksum = "cae6970bab043c4fbc10aee1660ceb5b306d0c42c8cc5f6ae564efcd9759b663" dependencies = [ "bytes", "half", @@ -194,9 +194,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cf8385a9d5b5fcde771661dd07652b79b9139fea66193eda6a88664400ccab" +checksum = "1c7ef44f26ef4f8edc392a048324ed5d757ad09135eff6d5509e6450d39e0398" dependencies = [ "arrow-array", "arrow-buffer", @@ -215,9 +215,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea5068bef430a86690059665e40034625ec323ffa4dd21972048eebb0127adc" +checksum = "5f843490bd258c5182b66e888161bb6f198f49f3792f7c7f98198b924ae0f564" dependencies = [ "arrow-array", "arrow-buffer", @@ -234,9 +234,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb29be98f987bcf217b070512bb7afba2f65180858bca462edf4a39d84a23e10" +checksum = "a769666ffac256dd301006faca1ca553d0ae7cffcf4cd07095f73f95eb226514" dependencies = [ "arrow-buffer", "arrow-schema", @@ -246,9 +246,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffc68f6523970aa6f7ce1dc9a33a7d9284cfb9af77d4ad3e617dbe5d79cc6ec8" +checksum = "dbf9c3fb57390a1af0b7bb3b5558c1ee1f63905f3eccf49ae7676a8d1e6e5a72" dependencies = [ "arrow-array", "arrow-buffer", @@ -261,9 +261,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2041380f94bd6437ab648e6c2085a045e45a0c44f91a1b9a4fe3fed3d379bfb1" +checksum = "654e7f3724176b66ddfacba31af397c48e106fbe4d281c8144e7d237df5acfd7" dependencies = [ "arrow-array", "arrow-buffer", @@ -281,9 +281,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb56ed1547004e12203652f12fe12e824161ff9d1e5cf2a7dc4ff02ba94f413" +checksum = "e8008370e624e8e3c68174faaf793540287106cfda8ad1da862fdc53d8e096b4" dependencies = [ "arrow-array", "arrow-buffer", @@ -296,9 +296,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "575b42f1fc588f2da6977b94a5ca565459f5ab07b60545e17243fb9a7ed6d43e" +checksum = "ca5e3a6b7fda8d9fe03f3b18a2d946354ea7f3c8e4076dbdb502ad50d9d44824" dependencies = [ "ahash", "arrow-array", @@ -311,15 +311,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32aae6a60458a2389c0da89c9de0b7932427776127da1a738e2efc21d32f3393" +checksum = "dab1c12b40e29d9f3b699e0203c2a73ba558444c05e388a4377208f8f9c97eee" [[package]] name = "arrow-select" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de36abaef8767b4220d7b4a8c2fe5ffc78b47db81b03d77e2136091c3ba39102" +checksum = "e80159088ffe8c48965cb9b1a7c968b2729f29f37363df7eca177fc3281fe7c3" dependencies = [ "ahash", "arrow-array", @@ -331,9 +331,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e435ada8409bcafc910bc3e0077f532a4daa20e99060a496685c0e3e53cc2597" +checksum = "0fd04a6ea7de183648edbcb7a6dd925bbd04c210895f6384c780e27a9b54afcd" dependencies = [ "arrow-array", "arrow-buffer", @@ -375,8 +375,8 @@ dependencies = [ "pin-project-lite", "tokio", "xz2", - "zstd 0.13.0", - "zstd-safe 7.0.0", + "zstd 0.13.2", + "zstd-safe 7.2.0", ] [[package]] @@ -875,9 +875,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.103" +version = "1.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2755ff20a1d93490d26ba33a6f092a38a508398a5320df5d4b3014fcccce9410" +checksum = "74b6a57f98764a267ff415d50a25e6e166f3831a5071af4995296ea97d210490" dependencies = [ "jobserver", "libc", @@ -900,7 +900,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1172,7 +1172,7 @@ dependencies = [ "url", "uuid", "xz2", - "zstd 0.13.0", + "zstd 0.13.2", ] [[package]] @@ -1964,9 +1964,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +checksum = "c4fe55fb7a772d59a5ff1dfbff4fe0258d19b89fec4b233e75d35d5d2316badc" dependencies = [ "bytes", "futures-channel", @@ -2005,10 +2005,10 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.3.1", + "hyper 1.4.0", "hyper-util", "rustls 0.23.10", - "rustls-native-certs 0.7.0", + "rustls-native-certs 0.7.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -2017,16 +2017,16 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b875924a60b96e5d7b9ae7b066540b1dd1cbd90d1828f54c92e02a283351c56" +checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.3.1", + "hyper 1.4.0", "pin-project-lite", "socket2", "tokio", @@ -2501,7 +2501,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.3.1", + "hyper 1.4.0", "itertools", "md-5", "parking_lot", @@ -2573,14 +2573,14 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] name = "parquet" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c3b5322cc1bbf67f11c079c42be41a55949099b78732f7dba9e15edde40eab" +checksum = "0f22ba0d95db56dde8685e3fadcb915cdaadda31ab8abbe3ff7f0ad1ef333267" dependencies = [ "ahash", "arrow-array", @@ -2608,7 +2608,7 @@ dependencies = [ "thrift", "tokio", "twox-hash", - "zstd 0.13.0", + "zstd 0.13.2", "zstd-sys", ] @@ -2975,7 +2975,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.0", "hyper-rustls 0.27.2", "hyper-util", "ipnet", @@ -2987,7 +2987,7 @@ dependencies = [ "pin-project-lite", "quinn", "rustls 0.23.10", - "rustls-native-certs 0.7.0", + "rustls-native-certs 0.7.1", "rustls-pemfile 2.1.2", "rustls-pki-types", "serde", @@ -3142,9 +3142,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" dependencies = [ "openssl-probe", "rustls-pemfile 2.1.2", @@ -3180,9 +3180,9 @@ checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78" dependencies = [ "ring 0.17.8", "rustls-pki-types", @@ -3315,9 +3315,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.119" +version = "1.0.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8eddb61f0697cc3989c5d64b452f5488e2b8a60fd7d5076a3045076ffef8cb0" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" dependencies = [ "itoa", "ryu", @@ -3646,9 +3646,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" +checksum = "ce6b6a2fb3a985e99cebfaefa9faa3024743da73304ca1c683a36429613d3d22" dependencies = [ "tinyvec_macros", ] @@ -4097,7 +4097,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4115,7 +4115,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4135,18 +4135,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4157,9 +4157,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4169,9 +4169,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -4181,15 +4181,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -4199,9 +4199,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -4211,9 +4211,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -4223,9 +4223,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -4235,9 +4235,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winreg" @@ -4266,18 +4266,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", @@ -4301,11 +4301,11 @@ dependencies = [ [[package]] name = "zstd" -version = "0.13.0" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" dependencies = [ - "zstd-safe 7.0.0", + "zstd-safe 7.2.0", ] [[package]] @@ -4320,18 +4320,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.0.0" +version = "7.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +checksum = "fa556e971e7b568dc775c136fc9de8c779b1c2fc3a63defaafadffdbd3181afa" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.9+zstd.1.5.5" +version = "2.0.11+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +checksum = "75652c55c0b6f3e6f12eb786fe1bc960396bf05a1eb3bf1f3691c3610ac2e6d4" dependencies = [ "cc", "pkg-config", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 8578476ed43d..bcacf1d52a9b 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -30,7 +30,7 @@ rust-version = "1.76" readme = "README.md" [dependencies] -arrow = { version = "52.0.0" } +arrow = { version = "52.1.0" } async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" @@ -51,7 +51,7 @@ futures = "0.3" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.10.1", features = ["aws", "gcp", "http"] } parking_lot = { version = "0.12" } -parquet = { version = "52.0.0", default-features = false } +parquet = { version = "52.1.0", default-features = false } regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 442a33bebc99..83a7da046844 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -1076,7 +1076,7 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Microsecond, (l, r) => { assert_eq!(l, r); - l.clone() + *l } }; diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index f9f467098ee4..b430b343e484 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -607,11 +607,11 @@ fn coerced_from<'a>( (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { - Some(Timestamp(unit.clone(), Some(Arc::clone(from_tz)))) + Some(Timestamp(*unit, Some(Arc::clone(from_tz)))) } Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => { // In the absence of any other information assume the time zone is "+00" (UTC). - Some(Timestamp(unit.clone(), Some("+00".into()))) + Some(Timestamp(*unit, Some("+00".into()))) } _ => None, } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index e777e5ea95d0..997f1a36ad04 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -57,35 +57,35 @@ impl DateBinFunc { vec![ Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), Timestamp(Nanosecond, None), ]), Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ DataType::Interval(DayTime), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), Timestamp(Nanosecond, None), ]), Exact(vec![ DataType::Interval(DayTime), - Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), ]), Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ DataType::Interval(DayTime), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), ]), Exact(vec![ DataType::Interval(DayTime), From e693ed7a3c3b36405f0a34887e6f8b49d4e97152 Mon Sep 17 00:00:00 2001 From: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Date: Sun, 7 Jul 2024 06:37:34 -0400 Subject: [PATCH 20/26] AggregateExec: Take grouping sets into account for InputOrderMode (#11301) * AggregateExec: Take grouping sets into account for InputOrderMode * pr comments --- .../physical-plan/src/aggregates/mod.rs | 121 ++++++++++++++++-- 1 file changed, 113 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 258f4140bc1e..6a0ae202c067 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -369,14 +369,26 @@ impl AggregateExec { new_requirement.extend(req); new_requirement = collapse_lex_req(new_requirement); - let input_order_mode = - if indices.len() == groupby_exprs.len() && !indices.is_empty() { - InputOrderMode::Sorted - } else if !indices.is_empty() { - InputOrderMode::PartiallySorted(indices) - } else { - InputOrderMode::Linear - }; + // If our aggregation has grouping sets then our base grouping exprs will + // be expanded based on the flags in `group_by.groups` where for each + // group we swap the grouping expr for `null` if the flag is `true` + // That means that each index in `indices` is valid if and only if + // it is not null in every group + let indices: Vec = indices + .into_iter() + .filter(|idx| group_by.groups.iter().all(|group| !group[*idx])) + .collect(); + + let input_order_mode = if indices.len() == groupby_exprs.len() + && !indices.is_empty() + && group_by.groups.len() == 1 + { + InputOrderMode::Sorted + } else if !indices.is_empty() { + InputOrderMode::PartiallySorted(indices) + } else { + InputOrderMode::Linear + }; // construct a map from the input expression to the output expression of the Aggregation group by let projection_mapping = @@ -1180,6 +1192,7 @@ mod tests { use arrow::array::{Float64Array, UInt32Array}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; + use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, ScalarValue, @@ -1195,7 +1208,9 @@ mod tests { use datafusion_physical_expr::expressions::{lit, OrderSensitiveArrayAgg}; use datafusion_physical_expr::PhysicalSortExpr; + use crate::common::collect; use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::expressions::Literal; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -2267,4 +2282,94 @@ mod tests { assert_eq!(new_agg.schema(), aggregate_exec.schema()); Ok(()) } + + #[tokio::test] + async fn test_agg_exec_group_by_const() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + Field::new("const", DataType::Int32, false), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); + + let groups = PhysicalGroupBy::new( + vec![ + (col_a, "a".to_string()), + (col_b, "b".to_string()), + (const_expr, "const".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "b".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Int32(None))), + "const".to_string(), + ), + ], + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + ); + + let aggregates: Vec> = vec![create_aggregate_expr( + count_udaf().as_ref(), + &[lit(1)], + &[datafusion_expr::lit(1)], + &[], + &[], + schema.as_ref(), + "1", + false, + false, + )?]; + + let input_batches = (0..4) + .map(|_| { + let a = Arc::new(Float32Array::from(vec![0.; 8192])); + let b = Arc::new(Float32Array::from(vec![0.; 8192])); + let c = Arc::new(Int32Array::from(vec![1; 8192])); + + RecordBatch::try_new(schema.clone(), vec![a, b, c]).unwrap() + }) + .collect(); + + let input = + Arc::new(MemoryExec::try_new(&[input_batches], schema.clone(), None)?); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None], + input, + schema, + )?); + + let output = + collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; + + let expected = [ + "+-----+-----+-------+----------+", + "| a | b | const | 1[count] |", + "+-----+-----+-------+----------+", + "| | 0.0 | | 32768 |", + "| 0.0 | | | 32768 |", + "| | | 1 | 32768 |", + "+-----+-----+-------+----------+", + ]; + assert_batches_sorted_eq!(expected, &output); + + Ok(()) + } } From 9f8ba6ab686434d6abfae42d2b2b2819e4bbbd9d Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sun, 7 Jul 2024 13:03:10 -0400 Subject: [PATCH 21/26] Add user_defined_sql_planners(..) to FunctionRegistry (#11296) * Add user_defined_sql_planners(..) to FunctionRegistry * Adding simple test for user_defined_sql_planners * Renamed user_defined_sql_planners to expr_planners --- datafusion/core/src/execution/context/mod.rs | 4 ++++ datafusion/core/src/execution/session_state.rs | 4 ++++ .../user_defined/user_defined_scalar_functions.rs | 11 +++++++++++ datafusion/execution/src/task.rs | 10 +++++++--- datafusion/expr/src/registry.rs | 7 +++++++ datafusion/proto/src/bytes/mod.rs | 5 +++++ datafusion/proto/src/bytes/registry.rs | 5 +++++ 7 files changed, 43 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 4685f194fe29..04debf498aa9 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1392,6 +1392,10 @@ impl FunctionRegistry for SessionContext { self.state.write().register_function_rewrite(rewrite) } + fn expr_planners(&self) -> Vec> { + self.state.read().expr_planners() + } + fn register_user_defined_sql_planner( &mut self, user_defined_sql_planner: Arc, diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index ffaaa2df5e7e..ad557b12255c 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -1183,6 +1183,10 @@ impl FunctionRegistry for SessionState { Ok(()) } + fn expr_planners(&self) -> Vec> { + self.user_defined_sql_planners.clone() + } + fn register_user_defined_sql_planner( &mut self, user_defined_sql_planner: Arc, diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 5e3c44c039ab..ae8a009c6292 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -571,6 +571,17 @@ async fn test_user_defined_functions_cast_to_i64() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_user_defined_sql_functions() -> Result<()> { + let ctx = SessionContext::new(); + + let sql_planners = ctx.expr_planners(); + + assert!(!sql_planners.is_empty()); + + Ok(()) +} + #[tokio::test] async fn deregister_udf() -> Result<()> { let cast2i64 = ScalarUDF::from(CastToI64UDF::new()); diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index c21ce3d21da1..24d61e6a8b72 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -20,15 +20,15 @@ use std::{ sync::Arc, }; -use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; - use crate::{ config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry, runtime_env::{RuntimeConfig, RuntimeEnv}, }; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::planner::UserDefinedSQLPlanner; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; /// Task Execution Context /// @@ -191,6 +191,10 @@ impl FunctionRegistry for TaskContext { }); Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } #[cfg(test)] diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index c276fe30f897..6a27c05bb451 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -110,6 +110,9 @@ pub trait FunctionRegistry { not_impl_err!("Registering FunctionRewrite") } + /// Set of all registered [`UserDefinedSQLPlanner`]s + fn expr_planners(&self) -> Vec>; + /// Registers a new [`UserDefinedSQLPlanner`] with the registry. fn register_user_defined_sql_planner( &mut self, @@ -192,4 +195,8 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn register_udwf(&mut self, udaf: Arc) -> Result>> { Ok(self.udwfs.insert(udaf.name().into(), udaf)) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 901aa2455e16..83210cb4e41f 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -39,6 +39,7 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; +use datafusion_expr::planner::UserDefinedSQLPlanner; mod registry; @@ -165,6 +166,10 @@ impl Serializeable for Expr { "register_udwf called in Placeholder Registry!" ) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 4bf2bb3d7b79..075993e2ba76 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -20,6 +20,7 @@ use std::{collections::HashSet, sync::Arc}; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::plan_err; use datafusion_common::Result; +use datafusion_expr::planner::UserDefinedSQLPlanner; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; /// A default [`FunctionRegistry`] registry that does not resolve any @@ -54,4 +55,8 @@ impl FunctionRegistry for NoRegistry { fn register_udwf(&mut self, udwf: Arc) -> Result>> { plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Window Function '{}'", udwf.inner().name()) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } From 45599ce310aa6270813091a5c3288abcd7541f59 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Sun, 7 Jul 2024 10:39:37 -0700 Subject: [PATCH 22/26] use safe cast in propagate_constraints (#11297) * use safe cast in propagate_constraints * add test --- .../src/expressions/cast.rs | 11 +++++++--- datafusion/sqllogictest/test_files/cast.slt | 20 +++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr-common/src/expressions/cast.rs b/datafusion/physical-expr-common/src/expressions/cast.rs index 8aba33932c56..dd6131ad65c3 100644 --- a/datafusion/physical-expr-common/src/expressions/cast.rs +++ b/datafusion/physical-expr-common/src/expressions/cast.rs @@ -36,6 +36,11 @@ const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { format_options: DEFAULT_FORMAT_OPTIONS, }; +const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { + safe: true, + format_options: DEFAULT_FORMAT_OPTIONS, +}; + /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast #[derive(Debug, Clone)] pub struct CastExpr { @@ -150,9 +155,9 @@ impl PhysicalExpr for CastExpr { let child_interval = children[0]; // Get child's datatype: let cast_type = child_interval.data_type(); - Ok(Some( - vec![interval.cast_to(&cast_type, &self.cast_options)?], - )) + Ok(Some(vec![ + interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)? + ])) } fn dyn_hash(&self, state: &mut dyn Hasher) { diff --git a/datafusion/sqllogictest/test_files/cast.slt b/datafusion/sqllogictest/test_files/cast.slt index 4554c9292b6e..3466354e54d7 100644 --- a/datafusion/sqllogictest/test_files/cast.slt +++ b/datafusion/sqllogictest/test_files/cast.slt @@ -69,3 +69,23 @@ query ? SELECT CAST(MAKE_ARRAY() AS VARCHAR[]) ---- [] + +statement ok +create table t0(v0 BIGINT); + +statement ok +insert into t0 values (1),(2),(3); + +query I +select * from t0 where v0>1e100; +---- + +query I +select * from t0 where v0<1e100; +---- +1 +2 +3 + +statement ok +drop table t0; From 5aa7c4ae7977cd043f28c4f55e07fa72d278a7b7 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 8 Jul 2024 03:52:31 +0800 Subject: [PATCH 23/26] Minor: Remove clone in optimizer (#11315) * rm clone Signed-off-by: jayzhan211 * outer join + fix test Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../optimizer/src/eliminate_nested_union.rs | 34 +++++++++++-------- .../optimizer/src/eliminate_outer_join.rs | 13 ++++--- .../optimizer/src/propagate_empty_relation.rs | 4 ++- 3 files changed, 31 insertions(+), 20 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 09407aed53cd..3732f7ed90c8 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -21,7 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::{Distinct, LogicalPlan, Union}; +use itertools::Itertools; use std::sync::Arc; #[derive(Default)] @@ -56,32 +58,34 @@ impl OptimizerRule for EliminateNestedUnion { match plan { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs - .iter() + .into_iter() .flat_map(extract_plans_from_union) .collect::>(); Ok(Transformed::yes(LogicalPlan::Union(Union { - inputs, + inputs: inputs.into_iter().map(Arc::new).collect_vec(), schema, }))) } - LogicalPlan::Distinct(Distinct::All(ref nested_plan)) => { - match nested_plan.as_ref() { + LogicalPlan::Distinct(Distinct::All(nested_plan)) => { + match unwrap_arc(nested_plan) { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs - .iter() + .into_iter() .map(extract_plan_from_distinct) .flat_map(extract_plans_from_union) .collect::>(); Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All( Arc::new(LogicalPlan::Union(Union { - inputs, + inputs: inputs.into_iter().map(Arc::new).collect_vec(), schema: schema.clone(), })), )))) } - _ => Ok(Transformed::no(plan)), + nested_plan => Ok(Transformed::no(LogicalPlan::Distinct( + Distinct::All(Arc::new(nested_plan)), + ))), } } _ => Ok(Transformed::no(plan)), @@ -89,20 +93,20 @@ impl OptimizerRule for EliminateNestedUnion { } } -fn extract_plans_from_union(plan: &Arc) -> Vec> { - match plan.as_ref() { +fn extract_plans_from_union(plan: Arc) -> Vec { + match unwrap_arc(plan) { LogicalPlan::Union(Union { inputs, schema }) => inputs - .iter() - .map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap())) + .into_iter() + .map(|plan| coerce_plan_expr_for_schema(&plan, &schema).unwrap()) .collect::>(), - _ => vec![plan.clone()], + plan => vec![plan], } } -fn extract_plan_from_distinct(plan: &Arc) -> &Arc { - match plan.as_ref() { +fn extract_plan_from_distinct(plan: Arc) -> Arc { + match unwrap_arc(plan) { LogicalPlan::Distinct(Distinct::All(plan)) => plan, - _ => plan, + plan => Arc::new(plan), } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index ccc637a0eb01..13c483c6dfcc 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -18,6 +18,7 @@ //! [`EliminateOuterJoin`] converts `LEFT/RIGHT/FULL` joins to `INNER` joins use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, Result}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; use datafusion_expr::{Expr, Filter, Operator}; @@ -78,7 +79,7 @@ impl OptimizerRule for EliminateOuterJoin { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Filter(filter) => match filter.input.as_ref() { + LogicalPlan::Filter(mut filter) => match unwrap_arc(filter.input) { LogicalPlan::Join(join) => { let mut non_nullable_cols: Vec = vec![]; @@ -109,9 +110,10 @@ impl OptimizerRule for EliminateOuterJoin { } else { join.join_type }; + let new_join = Arc::new(LogicalPlan::Join(Join { - left: Arc::new((*join.left).clone()), - right: Arc::new((*join.right).clone()), + left: join.left, + right: join.right, join_type: new_join_type, join_constraint: join.join_constraint, on: join.on.clone(), @@ -122,7 +124,10 @@ impl OptimizerRule for EliminateOuterJoin { Filter::try_new(filter.predicate, new_join) .map(|f| Transformed::yes(LogicalPlan::Filter(f))) } - _ => Ok(Transformed::no(LogicalPlan::Filter(filter))), + filter_input => { + filter.input = Arc::new(filter_input); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } }, _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 576dabe305e6..88bd1b17883b 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -182,7 +182,9 @@ impl OptimizerRule for PropagateEmptyRelation { }, ))) } else if new_inputs.len() == 1 { - let child = unwrap_arc(new_inputs[0].clone()); + let mut new_inputs = new_inputs; + let input_plan = new_inputs.pop().unwrap(); // length checked + let child = unwrap_arc(input_plan); if child.schema().eq(plan.schema()) { Ok(Transformed::yes(child)) } else { From 229c1398d6bcd4267c8f7030a4fcf7d17c096b74 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 7 Jul 2024 13:56:50 -0600 Subject: [PATCH 24/26] minor: Add `PhysicalSortExpr::new` (#11310) * Add PhysicalSortExpr::new * update call sites in physical-expr-common crate --- .../physical-expr-common/src/sort_expr.rs | 23 ++++++++++--------- datafusion/physical-expr-common/src/utils.rs | 5 +--- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index f637355519af..8fb1356a8092 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -39,6 +39,13 @@ pub struct PhysicalSortExpr { pub options: SortOptions, } +impl PhysicalSortExpr { + /// Create a new PhysicalSortExpr + pub fn new(expr: Arc, options: SortOptions) -> Self { + Self { expr, options } + } +} + impl PartialEq for PhysicalSortExpr { fn eq(&self, other: &PhysicalSortExpr) -> bool { self.options == other.options && self.expr.eq(&other.expr) @@ -155,10 +162,7 @@ impl From for PhysicalSortExpr { descending: false, nulls_first: false, }); - PhysicalSortExpr { - expr: value.expr, - options, - } + PhysicalSortExpr::new(value.expr, options) } } @@ -281,16 +285,13 @@ pub fn limited_convert_logical_sort_exprs_to_physical( let Expr::Sort(sort) = expr else { return exec_err!("Expects to receive sort expression"); }; - sort_exprs.push(PhysicalSortExpr { - expr: limited_convert_logical_expr_to_physical_expr( - sort.expr.as_ref(), - schema, - )?, - options: SortOptions { + sort_exprs.push(PhysicalSortExpr::new( + limited_convert_logical_expr_to_physical_expr(sort.expr.as_ref(), schema)?, + SortOptions { descending: !sort.asc, nulls_first: sort.nulls_first, }, - }); + )) } Ok(sort_exprs) } diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index d5cd3c6f4af0..44622bd309df 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -104,10 +104,7 @@ pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec { order_bys .iter() - .map(|e| PhysicalSortExpr { - expr: e.expr.clone(), - options: !e.options, - }) + .map(|e| PhysicalSortExpr::new(e.expr.clone(), !e.options)) .collect() } From 6f330c98b390b23ae6ba04dce1b264fcfd28e684 Mon Sep 17 00:00:00 2001 From: Eric Fredine Date: Sun, 7 Jul 2024 14:42:16 -0700 Subject: [PATCH 25/26] Fix data page statistics when all rows are null in a data page (#11295) * Adds tests for data page statistics when all values on the page are null. Fixes most of the failing tests for iterators not handling this situation correctly. * Fix handling of data page statistics for FixedBinaryArray using a builder. * Fix data page all nulls stats test for Dictionary DataType. * Fixes handling of None statistics for Decimal128 and Decimal256. * Consolidate make_data_page_stats_iterator uses. * Fix linting error. * Remove unnecessary collect. --------- Co-authored-by: Eric Fredine --- .../physical_plan/parquet/statistics.rs | 128 +++++++------- .../core/tests/parquet/arrow_statistics.rs | 158 +++++++++++++----- 2 files changed, 184 insertions(+), 102 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index bd05fe64e62d..b9aca2ac2cc9 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -19,6 +19,7 @@ // TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 +use arrow::array::builder::FixedSizeBinaryBuilder; use arrow::datatypes::i256; use arrow::{array::ArrayRef, datatypes::DataType}; use arrow_array::{ @@ -600,6 +601,31 @@ make_data_page_stats_iterator!( Index::DOUBLE, f64 ); +make_data_page_stats_iterator!( + MinByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.min.clone() }, + Index::BYTE_ARRAY, + ByteArray +); +make_data_page_stats_iterator!( + MaxByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.max.clone() }, + Index::BYTE_ARRAY, + ByteArray +); +make_data_page_stats_iterator!( + MaxFixedLenByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.max.clone() }, + Index::FIXED_LEN_BYTE_ARRAY, + FixedLenByteArray +); + +make_data_page_stats_iterator!( + MinFixedLenByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.min.clone() }, + Index::FIXED_LEN_BYTE_ARRAY, + FixedLenByteArray +); macro_rules! get_decimal_page_stats_iterator { ($iterator_type: ident, $func: ident, $stat_value_type: ident, $convert_func: ident) => { @@ -634,9 +660,7 @@ macro_rules! get_decimal_page_stats_iterator { .indexes .iter() .map(|x| { - Some($stat_value_type::from( - x.$func.unwrap_or_default(), - )) + x.$func.and_then(|x| Some($stat_value_type::from(x))) }) .collect::>(), ), @@ -645,9 +669,7 @@ macro_rules! get_decimal_page_stats_iterator { .indexes .iter() .map(|x| { - Some($stat_value_type::from( - x.$func.unwrap_or_default(), - )) + x.$func.and_then(|x| Some($stat_value_type::from(x))) }) .collect::>(), ), @@ -656,9 +678,9 @@ macro_rules! get_decimal_page_stats_iterator { .indexes .iter() .map(|x| { - Some($convert_func( - x.clone().$func.unwrap_or_default().data(), - )) + x.clone() + .$func + .and_then(|x| Some($convert_func(x.data()))) }) .collect::>(), ), @@ -667,9 +689,9 @@ macro_rules! get_decimal_page_stats_iterator { .indexes .iter() .map(|x| { - Some($convert_func( - x.clone().$func.unwrap_or_default().data(), - )) + x.clone() + .$func + .and_then(|x| Some($convert_func(x.data()))) }) .collect::>(), ), @@ -713,32 +735,6 @@ get_decimal_page_stats_iterator!( i256, from_bytes_to_i256 ); -make_data_page_stats_iterator!( - MinByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.min.clone() }, - Index::BYTE_ARRAY, - ByteArray -); -make_data_page_stats_iterator!( - MaxByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.max.clone() }, - Index::BYTE_ARRAY, - ByteArray -); - -make_data_page_stats_iterator!( - MaxFixedLenByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.max.clone() }, - Index::FIXED_LEN_BYTE_ARRAY, - FixedLenByteArray -); - -make_data_page_stats_iterator!( - MinFixedLenByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.min.clone() }, - Index::FIXED_LEN_BYTE_ARRAY, - FixedLenByteArray -); macro_rules! get_data_page_statistics { ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { @@ -757,7 +753,7 @@ macro_rules! get_data_page_statistics { UInt8Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| u8::try_from(x).ok()) }) }) @@ -768,7 +764,7 @@ macro_rules! get_data_page_statistics { UInt16Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| u16::try_from(x).ok()) }) }) @@ -779,7 +775,7 @@ macro_rules! get_data_page_statistics { UInt32Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| Some(x as u32)) }) }) @@ -789,7 +785,7 @@ macro_rules! get_data_page_statistics { UInt64Array::from_iter( [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| Some(x as u64)) }) }) @@ -799,7 +795,7 @@ macro_rules! get_data_page_statistics { Int8Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| i8::try_from(x).ok()) }) }) @@ -810,7 +806,7 @@ macro_rules! get_data_page_statistics { Int16Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| i16::try_from(x).ok()) }) }) @@ -823,8 +819,8 @@ macro_rules! get_data_page_statistics { Float16Array::from_iter( [<$stat_type_prefix Float16DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { - x.and_then(|x| Some(from_bytes_to_f16(x.data()))) + x.into_iter().map(|x| { + x.and_then(|x| from_bytes_to_f16(x.data())) }) }) .flatten() @@ -836,7 +832,7 @@ macro_rules! get_data_page_statistics { Some(DataType::LargeBinary) => Ok(Arc::new(LargeBinaryArray::from_iter([<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).flatten()))), Some(DataType::Utf8) => Ok(Arc::new(StringArray::from( [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| { let res = std::str::from_utf8(x.data()).map(|s| s.to_string()).ok(); if res.is_none() { @@ -849,7 +845,7 @@ macro_rules! get_data_page_statistics { ))), Some(DataType::LargeUtf8) => Ok(Arc::new(LargeStringArray::from( [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| { let res = std::str::from_utf8(x.data()).map(|s| s.to_string()).ok(); if res.is_none() { @@ -878,10 +874,10 @@ macro_rules! get_data_page_statistics { Date64Array::from([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { x.into_iter() - .filter_map(|x| { + .map(|x| { x.and_then(|x| i64::try_from(x).ok()) + .map(|x| x * 24 * 60 * 60 * 1000) }) - .map(|x| x * 24 * 60 * 60 * 1000) }).flatten().collect::>() ) ) @@ -919,16 +915,28 @@ macro_rules! get_data_page_statistics { }) }, Some(DataType::FixedSizeBinary(size)) => { - Ok(Arc::new( - FixedSizeBinaryArray::try_from_iter( - [<$stat_type_prefix FixedLenByteArrayDataPageStatsIterator>]::new($iterator) - .flat_map(|x| x.into_iter()) - .filter_map(|x| x) - ).unwrap_or_else(|e| { - log::debug!("FixedSizeBinary statistics is invalid: {}", e); - FixedSizeBinaryArray::new(*size, vec![].into(), None) - }) - )) + let mut builder = FixedSizeBinaryBuilder::new(*size); + let iterator = [<$stat_type_prefix FixedLenByteArrayDataPageStatsIterator>]::new($iterator); + for x in iterator { + for x in x.into_iter() { + let Some(x) = x else { + builder.append_null(); // no statistics value + continue; + }; + + if x.len() == *size as usize { + let _ = builder.append_value(x.data()); + } else { + log::debug!( + "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", + size, + x.len(), + ); + builder.append_null(); + } + } + } + Ok(Arc::new(builder.finish())) }, _ => unimplemented!() } diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 6bfb9b02d347..2b4ba0b17133 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -29,15 +29,15 @@ use arrow::datatypes::{ TimestampNanosecondType, TimestampSecondType, }; use arrow_array::{ - make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, - LargeStringArray, RecordBatch, StringArray, Time32MillisecondArray, + make_array, new_null_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, + Date64Array, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, + LargeBinaryArray, LargeStringArray, RecordBatch, StringArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; -use arrow_schema::{DataType, Field, Schema}; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; use datafusion::datasource::physical_plan::parquet::StatisticsConverter; use half::f16; use parquet::arrow::arrow_reader::{ @@ -91,51 +91,60 @@ impl Int64Case { // Create a parquet file with the specified settings pub fn build(&self) -> ParquetRecordBatchReaderBuilder { - let mut output_file = tempfile::Builder::new() - .prefix("parquert_statistics_test") - .suffix(".parquet") - .tempfile() - .expect("tempfile creation"); - - let mut builder = - WriterProperties::builder().set_max_row_group_size(self.row_per_group); - if let Some(enable_stats) = self.enable_stats { - builder = builder.set_statistics_enabled(enable_stats); - } - if let Some(data_page_row_count_limit) = self.data_page_row_count_limit { - builder = builder.set_data_page_row_count_limit(data_page_row_count_limit); - } - let props = builder.build(); - let batches = vec![self.make_int64_batches_with_null()]; + build_parquet_file( + self.row_per_group, + self.enable_stats, + self.data_page_row_count_limit, + batches, + ) + } +} - let schema = batches[0].schema(); +fn build_parquet_file( + row_per_group: usize, + enable_stats: Option, + data_page_row_count_limit: Option, + batches: Vec, +) -> ParquetRecordBatchReaderBuilder { + let mut output_file = tempfile::Builder::new() + .prefix("parquert_statistics_test") + .suffix(".parquet") + .tempfile() + .expect("tempfile creation"); + + let mut builder = WriterProperties::builder().set_max_row_group_size(row_per_group); + if let Some(enable_stats) = enable_stats { + builder = builder.set_statistics_enabled(enable_stats); + } + if let Some(data_page_row_count_limit) = data_page_row_count_limit { + builder = builder.set_data_page_row_count_limit(data_page_row_count_limit); + } + let props = builder.build(); - let mut writer = - ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); + let schema = batches[0].schema(); - // if we have a datapage limit send the batches in one at a time to give - // the writer a chance to be split into multiple pages - if self.data_page_row_count_limit.is_some() { - for batch in batches { - for i in 0..batch.num_rows() { - writer.write(&batch.slice(i, 1)).expect("writing batch"); - } - } - } else { - for batch in batches { - writer.write(&batch).expect("writing batch"); + let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); + + // if we have a datapage limit send the batches in one at a time to give + // the writer a chance to be split into multiple pages + if data_page_row_count_limit.is_some() { + for batch in &batches { + for i in 0..batch.num_rows() { + writer.write(&batch.slice(i, 1)).expect("writing batch"); } } + } else { + for batch in &batches { + writer.write(batch).expect("writing batch"); + } + } - // close file - let _file_meta = writer.close().unwrap(); + let _file_meta = writer.close().unwrap(); - // open the file & get the reader - let file = output_file.reopen().unwrap(); - let options = ArrowReaderOptions::new().with_page_index(true); - ArrowReaderBuilder::try_new_with_options(file, options).unwrap() - } + let file = output_file.reopen().unwrap(); + let options = ArrowReaderOptions::new().with_page_index(true); + ArrowReaderBuilder::try_new_with_options(file, options).unwrap() } /// Defines what data to create in a parquet file @@ -503,6 +512,71 @@ async fn test_multiple_data_pages_nulls_and_negatives() { .run() } +#[tokio::test] +async fn test_data_page_stats_with_all_null_page() { + for data_type in &[ + DataType::Boolean, + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + DataType::Float16, + DataType::Float32, + DataType::Float64, + DataType::Date32, + DataType::Date64, + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Second), + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Binary, + DataType::LargeBinary, + DataType::FixedSizeBinary(3), + DataType::Utf8, + DataType::LargeUtf8, + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Decimal128(8, 2), // as INT32 + DataType::Decimal128(10, 2), // as INT64 + DataType::Decimal128(20, 2), // as FIXED_LEN_BYTE_ARRAY + DataType::Decimal256(8, 2), // as INT32 + DataType::Decimal256(10, 2), // as INT64 + DataType::Decimal256(20, 2), // as FIXED_LEN_BYTE_ARRAY + ] { + let batch = + RecordBatch::try_from_iter(vec![("col", new_null_array(data_type, 4))]) + .expect("record batch creation"); + + let reader = + build_parquet_file(4, Some(EnabledStatistics::Page), Some(4), vec![batch]); + + let expected_data_type = match data_type { + DataType::Dictionary(_, value_type) => value_type.as_ref(), + _ => data_type, + }; + + // There is one data page with 4 nulls + // The statistics should be present but null + Test { + reader: &reader, + expected_min: new_null_array(expected_data_type, 1), + expected_max: new_null_array(expected_data_type, 1), + expected_null_counts: UInt64Array::from(vec![4]), + expected_row_counts: Some(UInt64Array::from(vec![4])), + column_name: "col", + check: Check::DataPage, + } + .run() + } +} + /////////////// MORE GENERAL TESTS ////////////////////// // . Many columns in a file // . Differnet data types From 940efd3b4240bcc4a44ab02342a10ca2663318a0 Mon Sep 17 00:00:00 2001 From: Xin Li <33629085+xinlifoobar@users.noreply.github.com> Date: Sun, 7 Jul 2024 18:04:19 -0700 Subject: [PATCH 26/26] Made userDefinedFunctionPlanner to uniform the usages (#11318) --- .../core/src/execution/session_state.rs | 9 ++--- datafusion/functions/src/datetime/mod.rs | 1 - datafusion/functions/src/lib.rs | 3 ++ .../functions/src/{datetime => }/planner.rs | 14 ++++++-- datafusion/functions/src/unicode/mod.rs | 1 - datafusion/functions/src/unicode/planner.rs | 36 ------------------- 6 files changed, 19 insertions(+), 45 deletions(-) rename datafusion/functions/src/{datetime => }/planner.rs (71%) delete mode 100644 datafusion/functions/src/unicode/planner.rs diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index ad557b12255c..d056b91c2747 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -238,10 +238,11 @@ impl SessionState { Arc::new(functions_array::planner::ArrayFunctionPlanner), #[cfg(feature = "array_expressions")] Arc::new(functions_array::planner::FieldAccessPlanner), - #[cfg(feature = "datetime_expressions")] - Arc::new(functions::datetime::planner::ExtractPlanner), - #[cfg(feature = "unicode_expressions")] - Arc::new(functions::unicode::planner::PositionPlanner), + #[cfg(any( + feature = "datetime_expressions", + feature = "unicode_expressions" + ))] + Arc::new(functions::planner::UserDefinedFunctionPlanner), ]; let mut new_self = SessionState { diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index 8365a38f41f2..9c2f80856bf8 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -30,7 +30,6 @@ pub mod date_trunc; pub mod from_unixtime; pub mod make_date; pub mod now; -pub mod planner; pub mod to_char; pub mod to_date; pub mod to_timestamp; diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 4bc24931d06b..433a4f90d95b 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -130,6 +130,9 @@ make_stub_package!(crypto, "crypto_expressions"); pub mod unicode; make_stub_package!(unicode, "unicode_expressions"); +#[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))] +pub mod planner; + mod utils; /// Fluent-style API for creating `Expr`s diff --git a/datafusion/functions/src/datetime/planner.rs b/datafusion/functions/src/planner.rs similarity index 71% rename from datafusion/functions/src/datetime/planner.rs rename to datafusion/functions/src/planner.rs index 4265ce42a51a..b00d5cf60810 100644 --- a/datafusion/functions/src/datetime/planner.rs +++ b/datafusion/functions/src/planner.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! SQL planning extensions like [`ExtractPlanner`] +//! SQL planning extensions like [`UserDefinedFunctionPlanner`] use datafusion_common::Result; use datafusion_expr::{ @@ -25,12 +25,20 @@ use datafusion_expr::{ }; #[derive(Default)] -pub struct ExtractPlanner; +pub struct UserDefinedFunctionPlanner; -impl UserDefinedSQLPlanner for ExtractPlanner { +impl UserDefinedSQLPlanner for UserDefinedFunctionPlanner { + #[cfg(feature = "datetime_expressions")] fn plan_extract(&self, args: Vec) -> Result>> { Ok(PlannerResult::Planned(Expr::ScalarFunction( ScalarFunction::new_udf(crate::datetime::date_part(), args), ))) } + + #[cfg(feature = "unicode_expressions")] + fn plan_position(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::unicode::strpos(), args), + ))) + } } diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index a391b8ba11dc..9e8c07cd36ed 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -25,7 +25,6 @@ pub mod character_length; pub mod find_in_set; pub mod left; pub mod lpad; -pub mod planner; pub mod reverse; pub mod right; pub mod rpad; diff --git a/datafusion/functions/src/unicode/planner.rs b/datafusion/functions/src/unicode/planner.rs deleted file mode 100644 index 4d6f73321b4a..000000000000 --- a/datafusion/functions/src/unicode/planner.rs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! SQL planning extensions like [`PositionPlanner`] - -use datafusion_common::Result; -use datafusion_expr::{ - expr::ScalarFunction, - planner::{PlannerResult, UserDefinedSQLPlanner}, - Expr, -}; - -#[derive(Default)] -pub struct PositionPlanner; - -impl UserDefinedSQLPlanner for PositionPlanner { - fn plan_position(&self, args: Vec) -> Result>> { - Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf(crate::unicode::strpos(), args), - ))) - } -}