diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9e192b0be0f1..7ddc0af4306c 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1373,6 +1373,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-buffer", + "arrow-ord", "arrow-schema", "async-trait", "chrono", diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 2072cc7df002..5fbcea0c0683 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -42,6 +42,7 @@ use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; use datafusion::common::FileType; +use datafusion::sql::sqlparser; use rustyline::error::ReadlineError; use rustyline::Editor; use tokio::signal; @@ -221,15 +222,12 @@ async fn exec_and_print( let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { + let adjusted = + AdjustedPrintOptions::new(print_options.clone()).with_statement(&statement); + let plan = create_plan(ctx, statement).await?; + let adjusted = adjusted.with_plan(&plan); - // For plans like `Explain` ignore `MaxRows` option and always display all rows - let should_ignore_maxrows = matches!( - plan, - LogicalPlan::Explain(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Analyze(_) - ); let df = ctx.execute_logical_plan(plan).await?; let physical_plan = df.create_physical_plan().await?; @@ -237,21 +235,60 @@ async fn exec_and_print( let stream = execute_stream(physical_plan, task_ctx.clone())?; print_options.print_stream(stream, now).await?; } else { - let mut print_options = print_options.clone(); - if should_ignore_maxrows { - print_options.maxrows = MaxRows::Unlimited; - } - if print_options.format == PrintFormat::Automatic { - print_options.format = PrintFormat::Table; - } let results = collect(physical_plan, task_ctx.clone()).await?; - print_options.print_batches(&results, now)?; + adjusted.into_inner().print_batches(&results, now)?; } } Ok(()) } +/// Track adjustments to the print options based on the plan / statement being executed +#[derive(Debug)] +struct AdjustedPrintOptions { + inner: PrintOptions, +} + +impl AdjustedPrintOptions { + fn new(inner: PrintOptions) -> Self { + Self { inner } + } + /// Adjust print options based on any statement specific requirements + fn with_statement(mut self, statement: &Statement) -> Self { + if let Statement::Statement(sql_stmt) = statement { + // SHOW / SHOW ALL + if let sqlparser::ast::Statement::ShowVariable { .. } = sql_stmt.as_ref() { + self.inner.maxrows = MaxRows::Unlimited + } + } + self + } + + /// Adjust print options based on any plan specific requirements + fn with_plan(mut self, plan: &LogicalPlan) -> Self { + // For plans like `Explain` ignore `MaxRows` option and always display + // all rows + if matches!( + plan, + LogicalPlan::Explain(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Analyze(_) + ) { + self.inner.maxrows = MaxRows::Unlimited; + } + self + } + + /// Finalize and return the inner `PrintOptions` + fn into_inner(mut self) -> PrintOptions { + if self.inner.format == PrintFormat::Automatic { + self.inner.format = PrintFormat::Table; + } + + self.inner + } +} + async fn create_plan( ctx: &mut SessionContext, statement: Statement, diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 5cf8969aa46d..fc2cdbb7518d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1952,6 +1952,11 @@ impl SessionState { &self.config } + /// Return the mutable [`SessionConfig`]. + pub fn config_mut(&mut self) -> &mut SessionConfig { + &mut self.config + } + /// Return the physical optimizers pub fn physical_optimizers(&self) -> &[Arc] { &self.physical_optimizers.rules diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c25523c5ae33..f5e937bb56a0 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1168,12 +1168,13 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(GlobalLimitExec::new(input, *skip, *fetch))) } - LogicalPlan::Unnest(Unnest { input, column, schema, options }) => { + LogicalPlan::Unnest(Unnest { input, columns, schema, options }) => { let input = self.create_initial_plan(input, session_state).await?; - let column_exec = schema.index_of_column(column) - .map(|idx| Column::new(&column.name, idx))?; + let column_execs = columns.iter().map(|column| { + schema.index_of_column(column).map(|idx| Column::new(&column.name, idx)) + }).collect::>()?; let schema = SchemaRef::new(schema.as_ref().to_owned().into()); - Ok(Arc::new(UnnestExec::new(input, column_exec, schema, options.clone()))) + Ok(Arc::new(UnnestExec::new(input, column_execs, schema, options.clone()))) } LogicalPlan::Ddl(ddl) => { // There is no default plan for DDl statements -- diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 0a7a87c7d81a..e29030e61457 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -501,13 +501,54 @@ impl SessionConfig { /// /// [^1]: Compare that to [`ConfigOptions`] which only supports [`ScalarValue`] payloads. pub fn with_extension(mut self, ext: Arc) -> Self + where + T: Send + Sync + 'static, + { + self.set_extension(ext); + self + } + + /// Set extension. Pretty much the same as [`with_extension`](Self::with_extension), but take + /// mutable reference instead of owning it. Useful if you want to add another extension after + /// the [`SessionConfig`] is created. + /// + /// # Example + /// ``` + /// use std::sync::Arc; + /// use datafusion_execution::config::SessionConfig; + /// + /// // application-specific extension types + /// struct Ext1(u8); + /// struct Ext2(u8); + /// struct Ext3(u8); + /// + /// let ext1a = Arc::new(Ext1(10)); + /// let ext1b = Arc::new(Ext1(11)); + /// let ext2 = Arc::new(Ext2(2)); + /// + /// let mut cfg = SessionConfig::default(); + /// + /// // will only remember the last Ext1 + /// cfg.set_extension(Arc::clone(&ext1a)); + /// cfg.set_extension(Arc::clone(&ext1b)); + /// cfg.set_extension(Arc::clone(&ext2)); + /// + /// let ext1_received = cfg.get_extension::().unwrap(); + /// assert!(!Arc::ptr_eq(&ext1_received, &ext1a)); + /// assert!(Arc::ptr_eq(&ext1_received, &ext1b)); + /// + /// let ext2_received = cfg.get_extension::().unwrap(); + /// assert!(Arc::ptr_eq(&ext2_received, &ext2)); + /// + /// assert!(cfg.get_extension::().is_none()); + /// ``` + pub fn set_extension(&mut self, ext: Arc) where T: Send + Sync + 'static, { let ext = ext as Arc; let id = TypeId::of::(); self.extensions.insert(id, ext); - self } /// Get extension, if any for the specified type `T` exists. diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 43cb0c3e0a50..5bfec00ea3b3 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -37,14 +37,8 @@ use strum_macros::EnumIter; #[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)] pub enum BuiltinScalarFunction { // math functions - /// ceil - Ceil, /// coalesce Coalesce, - /// exp - Exp, - /// factorial - Factorial, // string functions /// concat Concat, @@ -106,10 +100,7 @@ impl BuiltinScalarFunction { pub fn volatility(&self) -> Volatility { match self { // Immutable scalar builtins - BuiltinScalarFunction::Ceil => Volatility::Immutable, BuiltinScalarFunction::Coalesce => Volatility::Immutable, - BuiltinScalarFunction::Exp => Volatility::Immutable, - BuiltinScalarFunction::Factorial => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, @@ -145,15 +136,6 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "initcap") } BuiltinScalarFunction::EndsWith => Ok(Boolean), - - BuiltinScalarFunction::Factorial => Ok(Int64), - - BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => { - match input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - } - } } } @@ -185,17 +167,6 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Factorial => { - Signature::uniform(1, vec![Int64], self.volatility()) - } - BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => { - // math expressions expect 1 argument of type f64 or f32 - // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we - // return the best approximation for it (in f64). - // We accept f32 because in this case it is clear that the best approximation - // will be as good as the number of digits in the number - Signature::uniform(1, vec![Float64, Float32], self.volatility()) - } } } @@ -203,25 +174,12 @@ impl BuiltinScalarFunction { /// The list can be extended, only mathematical and datetime functions are /// considered for the initial implementation of this feature. pub fn monotonicity(&self) -> Option { - if matches!( - &self, - BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Factorial - ) { - Some(vec![Some(true)]) - } else { - None - } + None } /// Returns all names that can be used to call this function pub fn aliases(&self) -> &'static [&'static str] { match self { - BuiltinScalarFunction::Ceil => &["ceil"], - BuiltinScalarFunction::Exp => &["exp"], - BuiltinScalarFunction::Factorial => &["factorial"], - // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c7c50d871902..cffb58dadd8e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -186,7 +186,16 @@ pub enum Expr { #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Unnest { - pub exprs: Vec, + pub expr: Box, +} + +impl Unnest { + /// Create a new Unnest expression. + pub fn new(expr: Expr) -> Self { + Self { + expr: Box::new(expr), + } + } } /// Alias expression @@ -1567,8 +1576,8 @@ impl fmt::Display for Expr { } }, Expr::Placeholder(Placeholder { id, .. }) => write!(f, "{id}"), - Expr::Unnest(Unnest { exprs }) => { - write!(f, "UNNEST({exprs:?})") + Expr::Unnest(Unnest { expr }) => { + write!(f, "UNNEST({expr:?})") } } } @@ -1757,7 +1766,10 @@ fn create_name(e: &Expr) -> Result { } } } - Expr::Unnest(Unnest { exprs }) => create_function_name("unnest", false, exprs), + Expr::Unnest(Unnest { expr }) => { + let expr_name = create_name(expr)?; + Ok(format!("unnest({expr_name})")) + } Expr::ScalarFunction(fun) => create_function_name(fun.name(), false, &fun.args), Expr::WindowFunction(WindowFunction { fun, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 6a28275ebfcf..f7900f6b197d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -525,16 +525,6 @@ macro_rules! nary_scalar_expr { // generate methods for creating the supported unary/binary expressions // math functions -scalar_expr!(Factorial, factorial, num, "factorial"); -scalar_expr!( - Ceil, - ceil, - num, - "nearest integer greater than or equal to argument" -); - -scalar_expr!(Exp, exp, num, "exponential"); - scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); @@ -877,22 +867,6 @@ mod test { ); } - macro_rules! test_unary_scalar_expr { - ($ENUM:ident, $FUNC:ident) => {{ - if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args, - }) = $FUNC(col("tableA.a")) - { - let name = built_in_function::BuiltinScalarFunction::$ENUM; - assert_eq!(name, fun); - assert_eq!(1, args.len()); - } else { - assert!(false, "unexpected"); - } - }}; - } - macro_rules! test_scalar_expr { ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => { let expected = [$(stringify!($arg)),*]; @@ -913,10 +887,6 @@ mod test { #[test] fn scalar_function_definitions() { - test_unary_scalar_expr!(Factorial, factorial); - test_unary_scalar_expr!(Ceil, ceil); - test_unary_scalar_expr!(Exp, exp); - test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index d678fe7ee39c..c11619fc0ea2 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -82,13 +82,13 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( using_columns: &[HashSet], ) -> Result { // Normalize column inside Unnest - if let Expr::Unnest(Unnest { exprs }) = expr { + if let Expr::Unnest(Unnest { expr }) = expr { let e = normalize_col_with_schemas_and_ambiguity_check( - exprs[0].clone(), + expr.as_ref().clone(), schemas, using_columns, )?; - return Ok(Expr::Unnest(Unnest { exprs: vec![e] })); + return Ok(Expr::Unnest(Unnest { expr: Box::new(e) })); } expr.transform(&|expr| { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 39892d9e0c0d..466fd13ce207 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -115,12 +115,8 @@ impl ExprSchemable for Expr { Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), - Expr::Unnest(Unnest { exprs }) => { - let arg_data_types = exprs - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - let arg_data_type = arg_data_types[0].clone(); + Expr::Unnest(Unnest { expr }) => { + let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list match arg_data_type{ DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) =>{ diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index beac5a7f4eb7..f7c0fbac537b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1112,7 +1112,7 @@ impl LogicalPlanBuilder { /// Unnest the given column. pub fn unnest_column(self, column: impl Into) -> Result { - Ok(Self::from(unnest(self.plan, column.into())?)) + Ok(Self::from(unnest(self.plan, vec![column.into()])?)) } /// Unnest the given column given [`UnnestOptions`] @@ -1123,10 +1123,21 @@ impl LogicalPlanBuilder { ) -> Result { Ok(Self::from(unnest_with_options( self.plan, - column.into(), + vec![column.into()], options, )?)) } + + /// Unnest the given columns with the given [`UnnestOptions`] + pub fn unnest_columns_with_options( + self, + columns: Vec, + options: UnnestOptions, + ) -> Result { + Ok(Self::from(unnest_with_options( + self.plan, columns, options, + )?)) + } } pub fn change_redundant_column(fields: &Fields) -> Vec { let mut name_map = HashMap::new(); @@ -1534,44 +1545,50 @@ impl TableSource for LogicalTableSource { } /// Create a [`LogicalPlan::Unnest`] plan -pub fn unnest(input: LogicalPlan, column: Column) -> Result { - unnest_with_options(input, column, UnnestOptions::new()) +pub fn unnest(input: LogicalPlan, columns: Vec) -> Result { + unnest_with_options(input, columns, UnnestOptions::new()) } /// Create a [`LogicalPlan::Unnest`] plan with options pub fn unnest_with_options( input: LogicalPlan, - column: Column, + columns: Vec, options: UnnestOptions, ) -> Result { - let (unnest_qualifier, unnest_field) = - input.schema().qualified_field_from_column(&column)?; - // Extract the type of the nested field in the list. - let unnested_field = match unnest_field.data_type() { - DataType::List(field) - | DataType::FixedSizeList(field, _) - | DataType::LargeList(field) => Arc::new(Field::new( - unnest_field.name(), - field.data_type().clone(), - unnest_field.is_nullable(), - )), - _ => { - // If the unnest field is not a list type return the input plan. - return Ok(input); - } - }; + let mut unnested_fields: HashMap = HashMap::with_capacity(columns.len()); + // Add qualifiers to the columns. + let mut qualified_columns = Vec::with_capacity(columns.len()); + for c in &columns { + let index = input.schema().index_of_column(c)?; + let (unnest_qualifier, unnest_field) = input.schema().qualified_field(index); + let unnested_field = match unnest_field.data_type() { + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) => Arc::new(Field::new( + unnest_field.name(), + field.data_type().clone(), + // Unnesting may produce NULLs even if the list is not null. + // For example: unnset([1], []) -> 1, null + true, + )), + _ => { + // If the unnest field is not a list type return the input plan. + return Ok(input); + } + }; + qualified_columns.push(Column::from((unnest_qualifier, unnested_field.as_ref()))); + unnested_fields.insert(index, unnested_field); + } - // Update the schema with the unnest column type changed to contain the nested type. + // Update the schema with the unnest column types changed to contain the nested types. let input_schema = input.schema(); let fields = input_schema .iter() - .map(|(q, f)| { - if f.as_ref() == unnest_field && q == unnest_qualifier { - (unnest_qualifier.cloned(), unnested_field.clone()) - } else { - (q.cloned(), f.clone()) - } + .enumerate() + .map(|(index, (q, f))| match unnested_fields.get(&index) { + Some(unnested_field) => (q.cloned(), unnested_field.clone()), + None => (q.cloned(), f.clone()), }) .collect::>(); @@ -1580,11 +1597,9 @@ pub fn unnest_with_options( // We can use the existing functional dependencies: let deps = input_schema.functional_dependencies().clone(); let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); - let column = Column::from((unnest_qualifier, unnested_field.as_ref())); - Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), - column, + columns: qualified_columns, schema, options, })) diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index edc3afd55d63..3a2ed9ffc2d8 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -638,10 +638,10 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Node Type": "DescribeTable" }) } - LogicalPlan::Unnest(Unnest { column, .. }) => { + LogicalPlan::Unnest(Unnest { columns, .. }) => { json!({ "Node Type": "Unnest", - "Column": format!("{}", column) + "Column": expr_vec_fmt!(columns), }) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 91c8670f3805..dbff5046013b 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use super::dml::CopyTo; use super::DdlStatement; -use crate::builder::change_redundant_column; +use crate::builder::{change_redundant_column, unnest_with_options}; use crate::expr::{Alias, Placeholder, Sort as SortExpr, WindowFunction}; use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; @@ -807,51 +807,11 @@ impl LogicalPlan { } LogicalPlan::DescribeTable(_) => Ok(self.clone()), LogicalPlan::Unnest(Unnest { - column, - schema, - options, - .. + columns, options, .. }) => { // Update schema with unnested column type. - let input = Arc::new(inputs.swap_remove(0)); - let (nested_qualifier, nested_field) = - input.schema().qualified_field_from_column(column)?; - let (unnested_qualifier, unnested_field) = - schema.qualified_field_from_column(column)?; - let qualifiers_and_fields = input - .schema() - .iter() - .map(|(qualifier, field)| { - if qualifier.eq(&nested_qualifier) - && field.as_ref() == nested_field - { - ( - unnested_qualifier.cloned(), - Arc::new(unnested_field.clone()), - ) - } else { - (qualifier.cloned(), field.clone()) - } - }) - .collect::>(); - - let schema = Arc::new( - DFSchema::new_with_metadata( - qualifiers_and_fields, - input.schema().metadata().clone(), - )? - // We can use the existing functional dependencies as is: - .with_functional_dependencies( - input.schema().functional_dependencies().clone(), - )?, - ); - - Ok(LogicalPlan::Unnest(Unnest { - input, - column: column.clone(), - schema, - options: options.clone(), - })) + let input = inputs.swap_remove(0); + unnest_with_options(input, columns.clone(), options.clone()) } } } @@ -1581,8 +1541,8 @@ impl LogicalPlan { LogicalPlan::DescribeTable(DescribeTable { .. }) => { write!(f, "DescribeTable") } - LogicalPlan::Unnest(Unnest { column, .. }) => { - write!(f, "Unnest: {column}") + LogicalPlan::Unnest(Unnest { columns, .. }) => { + write!(f, "Unnest: {}", expr_vec_fmt!(columns)) } } } @@ -2556,8 +2516,8 @@ pub enum Partitioning { pub struct Unnest { /// The incoming logical plan pub input: Arc, - /// The column to unnest - pub column: Column, + /// The columns to unnest + pub columns: Vec, /// The output schema, containing the unnested field column. pub schema: DFSchemaRef, /// Options diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 3644f89e8b42..48f047c070dd 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -311,13 +311,13 @@ impl TreeNode for LogicalPlan { } LogicalPlan::Unnest(Unnest { input, - column, + columns, schema, options, }) => rewrite_arc(input, f)?.update_data(|input| { LogicalPlan::Unnest(Unnest { input, - column, + columns, schema, options, }) @@ -507,8 +507,12 @@ impl LogicalPlan { LogicalPlan::TableScan(TableScan { filters, .. }) => { filters.iter().apply_until_stop(f) } - LogicalPlan::Unnest(Unnest { column, .. }) => { - f(&Expr::Column(column.clone())) + LogicalPlan::Unnest(Unnest { columns, .. }) => { + let exprs = columns + .iter() + .map(|c| Expr::Column(c.clone())) + .collect::>(); + exprs.iter().apply_until_stop(f) } LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, @@ -706,20 +710,6 @@ impl LogicalPlan { fetch, }) }), - LogicalPlan::Unnest(Unnest { - input, - column, - schema, - options, - }) => f(Expr::Column(column))?.map_data(|column| match column { - Expr::Column(column) => Ok(LogicalPlan::Unnest(Unnest { - input, - column, - schema, - options, - })), - _ => internal_err!("Transformation should return Column"), - })?, LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, select_expr, @@ -744,6 +734,7 @@ impl LogicalPlan { }), // plans without expressions LogicalPlan::EmptyRelation(_) + | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 85097f6249e1..35fec509c95a 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -36,6 +36,7 @@ impl TreeNode for Expr { ) -> Result { let children = match self { Expr::Alias(Alias{expr,..}) + | Expr::Unnest(Unnest{expr}) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -60,7 +61,6 @@ impl TreeNode for Expr { GetFieldAccess::NamedStructField { .. } => vec![expr], } } - Expr::Unnest(Unnest { exprs }) | Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs new file mode 100644 index 000000000000..dc481da79069 --- /dev/null +++ b/datafusion/functions/src/math/factorial.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array}; +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Int64; + +use crate::utils::make_scalar_function; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub struct FactorialFunc { + signature: Signature, +} + +impl Default for FactorialFunc { + fn default() -> Self { + FactorialFunc::new() + } +} + +impl FactorialFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for FactorialFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "factorial" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(factorial, vec![])(args) + } +} + +macro_rules! make_function_scalar_inputs { + ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ + let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); + + arg.iter() + .map(|a| match a { + Some(a) => Some($FUNC(a)), + _ => None, + }) + .collect::<$ARRAY_TYPE>() + }}; +} + +/// Factorial SQL function +fn factorial(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Int64 => Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Int64Array, + { |value: i64| { (1..=value).product() } } + )) as ArrayRef), + other => exec_err!("Unsupported data type {other:?} for function factorial."), + } +} + +#[cfg(test)] +mod test { + + use datafusion_common::cast::as_int64_array; + + use super::*; + + #[test] + fn test_factorial_i64() { + let args: Vec = vec![ + Arc::new(Int64Array::from(vec![0, 1, 2, 4])), // input + ]; + + let result = factorial(&args).expect("failed to initialize function factorial"); + let ints = + as_int64_array(&result).expect("failed to initialize function factorial"); + + let expected = Int64Array::from(vec![1, 1, 2, 24]); + + assert_eq!(ints, &expected); + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index c83a98cb1913..b6e8d26b6460 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; pub mod abs; pub mod cot; +pub mod factorial; pub mod gcd; pub mod iszero; pub mod lcm; @@ -44,10 +45,13 @@ make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, Some(vec![Some(true)])); make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, None); +make_math_unary_udf!(CeilFunc, CEIL, ceil, ceil, Some(vec![Some(true)])); make_math_unary_udf!(CosFunc, COS, cos, cos, None); make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, None); make_udf_function!(cot::CotFunc, COT, cot); make_math_unary_udf!(DegreesFunc, DEGREES, degrees, to_degrees, None); +make_math_unary_udf!(ExpFunc, EXP, exp, exp, Some(vec![Some(true)])); +make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial); make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, Some(vec![Some(true)])); make_udf_function!(log::LogFunc, LOG, log); make_udf_function!(gcd::GcdFunc, GCD, gcd); @@ -119,6 +123,11 @@ pub mod expr_fn { super::cbrt().call(vec![num]) } + #[doc = "nearest integer greater than or equal to argument"] + pub fn ceil(num: Expr) -> Expr { + super::ceil().call(vec![num]) + } + #[doc = "cosine"] pub fn cos(num: Expr) -> Expr { super::cos().call(vec![num]) @@ -139,6 +148,16 @@ pub mod expr_fn { super::degrees().call(vec![num]) } + #[doc = "exponential"] + pub fn exp(num: Expr) -> Expr { + super::exp().call(vec![num]) + } + + #[doc = "factorial"] + pub fn factorial(num: Expr) -> Expr { + super::factorial().call(vec![num]) + } + #[doc = "nearest integer less than or equal to argument"] pub fn floor(num: Expr) -> Expr { super::floor().call(vec![num]) @@ -262,10 +281,13 @@ pub fn functions() -> Vec> { atan2(), atanh(), cbrt(), + ceil(), cos(), cosh(), cot(), degrees(), + exp(), + factorial(), floor(), gcd(), isnan(), diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index f1f49727c39c..9176d67c1d18 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -47,7 +47,6 @@ pub mod optimizer; pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; -pub mod push_down_projection; pub mod replace_distinct_aggregate; pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 6967b28f3037..b54fb248a7c7 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -925,20 +925,32 @@ fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result #[cfg(test)] mod tests { + use std::collections::HashMap; use std::fmt::Formatter; use std::sync::Arc; + use std::vec; use crate::optimize_projections::OptimizeProjections; + use crate::optimizer::Optimizer; use crate::test::{ - assert_optimized_plan_eq, test_table_scan, test_table_scan_with_name, + assert_fields_eq, assert_optimized_plan_eq, scan_empty, test_table_scan, + test_table_scan_fields, test_table_scan_with_name, }; + use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Column, DFSchemaRef, JoinType, Result, TableReference}; + use datafusion_common::{ + Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, + }; use datafusion_expr::{ - binary_expr, build_join_schema, col, count, lit, - logical_plan::builder::LogicalPlanBuilder, not, table_scan, try_cast, when, - BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, - UserDefinedLogicalNodeCore, + binary_expr, build_join_schema, + builder::table_scan_with_filters, + col, count, + expr::{self, Cast}, + lit, + logical_plan::{builder::LogicalPlanBuilder, table_scan}, + max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, + Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFrame, + WindowFunctionDefinition, }; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -1466,4 +1478,617 @@ mod tests { \n TableScan: r projection=[a]"; assert_optimized_plan_equal(plan, expected) } + + #[test] + fn aggregate_no_group_by() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ + \n TableScan: test projection=[b]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_group_by() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.b)]]\ + \n TableScan: test projection=[b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_group_by_with_table_alias() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .alias("a")? + .aggregate(vec![col("c")], vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[a.c]], aggr=[[MAX(a.b)]]\ + \n SubqueryAlias: a\ + \n TableScan: test projection=[b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_no_group_by_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("c").gt(lit(1)))? + .aggregate(Vec::::new(), vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ + \n Projection: test.b\ + \n Filter: test.c > Int32(1)\ + \n TableScan: test projection=[b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_with_periods() -> Result<()> { + let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]); + + // Build a plan that looks as follows (note "tag.one" is a column named + // "tag.one", not a column named "one" in a table named "tag"): + // + // Projection: tag.one + // Aggregate: groupBy=[], aggr=[MAX("tag.one") AS "tag.one"] + // TableScan + let plan = table_scan(Some("m4"), &schema, None)? + .aggregate( + Vec::::new(), + vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")], + )? + .project([col(Column::new_unqualified("tag.one"))])? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ + \n TableScan: m4 projection=[tag.one]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn redundant_project() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b"), col("c")])? + .project(vec![col("a"), col("c"), col("b")])? + .build()?; + let expected = "Projection: test.a, test.c, test.b\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn reorder_scan() -> Result<()> { + let schema = Schema::new(test_table_scan_fields()); + + let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; + let expected = "TableScan: test projection=[b, a, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn reorder_scan_projection() -> Result<()> { + let schema = Schema::new(test_table_scan_fields()); + + let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))? + .project(vec![col("a"), col("b")])? + .build()?; + let expected = "Projection: test.a, test.b\ + \n TableScan: test projection=[b, a]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn reorder_projection() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("c"), col("b"), col("a")])? + .build()?; + let expected = "Projection: test.c, test.b, test.a\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn noncontinuous_redundant_projection() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("c"), col("b"), col("a")])? + .filter(col("c").gt(lit(1)))? + .project(vec![col("c"), col("a"), col("b")])? + .filter(col("b").gt(lit(1)))? + .filter(col("a").gt(lit(1)))? + .project(vec![col("a"), col("c"), col("b")])? + .build()?; + let expected = "Projection: test.a, test.c, test.b\ + \n Filter: test.a > Int32(1)\ + \n Filter: test.b > Int32(1)\ + \n Projection: test.c, test.a, test.b\ + \n Filter: test.c > Int32(1)\ + \n Projection: test.c, test.b, test.a\ + \n TableScan: test projection=[a, b, c]"; + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn join_schema_trim_full_join_column_projection() -> Result<()> { + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); + let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)? + .project(vec![col("a"), col("b"), col("c1")])? + .build()?; + + // make sure projections are pushed down to both table scans + let expected = "Left Join: test.a = test2.c1\ + \n TableScan: test projection=[a, b]\ + \n TableScan: test2 projection=[c1]"; + + let optimized_plan = optimize(plan)?; + let formatted_plan = format!("{optimized_plan:?}"); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns + let optimized_join = optimized_plan; + assert_eq!( + **optimized_join.schema(), + DFSchema::new_with_metadata( + vec![ + ( + Some("test".into()), + Arc::new(Field::new("a", DataType::UInt32, false)) + ), + ( + Some("test".into()), + Arc::new(Field::new("b", DataType::UInt32, false)) + ), + ( + Some("test2".into()), + Arc::new(Field::new("c1", DataType::UInt32, true)) + ), + ], + HashMap::new() + )?, + ); + + Ok(()) + } + + #[test] + fn join_schema_trim_partial_join_column_projection() -> Result<()> { + // test join column push down without explicit column projections + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); + let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)? + // projecting joined column `a` should push the right side column `c1` projection as + // well into test2 table even though `c1` is not referenced in projection. + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to both table scans + let expected = "Projection: test.a, test.b\ + \n Left Join: test.a = test2.c1\ + \n TableScan: test projection=[a, b]\ + \n TableScan: test2 projection=[c1]"; + + let optimized_plan = optimize(plan)?; + let formatted_plan = format!("{optimized_plan:?}"); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new_with_metadata( + vec![ + ( + Some("test".into()), + Arc::new(Field::new("a", DataType::UInt32, false)) + ), + ( + Some("test".into()), + Arc::new(Field::new("b", DataType::UInt32, false)) + ), + ( + Some("test2".into()), + Arc::new(Field::new("c1", DataType::UInt32, true)) + ), + ], + HashMap::new() + )?, + ); + + Ok(()) + } + + #[test] + fn join_schema_trim_using_join() -> Result<()> { + // shared join columns from using join should be pushed to both sides + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join_using(table2_scan, JoinType::Left, vec!["a"])? + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to table scan + let expected = "Projection: test.a, test.b\ + \n Left Join: Using test.a = test2.a\ + \n TableScan: test projection=[a, b]\ + \n TableScan: test2 projection=[a]"; + + let optimized_plan = optimize(plan)?; + let formatted_plan = format!("{optimized_plan:?}"); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new_with_metadata( + vec![ + ( + Some("test".into()), + Arc::new(Field::new("a", DataType::UInt32, false)) + ), + ( + Some("test".into()), + Arc::new(Field::new("b", DataType::UInt32, false)) + ), + ( + Some("test2".into()), + Arc::new(Field::new("a", DataType::UInt32, true)) + ), + ], + HashMap::new() + )?, + ); + + Ok(()) + } + + #[test] + fn cast() -> Result<()> { + let table_scan = test_table_scan()?; + + let projection = LogicalPlanBuilder::from(table_scan) + .project(vec![Expr::Cast(Cast::new( + Box::new(col("c")), + DataType::Float64, + ))])? + .build()?; + + let expected = "Projection: CAST(test.c AS Float64)\ + \n TableScan: test projection=[c]"; + + assert_optimized_plan_equal(projection, expected) + } + + #[test] + fn table_scan_projected_schema() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a"), col("b")])? + .build()?; + + assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + assert_fields_eq(&plan, vec!["a", "b"]); + + let expected = "TableScan: test projection=[a, b]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn table_scan_projected_schema_non_qualified_relation() -> Result<()> { + let table_scan = test_table_scan()?; + let input_schema = table_scan.schema(); + assert_eq!(3, input_schema.fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + // Build the LogicalPlan directly (don't use PlanBuilder), so + // that the Column references are unqualified (e.g. their + // relation is `None`). PlanBuilder resolves the expressions + let expr = vec![col("test.a"), col("test.b")]; + let plan = + LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?); + + assert_fields_eq(&plan, vec!["a", "b"]); + + let expected = "TableScan: test projection=[a, b]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn table_limit() -> Result<()> { + let table_scan = test_table_scan()?; + assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("c"), col("a")])? + .limit(0, Some(5))? + .build()?; + + assert_fields_eq(&plan, vec!["c", "a"]); + + let expected = "Limit: skip=0, fetch=5\ + \n Projection: test.c, test.a\ + \n TableScan: test projection=[a, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn table_scan_without_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan).build()?; + // should expand projection to all columns without projection + let expected = "TableScan: test projection=[a, b, c]"; + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn table_scan_with_literal_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![lit(1_i64), lit(2_i64)])? + .build()?; + let expected = "Projection: Int64(1), Int64(2)\ + \n TableScan: test projection=[]"; + assert_optimized_plan_equal(plan, expected) + } + + /// tests that it removes unused columns in projections + #[test] + fn table_unused_column() -> Result<()> { + let table_scan = test_table_scan()?; + assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + // we never use "b" in the first projection => remove it + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("c"), col("a"), col("b")])? + .filter(col("c").gt(lit(1)))? + .aggregate(vec![col("c")], vec![max(col("a"))])? + .build()?; + + assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); + + let plan = optimize(plan).expect("failed to optimize plan"); + let expected = "\ + Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.a)]]\ + \n Filter: test.c > Int32(1)\ + \n Projection: test.c, test.a\ + \n TableScan: test projection=[a, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + /// tests that it removes un-needed projections + #[test] + fn table_unused_projection() -> Result<()> { + let table_scan = test_table_scan()?; + assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + // there is no need for the first projection + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("b")])? + .project(vec![lit(1).alias("a")])? + .build()?; + + assert_fields_eq(&plan, vec!["a"]); + + let expected = "\ + Projection: Int32(1) AS a\ + \n TableScan: test projection=[]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn table_full_filter_pushdown() -> Result<()> { + let schema = Schema::new(test_table_scan_fields()); + + let table_scan = table_scan_with_filters( + Some("test"), + &schema, + None, + vec![col("b").eq(lit(1))], + )? + .build()?; + assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + // there is no need for the first projection + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("b")])? + .project(vec![lit(1).alias("a")])? + .build()?; + + assert_fields_eq(&plan, vec!["a"]); + + let expected = "\ + Projection: Int32(1) AS a\ + \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; + + assert_optimized_plan_equal(plan, expected) + } + + /// tests that optimizing twice yields same plan + #[test] + fn test_double_optimization() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("b")])? + .project(vec![lit(1).alias("a")])? + .build()?; + + let optimized_plan1 = optimize(plan).expect("failed to optimize plan"); + let optimized_plan2 = + optimize(optimized_plan1.clone()).expect("failed to optimize plan"); + + let formatted_plan1 = format!("{optimized_plan1:?}"); + let formatted_plan2 = format!("{optimized_plan2:?}"); + assert_eq!(formatted_plan1, formatted_plan2); + Ok(()) + } + + /// tests that it removes an aggregate is never used downstream + #[test] + fn table_unused_aggregate() -> Result<()> { + let table_scan = test_table_scan()?; + assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + // we never use "min(b)" => remove it + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])? + .filter(col("c").gt(lit(1)))? + .project(vec![col("c"), col("a"), col("MAX(test.b)")])? + .build()?; + + assert_fields_eq(&plan, vec!["c", "a", "MAX(test.b)"]); + + let expected = "Projection: test.c, test.a, MAX(test.b)\ + \n Filter: test.c > Int32(1)\ + \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b)]]\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_filter_pushdown() -> Result<()> { + let table_scan = test_table_scan()?; + + let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new( + AggregateFunction::Count, + vec![col("b")], + false, + Some(Box::new(col("c").gt(lit(42)))), + None, + None, + )); + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![count(col("b")), aggr_with_filter.alias("count2")], + )? + .build()?; + + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn pushdown_through_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .distinct()? + .project(vec![col("a")])? + .build()?; + + let expected = "Projection: test.a\ + \n Distinct:\ + \n TableScan: test projection=[a, b]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn test_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let max1 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + vec![col("test.a")], + vec![col("test.b")], + vec![], + WindowFrame::new(None), + None, + )); + + let max2 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + vec![col("test.b")], + vec![], + vec![], + WindowFrame::new(None), + None, + )); + let col1 = col(max1.display_name()?); + let col2 = col(max2.display_name()?); + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![max1])? + .window(vec![max2])? + .project(vec![col1, col2])? + .build()?; + + let expected = "Projection: MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MAX(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[MAX(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n Projection: test.b, MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: test projection=[a, b]"; + + assert_optimized_plan_equal(plan, expected) + } + + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + + fn optimize(plan: LogicalPlan) -> Result { + let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + Ok(optimized_plan) + } } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs deleted file mode 100644 index 2f578094b3bc..000000000000 --- a/datafusion/optimizer/src/push_down_projection.rs +++ /dev/null @@ -1,660 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - use std::sync::Arc; - use std::vec; - - use crate::optimize_projections::OptimizeProjections; - use crate::optimizer::Optimizer; - use crate::test::*; - use crate::{OptimizerContext, OptimizerRule}; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Column, DFSchema, Result}; - use datafusion_expr::builder::table_scan_with_filters; - use datafusion_expr::expr::{self, Cast}; - use datafusion_expr::logical_plan::{ - builder::LogicalPlanBuilder, table_scan, JoinType, - }; - use datafusion_expr::{ - col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan, Projection, - WindowFrame, WindowFunctionDefinition, - }; - - #[test] - fn aggregate_no_group_by() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![max(col("b"))])? - .build()?; - - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ - \n TableScan: test projection=[b]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn aggregate_group_by() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("c")], vec![max(col("b"))])? - .build()?; - - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.b)]]\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn aggregate_group_by_with_table_alias() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .alias("a")? - .aggregate(vec![col("c")], vec![max(col("b"))])? - .build()?; - - let expected = "Aggregate: groupBy=[[a.c]], aggr=[[MAX(a.b)]]\ - \n SubqueryAlias: a\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn aggregate_no_group_by_with_filter() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .filter(col("c").gt(lit(1)))? - .aggregate(Vec::::new(), vec![max(col("b"))])? - .build()?; - - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ - \n Projection: test.b\ - \n Filter: test.c > Int32(1)\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn aggregate_with_periods() -> Result<()> { - let schema = Schema::new(vec![Field::new("tag.one", DataType::Utf8, false)]); - - // Build a plan that looks as follows (note "tag.one" is a column named - // "tag.one", not a column named "one" in a table named "tag"): - // - // Projection: tag.one - // Aggregate: groupBy=[], aggr=[MAX("tag.one") AS "tag.one"] - // TableScan - let plan = table_scan(Some("m4"), &schema, None)? - .aggregate( - Vec::::new(), - vec![max(col(Column::new_unqualified("tag.one"))).alias("tag.one")], - )? - .project([col(Column::new_unqualified("tag.one"))])? - .build()?; - - let expected = "\ - Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ - \n TableScan: m4 projection=[tag.one]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn redundant_project() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("b"), col("c")])? - .project(vec![col("a"), col("c"), col("b")])? - .build()?; - let expected = "Projection: test.a, test.c, test.b\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn reorder_scan() -> Result<()> { - let schema = Schema::new(test_table_scan_fields()); - - let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; - let expected = "TableScan: test projection=[b, a, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn reorder_scan_projection() -> Result<()> { - let schema = Schema::new(test_table_scan_fields()); - - let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))? - .project(vec![col("a"), col("b")])? - .build()?; - let expected = "Projection: test.a, test.b\ - \n TableScan: test projection=[b, a]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn reorder_projection() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("c"), col("b"), col("a")])? - .build()?; - let expected = "Projection: test.c, test.b, test.a\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn noncontinuous_redundant_projection() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("c"), col("b"), col("a")])? - .filter(col("c").gt(lit(1)))? - .project(vec![col("c"), col("a"), col("b")])? - .filter(col("b").gt(lit(1)))? - .filter(col("a").gt(lit(1)))? - .project(vec![col("a"), col("c"), col("b")])? - .build()?; - let expected = "Projection: test.a, test.c, test.b\ - \n Filter: test.a > Int32(1)\ - \n Filter: test.b > Int32(1)\ - \n Projection: test.c, test.a, test.b\ - \n Filter: test.c > Int32(1)\ - \n Projection: test.c, test.b, test.a\ - \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn join_schema_trim_full_join_column_projection() -> Result<()> { - let table_scan = test_table_scan()?; - - let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); - let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)? - .project(vec![col("a"), col("b"), col("c1")])? - .build()?; - - // make sure projections are pushed down to both table scans - let expected = "Left Join: test.a = test2.c1\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[c1]"; - - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); - - // make sure schema for join node include both join columns - let optimized_join = optimized_plan; - assert_eq!( - **optimized_join.schema(), - DFSchema::new_with_metadata( - vec![ - ( - Some("test".into()), - Arc::new(Field::new("a", DataType::UInt32, false)) - ), - ( - Some("test".into()), - Arc::new(Field::new("b", DataType::UInt32, false)) - ), - ( - Some("test2".into()), - Arc::new(Field::new("c1", DataType::UInt32, true)) - ), - ], - HashMap::new() - )?, - ); - - Ok(()) - } - - #[test] - fn join_schema_trim_partial_join_column_projection() -> Result<()> { - // test join column push down without explicit column projections - - let table_scan = test_table_scan()?; - - let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); - let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)? - // projecting joined column `a` should push the right side column `c1` projection as - // well into test2 table even though `c1` is not referenced in projection. - .project(vec![col("a"), col("b")])? - .build()?; - - // make sure projections are pushed down to both table scans - let expected = "Projection: test.a, test.b\ - \n Left Join: test.a = test2.c1\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[c1]"; - - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); - - // make sure schema for join node include both join columns - let optimized_join = optimized_plan.inputs()[0]; - assert_eq!( - **optimized_join.schema(), - DFSchema::new_with_metadata( - vec![ - ( - Some("test".into()), - Arc::new(Field::new("a", DataType::UInt32, false)) - ), - ( - Some("test".into()), - Arc::new(Field::new("b", DataType::UInt32, false)) - ), - ( - Some("test2".into()), - Arc::new(Field::new("c1", DataType::UInt32, true)) - ), - ], - HashMap::new() - )?, - ); - - Ok(()) - } - - #[test] - fn join_schema_trim_using_join() -> Result<()> { - // shared join columns from using join should be pushed to both sides - - let table_scan = test_table_scan()?; - - let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); - let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .join_using(table2_scan, JoinType::Left, vec!["a"])? - .project(vec![col("a"), col("b")])? - .build()?; - - // make sure projections are pushed down to table scan - let expected = "Projection: test.a, test.b\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[a]"; - - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); - - // make sure schema for join node include both join columns - let optimized_join = optimized_plan.inputs()[0]; - assert_eq!( - **optimized_join.schema(), - DFSchema::new_with_metadata( - vec![ - ( - Some("test".into()), - Arc::new(Field::new("a", DataType::UInt32, false)) - ), - ( - Some("test".into()), - Arc::new(Field::new("b", DataType::UInt32, false)) - ), - ( - Some("test2".into()), - Arc::new(Field::new("a", DataType::UInt32, true)) - ), - ], - HashMap::new() - )?, - ); - - Ok(()) - } - - #[test] - fn cast() -> Result<()> { - let table_scan = test_table_scan()?; - - let projection = LogicalPlanBuilder::from(table_scan) - .project(vec![Expr::Cast(Cast::new( - Box::new(col("c")), - DataType::Float64, - ))])? - .build()?; - - let expected = "Projection: CAST(test.c AS Float64)\ - \n TableScan: test projection=[c]"; - - assert_optimized_plan_eq(projection, expected) - } - - #[test] - fn table_scan_projected_schema() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(test_table_scan()?) - .project(vec![col("a"), col("b")])? - .build()?; - - assert_eq!(3, table_scan.schema().fields().len()); - assert_fields_eq(&table_scan, vec!["a", "b", "c"]); - assert_fields_eq(&plan, vec!["a", "b"]); - - let expected = "TableScan: test projection=[a, b]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn table_scan_projected_schema_non_qualified_relation() -> Result<()> { - let table_scan = test_table_scan()?; - let input_schema = table_scan.schema(); - assert_eq!(3, input_schema.fields().len()); - assert_fields_eq(&table_scan, vec!["a", "b", "c"]); - - // Build the LogicalPlan directly (don't use PlanBuilder), so - // that the Column references are unqualified (e.g. their - // relation is `None`). PlanBuilder resolves the expressions - let expr = vec![col("test.a"), col("test.b")]; - let plan = - LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?); - - assert_fields_eq(&plan, vec!["a", "b"]); - - let expected = "TableScan: test projection=[a, b]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn table_limit() -> Result<()> { - let table_scan = test_table_scan()?; - assert_eq!(3, table_scan.schema().fields().len()); - assert_fields_eq(&table_scan, vec!["a", "b", "c"]); - - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("c"), col("a")])? - .limit(0, Some(5))? - .build()?; - - assert_fields_eq(&plan, vec!["c", "a"]); - - let expected = "Limit: skip=0, fetch=5\ - \n Projection: test.c, test.a\ - \n TableScan: test projection=[a, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn table_scan_without_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan).build()?; - // should expand projection to all columns without projection - let expected = "TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn table_scan_with_literal_projection() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![lit(1_i64), lit(2_i64)])? - .build()?; - let expected = "Projection: Int64(1), Int64(2)\ - \n TableScan: test projection=[]"; - assert_optimized_plan_eq(plan, expected) - } - - /// tests that it removes unused columns in projections - #[test] - fn table_unused_column() -> Result<()> { - let table_scan = test_table_scan()?; - assert_eq!(3, table_scan.schema().fields().len()); - assert_fields_eq(&table_scan, vec!["a", "b", "c"]); - - // we never use "b" in the first projection => remove it - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("c"), col("a"), col("b")])? - .filter(col("c").gt(lit(1)))? - .aggregate(vec![col("c")], vec![max(col("a"))])? - .build()?; - - assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); - - let plan = optimize(plan).expect("failed to optimize plan"); - let expected = "\ - Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.a)]]\ - \n Filter: test.c > Int32(1)\ - \n Projection: test.c, test.a\ - \n TableScan: test projection=[a, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - /// tests that it removes un-needed projections - #[test] - fn table_unused_projection() -> Result<()> { - let table_scan = test_table_scan()?; - assert_eq!(3, table_scan.schema().fields().len()); - assert_fields_eq(&table_scan, vec!["a", "b", "c"]); - - // there is no need for the first projection - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("b")])? - .project(vec![lit(1).alias("a")])? - .build()?; - - assert_fields_eq(&plan, vec!["a"]); - - let expected = "\ - Projection: Int32(1) AS a\ - \n TableScan: test projection=[]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn table_full_filter_pushdown() -> Result<()> { - let schema = Schema::new(test_table_scan_fields()); - - let table_scan = table_scan_with_filters( - Some("test"), - &schema, - None, - vec![col("b").eq(lit(1))], - )? - .build()?; - assert_eq!(3, table_scan.schema().fields().len()); - assert_fields_eq(&table_scan, vec!["a", "b", "c"]); - - // there is no need for the first projection - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("b")])? - .project(vec![lit(1).alias("a")])? - .build()?; - - assert_fields_eq(&plan, vec!["a"]); - - let expected = "\ - Projection: Int32(1) AS a\ - \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; - - assert_optimized_plan_eq(plan, expected) - } - - /// tests that optimizing twice yields same plan - #[test] - fn test_double_optimization() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("b")])? - .project(vec![lit(1).alias("a")])? - .build()?; - - let optimized_plan1 = optimize(plan).expect("failed to optimize plan"); - let optimized_plan2 = - optimize(optimized_plan1.clone()).expect("failed to optimize plan"); - - let formatted_plan1 = format!("{optimized_plan1:?}"); - let formatted_plan2 = format!("{optimized_plan2:?}"); - assert_eq!(formatted_plan1, formatted_plan2); - Ok(()) - } - - /// tests that it removes an aggregate is never used downstream - #[test] - fn table_unused_aggregate() -> Result<()> { - let table_scan = test_table_scan()?; - assert_eq!(3, table_scan.schema().fields().len()); - assert_fields_eq(&table_scan, vec!["a", "b", "c"]); - - // we never use "min(b)" => remove it - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])? - .filter(col("c").gt(lit(1)))? - .project(vec![col("c"), col("a"), col("MAX(test.b)")])? - .build()?; - - assert_fields_eq(&plan, vec!["c", "a", "MAX(test.b)"]); - - let expected = "Projection: test.c, test.a, MAX(test.b)\ - \n Filter: test.c > Int32(1)\ - \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b)]]\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn aggregate_filter_pushdown() -> Result<()> { - let table_scan = test_table_scan()?; - - let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("b")], - false, - Some(Box::new(col("c").gt(lit(42)))), - None, - None, - )); - - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate( - vec![col("a")], - vec![count(col("b")), aggr_with_filter.alias("count2")], - )? - .build()?; - - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn pushdown_through_distinct() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("b")])? - .distinct()? - .project(vec![col("a")])? - .build()?; - - let expected = "Projection: test.a\ - \n Distinct:\ - \n TableScan: test projection=[a, b]"; - - assert_optimized_plan_eq(plan, expected) - } - - #[test] - fn test_window() -> Result<()> { - let table_scan = test_table_scan()?; - - let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.a")], - vec![col("test.b")], - vec![], - WindowFrame::new(None), - None, - )); - - let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.b")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); - let col1 = col(max1.display_name()?); - let col2 = col(max2.display_name()?); - - let plan = LogicalPlanBuilder::from(table_scan) - .window(vec![max1])? - .window(vec![max2])? - .project(vec![col1, col2])? - .build()?; - - let expected = "Projection: MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MAX(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[MAX(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n Projection: test.b, MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: test projection=[a, b]"; - - assert_optimized_plan_eq(plan, expected) - } - - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - let optimized_plan = optimize(plan).expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); - Ok(()) - } - - fn optimize(plan: LogicalPlan) -> Result { - let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - - Ok(optimized_plan) - } - - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} -} diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 688cdf798bdd..ed4600f2d95e 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -228,8 +228,7 @@ mod tests { use itertools::Itertools; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{BuiltinScalarFunction, Operator, ScalarUDF}; + use datafusion_expr::{Operator, ScalarUDF}; use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, create_random_schema, @@ -241,7 +240,6 @@ mod tests { }; use crate::expressions::Column; use crate::expressions::{col, BinaryExpr}; - use crate::functions::create_physical_expr; use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExpr, PhysicalSortExpr}; @@ -301,11 +299,12 @@ mod tests { &[], &DFSchema::empty(), )?; - let exp_a = &create_physical_expr( - &BuiltinScalarFunction::Exp, + let exp_a = &crate::udf::create_physical_expr( + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; let a_plus_b = Arc::new(BinaryExpr::new( col_a.clone(), diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c237e2070675..6efbc4179ff4 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -50,8 +50,7 @@ use datafusion_expr::{ use crate::sort_properties::SortProperties; use crate::{ - conditional_expressions, math_expressions, string_expressions, PhysicalExpr, - ScalarFunctionExpr, + conditional_expressions, string_expressions, PhysicalExpr, ScalarFunctionExpr, }; /// Create a physical (function) expression. @@ -178,12 +177,6 @@ pub fn create_physical_fun( fun: &BuiltinScalarFunction, ) -> Result { Ok(match fun { - // math functions - BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), - BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp), - BuiltinScalarFunction::Factorial => { - Arc::new(|args| make_scalar_function_inner(math_expressions::factorial)(args)) - } // string functions BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), @@ -279,10 +272,7 @@ fn func_order_in_one_dimension( #[cfg(test)] mod tests { use arrow::{ - array::{ - Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int32Array, - StringArray, UInt64Array, - }, + array::{Array, ArrayRef, BooleanArray, Int32Array, StringArray, UInt64Array}, datatypes::Field, record_batch::RecordBatch, }; @@ -410,46 +400,6 @@ mod tests { Utf8, StringArray ); - test_function!( - Exp, - &[lit(ScalarValue::Int32(Some(1)))], - Ok(Some((1.0_f64).exp())), - f64, - Float64, - Float64Array - ); - test_function!( - Exp, - &[lit(ScalarValue::UInt32(Some(1)))], - Ok(Some((1.0_f64).exp())), - f64, - Float64, - Float64Array - ); - test_function!( - Exp, - &[lit(ScalarValue::UInt64(Some(1)))], - Ok(Some((1.0_f64).exp())), - f64, - Float64, - Float64Array - ); - test_function!( - Exp, - &[lit(ScalarValue::Float64(Some(1.0)))], - Ok(Some((1.0_f64).exp())), - f64, - Float64, - Float64Array - ); - test_function!( - Exp, - &[lit(ScalarValue::Float32(Some(1.0)))], - Ok(Some((1.0_f32).exp())), - f32, - Float32, - Float32Array - ); test_function!( InitCap, &[lit("hi THOMAS")], diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 004a9abe7f0b..cee1b8c787e2 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -21,69 +21,12 @@ use std::any::type_name; use std::sync::Arc; use arrow::array::ArrayRef; -use arrow::array::{BooleanArray, Float32Array, Float64Array, Int64Array}; +use arrow::array::{BooleanArray, Float32Array, Float64Array}; use arrow::datatypes::DataType; use arrow_array::Array; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::exec_err; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::ColumnarValue; - -macro_rules! downcast_compute_op { - ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident) => {{ - let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); - match n { - Some(array) => { - let res: $TYPE = - arrow::compute::kernels::arity::unary(array, |x| x.$FUNC()); - Ok(Arc::new(res)) - } - _ => exec_err!("Invalid data type for {}", $NAME), - } - }}; -} - -macro_rules! unary_primitive_array_op { - ($VALUE:expr, $NAME:expr, $FUNC:ident) => {{ - match ($VALUE) { - ColumnarValue::Array(array) => match array.data_type() { - DataType::Float32 => { - let result = downcast_compute_op!(array, $NAME, $FUNC, Float32Array); - Ok(ColumnarValue::Array(result?)) - } - DataType::Float64 => { - let result = downcast_compute_op!(array, $NAME, $FUNC, Float64Array); - Ok(ColumnarValue::Array(result?)) - } - other => { - exec_err!("Unsupported data type {:?} for function {}", other, $NAME) - } - }, - ColumnarValue::Scalar(a) => match a { - ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar( - ScalarValue::Float32(a.map(|x| x.$FUNC())), - )), - ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar( - ScalarValue::Float64(a.map(|x| x.$FUNC())), - )), - _ => exec_err!( - "Unsupported data type {:?} for function {}", - ($VALUE).data_type(), - $NAME - ), - }, - } - }}; -} - -macro_rules! math_unary_function { - ($NAME:expr, $FUNC:ident) => { - /// mathematical function that accepts f32 or f64 and returns f64 - pub fn $FUNC(args: &[ColumnarValue]) -> Result { - unary_primitive_array_op!(&args[0], $NAME, $FUNC) - } - }; -} macro_rules! downcast_arg { ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ @@ -98,19 +41,6 @@ macro_rules! downcast_arg { }}; } -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; -} - macro_rules! make_function_scalar_inputs_return_type { ($ARG: expr, $NAME:expr, $ARGS_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ let arg = downcast_arg!($ARG, $NAME, $ARGS_TYPE); @@ -124,22 +54,6 @@ macro_rules! make_function_scalar_inputs_return_type { }}; } -math_unary_function!("ceil", ceil); -math_unary_function!("exp", exp); - -/// Factorial SQL function -pub fn factorial(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Int64 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Int64Array, - { |value: i64| { (1..=value).product() } } - )) as ArrayRef), - other => exec_err!("Unsupported data type {other:?} for function factorial."), - } -} - /// Isnan SQL function pub fn isnan(args: &[ArrayRef]) -> Result { match args[0].data_type() { @@ -167,25 +81,10 @@ pub fn isnan(args: &[ArrayRef]) -> Result { mod tests { use arrow::array::Float64Array; - use datafusion_common::cast::{as_boolean_array, as_int64_array}; + use datafusion_common::cast::as_boolean_array; use super::*; - #[test] - fn test_factorial_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 1, 2, 4])), // input - ]; - - let result = factorial(&args).expect("failed to initialize function factorial"); - let ints = - as_int64_array(&result).expect("failed to initialize function factorial"); - - let expected = Int64Array::from(vec![1, 1, 2, 24]); - - assert_eq!(ints, &expected); - } - #[test] fn test_isnan_f64() { let args: Vec = vec![Arc::new(Float64Array::from(vec![ diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 6a78bd596a46..6863f2646000 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -39,6 +39,7 @@ ahash = { version = "0.8", default-features = false, features = [ arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } async-trait = { workspace = true } chrono = { workspace = true } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 6ea1b3c40c83..45b848112ba9 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! Defines the unnest column plan for unnesting values in a column that contains a list -//! type, conceptually is like joining each row with all the values in the list column. +//! Define a plan for unnesting values in columns that contain a list type. + +use std::collections::HashMap; use std::{any::Any, sync::Arc}; use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; @@ -27,15 +28,17 @@ use crate::{ }; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, FixedSizeListArray, GenericListArray, - LargeListArray, ListArray, OffsetSizeTrait, PrimitiveArray, -}; -use arrow::compute::kernels; -use arrow::datatypes::{ - ArrowNativeType, DataType, Int32Type, Int64Type, Schema, SchemaRef, + Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray, + PrimitiveArray, }; +use arrow::compute::kernels::length::length; +use arrow::compute::kernels::zip::zip; +use arrow::compute::{cast, is_not_null, kernels, sum}; +use arrow::datatypes::{DataType, Int64Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{exec_err, Result, UnnestOptions}; +use arrow_array::{Int64Array, Scalar}; +use arrow_ord::cmp::lt; +use datafusion_common::{exec_datafusion_err, exec_err, Result, UnnestOptions}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -43,7 +46,7 @@ use async_trait::async_trait; use futures::{Stream, StreamExt}; use log::trace; -/// Unnest the given column by joining the row with each value in the +/// Unnest the given columns by joining the row with each value in the /// nested type. /// /// See [`UnnestOptions`] for more details and an example. @@ -53,8 +56,8 @@ pub struct UnnestExec { input: Arc, /// The schema once the unnest is applied schema: SchemaRef, - /// The unnest column - column: Column, + /// The unnest columns + columns: Vec, /// Options options: UnnestOptions, /// Execution metrics @@ -67,7 +70,7 @@ impl UnnestExec { /// Create a new [UnnestExec]. pub fn new( input: Arc, - column: Column, + columns: Vec, schema: SchemaRef, options: UnnestOptions, ) -> Self { @@ -75,7 +78,7 @@ impl UnnestExec { UnnestExec { input, schema, - column, + columns, options, metrics: Default::default(), cache, @@ -134,7 +137,7 @@ impl ExecutionPlan for UnnestExec { ) -> Result> { Ok(Arc::new(UnnestExec::new( children[0].clone(), - self.column.clone(), + self.columns.clone(), self.schema.clone(), self.options.clone(), ))) @@ -155,7 +158,7 @@ impl ExecutionPlan for UnnestExec { Ok(Box::pin(UnnestStream { input, schema: self.schema.clone(), - column: self.column.clone(), + columns: self.columns.clone(), options: self.options.clone(), metrics, })) @@ -210,8 +213,8 @@ struct UnnestStream { input: SendableRecordBatchStream, /// Unnested schema schema: Arc, - /// The unnest column - column: Column, + /// The unnest columns + columns: Vec, /// Options options: UnnestOptions, /// Metrics @@ -249,7 +252,7 @@ impl UnnestStream { Some(Ok(batch)) => { let timer = self.metrics.elapsed_compute.timer(); let result = - build_batch(&batch, &self.schema, &self.column, &self.options); + build_batch(&batch, &self.schema, &self.columns, &self.options); self.metrics.input_batches.add(1); self.metrics.input_rows.add(batch.num_rows()); if let Ok(ref batch) = result { @@ -276,270 +279,265 @@ impl UnnestStream { } } +/// For each row in a `RecordBatch`, some list columns need to be unnested. +/// We will expand the values in each list into multiple rows, +/// taking the longest length among these lists, and shorter lists are padded with NULLs. +// +/// For columns that don't need to be unnested, repeat their values until reaching the longest length. fn build_batch( batch: &RecordBatch, schema: &SchemaRef, - column: &Column, + columns: &[Column], options: &UnnestOptions, ) -> Result { - let list_array = column.evaluate(batch)?.into_array(batch.num_rows())?; - match list_array.data_type() { - DataType::List(_) => { - let list_array = list_array.as_any().downcast_ref::().unwrap(); - build_batch_generic_list::( - batch, - schema, - column.index(), - list_array, - options, - ) - } - DataType::LargeList(_) => { - let list_array = list_array - .as_any() - .downcast_ref::() - .unwrap(); - build_batch_generic_list::( - batch, - schema, - column.index(), - list_array, - options, - ) - } - DataType::FixedSizeList(_, _) => { - let list_array = list_array - .as_any() - .downcast_ref::() - .unwrap(); - build_batch_fixedsize_list(batch, schema, column.index(), list_array, options) - } - _ => exec_err!("Invalid unnest column {column}"), + let list_arrays: Vec = columns + .iter() + .map(|column| column.evaluate(batch)?.into_array(batch.num_rows())) + .collect::>()?; + + let longest_length = find_longest_length(&list_arrays, options)?; + let unnested_length = longest_length.as_primitive::(); + let total_length = if unnested_length.is_empty() { + 0 + } else { + sum(unnested_length).ok_or_else(|| { + exec_datafusion_err!("Failed to calculate the total unnested length") + })? as usize + }; + if total_length == 0 { + return Ok(RecordBatch::new_empty(schema.clone())); } -} -fn build_batch_generic_list>( - batch: &RecordBatch, - schema: &SchemaRef, - unnest_column_idx: usize, - list_array: &GenericListArray, - options: &UnnestOptions, -) -> Result { - let unnested_array = unnest_generic_list::(list_array, options)?; - - let take_indicies = - create_take_indicies_generic::(list_array, unnested_array.len(), options); - - batch_from_indices( - batch, - schema, - unnest_column_idx, - &unnested_array, - &take_indicies, - ) + // Unnest all the list arrays + let unnested_arrays = + unnest_list_arrays(&list_arrays, unnested_length, total_length)?; + let unnested_array_map: HashMap<_, _> = unnested_arrays + .into_iter() + .zip(columns.iter()) + .map(|(array, column)| (column.index(), array)) + .collect(); + + // Create the take indices array for other columns + let take_indicies = create_take_indicies(unnested_length, total_length); + + batch_from_indices(batch, schema, &unnested_array_map, &take_indicies) } -/// Given this `GenericList` list_array: +/// Find the longest list length among the given list arrays for each row. +/// +/// For example if we have the following two list arrays: /// /// ```ignore -/// [1], null, [2, 3, 4], null, [5, 6] +/// l1: [1, 2, 3], null, [], [3] +/// l2: [4,5], [], null, [6, 7] /// ``` -/// Its values array is represented like this: +/// +/// If `preserve_nulls` is false, the longest length array will be: /// /// ```ignore -/// [1, 2, 3, 4, 5, 6] +/// longest_length: [3, 0, 0, 2] /// ``` /// -/// So if there are no null values or `UnnestOptions.preserve_nulls` is false -/// we can return the values array without any copying. +/// whereas if `preserve_nulls` is true, the longest length array will be: /// -/// Otherwise we'll transfrom the values array using the take kernel and the following take indicies: /// /// ```ignore -/// 0, null, 1, 2, 3, null, 4, 5 +/// longest_length: [3, 1, 1, 2] /// ``` /// -fn unnest_generic_list>( - list_array: &GenericListArray, +fn find_longest_length( + list_arrays: &[ArrayRef], options: &UnnestOptions, -) -> Result> { - let values = list_array.values(); - if list_array.null_count() == 0 { - return Ok(values.clone()); +) -> Result { + // The length of a NULL list + let null_length = if options.preserve_nulls { + Scalar::new(Int64Array::from_value(1, 1)) + } else { + Scalar::new(Int64Array::from_value(0, 1)) + }; + let list_lengths: Vec = list_arrays + .iter() + .map(|list_array| { + let mut length_array = length(list_array)?; + // Make sure length arrays have the same type. Int64 is the most general one. + length_array = cast(&length_array, &DataType::Int64)?; + length_array = + zip(&is_not_null(&length_array)?, &length_array, &null_length)?; + Ok(length_array) + }) + .collect::>()?; + + let longest_length = list_lengths.iter().skip(1).try_fold( + list_lengths[0].clone(), + |longest, current| { + let is_lt = lt(&longest, ¤t)?; + zip(&is_lt, ¤t, &longest) + }, + )?; + Ok(longest_length) +} + +/// Trait defining common methods used for unnesting, implemented by list array types. +trait ListArrayType: Array { + /// Returns a reference to the values of this list. + fn values(&self) -> &ArrayRef; + + /// Returns the start and end offset of the values for the given row. + fn value_offsets(&self, row: usize) -> (i64, i64); +} + +impl ListArrayType for ListArray { + fn values(&self) -> &ArrayRef { + self.values() } - let mut take_indicies_builder = - PrimitiveArray::

::builder(values.len() + list_array.null_count()); - let offsets = list_array.value_offsets(); - for row in 0..list_array.len() { - if list_array.is_null(row) { - if options.preserve_nulls { - take_indicies_builder.append_null(); - } - } else { - let start = offsets[row].as_usize(); - let end = offsets[row + 1].as_usize(); - for idx in start..end { - take_indicies_builder.append_value(P::Native::from_usize(idx).unwrap()); - } - } + fn value_offsets(&self, row: usize) -> (i64, i64) { + let offsets = self.value_offsets(); + (offsets[row].into(), offsets[row + 1].into()) } - Ok(kernels::take::take( - &values, - &take_indicies_builder.finish(), - None, - )?) } -fn build_batch_fixedsize_list( - batch: &RecordBatch, - schema: &SchemaRef, - unnest_column_idx: usize, - list_array: &FixedSizeListArray, - options: &UnnestOptions, -) -> Result { - let unnested_array = unnest_fixed_list(list_array, options)?; - - let take_indicies = - create_take_indicies_fixed(list_array, unnested_array.len(), options); - - batch_from_indices( - batch, - schema, - unnest_column_idx, - &unnested_array, - &take_indicies, - ) +impl ListArrayType for LargeListArray { + fn values(&self) -> &ArrayRef { + self.values() + } + + fn value_offsets(&self, row: usize) -> (i64, i64) { + let offsets = self.value_offsets(); + (offsets[row], offsets[row + 1]) + } } -/// Given this `FixedSizeListArray` list_array: +impl ListArrayType for FixedSizeListArray { + fn values(&self) -> &ArrayRef { + self.values() + } + + fn value_offsets(&self, row: usize) -> (i64, i64) { + let start = self.value_offset(row) as i64; + (start, start + self.value_length() as i64) + } +} + +/// Unnest multiple list arrays according to the length array. +fn unnest_list_arrays( + list_arrays: &[ArrayRef], + length_array: &PrimitiveArray, + capacity: usize, +) -> Result> { + let typed_arrays = list_arrays + .iter() + .map(|list_array| match list_array.data_type() { + DataType::List(_) => Ok(list_array.as_list::() as &dyn ListArrayType), + DataType::LargeList(_) => { + Ok(list_array.as_list::() as &dyn ListArrayType) + } + DataType::FixedSizeList(_, _) => { + Ok(list_array.as_fixed_size_list() as &dyn ListArrayType) + } + other => exec_err!("Invalid unnest datatype {other }"), + }) + .collect::>>()?; + + // If there is only one list column to unnest and it doesn't contain any NULL lists, + // we can return the values array directly without any copying. + if typed_arrays.len() == 1 && typed_arrays[0].null_count() == 0 { + Ok(vec![typed_arrays[0].values().clone()]) + } else { + typed_arrays + .iter() + .map(|list_array| unnest_list_array(*list_array, length_array, capacity)) + .collect::>() + } +} + +/// Unnest a list array according the target length array. /// -/// ```ignore -/// [1, 2], null, [3, 4], null, [5, 6] -/// ``` -/// Its values array is represented like this: +/// Consider a list array like this: /// /// ```ignore -/// [1, 2, null, null 3, 4, null, null, 5, 6] +/// [1], [2, 3, 4], null, [5], [], /// ``` /// -/// So if there are no null values -/// we can return the values array without any copying. -/// -/// Otherwise we'll transfrom the values array using the take kernel. -/// -/// If `UnnestOptions.preserve_nulls` is true the take indicies will look like this: +/// and the length array is: /// /// ```ignore -/// 0, 1, null, 4, 5, null, 8, 9 +/// [2, 3, 2, 1, 2] /// ``` -/// Otherwise we drop the nulls and take indicies will look like this: +/// +/// If the length of a certain list is less than the target length, pad with NULLs. +/// So the unnested array will look like this: /// /// ```ignore -/// 0, 1, 4, 5, 8, 9 +/// [1, null, 2, 3, 4, null, null, 5, null, null] /// ``` /// -fn unnest_fixed_list( - list_array: &FixedSizeListArray, - options: &UnnestOptions, -) -> Result> { +fn unnest_list_array( + list_array: &dyn ListArrayType, + length_array: &PrimitiveArray, + capacity: usize, +) -> Result { let values = list_array.values(); - - if list_array.null_count() == 0 { - Ok(values.clone()) - } else { - let len_without_nulls = - values.len() - list_array.null_count() * list_array.value_length() as usize; - let null_count = if options.preserve_nulls { - list_array.null_count() - } else { - 0 - }; - let mut builder = - PrimitiveArray::::builder(len_without_nulls + null_count); - let mut take_offset = 0; - let fixed_value_length = list_array.value_length() as usize; - list_array.iter().for_each(|elem| match elem { - Some(_) => { - for i in 0..fixed_value_length { - //take_offset + i is always positive - let take_index = take_offset + i; - builder.append_value(take_index as i32); - } - take_offset += fixed_value_length; - } - None => { - if options.preserve_nulls { - builder.append_null(); - } - take_offset += fixed_value_length; + let mut take_indicies_builder = PrimitiveArray::::builder(capacity); + for row in 0..list_array.len() { + let mut value_length = 0; + if !list_array.is_null(row) { + let (start, end) = list_array.value_offsets(row); + value_length = end - start; + for i in start..end { + take_indicies_builder.append_value(i) } - }); - Ok(kernels::take::take(&values, &builder.finish(), None)?) + } + let target_length = length_array.value(row); + debug_assert!( + value_length <= target_length, + "value length is beyond the longest length" + ); + // Pad with NULL values + for _ in value_length..target_length { + take_indicies_builder.append_null(); + } } + Ok(kernels::take::take( + &values, + &take_indicies_builder.finish(), + None, + )?) } -/// Creates take indicies to be used to expand all other column's data. -/// Every column value needs to be repeated as many times as many elements there is in each corresponding array value. +/// Creates take indicies that will be used to expand all columns except for the unnest [`columns`](UnnestExec::columns). +/// Every column value needs to be repeated multiple times according to the length array. /// -/// If the column being unnested looks like this: +/// If the length array looks like this: /// /// ```ignore -/// [1], null, [2, 3, 4], null, [5, 6] +/// [2, 3, 1] /// ``` -/// Then `create_take_indicies_generic` will return an array like this +/// Then `create_take_indicies` will return an array like this /// /// ```ignore -/// [1, null, 2, 2, 2, null, 4, 4] +/// [0, 0, 1, 1, 1, 2] /// ``` /// -fn create_take_indicies_generic>( - list_array: &GenericListArray, +fn create_take_indicies( + length_array: &PrimitiveArray, capacity: usize, - options: &UnnestOptions, -) -> PrimitiveArray

{ - let mut builder = PrimitiveArray::

::builder(capacity); - let null_repeat: usize = if options.preserve_nulls { 1 } else { 0 }; - - for row in 0..list_array.len() { - let repeat = if list_array.is_null(row) { - null_repeat - } else { - list_array.value(row).len() - }; - - // `index` is a positive interger. - let index = P::Native::from_usize(row).unwrap(); - (0..repeat).for_each(|_| builder.append_value(index)); +) -> PrimitiveArray { + // `find_longest_length()` guarantees this. + debug_assert!( + length_array.null_count() == 0, + "length array should not contain nulls" + ); + let mut builder = PrimitiveArray::::builder(capacity); + for (index, repeat) in length_array.iter().enumerate() { + // The length array should not contain nulls, so unwrap is safe + let repeat = repeat.unwrap(); + (0..repeat).for_each(|_| builder.append_value(index as i64)); } - builder.finish() } -fn create_take_indicies_fixed( - list_array: &FixedSizeListArray, - capacity: usize, - options: &UnnestOptions, -) -> PrimitiveArray { - let mut builder = PrimitiveArray::::builder(capacity); - let null_repeat: usize = if options.preserve_nulls { 1 } else { 0 }; - - for row in 0..list_array.len() { - let repeat = if list_array.is_null(row) { - null_repeat - } else { - list_array.value_length() as usize - }; - - // `index` is a positive interger. - let index = ::Native::from_usize(row).unwrap(); - (0..repeat).for_each(|_| builder.append_value(index)); - } - - builder.finish() -} - -/// Create the final batch given the unnested column array and a `indices` array +/// Create the final batch given the unnested column arrays and a `indices` array /// that is used by the take kernel to copy values. /// /// For example if we have the following `RecordBatch`: @@ -549,8 +547,8 @@ fn create_take_indicies_fixed( /// c2: 'a', 'b', 'c', null, 'd' /// ``` /// -/// then the `unnested_array` contains the unnest column that will replace `c1` in -/// the final batch: +/// then the `unnested_list_arrays` contains the unnest column that will replace `c1` in +/// the final batch if `preserve_nulls` is true: /// /// ```ignore /// c1: 1, null, 2, 3, 4, null, 5, 6 @@ -570,26 +568,19 @@ fn create_take_indicies_fixed( /// c2: 'a', 'b', 'c', 'c', 'c', null, 'd', 'd' /// ``` /// -fn batch_from_indices( +fn batch_from_indices( batch: &RecordBatch, schema: &SchemaRef, - unnest_column_idx: usize, - unnested_array: &ArrayRef, - indices: &PrimitiveArray, -) -> Result -where - T: ArrowPrimitiveType, -{ + unnested_list_arrays: &HashMap, + indices: &PrimitiveArray, +) -> Result { let arrays = batch .columns() .iter() .enumerate() - .map(|(col_idx, arr)| { - if col_idx == unnest_column_idx { - Ok(unnested_array.clone()) - } else { - Ok(kernels::take::take(&arr, indices, None)?) - } + .map(|(col_idx, arr)| match unnested_list_arrays.get(&col_idx) { + Some(unnested_array) => Ok(unnested_array.clone()), + None => Ok(kernels::take::take(arr, indices, None)?), }) .collect::>>()?; @@ -599,51 +590,51 @@ where #[cfg(test)] mod tests { use super::*; - use arrow::{ - array::AsArray, - datatypes::{DataType, Field}, - }; - use arrow_array::StringArray; + use arrow::datatypes::{DataType, Field}; + use arrow_array::{GenericListArray, OffsetSizeTrait, StringArray}; use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; - // Create a ListArray with the following list values: + // Create a GenericListArray with the following list values: // [A, B, C], [], NULL, [D], NULL, [NULL, F] - fn make_test_array() -> ListArray { + fn make_generic_array() -> GenericListArray + where + OffsetSize: OffsetSizeTrait, + { let mut values = vec![]; - let mut offsets = vec![0]; - let mut valid = BooleanBufferBuilder::new(2); + let mut offsets: Vec = vec![OffsetSize::zero()]; + let mut valid = BooleanBufferBuilder::new(6); // [A, B, C] values.extend_from_slice(&[Some("A"), Some("B"), Some("C")]); - offsets.push(values.len() as i32); + offsets.push(OffsetSize::from_usize(values.len()).unwrap()); valid.append(true); // [] - offsets.push(values.len() as i32); + offsets.push(OffsetSize::from_usize(values.len()).unwrap()); valid.append(true); // NULL with non-zero value length // Issue https://github.com/apache/arrow-datafusion/issues/9932 values.push(Some("?")); - offsets.push(values.len() as i32); + offsets.push(OffsetSize::from_usize(values.len()).unwrap()); valid.append(false); // [D] values.push(Some("D")); - offsets.push(values.len() as i32); + offsets.push(OffsetSize::from_usize(values.len()).unwrap()); valid.append(true); // Another NULL with zero value length - offsets.push(values.len() as i32); + offsets.push(OffsetSize::from_usize(values.len()).unwrap()); valid.append(false); // [NULL, F] values.extend_from_slice(&[None, Some("F")]); - offsets.push(values.len() as i32); + offsets.push(OffsetSize::from_usize(values.len()).unwrap()); valid.append(true); let field = Arc::new(Field::new("item", DataType::Utf8, true)); - ListArray::new( + GenericListArray::::new( field, OffsetBuffer::new(offsets.into()), Arc::new(StringArray::from(values)), @@ -651,43 +642,141 @@ mod tests { ) } - #[test] - fn test_unnest_generic_list() -> datafusion_common::Result<()> { - let list_array = make_test_array(); - - // Test with preserve_nulls = false - let options = UnnestOptions { - preserve_nulls: false, - }; - let unnested_array = - unnest_generic_list::(&list_array, &options)?; - let strs = unnested_array.as_string::().iter().collect::>(); - assert_eq!( - strs, - vec![Some("A"), Some("B"), Some("C"), Some("D"), None, Some("F")] - ); + // Create a FixedSizeListArray with the following list values: + // [A, B], NULL, [C, D], NULL, [NULL, F], [NULL, NULL] + fn make_fixed_list() -> FixedSizeListArray { + let values = Arc::new(StringArray::from_iter([ + Some("A"), + Some("B"), + None, + None, + Some("C"), + Some("D"), + None, + None, + None, + Some("F"), + None, + None, + ])); + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let valid = NullBuffer::from(vec![true, false, true, false, true, true]); + FixedSizeListArray::new(field, 2, values, Some(valid)) + } - // Test with preserve_nulls = true - let options = UnnestOptions { - preserve_nulls: true, - }; - let unnested_array = - unnest_generic_list::(&list_array, &options)?; + fn verify_unnest_list_array( + list_array: &dyn ListArrayType, + lengths: Vec, + expected: Vec>, + ) -> datafusion_common::Result<()> { + let length_array = Int64Array::from(lengths); + let unnested_array = unnest_list_array(list_array, &length_array, 3 * 6)?; let strs = unnested_array.as_string::().iter().collect::>(); - assert_eq!( - strs, + assert_eq!(strs, expected); + Ok(()) + } + + #[test] + fn test_unnest_list_array() -> datafusion_common::Result<()> { + // [A, B, C], [], NULL, [D], NULL, [NULL, F] + let list_array = make_generic_array::(); + verify_unnest_list_array( + &list_array, + vec![3, 2, 1, 2, 0, 3], vec![ Some("A"), Some("B"), Some("C"), None, + None, + None, Some("D"), None, None, - Some("F") - ] + Some("F"), + None, + ], + )?; + + // [A, B], NULL, [C, D], NULL, [NULL, F], [NULL, NULL] + let list_array = make_fixed_list(); + verify_unnest_list_array( + &list_array, + vec![3, 1, 2, 0, 2, 3], + vec![ + Some("A"), + Some("B"), + None, + None, + Some("C"), + Some("D"), + None, + Some("F"), + None, + None, + None, + ], + )?; + + Ok(()) + } + + fn verify_longest_length( + list_arrays: &[ArrayRef], + preserve_nulls: bool, + expected: Vec, + ) -> datafusion_common::Result<()> { + let options = UnnestOptions { preserve_nulls }; + let longest_length = find_longest_length(list_arrays, &options)?; + let expected_array = Int64Array::from(expected); + assert_eq!( + longest_length + .as_any() + .downcast_ref::() + .unwrap(), + &expected_array ); + Ok(()) + } + + #[test] + fn test_longest_list_length() -> datafusion_common::Result<()> { + // Test with single ListArray + // [A, B, C], [], NULL, [D], NULL, [NULL, F] + let list_array = Arc::new(make_generic_array::()) as ArrayRef; + verify_longest_length(&[list_array.clone()], false, vec![3, 0, 0, 1, 0, 2])?; + verify_longest_length(&[list_array.clone()], true, vec![3, 0, 1, 1, 1, 2])?; + + // Test with single LargeListArray + // [A, B, C], [], NULL, [D], NULL, [NULL, F] + let list_array = Arc::new(make_generic_array::()) as ArrayRef; + verify_longest_length(&[list_array.clone()], false, vec![3, 0, 0, 1, 0, 2])?; + verify_longest_length(&[list_array.clone()], true, vec![3, 0, 1, 1, 1, 2])?; + + // Test with single FixedSizeListArray + // [A, B], NULL, [C, D], NULL, [NULL, F], [NULL, NULL] + let list_array = Arc::new(make_fixed_list()) as ArrayRef; + verify_longest_length(&[list_array.clone()], false, vec![2, 0, 2, 0, 2, 2])?; + verify_longest_length(&[list_array.clone()], true, vec![2, 1, 2, 1, 2, 2])?; + + // Test with multiple list arrays + // [A, B, C], [], NULL, [D], NULL, [NULL, F] + // [A, B], NULL, [C, D], NULL, [NULL, F], [NULL, NULL] + let list1 = Arc::new(make_generic_array::()) as ArrayRef; + let list2 = Arc::new(make_fixed_list()) as ArrayRef; + let list_arrays = vec![list1.clone(), list2.clone()]; + verify_longest_length(&list_arrays, false, vec![3, 0, 2, 1, 2, 2])?; + verify_longest_length(&list_arrays, true, vec![3, 1, 2, 1, 2, 2])?; Ok(()) } + + #[test] + fn test_create_take_indicies() -> datafusion_common::Result<()> { + let length_array = Int64Array::from(vec![2, 3, 1]); + let take_indicies = create_take_indicies(&length_array, 6); + let expected = Int64Array::from(vec![0, 0, 1, 1, 1, 2]); + assert_eq!(take_indicies, expected); + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e1bcf33b8254..6578c64cff1f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -546,10 +546,10 @@ enum ScalarFunction { // 2 was Asin // 3 was Atan // 4 was Ascii - Ceil = 5; + // 5 was Ceil // 6 was Cos // 7 was Digest - Exp = 8; + // 8 was Exp // 9 was Floor // 10 was Ln // 11 was Log @@ -624,7 +624,7 @@ enum ScalarFunction { // 80 was Pi // 81 was Degrees // 82 was Radians - Factorial = 83; + // 83 was Factorial // 84 was Lcm // 85 was Gcd // 86 was ArrayAppend diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7beaeef0e58b..1546d75f2acd 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22792,13 +22792,10 @@ impl serde::Serialize for ScalarFunction { { let variant = match self { Self::Unknown => "unknown", - Self::Ceil => "Ceil", - Self::Exp => "Exp", Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", Self::Coalesce => "Coalesce", - Self::Factorial => "Factorial", Self::EndsWith => "EndsWith", }; serializer.serialize_str(variant) @@ -22812,13 +22809,10 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { const FIELDS: &[&str] = &[ "unknown", - "Ceil", - "Exp", "Concat", "ConcatWithSeparator", "InitCap", "Coalesce", - "Factorial", "EndsWith", ]; @@ -22861,13 +22855,10 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { match value { "unknown" => Ok(ScalarFunction::Unknown), - "Ceil" => Ok(ScalarFunction::Ceil), - "Exp" => Ok(ScalarFunction::Exp), "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), "Coalesce" => Ok(ScalarFunction::Coalesce), - "Factorial" => Ok(ScalarFunction::Factorial), "EndsWith" => Ok(ScalarFunction::EndsWith), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d6a27dbc5652..c752743cbdce 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2846,10 +2846,10 @@ pub enum ScalarFunction { /// 2 was Asin /// 3 was Atan /// 4 was Ascii - Ceil = 5, + /// 5 was Ceil /// 6 was Cos /// 7 was Digest - Exp = 8, + /// 8 was Exp /// 9 was Floor /// 10 was Ln /// 11 was Log @@ -2924,7 +2924,7 @@ pub enum ScalarFunction { /// 80 was Pi /// 81 was Degrees /// 82 was Radians - Factorial = 83, + /// 83 was Factorial /// 84 was Lcm /// 85 was Gcd /// 86 was ArrayAppend @@ -2988,13 +2988,10 @@ impl ScalarFunction { pub fn as_str_name(&self) -> &'static str { match self { ScalarFunction::Unknown => "unknown", - ScalarFunction::Ceil => "Ceil", - ScalarFunction::Exp => "Exp", ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", ScalarFunction::Coalesce => "Coalesce", - ScalarFunction::Factorial => "Factorial", ScalarFunction::EndsWith => "EndsWith", } } @@ -3002,13 +2999,10 @@ impl ScalarFunction { pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "unknown" => Some(Self::Unknown), - "Ceil" => Some(Self::Ceil), - "Exp" => Some(Self::Exp), "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), "Coalesce" => Some(Self::Coalesce), - "Factorial" => Some(Self::Factorial), "EndsWith" => Some(Self::EndsWith), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 057690aacee6..e66bd1a5f0a9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -37,9 +37,9 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - ceil, coalesce, concat_expr, concat_ws_expr, ends_with, exp, + coalesce, concat_expr, concat_ws_expr, ends_with, expr::{self, InList, Sort, WindowFunction}, - factorial, initcap, + initcap, logical_plan::{PlanType, StringifiedPlan}, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, @@ -418,9 +418,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { use protobuf::ScalarFunction; match f { ScalarFunction::Unknown => todo!(), - ScalarFunction::Exp => Self::Exp, - ScalarFunction::Factorial => Self::Factorial, - ScalarFunction::Ceil => Self::Ceil, ScalarFunction::Concat => Self::Concat, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, @@ -1260,8 +1257,11 @@ pub fn parse_expr( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Unnest(unnest) => { - let exprs = parse_exprs(&unnest.exprs, registry, codec)?; - Ok(Expr::Unnest(Unnest { exprs })) + let mut exprs = parse_exprs(&unnest.exprs, registry, codec)?; + if exprs.len() != 1 { + return Err(proto_error("Unnest must have exactly one expression")); + } + Ok(Expr::Unnest(Unnest::new(exprs.swap_remove(0)))) } ExprType::InList(in_list) => Ok(Expr::InList(InList::new( Box::new(parse_required_expr( @@ -1287,11 +1287,6 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), - ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Factorial => { - Ok(factorial(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry, codec)?)), ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 358eea785713..4916b4bed9a3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -963,9 +963,9 @@ pub fn serialize_expr( expr_type: Some(ExprType::Negative(expr)), } } - Expr::Unnest(Unnest { exprs }) => { + Expr::Unnest(Unnest { expr }) => { let expr = protobuf::Unnest { - exprs: serialize_exprs(exprs, codec)?, + exprs: vec![serialize_expr(expr.as_ref(), codec)?], }; protobuf::LogicalExprNode { expr_type: Some(ExprType::Unnest(expr)), @@ -1407,9 +1407,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { fn try_from(scalar: &BuiltinScalarFunction) -> Result { let scalar_function = match scalar { - BuiltinScalarFunction::Exp => Self::Exp, - BuiltinScalarFunction::Factorial => Self::Factorial, - BuiltinScalarFunction::Ceil => Self::Ceil, BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e680a1b2ff1e..eee15008fbbb 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1599,7 +1599,7 @@ fn roundtrip_inlist() { #[test] fn roundtrip_unnest() { let test_expr = Expr::Unnest(Unnest { - exprs: vec![lit(1), lit(2), lit(3)], + expr: Box::new(col("col")), }); let ctx = SessionContext::new(); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 4bf0906685ca..c225afec58d6 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -119,10 +119,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Build Unnest expression if name.eq("unnest") { - let exprs = + let mut exprs = self.function_args_to_expr(args.clone(), schema, planner_context)?; - Self::check_unnest_args(&exprs, schema)?; - return Ok(Expr::Unnest(Unnest { exprs })); + if exprs.len() != 1 { + return plan_err!("unnest() requires exactly one argument"); + } + let expr = exprs.swap_remove(0); + Self::check_unnest_arg(&expr, schema)?; + return Ok(Expr::Unnest(Unnest::new(expr))); } // next, scalar built-in @@ -282,17 +286,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } - pub(super) fn sql_named_function_to_expr( - &self, - expr: SQLExpr, - fun: BuiltinScalarFunction, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - let args = vec![self.sql_expr_to_logical_expr(expr, schema, planner_context)?]; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) - } - pub(super) fn find_window_func( &self, name: &str, @@ -353,17 +346,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>() } - pub(crate) fn check_unnest_args(args: &[Expr], schema: &DFSchema) -> Result<()> { - // Currently only one argument is supported - let arg = match args.len() { - 0 => { - return plan_err!("unnest() requires at least one argument"); - } - 1 => &args[0], - _ => { - return not_impl_err!("unnest() does not support multiple arguments yet"); - } - }; + pub(crate) fn check_unnest_arg(arg: &Expr, schema: &DFSchema) -> Result<()> { // Check argument type, array types are supported match arg.get_type(schema)? { DataType::List(_) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 7763fa2d8dab..f07377ce50e1 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -28,9 +28,8 @@ use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, - Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Literal, Operator, - TryCast, + col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, + GetFieldAccess, GetIndexedField, Like, Literal, Operator, TryCast, }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -522,12 +521,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Ceil { expr, field: _field, - } => self.sql_named_function_to_expr( - *expr, - BuiltinScalarFunction::Ceil, - schema, - planner_context, - ), + } => self.sql_fn_name_to_expr(*expr, "ceil", schema, planner_context), SQLExpr::Overlay { expr, overlay_what, diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 1e01205ba618..9380e569f2e4 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -105,15 +105,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Unnest table factor has empty input let schema = DFSchema::empty(); let input = LogicalPlanBuilder::empty(true).build()?; - let exprs = array_exprs + // Unnest table factor can have multiple arugments. + // We treat each argument as a separate unnest expression. + let unnest_exprs = array_exprs .into_iter() - .map(|expr| { - self.sql_expr_to_logical_expr(expr, &schema, planner_context) + .map(|sql_expr| { + let expr = self.sql_expr_to_logical_expr( + sql_expr, + &schema, + planner_context, + )?; + Self::check_unnest_arg(&expr, &schema)?; + Ok(Expr::Unnest(Unnest::new(expr))) }) .collect::>>()?; - Self::check_unnest_args(&exprs, &schema)?; - let unnest_expr = Expr::Unnest(Unnest { exprs }); - let logical_plan = self.try_process_unnest(input, vec![unnest_expr])?; + if unnest_exprs.is_empty() { + return plan_err!("UNNEST must have at least one argument"); + } + let logical_plan = self.try_process_unnest(input, unnest_exprs)?; (logical_plan, alias) } TableFactor::UNNEST { .. } => { diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 1bfd60a8ce1a..30eacdb44c4a 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -294,13 +294,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { transformed, tnr: _, } = expr.transform_up_mut(&mut |expr: Expr| { - if let Expr::Unnest(Unnest { ref exprs }) = expr { + if let Expr::Unnest(Unnest { expr: ref arg }) = expr { let column_name = expr.display_name()?; unnest_columns.push(column_name.clone()); // Add alias for the argument expression, to avoid naming conflicts with other expressions // in the select list. For example: `select unnest(col1), col1 from t`. inner_projection_exprs - .push(exprs[0].clone().alias(column_name.clone())); + .push(arg.clone().alias(column_name.clone())); Ok(Transformed::yes(Expr::Column(Column::from_name( column_name, )))) @@ -332,15 +332,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .project(inner_projection_exprs)? .build() } else { - if unnest_columns.len() > 1 { - return not_impl_err!("Only support single unnest expression for now"); - } - let unnest_column = unnest_columns.pop().unwrap(); + let columns = unnest_columns.into_iter().map(|col| col.into()).collect(); // Set preserve_nulls to false to ensure compatibility with DuckDB and PostgreSQL let unnest_options = UnnestOptions::new().with_preserve_nulls(false); LogicalPlanBuilder::from(input) .project(inner_projection_exprs)? - .unnest_column_with_options(unnest_column, unnest_options)? + .unnest_columns_with_options(columns, unnest_options)? .project(outer_projection_exprs)? .build() } diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 5c178bb392b1..38207fa7d1d6 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -22,12 +22,12 @@ statement ok CREATE TABLE unnest_table AS VALUES - ([1,2,3], [7], 1), - ([4,5], [8,9,10], 2), - ([6], [11,12], 3), - ([12], [null, 42, null], null), + ([1,2,3], [7], 1, [13, 14]), + ([4,5], [8,9,10], 2, [15, 16]), + ([6], [11,12], 3, null), + ([12], [null, 42, null], null, null), -- null array to verify the `preserve_nulls` option - (null, null, 4) + (null, null, 4, [17, 18]) ; ## Basic unnest expression in select list @@ -93,6 +93,20 @@ NULL 42 NULL +## Unnest single column and filter out null lists +query I +select unnest(column2) from unnest_table where column2 is not null; +---- +7 +8 +9 +10 +11 +12 +NULL +42 +NULL + ## Unnest with additional column ## Issue: https://github.com/apache/arrow-datafusion/issues/9349 query II @@ -135,9 +149,48 @@ select array_remove(column1, 4), unnest(column2), column3 * 10 from unnest_table query error DataFusion error: Error during planning: unnest\(\) can only be applied to array, struct and null select unnest(column3) from unnest_table; +## Unnest doesn't work with untyped nulls +query error DataFusion error: This feature is not implemented: unnest\(\) does not support null yet +select unnest(null) from unnest_table; + ## Multiple unnest functions in selection -query error DataFusion error: This feature is not implemented: Only support single unnest expression for now -select unnest(column1), unnest(column2) from unnest_table; +query ?I +select unnest([]), unnest(NULL::int[]); +---- + +query III +select + unnest(column1), + unnest(arrow_cast(column2, 'LargeList(Int64)')), + unnest(arrow_cast(column4, 'FixedSizeList(2, Int64)')) +from unnest_table where column4 is not null; +---- +1 7 13 +2 NULL 14 +3 NULL NULL +4 8 15 +5 9 16 +NULL 10 NULL +NULL NULL 17 +NULL NULL 18 + +query IIII +select + unnest(column1), unnest(column2) + 2, + column3 * 10, unnest(array_remove(column1, '4')) +from unnest_table; +---- +1 9 10 1 +2 NULL 10 2 +3 NULL 10 3 +4 10 20 5 +5 11 20 NULL +NULL 12 20 NULL +6 13 30 6 +NULL 14 30 NULL +12 NULL NULL 12 +NULL 44 NULL NULL +NULL NULL NULL NULL ## Unnest scalar in select list query error DataFusion error: Error during planning: unnest\(\) can only be applied to array, struct and null @@ -149,7 +202,7 @@ select * from unnest(1); ## Unnest empty expression in select list -query error DataFusion error: Error during planning: unnest\(\) requires at least one argument +query error DataFusion error: Error during planning: unnest\(\) requires exactly one argument select unnest(); ## Unnest empty expression in from clause @@ -157,13 +210,26 @@ query error DataFusion error: SQL error: ParserError\("Expected an expression:, select * from unnest(); -## Unnest multiple expressions in select list -query error DataFusion error: This feature is not implemented: unnest\(\) does not support multiple arguments yet +## Unnest multiple expressions in select list. This form is only allowed in a query's FROM clause. +query error DataFusion error: Error during planning: unnest\(\) requires exactly one argument select unnest([1,2], [2,3]); ## Unnest multiple expressions in from clause -query error DataFusion error: This feature is not implemented: unnest\(\) does not support multiple arguments yet -select * from unnest([1,2], [2,3]); +query ITII +select * from unnest( + [1,2], + arrow_cast(['a','b', 'c'], 'LargeList(Utf8)'), + arrow_cast([4, NULL], 'FixedSizeList(2, Int64)'), + NULL::int[] +) as t(a, b, c, d); +---- +1 a 4 NULL +2 b NULL NULL +NULL c NULL NULL + +query ?I +select * from unnest([], NULL::int[]); +---- ## Unnest struct expression in select list