From ede665b212ec5e3a673986424103645bcb49e3bc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 21 Dec 2024 02:19:42 -0500 Subject: [PATCH 1/8] Fix build (#13869) --- datafusion/sql/src/statement.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index e934e5d3ca8c..4fa359ebe00d 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -62,8 +62,8 @@ use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, CreateTableOptions, Delete, DescribeAlias, Expr as SQLExpr, FromTable, Ident, Insert, ObjectName, ObjectType, OneOrManyWithParens, Query, SchemaName, SetExpr, - ShowCreateObject, Statement, TableConstraint, TableFactor, TableWithJoins, - TransactionMode, UnaryOperator, Value, + ShowCreateObject, ShowStatementFilter, Statement, TableConstraint, TableFactor, + TableWithJoins, TransactionMode, UnaryOperator, Value, }; use sqlparser::parser::ParserError::ParserError; From b99400e142a92c6e580fc5364196294c1eb1c91b Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Sat, 21 Dec 2024 06:20:25 -0800 Subject: [PATCH 2/8] feat(substrait): modular substrait consumer (#13803) * feat(substrait): modular substrait consumer * feat(substrait): include Extension Rel handlers in default consumer Include SerializerRegistry based handlers for Extension Relations in the DefaultSubstraitConsumer * refactor(substrait) _selection -> _field_reference * refactor(substrait): remove SubstraitPlannerState usage from consumer * refactor: get_state() -> get_function_registry() * docs: elide imports from example * test: simplify test * refactor: remove Arc from DefaultSubstraitConsumer * doc: add ticket for API improvements * doc: link DefaultSubstraitConsumer to from_subtrait_plan * refactor: remove redundant Extensions parsing --- .../substrait/src/logical_plan/consumer.rs | 2373 ++++++++++------- .../substrait/src/logical_plan/producer.rs | 22 +- .../tests/cases/roundtrip_logical_plan.rs | 3 +- datafusion/substrait/tests/utils.rs | 39 +- 4 files changed, 1442 insertions(+), 995 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9f98fdace6a0..9aa3f008040c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -21,19 +21,19 @@ use datafusion::arrow::array::{GenericListArray, MapArray}; use datafusion::arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; -use datafusion::common::plan_err; use datafusion::common::{ - not_impl_err, plan_datafusion_err, substrait_datafusion_err, substrait_err, DFSchema, - DFSchemaRef, + not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, + substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, }; use datafusion::datasource::provider_as_source; use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ - Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, - Operator, Projection, SortExpr, TryCast, Values, + Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, + LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values, }; use substrait::proto::aggregate_rel::Grouping; +use substrait::proto::expression as substrait_expression; use substrait::proto::expression::subquery::set_predicate::PredicateOp; use substrait::proto::expression_reference::ExprType; use url::Url; @@ -53,14 +53,17 @@ use crate::variation_const::{ TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, }; +use async_trait::async_trait; use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::catalog::TableProvider; use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::execution::{FunctionRegistry, SessionState}; use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ - col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, - Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + col, expr, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion::prelude::{lit, JoinType}; use datafusion::sql::TableReference; @@ -70,17 +73,21 @@ use datafusion::{ }; use std::collections::HashSet; use std::sync::Arc; +use substrait::proto; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::cast::FailureBehavior::ReturnNull; use substrait::proto::expression::literal::user_defined::Val; use substrait::proto::expression::literal::{ interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, - UserDefined, }; use substrait::proto::expression::subquery::SubqueryType; -use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; +use substrait::proto::expression::{ + Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction, + SingularOrList, SwitchExpression, WindowFunction, +}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; use substrait::proto::rel_common::{Emit, EmitKind}; +use substrait::proto::set_rel::SetOp; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -90,17 +97,469 @@ use substrait::proto::{ window_function::bound::Kind as BoundKind, window_function::Bound, window_function::BoundsType, MaskExpression, RexType, }, + fetch_rel, function_argument::ArgType, join_rel, plan_rel, r#type, read_rel::ReadType, rel::RelType, - rel_common, set_rel, + rel_common, sort_field::{SortDirection, SortKind::*}, - AggregateFunction, Expression, NamedStruct, Plan, Rel, RelCommon, Type, + AggregateFunction, AggregateRel, ConsistentPartitionWindowRel, CrossRel, ExchangeRel, + Expression, ExtendedExpression, ExtensionLeafRel, ExtensionMultiRel, + ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument, JoinRel, NamedStruct, + Plan, ProjectRel, ReadRel, Rel, RelCommon, SetRel, SortField, SortRel, Type, }; -use substrait::proto::{fetch_rel, ExtendedExpression, FunctionArgument, SortField}; -use super::state::SubstraitPlanningState; +#[async_trait] +/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// # Example Usage +/// +/// ``` +/// # use async_trait::async_trait; +/// # use datafusion::catalog::TableProvider; +/// # use datafusion::common::{not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference}; +/// # use datafusion::error::Result; +/// # use datafusion::execution::{FunctionRegistry, SessionState}; +/// # use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +/// # use std::sync::Arc; +/// # use substrait::proto; +/// # use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel}; +/// # use datafusion::arrow::datatypes::DataType; +/// # use datafusion::logical_expr::expr::ScalarFunction; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::consumer::{ +/// # from_project_rel, from_substrait_rel, from_substrait_rex, SubstraitConsumer +/// # }; +/// +/// struct CustomSubstraitConsumer { +/// extensions: Arc, +/// state: Arc, +/// } +/// +/// #[async_trait] +/// impl SubstraitConsumer for CustomSubstraitConsumer { +/// async fn resolve_table_ref( +/// &self, +/// table_ref: &TableReference, +/// ) -> Result>> { +/// let table = table_ref.table().to_string(); +/// let schema = self.state.schema_for_ref(table_ref.clone())?; +/// let table_provider = schema.table(&table).await?; +/// Ok(table_provider) +/// } +/// +/// fn get_extensions(&self) -> &Extensions { +/// self.extensions.as_ref() +/// } +/// +/// fn get_function_registry(&self) -> &impl FunctionRegistry { +/// self.state.as_ref() +/// } +/// +/// // You can reuse existing consumer code to assist in handling advanced extensions +/// async fn consume_project(&self, rel: &ProjectRel) -> Result { +/// let df_plan = from_project_rel(self, rel).await?; +/// if let Some(advanced_extension) = rel.advanced_extension.as_ref() { +/// not_impl_err!( +/// "decode and handle an advanced extension: {:?}", +/// advanced_extension +/// ) +/// } else { +/// Ok(df_plan) +/// } +/// } +/// +/// // You can implement a fully custom consumer method if you need special handling +/// async fn consume_filter(&self, rel: &FilterRel) -> Result { +/// let input = from_substrait_rel(self, rel.input.as_ref().unwrap()).await?; +/// let expression = +/// from_substrait_rex(self, rel.condition.as_ref().unwrap(), input.schema()) +/// .await?; +/// // though this one is quite boring +/// LogicalPlanBuilder::from(input).filter(expression)?.build() +/// } +/// +/// // You can add handlers for extension relations +/// async fn consume_extension_leaf( +/// &self, +/// rel: &ExtensionLeafRel, +/// ) -> Result { +/// not_impl_err!( +/// "handle protobuf Any {} as you need", +/// rel.detail.as_ref().unwrap().type_url +/// ) +/// } +/// +/// // and handlers for user-define types +/// fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&typ.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// +/// // and user-defined literals +/// fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&literal.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// } +/// ``` +/// +pub trait SubstraitConsumer: Send + Sync + Sized { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> Result>>; + + // TODO: Remove these two methods + // Ideally, the abstract consumer should not place any constraints on implementations. + // The functionality for which the Extensions and FunctionRegistry is needed should be abstracted + // out into methods on the trait. As an example, resolve_table_reference is such a method. + // See: https://github.com/apache/datafusion/issues/13863 + fn get_extensions(&self) -> &Extensions; + fn get_function_registry(&self) -> &impl FunctionRegistry; + + // Relation Methods + // There is one method per Substrait relation to allow for easy overriding of consumer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + async fn consume_read(&self, rel: &ReadRel) -> Result { + from_read_rel(self, rel).await + } + + async fn consume_filter(&self, rel: &FilterRel) -> Result { + from_filter_rel(self, rel).await + } + + async fn consume_fetch(&self, rel: &FetchRel) -> Result { + from_fetch_rel(self, rel).await + } + + async fn consume_aggregate(&self, rel: &AggregateRel) -> Result { + from_aggregate_rel(self, rel).await + } + + async fn consume_sort(&self, rel: &SortRel) -> Result { + from_sort_rel(self, rel).await + } + + async fn consume_join(&self, rel: &JoinRel) -> Result { + from_join_rel(self, rel).await + } + + async fn consume_project(&self, rel: &ProjectRel) -> Result { + from_project_rel(self, rel).await + } + + async fn consume_set(&self, rel: &SetRel) -> Result { + from_set_rel(self, rel).await + } + + async fn consume_cross(&self, rel: &CrossRel) -> Result { + from_cross_rel(self, rel).await + } + + async fn consume_consistent_partition_window( + &self, + _rel: &ConsistentPartitionWindowRel, + ) -> Result { + not_impl_err!("Consistent Partition Window Rel not supported") + } + + async fn consume_exchange(&self, rel: &ExchangeRel) -> Result { + from_exchange_rel(self, rel).await + } + + // Expression Methods + // There is one method per Substrait expression to allow for easy overriding of consumer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + async fn consume_literal(&self, expr: &Literal) -> Result { + from_literal(self, expr).await + } + + async fn consume_field_reference( + &self, + expr: &FieldReference, + input_schema: &DFSchema, + ) -> Result { + from_field_reference(self, expr, input_schema).await + } + + async fn consume_scalar_function( + &self, + expr: &ScalarFunction, + input_schema: &DFSchema, + ) -> Result { + from_scalar_function(self, expr, input_schema).await + } + + async fn consume_window_function( + &self, + expr: &WindowFunction, + input_schema: &DFSchema, + ) -> Result { + from_window_function(self, expr, input_schema).await + } + + async fn consume_if_then( + &self, + expr: &IfThen, + input_schema: &DFSchema, + ) -> Result { + from_if_then(self, expr, input_schema).await + } + + async fn consume_switch( + &self, + _expr: &SwitchExpression, + _input_schema: &DFSchema, + ) -> Result { + not_impl_err!("Switch expression not supported") + } + + async fn consume_singular_or_list( + &self, + expr: &SingularOrList, + input_schema: &DFSchema, + ) -> Result { + from_singular_or_list(self, expr, input_schema).await + } + + async fn consume_multi_or_list( + &self, + _expr: &MultiOrList, + _input_schema: &DFSchema, + ) -> Result { + not_impl_err!("Multi Or List expression not supported") + } + + async fn consume_cast( + &self, + expr: &substrait_expression::Cast, + input_schema: &DFSchema, + ) -> Result { + from_cast(self, expr, input_schema).await + } + + async fn consume_subquery( + &self, + expr: &substrait_expression::Subquery, + input_schema: &DFSchema, + ) -> Result { + from_subquery(self, expr, input_schema).await + } + + async fn consume_nested( + &self, + _expr: &Nested, + _input_schema: &DFSchema, + ) -> Result { + not_impl_err!("Nested expression not supported") + } + + async fn consume_enum(&self, _expr: &Enum, _input_schema: &DFSchema) -> Result { + not_impl_err!("Enum expression not supported") + } + + // User-Defined Functionality + + // The details of extension relations, and how to handle them, are fully up to users to specify. + // The following methods allow users to customize the consumer behaviour + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionLeafRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionLeafRel") + } + + async fn consume_extension_single( + &self, + rel: &ExtensionSingleRel, + ) -> Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionSingleRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionSingleRel") + } + + async fn consume_extension_multi( + &self, + rel: &ExtensionMultiRel, + ) -> Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionMultiRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionMultiRel") + } + + // Users can bring their own types to Substrait which require custom handling + + fn consume_user_defined_type( + &self, + user_defined_type: &r#type::UserDefined, + ) -> Result { + substrait_err!( + "Missing handler for user-defined type: {}", + user_defined_type.type_reference + ) + } + + fn consume_user_defined_literal( + &self, + user_defined_literal: &proto::expression::literal::UserDefined, + ) -> Result { + substrait_err!( + "Missing handler for user-defined literals {}", + user_defined_literal.type_reference + ) + } +} + +/// Convert Substrait Rel to DataFusion DataFrame +#[async_recursion] +pub async fn from_substrait_rel( + consumer: &impl SubstraitConsumer, + relation: &Rel, +) -> Result { + let plan: Result = match &relation.rel_type { + Some(rel_type) => match rel_type { + RelType::Read(rel) => consumer.consume_read(rel).await, + RelType::Filter(rel) => consumer.consume_filter(rel).await, + RelType::Fetch(rel) => consumer.consume_fetch(rel).await, + RelType::Aggregate(rel) => consumer.consume_aggregate(rel).await, + RelType::Sort(rel) => consumer.consume_sort(rel).await, + RelType::Join(rel) => consumer.consume_join(rel).await, + RelType::Project(rel) => consumer.consume_project(rel).await, + RelType::Set(rel) => consumer.consume_set(rel).await, + RelType::ExtensionSingle(rel) => consumer.consume_extension_single(rel).await, + RelType::ExtensionMulti(rel) => consumer.consume_extension_multi(rel).await, + RelType::ExtensionLeaf(rel) => consumer.consume_extension_leaf(rel).await, + RelType::Cross(rel) => consumer.consume_cross(rel).await, + RelType::Window(rel) => { + consumer.consume_consistent_partition_window(rel).await + } + RelType::Exchange(rel) => consumer.consume_exchange(rel).await, + rt => not_impl_err!("{rt:?} rel not supported yet"), + }, + None => return substrait_err!("rel must set rel_type"), + }; + apply_emit_kind(retrieve_rel_common(relation), plan?) +} + +/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions. +/// +/// Used as the consumer in [from_substrait_plan] +pub struct DefaultSubstraitConsumer<'a> { + extensions: &'a Extensions, + state: &'a SessionState, +} + +impl<'a> DefaultSubstraitConsumer<'a> { + pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self { + DefaultSubstraitConsumer { extensions, state } + } +} + +#[async_trait] +impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> Result>> { + let table = table_ref.table().to_string(); + let schema = self.state.schema_for_ref(table_ref.clone())?; + let table_provider = schema.table(&table).await?; + Ok(table_provider) + } + + fn get_extensions(&self) -> &Extensions { + self.extensions + } + + fn get_function_registry(&self) -> &impl FunctionRegistry { + self.state + } + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } + + async fn consume_extension_single( + &self, + rel: &ExtensionSingleRel, + ) -> Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + let Some(input_rel) = &rel.input else { + return substrait_err!( + "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead" + ); + }; + let input_plan = from_substrait_rel(self, input_rel).await?; + let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } + + async fn consume_extension_multi( + &self, + rel: &ExtensionMultiRel, + ) -> Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionMultiRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + let mut inputs = Vec::with_capacity(rel.inputs.len()); + for input in &rel.inputs { + let input_plan = from_substrait_rel(self, input).await?; + inputs.push(input_plan); + } + let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } +} // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which // is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone @@ -202,16 +661,15 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( } async fn union_rels( + consumer: &impl SubstraitConsumer, rels: &[Rel], - state: &dyn SubstraitPlanningState, - extensions: &Extensions, is_all: bool, ) -> Result { let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(state, &rels[0], extensions).await?, + from_substrait_rel(consumer, &rels[0]).await?, )); for input in &rels[1..] { - let rel_plan = from_substrait_rel(state, input, extensions).await?; + let rel_plan = from_substrait_rel(consumer, input).await?; union_builder = if is_all { union_builder?.union(rel_plan) @@ -223,17 +681,16 @@ async fn union_rels( } async fn intersect_rels( + consumer: &impl SubstraitConsumer, rels: &[Rel], - state: &dyn SubstraitPlanningState, - extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(consumer, &rels[0]).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::intersect( rel, - from_substrait_rel(state, input, extensions).await?, + from_substrait_rel(consumer, input).await?, is_all, )? } @@ -242,17 +699,16 @@ async fn intersect_rels( } async fn except_rels( + consumer: &impl SubstraitConsumer, rels: &[Rel], - state: &dyn SubstraitPlanningState, - extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(consumer, &rels[0]).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::except( rel, - from_substrait_rel(state, input, extensions).await?, + from_substrait_rel(consumer, input).await?, is_all, )? } @@ -262,7 +718,7 @@ async fn except_rels( /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( - state: &dyn SubstraitPlanningState, + state: &SessionState, plan: &Plan, ) -> Result { // Register function extension @@ -271,16 +727,27 @@ pub async fn from_substrait_plan( return not_impl_err!("Type variation extensions are not supported"); } - // Parse relations + let consumer = DefaultSubstraitConsumer { + extensions: &extensions, + state, + }; + from_substrait_plan_with_consumer(&consumer, plan).await +} + +/// Convert Substrait Plan to DataFusion LogicalPlan using the given consumer +pub async fn from_substrait_plan_with_consumer( + consumer: &impl SubstraitConsumer, + plan: &Plan, +) -> Result { match plan.relations.len() { 1 => { match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(state, rel, &extensions).await?) + Ok(from_substrait_rel(consumer, rel).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(state, root.input.as_ref().unwrap(), &extensions).await?; + let plan = from_substrait_rel(consumer, root.input.as_ref().unwrap()).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -341,7 +808,7 @@ pub struct ExprContainer { /// between systems. This is often useful for scenarios like pushdown where filter /// expressions need to be sent to remote systems. pub async fn from_substrait_extended_expr( - state: &dyn SubstraitPlanningState, + state: &SessionState, extended_expr: &ExtendedExpression, ) -> Result { // Register function extension @@ -350,8 +817,13 @@ pub async fn from_substrait_extended_expr( return not_impl_err!("Type variation extensions are not supported"); } + let consumer = DefaultSubstraitConsumer { + extensions: &extensions, + state, + }; + let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { - Some(base_schema) => from_substrait_named_struct(base_schema, &extensions), + Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), None => { plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") } @@ -369,8 +841,7 @@ pub async fn from_substrait_extended_expr( plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") } }?; - let expr = - from_substrait_rex(state, scalar_expr, &input_schema, &extensions).await?; + let expr = from_substrait_rex(&consumer, scalar_expr, &input_schema).await?; let (output_type, expected_nullability) = expr.data_type_and_nullable(&input_schema)?; let output_field = Field::new("", output_type, expected_nullability); @@ -557,583 +1028,498 @@ fn make_renamed_schema( ) } -/// Convert Substrait Rel to DataFusion DataFrame -#[allow(deprecated)] #[async_recursion] -pub async fn from_substrait_rel( - state: &dyn SubstraitPlanningState, - rel: &Rel, - extensions: &Extensions, +pub async fn from_project_rel( + consumer: &impl SubstraitConsumer, + p: &ProjectRel, ) -> Result { - let plan: Result = match &rel.rel_type { - Some(RelType::Project(p)) => { - if let Some(input) = p.input.as_ref() { - let mut input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - let original_schema = input.schema().clone(); - - // Ensure that all expressions have a unique display name, so that - // validate_unique_names does not fail when constructing the project. - let mut name_tracker = NameTracker::new(); - - // By default, a Substrait Project emits all inputs fields followed by all expressions. - // We build the explicit expressions first, and then the input expressions to avoid - // adding aliases to the explicit expressions (as part of ensuring unique names). - // - // This is helpful for plan visualization and tests, because when DataFusion produces - // Substrait Projects it adds an output mapping that excludes all input columns - // leaving only explicit expressions. - - let mut explicit_exprs: Vec = vec![]; - for expr in &p.expressions { - let e = from_substrait_rex( - state, - expr, - input.clone().schema(), - extensions, - ) - .await?; - // if the expression is WindowFunction, wrap in a Window relation - if let Expr::WindowFunction(_) = &e { - // Adding the same expression here and in the project below - // works because the project's builder uses columnize_expr(..) - // to transform it into a column reference - input = input.window(vec![e.clone()])? - } - explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } + if let Some(input) = p.input.as_ref() { + let mut input = + LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let original_schema = input.schema().clone(); + + // Ensure that all expressions have a unique display name, so that + // validate_unique_names does not fail when constructing the project. + let mut name_tracker = NameTracker::new(); + + // By default, a Substrait Project emits all inputs fields followed by all expressions. + // We build the explicit expressions first, and then the input expressions to avoid + // adding aliases to the explicit expressions (as part of ensuring unique names). + // + // This is helpful for plan visualization and tests, because when DataFusion produces + // Substrait Projects it adds an output mapping that excludes all input columns + // leaving only explicit expressions. + + let mut explicit_exprs: Vec = vec![]; + for expr in &p.expressions { + let e = from_substrait_rex(consumer, expr, input.clone().schema()).await?; + // if the expression is WindowFunction, wrap in a Window relation + if let Expr::WindowFunction(_) = &e { + // Adding the same expression here and in the project below + // works because the project's builder uses columnize_expr(..) + // to transform it into a column reference + input = input.window(vec![e.clone()])? + } + explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); + } - let mut final_exprs: Vec = vec![]; - for index in 0..original_schema.fields().len() { - let e = Expr::Column(Column::from( - original_schema.qualified_field(index), - )); - final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } - final_exprs.append(&mut explicit_exprs); + let mut final_exprs: Vec = vec![]; + for index in 0..original_schema.fields().len() { + let e = Expr::Column(Column::from(original_schema.qualified_field(index))); + final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); + } + final_exprs.append(&mut explicit_exprs); + input.project(final_exprs)?.build() + } else { + not_impl_err!("Projection without an input is not supported") + } +} - input.project(final_exprs)?.build() - } else { - not_impl_err!("Projection without an input is not supported") - } +#[async_recursion] +pub async fn from_filter_rel( + consumer: &impl SubstraitConsumer, + filter: &FilterRel, +) -> Result { + if let Some(input) = filter.input.as_ref() { + let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + if let Some(condition) = filter.condition.as_ref() { + let expr = from_substrait_rex(consumer, condition, input.schema()).await?; + input.filter(expr)?.build() + } else { + not_impl_err!("Filter without an condition is not valid") } - Some(RelType::Filter(filter)) => { - if let Some(input) = filter.input.as_ref() { - let input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - if let Some(condition) = filter.condition.as_ref() { - let expr = - from_substrait_rex(state, condition, input.schema(), extensions) - .await?; - input.filter(expr)?.build() - } else { - not_impl_err!("Filter without an condition is not valid") - } - } else { - not_impl_err!("Filter without an input is not valid") + } else { + not_impl_err!("Filter without an input is not valid") + } +} + +#[async_recursion] +pub async fn from_fetch_rel( + consumer: &impl SubstraitConsumer, + fetch: &FetchRel, +) -> Result { + if let Some(input) = fetch.input.as_ref() { + let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let offset = match &fetch.offset_mode { + Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), + Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { + Some(from_substrait_rex(consumer, expr, &empty_schema).await?) } - } - Some(RelType::Fetch(fetch)) => { - if let Some(input) = fetch.input.as_ref() { - let input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let offset = match &fetch.offset_mode { - Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), - Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => Some( - from_substrait_rex(state, expr, &empty_schema, extensions) - .await?, - ), - None => None, - }; - let count = match &fetch.count_mode { - Some(fetch_rel::CountMode::Count(count)) => { - // -1 means that ALL records should be returned, equivalent to None - (*count != -1).then(|| lit(*count)) - } - Some(fetch_rel::CountMode::CountExpr(expr)) => Some( - from_substrait_rex(state, expr, &empty_schema, extensions) - .await?, - ), - None => None, - }; - input.limit_by_expr(offset, count)?.build() - } else { - not_impl_err!("Fetch without an input is not valid") + None => None, + }; + let count = match &fetch.count_mode { + Some(fetch_rel::CountMode::Count(count)) => { + // -1 means that ALL records should be returned, equivalent to None + (*count != -1).then(|| lit(*count)) } - } - Some(RelType::Sort(sort)) => { - if let Some(input) = sort.input.as_ref() { - let input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - let sorts = - from_substrait_sorts(state, &sort.sorts, input.schema(), extensions) - .await?; - input.sort(sorts)?.build() - } else { - not_impl_err!("Sort without an input is not valid") + Some(fetch_rel::CountMode::CountExpr(expr)) => { + Some(from_substrait_rex(consumer, expr, &empty_schema).await?) } + None => None, + }; + input.limit_by_expr(offset, count)?.build() + } else { + not_impl_err!("Fetch without an input is not valid") + } +} + +pub async fn from_sort_rel( + consumer: &impl SubstraitConsumer, + sort: &SortRel, +) -> Result { + if let Some(input) = sort.input.as_ref() { + let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let sorts = from_substrait_sorts(consumer, &sort.sorts, input.schema()).await?; + input.sort(sorts)?.build() + } else { + not_impl_err!("Sort without an input is not valid") + } +} + +pub async fn from_aggregate_rel( + consumer: &impl SubstraitConsumer, + agg: &AggregateRel, +) -> Result { + if let Some(input) = agg.input.as_ref() { + let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let mut ref_group_exprs = vec![]; + + for e in &agg.grouping_expressions { + let x = from_substrait_rex(consumer, e, input.schema()).await?; + ref_group_exprs.push(x); } - Some(RelType::Aggregate(agg)) => { - if let Some(input) = agg.input.as_ref() { - let input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - let mut ref_group_exprs = vec![]; - for e in &agg.grouping_expressions { - let x = - from_substrait_rex(state, e, input.schema(), extensions).await?; - ref_group_exprs.push(x); + let mut group_exprs = vec![]; + let mut aggr_exprs = vec![]; + + match agg.groupings.len() { + 1 => { + group_exprs.extend_from_slice( + &from_substrait_grouping( + consumer, + &agg.groupings[0], + &ref_group_exprs, + input.schema(), + ) + .await?, + ); + } + _ => { + let mut grouping_sets = vec![]; + for grouping in &agg.groupings { + let grouping_set = from_substrait_grouping( + consumer, + grouping, + &ref_group_exprs, + input.schema(), + ) + .await?; + grouping_sets.push(grouping_set); } + // Single-element grouping expression of type Expr::GroupingSet. + // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when + // parsed by the producer and consumer, since Substrait does not have a type dedicated + // to ROLLUP. Only vector of Groupings (grouping sets) is available. + group_exprs + .push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets))); + } + }; - let mut group_exprs = vec![]; - let mut aggr_exprs = vec![]; - - match agg.groupings.len() { - 1 => { - group_exprs.extend_from_slice( - &from_substrait_grouping( - state, - &agg.groupings[0], - &ref_group_exprs, - input.schema(), - extensions, - ) - .await?, - ); - } - _ => { - let mut grouping_sets = vec![]; - for grouping in &agg.groupings { - let grouping_set = from_substrait_grouping( - state, - grouping, - &ref_group_exprs, - input.schema(), - extensions, - ) - .await?; - grouping_sets.push(grouping_set); + for m in &agg.measures { + let filter = match &m.filter { + Some(fil) => Some(Box::new( + from_substrait_rex(consumer, fil, input.schema()).await?, + )), + None => None, + }; + let agg_func = match &m.measure { + Some(f) => { + let distinct = match f.invocation { + _ if f.invocation == AggregationInvocation::Distinct as i32 => { + true } - // Single-element grouping expression of type Expr::GroupingSet. - // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when - // parsed by the producer and consumer, since Substrait does not have a type dedicated - // to ROLLUP. Only vector of Groupings (grouping sets) is available. - group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets( - grouping_sets, - ))); - } - }; - - for m in &agg.measures { - let filter = match &m.filter { - Some(fil) => Some(Box::new( - from_substrait_rex(state, fil, input.schema(), extensions) - .await?, - )), - None => None, + _ if f.invocation == AggregationInvocation::All as i32 => false, + _ => false, }; - let agg_func = match &m.measure { - Some(f) => { - let distinct = match f.invocation { - _ if f.invocation - == AggregationInvocation::Distinct as i32 => - { - true - } - _ if f.invocation - == AggregationInvocation::All as i32 => - { - false - } - _ => false, - }; - let order_by = if !f.sorts.is_empty() { - Some( - from_substrait_sorts( - state, - &f.sorts, - input.schema(), - extensions, - ) - .await?, - ) - } else { - None - }; - - from_substrait_agg_func( - state, - f, - input.schema(), - extensions, - filter, - order_by, - distinct, - ) - .await - } - None => not_impl_err!( - "Aggregate without aggregate function is not supported" - ), + let order_by = if !f.sorts.is_empty() { + Some( + from_substrait_sorts(consumer, &f.sorts, input.schema()) + .await?, + ) + } else { + None }; - aggr_exprs.push(agg_func?.as_ref().clone()); - } - input.aggregate(group_exprs, aggr_exprs)?.build() - } else { - not_impl_err!("Aggregate without an input is not valid") - } - } - Some(RelType::Join(join)) => { - if join.post_join_filter.is_some() { - return not_impl_err!( - "JoinRel with post_join_filter is not yet supported" - ); - } - let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - from_substrait_rel(state, join.left.as_ref().unwrap(), extensions) - .await?, - ); - let right = LogicalPlanBuilder::from( - from_substrait_rel(state, join.right.as_ref().unwrap(), extensions) - .await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - - let join_type = from_substrait_jointype(join.r#type)?; - // The join condition expression needs full input schema and not the output schema from join since we lose columns from - // certain join types such as semi and anti joins - let in_join_schema = left.schema().join(right.schema())?; - - // If join expression exists, parse the `on` condition expression, build join and return - // Otherwise, build join with only the filter, without join keys - match &join.expression.as_ref() { - Some(expr) => { - let on = from_substrait_rex(state, expr, &in_join_schema, extensions) - .await?; - // The join expression can contain both equal and non-equal ops. - // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. - // So we extract each part as follows: - // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector - // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) - let (join_ons, nulls_equal_nulls, join_filter) = - split_eq_and_noneq_join_predicate_with_nulls_equality(&on); - let (left_cols, right_cols): (Vec<_>, Vec<_>) = - itertools::multiunzip(join_ons); - left.join_detailed( - right.build()?, - join_type, - (left_cols, right_cols), - join_filter, - nulls_equal_nulls, - )? - .build() + from_substrait_agg_func( + consumer, + f, + input.schema(), + filter, + order_by, + distinct, + ) + .await } None => { - let on: Vec = vec![]; - left.join_detailed( - right.build()?, - join_type, - (on.clone(), on), - None, - false, - )? - .build() + not_impl_err!("Aggregate without aggregate function is not supported") } - } + }; + aggr_exprs.push(agg_func?.as_ref().clone()); } - Some(RelType::Cross(cross)) => { - let left = LogicalPlanBuilder::from( - from_substrait_rel(state, cross.left.as_ref().unwrap(), extensions) - .await?, - ); - let right = LogicalPlanBuilder::from( - from_substrait_rel(state, cross.right.as_ref().unwrap(), extensions) - .await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - left.cross_join(right.build()?)?.build() + input.aggregate(group_exprs, aggr_exprs)?.build() + } else { + not_impl_err!("Aggregate without an input is not valid") + } +} + +pub async fn from_join_rel( + consumer: &impl SubstraitConsumer, + join: &JoinRel, +) -> Result { + if join.post_join_filter.is_some() { + return not_impl_err!("JoinRel with post_join_filter is not yet supported"); + } + + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + from_substrait_rel(consumer, join.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + from_substrait_rel(consumer, join.right.as_ref().unwrap()).await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + + let join_type = from_substrait_jointype(join.r#type)?; + // The join condition expression needs full input schema and not the output schema from join since we lose columns from + // certain join types such as semi and anti joins + let in_join_schema = left.schema().join(right.schema())?; + + // If join expression exists, parse the `on` condition expression, build join and return + // Otherwise, build join with only the filter, without join keys + match &join.expression.as_ref() { + Some(expr) => { + let on = from_substrait_rex(consumer, expr, &in_join_schema).await?; + // The join expression can contain both equal and non-equal ops. + // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. + // So we extract each part as follows: + // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector + // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) + let (join_ons, nulls_equal_nulls, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = + itertools::multiunzip(join_ons); + left.join_detailed( + right.build()?, + join_type, + (left_cols, right_cols), + join_filter, + nulls_equal_nulls, + )? + .build() } - Some(RelType::Read(read)) => { - async fn read_with_schema( - state: &dyn SubstraitPlanningState, - table_ref: TableReference, - schema: DFSchema, - projection: &Option, - ) -> Result { - let schema = schema.replace_qualifier(table_ref.clone()); - - let plan = { - let provider = match state.table(&table_ref).await? { - Some(ref provider) => Arc::clone(provider), - _ => return plan_err!("No table named '{table_ref}'"), - }; + None => { + let on: Vec = vec![]; + left.join_detailed(right.build()?, join_type, (on.clone(), on), None, false)? + .build() + } + } +} - LogicalPlanBuilder::scan( - table_ref, - provider_as_source(Arc::clone(&provider)), - None, - )? - .build()? - }; +pub async fn from_cross_rel( + consumer: &impl SubstraitConsumer, + cross: &CrossRel, +) -> Result { + let left = LogicalPlanBuilder::from( + from_substrait_rel(consumer, cross.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + from_substrait_rel(consumer, cross.right.as_ref().unwrap()).await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + left.cross_join(right.build()?)?.build() +} + +#[allow(deprecated)] +pub async fn from_read_rel( + consumer: &impl SubstraitConsumer, + read: &ReadRel, +) -> Result { + async fn read_with_schema( + consumer: &impl SubstraitConsumer, + table_ref: TableReference, + schema: DFSchema, + projection: &Option, + ) -> Result { + let schema = schema.replace_qualifier(table_ref.clone()); + + let plan = { + let provider = match consumer.resolve_table_ref(&table_ref).await? { + Some(ref provider) => Arc::clone(provider), + _ => return plan_err!("No table named '{table_ref}'"), + }; - ensure_schema_compatability(plan.schema(), schema.clone())?; + LogicalPlanBuilder::scan( + table_ref, + provider_as_source(Arc::clone(&provider)), + None, + )? + .build()? + }; - let schema = apply_masking(schema, projection)?; + ensure_schema_compatability(plan.schema(), schema.clone())?; - apply_projection(plan, schema) - } + let schema = apply_masking(schema, projection)?; - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Read Relation") - })?; + apply_projection(plan, schema) + } - let substrait_schema = from_substrait_named_struct(named_struct, extensions)?; + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Read Relation") + })?; - match &read.as_ref().read_type { - Some(ReadType::NamedTable(nt)) => { - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; + let substrait_schema = from_substrait_named_struct(consumer, named_struct)?; - read_with_schema( - state, - table_reference, - substrait_schema, - &read.projection, - ) - .await + match &read.read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); } - Some(ReadType::VirtualTable(vt)) => { - if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: DFSchemaRef::new(substrait_schema), - })); - } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; - let values = vt - .values - .iter() - .map(|row| { - let mut name_idx = 0; - let lits = row - .fields - .iter() - .map(|lit| { - name_idx += 1; // top-level names are provided through schema - Ok(Expr::Literal(from_substrait_literal( - lit, - extensions, - &named_struct.names, - &mut name_idx, - )?)) - }) - .collect::>()?; - if name_idx != named_struct.names.len() { - return substrait_err!( + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + ) + .await + } + Some(ReadType::VirtualTable(vt)) => { + if vt.values.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(substrait_schema), + })); + } + + let values = vt + .values + .iter() + .map(|row| { + let mut name_idx = 0; + let lits = row + .fields + .iter() + .map(|lit| { + name_idx += 1; // top-level names are provided through schema + Ok(Expr::Literal(from_substrait_literal( + consumer, + lit, + &named_struct.names, + &mut name_idx, + )?)) + }) + .collect::>()?; + if name_idx != named_struct.names.len() { + return substrait_err!( "Names list must match exactly to nested schema, but found {} uses for {} names", name_idx, named_struct.names.len() ); - } - Ok(lits) - }) - .collect::>()?; + } + Ok(lits) + }) + .collect::>()?; - Ok(LogicalPlan::Values(Values { - schema: DFSchemaRef::new(substrait_schema), - values, - })) - } - Some(ReadType::LocalFiles(lf)) => { - fn extract_filename(name: &str) -> Option { - let corrected_url = if name.starts_with("file://") - && !name.starts_with("file:///") - { - name.replacen("file://", "file:///", 1) - } else { - name.to_string() - }; + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(substrait_schema), + values, + })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = + if name.starts_with("file://") && !name.starts_with("file:///") { + name.replacen("file://", "file:///", 1) + } else { + name.to_string() + }; - Url::parse(&corrected_url).ok().and_then(|url| { - let path = url.path(); - std::path::Path::new(path) - .file_name() - .map(|filename| filename.to_string_lossy().to_string()) - }) - } + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } - // we could use the file name to check the original table provider - // TODO: currently does not support multiple local files - let filename: Option = - lf.items.first().and_then(|x| match x.path_type.as_ref() { - Some(UriFile(name)) => extract_filename(name), - _ => None, - }); + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); - if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!("Only single file reads are supported"); - } - let name = filename.unwrap(); - // directly use unwrap here since we could determine it is a valid one - let table_reference = TableReference::Bare { table: name.into() }; - - read_with_schema( - state, - table_reference, - substrait_schema, - &read.projection, - ) - .await - } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) - } - } - } - Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { - Ok(set_op) => { - if set.inputs.len() < 2 { - substrait_err!("Set operation requires at least two inputs") - } else { - match set_op { - set_rel::SetOp::UnionAll => { - union_rels(&set.inputs, state, extensions, true).await - } - set_rel::SetOp::UnionDistinct => { - union_rels(&set.inputs, state, extensions, false).await - } - set_rel::SetOp::IntersectionPrimary => { - LogicalPlanBuilder::intersect( - from_substrait_rel(state, &set.inputs[0], extensions) - .await?, - union_rels(&set.inputs[1..], state, extensions, true) - .await?, - false, - ) - } - set_rel::SetOp::IntersectionMultiset => { - intersect_rels(&set.inputs, state, extensions, false).await - } - set_rel::SetOp::IntersectionMultisetAll => { - intersect_rels(&set.inputs, state, extensions, true).await - } - set_rel::SetOp::MinusPrimary => { - except_rels(&set.inputs, state, extensions, false).await - } - set_rel::SetOp::MinusPrimaryAll => { - except_rels(&set.inputs, state, extensions, true).await - } - _ => not_impl_err!("Unsupported set operator: {set_op:?}"), - } - } + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); } - Err(e) => not_impl_err!("Invalid set operation type {}: {e}", set.op), - }, - Some(RelType::ExtensionLeaf(extension)) => { - let Some(ext_detail) = &extension.detail else { - return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); - }; - let plan = state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + ) + .await } - Some(RelType::ExtensionSingle(extension)) => { - let Some(ext_detail) = &extension.detail else { - return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); - }; - let plan = state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let Some(input_rel) = &extension.input else { - return substrait_err!( - "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" - ); - }; - let input_plan = from_substrait_rel(state, input_rel, extensions).await?; - let plan = - plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + _ => { + not_impl_err!("Unsupported ReadType: {:?}", read.read_type) } - Some(RelType::ExtensionMulti(extension)) => { - let Some(ext_detail) = &extension.detail else { - return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); - }; - let plan = state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let mut inputs = Vec::with_capacity(extension.inputs.len()); - for input in &extension.inputs { - let input_plan = from_substrait_rel(state, input, extensions).await?; - inputs.push(input_plan); - } - let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + } +} + +pub async fn from_set_rel( + consumer: &impl SubstraitConsumer, + set: &SetRel, +) -> Result { + if set.inputs.len() < 2 { + substrait_err!("Set operation requires at least two inputs") + } else { + match set.op() { + SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await, + SetOp::UnionDistinct => union_rels(consumer, &set.inputs, false).await, + SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect( + from_substrait_rel(consumer, &set.inputs[0]).await?, + union_rels(consumer, &set.inputs[1..], true).await?, + false, + ), + SetOp::IntersectionMultiset => { + intersect_rels(consumer, &set.inputs, false).await + } + SetOp::IntersectionMultisetAll => { + intersect_rels(consumer, &set.inputs, true).await + } + SetOp::MinusPrimary => except_rels(consumer, &set.inputs, false).await, + SetOp::MinusPrimaryAll => except_rels(consumer, &set.inputs, true).await, + set_op => not_impl_err!("Unsupported set operator: {set_op:?}"), } - Some(RelType::Exchange(exchange)) => { - let Some(input) = exchange.input.as_ref() else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; - let input = Arc::new(from_substrait_rel(state, input, extensions).await?); + } +} - let Some(exchange_kind) = &exchange.exchange_kind else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; +pub async fn from_exchange_rel( + consumer: &impl SubstraitConsumer, + exchange: &ExchangeRel, +) -> Result { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(from_substrait_rel(consumer, input).await?); - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let partitioning_scheme = match exchange_kind { - ExchangeKind::ScatterByFields(scatter_fields) => { - let mut partition_columns = vec![]; - let input_schema = input.schema(); - for field_ref in &scatter_fields.fields { - let column = - from_substrait_field_reference(field_ref, input_schema)?; - partition_columns.push(column); - } - Partitioning::Hash( - partition_columns, - exchange.partition_count as usize, - ) - } - ExchangeKind::RoundRobin(_) => { - Partitioning::RoundRobinBatch(exchange.partition_count as usize) - } - ExchangeKind::SingleTarget(_) - | ExchangeKind::MultiTarget(_) - | ExchangeKind::Broadcast(_) => { - return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); - } - }; - Ok(LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - })) + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); + } + Partitioning::Hash(partition_columns, exchange.partition_count as usize) + } + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) + } + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); } - _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), }; - apply_emit_kind(retrieve_rel_common(rel), plan?) + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) } fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> { @@ -1384,7 +1770,7 @@ fn compatible_nullabilities( } /// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise -/// conflict with the columns from the other. +/// conflict with the columns from the other. /// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For /// Substrait the names don't matter since it only refers to columns by indices, however DataFusion /// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). @@ -1430,16 +1816,14 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, substrait_sorts: &Vec, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { let expr = - from_substrait_rex(state, s.expr.as_ref().unwrap(), input_schema, extensions) - .await?; + from_substrait_rex(consumer, s.expr.as_ref().unwrap(), input_schema).await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { @@ -1480,15 +1864,13 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, exprs: &Vec, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = - from_substrait_rex(state, expr, input_schema, extensions).await?; + let expression = from_substrait_rex(consumer, expr, input_schema).await?; expressions.push(expression); } Ok(expressions) @@ -1496,16 +1878,15 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substrait_func_args( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, arguments: &Vec, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result> { let mut args: Vec = vec![]; for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(state, e, input_schema, extensions).await + from_substrait_rex(consumer, e, input_schema).await } _ => not_impl_err!("Function argument non-Value type not supported"), }; @@ -1516,370 +1897,416 @@ pub async fn from_substrait_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, f: &AggregateFunction, input_schema: &DFSchema, - extensions: &Extensions, filter: Option>, order_by: Option>, distinct: bool, ) -> Result> { - let args = - from_substrait_func_args(state, &f.arguments, input_schema, extensions).await?; - - let Some(function_name) = extensions.functions.get(&f.function_reference) else { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { return plan_err!( "Aggregate function not registered: function anchor = {:?}", f.function_reference ); }; - let function_name = substrait_fun_name(function_name); - // try udaf first, then built-in aggr fn. - if let Ok(fun) = state.udaf(function_name) { - // deal with situation that count(*) got no arguments - let args = if fun.name() == "count" && args.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] - } else { - args - }; - - Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), - ))) - } else { - not_impl_err!( + let fn_name = substrait_fun_name(fn_signature); + let udaf = consumer.get_function_registry().udaf(fn_name); + let udaf = udaf.map_err(|_| { + not_impl_datafusion_err!( "Aggregate function {} is not supported: function anchor = {:?}", - function_name, + fn_signature, f.function_reference ) - } + })?; + + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + // deal with situation that count(*) got no arguments + let args = if udaf.name() == "count" && args.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + } else { + args + }; + + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None), + ))) } /// Convert Substrait Rex to DataFusion Expr -#[async_recursion] pub async fn from_substrait_rex( - state: &dyn SubstraitPlanningState, - e: &Expression, + consumer: &impl SubstraitConsumer, + expression: &Expression, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result { - match &e.rex_type { - Some(RexType::SingularOrList(s)) => { - let substrait_expr = s.value.as_ref().unwrap(); - let substrait_list = s.options.as_ref(); - Ok(Expr::InList(InList { - expr: Box::new( - from_substrait_rex(state, substrait_expr, input_schema, extensions) - .await?, - ), - list: from_substrait_rex_vec( - state, - substrait_list, - input_schema, - extensions, - ) - .await?, - negated: false, - })) - } - Some(RexType::Selection(field_ref)) => { - Ok(from_substrait_field_reference(field_ref, input_schema)?) - } - Some(RexType::IfThen(if_then)) => { - // Parse `ifs` - // If the first element does not have a `then` part, then we can assume it's a base expression - let mut when_then_expr: Vec<(Box, Box)> = vec![]; - let mut expr = None; - for (i, if_expr) in if_then.ifs.iter().enumerate() { - if i == 0 { - // Check if the first element is type base expression - if if_expr.then.is_none() { - expr = Some(Box::new( - from_substrait_rex( - state, - if_expr.r#if.as_ref().unwrap(), - input_schema, - extensions, - ) - .await?, - )); - continue; - } - } - when_then_expr.push(( - Box::new( - from_substrait_rex( - state, - if_expr.r#if.as_ref().unwrap(), - input_schema, - extensions, - ) - .await?, - ), - Box::new( - from_substrait_rex( - state, - if_expr.then.as_ref().unwrap(), - input_schema, - extensions, - ) - .await?, - ), - )); + match &expression.rex_type { + Some(t) => match t { + RexType::Literal(expr) => consumer.consume_literal(expr).await, + RexType::Selection(expr) => { + consumer.consume_field_reference(expr, input_schema).await + } + RexType::ScalarFunction(expr) => { + consumer.consume_scalar_function(expr, input_schema).await + } + RexType::WindowFunction(expr) => { + consumer.consume_window_function(expr, input_schema).await + } + RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await, + RexType::SwitchExpression(expr) => { + consumer.consume_switch(expr, input_schema).await + } + RexType::SingularOrList(expr) => { + consumer.consume_singular_or_list(expr, input_schema).await } - // Parse `else` - let else_expr = match &if_then.r#else { - Some(e) => Some(Box::new( - from_substrait_rex(state, e, input_schema, extensions).await?, - )), - None => None, - }; - Ok(Expr::Case(Case { - expr, - when_then_expr, - else_expr, - })) - } - Some(RexType::ScalarFunction(f)) => { - let Some(fn_name) = extensions.functions.get(&f.function_reference) else { - return plan_err!( - "Scalar function not found: function reference = {:?}", - f.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_name); - let args = - from_substrait_func_args(state, &f.arguments, input_schema, extensions) - .await?; + RexType::MultiOrList(expr) => { + consumer.consume_multi_or_list(expr, input_schema).await + } - // try to first match the requested function into registered udfs, then built-in ops - // and finally built-in expressions - if let Ok(func) = state.udf(fn_name) { - Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( - func.to_owned(), - args, - ))) - } else if let Some(op) = name_to_op(fn_name) { - if f.arguments.len() < 2 { - return not_impl_err!( - "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", - f.arguments.len() - ); - } - // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. - // In those cases we iterate through all the arguments, applying the binary expression against them all - let combined_expr = args - .into_iter() - .fold(None, |combined_expr: Option, arg: Expr| { - Some(match combined_expr { - Some(expr) => Expr::BinaryExpr(BinaryExpr { - left: Box::new(expr), - op, - right: Box::new(arg), - }), - None => arg, - }) - }) - .unwrap(); + RexType::Cast(expr) => { + consumer.consume_cast(expr.as_ref(), input_schema).await + } - Ok(combined_expr) - } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { - builder.build(state, f, input_schema, extensions).await - } else { - not_impl_err!("Unsupported function name: {fn_name:?}") + RexType::Subquery(expr) => { + consumer.consume_subquery(expr.as_ref(), input_schema).await } - } - Some(RexType::Literal(lit)) => { - let scalar_value = from_substrait_literal_without_names(lit, extensions)?; - Ok(Expr::Literal(scalar_value)) - } - Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { - Some(output_type) => { - let input_expr = Box::new( + RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, + RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, + }, + None => substrait_err!("Expression must set rex_type: {:?}", expression), + } +} + +pub async fn from_singular_or_list( + consumer: &impl SubstraitConsumer, + expr: &SingularOrList, + input_schema: &DFSchema, +) -> Result { + let substrait_expr = expr.value.as_ref().unwrap(); + let substrait_list = expr.options.as_ref(); + Ok(Expr::InList(InList { + expr: Box::new(from_substrait_rex(consumer, substrait_expr, input_schema).await?), + list: from_substrait_rex_vec(consumer, substrait_list, input_schema).await?, + negated: false, + })) +} + +pub async fn from_field_reference( + _consumer: &impl SubstraitConsumer, + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> Result { + from_substrait_field_reference(field_ref, input_schema) +} + +pub async fn from_if_then( + consumer: &impl SubstraitConsumer, + if_then: &IfThen, + input_schema: &DFSchema, +) -> Result { + // Parse `ifs` + // If the first element does not have a `then` part, then we can assume it's a base expression + let mut when_then_expr: Vec<(Box, Box)> = vec![]; + let mut expr = None; + for (i, if_expr) in if_then.ifs.iter().enumerate() { + if i == 0 { + // Check if the first element is type base expression + if if_expr.then.is_none() { + expr = Some(Box::new( from_substrait_rex( - state, - cast.as_ref().input.as_ref().unwrap().as_ref(), + consumer, + if_expr.r#if.as_ref().unwrap(), input_schema, - extensions, ) .await?, - ); - let data_type = - from_substrait_type_without_names(output_type, extensions)?; - if cast.failure_behavior() == ReturnNull { - Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) - } else { - Ok(Expr::Cast(Cast::new(input_expr, data_type))) - } + )); + continue; } - None => substrait_err!("Cast expression without output type is not allowed"), - }, - Some(RexType::WindowFunction(window)) => { - let Some(fn_name) = extensions.functions.get(&window.function_reference) - else { - return plan_err!( - "Window function not found: function reference = {:?}", - window.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_name); - - // check udwf first, then udaf, then built-in window and aggregate functions - let fun = if let Ok(udwf) = state.udwf(fn_name) { - Ok(WindowFunctionDefinition::WindowUDF(udwf)) - } else if let Ok(udaf) = state.udaf(fn_name) { - Ok(WindowFunctionDefinition::AggregateUDF(udaf)) - } else { - not_impl_err!( - "Window function {} is not supported: function anchor = {:?}", - fn_name, - window.function_reference + } + when_then_expr.push(( + Box::new( + from_substrait_rex( + consumer, + if_expr.r#if.as_ref().unwrap(), + input_schema, ) - }?; - - let order_by = - from_substrait_sorts(state, &window.sorts, input_schema, extensions) - .await?; - - let bound_units = - match BoundsType::try_from(window.bounds_type).map_err(|e| { - plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) - })? { - BoundsType::Rows => WindowFrameUnits::Rows, - BoundsType::Range => WindowFrameUnits::Range, - BoundsType::Unspecified => { - // If the plan does not specify the bounds type, then we use a simple logic to determine the units - // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary - // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row - if order_by.is_empty() { - WindowFrameUnits::Rows - } else { - WindowFrameUnits::Range - } - } - }; - Ok(Expr::WindowFunction(expr::WindowFunction { - fun, - args: from_substrait_func_args( - state, - &window.arguments, + .await?, + ), + Box::new( + from_substrait_rex( + consumer, + if_expr.then.as_ref().unwrap(), input_schema, - extensions, ) .await?, - partition_by: from_substrait_rex_vec( - state, - &window.partitions, + ), + )); + } + // Parse `else` + let else_expr = match &if_then.r#else { + Some(e) => Some(Box::new( + from_substrait_rex(consumer, e, input_schema).await?, + )), + None => None, + }; + Ok(Expr::Case(Case { + expr, + when_then_expr, + else_expr, + })) +} + +pub async fn from_scalar_function( + consumer: &impl SubstraitConsumer, + f: &ScalarFunction, + input_schema: &DFSchema, +) -> Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { + return plan_err!( + "Scalar function not found: function reference = {:?}", + f.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_signature); + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + // try to first match the requested function into registered udfs, then built-in ops + // and finally built-in expressions + if let Ok(func) = consumer.get_function_registry().udf(fn_name) { + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + func.to_owned(), + args, + ))) + } else if let Some(op) = name_to_op(fn_name) { + if f.arguments.len() < 2 { + return not_impl_err!( + "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", + f.arguments.len() + ); + } + // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. + // In those cases we iterate through all the arguments, applying the binary expression against them all + let combined_expr = args + .into_iter() + .fold(None, |combined_expr: Option, arg: Expr| { + Some(match combined_expr { + Some(expr) => Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr), + op, + right: Box::new(arg), + }), + None => arg, + }) + }) + .unwrap(); + + Ok(combined_expr) + } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { + builder.build(consumer, f, input_schema).await + } else { + not_impl_err!("Unsupported function name: {fn_name:?}") + } +} + +pub async fn from_literal( + consumer: &impl SubstraitConsumer, + expr: &Literal, +) -> Result { + let scalar_value = from_substrait_literal_without_names(consumer, expr)?; + Ok(Expr::Literal(scalar_value)) +} + +pub async fn from_cast( + consumer: &impl SubstraitConsumer, + cast: &substrait_expression::Cast, + input_schema: &DFSchema, +) -> Result { + match cast.r#type.as_ref() { + Some(output_type) => { + let input_expr = Box::new( + from_substrait_rex( + consumer, + cast.input.as_ref().unwrap().as_ref(), input_schema, - extensions, ) .await?, - order_by, - window_frame: datafusion::logical_expr::WindowFrame::new_bounds( - bound_units, - from_substrait_bound(&window.lower_bound, true)?, - from_substrait_bound(&window.upper_bound, false)?, - ), - null_treatment: None, - })) + ); + let data_type = from_substrait_type_without_names(consumer, output_type)?; + if cast.failure_behavior() == ReturnNull { + Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) + } else { + Ok(Expr::Cast(Cast::new(input_expr, data_type))) + } } - Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { - Some(subquery_type) => match subquery_type { - SubqueryType::InPredicate(in_predicate) => { - if in_predicate.needles.len() != 1 { - substrait_err!("InPredicate Subquery type must have exactly one Needle expression") - } else { - let needle_expr = &in_predicate.needles[0]; - let haystack_expr = &in_predicate.haystack; - if let Some(haystack_expr) = haystack_expr { - let haystack_expr = - from_substrait_rel(state, haystack_expr, extensions) - .await?; - let outer_refs = haystack_expr.all_out_ref_exprs(); - Ok(Expr::InSubquery(InSubquery { - expr: Box::new( - from_substrait_rex( - state, - needle_expr, - input_schema, - extensions, - ) + None => substrait_err!("Cast expression without output type is not allowed"), + } +} + +pub async fn from_window_function( + consumer: &impl SubstraitConsumer, + window: &WindowFunction, + input_schema: &DFSchema, +) -> Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&window.function_reference) + else { + return plan_err!( + "Window function not found: function reference = {:?}", + window.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_signature); + + // check udwf first, then udaf, then built-in window and aggregate functions + let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) { + Ok(WindowFunctionDefinition::WindowUDF(udwf)) + } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) { + Ok(WindowFunctionDefinition::AggregateUDF(udaf)) + } else { + not_impl_err!( + "Window function {} is not supported: function anchor = {:?}", + fn_name, + window.function_reference + ) + }?; + + let order_by = from_substrait_sorts(consumer, &window.sorts, input_schema).await?; + + let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| { + plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) + })? { + BoundsType::Rows => WindowFrameUnits::Rows, + BoundsType::Range => WindowFrameUnits::Range, + BoundsType::Unspecified => { + // If the plan does not specify the bounds type, then we use a simple logic to determine the units + // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary + // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row + if order_by.is_empty() { + WindowFrameUnits::Rows + } else { + WindowFrameUnits::Range + } + } + }; + Ok(Expr::WindowFunction(expr::WindowFunction { + fun, + args: from_substrait_func_args(consumer, &window.arguments, input_schema).await?, + partition_by: from_substrait_rex_vec(consumer, &window.partitions, input_schema) + .await?, + order_by, + window_frame: datafusion::logical_expr::WindowFrame::new_bounds( + bound_units, + from_substrait_bound(&window.lower_bound, true)?, + from_substrait_bound(&window.upper_bound, false)?, + ), + null_treatment: None, + })) +} + +pub async fn from_subquery( + consumer: &impl SubstraitConsumer, + subquery: &substrait_expression::Subquery, + input_schema: &DFSchema, +) -> Result { + match &subquery.subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + substrait_err!("InPredicate Subquery type must have exactly one Needle expression") + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = + from_substrait_rel(consumer, haystack_expr).await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Expr::InSubquery(InSubquery { + expr: Box::new( + from_substrait_rex(consumer, needle_expr, input_schema) .await?, - ), - subquery: Subquery { - subquery: Arc::new(haystack_expr), - outer_ref_columns: outer_refs, - }, - negated: false, - })) - } else { - substrait_err!("InPredicate Subquery type must have a Haystack expression") - } + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + }, + negated: false, + })) + } else { + substrait_err!( + "InPredicate Subquery type must have a Haystack expression" + ) } } - SubqueryType::Scalar(query) => { - let plan = from_substrait_rel( - state, - &(query.input.clone()).unwrap_or_default(), - extensions, - ) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - })) - } - SubqueryType::SetPredicate(predicate) => { - match predicate.predicate_op() { - // exist - PredicateOp::Exists => { - let relation = &predicate.tuples; - let plan = from_substrait_rel( - state, - &relation.clone().unwrap_or_default(), - extensions, - ) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::Exists(Exists::new( - Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - }, - false, - ))) - } - other_type => substrait_err!( - "unimplemented type {:?} for set predicate", - other_type - ), + } + SubqueryType::Scalar(query) => { + let plan = from_substrait_rel( + consumer, + &(query.input.clone()).unwrap_or_default(), + ) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + })) + } + SubqueryType::SetPredicate(predicate) => { + match predicate.predicate_op() { + // exist + PredicateOp::Exists => { + let relation = &predicate.tuples; + let plan = from_substrait_rel( + consumer, + &relation.clone().unwrap_or_default(), + ) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + }, + false, + ))) } + other_type => substrait_err!( + "unimplemented type {:?} for set predicate", + other_type + ), } - other_type => { - substrait_err!("Subquery type {:?} not implemented", other_type) - } - }, - None => { - substrait_err!("Subquery expression without SubqueryType is not allowed") + } + other_type => { + substrait_err!("Subquery type {:?} not implemented", other_type) } }, - _ => not_impl_err!("unsupported rex_type"), + None => { + substrait_err!("Subquery expression without SubqueryType is not allowed") + } } } pub(crate) fn from_substrait_type_without_names( + consumer: &impl SubstraitConsumer, dt: &Type, - extensions: &Extensions, ) -> Result { - from_substrait_type(dt, extensions, &[], &mut 0) + from_substrait_type(consumer, dt, &[], &mut 0) } fn from_substrait_type( + consumer: &impl SubstraitConsumer, dt: &Type, - extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1992,7 +2419,7 @@ fn from_substrait_type( substrait_datafusion_err!("List type must have inner type") })?; let field = Arc::new(Field::new_list_field( - from_substrait_type(inner_type, extensions, dfs_names, name_idx)?, + from_substrait_type(consumer, inner_type, dfs_names, name_idx)?, // We ignore Substrait's nullability here to match to_substrait_literal // which always creates nullable lists true, @@ -2014,12 +2441,12 @@ fn from_substrait_type( })?; let key_field = Arc::new(Field::new( "key", - from_substrait_type(key_type, extensions, dfs_names, name_idx)?, + from_substrait_type(consumer, key_type, dfs_names, name_idx)?, false, )); let value_field = Arc::new(Field::new( "value", - from_substrait_type(value_type, extensions, dfs_names, name_idx)?, + from_substrait_type(consumer, value_type, dfs_names, name_idx)?, true, )); Ok(DataType::Map( @@ -2050,42 +2477,48 @@ fn from_substrait_type( Ok(DataType::Interval(IntervalUnit::MonthDayNano)) } r#type::Kind::UserDefined(u) => { - if let Some(name) = extensions.types.get(&u.type_reference) { + if let Ok(data_type) = consumer.consume_user_defined_type(u) { + return Ok(data_type); + } + + // TODO: remove the code below once the producer has been updated + if let Some(name) = consumer.get_extensions().types.get(&u.type_reference) + { #[allow(deprecated)] - match name.as_ref() { - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + match name.as_ref() { + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", u.type_reference, u.type_variation_reference ), - } + } } else { #[allow(deprecated)] - match u.type_reference { - // Kept for backwards compatibility, producers should use IntervalYear instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - // Kept for backwards compatibility, producers should use IntervalDay instead - INTERVAL_DAY_TIME_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) - } - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - _ => not_impl_err!( + match u.type_reference { + // Kept for backwards compatibility, producers should use IntervalYear instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + // Kept for backwards compatibility, producers should use IntervalDay instead + INTERVAL_DAY_TIME_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", u.type_reference, u.type_variation_reference ), - } + } } } r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( - s, extensions, dfs_names, name_idx, + consumer, s, dfs_names, name_idx, )?)), r#type::Kind::Varchar(_) => Ok(DataType::Utf8), r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), @@ -2096,8 +2529,8 @@ fn from_substrait_type( } fn from_substrait_struct_type( + consumer: &impl SubstraitConsumer, s: &r#type::Struct, - extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -2105,7 +2538,7 @@ fn from_substrait_struct_type( for (i, f) in s.types.iter().enumerate() { let field = Field::new( next_struct_field_name(i, dfs_names, name_idx)?, - from_substrait_type(f, extensions, dfs_names, name_idx)?, + from_substrait_type(consumer, f, dfs_names, name_idx)?, true, // We assume everything to be nullable since that's easier than ensuring it matches ); fields.push(field); @@ -2133,15 +2566,15 @@ fn next_struct_field_name( /// Convert Substrait NamedStruct to DataFusion DFSchemaRef pub fn from_substrait_named_struct( + consumer: &impl SubstraitConsumer, base_schema: &NamedStruct, - extensions: &Extensions, ) -> Result { let mut name_idx = 0; let fields = from_substrait_struct_type( + consumer, base_schema.r#struct.as_ref().ok_or_else(|| { substrait_datafusion_err!("Named struct must contain a struct") })?, - extensions, &base_schema.names, &mut name_idx, ); @@ -2202,15 +2635,15 @@ fn from_substrait_bound( } pub(crate) fn from_substrait_literal_without_names( + consumer: &impl SubstraitConsumer, lit: &Literal, - extensions: &Extensions, ) -> Result { - from_substrait_literal(lit, extensions, &vec![], &mut 0) + from_substrait_literal(consumer, lit, &vec![], &mut 0) } fn from_substrait_literal( + consumer: &impl SubstraitConsumer, lit: &Literal, - extensions: &Extensions, dfs_names: &Vec, name_idx: &mut usize, ) -> Result { @@ -2346,12 +2779,7 @@ fn from_substrait_literal( .iter() .map(|el| { element_name_idx = *name_idx; - from_substrait_literal( - el, - extensions, - dfs_names, - &mut element_name_idx, - ) + from_substrait_literal(consumer, el, dfs_names, &mut element_name_idx) }) .collect::>>()?; *name_idx = element_name_idx; @@ -2375,8 +2803,8 @@ fn from_substrait_literal( } Some(LiteralType::EmptyList(l)) => { let element_type = from_substrait_type( + consumer, l.r#type.clone().unwrap().as_ref(), - extensions, dfs_names, name_idx, )?; @@ -2402,14 +2830,14 @@ fn from_substrait_literal( .map(|kv| { entry_name_idx = *name_idx; let key_sv = from_substrait_literal( + consumer, kv.key.as_ref().unwrap(), - extensions, dfs_names, &mut entry_name_idx, )?; let value_sv = from_substrait_literal( + consumer, kv.value.as_ref().unwrap(), - extensions, dfs_names, &mut entry_name_idx, )?; @@ -2447,8 +2875,8 @@ fn from_substrait_literal( Some(v) => Ok(v), _ => plan_err!("Missing value type for empty map"), }?; - let key_type = from_substrait_type(key, extensions, dfs_names, name_idx)?; - let value_type = from_substrait_type(value, extensions, dfs_names, name_idx)?; + let key_type = from_substrait_type(consumer, key, dfs_names, name_idx)?; + let value_type = from_substrait_type(consumer, value, dfs_names, name_idx)?; // new_empty_array on a MapType creates a too empty array // We want it to contain an empty struct array to align with an empty MapBuilder one @@ -2474,7 +2902,7 @@ fn from_substrait_literal( let mut builder = ScalarStructBuilder::new(); for (i, field) in s.fields.iter().enumerate() { let name = next_struct_field_name(i, dfs_names, name_idx)?; - let sv = from_substrait_literal(field, extensions, dfs_names, name_idx)?; + let sv = from_substrait_literal(consumer, field, dfs_names, name_idx)?; // We assume everything to be nullable, since Arrow's strict about things matching // and it's hard to match otherwise. builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); @@ -2482,7 +2910,7 @@ fn from_substrait_literal( builder.build()? } Some(LiteralType::Null(ntype)) => { - from_substrait_null(ntype, extensions, dfs_names, name_idx)? + from_substrait_null(consumer, ntype, dfs_names, name_idx)? } Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { days, @@ -2546,9 +2974,15 @@ fn from_substrait_literal( }, Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { + if let Ok(value) = consumer.consume_user_defined_literal(user_defined) { + return Ok(value); + } + + // TODO: remove the code below once the producer has been updated + // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed let interval_month_day_nano = - |user_defined: &UserDefined| -> Result { + |user_defined: &proto::expression::literal::UserDefined| -> Result { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval month day nano value is empty"); }; @@ -2572,7 +3006,11 @@ fn from_substrait_literal( ))) }; - if let Some(name) = extensions.types.get(&user_defined.type_reference) { + if let Some(name) = consumer + .get_extensions() + .types + .get(&user_defined.type_reference) + { match name.as_ref() { // Kept for backwards compatibility - producers should use IntervalCompound instead #[allow(deprecated)] @@ -2645,8 +3083,8 @@ fn from_substrait_literal( } fn from_substrait_null( + consumer: &impl SubstraitConsumer, null_type: &Type, - extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -2764,8 +3202,8 @@ fn from_substrait_null( r#type::Kind::List(l) => { let field = Field::new_list_field( from_substrait_type( + consumer, l.r#type.clone().unwrap().as_ref(), - extensions, dfs_names, name_idx, )?, @@ -2792,9 +3230,9 @@ fn from_substrait_null( })?; let key_type = - from_substrait_type(key_type, extensions, dfs_names, name_idx)?; + from_substrait_type(consumer, key_type, dfs_names, name_idx)?; let value_type = - from_substrait_type(value_type, extensions, dfs_names, name_idx)?; + from_substrait_type(consumer, value_type, dfs_names, name_idx)?; let entries_field = Arc::new(Field::new_struct( "entries", vec![ @@ -2808,7 +3246,7 @@ fn from_substrait_null( } r#type::Kind::Struct(s) => { let fields = - from_substrait_struct_type(s, extensions, dfs_names, name_idx)?; + from_substrait_struct_type(consumer, s, dfs_names, name_idx)?; Ok(ScalarStructBuilder::new_null(fields)) } _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), @@ -2820,16 +3258,15 @@ fn from_substrait_null( #[allow(deprecated)] async fn from_substrait_grouping( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, grouping: &Grouping, expressions: &[Expr], input_schema: &DFSchemaRef, - extensions: &Extensions, ) -> Result> { let mut group_exprs = vec![]; if !grouping.grouping_expressions.is_empty() { for e in &grouping.grouping_expressions { - let expr = from_substrait_rex(state, e, input_schema, extensions).await?; + let expr = from_substrait_rex(consumer, e, input_schema).await?; group_exprs.push(expr); } return Ok(group_exprs); @@ -2882,29 +3319,17 @@ impl BuiltinExprBuilder { pub async fn build( self, - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result { match self.expr_name.as_str() { - "like" => { - Self::build_like_expr(state, false, f, input_schema, extensions).await - } - "ilike" => { - Self::build_like_expr(state, true, f, input_schema, extensions).await - } + "like" => Self::build_like_expr(consumer, false, f, input_schema).await, + "ilike" => Self::build_like_expr(consumer, true, f, input_schema).await, "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { - Self::build_unary_expr( - state, - &self.expr_name, - f, - input_schema, - extensions, - ) - .await + Self::build_unary_expr(consumer, &self.expr_name, f, input_schema).await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -2913,11 +3338,10 @@ impl BuiltinExprBuilder { } async fn build_unary_expr( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result { if f.arguments.len() != 1 { return substrait_err!("Expect one argument for {fn_name} expr"); @@ -2925,8 +3349,7 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let arg = - from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; + let arg = from_substrait_rex(consumer, expr_substrait, input_schema).await?; let arg = Box::new(arg); let expr = match fn_name { @@ -2947,11 +3370,10 @@ impl BuiltinExprBuilder { } async fn build_like_expr( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; if f.arguments.len() != 2 && f.arguments.len() != 3 { @@ -2961,14 +3383,12 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let expr = - from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; + let expr = from_substrait_rex(consumer, expr_substrait, input_schema).await?; let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = - from_substrait_rex(state, pattern_substrait, input_schema, extensions) - .await?; + from_substrait_rex(consumer, pattern_substrait, input_schema).await?; // Default case: escape character is Literal(Utf8(None)) let escape_char = if f.arguments.len() == 3 { @@ -2977,13 +3397,8 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let escape_char_expr = from_substrait_rex( - state, - escape_char_substrait, - input_schema, - extensions, - ) - .await?; + let escape_char_expr = + from_substrait_rex(consumer, escape_char_substrait, input_schema).await?; match escape_char_expr { Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { @@ -3013,16 +3428,29 @@ impl BuiltinExprBuilder { #[cfg(test)] mod test { use crate::extensions::Extensions; - use crate::logical_plan::consumer::from_substrait_literal_without_names; + use crate::logical_plan::consumer::{ + from_substrait_literal_without_names, DefaultSubstraitConsumer, + }; use arrow_buffer::IntervalMonthDayNano; use datafusion::error::Result; + use datafusion::execution::SessionState; + use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; + use std::sync::OnceLock; use substrait::proto::expression::literal::{ interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, LiteralType, }; use substrait::proto::expression::Literal; + static TEST_SESSION_STATE: OnceLock = OnceLock::new(); + static TEST_EXTENSIONS: OnceLock = OnceLock::new(); + fn test_consumer() -> DefaultSubstraitConsumer<'static> { + let extensions = TEST_EXTENSIONS.get_or_init(Extensions::default); + let state = TEST_SESSION_STATE.get_or_init(|| SessionContext::default().state()); + DefaultSubstraitConsumer::new(extensions, state) + } + #[test] fn interval_compound_different_precision() -> Result<()> { // DF producer (and thus roundtrip) always uses precision = 9, @@ -3046,8 +3474,9 @@ mod test { })), }; + let consumer = test_consumer(); assert_eq!( - from_substrait_literal_without_names(&substrait, &Extensions::default())?, + from_substrait_literal_without_names(&consumer, &substrait)?, ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { months: 14, days: 3, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 375cb734f564..5191a620b473 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2211,11 +2211,11 @@ fn substrait_field_ref(index: usize) -> Result { #[cfg(test)] mod test { - use super::*; use crate::logical_plan::consumer::{ from_substrait_extended_expr, from_substrait_literal_without_names, from_substrait_named_struct, from_substrait_type_without_names, + DefaultSubstraitConsumer, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::{ @@ -2224,7 +2224,17 @@ mod test { use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::common::DFSchema; - use datafusion::execution::SessionStateBuilder; + use datafusion::execution::{SessionState, SessionStateBuilder}; + use datafusion::prelude::SessionContext; + use std::sync::OnceLock; + + static TEST_SESSION_STATE: OnceLock = OnceLock::new(); + static TEST_EXTENSIONS: OnceLock = OnceLock::new(); + fn test_consumer() -> DefaultSubstraitConsumer<'static> { + let extensions = TEST_EXTENSIONS.get_or_init(Extensions::default); + let state = TEST_SESSION_STATE.get_or_init(|| SessionContext::default().state()); + DefaultSubstraitConsumer::new(extensions, state) + } #[test] fn round_trip_literals() -> Result<()> { @@ -2350,7 +2360,7 @@ mod test { let mut extensions = Extensions::default(); let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; let roundtrip_scalar = - from_substrait_literal_without_names(&substrait_literal, &extensions)?; + from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; assert_eq!(scalar, roundtrip_scalar); Ok(()) } @@ -2429,8 +2439,8 @@ mod test { // As DataFusion doesn't consider nullability as a property of the type, but field, // it doesn't matter if we set nullability to true or false here. let substrait = to_substrait_type(&dt, true)?; - let roundtrip_dt = - from_substrait_type_without_names(&substrait, &Extensions::default())?; + let consumer = test_consumer(); + let roundtrip_dt = from_substrait_type_without_names(&consumer, &substrait)?; assert_eq!(dt, roundtrip_dt); Ok(()) } @@ -2481,7 +2491,7 @@ mod test { ); let roundtrip_schema = - from_substrait_named_struct(&named_struct, &Extensions::default())?; + from_substrait_named_struct(&test_consumer(), &named_struct)?; assert_eq!(schema.as_ref(), &roundtrip_schema); Ok(()) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 1291bbd6a244..1ce0eec1b21d 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -30,6 +30,7 @@ use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{ Extension, LogicalPlan, PartitionEvaluator, Repartition, UserDefinedLogicalNode, Values, Volatility, @@ -38,8 +39,6 @@ use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLI use datafusion::prelude::*; use std::hash::Hash; use std::sync::Arc; - -use datafusion::execution::session_state::SessionStateBuilder; use substrait::proto::extensions::simple_extension_declaration::MappingType; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index 00cbfb0c412c..b9e5e0e5257c 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -24,7 +24,9 @@ pub mod test { use datafusion::error::Result; use datafusion::prelude::SessionContext; use datafusion_substrait::extensions::Extensions; - use datafusion_substrait::logical_plan::consumer::from_substrait_named_struct; + use datafusion_substrait::logical_plan::consumer::{ + from_substrait_named_struct, DefaultSubstraitConsumer, SubstraitConsumer, + }; use std::collections::HashMap; use std::fs::File; use std::io::BufReader; @@ -50,7 +52,18 @@ pub mod test { ctx: SessionContext, plan: &Plan, ) -> Result { - let schemas = TestSchemaCollector::collect_schemas(plan)?; + let extensions = Extensions::default(); + let state = ctx.state(); + let consumer = DefaultSubstraitConsumer::new(&extensions, &state); + add_plan_schemas_to_ctx_with_consumer(&consumer, ctx, plan) + } + + fn add_plan_schemas_to_ctx_with_consumer( + consumer: &impl SubstraitConsumer, + ctx: SessionContext, + plan: &Plan, + ) -> Result { + let schemas = TestSchemaCollector::collect_schemas(consumer, plan)?; let mut schema_map: HashMap> = HashMap::new(); for (table_reference, table) in schemas.into_iter() { @@ -71,21 +84,24 @@ pub mod test { Ok(ctx) } - pub struct TestSchemaCollector { + pub struct TestSchemaCollector<'a, T: SubstraitConsumer> { + consumer: &'a T, schemas: Vec<(TableReference, Arc)>, } - impl TestSchemaCollector { - fn new() -> Self { + impl<'a, T: SubstraitConsumer> TestSchemaCollector<'a, T> { + fn new(consumer: &'a T) -> Self { TestSchemaCollector { schemas: Vec::new(), + consumer, } } fn collect_schemas( + consumer: &'a T, plan: &Plan, ) -> Result)>> { - let mut schema_collector = Self::new(); + let mut schema_collector = Self::new(consumer); for plan_rel in plan.relations.iter() { let rel_type = plan_rel @@ -132,15 +148,8 @@ pub mod test { "No base schema found for NamedTable: {}", table_reference ))?; - let empty_extensions = Extensions { - functions: Default::default(), - types: Default::default(), - type_variations: Default::default(), - }; - - let df_schema = - from_substrait_named_struct(substrait_schema, &empty_extensions)? - .replace_qualifier(table_reference.clone()); + let df_schema = from_substrait_named_struct(self.consumer, substrait_schema)? + .replace_qualifier(table_reference.clone()); let table = EmptyTable::new(df_schema.inner().clone()); self.schemas.push((table_reference, Arc::new(table))); From a50ed3488f77743a192d9e8dd9c99f00df659ef1 Mon Sep 17 00:00:00 2001 From: robtandy Date: Sat, 21 Dec 2024 10:34:05 -0500 Subject: [PATCH 3/8] Minor: fix: Include FetchRel when producing LogicalPlan from Sort (#13862) * include FetchRel when producing LogicalPlan from Sort * add suggested test * address review feedback --- .../substrait/src/logical_plan/producer.rs | 38 +++++++++++++++---- .../tests/cases/roundtrip_logical_plan.rs | 10 +++++ 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 5191a620b473..b73d246e1989 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -361,21 +361,45 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), state, extensions)?; - let sort_fields = sort - .expr + LogicalPlan::Sort(datafusion::logical_expr::Sort { expr, input, fetch }) => { + let sort_fields = expr .iter() - .map(|e| substrait_sort_field(state, e, sort.input.schema(), extensions)) + .map(|e| substrait_sort_field(state, e, input.schema(), extensions)) .collect::>>()?; - Ok(Box::new(Rel { + + let input = to_substrait_rel(input.as_ref(), state, extensions)?; + + let sort_rel = Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { common: None, input: Some(input), sorts: sort_fields, advanced_extension: None, }))), - })) + }); + + match fetch { + Some(amount) => { + let count_mode = + Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::I64(*amount as i64)), + })), + }))); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(sort_rel), + offset_mode: None, + count_mode, + advanced_extension: None, + }))), + })) + } + None => Ok(sort_rel), + } } LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), state, extensions)?; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 1ce0eec1b21d..1d1a87015135 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -199,6 +199,16 @@ async fn select_with_filter() -> Result<()> { roundtrip("SELECT * FROM data WHERE a > 1").await } +#[tokio::test] +async fn select_with_filter_sort_limit() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a > 1 ORDER BY b ASC LIMIT 2").await +} + +#[tokio::test] +async fn select_with_filter_sort_limit_offset() -> Result<()> { + roundtrip("SELECT * FROM data WHERE a > 1 ORDER BY b ASC LIMIT 2 OFFSET 1").await +} + #[tokio::test] async fn select_with_reused_functions() -> Result<()> { let ctx = create_context().await?; From 8c48a8cc4d4d849d3e02f667204ee7fd81c150c8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 21 Dec 2024 16:55:08 -0500 Subject: [PATCH 4/8] Minor: improve error message when ARRAY literals can not be planned (#13859) * Minor: improve error message when ARRAY literals can not be planned * fmt * Update datafusion/sql/src/expr/value.rs Co-authored-by: Oleks V --------- Co-authored-by: Oleks V --- datafusion/sql/src/expr/value.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index a70934b5cd5d..847163c6d3b3 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -24,8 +24,8 @@ use arrow_schema::{DataType, DECIMAL256_MAX_PRECISION}; use bigdecimal::num_bigint::BigInt; use bigdecimal::{BigDecimal, Signed, ToPrimitive}; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, - DataFusionError, Result, ScalarValue, + internal_datafusion_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::expr::{BinaryExpr, Placeholder}; use datafusion_expr::planner::PlannerResult; @@ -169,7 +169,7 @@ impl SqlToRel<'_, S> { } } - internal_err!("Expected a simplified result, but none was found") + not_impl_err!("Could not plan array literal. Hint: Please try with `nested_expressions` DataFusion feature enabled") } /// Convert a SQL interval expression to a DataFusion logical plan From 2639fe045a7a724d56d014561ac801874f5e4805 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 21 Dec 2024 16:57:09 -0500 Subject: [PATCH 5/8] Add documentation for `SHOW FUNCTIONS` (#13868) --- .../user-guide/sql/information_schema.md | 56 ++++++++++++++++++- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/docs/source/user-guide/sql/information_schema.md b/docs/source/user-guide/sql/information_schema.md index bf4aa00e1dde..db74ec0708b3 100644 --- a/docs/source/user-guide/sql/information_schema.md +++ b/docs/source/user-guide/sql/information_schema.md @@ -22,7 +22,10 @@ DataFusion supports showing metadata about the tables and views available. This information can be accessed using the views of the ISO SQL `information_schema` schema or the DataFusion specific `SHOW TABLES` and `SHOW COLUMNS` commands. -To show tables in the DataFusion catalog, use the `SHOW TABLES` command or the `information_schema.tables` view: +## `SHOW TABLES` + +To show tables in the DataFusion catalog, use the `SHOW TABLES` command or the +`information_schema.tables` view: ```sql > show tables; @@ -39,7 +42,10 @@ or ``` -To show the schema of a table in DataFusion, use the `SHOW COLUMNS` command or the `information_schema.columns` view: +## `SHOW COLUMNS` + +To show the schema of a table in DataFusion, use the `SHOW COLUMNS` command or +the `information_schema.columns` view. ```sql > show columns from t; @@ -52,7 +58,10 @@ or +---------------+--------------+------------+-------------+-----------+-------------+ ``` -To show the current session configuration options, use the `SHOW ALL` command or the `information_schema.df_settings` view: +## `SHOW ALL` (configuration options) + +To show the current session configuration options, use the `SHOW ALL` command or +the `information_schema.df_settings` view: ```sql select * from information_schema.df_settings; @@ -65,7 +74,48 @@ select * from information_schema.df_settings; | datafusion.execution.time_zone | UTC | | datafusion.explain.logical_plan_only | false | | datafusion.explain.physical_plan_only | false | +... | datafusion.optimizer.filter_null_join_keys | false | | datafusion.optimizer.skip_failed_rules | true | +-------------------------------------------------+---------+ ``` + +## `SHOW FUNCTIONS` + +To show the list of functions available, use the `SHOW FUNCTIONS` command or the + +- `information_schema.information_schema.routines` view: functions and descriptions +- `information_schema.information_schema.parameters` view: parameters and descriptions + +Syntax: + +```sql +SHOW FUNCTIONS [ LIKE ]; +``` + +Example output + +```sql +> show functions like '%datetrunc'; ++---------------+-------------------------------------+-------------------------+-------------------------------------------------+---------------+-------------------------------------------------------+-----------------------------------+ +| function_name | return_type | parameters | parameter_types | function_type | description | syntax_example | ++---------------+-------------------------------------+-------------------------+-------------------------------------------------+---------------+-------------------------------------------------------+-----------------------------------+ +| datetrunc | Timestamp(Microsecond, Some("+TZ")) | [precision, expression] | [Utf8, Timestamp(Microsecond, Some("+TZ"))] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Nanosecond, None) | [precision, expression] | [Utf8View, Timestamp(Nanosecond, None)] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Second, Some("+TZ")) | [precision, expression] | [Utf8View, Timestamp(Second, Some("+TZ"))] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Microsecond, None) | [precision, expression] | [Utf8View, Timestamp(Microsecond, None)] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Second, None) | [precision, expression] | [Utf8View, Timestamp(Second, None)] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Microsecond, None) | [precision, expression] | [Utf8, Timestamp(Microsecond, None)] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Second, None) | [precision, expression] | [Utf8, Timestamp(Second, None)] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Microsecond, Some("+TZ")) | [precision, expression] | [Utf8View, Timestamp(Microsecond, Some("+TZ"))] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Nanosecond, Some("+TZ")) | [precision, expression] | [Utf8, Timestamp(Nanosecond, Some("+TZ"))] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Millisecond, None) | [precision, expression] | [Utf8, Timestamp(Millisecond, None)] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Millisecond, Some("+TZ")) | [precision, expression] | [Utf8, Timestamp(Millisecond, Some("+TZ"))] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Second, Some("+TZ")) | [precision, expression] | [Utf8, Timestamp(Second, Some("+TZ"))] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Nanosecond, None) | [precision, expression] | [Utf8, Timestamp(Nanosecond, None)] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Millisecond, None) | [precision, expression] | [Utf8View, Timestamp(Millisecond, None)] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Millisecond, Some("+TZ")) | [precision, expression] | [Utf8View, Timestamp(Millisecond, Some("+TZ"))] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | +| datetrunc | Timestamp(Nanosecond, Some("+TZ")) | [precision, expression] | [Utf8View, Timestamp(Nanosecond, Some("+TZ"))] | SCALAR | Truncates a timestamp value to a specified precision. | date_trunc(precision, expression) | ++---------------+-------------------------------------+-------------------------+-------------------------------------------------+---------------+-------------------------------------------------------+-----------------------------------+ +16 row(s) fetched. +``` From a267784bc60910bfdf558b4f8e600c1890ad6245 Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Sun, 22 Dec 2024 18:10:35 +0700 Subject: [PATCH 6/8] Support unicode character for `initcap` function (#13752) * Support unicode character for 'initcap' function Signed-off-by: Tai Le Manh * Update unit tests * Fix clippy warning * Update sqllogictests - initcap * Update scalar_functions.md docs * Add suggestions change Signed-off-by: Tai Le Manh --------- Signed-off-by: Tai Le Manh --- datafusion/functions/Cargo.toml | 2 +- datafusion/functions/benches/initcap.rs | 4 +- datafusion/functions/src/string/mod.rs | 7 -- .../src/{string => unicode}/initcap.rs | 114 +++++++++++++----- datafusion/functions/src/unicode/mod.rs | 7 ++ .../test_files/string/string_query.slt.part | 2 +- .../source/user-guide/sql/scalar_functions.md | 4 +- 7 files changed, 93 insertions(+), 47 deletions(-) rename datafusion/functions/src/{string => unicode}/initcap.rs (68%) diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index de72c7ee946b..fd986c4be41c 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -212,4 +212,4 @@ required-features = ["math_expressions"] [[bench]] harness = false name = "initcap" -required-features = ["string_expressions"] +required-features = ["unicode_expressions"] diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index c88b6b513980..97c76831b33c 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -24,7 +24,7 @@ use arrow::util::bench_util::{ }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use datafusion_functions::string; +use datafusion_functions::unicode; use std::sync::Arc; fn create_args( @@ -46,7 +46,7 @@ fn create_args( } fn criterion_benchmark(c: &mut Criterion) { - let initcap = string::initcap(); + let initcap = unicode::initcap(); for size in [1024, 4096] { let args = create_args::(size, 8, true); c.bench_function( diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index f156f070d960..c43aaeccbefe 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -30,7 +30,6 @@ pub mod concat; pub mod concat_ws; pub mod contains; pub mod ends_with; -pub mod initcap; pub mod levenshtein; pub mod lower; pub mod ltrim; @@ -52,7 +51,6 @@ make_udf_function!(chr::ChrFunc, chr); make_udf_function!(concat::ConcatFunc, concat); make_udf_function!(concat_ws::ConcatWsFunc, concat_ws); make_udf_function!(ends_with::EndsWithFunc, ends_with); -make_udf_function!(initcap::InitcapFunc, initcap); make_udf_function!(levenshtein::LevenshteinFunc, levenshtein); make_udf_function!(ltrim::LtrimFunc, ltrim); make_udf_function!(lower::LowerFunc, lower); @@ -94,10 +92,6 @@ pub mod expr_fn { ends_with, "Returns true if the `string` ends with the `suffix`, false otherwise.", string suffix - ),( - initcap, - "Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase", - string ),( levenshtein, "Returns the Levenshtein distance between the two given strings", @@ -177,7 +171,6 @@ pub fn functions() -> Vec> { concat(), concat_ws(), ends_with(), - initcap(), levenshtein(), lower(), ltrim(), diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/unicode/initcap.rs similarity index 68% rename from datafusion/functions/src/string/initcap.rs rename to datafusion/functions/src/unicode/initcap.rs index 2780dcaeeb83..e9f966b95868 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/unicode/initcap.rs @@ -18,7 +18,9 @@ use std::any::Any; use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; +use arrow::array::{ + Array, ArrayRef, GenericStringBuilder, OffsetSizeTrait, StringViewBuilder, +}; use arrow::datatypes::DataType; use crate::utils::{make_scalar_function, utf8_to_str_type}; @@ -74,7 +76,7 @@ impl ScalarUDFImpl for InitcapFunc { DataType::LargeUtf8 => make_scalar_function(initcap::, vec![])(args), DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args), other => { - exec_err!("Unsupported data type {other:?} for function initcap") + exec_err!("Unsupported data type {other:?} for function `initcap`") } } } @@ -90,9 +92,8 @@ fn get_initcap_doc() -> &'static Documentation { DOCUMENTATION.get_or_init(|| { Documentation::builder( DOC_SECTION_STRING, - "Capitalizes the first character in each word in the ASCII input string. \ - Words are delimited by non-alphanumeric characters.\n\n\ - Note this function does not support UTF-8 characters.", + "Capitalizes the first character in each word in the input string. \ + Words are delimited by non-alphanumeric characters.", "initcap(str)", ) .with_sql_example( @@ -123,50 +124,70 @@ fn get_initcap_doc() -> &'static Documentation { fn initcap(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; - // first map is the iterator, second is for the `Option<_>` - let result = string_array - .iter() - .map(initcap_string) - .collect::>(); + let mut builder = GenericStringBuilder::::with_capacity( + string_array.len(), + string_array.value_data().len(), + ); - Ok(Arc::new(result) as ArrayRef) + string_array.iter().for_each(|str| match str { + Some(s) => { + let initcap_str = initcap_string(s); + builder.append_value(initcap_str); + } + None => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) } fn initcap_utf8view(args: &[ArrayRef]) -> Result { let string_view_array = as_string_view_array(&args[0])?; - let result = string_view_array - .iter() - .map(initcap_string) - .collect::(); + let mut builder = StringViewBuilder::with_capacity(string_view_array.len()); + + string_view_array.iter().for_each(|str| match str { + Some(s) => { + let initcap_str = initcap_string(s); + builder.append_value(initcap_str); + } + None => builder.append_null(), + }); - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } -fn initcap_string(input: Option<&str>) -> Option { - input.map(|s| { - let mut result = String::with_capacity(s.len()); - let mut prev_is_alphanumeric = false; +fn initcap_string(input: &str) -> String { + let mut result = String::with_capacity(input.len()); + let mut prev_is_alphanumeric = false; - for c in s.chars() { - let transformed = if prev_is_alphanumeric { - c.to_ascii_lowercase() + if input.is_ascii() { + for c in input.chars() { + if prev_is_alphanumeric { + result.push(c.to_ascii_lowercase()); } else { - c.to_ascii_uppercase() + result.push(c.to_ascii_uppercase()); }; - result.push(transformed); prev_is_alphanumeric = c.is_ascii_alphanumeric(); } + } else { + for c in input.chars() { + if prev_is_alphanumeric { + result.extend(c.to_lowercase()); + } else { + result.extend(c.to_uppercase()); + } + prev_is_alphanumeric = c.is_alphanumeric(); + } + } - result - }) + result } #[cfg(test)] mod tests { - use crate::string::initcap::InitcapFunc; + use crate::unicode::initcap::InitcapFunc; use crate::utils::test::test_function; - use arrow::array::{Array, StringArray}; + use arrow::array::{Array, StringArray, StringViewArray}; use arrow::datatypes::DataType::Utf8; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -181,6 +202,19 @@ mod tests { Utf8, StringArray ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ" + .to_string() + )))], + Ok(Some( + "Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική" + )), + &str, + Utf8, + StringArray + ); test_function!( InitcapFunc::new(), vec![ColumnarValue::Scalar(ScalarValue::from(""))], @@ -205,6 +239,7 @@ mod tests { Utf8, StringArray ); + test_function!( InitcapFunc::new(), vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( @@ -213,7 +248,7 @@ mod tests { Ok(Some("Hi Thomas")), &str, Utf8, - StringArray + StringViewArray ); test_function!( InitcapFunc::new(), @@ -223,7 +258,20 @@ mod tests { Ok(Some("Hi Thomas With M0re Than 12 Chars")), &str, Utf8, - StringArray + StringViewArray + ); + test_function!( + InitcapFunc::new(), + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "đẸp đẼ êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ" + .to_string() + )))], + Ok(Some( + "Đẹp Đẽ Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική" + )), + &str, + Utf8, + StringViewArray ); test_function!( InitcapFunc::new(), @@ -233,7 +281,7 @@ mod tests { Ok(Some("")), &str, Utf8, - StringArray + StringViewArray ); test_function!( InitcapFunc::new(), @@ -241,7 +289,7 @@ mod tests { Ok(None), &str, Utf8, - StringArray + StringViewArray ); Ok(()) diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index f31ece9196d8..e8e3eb3f4e75 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -23,6 +23,7 @@ use datafusion_expr::ScalarUDF; pub mod character_length; pub mod find_in_set; +pub mod initcap; pub mod left; pub mod lpad; pub mod reverse; @@ -36,6 +37,7 @@ pub mod translate; // create UDFs make_udf_function!(character_length::CharacterLengthFunc, character_length); make_udf_function!(find_in_set::FindInSetFunc, find_in_set); +make_udf_function!(initcap::InitcapFunc, initcap); make_udf_function!(left::LeftFunc, left); make_udf_function!(lpad::LPadFunc, lpad); make_udf_function!(right::RightFunc, right); @@ -94,6 +96,10 @@ pub mod expr_fn { left, "returns the first `n` characters in the `string`", string n + ),( + initcap, + "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase", + string ),( find_in_set, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings", @@ -126,6 +132,7 @@ pub fn functions() -> Vec> { vec![ character_length(), find_in_set(), + initcap(), left(), lpad(), reverse(), diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index 80fcc0102887..2414e5864c99 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -460,7 +460,7 @@ Andrew Datafusion📊🔥 Xiangpeng Datafusion数据融合 Raphael Datafusionдатафусион Under_Score Un Iść Core -Percent Pan Tadeusz Ma Iść W KąT +Percent Pan Tadeusz Ma Iść W Kąt (empty) (empty) (empty) (empty) % (empty) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 2e4147f96e0f..be4f5e56b3af 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1070,9 +1070,7 @@ find_in_set(str, strlist) ### `initcap` -Capitalizes the first character in each word in the ASCII input string. Words are delimited by non-alphanumeric characters. - -Note this function does not support UTF-8 characters. +Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters. ``` initcap(str) From e6f5cb6daa92fc6ede2792e4eacb2126ab1fa46f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20=C5=9Een?= Date: Sun, 22 Dec 2024 15:59:56 +0300 Subject: [PATCH 7/8] [minor] make recursive package dependency optional (#13778) * make recursive optional * add to default for common package * cargo update * added to readme * make test conditional * reviews * cargo update --------- Co-authored-by: Andrew Lamb --- README.md | 3 ++- datafusion-cli/Cargo.lock | 1 - datafusion/common/Cargo.toml | 4 +++- datafusion/common/src/tree_node.rs | 14 ++++++------- datafusion/expr/Cargo.toml | 4 +++- datafusion/expr/src/expr_schema.rs | 3 +-- datafusion/expr/src/logical_plan/tree_node.rs | 13 ++++++------ datafusion/optimizer/Cargo.toml | 6 +++++- datafusion/optimizer/src/analyzer/subquery.rs | 7 +++---- .../optimizer/src/common_subexpr_eliminate.rs | 5 ++--- .../optimizer/src/eliminate_cross_join.rs | 21 +++++++++---------- .../optimizer/src/optimize_projections/mod.rs | 3 +-- datafusion/physical-optimizer/Cargo.toml | 6 +++++- .../src/aggregate_statistics.rs | 3 +-- datafusion/sql/Cargo.toml | 5 +++-- datafusion/sql/src/expr/mod.rs | 3 +-- datafusion/sql/src/set_expr.rs | 3 +-- 17 files changed, 54 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index f199021d7d78..c2ede4833e9b 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,8 @@ Default features: - `parquet`: support for reading the [Apache Parquet] format - `regex_expressions`: regular expression functions, such as `regexp_match` - `unicode_expressions`: Include unicode aware functions such as `character_length` -- `unparser` : enables support to reverse LogicalPlans back into SQL +- `unparser`: enables support to reverse LogicalPlans back into SQL +- `recursive-protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection. Optional features: diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9549cfeeb3b8..9af27a90bc2a 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1544,7 +1544,6 @@ dependencies = [ "indexmap", "itertools", "log", - "recursive", "regex", "regex-syntax", ] diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index a81ec724dd66..918f0cd583f7 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -36,10 +36,12 @@ name = "datafusion_common" path = "src/lib.rs" [features] +default = ["recursive-protection"] avro = ["apache-avro"] backtrace = [] pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] force_hash_collisions = [] +recursive-protection = ["dep:recursive"] [dependencies] ahash = { workspace = true } @@ -62,7 +64,7 @@ object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" pyo3 = { version = "0.22.0", optional = true } -recursive = { workspace = true } +recursive = { workspace = true, optional = true } sqlparser = { workspace = true } tokio = { workspace = true } diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 0c153583e34b..d92a2cc34b56 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -18,7 +18,6 @@ //! [`TreeNode`] for visiting and rewriting expression and plan trees use crate::Result; -use recursive::recursive; use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; @@ -125,7 +124,7 @@ pub trait TreeNode: Sized { /// TreeNodeVisitor::f_up(ChildNode2) /// TreeNodeVisitor::f_up(ParentNode) /// ``` - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>( &'n self, visitor: &mut V, @@ -175,7 +174,7 @@ pub trait TreeNode: Sized { /// TreeNodeRewriter::f_up(ChildNode2) /// TreeNodeRewriter::f_up(ParentNode) /// ``` - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn rewrite>( self, rewriter: &mut R, @@ -198,7 +197,7 @@ pub trait TreeNode: Sized { &'n self, mut f: F, ) -> Result { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result>( node: &'n N, f: &mut F, @@ -233,7 +232,7 @@ pub trait TreeNode: Sized { self, mut f: F, ) -> Result> { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn transform_down_impl Result>>( node: N, f: &mut F, @@ -257,7 +256,7 @@ pub trait TreeNode: Sized { self, mut f: F, ) -> Result> { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn transform_up_impl Result>>( node: N, f: &mut F, @@ -372,7 +371,7 @@ pub trait TreeNode: Sized { mut f_down: FD, mut f_up: FU, ) -> Result> { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn transform_down_up_impl< N: TreeNode, FD: FnMut(N) -> Result>, @@ -2350,6 +2349,7 @@ pub(crate) mod tests { Ok(()) } + #[cfg(feature = "recursive-protection")] #[test] fn test_large_tree() { let mut item = TestTreeNode::new_leaf("initial".to_string()); diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 2f41292f680f..403a80972c3b 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -36,6 +36,8 @@ name = "datafusion_expr" path = "src/lib.rs" [features] +default = ["recursive-protection"] +recursive-protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } @@ -48,7 +50,7 @@ datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } indexmap = { workspace = true } paste = "^1.0" -recursive = { workspace = true } +recursive = { workspace = true, optional = true } serde_json = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3317deafbd6c..e61904e24918 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -32,7 +32,6 @@ use datafusion_common::{ TableReference, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; -use recursive::recursive; use std::collections::HashMap; use std::sync::Arc; @@ -100,7 +99,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 1539b69b4007..cdc95b84d837 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -45,7 +45,6 @@ use crate::{ UserDefinedLogicalNode, Values, Window, }; use datafusion_common::tree_node::TreeNodeRefContainer; -use recursive::recursive; use crate::expr::{Exists, InSubquery}; use datafusion_common::tree_node::{ @@ -669,7 +668,7 @@ impl LogicalPlan { /// Visits a plan similarly to [`Self::visit`], including subqueries that /// may appear in expressions such as `IN (SELECT ...)`. - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] pub fn visit_with_subqueries TreeNodeVisitor<'n, Node = Self>>( &self, visitor: &mut V, @@ -688,7 +687,7 @@ impl LogicalPlan { /// Similarly to [`Self::rewrite`], rewrites this node and its inputs using `f`, /// including subqueries that may appear in expressions such as `IN (SELECT /// ...)`. - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] pub fn rewrite_with_subqueries>( self, rewriter: &mut R, @@ -707,7 +706,7 @@ impl LogicalPlan { &self, mut f: F, ) -> Result { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn apply_with_subqueries_impl< F: FnMut(&LogicalPlan) -> Result, >( @@ -742,7 +741,7 @@ impl LogicalPlan { self, mut f: F, ) -> Result> { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn transform_down_with_subqueries_impl< F: FnMut(LogicalPlan) -> Result>, >( @@ -767,7 +766,7 @@ impl LogicalPlan { self, mut f: F, ) -> Result> { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn transform_up_with_subqueries_impl< F: FnMut(LogicalPlan) -> Result>, >( @@ -795,7 +794,7 @@ impl LogicalPlan { mut f_down: FD, mut f_up: FU, ) -> Result> { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn transform_down_up_with_subqueries_impl< FD: FnMut(LogicalPlan) -> Result>, FU: FnMut(LogicalPlan) -> Result>, diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 9979df689b0a..3032c67682b1 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -35,6 +35,10 @@ workspace = true name = "datafusion_optimizer" path = "src/lib.rs" +[features] +default = ["recursive-protection"] +recursive-protection = ["dep:recursive"] + [dependencies] arrow = { workspace = true } chrono = { workspace = true } @@ -44,7 +48,7 @@ datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } -recursive = { workspace = true } +recursive = { workspace = true, optional = true } regex = { workspace = true } regex-syntax = "0.8.0" diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index fee06eeb9f75..0d04efbcf36a 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -17,7 +17,6 @@ use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; -use recursive::recursive; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; @@ -79,7 +78,7 @@ pub fn check_subquery_expr( match outer_plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) => Ok(()), - LogicalPlan::Aggregate(Aggregate {group_expr, aggr_expr,..}) => { + LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. }) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( @@ -88,7 +87,7 @@ pub fn check_subquery_expr( } else { Ok(()) } - }, + } _ => plan_err!( "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" ) @@ -129,7 +128,7 @@ fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> { } // Recursively check the unsupported outer references in the sub query plan. -#[recursive] +#[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> { if !can_contain_outer_ref && inner_plan.contains_outer_reference() { return plan_err!("Accessing outer reference columns is not allowed in the plan"); diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e7c9a198f3ad..ff75a6a60f4b 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -22,7 +22,6 @@ use std::fmt::Debug; use std::sync::Arc; use crate::{OptimizerConfig, OptimizerRule}; -use recursive::recursive; use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; @@ -532,7 +531,7 @@ impl OptimizerRule for CommonSubexprEliminate { None } - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn rewrite( &self, plan: LogicalPlan, @@ -952,7 +951,7 @@ mod test { )? .build()?; - let expected ="Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 32b7ce44a63a..9a47f437e444 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -17,7 +17,6 @@ //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. use crate::{OptimizerConfig, OptimizerRule}; -use recursive::recursive; use std::sync::Arc; use crate::join_key_set::JoinKeySet; @@ -80,7 +79,7 @@ impl OptimizerRule for EliminateCrossJoin { true } - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn rewrite( &self, plan: LogicalPlan, @@ -651,7 +650,7 @@ mod tests { " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" ]; @@ -1237,10 +1236,10 @@ mod tests { .build()?; let expected = vec![ - "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(plan, expected); @@ -1293,10 +1292,10 @@ mod tests { .build()?; let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(plan, expected); diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 1519c54dbf68..7c8e4120ea20 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -21,7 +21,6 @@ mod required_indices; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use recursive::recursive; use std::collections::HashSet; use std::sync::Arc; @@ -110,7 +109,7 @@ impl OptimizerRule for OptimizeProjections { /// columns. /// - `Ok(None)`: Signal that the given logical plan did not require any change. /// - `Err(error)`: An error occurred during the optimization process. -#[recursive] +#[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn optimize_projections( plan: LogicalPlan, config: &dyn OptimizerConfig, diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 838617ae9889..c964ca47e6a0 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -31,6 +31,10 @@ rust-version = { workspace = true } [lints] workspace = true +[features] +default = ["recursive-protection"] +recursive-protection = ["dep:recursive"] + [dependencies] arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } @@ -40,7 +44,7 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } itertools = { workspace = true } log = { workspace = true } -recursive = { workspace = true } +recursive = { workspace = true, optional = true } [dev-dependencies] datafusion-expr = { workspace = true } diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index 87077183110d..dffdc49adf09 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -25,7 +25,6 @@ use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; use datafusion_physical_plan::{expressions, ExecutionPlan}; -use recursive::recursive; use std::sync::Arc; use crate::PhysicalOptimizerRule; @@ -42,7 +41,7 @@ impl AggregateStatistics { } impl PhysicalOptimizerRule for AggregateStatistics { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn optimize( &self, plan: Arc, diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index e1e4d8df3d22..c6500e974206 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -36,9 +36,10 @@ name = "datafusion_sql" path = "src/lib.rs" [features] -default = ["unicode_expressions", "unparser"] +default = ["unicode_expressions", "unparser", "recursive-protection"] unicode_expressions = [] unparser = [] +recursive-protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } @@ -49,7 +50,7 @@ datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } indexmap = { workspace = true } log = { workspace = true } -recursive = { workspace = true } +recursive = { workspace = true, optional = true } regex = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a651d8fa5d35..7c4d8dd21d66 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -20,7 +20,6 @@ use arrow_schema::TimeUnit; use datafusion_expr::planner::{ PlannerResult, RawBinaryExpr, RawDictionaryExpr, RawFieldAccessExpr, }; -use recursive::recursive; use sqlparser::ast::{ BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript, @@ -197,7 +196,7 @@ impl SqlToRel<'_, S> { /// Internal implementation. Use /// [`Self::sql_expr_to_logical_expr`] to plan exprs. - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] fn sql_expr_to_logical_expr_internal( &self, sql: SQLExpr, diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 3b1201d3dd59..d1569c81d350 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -18,11 +18,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; -use recursive::recursive; use sqlparser::ast::{SetExpr, SetOperator, SetQuantifier}; impl SqlToRel<'_, S> { - #[recursive] + #[cfg_attr(feature = "recursive-protection", recursive::recursive)] pub(super) fn set_expr_to_plan( &self, set_expr: SetExpr, From 242f45f4f19c9f25f8f084e8b2c534d9f14fa2d7 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sun, 22 Dec 2024 21:48:00 +0800 Subject: [PATCH 8/8] Minor: remove unused async-compression `futures-io` feature (#13875) * Minor: remove unused async-compression feature * Fix cli cargo lock --- datafusion-cli/Cargo.lock | 1 - datafusion/core/Cargo.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9af27a90bc2a..34505bee2e13 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -414,7 +414,6 @@ dependencies = [ "bzip2 0.4.4", "flate2", "futures-core", - "futures-io", "memchr", "pin-project-lite", "tokio", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 9bf530a9d6ac..dca40ab3d67a 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -87,7 +87,6 @@ async-compression = { version = "0.4.0", features = [ "gzip", "xz", "zstd", - "futures-io", "tokio", ], optional = true } async-trait = { workspace = true }