From 07760d3c45e557e702b4970c2c800d59b6029894 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Fri, 24 Jan 2025 18:59:07 +0000 Subject: [PATCH] More single-pass analyzer functionality --- .../sql/catalyst/analysis/resolver.scala | 26 + .../AggregateExpressionResolver.scala | 156 ++++ .../analysis/resolver/AliasResolver.scala | 136 +-- .../resolver/AttributeScopeStack.scala | 2 +- .../resolver/BinaryArithmeticResolver.scala | 38 +- .../resolver/BridgedRelationsProvider.scala | 22 +- .../ConditionalExpressionResolver.scala | 12 +- .../catalyst/analysis/resolver/CteScope.scala | 257 ++++++ .../DelegatesResolutionToExtensions.scala | 8 +- ...ExplicitlyUnsupportedResolverFeature.scala | 1 - .../resolver/ExpressionIdAssigner.scala | 374 ++++++++ .../ExpressionResolutionContext.scala | 43 + .../ExpressionResolutionValidator.scala | 273 +----- .../resolver/ExpressionResolver.scala | 395 +++++++-- .../analysis/resolver/FunctionResolver.scala | 161 +++- .../analysis/resolver/HybridAnalyzer.scala | 16 +- .../analysis/resolver/IdentifierMap.scala | 9 - .../resolver/KeyTransformingMap.scala | 21 +- ...LateralColumnAliasProhibitedRegistry.scala | 44 + .../resolver/LateralColumnAliasRegistry.scala | 46 + .../LateralColumnAliasRegistryImpl.scala | 182 ++++ .../resolver/LimitExpressionResolver.scala | 10 +- .../analysis/resolver/MetadataResolver.scala | 25 +- .../analysis/resolver/NameScope.scala | 630 ++++++++------ .../analysis/resolver/NameTarget.scala | 85 +- .../analysis/resolver/PredicateResolver.scala | 19 +- .../resolver/ProhibitedResolver.scala | 35 + .../analysis/resolver/ProjectResolver.scala | 164 ++++ .../resolver/RelationMetadataProvider.scala | 7 + .../resolver/ResolutionValidator.scala | 137 ++- .../resolver/ResolvedProjectList.scala | 36 + .../catalyst/analysis/resolver/Resolver.scala | 407 +++++++-- .../analysis/resolver/ResolverExtension.scala | 14 +- .../analysis/resolver/ResolverGuard.scala | 135 ++- .../analysis/resolver/ResolverRunner.scala | 72 ++ .../analysis/resolver/TimeAddResolver.scala | 16 +- .../TimezoneAwareExpressionResolver.scala | 13 +- .../resolver/TracksResolvedNodes.scala | 51 -- .../analysis/resolver/TreeNodeResolver.scala | 4 +- .../resolver/TypeCoercionResolver.scala | 108 ++- .../resolver/UnaryMinusResolver.scala | 16 +- .../analysis/resolver/UnionResolver.scala | 378 ++++++++ .../analysis/resolver/ViewResolver.scala | 145 ++++ .../apache/spark/sql/internal/SQLConf.scala | 13 - .../LimitExpressionResolverSuite.scala | 9 +- .../resolver/ResolutionValidatorSuite.scala | 4 +- .../datasources/DataSourceResolver.scala | 17 +- .../execution/datasources/FileResolver.scala | 10 +- .../AggregateExpressionResolverSuite.scala | 72 ++ ...citlyUnsupportedResolverFeatureSuite.scala | 66 +- .../resolver/ExpressionIdAssignerSuite.scala | 818 ++++++++++++++++++ .../resolver/HybridAnalyzerSuite.scala | 30 +- .../resolver/MetadataResolverSuite.scala | 43 +- .../analysis/resolver/NameScopeSuite.scala | 466 ++++------ .../resolver/ResolverGuardSuite.scala | 68 +- .../sql/analysis/resolver/ResolverSuite.scala | 73 +- .../resolver/TracksResolvedNodesSuite.scala | 135 --- .../analysis/resolver/ViewResolverSuite.scala | 194 +++++ .../datasources/DataSourceResolverSuite.scala | 9 +- .../datasources/FileResolverSuite.scala | 10 +- .../sql/hive/DataSourceWithHiveResolver.scala | 32 +- .../DataSourceWithHiveResolverSuite.scala | 12 +- 62 files changed, 5147 insertions(+), 1663 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CteScope.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasProhibitedRegistry.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistry.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistryImpl.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProhibitedResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TracksResolvedNodes.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/AggregateExpressionResolverSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExpressionIdAssignerSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/TracksResolvedNodesSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ViewResolverSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver.scala new file mode 100644 index 0000000000000..3cbe49c3156c6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +package object resolver { + + type LogicalPlanResolver = TreeNodeResolver[LogicalPlan, LogicalPlan] + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala new file mode 100644 index 0000000000000..ccbb82a0bac34 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.analysis.{ + AnalysisErrorAt, + AnsiTypeCoercion, + CollationTypeCoercion, + TypeCoercion +} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Project} + +/** + * A resolver for [[AggregateExpression]]s which are introduced while resolving an + * [[UnresolvedFunction]]. It is responsible for the following: + * - Handling of the exceptions related to [[AggregateExpressions]]. + * - Updating the [[ExpressionResolver.expressionResolutionContextStack]]. + * - Applying type coercion rules to the [[AggregateExpressions]]s children. This is the only + * resolution that we apply here as we already resolved the children of [[AggregateExpression]] + * in the [[FunctionResolver]]. + */ +class AggregateExpressionResolver( + expressionResolver: ExpressionResolver, + timezoneAwareExpressionResolver: TimezoneAwareExpressionResolver) + extends TreeNodeResolver[AggregateExpression, Expression] + with ResolvesExpressionChildren { + private val typeCoercionTransformations: Seq[Expression => Expression] = + if (conf.ansiEnabled) { + AggregateExpressionResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS + } else { + AggregateExpressionResolver.TYPE_COERCION_TRANSFORMATIONS + } + + private val typeCoercionResolver: TypeCoercionResolver = + new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionTransformations) + + private val expressionResolutionContextStack = + expressionResolver.getExpressionResolutionContextStack + + /** + * Resolves the given [[AggregateExpression]] by applying: + * - Type coercion rules + * - Validity checks. Those include: + * - Whether the [[AggregateExpression]] is under a valid operator. + * - Whether there is a nested [[AggregateExpression]]. + * - Whether there is a nondeterministic child. + * - Updates to the [[ExpressionResolver.expressionResolutionContextStack]] + */ + override def resolve(aggregateExpression: AggregateExpression): Expression = { + val aggregateExpressionWithTypeCoercion = + withResolvedChildren(aggregateExpression, typeCoercionResolver.resolve) + + throwIfNotUnderValidOperator(aggregateExpression) + throwIfNestedAggregateExists(aggregateExpressionWithTypeCoercion) + throwIfHasNondeterministicChildren(aggregateExpressionWithTypeCoercion) + + expressionResolutionContextStack + .peek() + .hasAggregateExpressionsInASubtree = true + + // There are two different cases that we handle regarding the value of the flag: + // + // - We have an attribute under an `AggregateExpression`: + // {{{ SELECT COUNT(col1) FROM VALUES (1); }}} + // In this case, value of the `hasAttributeInASubtree` flag should be `false` as it + // indicates whether there is an attribute in the subtree that's not `AggregateExpression` + // so we can throw the `MISSING_GROUP_BY` exception appropriately. + // + // - In the following example: + // {{{ SELECT COUNT(*), col1 + 1 FROM VALUES (1); }}} + // It would be `true` as described above. + expressionResolutionContextStack.peek().hasAttributeInASubtree = false + + aggregateExpressionWithTypeCoercion + } + + private def throwIfNotUnderValidOperator(aggregateExpression: AggregateExpression): Unit = { + expressionResolver.getParentOperator.get match { + case _: Aggregate | _: Project => + case filter: Filter => + filter.failAnalysis( + errorClass = "INVALID_WHERE_CONDITION", + messageParameters = Map( + "condition" -> toSQLExpr(filter.condition), + "expressionList" -> Seq(aggregateExpression).mkString(", ") + ) + ) + case other => + other.failAnalysis( + errorClass = "UNSUPPORTED_EXPR_FOR_OPERATOR", + messageParameters = Map( + "invalidExprSqls" -> Seq(aggregateExpression).mkString(", ") + ) + ) + } + } + + private def throwIfNestedAggregateExists(aggregateExpression: AggregateExpression): Unit = { + if (expressionResolutionContextStack + .peek() + .hasAggregateExpressionsInASubtree) { + aggregateExpression.failAnalysis( + errorClass = "NESTED_AGGREGATE_FUNCTION", + messageParameters = Map.empty + ) + } + } + + private def throwIfHasNondeterministicChildren(aggregateExpression: AggregateExpression): Unit = { + aggregateExpression.aggregateFunction.children.foreach(child => { + if (!child.deterministic) { + child.failAnalysis( + errorClass = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION", + messageParameters = Map("sqlExpr" -> toSQLExpr(aggregateExpression)) + ) + } + }) + } +} + +object AggregateExpressionResolver { + // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. + private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( + CollationTypeCoercion.apply, + TypeCoercion.InTypeCoercion.apply, + TypeCoercion.FunctionArgumentTypeCoercion.apply, + TypeCoercion.IfTypeCoercion.apply, + TypeCoercion.ImplicitTypeCoercion.apply + ) + + // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. + private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( + CollationTypeCoercion.apply, + AnsiTypeCoercion.InTypeCoercion.apply, + AnsiTypeCoercion.FunctionArgumentTypeCoercion.apply, + AnsiTypeCoercion.IfTypeCoercion.apply, + AnsiTypeCoercion.ImplicitTypeCoercion.apply + ) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala index 7b652437dbd8b..93efbc68f8077 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala @@ -17,14 +17,8 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import org.apache.spark.sql.catalyst.analysis.{AliasResolution, UnresolvedAlias} -import org.apache.spark.sql.catalyst.expressions.{ - Alias, - Cast, - CreateNamedStruct, - Expression, - NamedExpression -} +import org.apache.spark.sql.catalyst.analysis.{AliasResolution, MultiAlias, UnresolvedAlias} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} /** * Resolver class that resolves unresolved aliases and handles user-specified aliases. @@ -34,109 +28,41 @@ class AliasResolver(expressionResolver: ExpressionResolver, scopes: NameScopeSta with ResolvesExpressionChildren { /** - * Resolves [[UnresolvedAlias]] by handling two specific cases: - * - Alias(CreateNamedStruct(...)) - instead of calling [[CreateNamedStructResolver]] which will - * clean up its inner aliases, we manually resolve [[CreateNamedStruct]]'s children, because we - * need to preserve inner aliases until after the alias name is computed. This is a hack because - * fixed-point analyzer computes [[Alias]] name before removing inner aliases. - * - Alias(...) - recursively call [[ExpressionResolver]] to resolve the child expression. - * - * After the children are resolved, call [[AliasResolution]] to compute the alias name. Finally, - * clean up inner aliases from [[CreateNamedStruct]]. + * Resolves [[UnresolvedAlias]] by resolving its child and computing the alias name by calling + * [[AliasResolution]] on the result. After resolving it, we assign a correct exprId to the + * resulting [[Alias]]. Here we allow inner aliases to persist until the end of single-pass + * resolution, after which they will be removed in the post-processing phase. */ - override def resolve(unresolvedAlias: UnresolvedAlias): NamedExpression = { - val aliasWithResolvedChildren = withResolvedChildren( - unresolvedAlias, { - case createNamedStruct: CreateNamedStruct => - withResolvedChildren(createNamedStruct, expressionResolver.resolve) - case other => expressionResolver.resolve(other) - } - ) + override def resolve(unresolvedAlias: UnresolvedAlias): NamedExpression = + scopes.top.lcaRegistry.withNewLcaScope { + val aliasWithResolvedChildren = + withResolvedChildren(unresolvedAlias, expressionResolver.resolve) - val resolvedAlias = - AliasResolution.resolve(aliasWithResolvedChildren).asInstanceOf[NamedExpression] + val resolvedAlias = + AliasResolution.resolve(aliasWithResolvedChildren).asInstanceOf[NamedExpression] - scopes.top.addAlias(resolvedAlias.name) - AliasResolver.cleanupAliases(resolvedAlias) - } + resolvedAlias match { + case multiAlias: MultiAlias => + throw new ExplicitlyUnsupportedResolverFeature( + s"unsupported expression: ${multiAlias.getClass.getName}" + ) + case alias: Alias => + expressionResolver.getExpressionIdAssigner + .mapExpression(alias) + .asInstanceOf[Alias] + } + } /** - * Handle already resolved [[Alias]] nodes, i.e. user-specified aliases. We disallow stacking - * of [[Alias]] nodes by collapsing them so that only the top node remains. - * - * For an example query like: - * - * {{{ SELECT 1 AS a }}} - * - * parsed plan will be: - * - * Project [Alias(1, a)] - * +- OneRowRelation - * + * Handle already resolved [[Alias]] nodes, i.e. user-specified aliases. Here we only need to + * resolve its children and afterwards reassign exprId to the resulting [[Alias]]. */ def handleResolvedAlias(alias: Alias): Alias = { - val aliasWithResolvedChildren = withResolvedChildren(alias, expressionResolver.resolve) - scopes.top.addAlias(aliasWithResolvedChildren.name) - AliasResolver.collapseAlias(aliasWithResolvedChildren) - } -} - -object AliasResolver { - - /** - * For a query like: - * - * {{{ SELECT STRUCT(1 AS a, 2 AS b) AS st }}} - * - * After resolving [[CreateNamedStruct]] the plan will be: - * CreateNamedStruct(Seq("a", Alias(1, "a"), "b", Alias(2, "b"))) - * - * For a query like: - * - * {{{ df.select($"col1".cast("int").cast("double")) }}} - * - * After resolving top-most [[Alias]] the plan will be: - * Alias(Cast(Alias(Cast(col1, int), col1)), double), col1) - * - * Both examples contain inner aliases that are not expected in the analyzed logical plan, - * therefore need to be removed. However, in both examples inner aliases are necessary in order - * for the outer alias to compute its name. To achieve this, we delay removal of inner aliases - * until after the outer alias name is computed. - * - * For cases where there are no dependencies on inner alias, inner alias should be removed by the - * resolver that produces it. - */ - private def cleanupAliases(namedExpression: NamedExpression): NamedExpression = - namedExpression - .withNewChildren(namedExpression.children.map { - case cast @ Cast(alias: Alias, _, _, _) => - cast.copy(child = alias.child) - case createNamedStruct: CreateNamedStruct => - CreateNamedStructResolver.cleanupAliases(createNamedStruct) - case other => other - }) - .asInstanceOf[NamedExpression] - - /** - * If an [[Alias]] node appears on top of another [[Alias]], remove the bottom one. Here we don't - * handle a case where a node of different type appears between two [[Alias]] nodes: in this - * case, removal of inner alias (if it is unnecessary) should be handled by respective node's - * resolver, in order to preserve the bottom-up contract. - */ - private def collapseAlias(alias: Alias): Alias = - alias.child match { - case innerAlias: Alias => - val metadata = if (alias.metadata.isEmpty) { - None - } else { - Some(alias.metadata) - } - alias.copy(child = innerAlias.child)( - exprId = alias.exprId, - qualifier = alias.qualifier, - explicitMetadata = metadata, - nonInheritableMetadataKeys = alias.nonInheritableMetadataKeys - ) - case _ => alias + scopes.top.lcaRegistry.withNewLcaScope { + val aliasWithResolvedChildren = withResolvedChildren(alias, expressionResolver.resolve) + expressionResolver.getExpressionIdAssigner + .mapExpression(aliasWithResolvedChildren) + .asInstanceOf[Alias] } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AttributeScopeStack.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AttributeScopeStack.scala index 6f9d6defd2edb..ac50b0f511a12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AttributeScopeStack.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AttributeScopeStack.scala @@ -59,7 +59,7 @@ class AttributeScopeStack { /** * Overwrite current relevant scope with a sequence of attributes which is an output of some * operator. `attributes` can have duplicate IDs if the output of the operator contains multiple - * occurrences of the same attribute. + * occurencies of the same attribute. */ def overwriteTop(attributes: Seq[Attribute]): Unit = { stack.pop() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala index 7d9c6752094d7..d2b586f3d372d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.catalyst.expressions.{ Subtract, SubtractDates } -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DateType, StringType} /** @@ -90,17 +89,14 @@ class BinaryArithmeticResolver( extends TreeNodeResolver[BinaryArithmetic, Expression] with ProducesUnresolvedSubtree { - private val shouldTrackResolvedNodes = - conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_TRACK_RESOLVED_NODES_ENABLED) - - private val typeCoercionRules: Seq[Expression => Expression] = + private val typeCoercionTransformations: Seq[Expression => Expression] = if (conf.ansiEnabled) { - BinaryArithmeticResolver.ANSI_TYPE_COERCION_RULES + BinaryArithmeticResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS } else { - BinaryArithmeticResolver.TYPE_COERCION_RULES + BinaryArithmeticResolver.TYPE_COERCION_TRANSFORMATIONS } private val typeCoercionResolver: TypeCoercionResolver = - new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionRules) + new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionTransformations) override def resolve(unresolvedBinaryArithmetic: BinaryArithmetic): Expression = { val binaryArithmeticWithResolvedChildren: BinaryArithmetic = @@ -109,11 +105,10 @@ class BinaryArithmeticResolver( withResolvedSubtree(binaryArithmeticWithResolvedChildren, expressionResolver.resolve) { transformBinaryArithmeticNode(binaryArithmeticWithResolvedChildren) } - val binaryArithmeticWithResolvedTimezone = timezoneAwareExpressionResolver.withResolvedTimezone( + timezoneAwareExpressionResolver.withResolvedTimezone( binaryArithmeticWithResolvedSubtree, conf.sessionLocalTimeZone ) - reallocateKnownNodesForTracking(binaryArithmeticWithResolvedTimezone) } /** @@ -156,30 +151,11 @@ class BinaryArithmeticResolver( BinaryArithmeticWithDatetimeResolver.resolve(arithmetic) case other => other } - - /** - * Since [[TracksResolvedNodes]] requires all the expressions in the tree to be unique objects, - * we reallocate the known nodes in [[ANALYZER_SINGLE_PASS_TRACK_RESOLVED_NODES_ENABLED]] mode, - * otherwise we preserve the old object to avoid unnecessary memory allocations. - */ - private def reallocateKnownNodesForTracking(expression: Expression): Expression = { - if (shouldTrackResolvedNodes) { - expression match { - case add: Add => add.copy() - case subtract: Subtract => subtract.copy() - case multiply: Multiply => multiply.copy() - case divide: Divide => divide.copy() - case _ => expression - } - } else { - expression - } - } } object BinaryArithmeticResolver { // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( StringPromotionTypeCoercion.apply, DecimalPrecisionTypeCoercion.apply, DivisionTypeCoercion.apply, @@ -189,7 +165,7 @@ object BinaryArithmeticResolver { ) // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( AnsiStringPromotionTypeCoercion.apply, DecimalPrecisionTypeCoercion.apply, DivisionTypeCoercion.apply, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationsProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationsProvider.scala index bc7a9df064c33..a33675e9dfd09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationsProvider.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BridgedRelationsProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.catalyst.analysis.RelationResolution +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.CatalogManager /** @@ -32,19 +33,28 @@ class BridgedRelationMetadataProvider( override val relationResolution: RelationResolution, analyzerBridgeState: AnalyzerBridgeState ) extends RelationMetadataProvider { - override val relationsWithResolvedMetadata = getRelationsFromBridgeState(analyzerBridgeState) + override val relationsWithResolvedMetadata = new RelationsWithResolvedMetadata + updateRelationsWithResolvedMetadata() - private def getRelationsFromBridgeState( - analyzerBridgeState: AnalyzerBridgeState): RelationsWithResolvedMetadata = { - val result = new RelationsWithResolvedMetadata + /** + * We update relations on each [[resolve]] call, because relation IDs might have changed. + * This can happen for the nested views, since catalog name may differ, and expanded table name + * will differ for the same [[UnresolvedRelation]]. + * + * See [[ViewResolver.resolve]] for more info on how SQL configs are propagated to nested views). + */ + override def resolve(unresolvedPlan: LogicalPlan): Unit = { + updateRelationsWithResolvedMetadata() + } + + private def updateRelationsWithResolvedMetadata(): Unit = { analyzerBridgeState.relationsWithResolvedMetadata.forEach( (unresolvedRelation, relationWithResolvedMetadata) => { - result.put( + relationsWithResolvedMetadata.put( relationIdFromUnresolvedRelation(unresolvedRelation), relationWithResolvedMetadata ) } ) - result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala index 75ba1b7a01a5c..548c824b24f5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala @@ -31,14 +31,14 @@ class ConditionalExpressionResolver( with ResolvesExpressionChildren with SQLConfHelper { - private val typeCoercionRules: Seq[Expression => Expression] = + private val typeCoercionTransformations: Seq[Expression => Expression] = if (conf.ansiEnabled) { - ConditionalExpressionResolver.ANSI_TYPE_COERCION_RULES + ConditionalExpressionResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS } else { - ConditionalExpressionResolver.TYPE_COERCION_RULES + ConditionalExpressionResolver.TYPE_COERCION_TRANSFORMATIONS } private val typeCoercionResolver: TypeCoercionResolver = - new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionRules) + new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionTransformations) override def resolve(unresolvedConditionalExpression: ConditionalExpression): Expression = { val conditionalExpressionWithResolvedChildren = @@ -50,14 +50,14 @@ class ConditionalExpressionResolver( object ConditionalExpressionResolver { // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( TypeCoercion.CaseWhenTypeCoercion.apply, TypeCoercion.FunctionArgumentTypeCoercion.apply, TypeCoercion.IfTypeCoercion.apply ) // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( AnsiTypeCoercion.CaseWhenTypeCoercion.apply, AnsiTypeCoercion.FunctionArgumentTypeCoercion.apply, AnsiTypeCoercion.IfTypeCoercion.apply diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CteScope.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CteScope.scala new file mode 100644 index 0000000000000..83019b9b7359a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CteScope.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.{ArrayDeque, ArrayList} + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.plans.logical.CTERelationDef + +/** + * The [[CteScope]] is responsible for keeping track of visible and known CTE definitions at a given + * stage of a SQL query/DataFrame program resolution. These scopes are stacked and the stack is + * managed by the [[CteRegistry]]. The scope is created per single WITH clause. + * + * The CTE operators are: + * - [[UnresolvedWith]]. This is a `host` operator that contains a list of unresolved CTE + * definitions from the WITH clause and a single child operator, which is the actual unresolved + * SELECT query. + * - [[UnresolvedRelation]]. This is a generic unresolved relation operator that will sometimes + * be resolved to a CTE definition and later replaced with a [[CTERelationRef]]. The CTE takes + * precedence over a regular table or a view when resolving this identifier. + * - [[CTERelationDef]]. This is a reusable logical plan, which will later be referenced by the + * lower CTE definitions and [[UnresolvedWith]] child. + * - [[CTERelationRef]]. This is a leaf node similar to a relation operator that references a + * certain [[CTERelationDef]] by its ID. It has a name (unique locally for a WITH clause list) + * and an ID (unique for all the CTEs in a query). + * - [[WithCTE]]. This is a `host` operator that contains a list of resolved CTE definitions from + * the WITH clause and a single child operator, which is the actual resolved SELECT query. + * + * The task of the [[Resolver]] is to correctly place [[WithCTE]] with [[CTERelationDef]]s inside + * and make sure that [[CTERelationRef]]s correctly reference [[CTERelationDef]]s with their IDs. + * The decision whether to inline those CTE subtrees or not is made by the [[Optimizer]], unlike + * what Spark does for the [[View]]s (always inline during the analysis). + * + * There are some caveats in how Spark places those operators and resolves their names: + * - Ambiguous CTE definition names are disallowed only within a single WITH clause, and this is + * validated by the Parser in [[AstBuilder]] + * using [[QueryParsingErrors.duplicateCteDefinitionNamesError]]: + * + * {{{ + * -- This is disallowed. + * WITH cte AS (SELECT 1), + * cte AS (SELECT 2) + * SELECT * FROM cte; + * }}} + * + * - When [[UnresolvedRelation]] identifier is resolved to a [[CTERelationDef]] and there is a + * name conflict on several layers of CTE definitions, the lower definitions take precedence: + * + * {{{ + * -- The result is `3`, lower [[CTERelationDef]] takes precedence. + * WITH cte AS ( + * SELECT 1 + * ) + * SELECT * FROM ( + * WITH cte AS ( + * SELECT 2 + * ) + * SELECT * FROM ( + * WITH cte AS ( + * SELECT 3 + * ) + * SELECT * FROM cte + * ) + * ) + * }}} + * + * - Any subquery can contain [[UnresolvedWith]] on top of it, but [[WithCTE]] is not gonna be + * 1 to 1 to its unresolved counterpart. For example, if we are dealing with simple subqueries, + * [[CTERelationDef]]s will be merged together under a single [[WithCTE]]. The previous example + * would produce the following resolved plan: + * + * {{{ + * WithCTE + * :- CTERelationDef 18, false + * : +- ... + * :- CTERelationDef 19, false + * : +- ... + * :- CTERelationDef 20, false + * : +- ... + * +- Project [3#1203] + * : +- ... + * }}} + * + * - However, if we have any expression subquery (scalar/IN/EXISTS...), the top + * [[CTERelationDef]]s and subquery's [[CTERelationDef]] won't be merged together (as they are + * separated by an expression tree): + * + * {{{ + * WITH cte AS ( + * SELECT 1 AS col1 + * ) + * SELECT * FROM cte WHERE col1 IN ( + * WITH cte AS ( + * SELECT 2 + * ) + * SELECT * FROM cte + * ) + * }}} + * + * -> + * + * {{{ + * WithCTE + * :- CTERelationDef 21, false + * : +- ... + * +- Project [col1#1223] + * +- Filter col1#1223 IN (list#1222 []) + * : +- WithCTE + * : :- CTERelationDef 22, false + * : : +- ... + * : +- Project [2#1241] + * : +- ... + * +- ... + * }}} + * + * - Upper CTEs are visible through subqueries and can be referenced by lower operators, but not + * through the [[View]] boyndary: + * + * {{{ + * CREATE VIEW v1 AS SELECT 1; + * CREATE VIEW v2 AS SELECT * FROM v1; + * + * -- The result is 1. + * -- The `v2` body will be inlined in the main query tree during the analysis, but upper `v1` + * -- CTE definition _won't_ take precedence over the lower `v1` view. + * WITH v1 AS ( + * SELECT 2 + * ) + * SELECT * FROM v2; + * }}} + * + * @param isRoot This marks the place where [[WithCTE]] has to be placed with all the merged + * [[CTERelationDef]] that were collected under it. It will be true for root query, [[View]]s + * and expression subqueries. + * @param isOpaque This flag makes this [[CteScope]] opaque for [[CTERelationDef]] lookups. It will + * be true for root query and [[View]]s. + */ +class CteScope(val isRoot: Boolean, val isOpaque: Boolean) { + + /** + * Known [[CTERelationDef]]s that were already resolved in this scope or in child scopes. This is + * used to merge CTE definitions together in a single [[WithCTE]]. + */ + private val knownCtes = new ArrayList[CTERelationDef] + + /** + * Visible [[CTERelationDef]]s that were already resolved in this scope. Child scope definitions + * are _not_ visible. Upper definitions _are_ visible, but this is handled by + * [[CteRegistry.resolveCteName]] to avoid cascadingly growing [[IdentifierMap]]s. + */ + private val visibleCtes = new IdentifierMap[CTERelationDef] + + /** + * Register a new CTE definition in this scope. Since the scope is created per single WITH clause, + * there can be no name conflicts, but this is validated by the Parser in [[AstBuilder]] + * using [[QueryParsingErrors.duplicateCteDefinitionNamesError]]. This definition will be both + * known and visible. + */ + def registerCte(name: String, cteDef: CTERelationDef): Unit = { + knownCtes.add(cteDef) + visibleCtes.put(name, cteDef) + } + + /** + * Get a visible CTE definition by its name. + */ + def getCte(name: String): Option[CTERelationDef] = { + visibleCtes.get(name) + } + + /** + * Merge the state from a child scope. We transfer all the known CTE definitions to later merge + * them in one [[WithCTE]]. Root scopes terminate this chain, since they have their own + * [[WithCTE]]. + */ + def mergeChildScope(childScope: CteScope): Unit = { + if (!childScope.isRoot) { + knownCtes.addAll(childScope.knownCtes) + } + } + + /** + * Get all known (from this and child scopes) [[CTERelationDef]]s. This is used to construct + * [[WithCTE]] from a root scope. + */ + def getKnownCtes: Seq[CTERelationDef] = { + knownCtes.asScala.toSeq + } +} + +/** + * The [[CteRegistry]] is responsible for managing the stack of [[CteScope]]s and resolving visible + * [[CTERelationDef]] names. + */ +class CteRegistry { + private val stack = new ArrayDeque[CteScope] + stack.push(new CteScope(isRoot = true, isOpaque = true)) + + def currentScope: CteScope = stack.peek() + + /** + * A RAII-wrapper for pushing/popping scopes. This is used by the [[Resolver]] to create a new + * scope for each WITH clause. + */ + def withNewScope[R](isRoot: Boolean = false, isOpaque: Boolean = false)(body: => R): R = { + stack.push(new CteScope(isRoot = isRoot, isOpaque = isOpaque)) + + try { + body + } finally { + val childScope = stack.pop() + currentScope.mergeChildScope(childScope) + } + } + + /** + * Resolve `name` to a visible [[CTERelationDef]]. The upper definitions are also visible, and + * the lowest of them takes precedence. Opaque scopes terminate the lookup (e.g. [[View]] + * boundary). + */ + def resolveCteName(name: String): Option[CTERelationDef] = { + val iter = stack.iterator + var done = false + var result: Option[CTERelationDef] = None + while (iter.hasNext() && !done) { + val scope = iter.next() + + done = scope.isOpaque + + scope.getCte(name) match { + case Some(cte) => + result = Some(cte) + done = true + case None => + } + } + + result + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DelegatesResolutionToExtensions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DelegatesResolutionToExtensions.scala index 7d57e4683df40..7e1b3bfc1f138 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DelegatesResolutionToExtensions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DelegatesResolutionToExtensions.scala @@ -42,19 +42,21 @@ trait DelegatesResolutionToExtensions { * @throws `AMBIGUOUS_RESOLVER_EXTENSION` if there were several matched extensions for this * operator. */ - def tryDelegateResolutionToExtension(unresolvedOperator: LogicalPlan): Option[LogicalPlan] = { + protected def tryDelegateResolutionToExtension( + unresolvedOperator: LogicalPlan, + resolver: LogicalPlanResolver): Option[LogicalPlan] = { var resolutionResult: Option[LogicalPlan] = None var matchedExtension: Option[ResolverExtension] = None extensions.foreach { extension => matchedExtension match { case None => - resolutionResult = extension.resolveOperator.lift(unresolvedOperator) + resolutionResult = extension.resolveOperator(unresolvedOperator, resolver) if (resolutionResult.isDefined) { matchedExtension = Some(extension) } case Some(matchedExtension) => - if (extension.resolveOperator.isDefinedAt(unresolvedOperator)) { + if (extension.resolveOperator(unresolvedOperator, resolver).isDefined) { throw QueryCompilationErrors .ambiguousResolverExtension( unresolvedOperator, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExplicitlyUnsupportedResolverFeature.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExplicitlyUnsupportedResolverFeature.scala index e6279c9740395..c9b8dece77cec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExplicitlyUnsupportedResolverFeature.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExplicitlyUnsupportedResolverFeature.scala @@ -49,7 +49,6 @@ class ExplicitlyUnsupportedResolverFeature(reason: String) */ object ExplicitlyUnsupportedResolverFeature { val OPERATORS = Set( - "org.apache.spark.sql.catalyst.plans.logical.View", "org.apache.spark.sql.catalyst.streaming.StreamingRelationV2", "org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation", "org.apache.spark.sql.execution.streaming.StreamingRelation" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala new file mode 100644 index 0000000000000..4622a8ebdb563 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.{ArrayDeque, HashMap, HashSet} + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + Attribute, + AttributeReference, + ExprId, + NamedExpression +} +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * [[ExpressionIdAssigner]] is used by the [[ExpressionResolver]] to assign unique expression IDs to + * [[NamedExpression]]s ([[AttributeReference]]s and [[Alias]]es). This is necessary to ensure + * that Optimizer performs its work correctly and does not produce correctness issues. + * + * The framework works the following way: + * - Each leaf operator must have unique output IDs (even if it's the same table, view, or CTE). + * - The [[AttributeReference]]s get propagated "upwards" through the operator tree with their IDs + * preserved. + * - Each [[Alias]] gets assigned a new unique ID and it sticks with it after it gets converted to + * an [[AttributeReference]] when it is outputted from the operator that produced it. + * - Any operator may have [[AttributeReference]]s with the same IDs in its output given it is the + * same attribute. + * Thus, **no multi-child operator may have children with conflicting [[AttributeReference]] IDs**. + * In other words, two subtrees must not output the [[AttributeReference]]s with the same IDs, since + * relations, views and CTEs all output unique attributes, and [[Alias]]es get assigned new IDs as + * well. [[ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds]] is used to assert this + * invariant. + * + * For SQL queries, this framework provides correctness just by reallocating relation outputs and + * by validating the invariants mentioned above. Reallocation is done in + * [[Resolver.handleLeafOperator]]. If all the relations (even if it's the same table) have unique + * output IDs, the expression ID assignment will be correct, because there are no duplicate IDs in + * a pure unresolved tree. The old ID -> new ID mapping is not needed in this case. + * For example, consider this query: + * + * {{{ + * SELECT * FROM t AS t1 CROSS JOIN t AS t2 ON t1.col1 = t2.col1 + * }}} + * + * The analyzed plan should be: + * {{{ + * Project [col1#0, col2#1, col1#2, col2#3] + * +- Join Cross, (col1#0 = col1#2) + * :- SubqueryAlias t1 + * : +- Relation t[col1#0,col2#1] parquet + * +- SubqueryAlias t2 + * +- Relation t[col1#2,col2#3] parquet + * }}} + * + * and not: + * {{{ + * Project [col1#0, col2#1, col1#0, col2#1] + * +- Join Cross, (col1#0 = col1#0) + * :- SubqueryAlias t1 + * : +- Relation t[col1#0,col2#1] parquet + * +- SubqueryAlias t2 + * +- Relation t[col1#0,col2#1] parquet + * }}} + * + * Because in the latter case the join condition is always true. + * + * For DataFrame programs we need the full power of [[ExpressionIdAssigner]], and old ID -> new ID + * mapping comes in handy, because DataFrame programs pass _partially_ resolved plans to the + * [[Resolver]], which may consist of duplicate subtrees, and thus will have already assigned + * expression IDs. These already resolved dupliciate subtrees with assigned IDs will conflict. + * Hence, we need to reallocate all the leaf node outputs _and_ remap old IDs to the new ones. + * Also, DataFrame programs may introduce the same [[Alias]]es in different parts of the query plan, + * so we just reallocate all the [[Alias]]es. + * + * For example, consider this DataFrame program: + * + * {{{ + * spark.range(0, 10).select($"id").write.format("parquet").saveAsTable("t") + * val alias = ($"id" + 1).as("id") + * spark.table("t").select(alias).select(alias) + * }}} + * + * The analyzed plan should be: + * {{{ + * Project [(id#6L + cast(1 as bigint)) AS id#13L] + * +- Project [(id#4L + cast(1 as bigint)) AS id#6L] + * +- SubqueryAlias spark_catalog.default.t + * +- Relation spark_catalog.default.t[id#4L] parquet + * }}} + * + * and not: + * {{{ + * Project [(id#6L + cast(1 as bigint)) AS id#6L] + * +- Project [(id#4L + cast(1 as bigint)) AS id#6L] + * +- SubqueryAlias spark_catalog.default.t + * +- Relation spark_catalog.default.t[id#4L] parquet + * }}} + * + * Because the latter case will confuse the Optimizer and the top [[Project]] will be eliminated + * leading to incorrect result. + * + * There's an important caveat here: the leftmost branch of a logical plan tree. In this branch we + * need to preserve the expression IDs wherever possible because DataFrames may reference each other + * using their attributes. This also makes sense for performance reasons. + * + * Consider this example: + * + * {{{ + * val df1 = spark.range(0, 10).select($"id") + * val df2 = spark.range(5, 15).select($"id") + * df1.union(df2).filter(df1("id") === 5) + * }}} + * + * In this example `df("id")` references lower `id` attribute by expression ID, so `union` must not + * reassign expression IDs in `df1` (left child). Referencing `df2` (right child) is not supported + * in Spark. + * + * The [[ExpressionIdAssigner]] covers both SQL and DataFrame scenarios with single approach and is + * integrated in the single-pass analysis framework. + * + * The [[ExpressionIdAssigner]] is used in the following way: + * - When the [[Resolver]] traverses the tree downwards prior to starting bottom-up analysis, + * we build the [[mappingStack]] by calling [[withNewMapping]] (i.e. [[mappingStack.push]]) + * for every child of a multi-child operator, so we have a separate stack entry (separate + * mapping) for each branch. This way sibling branches' mappings are isolated from each other and + * attribute IDs are reused only within the same branch. Initially we push `None`, because + * the mapping needs to be initialized later with the correct output of a resolved operator. + * - When the bottom-up analysis starts, we assign IDs to all the [[NamedExpression]]s which are + * present in operators starting from the [[LeafNode]]s using [[mapExpression]]. + * [[createMapping]] is called right after each [[LeafNode]] is resolved, and first remapped + * attributes come from that [[LeafNode]]. This is done in [[Resolver.handleLeafOperator]] for + * each logical plan tree branch except the leftmost. + * - Once the child branch is resolved, [[withNewMapping]] ends by calling [[mappingStack.pop]]. + * - After the multi-child operator is resolved, we call [[createMapping]] to + * initialize the mapping with attributes _chosen_ (e.g. [[Union.mergeChildOutputs]]) by that + * operator's resolution algorithm and remap _old_ expression IDs to those chosen attributes. + * - Continue remapping expressions until we reach the root of the operator tree. + */ +class ExpressionIdAssigner { + private val mappingStack = new ExpressionIdAssigner.Stack + mappingStack.push(ExpressionIdAssigner.StackEntry(isLeftmostBranch = true)) + + /** + * Returns `true` if the current logical plan branch is the leftmost branch. This is important + * in the context of preserving expression IDs in DataFrames. See class doc for more details. + */ + def isLeftmostBranch: Boolean = mappingStack.peek().isLeftmostBranch + + /** + * A RAII-wrapper for [[mappingStack.push]] and [[mappingStack.pop]]. [[Resolver]] uses this for + * every child of a multi-child operator to ensure that each operator branch uses an isolated + * expression ID mapping. + * + * @param isLeftmostChild whether the current child is the leftmost child of the operator that is + * being resolved. This is used to determine whether the new stack entry is gonna be in the + * leftmost logical plan branch. It's `false` by default, because it's safer to remap attributes + * than to leave duplicates (to prevent correctness issues). + */ + def withNewMapping[R](isLeftmostChild: Boolean = false)(body: => R): R = { + mappingStack.push( + ExpressionIdAssigner.StackEntry( + isLeftmostBranch = isLeftmostChild && isLeftmostBranch + ) + ) + try { + body + } finally { + mappingStack.pop() + } + } + + /** + * Create mapping with the given `newOutput` that rewrites the `oldOutput`. This + * is used by the [[Resolver]] after the multi-child operator is resolved to fill the current + * mapping with the attributes _chosen_ by that operator's resolution algorithm and remap _old_ + * expression IDs to those chosen attributes. It's also used by the [[ExpressionResolver]] right + * before remapping the attributes of a [[LeafNode]]. + * + * `oldOutput` is present for already resolved subtrees (e.g. DataFrames), but for SQL queries + * is will be `None`, because that logical plan is analyzed for the first time. + */ + def createMapping( + newOutput: Seq[Attribute] = Seq.empty, + oldOutput: Option[Seq[Attribute]] = None): Unit = { + if (mappingStack.peek().mapping.isDefined) { + throw SparkException.internalError( + s"Attempt to overwrite existing mapping. New output: $newOutput, old output: $oldOutput" + ) + } + + val newMapping = new ExpressionIdAssigner.Mapping + oldOutput match { + case Some(oldOutput) => + if (newOutput.length != oldOutput.length) { + throw SparkException.internalError( + s"Outputs have different lengths. New output: $newOutput, old output: $oldOutput" + ) + } + + newOutput.zip(oldOutput).foreach { + case (newAttribute, oldAttribute) => + newMapping.put(oldAttribute.exprId, newAttribute.exprId) + newMapping.put(newAttribute.exprId, newAttribute.exprId) + } + case None => + newOutput.foreach { newAttribute => + newMapping.put(newAttribute.exprId, newAttribute.exprId) + } + } + + mappingStack.push(mappingStack.pop().copy(mapping = Some(newMapping))) + } + + /** + * Assign a correct ID to the given [[originalExpression]] and return a new instance of that + * expression, or return a corresponding new instance of the same attribute, that was previously + * reallocated and is present in the current [[mappingStack]] entry. + * + * For [[Alias]]es: Try to preserve them if we are in the leftmost logical plan tree branch and + * unless they conflict. Conflicting [[Alias]] IDs are never acceptable. Otherwise, reallocate + * with a new ID and return that instance. + * + * For [[AttributeReference]]s: If the attribute is present in the current [[mappingStack]] entry, + * return that instance, otherwise reallocate with a new ID and return that instance. The mapping + * is done both from the original expression ID _and_ from the new expression ID - this way we are + * able to replace old references to that attribute in the current operator branch, and preserve + * already reallocated attributes to make this call idempotent. + * + * When remapping the provided expressions, we don't replace them with the previously seen + * attributes, but replace their IDs ([[NamedExpression.withExprId]]). This is done to preserve + * the properties of attributes at a certain point in the query plan. Examples where it's + * important: + * + * 1) Preserve the name case. In Spark the "requested" name takes precedence over the "original" + * name: + * + * {{{ + * -- The output schema is [col1, COL1] + * SELECT col1, COL1 FROM VALUES (1); + * }}} + * + * 2) Preserve the metadata: + * + * {{{ + * // Metadata "m1" remains, "m2" gets overwritten by the specified schema, "m3" is newly added. + * val metadata1 = new MetadataBuilder().putString("m1", "1").putString("m2", "2").build() + * val metadata2 = new MetadataBuilder().putString("m2", "3").putString("m3", "4").build() + * val schema = new StructType().add("a", IntegerType, nullable = true, metadata = metadata2) + * val df = + * spark.sql("SELECT col1 FROM VALUES (1)").select(col("col1").as("a", metadata1)).to(schema) + * }}} + */ + def mapExpression(originalExpression: NamedExpression): NamedExpression = { + if (!mappingStack.peek().mapping.isDefined) { + throw SparkException.internalError( + "Expression ID mapping doesn't exist. Please call createMapping(...) first. " + + s"Original expression: $originalExpression" + ) + } + + val currentMapping = mappingStack.peek().mapping.get + + val resultExpression = originalExpression match { + case alias: Alias if isLeftmostBranch => + val resultAlias = currentMapping.get(alias.exprId) match { + case null => + alias + case _ => + alias.newInstance() + } + currentMapping.put(resultAlias.exprId, resultAlias.exprId) + resultAlias + case alias: Alias => + reassignExpressionId(alias, currentMapping) + case attributeReference: AttributeReference => + currentMapping.get(attributeReference.exprId) match { + case null => + reassignExpressionId(attributeReference, currentMapping) + case mappedExpressionId => + attributeReference.withExprId(mappedExpressionId) + } + case _ => + throw QueryCompilationErrors.unsupportedSinglePassAnalyzerFeature( + s"${originalExpression.getClass} expression ID assignment" + ) + } + + resultExpression.copyTagsFrom(originalExpression) + resultExpression + } + + private def reassignExpressionId( + originalExpression: NamedExpression, + currentMapping: ExpressionIdAssigner.Mapping): NamedExpression = { + val newExpression = originalExpression.newInstance() + + currentMapping.put(originalExpression.exprId, newExpression.exprId) + currentMapping.put(newExpression.exprId, newExpression.exprId) + + newExpression + } +} + +object ExpressionIdAssigner { + type Mapping = HashMap[ExprId, ExprId] + + case class StackEntry(mapping: Option[Mapping] = None, isLeftmostBranch: Boolean = false) + + type Stack = ArrayDeque[StackEntry] + + /** + * Assert that `outputs` don't have conflicting expression IDs. This is only relevant for child + * outputs of multi-child operators. Conflicting attributes are only checked between different + * child branches, since one branch may output the same attribute multiple times. Hence, we use + * only distinct expression IDs from each output. + * + * {{{ + * -- This is OK, one operator branch outputs its attribute multiple times + * SELECT col1, col1 FROM t1; + * }}} + * + * {{{ + * -- If both children of this [[Union]] operator output `col1` with the same expression ID, + * -- the analyzer is broken. + * SELECT col1 FROM t1 + * UNION ALL + * SELECT col1 FROM t1 + * ; + * }}} + */ + def assertOutputsHaveNoConflictingExpressionIds(outputs: Seq[Seq[Attribute]]): Unit = { + if (doOutputsHaveConflictingExpressionIds(outputs)) { + throw SparkException.internalError(s"Conflicting expression IDs in child outputs: $outputs") + } + } + + private def doOutputsHaveConflictingExpressionIds(outputs: Seq[Seq[Attribute]]): Boolean = { + outputs.length > 1 && { + val expressionIds = new HashSet[ExprId] + + outputs.exists { output => + val outputExpressionIds = new HashSet[ExprId] + + val hasConflicting = output.exists { attribute => + outputExpressionIds.add(attribute.exprId) + expressionIds.contains(attribute.exprId) + } + + if (!hasConflicting) { + expressionIds.addAll(outputExpressionIds) + } + + hasConflicting + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala new file mode 100644 index 0000000000000..744bae10dbd06 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +/** + * The [[ExpressionResolutionContext]] is a state that is propagated between the nodes of the + * expression tree during the bottom-up expression resolution process. This way we pass the results + * of [[ExpressionResolver.resolve]] call, which are not the resolved child itself, from children + * to parents. + * + * @hasAggregateExpressionsInASubtree A flag that highlights that a specific node corresponding to + * [[ExpressionResolutionContext]] has aggregate expressions in + * its subtree. + * @hasAttributeInASubtree A flag that highlights that a specific node corresponding to + * [[ExpressionResolutionContext]] has attributes in its subtree. + * @hasLateralColumnAlias A flag that highlights that a specific node corresponding to + * [[ExpressionResolutionContext]] has LCA in its subtree. + */ +class ExpressionResolutionContext( + var hasAggregateExpressionsInASubtree: Boolean = false, + var hasAttributeInASubtree: Boolean = false, + var hasLateralColumnAlias: Boolean = false) { + def merge(other: ExpressionResolutionContext): Unit = { + hasAggregateExpressionsInASubtree |= other.hasAggregateExpressionsInASubtree + hasAttributeInASubtree |= other.hasAttributeInASubtree + hasLateralColumnAlias |= other.hasLateralColumnAlias + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala index 8c80992e2fa2c..3ca62348e892e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala @@ -19,37 +19,14 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.catalyst.expressions.{ Alias, - ArrayDistinct, - ArrayInsert, - ArrayJoin, - ArrayMax, - ArrayMin, ArraysZip, AttributeReference, BinaryExpression, - ConditionalExpression, - CreateArray, - CreateMap, - CreateNamedStruct, Expression, - ExtractANSIIntervalDays, - GetArrayStructFields, - GetMapValue, - GetStructField, Literal, - MapConcat, - MapContainsKey, - MapEntries, - MapFromEntries, - MapKeys, - MapValues, NamedExpression, Predicate, - RuntimeReplaceable, - StringRPad, - StringToMap, - TimeZoneAwareExpression, - UnaryMinus + TimeZoneAwareExpression } import org.apache.spark.sql.types.BooleanType @@ -70,62 +47,18 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { validateAttributeReference(attributeReference) case alias: Alias => validateAlias(alias) - case getMapValue: GetMapValue => - validateGetMapValue(getMapValue) case binaryExpression: BinaryExpression => validateBinaryExpression(binaryExpression) - case extractANSIIntervalDay: ExtractANSIIntervalDays => - validateExtractANSIIntervalDays(extractANSIIntervalDay) case literal: Literal => validateLiteral(literal) case predicate: Predicate => validatePredicate(predicate) - case stringRPad: StringRPad => - validateStringRPad(stringRPad) - case unaryMinus: UnaryMinus => - validateUnaryMinus(unaryMinus) - case getStructField: GetStructField => - validateGetStructField(getStructField) - case createNamedStruct: CreateNamedStruct => - validateCreateNamedStruct(createNamedStruct) - case getArrayStructFields: GetArrayStructFields => - validateGetArrayStructFields(getArrayStructFields) - case createMap: CreateMap => - validateCreateMap(createMap) - case stringToMap: StringToMap => - validateStringToMap(stringToMap) - case mapContainsKey: MapContainsKey => - validateMapContainsKey(mapContainsKey) - case mapConcat: MapConcat => - validateMapConcat(mapConcat) - case mapKeys: MapKeys => - validateMapKeys(mapKeys) - case mapValues: MapValues => - validateMapValues(mapValues) - case mapEntries: MapEntries => - validateMapEntries(mapEntries) - case mapFromEntries: MapFromEntries => - validateMapFromEntries(mapFromEntries) - case createArray: CreateArray => - validateCreateArray(createArray) - case arrayDistinct: ArrayDistinct => - validateArrayDistinct(arrayDistinct) - case arrayInsert: ArrayInsert => - validateArrayInsert(arrayInsert) - case arrayJoin: ArrayJoin => - validateArrayJoin(arrayJoin) - case arrayMax: ArrayMax => - validateArrayMax(arrayMax) - case arrayMin: ArrayMin => - validateArrayMin(arrayMin) case arraysZip: ArraysZip => validateArraysZip(arraysZip) - case conditionalExpression: ConditionalExpression => - validateConditionalExpression(conditionalExpression) - case runtimeReplaceable: RuntimeReplaceable => - validateRuntimeReplaceable(runtimeReplaceable) case timezoneExpression: TimeZoneAwareExpression => validateTimezoneExpression(timezoneExpression) + case expression: Expression => + validateExpression(expression) } } @@ -144,22 +77,7 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { predicate.dataType == BooleanType, s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" ) - assert( - predicate.checkInputDataTypes().isSuccess, - "Input types of a predicate must be valid, but got: " + - predicate.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateStringRPad(stringRPad: StringRPad) = { - validate(stringRPad.first) - validate(stringRPad.second) - validate(stringRPad.third) - assert( - stringRPad.checkInputDataTypes().isSuccess, - "Input types of rpad must be valid, but got: " + - stringRPad.children.map(_.dataType.typeName).mkString(", ") - ) + validateInputDataTypes(predicate) } private def validateAttributeReference(attributeReference: AttributeReference): Unit = { @@ -177,11 +95,7 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { validate(binaryExpression.left) validate(binaryExpression.right) - assert( - binaryExpression.checkInputDataTypes().isSuccess, - "Input types of a binary expression must be valid, but got: " + - binaryExpression.children.map(_.dataType.typeName).mkString(", ") - ) + validateInputDataTypes(binaryExpression) binaryExpression match { case timezoneExpression: TimeZoneAwareExpression => @@ -190,178 +104,29 @@ class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { } } - private def validateConditionalExpression(conditionalExpression: ConditionalExpression): Unit = - conditionalExpression.children.foreach(validate) - - private def validateExtractANSIIntervalDays( - extractANSIIntervalDays: ExtractANSIIntervalDays): Unit = { - validate(extractANSIIntervalDays.child) - } - private def validateLiteral(literal: Literal): Unit = {} - private def validateUnaryMinus(unaryMinus: UnaryMinus): Unit = { - validate(unaryMinus.child) - assert( - unaryMinus.checkInputDataTypes().isSuccess, - "Input types of a unary minus must be valid, but got: " + - unaryMinus.child.dataType.typeName.mkString(", ") - ) - } - - private def validateGetStructField(getStructField: GetStructField): Unit = { - validate(getStructField.child) - } - - private def validateCreateNamedStruct(createNamedStruct: CreateNamedStruct): Unit = { - createNamedStruct.children.foreach(validate) - assert( - createNamedStruct.checkInputDataTypes().isSuccess, - "Input types of CreateNamedStruct must be valid, but got: " + - createNamedStruct.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateGetArrayStructFields(getArrayStructFields: GetArrayStructFields): Unit = { - validate(getArrayStructFields.child) - } - - private def validateGetMapValue(getMapValue: GetMapValue): Unit = { - validate(getMapValue.child) - validate(getMapValue.key) - assert( - getMapValue.checkInputDataTypes().isSuccess, - "Input types of GetMapValue must be valid, but got: " + - getMapValue.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateCreateMap(createMap: CreateMap): Unit = { - createMap.children.foreach(validate) - assert( - createMap.checkInputDataTypes().isSuccess, - "Input types of CreateMap must be valid, but got: " + - createMap.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateStringToMap(stringToMap: StringToMap): Unit = { - validate(stringToMap.text) - validate(stringToMap.pairDelim) - validate(stringToMap.keyValueDelim) - } - - private def validateMapContainsKey(mapContainsKey: MapContainsKey): Unit = { - validate(mapContainsKey.left) - validate(mapContainsKey.right) - assert( - mapContainsKey.checkInputDataTypes().isSuccess, - "Input types of MapContainsKey must be valid, but got: " + - mapContainsKey.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateMapConcat(mapConcat: MapConcat): Unit = { - mapConcat.children.foreach(validate) - assert( - mapConcat.checkInputDataTypes().isSuccess, - "Input types of MapConcat must be valid, but got: " + - mapConcat.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateMapKeys(mapKeys: MapKeys): Unit = { - validate(mapKeys.child) - } - - private def validateMapValues(mapValues: MapValues): Unit = { - validate(mapValues.child) - } - - private def validateMapEntries(mapEntries: MapEntries): Unit = { - validate(mapEntries.child) - } - - private def validateMapFromEntries(mapFromEntries: MapFromEntries): Unit = { - mapFromEntries.children.foreach(validate) - assert( - mapFromEntries.checkInputDataTypes().isSuccess, - "Input types of MapFromEntries must be valid, but got: " + - mapFromEntries.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateCreateArray(createArray: CreateArray): Unit = { - createArray.children.foreach(validate) - assert( - createArray.checkInputDataTypes().isSuccess, - "Input types of CreateArray must be valid, but got: " + - createArray.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateArrayDistinct(arrayDistinct: ArrayDistinct): Unit = { - validate(arrayDistinct.child) - assert( - arrayDistinct.checkInputDataTypes().isSuccess, - "Input types of ArrayDistinct must be valid, but got: " + - arrayDistinct.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateArrayInsert(arrayInsert: ArrayInsert): Unit = { - validate(arrayInsert.srcArrayExpr) - validate(arrayInsert.posExpr) - validate(arrayInsert.itemExpr) - assert( - arrayInsert.checkInputDataTypes().isSuccess, - "Input types of ArrayInsert must be valid, but got: " + - arrayInsert.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateArrayJoin(arrayJoin: ArrayJoin): Unit = { - validate(arrayJoin.array) - validate(arrayJoin.delimiter) - if (arrayJoin.nullReplacement.isDefined) { - validate(arrayJoin.nullReplacement.get) - } - } - - private def validateArrayMax(arrayMax: ArrayMax): Unit = { - validate(arrayMax.child) - assert( - arrayMax.checkInputDataTypes().isSuccess, - "Input types of ArrayMax must be valid, but got: " + - arrayMax.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateArrayMin(arrayMin: ArrayMin): Unit = { - validate(arrayMin.child) - assert( - arrayMin.checkInputDataTypes().isSuccess, - "Input types of ArrayMin must be valid, but got: " + - arrayMin.children.map(_.dataType.typeName).mkString(", ") - ) - } - private def validateArraysZip(arraysZip: ArraysZip): Unit = { arraysZip.children.foreach(validate) arraysZip.names.foreach(validate) - assert( - arraysZip.checkInputDataTypes().isSuccess, - "Input types of ArraysZip must be valid, but got: " + - arraysZip.children.map(_.dataType.typeName).mkString(", ") - ) - } - - private def validateRuntimeReplaceable(runtimeReplaceable: RuntimeReplaceable): Unit = { - runtimeReplaceable.children.foreach(validate) + validateInputDataTypes(arraysZip) } private def validateTimezoneExpression(timezoneExpression: TimeZoneAwareExpression): Unit = { timezoneExpression.children.foreach(validate) assert(timezoneExpression.timeZoneId.nonEmpty, "Timezone expression must have a timezone") } + + private def validateExpression(expression: Expression): Unit = { + expression.children.foreach(validate) + validateInputDataTypes(expression) + } + + private def validateInputDataTypes(expression: Expression): Unit = { + assert( + expression.checkInputDataTypes().isSuccess, + s"Input types of ${expression.getClass.getName} must be valid, but got: " + + expression.children.map(_.dataType.typeName).mkString(", ") + ) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala index 1d072509626b7..1a09b5b73be6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.analysis.resolver +import java.util.ArrayDeque + import org.apache.spark.sql.catalyst.analysis.{ withPosition, FunctionResolution, + GetViewColumnByNameAndOrdinal, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, @@ -31,17 +34,21 @@ import org.apache.spark.sql.catalyst.expressions.{ BinaryArithmetic, ConditionalExpression, CreateNamedStruct, + DateAddYMInterval, Expression, - ExtractANSIIntervalDays, - InheritAnalysisRules, + ExtractIntervalPart, + GetTimeField, Literal, + MakeTimestamp, NamedExpression, Predicate, RuntimeReplaceable, TimeAdd, TimeZoneAwareExpression, - UnaryMinus + UnaryMinus, + UnresolvedNamedLambdaVariable } +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -73,12 +80,10 @@ class ExpressionResolver( planLogger: PlanLogger) extends TreeNodeResolver[Expression, Expression] with ProducesUnresolvedSubtree - with ResolvesExpressionChildren - with TracksResolvedNodes[Expression] { - private val shouldTrackResolvedNodes = - conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_TRACK_RESOLVED_NODES_ENABLED) + with ResolvesExpressionChildren { + private val isLcaEnabled = conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED) + private val aliasResolver = new AliasResolver(this, scopes) - private val createNamedStructResolver = new CreateNamedStructResolver(this) private val timezoneAwareExpressionResolver = new TimezoneAwareExpressionResolver(this) private val conditionalExpressionResolver = new ConditionalExpressionResolver(this, timezoneAwareExpressionResolver) @@ -90,13 +95,62 @@ class ExpressionResolver( timezoneAwareExpressionResolver ) } + private val limitExpressionResolver = new LimitExpressionResolver + private val typeCoercionResolver = new TypeCoercionResolver(timezoneAwareExpressionResolver) + + private val expressionResolutionContextStack = new ArrayDeque[ExpressionResolutionContext] + + private val aggregateExpressionResolver = + new AggregateExpressionResolver(this, timezoneAwareExpressionResolver) private val functionResolver = new FunctionResolver( this, timezoneAwareExpressionResolver, - functionResolution + functionResolution, + aggregateExpressionResolver, + binaryArithmeticResolver ) private val timeAddResolver = new TimeAddResolver(this, timezoneAwareExpressionResolver) private val unaryMinusResolver = new UnaryMinusResolver(this, timezoneAwareExpressionResolver) + private var isTopOfProjectList: Boolean = false + + /** + * The stack of parent operators which were encountered during the resolution of a certain + * expression tree. This is filled by the [[resolveExpressionTreeInOperatorImpl]], and will + * usually have size 1 (the parent for this expression tree). However, in case of subquery + * expressions we would call [[resolveExpressionTreeInOperatorImpl]] several times recursively + * for each expression tree in the operator tree -> expression tree -> operator tree -> + * expression tree -> ... chain. Consider this example: + * + * {{{ SELECT (SELECT col1 FROM values(1) LIMIT 1) FROM VALUES(1); }}} + * + * Would have the following analyzed tree: + * + * Project [...] + * : +- GlobalLimit 1 + * : +- LocalLimit 1 + * : +- Project [col1] + * : +- LocalRelation [col1] + * +- LocalRelation [col1] + * + * The stack would contain the following operators during the resolution of the nested + * operator/expression trees: + * + * Project -> Project + */ + private val parentOperators = new ArrayDeque[LogicalPlan] + private val expressionIdAssigner = new ExpressionIdAssigner + + /** + * Resolve `unresolvedExpression` which is a child of `parentOperator`. This is the main entry + * point into the [[ExpressionResolver]] for operators. + */ + def resolveExpressionTreeInOperator( + unresolvedExpression: Expression, + parentOperator: LogicalPlan): Expression = { + val (resolvedExpression, _) = + resolveExpressionTreeInOperatorImpl(unresolvedExpression, parentOperator) + resolvedExpression + } /** * This method is an expression analysis entry point. The method first checks if the expression @@ -123,18 +177,18 @@ class ExpressionResolver( override def resolve(unresolvedExpression: Expression): Expression = { planLogger.logExpressionTreeResolutionEvent(unresolvedExpression, "Unresolved expression tree") - if (unresolvedExpression - .getTagValue(ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY) - .nonEmpty) { + if (tryPopSinglePassSubtreeBoundary(unresolvedExpression)) { unresolvedExpression } else { - throwIfNodeWasResolvedEarlier(unresolvedExpression) + pushResolutionContext() - val resolvedExpression = unresolvedExpression match { + var resolvedExpression = unresolvedExpression match { case unresolvedBinaryArithmetic: BinaryArithmetic => binaryArithmeticResolver.resolve(unresolvedBinaryArithmetic) - case unresolvedExtractANSIIntervalDays: ExtractANSIIntervalDays => - resolveExtractANSIIntervalDays(unresolvedExtractANSIIntervalDays) + case unresolvedDateAddYMInterval: DateAddYMInterval => + resolveExpressionGenerically(unresolvedDateAddYMInterval) + case extractIntervalPart: ExtractIntervalPart[_] => + resolveExpressionGenerically(extractIntervalPart) case unresolvedNamedExpression: NamedExpression => resolveNamedExpression(unresolvedNamedExpression) case unresolvedFunction: UnresolvedFunction => @@ -148,43 +202,92 @@ class ExpressionResolver( case unresolvedUnaryMinus: UnaryMinus => unaryMinusResolver.resolve(unresolvedUnaryMinus) case createNamedStruct: CreateNamedStruct => - createNamedStructResolver.resolve(createNamedStruct) + resolveExpressionGenerically(createNamedStruct) case unresolvedConditionalExpression: ConditionalExpression => conditionalExpressionResolver.resolve(unresolvedConditionalExpression) + case getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal => + resolveGetViewColumnByNameAndOrdinal(getViewColumnByNameAndOrdinal) + case getTimeField: GetTimeField => + resolveExpressionGenericallyWithTimezoneWithTypeCoercion(getTimeField) + case makeTimestamp: MakeTimestamp => + resolveExpressionGenericallyWithTimezoneWithTypeCoercion(makeTimestamp) case unresolvedRuntimeReplaceable: RuntimeReplaceable => - resolveRuntimeReplaceable(unresolvedRuntimeReplaceable) + resolveExpressionGenericallyWithTypeCoercion(unresolvedRuntimeReplaceable) case unresolvedTimezoneExpression: TimeZoneAwareExpression => timezoneAwareExpressionResolver.resolve(unresolvedTimezoneExpression) - case _ => - withPosition(unresolvedExpression) { - throwUnsupportedSinglePassAnalyzerFeature(unresolvedExpression) - } + case expression: Expression => + resolveExpressionGenericallyWithTypeCoercion(expression) } - markNodeAsResolved(resolvedExpression) + popResolutionContext() planLogger.logExpressionTreeResolution(unresolvedExpression, resolvedExpression) - resolvedExpression } } - private def resolveNamedExpression( - unresolvedNamedExpression: Expression, - isTopOfProjectList: Boolean = false): Expression = + /** + * Get the expression resolution context stack. + */ + def getExpressionResolutionContextStack: ArrayDeque[ExpressionResolutionContext] = { + expressionResolutionContextStack + } + + def getExpressionIdAssigner: ExpressionIdAssigner = expressionIdAssigner + + /** + * Get the most recent operator (bottommost) from the `parentOperators` stack. + */ + def getParentOperator: Option[LogicalPlan] = { + if (parentOperators.size() > 0) { + Some(parentOperators.peek()) + } else { + None + } + } + + /** + * Resolve the limit expression from either a [[LocalLimit]] or a [[GlobalLimit]] operator. + */ + def resolveLimitExpression( + unresolvedLimitExpr: Expression, + unresolvedLimit: LogicalPlan): Expression = { + val resolvedLimitExpr = resolveExpressionTreeInOperator( + unresolvedLimitExpr, + unresolvedLimit + ) + limitExpressionResolver.resolve(resolvedLimitExpr) + } + + private def resolveExpressionTreeInOperatorImpl( + unresolvedExpression: Expression, + parentOperator: LogicalPlan): (Expression, ExpressionResolutionContext) = { + this.parentOperators.push(parentOperator) + expressionResolutionContextStack.push(new ExpressionResolutionContext) + try { + val resolvedExpression = resolve(unresolvedExpression) + (resolvedExpression, expressionResolutionContextStack.peek()) + } finally { + expressionResolutionContextStack.pop() + this.parentOperators.pop() + } + } + + private def resolveNamedExpression(unresolvedNamedExpression: Expression): Expression = unresolvedNamedExpression match { case alias: Alias => aliasResolver.handleResolvedAlias(alias) case unresolvedAlias: UnresolvedAlias => aliasResolver.resolve(unresolvedAlias) case unresolvedAttribute: UnresolvedAttribute => - resolveAttribute(unresolvedAttribute, isTopOfProjectList) + resolveAttribute(unresolvedAttribute) case unresolvedStar: UnresolvedStar => - withPosition(unresolvedStar) { - throwInvalidStarUsageError(unresolvedStar) - } + // We don't support edge cases of star usage, e.g. `WHERE col1 IN (*)` + throw new ExplicitlyUnsupportedResolverFeature("Star outside of Project list") case attributeReference: AttributeReference => handleResolvedAttributeReference(attributeReference) + case _: UnresolvedNamedLambdaVariable => + throw new ExplicitlyUnsupportedResolverFeature("Lambda variables") case _ => withPosition(unresolvedNamedExpression) { throwUnsupportedSinglePassAnalyzerFeature(unresolvedNamedExpression) @@ -199,14 +302,60 @@ class ExpressionResolver( * from the [[Resolver]] during [[Project]] resolution. * * The output sequence can be larger than the input sequence due to [[UnresolvedStar]] expansion. + * + * @returns The list of resolved expressions along with flags indicating whether the resolved + * project list contains aggregate expressions or attributes (encapsulated in + * [[ResolvedProjectList]]) which are used during the further resolution of the tree. + * + * The following query: + * + * {{{ SELECT COUNT(col1), 2 FROM VALUES(1); }}} + * + * would have a project list with two expressions: `COUNT(col1)` and `2`. After the resolution it + * would return the following result: + * ResolvedProjectList( + * expressions = [count(col1) as count(col1), 2 AS 2], + * hasAggregateExpressions = true, // because it contains `count(col1)` in the project list + * hasAttributes = false // because it doesn't contain any [[AttributeReference]]s in the + * // project list (only under the aggregate expression, please check + * // [[AggregateExpressionResolver]] for more details). */ - def resolveProjectList(unresolvedProjectList: Seq[NamedExpression]): Seq[NamedExpression] = { - unresolvedProjectList.flatMap { + def resolveProjectList( + unresolvedProjectList: Seq[NamedExpression], + operator: LogicalPlan): ResolvedProjectList = { + val projectListResolutionContext = new ExpressionResolutionContext + val resolvedProjectList = unresolvedProjectList.flatMap { case unresolvedStar: UnresolvedStar => resolveStar(unresolvedStar) case other => - Seq(resolveNamedExpression(other, isTopOfProjectList = true).asInstanceOf[NamedExpression]) + val (resolvedElement, resolvedElementContext) = + resolveExpressionTreeInOperatorImpl(other, operator) + projectListResolutionContext.merge(resolvedElementContext) + Seq(resolvedElement.asInstanceOf[NamedExpression]) } + ResolvedProjectList( + expressions = resolvedProjectList, + hasAggregateExpressions = projectListResolutionContext.hasAggregateExpressionsInASubtree, + hasAttributes = projectListResolutionContext.hasAttributeInASubtree, + hasLateralColumnAlias = projectListResolutionContext.hasLateralColumnAlias + ) + } + + /** + * Resolves [[Expression]] only by resolving its children. This resolution method is used for + * nodes that don't require any special resolution other than resolving its children. + */ + def resolveExpressionGenerically(expression: Expression): Expression = + withResolvedChildren(expression, resolve) + + /** + * Resolves [[Expression]] by resolving its children and applying generic type coercion + * transformations to the resulting expression. This resolution method is used for nodes that + * require type coercion on top of [[resolveExpressionGenerically]]. + */ + def resolveExpressionGenericallyWithTypeCoercion(expression: Expression): Expression = { + val expressionWithResolvedChildren = withResolvedChildren(expression, resolve) + typeCoercionResolver.resolve(expressionWithResolvedChildren) } /** @@ -223,42 +372,60 @@ class ExpressionResolver( * - Single result from the [[NameScope]] means that the attribute was found as in: * {{{ SELECT col1 FROM VALUES (1); }}} * + * If [[NameTarget.lateralAttributeReference]] is defined, it means that we are resolving an + * attribute that is a lateral column alias reference. In that case we mark the referenced + * attribute as referenced and tag the LCA attribute for further [[Alias]] resolution. + * * If the attribute is at the top of the project list (which is indicated by * [[isTopOfProjectList]]), we preserve the [[Alias]] or remove it otherwise. */ - private def resolveAttribute( - unresolvedAttribute: UnresolvedAttribute, - isTopOfProjectList: Boolean): Expression = + private def resolveAttribute(unresolvedAttribute: UnresolvedAttribute): Expression = withPosition(unresolvedAttribute) { - if (scopes.top.isExistingAlias(unresolvedAttribute.nameParts.head)) { - // Temporarily disable referencing aliases until we support LCA resolution. - throw new ExplicitlyUnsupportedResolverFeature("unsupported expression: LateralColumnAlias") - } + expressionResolutionContextStack.peek().hasAttributeInASubtree = true - val nameTarget: NameTarget = scopes.top.matchMultipartName(unresolvedAttribute.nameParts) + val nameTarget: NameTarget = + scopes.top.resolveMultipartName(unresolvedAttribute.nameParts, isLcaEnabled) val candidate = nameTarget.pickCandidate(unresolvedAttribute) - if (isTopOfProjectList && nameTarget.aliasName.isDefined) { + + if (isLcaEnabled) { + nameTarget.lateralAttributeReference match { + case Some(lateralAttributeReference) => + scopes.top.lcaRegistry + .markAttributeLaterallyReferenced(lateralAttributeReference) + candidate.setTagValue(ExpressionResolver.SINGLE_PASS_IS_LCA, ()) + expressionResolutionContextStack.peek().hasLateralColumnAlias = true + case None => + } + } + + val resolvedAttribute = if (isTopOfProjectList && nameTarget.aliasName.isDefined) { Alias(candidate, nameTarget.aliasName.get)() } else { candidate } + + resolvedAttribute match { + case namedExpression: NamedExpression => + expressionIdAssigner.mapExpression(namedExpression) + case other => other + } } /** * [[AttributeReference]] is already resolved if it's passed to us from DataFrame `col(...)` * function, for example. */ - private def handleResolvedAttributeReference(attributeReference: AttributeReference) = - tryStripAmbiguousSelfJoinMetadata(attributeReference) + private def handleResolvedAttributeReference(attributeReference: AttributeReference) = { + val strippedAttributeReference = tryStripAmbiguousSelfJoinMetadata(attributeReference) + val resultAttribute = expressionIdAssigner.mapExpression(strippedAttributeReference) - /** - * [[ExtractANSIIntervalDays]] resolution doesn't require any specific resolution logic apart - * from resolving its children. - */ - private def resolveExtractANSIIntervalDays( - unresolvedExtractANSIIntervalDays: ExtractANSIIntervalDays) = - withResolvedChildren(unresolvedExtractANSIIntervalDays, resolve) + if (!scopes.top.hasAttributeWithId(resultAttribute.exprId)) { + throw new ExplicitlyUnsupportedResolverFeature("DataFrame missing attribute propagation") + } + + resultAttribute + } /** * [[UnresolvedStar]] resolution relies on the [[NameScope]]'s ability to get the attributes by a @@ -288,30 +455,117 @@ class ExpressionResolver( /** * [[Literal]] resolution doesn't require any specific resolution logic at this point. + */ + private def resolveLiteral(literal: Literal): Expression = literal + + /** + * The [[GetViewColumnByNameAndOrdinal]] is a special internal expression that is placed by the + * [[SessionCatalog]] in the top [[Project]] operator of the freshly reconstructed unresolved + * view plan. Since the view schema is fixed and persisted in the catalog, we have to extract + * the right attributes from the view plan regardless of the underlying table schema changes. + * [[GetViewColumnByNameAndOrdinal]] contains attribute name and it's ordinal to perform the + * necessary matching. If the matching was not successful, or the number of matched candidates + * differs from the recorded one, we throw an error. + * + * Example of the correct name matching: + * + * {{{ + * CREATE TABLE underlying (col1 INT, col2 STRING); + * CREATE VIEW all_columns AS SELECT * FROM underlying; + * + * -- View plan for the SELECT below will contain a Project node on top with the following + * -- expressions: + * -- getviewcolumnbynameandordinal(`spark_catalog`.`default`.`all_columns`, col1, 0, 1) + * -- getviewcolumnbynameandordinal(`spark_catalog`.`default`.`all_columns`, col2, 0, 1) + * SELECT * FROM all_columns; + * + * ALTER TABLE underlying DROP COLUMN col2; + * ALTER TABLE underlying ADD COLUMN col3 STRING; + * ALTER TABLE underlying ADD COLUMN col2 STRING; * - * Since [[TracksResolvedNodes]] requires all the expressions in the tree to be unique objects, - * we reallocate the literal in [[ANALYZER_SINGLE_PASS_TRACK_RESOLVED_NODES_ENABLED]] mode, - * otherwise we preserve the old object to avoid unnecessary memory allocations. + * -- The output schema for the SELECT below is [col1, col3, col2] + * SELECT * FROM underlying; + * + * -- The output schema for the SELECT below is [col1, col2], because the view schema is fixed. + * -- GetViewColumnByNameAndOrdinal allows us to perform this operation by matching attribute + * -- names. + * SELECT * FROM all_columns; + * }}} + * + * Example of the correct ordinal matching: + * + * {{{ + * CREATE TABLE underlying1 (col1 INT, col2 STRING); + * CREATE TABLE underlying2 (col1 INT, col2 STRING); + * + * CREATE VIEW all_columns (c1, c2, c3, c4) AS SELECT * FROM underlying1, underlying2; + * + * ALTER TABLE underlying1 ADD COLUMN col3 STRING; + * + * -- The output schema for this query has changed to [col1, col2, col3, col1, col2]. + * -- Now we need GetViewColumnByNameAndOrdinal to glue it to the fixed view schema. + * SELECT * FROM underlying1, underlying2; + * + * -- GetViewColumnByNameAndOrdinal helps us to disambiguate the column names from different + * -- tables by matching the same attribute names from those tables by their ordinal in the + * -- Project list, which is dependant on the order of tables in the inner join operator from + * -- the view plan: + * -- getviewcolumnbynameandordinal(`spark_catalog`.`default`.`all_columns`, col1, 0, 2) + * -- getviewcolumnbynameandordinal(`spark_catalog`.`default`.`all_columns`, col2, 0, 2) + * -- getviewcolumnbynameandordinal(`spark_catalog`.`default`.`all_columns`, col1, 1, 2) + * -- getviewcolumnbynameandordinal(`spark_catalog`.`default`.`all_columns`, col2, 1, 2) + * SELECT * FROM all_columns; + * }}} */ - private def resolveLiteral(literal: Literal): Expression = { - if (shouldTrackResolvedNodes) { - literal.copy() - } else { - literal + private def resolveGetViewColumnByNameAndOrdinal( + getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal): Expression = { + val candidates = scopes.top.findAttributesByName(getViewColumnByNameAndOrdinal.colName) + if (candidates.length != getViewColumnByNameAndOrdinal.expectedNumCandidates) { + throw QueryCompilationErrors.incompatibleViewSchemaChangeError( + getViewColumnByNameAndOrdinal.viewName, + getViewColumnByNameAndOrdinal.colName, + getViewColumnByNameAndOrdinal.expectedNumCandidates, + candidates, + getViewColumnByNameAndOrdinal.viewDDL + ) } + + candidates(getViewColumnByNameAndOrdinal.ordinal) } /** - * When [[RuntimeReplaceable]] is mixed in with [[InheritAnalysisRules]], child expression will - * be runtime replacement. In that case we need to resolve the children of the expression. - * otherwise, no resolution is necessary because replacement is already resolved. + * Resolves [[Expression]] by calling [[timezoneAwareExpressionResolver]] to resolve + * expression's children and apply timezone if needed. Applies generic type coercion + * rules to the result. */ - private def resolveRuntimeReplaceable(unresolvedRuntimeReplaceable: RuntimeReplaceable) = - unresolvedRuntimeReplaceable match { - case inheritAnalysisRules: InheritAnalysisRules => - withResolvedChildren(inheritAnalysisRules, resolve) - case other => other + private def resolveExpressionGenericallyWithTimezoneWithTypeCoercion( + timezoneAwareExpression: TimeZoneAwareExpression): Expression = { + val expressionWithTimezone = timezoneAwareExpressionResolver.resolve(timezoneAwareExpression) + typeCoercionResolver.resolve(expressionWithTimezone) + } + + private def popResolutionContext(): Unit = { + val currentExpressionResolutionContext = expressionResolutionContextStack.pop() + expressionResolutionContextStack.peek().merge(currentExpressionResolutionContext) + } + + private def pushResolutionContext(): Unit = { + isTopOfProjectList = expressionResolutionContextStack + .size() == 1 && parentOperators.peek().isInstanceOf[Project] + + expressionResolutionContextStack.push(new ExpressionResolutionContext) + } + + private def tryPopSinglePassSubtreeBoundary(unresolvedExpression: Expression): Boolean = { + if (unresolvedExpression + .getTagValue(ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY) + .isDefined) { + unresolvedExpression.unsetTagValue(ExpressionResolver.SINGLE_PASS_SUBTREE_BOUNDARY) + true + } else { + false } + } /** * [[DetectAmbiguousSelfJoin]] rule in the fixed-point Analyzer detects ambiguous references in @@ -335,13 +589,10 @@ class ExpressionResolver( throw QueryCompilationErrors.unsupportedSinglePassAnalyzerFeature( s"${unresolvedExpression.getClass} expression resolution" ) - - private def throwInvalidStarUsageError(unresolvedStar: UnresolvedStar): Nothing = - // TODO(vladimirg-db): Use parent operator name instead of "query" - throw QueryCompilationErrors.invalidStarUsageError("query", Seq(unresolvedStar)) } object ExpressionResolver { private val AMBIGUOUS_SELF_JOIN_METADATA = Seq("__dataset_id", "__col_position") val SINGLE_PASS_SUBTREE_BOUNDARY = TreeNodeTag[Unit]("single_pass_subtree_boundary") + val SINGLE_PASS_IS_LCA = TreeNodeTag[Unit]("single_pass_is_lca") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala index b7311b83e872e..b6648346ca837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolver.scala @@ -17,15 +17,23 @@ package org.apache.spark.sql.catalyst.analysis.resolver +import java.util.Locale + +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{ - AnsiTypeCoercion, - CollationTypeCoercion, + FunctionRegistry, FunctionResolution, - TypeCoercion, + ResolvedStar, UnresolvedFunction, UnresolvedStar } -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{ + BinaryArithmetic, + Expression, + InheritAnalysisRules, + Literal +} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression /** * A resolver for [[UnresolvedFunction]]s that resolves functions to concrete [[Expression]]s. @@ -40,52 +48,125 @@ import org.apache.spark.sql.catalyst.expressions.Expression * {{{ SELECT ARRAY(*) FROM VALUES (1); }}} * it is resolved using [[ExpressionResolver.resolveStar]]. * - * It applies appropriate [[TypeCoercion]] (or [[AnsiTypeCoercion]]) rules after resolving the - * function using the [[FunctionResolution]] code. + * After resolving the function with [[FunctionResolution.resolveFunction]], performs further + * resolution in two cases: + * - result of [[FunctionResolution.resolveFunction]] is [[InheritAnalysisRules]], in which case + * it is necessary to resolve its replacement expression. + * - result of [[FunctionResolution.resolveFunction]] is [[AggregateExpression]] in which case + * it is necessary to perform further resolution and checks on its children. + * + * Finally apply type coercion to the result of previous step and in case that the resulting + * expression is [[TimeZoneAwareExpression]], apply timezone. */ class FunctionResolver( expressionResolver: ExpressionResolver, timezoneAwareExpressionResolver: TimezoneAwareExpressionResolver, - functionResolution: FunctionResolution) - extends TreeNodeResolver[UnresolvedFunction, Expression] - with ProducesUnresolvedSubtree { + functionResolution: FunctionResolution, + aggregateExpressionResolver: AggregateExpressionResolver, + binaryArithmeticResolver: BinaryArithmeticResolver) + extends TreeNodeResolver[UnresolvedFunction, Expression] + with ProducesUnresolvedSubtree { - private val typeCoercionRules: Seq[Expression => Expression] = - if (conf.ansiEnabled) { - FunctionResolver.ANSI_TYPE_COERCION_RULES - } else { - FunctionResolver.TYPE_COERCION_RULES - } private val typeCoercionResolver: TypeCoercionResolver = - new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionRules) + new TypeCoercionResolver(timezoneAwareExpressionResolver) + /** + * Main method used to resolve an [[UnresolvedFunction]]. It resolves it in the following steps: + * - If the function is `count(*)` it is replaced with `count(1)` (please check + * [[normalizeCountExpression]] documentation for more details). Otherwise, we resolve the + * children of it. + * - Resolve the function using [[FunctionResolution.resolveFunction]]. + * - If the result from previous step is an [[AggregateExpression]], [[BinaryArithmetic]] or + * [[InheritAnalysisRules]], perform further resolution on its children. + * - Apply [[TypeCoercion]] rules to the result of previous step. In case that the resulting + * expression of the previous step is [[BinaryArithmetic]], skip this one as type coercion is + * already applied. + * - Apply timezone, if the resulting expression is [[TimeZoneAwareExpression]]. + */ override def resolve(unresolvedFunction: UnresolvedFunction): Expression = { + checkNotUdf(unresolvedFunction) + val functionWithResolvedChildren = - unresolvedFunction.copy(arguments = unresolvedFunction.arguments.flatMap { - case s: UnresolvedStar => expressionResolver.resolveStar(s) - case other => Seq(expressionResolver.resolve(other)) - }) - val resolvedFunction = functionResolution.resolveFunction(functionWithResolvedChildren) - typeCoercionResolver.resolve(resolvedFunction) + if (isCountStarExpansionAllowed(unresolvedFunction)) { + normalizeCountExpression(unresolvedFunction) + } else { + withResolvedChildren(unresolvedFunction, expressionResolver.resolve) + } + + var resolvedFunction = functionResolution.resolveFunction(functionWithResolvedChildren) + + resolvedFunction = resolvedFunction match { + case inheritAnalysisRules: InheritAnalysisRules => + // Since this [[InheritAnalysisRules]] node is created by + // [[FunctionResolution.resolveFunction]], we need to re-resolve its replacement + // expression. + expressionResolver.resolveExpressionGenericallyWithTypeCoercion(inheritAnalysisRules) + case aggregateExpression: AggregateExpression => + // In case `functionResolution.resolveFunction` produces a `AggregateExpression` we + // need to apply further resolution which is done in the + // `AggregateExpressionResolver`. + val resolvedAggregateExpression = aggregateExpressionResolver.resolve(aggregateExpression) + typeCoercionResolver.resolve(resolvedAggregateExpression) + case binaryArithmetic: BinaryArithmetic => + // In case `functionResolution.resolveFunction` produces a `BinaryArithmetic` we + // need to apply further resolution which is done in the `BinaryArithmeticResolver`. + // + // Examples for this case are following (SQL and Dataframe): + // - {{{ SELECT `+`(1,2); }}} + // - df.select(1+2) + binaryArithmeticResolver.resolve(binaryArithmetic) + case other => + typeCoercionResolver.resolve(other) + } + + timezoneAwareExpressionResolver.withResolvedTimezoneCopyTags( + resolvedFunction, + conf.sessionLocalTimeZone + ) + } + + private def isCountStarExpansionAllowed(unresolvedFunction: UnresolvedFunction): Boolean = + unresolvedFunction.arguments match { + case Seq(UnresolvedStar(None)) => isCount(unresolvedFunction) + case Seq(_: ResolvedStar) => isCount(unresolvedFunction) + case _ => false + } + + /** + * Method used to determine whether the given function should be replaced with another one. + */ + private def isCount(unresolvedFunction: UnresolvedFunction): Boolean = { + !unresolvedFunction.isDistinct && + unresolvedFunction.nameParts.length == 1 && + unresolvedFunction.nameParts.head.toLowerCase(Locale.ROOT) == "count" } -} -object FunctionResolver { - // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( - CollationTypeCoercion.apply, - TypeCoercion.InTypeCoercion.apply, - TypeCoercion.FunctionArgumentTypeCoercion.apply, - TypeCoercion.IfTypeCoercion.apply, - TypeCoercion.ImplicitTypeCoercion.apply - ) + /** + * Method used to replace the `count(*)` function with `count(1)` function. Resolution of the + * `count(*)` is done in the following way: + * - SQL: It is done during the construction of the AST (in [[AstBuilder]]). + * - Dataframes: It is done during the analysis phase and that's why we need to do it here. + */ + private def normalizeCountExpression( + unresolvedFunction: UnresolvedFunction): UnresolvedFunction = { + unresolvedFunction.copy( + nameParts = Seq("count"), + arguments = Seq(Literal(1)), + filter = unresolvedFunction.filter + ) + } - // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( - CollationTypeCoercion.apply, - AnsiTypeCoercion.InTypeCoercion.apply, - AnsiTypeCoercion.FunctionArgumentTypeCoercion.apply, - AnsiTypeCoercion.IfTypeCoercion.apply, - AnsiTypeCoercion.ImplicitTypeCoercion.apply - ) + /** + * Assert that the function that is being resolved is not a UDF, since those are not supported + * in single-pass at the moment. + */ + private def checkNotUdf(unresolvedFunction: UnresolvedFunction): Unit = { + if (!FunctionRegistry.functionSet.contains( + FunctionIdentifier(unresolvedFunction.nameParts.head) + )) { + throw new ExplicitlyUnsupportedResolverFeature( + s"unsupported expression: User Defined Function" + ) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala index e1332a8cfb594..0c1ed75e1e15b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala @@ -54,10 +54,13 @@ class HybridAnalyzer( legacyAnalyzer: Analyzer, resolverGuard: ResolverGuard, resolver: Resolver, + extendedResolutionChecks: Seq[LogicalPlan => Unit] = Seq.empty, checkSupportedSinglePassFeatures: Boolean = true) extends SQLConfHelper { private var singlePassResolutionDuration: Option[Long] = None private var fixedPointResolutionDuration: Option[Long] = None + private val resolverRunner: ResolverRunner = + new ResolverRunner(resolver, extendedResolutionChecks) def apply(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { val dualRun = @@ -176,17 +179,8 @@ class HybridAnalyzer( * This method is used to run the single-pass Analyzer which will return the resolved plan * or throw an exception if the resolution fails. Both cases are handled in the caller method. * */ - private def resolveInSinglePass(plan: LogicalPlan): LogicalPlan = { - val resolvedPlan = resolver.lookupMetadataAndResolve( - plan, - analyzerBridgeState = AnalysisContext.get.getSinglePassResolverBridgeState - ) - if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_VALIDATION_ENABLED)) { - val validator = new ResolutionValidator - validator.validatePlan(resolvedPlan) - } - resolvedPlan - } + private def resolveInSinglePass(plan: LogicalPlan): LogicalPlan = + resolverRunner.resolve(plan, AnalysisContext.get.getSinglePassResolverBridgeState) /** * This method is used to run the legacy Analyzer which will return the resolved plan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierMap.scala index 899eb7d71e813..372f40a6308dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierMap.scala @@ -26,12 +26,3 @@ import java.util.Locale private class IdentifierMap[V] extends KeyTransformingMap[String, V] { override def mapKey(key: String): String = key.toLowerCase(Locale.ROOT) } - -/** - * The [[OptionalIdentifierMap]] is an implementation of a [[KeyTransformingMap]] that uses optional - * SQL/DataFrame identifiers as keys. The implementation is case-insensitive for non-empty keys. - */ -private class OptionalIdentifierMap[V] extends KeyTransformingMap[Option[String], V] { - override def mapKey(key: Option[String]): Option[String] = - key.map(_.toLowerCase(Locale.ROOT)) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/KeyTransformingMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/KeyTransformingMap.scala index ff6e118fcc3c9..5ca7ae2b08154 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/KeyTransformingMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/KeyTransformingMap.scala @@ -17,28 +17,35 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import scala.collection.mutable +import java.util.{HashMap, Iterator} +import java.util.Map.Entry +import java.util.function.Function /** * The [[KeyTransformingMap]] is a partial implementation of [[mutable.Map]] that transforms input * keys with a custom [[mapKey]] method. */ private abstract class KeyTransformingMap[K, V] { - private val impl = new mutable.HashMap[K, V] + private val impl = new HashMap[K, V] - def get(key: K): Option[V] = impl.get(mapKey(key)) + def get(key: K): Option[V] = Option(impl.get(mapKey(key))) - def contains(key: K): Boolean = impl.contains(mapKey(key)) + def put(key: K, value: V): V = impl.put(mapKey(key), value) - def iterator: Iterator[(K, V)] = impl.iterator + def contains(key: K): Boolean = impl.containsKey(mapKey(key)) + + def computeIfAbsent(key: K, compute: Function[K, V]): V = + impl.computeIfAbsent(mapKey(key), compute) + + def iterator: Iterator[Entry[K, V]] = impl.entrySet().iterator() def +=(kv: (K, V)): this.type = { - impl += (mapKey(kv._1) -> kv._2) + impl.put(mapKey(kv._1), kv._2) this } def -=(key: K): this.type = { - impl -= mapKey(key) + impl.remove(mapKey(key)) this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasProhibitedRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasProhibitedRegistry.scala new file mode 100644 index 0000000000000..bc0f11f5bd6de --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasProhibitedRegistry.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.ArrayList + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} + +/** + * Dummy implementation of [[LateralColumnAliasRegistry]] used when + * [[SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED]] is disabled. Getter methods throw an exception + * as they should never be called on a dummy implementation. Non-getter methods must remain + * idempotent. + */ +class LateralColumnAliasProhibitedRegistry extends LateralColumnAliasRegistry { + def withNewLcaScope(body: => Alias): Alias = body + + def getAttribute(attributeName: String): Option[Attribute] = + throwLcaResolutionNotEnabled() + + def getAliasDependencyLevels(): ArrayList[ArrayList[Alias]] = + throwLcaResolutionNotEnabled() + + def markAttributeLaterallyReferenced(attribute: Attribute): Unit = + throwLcaResolutionNotEnabled() + + def isAttributeLaterallyReferenced(attribute: Attribute): Boolean = + throwLcaResolutionNotEnabled() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistry.scala new file mode 100644 index 0000000000000..45a38417a8eed --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistry.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.ArrayList + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} + +/** + * Base class for lateral column alias registry. This class is extended by 2 implementations: + * 1. [[LateralColumnAliasRegistryImpl]] - When [[SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED]] + * is enabled, this class implements logic for LCA resolution. + * 2. [[LateralColumnAliasProhibitedRegistry]] - Dummy class whose methods throw exceptions when + * LCA resolution is disabled by [[SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED]]. + */ +abstract class LateralColumnAliasRegistry { + def withNewLcaScope(body: => Alias): Alias + + def getAttribute(attributeName: String): Option[Attribute] + + def getAliasDependencyLevels(): ArrayList[ArrayList[Alias]] + + def markAttributeLaterallyReferenced(attribute: Attribute): Unit + + def isAttributeLaterallyReferenced(attribute: Attribute): Boolean + + protected def throwLcaResolutionNotEnabled(): Nothing = { + throw SparkException.internalError("Lateral column alias resolution is not enabled.") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistryImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistryImpl.scala new file mode 100644 index 0000000000000..c685b098db2d2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistryImpl.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.{ArrayDeque, ArrayList, HashSet} + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * [[LateralColumnAliasRegistryImpl]] is a utility class that contains structures required for + * lateral column alias resolution. Here we store: + * - [[currentAttributeDependencyLevelStack]] - Current attribute dependency level in the scope. + * Dependency level is defined as a maximum dependency in that attribute's expression tree. For + * example, in a query like: + * + * {{{ SELECT a, b, a + b AS c, a + c AS d}}} + * + * Dependency levels will be as follows: + * level 0: a, b + * level 1: c + * level 2: d + * + * We add a new entry to the stack for each new [[Alias]] resolution. This is needed because we + * can have nesting Aliases in the plan, that do not belong to the same LCA scope. For example, + * in the following query: + * + * {{{ SELECT STRUCT('alpha' AS A, 'beta' AS B) ST }}} + * + * ST, A and B would be aliases in the same expression tree, but they do not belong in the same + * LCA scope. + * + * - [[availableAttributes]] - All attributes that can be laterally referenced. This map is + * indexed by name, but contains a list of attributes with the same name. This is because it is + * possible to have multiple attributes with the same name in the scope, but they can't be + * laterally referenced. Handling ambiguous references is done in the [[getAttribute]] method. + * For the following query: + * + * {{{ SELECT 0 AS a, 1 AS b, 2 AS c, b AS d, a AS e, d AS f, a AS g, g AS h, h AS i }}} + * + * [[availableAttributes]] will be: {a, b, c, d, e, f, g, h, i} + * - [[referencedAliases]] - Aliases that have been laterally referenced. For the given query + * example, [[referencedAliases]] will be: {a, b, d, g, h} + * - [[aliasDependencyLevels]] - Dependency levels of all aliases, indexed by dependency level. + * For the given query example, dependency levels will be as follows: + * + * level 0: a, b, c + * level 1: d, e, g + * level 2: f, h + * level 3: i + * + * @param attributes Output attributes from currently resolved [[NameScope]], to which the registry + * belongs. + */ +class LateralColumnAliasRegistryImpl(attributes: Seq[Attribute]) + extends LateralColumnAliasRegistry { + private case class AliasReference(attribute: Attribute, dependencyLevel: Int) + + private val currentAttributeDependencyLevelStack: ArrayDeque[Int] = new ArrayDeque[Int] + + private val availableAttributes = new IdentifierMap[ArrayList[AliasReference]] + registerAllAttributes(attributes) + + private val referencedAliases = new HashSet[Attribute] + private val aliasDependencyLevels = new ArrayList[ArrayList[Alias]] + + /** + * Creates a new LCA resolution scope for each [[Alias]] resolution. Executes the lambda and + * registers the resolved alias for later LCA resolution. + */ + def withNewLcaScope(body: => Alias): Alias = { + currentAttributeDependencyLevelStack.push(0) + try { + val resolvedAlias = body + registerAlias(resolvedAlias) + resolvedAlias + } finally { + currentAttributeDependencyLevelStack.pop() + } + } + + /** + * Gets the attribute needed for LCA resolution by given name from the set of available + * attributes. If there are multiple matches, throws [[ambiguousLateralColumnAliasError]] error. + * If the method is called while resolving an [[Alias]], updates the dependency level in the + * current scope. + */ + def getAttribute(attributeName: String): Option[Attribute] = { + availableAttributes.get(attributeName) match { + case None => None + case Some(aliasReferenceList: ArrayList[AliasReference]) => + if (aliasReferenceList.size() > 1) { + throw QueryCompilationErrors.ambiguousLateralColumnAliasError( + attributeName, + aliasReferenceList.size() + ) + } + + val aliasReference = aliasReferenceList.get(0) + if (!currentAttributeDependencyLevelStack.isEmpty) { + // compute new dependency as a maximum of current dependency and dependency of the + // referenced attribute incremented by 1. + val maxDependencyLevel = Math.max( + currentAttributeDependencyLevelStack.pop(), + aliasReference.dependencyLevel + 1 + ) + currentAttributeDependencyLevelStack.push(maxDependencyLevel) + } + + Some(aliasReference.attribute) + } + } + + /** + * Returns the dependency levels of all aliases. + */ + def getAliasDependencyLevels(): ArrayList[ArrayList[Alias]] = aliasDependencyLevels + + /** + * Adds an attribute to the set of attributes that have been laterally referenced. + */ + def markAttributeLaterallyReferenced(attribute: Attribute): Unit = + referencedAliases.add(attribute) + + /** + * Returns true if the attribute has been laterally referenced, false otherwise. + */ + def isAttributeLaterallyReferenced(attribute: Attribute): Boolean = + referencedAliases.contains(attribute) + + /** + * Registers an alias for LCA resolution by adding it to correct dependency level. Additionally + * register an attribute for further LCA chaining. + */ + private def registerAlias(alias: Alias): Unit = { + addAliasDependency(alias) + registerAttribute( + alias.toAttribute, + currentAttributeDependencyLevelStack.peek() + ) + } + + private def registerAllAttributes(attributes: Seq[Attribute]) = + attributes.foreach(attribute => registerAttribute(attribute)) + + private def registerAttribute(attribute: Attribute, dependencyLevel: Int = 0): Unit = { + availableAttributes + .computeIfAbsent(attribute.name, _ => new ArrayList[AliasReference]) + .add( + AliasReference( + attribute, + dependencyLevel + ) + ) + } + + private def addAliasDependency(alias: Alias): Unit = { + val dependencyLevel = currentAttributeDependencyLevelStack.peek() + // If targeted dependency level does not exist yet, create it now. + if (aliasDependencyLevels.size() <= dependencyLevel) { + aliasDependencyLevels.add(new ArrayList[Alias]) + } + val dependencyLevelList = aliasDependencyLevels.get(dependencyLevel) + dependencyLevelList.add(alias) + aliasDependencyLevels.set(dependencyLevel, dependencyLevelList) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolver.scala index a25616ba50b6a..d25112d78c6e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolver.scala @@ -19,24 +19,20 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types.IntegerType /** * The [[LimitExpressionResolver]] is a resolver that resolves a [[LocalLimit]] or [[GlobalLimit]] * expression and performs all the necessary validation. */ -class LimitExpressionResolver(expressionResolver: TreeNodeResolver[Expression, Expression]) - extends TreeNodeResolver[Expression, Expression] - with QueryErrorsBase { +class LimitExpressionResolver extends TreeNodeResolver[Expression, Expression] { /** * Resolve a limit expression of [[GlobalLimit]] or [[LocalLimit]] and perform validation. */ override def resolve(unresolvedLimitExpression: Expression): Expression = { - val resolvedLimitExpression = expressionResolver.resolve(unresolvedLimitExpression) - validateLimitExpression(resolvedLimitExpression, expressionName = "limit") - resolvedLimitExpression + validateLimitExpression(unresolvedLimitExpression, expressionName = "limit") + unresolvedLimitExpression } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala index e1334fc56575e..524db0d6d3faf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver import java.util.ArrayDeque -import org.apache.spark.sql.catalyst.analysis.{withPosition, RelationResolution, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{RelationResolution, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} import org.apache.spark.sql.catalyst.plans.logical.{AnalysisHelper, LogicalPlan} import org.apache.spark.sql.connector.catalog.CatalogManager @@ -43,6 +43,12 @@ class MetadataResolver( with DelegatesResolutionToExtensions { override val relationsWithResolvedMetadata = new RelationsWithResolvedMetadata + /** + * [[ProhibitedResolver]] passed as an argument to [[tryDelegateResolutionToExtensions]], because + * unresolved subtree resolution doesn't make sense during metadata resolution traversal. + */ + private val prohibitedResolver = new ProhibitedResolver + /** * Resolves the relation metadata for `unresolvedPlan`. Usually this involves several blocking * calls for the [[UnresolvedRelation]]s present in that tree. During the `unresolvedPlan` @@ -52,7 +58,7 @@ class MetadataResolver( * [[RelationResolution]] wasn't successful, we resort to using [[extensions]]. * Otherwise, we fail with an exception. */ - def resolve(unresolvedPlan: LogicalPlan): Unit = { + override def resolve(unresolvedPlan: LogicalPlan): Unit = { traverseLogicalPlanTree(unresolvedPlan) { unresolvedOperator => unresolvedOperator match { case unresolvedRelation: UnresolvedRelation => @@ -63,7 +69,7 @@ class MetadataResolver( // In case the generic metadata resolution returned `None`, we try to check if any // of the [[extensions]] matches this `unresolvedRelation`, and resolve it using // that extension. - tryDelegateResolutionToExtension(unresolvedRelation) + tryDelegateResolutionToExtension(unresolvedRelation, prohibitedResolver) } relationWithResolvedMetadata match { @@ -73,9 +79,6 @@ class MetadataResolver( relationWithResolvedMetadata ) case None => - withPosition(unresolvedRelation) { - unresolvedRelation.tableNotFound(unresolvedRelation.multipartIdentifier) - } } } case _ => @@ -96,7 +99,8 @@ class MetadataResolver( /** * Traverse the logical plan tree from `root` in a pre-order DFS manner and apply `visitor` to - * each node. + * each node. This function handles the whole operator tree, its child expression subqueries and + * inner children (e.g. [[UnresolvedWith]] CTE definitions). */ private def traverseLogicalPlanTree(root: LogicalPlan)(visitor: LogicalPlan => Unit) = { val stack = new ArrayDeque[Either[LogicalPlan, Expression]] @@ -110,6 +114,13 @@ class MetadataResolver( for (child <- logicalPlan.children) { stack.push(Left(child)) } + for (innerChild <- logicalPlan.innerChildren) { + innerChild match { + case plan: LogicalPlan => + stack.push(Left(plan)) + case _ => + } + } for (expression <- logicalPlan.expressions) { stack.push(Right(expression)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala index 8abf4e04b8836..d16e929f08a96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala @@ -17,252 +17,374 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import java.util.{ArrayDeque, ArrayList, HashSet} +import java.util.{ArrayDeque, HashSet} import scala.collection.mutable import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{Resolver => NameComparator, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{ - Alias, - Attribute, - AttributeSeq, - Expression, - NamedExpression -} -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSeq, ExprId, NamedExpression} +import org.apache.spark.sql.internal.SQLConf /** - * The [[NameScope]] is used during the analysis to control the visibility of names: plan names - * and output attributes. New [[NameScope]] can be created both in the [[Resolver]] and in - * the [[ExpressionResolver]] using the [[NameScopeStack]] api. The name resolution for identifiers - * is case-insensitive. + * The [[NameScope]] is used to control the resolution of names (table, column, alias identifiers). + * It's a part of the [[Resolver]]'s state, and is used to manage the output of SQL query/DataFrame + * program operators. + * + * The [[NameScope]] output is immutable. If it's necessary to update the output, + * [[NameScopeStack]] methods are used ([[overwriteTop]] or [[withNewScope]]). The [[NameScope]] + * is always used through the [[NameScopeStack]]. + * + * The resolution of identifiers is case-insensitive. + * + * Name resolution priority is as follows: + * + * 1. Resolution of local references: + * - column reference + * - struct field or map key reference + * 2. Resolution of lateral column aliases (if enabled). + * + * For example, in a query like: * - * In this example: + * {{{ SELECT 1 AS col1, col1 FROM VALUES (2) }}} + * + * Because column resolution has a higher priority than LCA resolution, the result will be [1, 2] + * and not [1, 1]. + * + * Approximate tree of [[NameScope]] manipulations is shown in the following example: * * {{{ - * WITH table_1_cte AS ( - * SELECT - * col1, - * col2, - * col2 - * FROM - * table_1 - * ) + * CREATE TABLE IF NOT EXISTS t1 (col1 INT, col2 INT, col3 STRING); + * * SELECT - * table_1_cte.col1, - * table_2.col1 + * col1, col2 as alias1 * FROM - * table_1_cte - * INNER JOIN - * table_2 - * ON - * table_1_cte.col2 = table_2.col3 + * (SELECT * FROM VALUES (1, 2)) + * UNION + * (SELECT t2.col1, t2.col2 FROM (SELECT col1, col2 FROM t1) AS t2) * ; * }}} * - * there are two named subplans in the scope: table_1_cte -> [col1, col2, col2] and - * table_2 -> [col1, col3]. + * -> * - * State breakout: - * - `planOutputs`: list of named plan outputs. Order matters here (e.g. to correctly expand `*`). - * Can contain duplicate names, since it's possible to select same column twice, or to select - * columns with the same name from different relations. [[OptionalIdentifierMap]] is used here, - * since some plans don't have an explicit name, so output attributes from those plans will reside - * under the `None` key. - * In our example it will be {{{ [(table_1_cte, [col1, col2, col2]), (table_2, [col1, col3])] }}} + * {{{ + * unionAttributes = withNewScope { + * lhsOutput = withNewScope { + * expandedStar = withNewScope { + * scope.overwriteTop(localRelation.output) + * scope.expandStar(star) + * } + * scope.overwriteTop(expandedStar) + * scope.output + * } + * rhsOutput = withNewScope { + * subqueryAttributes = withNewScope { + * scope.overwriteTop(t1.output) + * scope.overwriteTop(prependQualifier(scope.output, "t2")) + * [scope.matchMultiPartName("t2", "col1"), scope.matchMultiPartName("t2", "col2")] + * } + * scope.overwriteTop(subqueryAttributes) + * scope.output + * } + * scope.overwriteTop(coerce(lhsOutput, rhsOutput)) + * [scope.matchMultiPartName("col1"), alias(scope.matchMultiPartName("col2"), "alias1")] + * } + * scope.overwriteTop(unionAttributes) + * }}} * - * - `planNameToOffset`: mapping from plan output names to their offsets in the `planOutputs` array. - * It's used to lookup attributes by plan output names (multipart names are not supported yet). - * In our example it will be {{{ [table_1_cte -> 0, table_2 -> 1] }}} + * @param output These are the attributes visible for lookups in the current scope. + * These may be: + * - Transformed outputs of lower scopes (e.g. type-coerced outputs of [[Union]]'s children). + * - Output of a current operator that is being resolved (leaf nodes like [[Relations]]). */ -class NameScope extends SQLConfHelper { - private val planOutputs = new ArrayList[PlanOutput]() - private val planNameToOffset = new OptionalIdentifierMap[Int] - private val nameComparator: NameComparator = conf.resolver - private val existingAliases = new HashSet[String] +class NameScope(val output: Seq[Attribute] = Seq.empty) extends SQLConfHelper { /** - * Register the named plan output in this [[NameScope]]. The named plan is usually a - * [[NamedRelation]]. `attributes` sequence can contain duplicate names both for this named plan - * and for the scope in general, despite the fact that their further resolution _may_ throw an - * error in case of ambiguous reference. After calling this method, the code can lookup the - * attributes using `get*` methods of this [[NameScope]]. - * - * Duplicate plan names are merged into the same [[PlanOutput]]. For example, this query: - * - * {{{ SELECT t.* FROM (SELECT * FROM VALUES (1)) as t, (SELECT * FROM VALUES (2)) as t; }}} - * - * will have the following output schema: - * - * {{{ [col1, col1] }}} - * - * Same logic applies for the unnamed plan outputs. This query: - * - * {{{ SELECT * FROM (SELECT * FROM VALUES (1)), (SELECT * FROM VALUES (2)); }}} - * - * will have the same output schema: - * - * {{{ [col1, col1] }}} - * - * @param name The name of this named plan. - * @param attributes The output of this named plan. Can contain duplicate names. + * [[nameComparator]] is a function that is used to compare two identifiers. Its implementation + * depends on the "spark.sql.caseSensitive" configuration - whether to respect case sensitivity + * or not. */ - def update(name: String, attributes: Seq[Attribute]): Unit = { - update(attributes, Some(name)) - } + private val nameComparator: NameComparator = conf.resolver /** - * Register the unnamed plan output in this [[NameScope]]. Some examples of the unnamed plan are - * [[Project]] and [[Aggregate]]. - * - * See the [[update]] method for more details. - * - * @param attributes The output of the unnamed plan. Can contain duplicate names. + * [[attributesForResolution]] is an [[AttributeSeq]] that is used for resolution of + * multipart attribute names. It's created from the `attributes` when [[NameScope]] is updated. */ - def +=(attributes: Seq[Attribute]): Unit = { - update(attributes) - } + private val attributesForResolution: AttributeSeq = AttributeSeq.fromNormalOutput(output) /** - * Get all the attributes from all the plans registered in this [[NameScope]]. The output can - * contain duplicate names. This is used for star (`*`) resolution. + * [[attributesByName]] is used to look up attributes by one-part name from the operator's output. + * This is a lazy val, since in most of the cases [[ExpressionResolver]] doesn't need it and + * accesses a generic [[attributesForResolution]] in [[resolveMultipartName]]. */ - def getAllAttributes: Seq[Attribute] = { - val attributes = new mutable.ArrayBuffer[Attribute] + private lazy val attributesByName = createAttributesByName(output) - planOutputs.forEach(planOutput => { - attributes.appendAll(planOutput.attributes) - }) + /** + * Expression IDs from `output`. See [[hasAttributeWithId]] for more details. + */ + private lazy val attributeIds = createAttributeIds(output) - attributes.toSeq + private val isLcaEnabled = conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED) + lazy val lcaRegistry: LateralColumnAliasRegistry = if (isLcaEnabled) { + new LateralColumnAliasRegistryImpl(output) + } else { + new LateralColumnAliasProhibitedRegistry } /** - * Expand the [[UnresolvedStar]] using `planOutputs`. The expected use case for this method is - * star expansion inside [[Project]]. Since [[Project]] has only one child, we assert that the - * size of `planOutputs` is 1, otherwise the query is malformed. - * - * Some examples of queries with a star: - * - * - Star without a target: - * {{{ SELECT * FROM VALUES (1, 2, 3) AS t(a, b, c); }}} - * - Star with a multipart name target: - * {{{ SELECT catalog1.database1.table1.* FROM catalog1.database1.table1; }}} - * - Star with a struct target: - * {{{ SELECT d.* FROM VALUES (named_struct('a', 1, 'b', 2)) AS t(d); }}} - * - Star as an argument to a function: - * {{{ SELECT concat_ws('', *) AS result FROM VALUES (1, 2, 3) AS t(a, b, c); }}} - * - * It is resolved by correctly resolving the star qualifier. - * Please check [[UnresolvedStarBase.expandStar]] for more details. - * - * @param unresolvedStar [[UnresolvedStar]] to expand. - * @return The output of a plan expanded from the star. + * Expand the [[UnresolvedStar]]. The expected use case for this method is star expansion inside + * [[Project]]. + * + * Star without a target: + * + * {{{ + * -- Here the star will be expanded to [a, b, c]. + * SELECT * FROM VALUES (1, 2, 3) AS t(a, b, c); + * }}} + * + * Star with a multipart name target: + * + * {{{ + * USE CATALOG catalog1; + * USE DATABASE database1; + * + * CREATE TABLE IF NOT EXISTS table1 (col1 INT, col2 INT); + * + * -- Here the star will be expanded to [col1, col2]. + * SELECT catalog1.database1.table1.* FROM catalog1.database1.table1; + * }}} + * + * Star with a struct target: + * + * {{{ + * -- Here the star will be expanded to [field1, field2]. + * SELECT d.* FROM VALUES (named_struct('field1', 1, 'field2', 2)) AS t(d); + * }}} + * + * Star as an argument to a function: + * + * {{{ + * -- Here the star will be expanded to [col1, col2, col3] and those would be passed as + * -- arguments to `concat_ws`. + * SELECT concat_ws('', *) AS result FROM VALUES (1, 2, 3); + * }}} + * + * Also, see [[UnresolvedStarBase.expandStar]] for more details. */ def expandStar(unresolvedStar: UnresolvedStar): Seq[NamedExpression] = { - if (planOutputs.size != 1) { - throw QueryCompilationErrors.invalidStarUsageError("query", Seq(unresolvedStar)) - } - - planOutputs.get(0).expandStar(unresolvedStar) + unresolvedStar.expandStar( + childOperatorOutput = output, + childOperatorMetadataOutput = Seq.empty, + resolve = + (nameParts, nameComparator) => attributesForResolution.resolve(nameParts, nameComparator), + suggestedAttributes = output, + resolver = nameComparator, + cleanupNestedAliasesDuringStructExpansion = true + ) } /** - * Get all matched attributes by a multipart name. It returns [[Attribute]]s when we resolve a - * simple column or an alias name from a lower operator. However this function can also return - * [[Alias]]es in case we access a struct field or a map value using some key. + * Resolve multipart name into a [[NameTarget]]. [[NameTarget]]'s `candidates` may contain + * simple [[AttributeReference]]s if it's a column or alias, or [[ExtractValue]] expressions if + * it's a struct field, map value or array value. The `aliasName` will optionally be set to the + * proposed alias name for the value extracted from a struct, map or array. * - * Example that contains those major use-cases: + * Example that demonstrates those major use-cases: * * {{{ - * SELECT col1, a, col2.field, col3.struct.field, col4.key - * FROM (SELECT *, col5 AS a FROM t); + * CREATE TABLE IF NOT EXISTS t ( + * col1 INT, + * col2 STRUCT, + * col3 STRUCT>, + * col4 MAP, + * col5 STRING + * ); + * + * -- For the SELECT below the top Project list will be resolved using this method like this: + * -- AttributeReference(col1), + * -- AttributeReference(a), + * -- GetStructField(col2, field), + * -- GetStructField(GetStructField(col3, struct), field), + * -- GetMapValue(col4, key) + * SELECT + * col1, a, col2.field, col3.struct.field, col4.key + * FROM + * (SELECT *, col5 AS a FROM t); * }}} * - * has a Project list that looks like this: + * Since there can be several expressions that matched the same multipart name, this method may + * return a [[NameTarget]] with the following `candidates`: + * - 0 values: No matched expressions + * - 1 value: Unique expression matched + * - 1+ values: Ambiguity, several expressions matched + * + * Some examples of ambiguity: * * {{{ - * AttributeReference(col1), - * AttributeReference(a), - * Alias(col2.field, field), - * Alias(col3.struct.field, field), - * Alias(col4[CAST(key AS INT)], key) + * CREATE TABLE IF NOT EXISTS t1 (c1 INT, c2 INT); + * CREATE TABLE IF NOT EXISTS t2 (c2 INT, c3 INT); + * + * -- Identically named columns from different tables. + * -- This will fail with AMBIGUOUS_REFERENCE error. + * SELECT c2 FROM t1, t2; * }}} * - * Also, see [[AttributeSeq.resolve]] for more details. + * {{{ + * CREATE TABLE IF NOT EXISTS foo (c1 INT); + * CREATE TABLE IF NOT EXISTS bar (foo STRUCT); * - * Since there can be several identical attribute names for several named plans, this function - * can return multiple values: - * - 0 values: No matched attributes - * - 1 value: Unique attribute matched - * - 1+ values: Ambiguity, several attributes matched + * -- Ambiguity between a column in a table and a field in a struct. + * -- This will succeed, and column will win over the struct field. + * SELECT foo.c1 FROM foo, bar; + * }}} * - * One example of a query with an attribute that has a multipart name: + * The candidates are deduplicated by expression ID (not by attribute name!): * - * {{{ SELECT catalog1.database1.table1.col1 FROM catalog1.database1.table1; }}} + * {{{ + * CREATE TABLE IF NOT EXISTS t1 (col1 STRING); + * + * -- No ambiguity here, since we are selecting the same column (same expression ID). + * SELECT col1 FROM (SELECT col1, col1 FROM t); + * }}} + * + * The case of the `multipartName` takes precedence over the original name case, so the candidates + * will have names that are case-identical to the `multipartName`: + * + * {{{ + * CREATE TABLE IF NOT EXISTS t1 (col1 STRING); * - * @param multipartName Multipart attribute name. Can be of several forms: - * - `catalog.database.table.column` - * - `database.table.column` - * - `table.column` - * - `column` - * @return All the attributes matched by the `multipartName`, encapsulated in a [[NameTarget]]. + * -- The output schema of this query is [COL1], despite the fact that the column is in + * -- lower-case. + * SELECT COL1 FROM t; + * }}} + * + * We are relying on the [[AttributeSeq]] to perform that work, since it requires complex + * resolution logic involving nested field extraction and multipart name matching. + * + * Also, see [[AttributeSeq.resolve]] for more details. */ - def matchMultipartName(multipartName: Seq[String]): NameTarget = { - val candidates = new mutable.ArrayBuffer[Expression] - val allAttributes = new mutable.ArrayBuffer[Attribute] - var aliasName: Option[String] = None + def resolveMultipartName( + multipartName: Seq[String], + canLaterallyReferenceColumn: Boolean = true): NameTarget = { + val (candidates, nestedFields) = + attributesForResolution.getCandidatesForResolution(multipartName, nameComparator) - planOutputs.forEach(planOutput => { - allAttributes.appendAll(planOutput.attributes) - val nameTarget = planOutput.matchMultipartName(multipartName) - if (nameTarget.aliasName.isDefined) { - aliasName = nameTarget.aliasName + val (candidatesWithLCAs: Seq[Attribute], referencedAttribute: Option[Attribute]) = + if (candidates.isEmpty && canLaterallyReferenceColumn) { + getLcaCandidates(multipartName) + } else { + (candidates, None) } - candidates.appendAll(nameTarget.candidates) - }) - NameTarget(candidates.toSeq, aliasName, allAttributes.toSeq) + val resolvedCandidates = attributesForResolution.resolveCandidates( + multipartName, + nameComparator, + candidatesWithLCAs, + nestedFields + ) + + resolvedCandidates match { + case Seq(Alias(child, aliasName)) => + NameTarget( + candidates = Seq(child), + aliasName = Some(aliasName), + lateralAttributeReference = referencedAttribute, + output = output + ) + case other => + NameTarget( + candidates = other, + lateralAttributeReference = referencedAttribute, + output = output + ) + } } /** - * Add an alias, by name, to the list of existing aliases. + * Find attributes in this [[NameScope]] that match a provided one-part `name`. + * + * This method is simpler and more lightweight than [[resolveMultipartName]], because here we + * just return all the attributes matched by the one-part `name`. This is only suitable + * for situations where name _resolution_ is not required (e.g. accessing struct fields + * from the lower operator's output). + * + * For example, this method is used to look up attributes to match a specific [[View]] schema. + * See [[ExpressionResolver.resolveGetViewColumnByNameAndOrdinal]] for more info on view column + * lookup. + * + * We are relying on a simple [[IdentifierMap]] to perform that work, since we just need to match + * one-part name from the lower operator's output here. */ - def addAlias(aliasName: String): Unit = existingAliases.add(aliasName.toLowerCase()) + def findAttributesByName(name: String): Seq[Attribute] = { + attributesByName.get(name) match { + case Some(attributes) => attributes.toSeq + case None => Seq.empty + } + } /** - * Returns whether an alias exists in the current scope. + * Check if `output` contains attributes with `expressionId`. This is used to disable missing + * attribute propagation for DataFrames, because we don't support it yet. */ - def isExistingAlias(aliasName: String): Boolean = - existingAliases.contains(aliasName.toLowerCase()) + def hasAttributeWithId(expressionId: ExprId): Boolean = { + attributeIds.contains(expressionId) + } - private def update(attributes: Seq[Attribute], name: Option[String] = None): Unit = { - planNameToOffset.get(name) match { - case Some(index) => - val prevPlanOutput = planOutputs.get(index) - planOutputs.set( - index, - new PlanOutput(prevPlanOutput.attributes ++ attributes, name, nameComparator) - ) - case None => - val index = planOutputs.size - planOutputs.add(new PlanOutput(attributes, name, nameComparator)) - planNameToOffset += (name -> index) + /** + * If a candidate was not found from output attributes, returns the candidate from lateral + * columns. Here we do [[AttributeSeq.fromNormalOutput]] because a struct field can also be + * laterally referenced and we need to properly resolve [[GetStructField]] node. + */ + private def getLcaCandidates(multipartName: Seq[String]): (Seq[Attribute], Option[Attribute]) = { + val referencedAttribute = lcaRegistry.getAttribute(multipartName.head) + if (referencedAttribute.isDefined) { + val attributesForResolution = AttributeSeq.fromNormalOutput(Seq(referencedAttribute.get)) + val (newCandidates, _) = + attributesForResolution.getCandidatesForResolution(multipartName, nameComparator) + (newCandidates, Some(referencedAttribute.get)) + } else { + (Seq.empty, None) + } + } + + private def createAttributesByName( + attributes: Seq[Attribute]): IdentifierMap[mutable.ArrayBuffer[Attribute]] = { + val result = new IdentifierMap[mutable.ArrayBuffer[Attribute]] + for (attribute <- attributes) { + result.get(attribute.name) match { + case Some(attributesForThisName) => + attributesForThisName += attribute + case None => + val attributesForThisName = new mutable.ArrayBuffer[Attribute] + attributesForThisName += attribute + + result += (attribute.name, attributesForThisName) + } } + + result + } + + private def createAttributeIds(attributes: Seq[Attribute]): HashSet[ExprId] = { + val result = new HashSet[ExprId] + for (attribute <- attributes) { + result.add(attribute.exprId) + } + + result } } /** - * The [[NameScopeStack]] is a stack of [[NameScope]]s managed by the [[Resolver]] and the - * [[ExpressionResolver]]. Usually a top scope is used for name resolution, but in case of - * correlated subqueries we can lookup names in the parent scopes. Low-level scope creation is - * managed internally, and only high-level api like [[withNewScope]] is available to the resolvers. - * Freshly-created [[NameScopeStack]] contains an empty root [[NameScope]]. + * The [[NameScopeStack]] is a stack of [[NameScope]]s managed by the [[Resolver]]. Usually a top + * scope is used for name resolution, but in case of correlated subqueries we can lookup names in + * the parent scopes. Low-level scope creation is managed internally, and only high-level api like + * [[withNewScope]] is available to the resolvers. Freshly-created [[NameScopeStack]] contains an + * empty root [[NameScope]], which in the context of [[Resolver]] corresponds to the query output. */ class NameScopeStack extends SQLConfHelper { private val stack = new ArrayDeque[NameScope] - push() + stack.push(new NameScope) /** * Get the top scope, which is a default choice for name resolution. @@ -272,122 +394,74 @@ class NameScopeStack extends SQLConfHelper { } /** - * Completely overwrite the top scope state with a named plan output. + * Completely overwrite the top scope state with operator `output`. * - * See [[NameScope.update]] for more details. - */ - def overwriteTop(name: String, attributes: Seq[Attribute]): Unit = { - val newScope = new NameScope - newScope.update(name, attributes) - - stack.pop() - stack.push(newScope) - } - - /** - * Completely overwrite the top scope state with an unnamed plan output. + * This method is called by the [[Resolver]] when we've calculated the output of an operator that + * is being resolved. The new output is calculated based on the outputs of operator's children. + * + * Example for [[SubqueryAlias]], here we rewrite the top [[NameScope]]'s attributes to prepend + * subquery qualifier to their names: * - * See [[NameScope.+=]] for more details. + * {{{ + * val qualifier = sa.identifier.qualifier :+ sa.alias + * scope.overwriteTop(scope.output.map(attribute => attribute.withQualifier(qualifier))) + * }}} + * + * Trivially, we would call this method for every operator in the query plan, + * however some operators just propagate the output of their children without any changes, so + * we can omit this call for them (e.g. [[Filter]]). + * + * This method should be preferred over [[withNewScope]]. */ - def overwriteTop(attributes: Seq[Attribute]): Unit = { - val newScope = new NameScope - newScope += attributes + def overwriteTop(output: Seq[Attribute]): Unit = { + val newScope = new NameScope(output) stack.pop() stack.push(newScope) } /** - * Execute `body` in a context of a fresh scope. It's used during the [[Project]] or the - * [[Aggregate]] resolution to avoid calling [[push]] and [[pop]] explicitly. + * Execute `body` in a context of a fresh scope. + * + * This method is called by the [[Resolver]] before recursing into the operator's child + * resolution _only_ in cases where a fresh scope is required. + * + * For esample, [[Project]] or [[Aggregate]] introduce their own scopes semantically, so that a + * lower resolution can lookup correlated names: + * + * {{{ + * CREATE TABLE IF NOT EXISTS t1 (col1 INT, col2 STRING); + * CREATE TABLE IF NOT EXISTS t2 (col1 INT, col2 STRING); + * + * -- Here we need a scope for the upper [[Project]], and a separate scope for the correlated + * -- subquery, because its [[Filter]] need to lookup `t1.col1` from the upper scope. + * -- Those scopes have to be indenepdent to avoid polluting each other's attributes. + * SELECT col1, (SELECT col2 FROM t2 WHERE t2.col1 == t1.col1 LIMIT 1) FROM t1; + * }}} + * + * Also, we need separate scopes for the operators with multiple children, so that the next + * child's resolution woudn't try to work with the data from it's sibling's scope, to avoid + * all kinds of undefined behavior: + * + * {{{ + * val resolvedLeftChild = withNewScope { + * resolve(unresolvedExcept.left) + * } + * + * // Right child should not see the left child's resolution data to avoid subtle bugs, so we + * // create a fresh scope here. + * + * val resolvedRightChild = withNewScope { + * resolve(unresolvedExcept.right) + * } + * }}} */ def withNewScope[R](body: => R): R = { - push() + stack.push(new NameScope) try { body } finally { - pop() - } - } - - /** - * Push a new scope to the stack. Introduced by the [[Project]] or the [[Aggregate]]. - */ - private def push(): Unit = { - stack.push(new NameScope) - } - - /** - * Pop a scope from the stack. Called when the resolution process for the pushed scope is done. - */ - private def pop(): Unit = { - stack.pop() - } -} - -/** - * [[PlanOutput]] represents a sequence of attributes from a plan ([[NamedRelation]], [[Project]], - * [[Aggregate]], etc). - * - * It is created from `attributes`, which is an output of a named plan, optional plan `name` and a - * resolver provided by the [[NameScopeStack]]. - * - * @param attributes Plan output. Can contain duplicate names. - * @param name Plan name. Non-empty for named plans like [[NamedRelation]] or [[SubqueryAlias]], - * `None` otherwise. - */ -class PlanOutput( - val attributes: Seq[Attribute], - val name: Option[String], - val nameComparator: NameComparator) { - - /** - * attributesForResolution is an [[AttributeSeq]] that is used for resolution of - * multipart attribute names. It's created from the `attributes` when [[NameScope]] is updated. - */ - private val attributesForResolution: AttributeSeq = - AttributeSeq.fromNormalOutput(attributes) - - /** - * Find attributes by the multipart name. - * - * See [[NameScope.matchMultipartName]] for more details. - * - * @param multipartName Multipart attribute name. - * @return Matched attributes or [[Seq.empty]] otherwise. - */ - def matchMultipartName(multipartName: Seq[String]): NameTarget = { - val (candidates, nestedFields) = - attributesForResolution.getCandidatesForResolution(multipartName, nameComparator) - val resolvedCandidates = attributesForResolution.resolveCandidates( - multipartName, - nameComparator, - candidates, - nestedFields - ) - resolvedCandidates match { - case Seq(Alias(child, aliasName)) => - NameTarget(Seq(child), Some(aliasName)) - case other => - NameTarget(other, None) + stack.pop() } } - - /** - * Method to expand an unresolved star. See [[NameScope.expandStar]] for more details. - * - * @param unresolvedStar Star to resolve. - * @return Attributes expanded from the star. - */ - def expandStar(unresolvedStar: UnresolvedStar): Seq[NamedExpression] = { - unresolvedStar.expandStar( - childOperatorOutput = attributes, - childOperatorMetadataOutput = Seq.empty, - resolve = - (nameParts, nameComparator) => attributesForResolution.resolve(nameParts, nameComparator), - suggestedAttributes = attributes, - resolver = nameComparator, - cleanupNestedAliasesDuringStructExpansion = true - ) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameTarget.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameTarget.scala index 3b31c9b1a9110..9d6a62b7fc60a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameTarget.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameTarget.scala @@ -23,51 +23,62 @@ import org.apache.spark.sql.catalyst.util.StringUtils.orderSuggestedIdentifiersB import org.apache.spark.sql.errors.QueryCompilationErrors /** - * Class that represents results of name resolution or star expansion. It encapsulates: - * - `candidates` - A list of candidates that are possible matches for a given name. - * - `aliasName` - If the candidates size is 1 and it's type is `ExtractValue` (which means that - * it's a recursive type), then the `aliasName` should be the name with which the candidate is - * aliased. Otherwise, `aliasName` should be `None`. - * - `allAttributes` - A list of all attributes which is used to generate suggestions for - * unresolved column error. + * [[NameTarget]] is a result of a multipart name resolution of the + * [[NameScope.resolveMultipartName]]. * - * Example: + * Attribute resolution: * - * - Attribute resolution: - * {{{ SELECT col1 FROM VALUES (1); }}} will have a [[NameTarget]] with a single candidate `col1`. - * `aliasName` would be `None` in this case because the column is not of recursive type. + * {{{ + * -- [[NameTarget]] with a single candidate `col1`. `aliasName` is be `None` in this case because + * -- the name is not a field/value/item of some recursive type. + * SELECT col1 FROM VALUES (1); + * }}} * - * - Recursive attribute resolution: - * {{{ SELECT col1.col1 FROM VALUES(STRUCT(1,2), 3) }}} will have a [[NameTarget]] with a - * single candidate `col1` and an `aliasName` of `Some("col1")`. + * Attribute resolution ambiguity: + * + * {{{ + * -- [[NameTarget]] with candidates `col1`, `col1`. [[pickCandidate]] will throw + * -- `AMBIGUOUS_REFERENCE`. + * SELECT col1 FROM VALUES (1) t1, VALUES (2) t2; + * }}} + * + * Struct field resolution: + * + * {{{ + * -- [[NameTarget]] with a single candidate `GetStructField(col1, "field1")`. `aliasName` is + * -- `Some("col1")`, since here we extract a field of a struct. + * SELECT col1.field1 FROM VALUES (named_struct('field1', 1), 3); + * }}} + * + * @param candidates A list of candidates that are possible matches for a given name. + * @param aliasName If the candidates size is 1 and it's type is [[ExtractValue]] (which means that + * it's a field/value/item from a recursive type), then the `aliasName` should be the name with + * which the candidate needs to be aliased. Otherwise, `aliasName` is `None`. + * @param lateralAttributeReference If the candidate is laterally referencing another column this + * field is populated with that column's attribute. + * @param output [[output]] of a [[NameSope]] that produced this [[NameTarget]]. Used to provide + * suggestions for thrown errors. */ case class NameTarget( candidates: Seq[Expression], aliasName: Option[String] = None, - allAttributes: Seq[Attribute] = Seq.empty) { + lateralAttributeReference: Option[Attribute] = None, + output: Seq[Attribute] = Seq.empty) { /** - * Picks a candidate from the list of candidates based on the given unresolved attribute. - * Its behavior is as follows (based on the number of candidates): - * - * - If there is only one candidate, it will be returned. - * - * - If there are multiple candidates, an ambiguous reference error will be thrown. - * - * - If there are no candidates, an unresolved column error will be thrown. + * Pick a single candidate from `candidates`: + * - If there are no candidates, throw `UNRESOLVED_COLUMN.WITH_SUGGESTION`. + * - If there are several candidates, throw `AMBIGUOUS_REFERENCE`. + * - Otherwise, return a single [[Expression]]. */ def pickCandidate(unresolvedAttribute: UnresolvedAttribute): Expression = { - candidates match { - case Seq() => - throwUnresolvedColumnError(unresolvedAttribute) - case Seq(candidate) => - candidate - case _ => - throw QueryCompilationErrors.ambiguousReferenceError( - unresolvedAttribute.name, - candidates.collect { case attribute: AttributeReference => attribute } - ) + if (candidates.isEmpty) { + throwUnresolvedColumnError(unresolvedAttribute) } + if (candidates.length > 1) { + throwAmbiguousReferenceError(unresolvedAttribute) + } + candidates.head } private def throwUnresolvedColumnError(unresolvedAttribute: UnresolvedAttribute): Nothing = @@ -75,7 +86,13 @@ case class NameTarget( unresolvedAttribute.name, proposal = orderSuggestedIdentifiersBySimilarity( unresolvedAttribute.name, - candidates = allAttributes.map(attribute => attribute.qualifier :+ attribute.name) + candidates = output.map(attribute => attribute.qualifier :+ attribute.name) ) ) + + private def throwAmbiguousReferenceError(unresolvedAttribute: UnresolvedAttribute): Nothing = + throw QueryCompilationErrors.ambiguousReferenceError( + unresolvedAttribute.name, + candidates.collect { case attribute: AttributeReference => attribute } + ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PredicateResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PredicateResolver.scala index d94559496d04e..1c4d8dd50113b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PredicateResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PredicateResolver.scala @@ -42,13 +42,14 @@ class PredicateResolver( extends TreeNodeResolver[Predicate, Expression] with ResolvesExpressionChildren { - private val typeCoercionRules = if (conf.ansiEnabled) { - PredicateResolver.ANSI_TYPE_COERCION_RULES - } else { - PredicateResolver.TYPE_COERCION_RULES - } - private val typeCoercionResolver = - new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionRules) + private val typeCoercionTransformations: Seq[Expression => Expression] = + if (conf.ansiEnabled) { + PredicateResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS + } else { + PredicateResolver.TYPE_COERCION_TRANSFORMATIONS + } + private val typeCoercionResolver: TypeCoercionResolver = + new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionTransformations) override def resolve(unresolvedPredicate: Predicate): Expression = { val predicateWithResolvedChildren = @@ -86,7 +87,7 @@ class PredicateResolver( object PredicateResolver { // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( CollationTypeCoercion.apply, TypeCoercion.InTypeCoercion.apply, StringPromotionTypeCoercion.apply, @@ -99,7 +100,7 @@ object PredicateResolver { ) // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( CollationTypeCoercion.apply, AnsiTypeCoercion.InTypeCoercion.apply, AnsiStringPromotionTypeCoercion.apply, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProhibitedResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProhibitedResolver.scala new file mode 100644 index 0000000000000..f6ba53ec3ee4b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProhibitedResolver.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * This is a dummy [[LogicalPlanResolver]] whose [[resolve]] is not implemented and throws + * [[SparkException]]. + * + * It's used by the [[MetadataResolver]] to pass it as an argument to + * [[tryDelegateResolutionToExtensions]], because unresolved subtree resolution doesn't make sense + * during metadata resolution traversal. + */ +class ProhibitedResolver extends LogicalPlanResolver { + def resolve(plan: LogicalPlan): LogicalPlan = { + throw SparkException.internalError("Resolver cannot be used here") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala new file mode 100644 index 0000000000000..ec4caae1bd76c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.catalyst.analysis.{withPosition, AnalysisErrorAt} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.internal.SQLConf + +/** + * Resolves initially unresolved [[Project]] operator to either a resolved [[Project]] or + * [[Aggregate]] node, based on whether there are aggregate expressions in the project list. When + * LateralColumnAlias resolution is enabled, replaces the output operator with an appropriate + * operator structure using information from the scope. Detailed explanation can be found in + * [[buildProjectWithResolvedLCAs]] method. + */ +class ProjectResolver( + operatorResolver: Resolver, + expressionResolver: ExpressionResolver, + scopes: NameScopeStack) + extends TreeNodeResolver[Project, LogicalPlan] { + + private val isLcaEnabled = conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED) + + /** + * * [[Project]] introduces a new scope to resolve its subtree and project list expressions. + * * During the resolution we determine whether the output operator will be [[Aggregate]] or + * * [[Project]] (based on the `hasAggregateExpressions` flag). + * + * If the output operator is [[Project]] and if lateral column alias resolution is enabled, we + * construct a multi-level [[Project]], created from all lateral column aliases and their + * dependencies. Finally, we place the original resolved project on top of this multi-level one. + * + * After the subtree and project-list expressions are resolved in the child scope we overwrite + * current scope with resolved operators output to expose new names to the parent operators. + */ + override def resolve(unresolvedProject: Project): LogicalPlan = { + val (resolvedOperator, resolvedProjectList) = scopes.withNewScope { + val resolvedChild = operatorResolver.resolve(unresolvedProject.child) + val resolvedProjectList = + expressionResolver.resolveProjectList(unresolvedProject.projectList, unresolvedProject) + if (resolvedProjectList.hasAggregateExpressions) { + if (resolvedProjectList.hasLateralColumnAlias) { + // Disable LCA in Aggregates until fully supported. + throw new ExplicitlyUnsupportedResolverFeature("LateralColumnAlias in Aggregate") + } + val aggregate = Aggregate( + groupingExpressions = Seq.empty[Expression], + aggregateExpressions = resolvedProjectList.expressions, + child = resolvedChild, + hint = None + ) + if (resolvedProjectList.hasAttributes) { + aggregate.failAnalysis(errorClass = "MISSING_GROUP_BY", messageParameters = Map.empty) + } + (aggregate, resolvedProjectList) + } else { + val projectWithLca = if (isLcaEnabled) { + buildProjectWithResolvedLCAs(resolvedChild, resolvedProjectList.expressions) + } else { + Project(resolvedProjectList.expressions, resolvedChild) + } + (projectWithLca, resolvedProjectList) + } + } + + withPosition(unresolvedProject) { + scopes.overwriteTop( + resolvedProjectList.expressions.map(namedExpression => namedExpression.toAttribute) + ) + } + + resolvedOperator + } + + /** + * Builds a multi-level [[Project]] with all lateral column aliases and their dependencies. First, + * from top scope, we acquire dependency levels of all aliases. Dependency level is defined as a + * number of attributes that an attribute depends on in the lateral alias reference chain. For + * example, in a query like: + * + * {{{ SELECT 0 AS a, 1 AS b, 2 AS c, b AS d, a AS e, d AS f, a AS g, g AS h, h AS i }}} + * + * Dependency levels will be as follows: + * + * level 0: a, b, c + * level 1: d, e, g + * level 2: f, h + * level 3: i + * + * Once we have dependency levels, we construct a multi-level [[Project]] in a following way: + * - There is exactly one [[Project]] node per level. + * - Project lists are compounded such that project lists on upper levels must contain all + * attributes from the below levels. + * - Project list on level 0 includes all attributes from the output of the operator below the + * original [[Project]]. + * - Original [[Project]] is placed on top of the multi-level [[Project]]. Any aliases that have + * been laterally referenced need to be replaced with only their names. This is because their + * full definitions ( `attr` as `name` ) have already been defined on lower levels. + * - If an attribute is never referenced, it does not show up in multi-level project lists, but + * instead only in the top-most [[Project]]. + * + * For previously given query, following above rules, resolved [[Project]] would look like: + * + * Project [a, b, 2 AS c, d, a AS e, d AS f, g, h, h AS i] + * +- Project [b, a, d, g, g AS h] + * +- Project [b, a, b AS d, a AS g] + * +- Project [1 AS b, 0 AS a] + * +- OneRowRelation + */ + private def buildProjectWithResolvedLCAs( + resolvedChild: LogicalPlan, + originalProjectList: Seq[NamedExpression]) = { + val aliasDependencyMap = scopes.top.lcaRegistry.getAliasDependencyLevels() + val (finalChildPlan, _) = aliasDependencyMap.asScala.foldLeft( + (resolvedChild, scopes.top.output.map(_.asInstanceOf[NamedExpression])) + ) { + case ((currentPlan, currentProjectList), availableAliases) => + val referencedAliases = new ArrayBuffer[Alias] + availableAliases.forEach( + alias => + if (scopes.top.lcaRegistry.isAttributeLaterallyReferenced(alias.toAttribute)) { + referencedAliases.append(alias) + } + ) + + if (referencedAliases.nonEmpty) { + val newProjectList = currentProjectList.map(_.toAttribute) ++ referencedAliases + (Project(newProjectList, currentPlan), newProjectList) + } else { + (currentPlan, currentProjectList) + } + } + + val finalProjectList = originalProjectList.map( + alias => + if (scopes.top.lcaRegistry.isAttributeLaterallyReferenced(alias.toAttribute)) { + alias.toAttribute + } else { + alias + } + ) + + Project(finalProjectList, finalChildPlan) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala index cf352842fd106..3bf15c51977f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RelationMetadataProvider.scala @@ -47,6 +47,13 @@ trait RelationMetadataProvider extends LookupCatalog { */ protected val relationsWithResolvedMetadata: RelationsWithResolvedMetadata + /** + * Resolve metadata for the given `unresolvedPlan`. This method is called once per unresolved + * logical plan by the [[Resolver]] (for each SQL query/ DataFrame program and for each + * nested [[View]] operator). + */ + def resolve(unresolvedPlan: LogicalPlan): Unit + /** * Get the [[LogicalPlan]] with resolved metadata for the given [[UnresolvedRelation]]. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala index 6c4de2e6e58d7..93034e931bb0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala @@ -17,9 +17,20 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, ResolvedInlineTable} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference} +import java.util.HashSet + +import org.apache.spark.sql.catalyst.analysis.{ + GetViewColumnByNameAndOrdinal, + MultiInstanceRelation, + ResolvedInlineTable, + SchemaBinding +} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{ + Aggregate, + CTERelationDef, + CTERelationRef, + Distinct, Filter, GlobalLimit, LocalLimit, @@ -27,10 +38,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{ LogicalPlan, OneRowRelation, Project, - SubqueryAlias + SubqueryAlias, + Union, + View, + WithCTE } +import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, DataType, MetadataBuilder, StructType} /** * The [[ResolutionValidator]] performs the validation work after the logical plan tree is @@ -43,6 +58,7 @@ class ResolutionValidator { private val expressionResolutionValidator = new ExpressionResolutionValidator(this) private[resolver] var attributeScopeStack = new AttributeScopeStack + private val cteRelationDefIds = new HashSet[Long] /** * Validate the resolved logical `plan` - assert invariants that should never be false no @@ -56,28 +72,81 @@ class ResolutionValidator { private def validate(operator: LogicalPlan): Unit = { operator match { + case withCte: WithCTE => + validateWith(withCte) + case cteRelationDef: CTERelationDef => + validateCteRelationDef(cteRelationDef) + case cteRelationRef: CTERelationRef => + validateCteRelationRef(cteRelationRef) + case aggregate: Aggregate => + validateAggregate(aggregate) case project: Project => validateProject(project) case filter: Filter => validateFilter(filter) case subqueryAlias: SubqueryAlias => validateSubqueryAlias(subqueryAlias) + case view: View => + validateView(view) case globalLimit: GlobalLimit => validateGlobalLimit(globalLimit) case localLimit: LocalLimit => validateLocalLimit(localLimit) + case distinct: Distinct => + validateDistinct(distinct) case inlineTable: ResolvedInlineTable => validateInlineTable(inlineTable) case localRelation: LocalRelation => validateRelation(localRelation) case oneRowRelation: OneRowRelation => validateRelation(oneRowRelation) + case union: Union => + validateUnion(union) // [[LogicalRelation]], [[HiveTableRelation]] and other specific relations can't be imported // because of a potential circular dependency, so we match a generic Catalyst // [[MultiInstanceRelation]] instead. case multiInstanceRelation: MultiInstanceRelation => validateRelation(multiInstanceRelation) } + ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds( + operator.children.map(_.output) + ) + } + + private def validateWith(withCte: WithCTE): Unit = { + for (cteDef <- withCte.cteDefs) { + validate(cteDef) + } + validate(withCte.plan) + } + + private def validateCteRelationDef(cteRelationDef: CTERelationDef): Unit = { + validate(cteRelationDef.child) + + assert( + !cteRelationDefIds.contains(cteRelationDef.id), + s"Duplicate CTE relation def ID: $cteRelationDef" + ) + + cteRelationDefIds.add(cteRelationDef.id) + } + + private def validateCteRelationRef(cteRelationRef: CTERelationRef): Unit = { + assert( + cteRelationDefIds.contains(cteRelationRef.cteId), + s"CTE relation ref ID is not known: $cteRelationRef" + ) + + handleOperatorOutput(cteRelationRef) + } + + private def validateAggregate(aggregate: Aggregate): Unit = { + attributeScopeStack.withNewScope { + validate(aggregate.child) + expressionResolutionValidator.validateProjectList(aggregate.aggregateExpressions) + } + + handleOperatorOutput(aggregate) } private def validateProject(project: Project): Unit = { @@ -105,6 +174,29 @@ class ResolutionValidator { handleOperatorOutput(subqueryAlias) } + private def validateView(view: View): Unit = { + validate(view.child) + + if (view.desc.viewSchemaMode == SchemaBinding) { + assert( + schemaWithExplicitMetadata(view.schema) == schemaWithExplicitMetadata(view.desc.schema), + "View output schema does not match the view description schema. " + + s"View schema: ${view.schema}, description schema: ${view.desc.schema}" + ) + } + view.child match { + case project: Project => + assert( + !project.projectList + .exists(expression => expression.isInstanceOf[GetViewColumnByNameAndOrdinal]), + "Resolved Project operator under a view cannot contain GetViewColumnByNameAndOrdinal" + ) + case _ => + } + + handleOperatorOutput(view) + } + private def validateGlobalLimit(globalLimit: GlobalLimit): Unit = { validate(globalLimit.child) expressionResolutionValidator.validate(globalLimit.limitExpr) @@ -115,6 +207,10 @@ class ResolutionValidator { expressionResolutionValidator.validate(localLimit.limitExpr) } + private def validateDistinct(distinct: Distinct): Unit = { + validate(distinct.child) + } + private def validateInlineTable(inlineTable: ResolvedInlineTable): Unit = { inlineTable.rows.foreach(row => { row.foreach(expression => { @@ -129,6 +225,29 @@ class ResolutionValidator { handleOperatorOutput(relation) } + private def validateUnion(union: Union): Unit = { + union.children.foreach(validate) + + assert(union.children.length > 1, "Union operator has to have at least 2 children") + val firstChildOutput = union.children.head.output + for (child <- union.children.tail) { + val childOutput = child.output + assert( + childOutput.length == firstChildOutput.length, + s"Unexpected output length for Union child $child" + ) + childOutput.zip(firstChildOutput).foreach { + case (current, first) => + assert( + DataType.equalsStructurally(current.dataType, first.dataType, ignoreNullability = true), + s"Unexpected type of Union child attribute $current for $child" + ) + } + } + + handleOperatorOutput(union) + } + private def handleOperatorOutput(operator: LogicalPlan): Unit = { attributeScopeStack.overwriteTop(operator.output) @@ -142,6 +261,16 @@ class ResolutionValidator { }) } + private def schemaWithExplicitMetadata(schema: StructType): StructType = { + StructType(schema.map { structField => + val metadataBuilder = new MetadataBuilder().withMetadata(structField.metadata) + metadataBuilder.remove(AUTO_GENERATED_ALIAS) + structField.copy( + metadata = metadataBuilder.build() + ) + }) + } + private def wrapErrors[R](plan: LogicalPlan)(body: => R): Unit = { try { body diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala new file mode 100644 index 0000000000000..7f3ca796a4949 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.NamedExpression + +/** + * Structure used to return results of the resolved project list. + * - expressions: The resolved expressions. It is resolved using the + * `resolveExpressionTreeInOperator`. + * - hasAggregateExpressions: True if the resolved project list contains any aggregate + * expressions. + * - hasAttributes: True if the resolved project list contains any attributes that are not under + * an aggregate expression. + * - hasLateralColumnAlias: True if the resolved project list contains any lateral column aliases. + */ +case class ResolvedProjectList( + expressions: Seq[NamedExpression], + hasAggregateExpressions: Boolean, + hasAttributes: Boolean, + hasLateralColumnAlias: Boolean) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala index 37b875abaade6..b17db704b3f42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala @@ -17,29 +17,50 @@ package org.apache.spark.sql.catalyst.analysis.resolver +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.EvaluateUnresolvedInlineTable import org.apache.spark.sql.catalyst.analysis.{ withPosition, + AnalysisErrorAt, FunctionResolution, - NamedRelation, + MultiInstanceRelation, RelationResolution, ResolvedInlineTable, UnresolvedInlineTable, - UnresolvedRelation + UnresolvedRelation, + UnresolvedSubqueryColumnAliases +} +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + Attribute, + AttributeSet, + Expression, + NamedExpression } import org.apache.spark.sql.catalyst.plans.logical.{ + AnalysisHelper, + CTERelationDef, + CTERelationRef, + Distinct, Filter, GlobalLimit, + LeafNode, LocalLimit, LocalRelation, LogicalPlan, OneRowRelation, Project, - SubqueryAlias + SubqueryAlias, + Union, + UnresolvedWith, + View, + WithCTE } -import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} +import org.apache.spark.sql.connector.catalog.{CatalogManager} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.BooleanType /** @@ -58,24 +79,29 @@ import org.apache.spark.sql.types.BooleanType * re-create it for every new analysis run. * * @param catalogManager [[CatalogManager]] for relation and identifier resolution. - * @param extensions A list of [[ResolverExtension]] that can resolve external operators. + * @param extensions A list of [[ResolverExtension]]s that will be used to resolve external + * operators. + * @param metadataResolverExtensions A list of [[ResolverExtension]]s that will be used to resolve + * relation operators in [[MetadataResolver]]. */ class Resolver( catalogManager: CatalogManager, override val extensions: Seq[ResolverExtension] = Seq.empty, metadataResolverExtensions: Seq[ResolverExtension] = Seq.empty) - extends TreeNodeResolver[LogicalPlan, LogicalPlan] - with QueryErrorsBase + extends LogicalPlanResolver with ResolvesOperatorChildren - with TracksResolvedNodes[LogicalPlan] with DelegatesResolutionToExtensions { private val scopes = new NameScopeStack + private val cteRegistry = new CteRegistry private val planLogger = new PlanLogger private val relationResolution = Resolver.createRelationResolution(catalogManager) private val functionResolution = new FunctionResolution(catalogManager, relationResolution) private val expressionResolver = new ExpressionResolver(this, scopes, functionResolution, planLogger) - private val limitExpressionResolver = new LimitExpressionResolver(expressionResolver) + private val expressionIdAssigner = expressionResolver.getExpressionIdAssigner + private val projectResolver = new ProjectResolver(this, expressionResolver, scopes) + private val viewResolver = new ViewResolver(resolver = this, catalogManager = catalogManager) + private val unionResolver = new UnionResolver(this, expressionResolver, scopes) /** * [[relationMetadataProvider]] is used to resolve metadata for relations. It's initialized with @@ -85,7 +111,7 @@ class Resolver( * [[resolveRelation]] to get the plan with resolved metadata (for example, a [[View]] or an * [[UnresolvedCatalogRelation]]) based on the [[UnresolvedRelation]]. * - * If the [[AnalyzerBridgeState]] is provided, we reset this provider to the + * If the [[AnalyzerBridgeState]] is provided, we reset the this provider to the * [[BridgedRelationMetadataProvider]] and later stick to it forever without resorting to the * actual blocking metadata resolution. */ @@ -95,6 +121,13 @@ class Resolver( metadataResolverExtensions ) + /** + * Get the [[CteRegistry]] which is a single instance per query resolution. + */ + def getCteRegistry: CteRegistry = { + cteRegistry + } + /** * This method is an analysis entry point. It resolves the metadata and invokes [[resolve]], * which does most of the analysis work. @@ -115,11 +148,7 @@ class Resolver( relationMetadataProvider } - relationMetadataProvider match { - case metadataResolver: MetadataResolver => - metadataResolver.resolve(unresolvedPlan) - case _ => - } + relationMetadataProvider.resolve(unresolvedPlan) resolve(unresolvedPlan) } @@ -139,63 +168,108 @@ class Resolver( override def resolve(unresolvedPlan: LogicalPlan): LogicalPlan = { planLogger.logPlanResolutionEvent(unresolvedPlan, "Unresolved plan") - throwIfNodeWasResolvedEarlier(unresolvedPlan) - val resolvedPlan = unresolvedPlan match { + case unresolvedWith: UnresolvedWith => + resolveWith(unresolvedWith) case unresolvedProject: Project => - resolveProject(unresolvedProject) + projectResolver.resolve(unresolvedProject) case unresolvedFilter: Filter => resolveFilter(unresolvedFilter) + case unresolvedSubqueryColumnAliases: UnresolvedSubqueryColumnAliases => + resolveSubqueryColumnAliases(unresolvedSubqueryColumnAliases) case unresolvedSubqueryAlias: SubqueryAlias => resolveSubqueryAlias(unresolvedSubqueryAlias) + case unresolvedView: View => + viewResolver.resolve(unresolvedView) case unresolvedGlobalLimit: GlobalLimit => resolveGlobalLimit(unresolvedGlobalLimit) case unresolvedLocalLimit: LocalLimit => resolveLocalLimit(unresolvedLocalLimit) + case unresolvedDistinct: Distinct => + resolveDistinct(unresolvedDistinct) case unresolvedRelation: UnresolvedRelation => resolveRelation(unresolvedRelation) + case unresolvedCteRelationDef: CTERelationDef => + resolveCteRelationDef(unresolvedCteRelationDef) case unresolvedInlineTable: UnresolvedInlineTable => resolveInlineTable(unresolvedInlineTable) + case unresolvedUnion: Union => + unionResolver.resolve(unresolvedUnion) // See the reason why we have to match both [[LocalRelation]] and [[ResolvedInlineTable]] // in the [[resolveInlineTable]] scaladoc case resolvedInlineTable: ResolvedInlineTable => - updateNameScopeWithPlanOutput(resolvedInlineTable) + handleLeafOperator(resolvedInlineTable) case localRelation: LocalRelation => - updateNameScopeWithPlanOutput(localRelation) + handleLeafOperator(localRelation) case unresolvedOneRowRelation: OneRowRelation => - updateNameScopeWithPlanOutput(unresolvedOneRowRelation) + handleLeafOperator(unresolvedOneRowRelation) case _ => tryDelegateResolutionToExtension(unresolvedPlan).getOrElse { handleUnmatchedOperator(unresolvedPlan) } } - markNodeAsResolved(resolvedPlan) + if (resolvedPlan.children.nonEmpty) { + val missingInput = resolvedPlan.missingInput + if (missingInput.nonEmpty) { + withPosition(unresolvedPlan) { + throwMissingAttributesError(resolvedPlan, missingInput) + } + } + } planLogger.logPlanResolution(unresolvedPlan, resolvedPlan) - resolvedPlan + preservePlanIdTag(unresolvedPlan, resolvedPlan) } /** - * [[Project]] introduces a new scope to resolve its subtree and project list expressions. After - * those are resolved in the child scope we overwrite current scope with resolved [[Project]]'s - * output to expose new names to the parent operators. + * [[UnresolvedWith]] contains a list of unresolved CTE definitions, which are represented by + * (name, subquery) pairs, and an actual child query. First we resolve the CTE definitions + * strictly in their declaration order, so they become available for other lower definitions + * (lower both in this WITH clause list and in the plan tree) and for the [[UnresolvedWith]] child + * query. After that, we resolve the child query. Optionally, if this is a root [[CteScope]], + * we return a [[WithCTE]] operator with all the resolved [[CTERelationDef]]s merged together + * from this scope and child scopes. Otherwise, we return the resolved child query so that + * the resolved [[CTERelationDefs]] propagate up and will be merged together later. + * + * See [[CteScope]] scaladoc for all the details on how CTEs are resolved. */ - private def resolveProject(unresolvedProject: Project): LogicalPlan = { - val resolvedProject = scopes.withNewScope { - val resolvedChild = resolve(unresolvedProject.child) - val resolvedProjectList = - expressionResolver.resolveProjectList(unresolvedProject.projectList) - Project(resolvedProjectList, resolvedChild) + private def resolveWith(unresolvedWith: UnresolvedWith): LogicalPlan = { + val childOutputs = new ArrayBuffer[Seq[Attribute]] + + unresolvedWith.cteRelations.map { cteRelation => + val (cteName, ctePlan) = cteRelation + + val resolvedCtePlan = scopes.withNewScope { + expressionIdAssigner.withNewMapping() { + cteRegistry.withNewScope() { + val resolvedCtePlan = resolve(ctePlan) + + childOutputs.append(scopes.top.output) + + resolvedCtePlan + } + } + } + + cteRegistry.currentScope.registerCte(cteName, CTERelationDef(resolvedCtePlan)) } - withPosition(unresolvedProject) { - scopes.overwriteTop(resolvedProject.output) + val resolvedChild = cteRegistry.withNewScope() { + resolve(unresolvedWith.child) } - resolvedProject + childOutputs.append(scopes.top.output) + + ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds(childOutputs.toSeq) + + if (cteRegistry.currentScope.isRoot) { + WithCTE(resolvedChild, cteRegistry.currentScope.getKnownCtes) + } else { + resolvedChild + } } /** @@ -204,7 +278,9 @@ class Resolver( */ private def resolveFilter(unresolvedFilter: Filter): LogicalPlan = { val resolvedChild = resolve(unresolvedFilter.child) - val resolvedCondition = expressionResolver.resolve(unresolvedFilter.condition) + val resolvedCondition = + expressionResolver + .resolveExpressionTreeInOperator(unresolvedFilter.condition, unresolvedFilter) val resolvedFilter = Filter(resolvedCondition, resolvedChild) if (resolvedFilter.condition.dataType != BooleanType) { @@ -216,6 +292,38 @@ class Resolver( resolvedFilter } + /** + * [[UnresolvedSubqueryColumnAliases]] Creates a [[Project]] on top of a [[SubqueryAlias]] with + * the specified attribute aliases. Example: + * + * {{{ + * -- The output schema is [a: INT, b: INT] + * SELECT t.a, t.b FROM VALUES (1, 2) t (a, b); + * }}} + */ + private def resolveSubqueryColumnAliases( + unresolvedSubqueryColumnAliases: UnresolvedSubqueryColumnAliases): LogicalPlan = { + val resolvedChild = resolve(unresolvedSubqueryColumnAliases.child) + + if (unresolvedSubqueryColumnAliases.outputColumnNames.size != scopes.top.output.size) { + withPosition(unresolvedSubqueryColumnAliases) { + throw QueryCompilationErrors.aliasNumberNotMatchColumnNumberError( + unresolvedSubqueryColumnAliases.outputColumnNames.size, + scopes.top.output.size, + unresolvedSubqueryColumnAliases + ) + } + } + + val projectList = scopes.top.output.zip(unresolvedSubqueryColumnAliases.outputColumnNames).map { + case (attr, columnName) => expressionIdAssigner.mapExpression(Alias(attr, columnName)()) + } + + overwriteTopScope(unresolvedSubqueryColumnAliases, projectList.map(_.toAttribute)) + + Project(projectList = projectList, child = resolvedChild) + } + /** * [[SubqueryAlias]] has a single child and an identifier. We need to resolve the child and update * the scope with the output, since upper expressions can reference [[SubqueryAlias]]es output by @@ -223,10 +331,14 @@ class Resolver( */ private def resolveSubqueryAlias(unresolvedSubqueryAlias: SubqueryAlias): LogicalPlan = { val resolvedSubqueryAlias = - SubqueryAlias(unresolvedSubqueryAlias.identifier, resolve(unresolvedSubqueryAlias.child)) - withPosition(unresolvedSubqueryAlias) { - scopes.overwriteTop(unresolvedSubqueryAlias.alias, resolvedSubqueryAlias.output) - } + unresolvedSubqueryAlias.copy(child = resolve(unresolvedSubqueryAlias.child)) + + val qualifier = resolvedSubqueryAlias.identifier.qualifier :+ resolvedSubqueryAlias.alias + overwriteTopScope( + unresolvedSubqueryAlias, + scopes.top.output.map(attribute => attribute.withQualifier(qualifier)) + ) + resolvedSubqueryAlias } @@ -238,7 +350,10 @@ class Resolver( val resolvedChild = resolve(unresolvedGlobalLimit.child) val resolvedLimitExpr = withPosition(unresolvedGlobalLimit) { - limitExpressionResolver.resolve(unresolvedGlobalLimit.limitExpr) + expressionResolver.resolveLimitExpression( + unresolvedGlobalLimit.limitExpr, + unresolvedGlobalLimit + ) } GlobalLimit(resolvedLimitExpr, resolvedChild) @@ -252,12 +367,22 @@ class Resolver( val resolvedChild = resolve(unresolvedLocalLimit.child) val resolvedLimitExpr = withPosition(unresolvedLocalLimit) { - limitExpressionResolver.resolve(unresolvedLocalLimit.limitExpr) + expressionResolver.resolveLimitExpression( + unresolvedLocalLimit.limitExpr, + unresolvedLocalLimit + ) } LocalLimit(resolvedLimitExpr, resolvedChild) } + /** + * [[Distinct]] operator doesn't require any speciial resolution. + */ + private def resolveDistinct(unresolvedDistinct: Distinct): LogicalPlan = { + withResolvedChildren(unresolvedDistinct, resolve) + } + /** * [[UnresolvedRelation]] was previously looked up by the [[MetadataResolver]] and now we need to: * - Get the specific relation with metadata from `relationsWithResolvedMetadata`, like @@ -265,23 +390,51 @@ class Resolver( * - Resolve it further, usually using extensions, like [[DataSourceResolver]] */ private def resolveRelation(unresolvedRelation: UnresolvedRelation): LogicalPlan = { - relationMetadataProvider.getRelationWithResolvedMetadata(unresolvedRelation) match { - case Some(relationWithResolvedMetadata) => - planLogger.logPlanResolutionEvent( - relationWithResolvedMetadata, - "Relation metadata retrieved" - ) - - withPosition(unresolvedRelation) { - resolve(relationWithResolvedMetadata) + withPosition(unresolvedRelation) { + viewResolver.withSourceUnresolvedRelation(unresolvedRelation) { + val maybeResolvedRelation = cteRegistry.resolveCteName(unresolvedRelation.name).orElse { + relationMetadataProvider.getRelationWithResolvedMetadata(unresolvedRelation) } - case None => - withPosition(unresolvedRelation) { - unresolvedRelation.tableNotFound(unresolvedRelation.multipartIdentifier) + + val resolvedRelation = maybeResolvedRelation match { + case Some(cteRelationDef: CTERelationDef) => + planLogger.logPlanResolutionEvent(cteRelationDef, "CTE definition resolved") + + SubqueryAlias(identifier = unresolvedRelation.name, child = cteRelationDef) + case Some(relationsWithResolvedMetadata) => + planLogger.logPlanResolutionEvent( + relationsWithResolvedMetadata, + "Relation metadata retrieved" + ) + + relationsWithResolvedMetadata + case None => + unresolvedRelation.tableNotFound(unresolvedRelation.multipartIdentifier) } + + resolve(resolvedRelation) + } } } + /** + * Resolve [[CTERelationDef]] by replacing it with [[CTERelationRef]] with the same ID so that + * the Optimizer can make a decision whether to inline the definition or not. + * + * [[CTERelationDef.statsOpt]] is filled by the Optimizer. + */ + private def resolveCteRelationDef(unresolvedCteRelationDef: CTERelationDef): LogicalPlan = { + val cteRelationRef = CTERelationRef( + cteId = unresolvedCteRelationDef.id, + _resolved = true, + isStreaming = unresolvedCteRelationDef.isStreaming, + output = unresolvedCteRelationDef.output, + recursive = false + ) + + handleLeafOperator(cteRelationRef) + } + /** * [[UnresolvedInlineTable]] resolution requires all the rows to be resolved first. After that we * use [[EvaluateUnresolvedInlineTable]] and try to evaluate the row expressions if possible to @@ -296,7 +449,10 @@ class Resolver( val withResolvedExpressions = UnresolvedInlineTable( unresolvedInlineTable.names, unresolvedInlineTable.rows.map(row => { - row.map(expressionResolver.resolve(_)) + row.map(unresolvedElement => { + expressionResolver + .resolveExpressionTreeInOperator(unresolvedElement, unresolvedInlineTable) + }) }) ) @@ -309,27 +465,86 @@ class Resolver( } /** - * To finish the operator resolution we add its output to the current scope. This is usually - * done for relations. [[NamedRelation]]'s output should be added to the scope under its name. + * Preserve `PLAN_ID_TAG` which is used for DataFrame column resolution in Spark Connect. */ - private def updateNameScopeWithPlanOutput(relation: LogicalPlan): LogicalPlan = { - withPosition(relation) { - relation match { - case namedRelation: NamedRelation => - scopes.top.update(namedRelation.name, namedRelation.output) - case _ => - scopes.top += relation.output - } + private def preservePlanIdTag( + unresolvedOperator: LogicalPlan, + resolvedOperator: LogicalPlan): LogicalPlan = { + unresolvedOperator.getTagValue(LogicalPlan.PLAN_ID_TAG) match { + case Some(planIdTag) => + resolvedOperator.setTagValue(LogicalPlan.PLAN_ID_TAG, planIdTag) + case None => } - relation + resolvedOperator } - override def tryDelegateResolutionToExtension( + private def tryDelegateResolutionToExtension( unresolvedOperator: LogicalPlan): Option[LogicalPlan] = { - val resolutionResult = super.tryDelegateResolutionToExtension(unresolvedOperator) - resolutionResult.map { resolvedOperator => - updateNameScopeWithPlanOutput(resolvedOperator) + val resolutionResult = super.tryDelegateResolutionToExtension(unresolvedOperator, this) + resolutionResult match { + case Some(leafOperator: LeafNode) => + Some(handleLeafOperator(leafOperator)) + case other => + other + } + } + + /** + * Leaf operators introduce original attributes to this operator subtree and need to be handled in + * a special way: + * - Initialize [[ExpressionIdAssigner]] mapping for this operator branch and reassign + * `leafOperator`'s output attribute IDs. We don't reassign expression IDs in the leftmost + * branch, see [[ExpressionIdAssigner]] class doc for more details. + * [[CTERelationRef]]'s output can always be reassigned. + * - Overwrite the current [[NameScope]] with remapped output attributes. It's OK to call + * `output` on a [[LeafNode]], because it's not recursive (this call fits the single-pass + * framework). + */ + private def handleLeafOperator(leafOperator: LeafNode): LogicalPlan = { + val leafOperatorWithAssignedExpressionIds = leafOperator match { + case leafOperator + if expressionIdAssigner.isLeftmostBranch && !leafOperator.isInstanceOf[CTERelationRef] => + expressionIdAssigner.createMapping(newOutput = leafOperator.output) + leafOperator + + /** + * [[InMemoryRelation.statsOfPlanToCache]] is mutable and does not get copied during normal + * [[transformExpressionsUp]]. The easiest way to correctly copy it is via [[newInstance]] + * call. + * + * We match [[MultiInstanceRelation]] to avoid a cyclic import between [[catalyst]] and + * [[execution]]. + */ + case originalRelation: MultiInstanceRelation => + val newRelation = originalRelation.newInstance() + + expressionIdAssigner.createMapping( + newOutput = newRelation.output, + oldOutput = Some(originalRelation.output) + ) + + newRelation + case _ => + expressionIdAssigner.createMapping() + + AnalysisHelper.allowInvokingTransformsInAnalyzer { + leafOperator.transformExpressionsUp { + case expression: NamedExpression => + val newExpression = expressionIdAssigner.mapExpression(expression) + if (newExpression.eq(expression)) { + throw SparkException.internalError( + s"Leaf operator expression ID was not reassigned. Expression: $expression, " + + s"leaf operator: $leafOperator" + ) + } + newExpression + } + } } + + overwriteTopScope(leafOperator, leafOperatorWithAssignedExpressionIds.output) + + leafOperatorWithAssignedExpressionIds } /** @@ -356,11 +571,59 @@ class Resolver( throw new AnalysisException( errorClass = "DATATYPE_MISMATCH.FILTER_NOT_BOOLEAN", messageParameters = Map( - "sqlExpr" -> filter.expressions.map(toSQLExpr).mkString(","), + "sqlExpr" -> makeCommaSeparatedExpressionString(filter.expressions), "filter" -> toSQLExpr(filter.condition), "type" -> toSQLType(filter.condition.dataType) ) ) + + private def throwMissingAttributesError( + operator: LogicalPlan, + missingInput: AttributeSet): Nothing = { + val inputSet = operator.inputSet + + val inputAttributesByName = new IdentifierMap[Attribute] + for (attribute <- inputSet) { + inputAttributesByName.put(attribute.name, attribute) + } + + val attributesWithSameName = missingInput.filter { missingAttribute => + inputAttributesByName.contains(missingAttribute.name) + } + + if (attributesWithSameName.nonEmpty) { + operator.failAnalysis( + errorClass = "MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_APPEAR_IN_OPERATION", + messageParameters = Map( + "missingAttributes" -> makeCommaSeparatedExpressionString(missingInput.toSeq), + "input" -> makeCommaSeparatedExpressionString(inputSet.toSeq), + "operator" -> operator.simpleString(conf.maxToStringFields), + "operation" -> makeCommaSeparatedExpressionString(attributesWithSameName.toSeq) + ) + ) + } else { + operator.failAnalysis( + errorClass = "MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT", + messageParameters = Map( + "missingAttributes" -> makeCommaSeparatedExpressionString(missingInput.toSeq), + "input" -> makeCommaSeparatedExpressionString(inputSet.toSeq), + "operator" -> operator.simpleString(conf.maxToStringFields) + ) + ) + } + } + + private def makeCommaSeparatedExpressionString(expressions: Seq[Expression]): String = { + expressions.map(toSQLExpr).mkString(", ") + } + + private def overwriteTopScope( + sourceUnresolvedOperator: LogicalPlan, + output: Seq[Attribute]): Unit = { + withPosition(sourceUnresolvedOperator) { + scopes.overwriteTop(output) + } + } } object Resolver { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverExtension.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverExtension.scala index 8bed881ec97a1..4c4f41e16d860 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverExtension.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverExtension.scala @@ -23,10 +23,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * The [[ResolverExtension]] is a main interface for single-pass analysis extensions in Catalyst. * External code that needs specific node types to be resolved has to implement this trait and * inject the implementation into the [[Analyzer.singlePassResolverExtensions]]. - * - * Note that resolver extensions are responsible for creating attribute references with IDs that - * are unique from any other subplans. This should be straightforward in most cases because - * creating new attribute references will assign [[NamedExpression.newExprId]] by default. */ trait ResolverExtension { @@ -35,9 +31,11 @@ trait ResolverExtension { * single-pass [[Resolver]] on all the configured extensions when it exhausted its match list * for the known node types. * - * Guarantees: - * - The implementation can rely on children being resolved - * - We commit to performing the partial function check only at most once per unresolved operator + * - The implementation can rely on children being resolved. + * - The implementation can introduce new unresolved subtrees, but has to invoke `resolver` on + * them. */ - def resolveOperator: PartialFunction[LogicalPlan, LogicalPlan] + def resolveOperator( + operator: LogicalPlan, + resolver: TreeNodeResolver[LogicalPlan, LogicalPlan]): Option[LogicalPlan] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala index b3b3d4def602d..ccd7d90e19161 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{ + GetViewColumnByNameAndOrdinal, ResolvedInlineTable, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedInlineTable, UnresolvedRelation, - UnresolvedStar + UnresolvedStar, + UnresolvedSubqueryColumnAliases } import org.apache.spark.sql.catalyst.expressions.{ Alias, @@ -40,6 +42,7 @@ import org.apache.spark.sql.catalyst.expressions.{ SubqueryExpression } import org.apache.spark.sql.catalyst.plans.logical.{ + Distinct, Filter, GlobalLimit, LocalLimit, @@ -47,9 +50,13 @@ import org.apache.spark.sql.catalyst.plans.logical.{ LogicalPlan, OneRowRelation, Project, - SubqueryAlias + SubqueryAlias, + Union, + UnresolvedWith, + View } import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode /** @@ -72,16 +79,24 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { * their children. For unimplemented ones, return false. */ private def checkOperator(operator: LogicalPlan): Boolean = operator match { + case unresolvedWith: UnresolvedWith => + checkUnresolvedWith(unresolvedWith) case project: Project => checkProject(project) case filter: Filter => checkFilter(filter) + case unresolvedSubqueryColumnAliases: UnresolvedSubqueryColumnAliases => + checkUnresolvedSubqueryColumnAliases(unresolvedSubqueryColumnAliases) case subqueryAlias: SubqueryAlias => checkSubqueryAlias(subqueryAlias) case globalLimit: GlobalLimit => checkGlobalLimit(globalLimit) case localLimit: LocalLimit => checkLocalLimit(localLimit) + case distinct: Distinct => + checkDistinct(distinct) + case view: View => + checkView(view) case unresolvedRelation: UnresolvedRelation => checkUnresolvedRelation(unresolvedRelation) case unresolvedInlineTable: UnresolvedInlineTable => @@ -92,6 +107,8 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { checkLocalRelation(localRelation) case oneRowRelation: OneRowRelation => checkOneRowRelation(oneRowRelation) + case union: Union => + checkUnion(union) case _ => false } @@ -126,11 +143,20 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { checkCreateNamedStruct(createNamedStruct) case unresolvedFunction: UnresolvedFunction => checkUnresolvedFunction(unresolvedFunction) + case getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal => + checkGetViewColumnBynameAndOrdinal(getViewColumnByNameAndOrdinal) case _ => false } } + private def checkUnresolvedWith(unresolvedWith: UnresolvedWith) = { + !unresolvedWith.allowRecursion && unresolvedWith.cteRelations.forall { + case (cteName, ctePlan) => + checkOperator(ctePlan) + } && checkOperator(unresolvedWith.child) + } + private def checkProject(project: Project) = { checkOperator(project.child) && project.projectList.forall(checkExpression) } @@ -138,8 +164,12 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { private def checkFilter(unresolvedFilter: Filter) = checkOperator(unresolvedFilter.child) && checkExpression(unresolvedFilter.condition) + private def checkUnresolvedSubqueryColumnAliases( + unresolvedSubqueryColumnAliases: UnresolvedSubqueryColumnAliases) = + checkOperator(unresolvedSubqueryColumnAliases.child) + private def checkSubqueryAlias(subqueryAlias: SubqueryAlias) = - subqueryAlias.identifier.qualifier.isEmpty && checkOperator(subqueryAlias.child) + checkOperator(subqueryAlias.child) private def checkGlobalLimit(globalLimit: GlobalLimit) = checkOperator(globalLimit.child) && checkExpression(globalLimit.limitExpr) @@ -147,6 +177,11 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { private def checkLocalLimit(localLimit: LocalLimit) = checkOperator(localLimit.child) && checkExpression(localLimit.limitExpr) + private def checkDistinct(distinct: Distinct) = + checkOperator(distinct.child) + + private def checkView(view: View) = checkOperator(view.child) + private def checkUnresolvedInlineTable(unresolvedInlineTable: UnresolvedInlineTable) = unresolvedInlineTable.rows.forall(_.forall(checkExpression)) @@ -160,6 +195,9 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { private def checkLocalRelation(localRelation: LocalRelation) = localRelation.output.forall(checkExpression) + private def checkUnion(union: Union) = + !union.byName && !union.allowMissingCol && union.children.forall(checkOperator) + private def checkOneRowRelation(oneRowRelation: OneRowRelation) = true private def checkAlias(alias: Alias) = checkExpression(alias.child) @@ -180,7 +218,8 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { checkExpression(unresolvedAlias.child) private def checkUnresolvedAttribute(unresolvedAttribute: UnresolvedAttribute) = - !ResolverGuard.UNSUPPORTED_ATTRIBUTE_NAMES.contains(unresolvedAttribute.nameParts.head) + !ResolverGuard.UNSUPPORTED_ATTRIBUTE_NAMES.contains(unresolvedAttribute.nameParts.head) && + !unresolvedAttribute.getTagValue(LogicalPlan.PLAN_ID_TAG).isDefined private def checkUnresolvedPredicate(unresolvedPredicate: Predicate) = { unresolvedPredicate match { @@ -197,17 +236,25 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { } private def checkUnresolvedFunction(unresolvedFunction: UnresolvedFunction) = - ResolverGuard.SUPPORTED_FUNCTION_NAMES.contains( + !ResolverGuard.UNSUPPORTED_FUNCTION_NAMES.contains( unresolvedFunction.nameParts.head - ) && unresolvedFunction.children.forall(checkExpression) + ) && + unresolvedFunction.children.forall(checkExpression) private def checkLiteral(literal: Literal) = true + private def checkGetViewColumnBynameAndOrdinal( + getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal) = true + private def checkConfValues() = // Case sensitive analysis is not supported. !conf.caseSensitiveAnalysis && // Case-sensitive inference is not supported for Hive table schema. - conf.caseSensitiveInferenceMode == HiveCaseSensitiveInferenceMode.NEVER_INFER + conf.caseSensitiveInferenceMode == HiveCaseSensitiveInferenceMode.NEVER_INFER && + // Legacy CTE resolution modes are not supported. + !conf.getConf(SQLConf.LEGACY_INLINE_CTE_IN_COMMANDS) && + LegacyBehaviorPolicy.withName(conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY)) == + LegacyBehaviorPolicy.CORRECTED private def checkVariables() = catalogManager.tempVariableManager.isEmpty } @@ -237,47 +284,39 @@ object ResolverGuard { map } - /** - * Most of the functions are not supported, but we allow some explicitly supported ones. - */ - private val SUPPORTED_FUNCTION_NAMES = { + private val UNSUPPORTED_FUNCTION_NAMES = { val map = new IdentifierMap[Unit]() - map += ("array", ()) - // map += ("array_agg", ()) - until aggregate expressions are supported - map += ("array_append", ()) - map += ("array_compact", ()) - map += ("array_contains", ()) - map += ("array_distinct", ()) - map += ("array_except", ()) - map += ("array_insert", ()) - map += ("array_intersect", ()) - map += ("array_join", ()) - map += ("array_max", ()) - map += ("array_min", ()) - map += ("array_position", ()) - map += ("array_prepend", ()) - map += ("array_remove", ()) - map += ("array_repeat", ()) - map += ("array_size", ()) - // map += ("array_sort", ()) - until lambda functions are supported - map += ("array_union", ()) - map += ("arrays_overlap", ()) - map += ("arrays_zip", ()) - map += ("coalesce", ()) - map += ("if", ()) - map += ("map", ()) - map += ("map_concat", ()) - map += ("map_contains_key", ()) - map += ("map_entries", ()) - // map += ("map_filter", ()) - until lambda functions are supported - map += ("map_from_arrays", ()) - map += ("map_from_entries", ()) - map += ("map_keys", ()) - map += ("map_values", ()) - // map += ("map_zip_with", ()) - until lambda functions are supported - map += ("named_struct", ()) - map += ("sort_array", ()) - map += ("str_to_map", ()) - map + // Non-deterministic functions are not supported. + map += ("current_user", ()) + map += ("rand", ()) + map += ("randn", ()) + map += ("random", ()) + map += ("randstr", ()) + map += ("session_user", ()) + map += ("uniform", ()) + map += ("user", ()) + map += ("uuid", ()) + // Functions that require lambda support. + map += ("array_sort", ()) + map += ("transform", ()) + // Functions that require generator support. + map += ("explode", ()) + map += ("explode_outer", ()) + map += ("inline", ()) + map += ("inline_outer", ()) + // Functions that require session/time window resolution. + map += ("session_window", ()) + map += ("window", ()) + map += ("window_time", ()) + // Functions that are not resolved properly. + map += ("collate", ()) + map += ("json_tuple", ()) + map += ("schema_of_unstructured_agg", ()) + map += ("shuffle", ()) + // Functions that produce wrong schemas/plans because of alias assignment. + map += ("from_json", ()) + map += ("schema_of_json", ()) + // Function for which we don't handle exceptions properly. + map += ("schema_of_xml", ()) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala new file mode 100644 index 0000000000000..3a77e0269d59c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.CleanupAliases +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf + +/** + * Wrapper class for [[Resolver]] and single-pass resolution. This class encapsulates single-pass + * resolution and post-processing of resolved plan. This post-processing is necessary in order to + * either fully resolve the plan or stay compatible with the fixed-point analyzer. + */ +class ResolverRunner( + resolver: Resolver, + extendedResolutionChecks: Seq[LogicalPlan => Unit] = Seq.empty + ) extends SQLConfHelper { + + private val resolutionPostProcessingExecutor = new RuleExecutor[LogicalPlan] { + override def batches: Seq[Batch] = Seq( + Batch("Post-process", Once, CleanupAliases) + ) + } + + /** + * Entry point for the resolver. This method performs following 3 steps: + * - Resolves the plan in a bottom-up, single-pass manner. + * - Validates the result of single-pass resolution. + * - Applies necessary post-processing rules. + */ + def resolve( + plan: LogicalPlan, + analyzerBridgeState: Option[AnalyzerBridgeState] = None): LogicalPlan = { + val resolvedPlan = resolver.lookupMetadataAndResolve(plan, analyzerBridgeState) + if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_VALIDATION_ENABLED)) { + val validator = new ResolutionValidator + validator.validatePlan(resolvedPlan) + } + finishResolution(resolvedPlan) + } + + /** + * This method performs necessary post-processing rules that aren't suitable for single-pass + * resolver. We apply these rules after the single-pass has finished resolution to stay + * compatible with fixed-point analyzer. + */ + private def finishResolution(plan: LogicalPlan): LogicalPlan = { + val planWithPostProcessing = resolutionPostProcessingExecutor.execute(plan) + + for (rule <- extendedResolutionChecks) { + rule(planWithPostProcessing) + } + planWithPostProcessing + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimeAddResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimeAddResolver.scala index bf27f64598723..9f04addb47998 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimeAddResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimeAddResolver.scala @@ -31,17 +31,17 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, TimeAdd} class TimeAddResolver( expressionResolver: ExpressionResolver, timezoneAwareExpressionResolver: TimezoneAwareExpressionResolver) - extends TreeNodeResolver[TimeAdd, Expression] - with ResolvesExpressionChildren { + extends TreeNodeResolver[TimeAdd, Expression] + with ResolvesExpressionChildren { - private val typeCoercionRules: Seq[Expression => Expression] = + private val typeCoercionTransformations: Seq[Expression => Expression] = if (conf.ansiEnabled) { - TimeAddResolver.ANSI_TYPE_COERCION_RULES + TimeAddResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS } else { - TimeAddResolver.TYPE_COERCION_RULES + TimeAddResolver.TYPE_COERCION_TRANSFORMATIONS } private val typeCoercionResolver: TypeCoercionResolver = - new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionRules) + new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionTransformations) override def resolve(unresolvedTimeAdd: TimeAdd): Expression = { val timeAddWithResolvedChildren: TimeAdd = @@ -57,14 +57,14 @@ class TimeAddResolver( object TimeAddResolver { // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( StringPromotionTypeCoercion.apply, TypeCoercion.ImplicitTypeCoercion.apply, TypeCoercion.DateTimeOperationsTypeCoercion.apply ) // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( AnsiStringPromotionTypeCoercion.apply, AnsiTypeCoercion.ImplicitTypeCoercion.apply, AnsiTypeCoercion.AnsiDateTimeOperationsTypeCoercion.apply diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala index a45e9e41cbfb1..5ba08c0c3edb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala @@ -55,11 +55,14 @@ class TimezoneAwareExpressionResolver(expressionResolver: TreeNodeResolver[Expre * @param timeZoneId The timezone ID to apply. * @return A new [[TimeZoneAwareExpression]] with the specified timezone and original tags. */ - def withResolvedTimezoneCopyTags(expression: Expression, timeZoneId: String): Expression = { - val withTimeZone = withResolvedTimezone(expression, timeZoneId) - withTimeZone.copyTagsFrom(expression) - withTimeZone - } + def withResolvedTimezoneCopyTags(expression: Expression, timeZoneId: String): Expression = + expression match { + case timezoneExpression: TimeZoneAwareExpression if timezoneExpression.timeZoneId.isEmpty => + val withTimezone = timezoneExpression.withTimeZone(timeZoneId) + withTimezone.copyTagsFrom(timezoneExpression) + withTimezone + case other => other + } /** * Apply timezone to [[TimeZoneAwareExpression]] expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TracksResolvedNodes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TracksResolvedNodes.scala deleted file mode 100644 index dd86bf843b4ec..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TracksResolvedNodes.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis.resolver - -import java.util.IdentityHashMap - -import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.internal.SQLConf - -/** - * Trait for top-level resolvers that is used to keep track of resolved nodes and throw an error if - * a node is resolved more than once. This is only used in tests because of the memory overhead of - * using a set to track resolved nodes. - */ -trait TracksResolvedNodes[TreeNodeType <: TreeNode[TreeNodeType]] extends SQLConfHelper { - // Using Map because IdentityHashSet is not available in Scala - private val seenResolvedNodes = new IdentityHashMap[TreeNodeType, Unit] - - private val shouldTrackResolvedNodes = - conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_TRACK_RESOLVED_NODES_ENABLED) - - protected def throwIfNodeWasResolvedEarlier(node: TreeNodeType): Unit = - if (shouldTrackResolvedNodes && seenResolvedNodes.containsKey(node)) { - throw SparkException.internalError( - s"Single-pass resolver attempted to resolve the same node more than once: $node" - ) - } - - protected def markNodeAsResolved(node: TreeNodeType): Unit = { - if (shouldTrackResolvedNodes) { - seenResolvedNodes.put(node, ()) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TreeNodeResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TreeNodeResolver.scala index 5991585995cad..f09d475566de5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TreeNodeResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TreeNodeResolver.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.errors.QueryErrorsBase /** * Base class for [[TreeNode]] resolvers. All resolvers should extend this class with * specific [[UnresolvedNode]] and [[ResolvedNode]] types. */ trait TreeNodeResolver[UnresolvedNode <: TreeNode[_], ResolvedNode <: TreeNode[_]] - extends SQLConfHelper { + extends SQLConfHelper + with QueryErrorsBase { def resolve(unresolvedNode: UnresolvedNode): ResolvedNode } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TypeCoercionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TypeCoercionResolver.scala index cf4c2ef0d7504..9c2e729e33e5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TypeCoercionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TypeCoercionResolver.scala @@ -17,26 +17,70 @@ package org.apache.spark.sql.catalyst.analysis.resolver +import org.apache.spark.sql.catalyst.analysis.{ + AnsiGetDateFieldOperationsTypeCoercion, + AnsiStringPromotionTypeCoercion, + AnsiTypeCoercion, + BooleanEqualityTypeCoercion, + CollationTypeCoercion, + DecimalPrecisionTypeCoercion, + DivisionTypeCoercion, + IntegralDivisionTypeCoercion, + StackTypeCoercion, + StringLiteralTypeCoercion, + StringPromotionTypeCoercion, + TypeCoercion +} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} /** * [[TypeCoercionResolver]] is used by other resolvers to uniformly apply type coercions to all - * expressions. [[TypeCoercionResolver]] takes in a sequence of type coercion transformations that - * should be applied to an expression and applies them in order. Finally, [[TypeCoercionResolver]] - * applies timezone to expression's children, as a child could be replaced with Cast(child, type), - * therefore [[Cast]] resolution is needed. Timezone is applied only on children that have been - * re-instantiated by [[TypeCoercionResolver]], because otherwise children have already been - * resolved. + * expressions. */ class TypeCoercionResolver( timezoneAwareExpressionResolver: TimezoneAwareExpressionResolver, - typeCoercionRules: Seq[Expression => Expression]) - extends TreeNodeResolver[Expression, Expression] { + typeCoercionTransformations: Seq[Expression => Expression] = Seq.empty) + extends TreeNodeResolver[Expression, Expression] { + private val typeCoercionTransformationsOrDefault = if (typeCoercionTransformations.nonEmpty) { + typeCoercionTransformations + } else { + if (conf.ansiEnabled) { + TypeCoercionResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS + } else { + TypeCoercionResolver.TYPE_COERCION_TRANSFORMATIONS + } + } + + /** + * Resolves type coercion for expression by applying necessary transformations on the expression + * and its children. Because fixed-point sometimes resolves type coercion in multiple passes, we + * apply each provided transformation twice, cyclically, to ensure that types are resolved. For + * example in a query like: + * + * {{{ SELECT '1' + '1' }}} + * + * fixed-point analyzer requires two passes to resolve types. + */ override def resolve(expression: Expression): Expression = { + val withTypeCoercedOnce = applyTypeCoercion(expression) + // This is a hack necessary because fixed-point analyzer sometimes requires multiple passes to + // resolve type coercion. Instead, in single pass, we apply type coercion twice on the same + // node in order to ensure that types are resolved. + applyTypeCoercion(withTypeCoercedOnce) + } + + /** + * Takes in a sequence of type coercion transformations that should be applied to an expression + * and applies them in order. Finally, [[TypeCoercionResolver]] applies timezone to expression's + * children, as a child could be replaced with Cast(child, type), therefore [[Cast]] resolution + * is needed. Timezone is applied only on children that have been re-instantiated, because + * otherwise children are already resolved. + */ + private def applyTypeCoercion(expression: Expression) = { val oldChildren = expression.children - val withTypeCoercion = typeCoercionRules.foldLeft(expression) { + val withTypeCoercion = typeCoercionTransformationsOrDefault.foldLeft(expression) { case (expr, rule) => rule.apply(expr) } @@ -48,3 +92,49 @@ class TypeCoercionResolver( withTypeCoercion.withNewChildren(newChildren) } } + +object TypeCoercionResolver { + + // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. + private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( + CollationTypeCoercion.apply, + TypeCoercion.InTypeCoercion.apply, + StringPromotionTypeCoercion.apply, + DecimalPrecisionTypeCoercion.apply, + BooleanEqualityTypeCoercion.apply, + TypeCoercion.FunctionArgumentTypeCoercion.apply, + TypeCoercion.ConcatTypeCoercion.apply, + TypeCoercion.MapZipWithTypeCoercion.apply, + TypeCoercion.EltTypeCoercion.apply, + TypeCoercion.CaseWhenTypeCoercion.apply, + TypeCoercion.IfTypeCoercion.apply, + StackTypeCoercion.apply, + DivisionTypeCoercion.apply, + IntegralDivisionTypeCoercion.apply, + TypeCoercion.ImplicitTypeCoercion.apply, + TypeCoercion.DateTimeOperationsTypeCoercion.apply, + TypeCoercion.WindowFrameTypeCoercion.apply, + StringLiteralTypeCoercion.apply + ) + + // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. + private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( + CollationTypeCoercion.apply, + AnsiTypeCoercion.InTypeCoercion.apply, + AnsiStringPromotionTypeCoercion.apply, + DecimalPrecisionTypeCoercion.apply, + AnsiTypeCoercion.FunctionArgumentTypeCoercion.apply, + AnsiTypeCoercion.ConcatTypeCoercion.apply, + AnsiTypeCoercion.MapZipWithTypeCoercion.apply, + AnsiTypeCoercion.EltTypeCoercion.apply, + AnsiTypeCoercion.CaseWhenTypeCoercion.apply, + AnsiTypeCoercion.IfTypeCoercion.apply, + StackTypeCoercion.apply, + DivisionTypeCoercion.apply, + IntegralDivisionTypeCoercion.apply, + AnsiTypeCoercion.ImplicitTypeCoercion.apply, + AnsiTypeCoercion.AnsiDateTimeOperationsTypeCoercion.apply, + AnsiTypeCoercion.WindowFrameTypeCoercion.apply, + AnsiGetDateFieldOperationsTypeCoercion.apply + ) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala index 739d7cf43c183..04089512b31b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala @@ -26,17 +26,17 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus} class UnaryMinusResolver( expressionResolver: ExpressionResolver, timezoneAwareExpressionResolver: TimezoneAwareExpressionResolver) - extends TreeNodeResolver[UnaryMinus, Expression] - with ResolvesExpressionChildren { + extends TreeNodeResolver[UnaryMinus, Expression] + with ResolvesExpressionChildren { - private val typeCoercionRules: Seq[Expression => Expression] = + private val typeCoercionTransformations: Seq[Expression => Expression] = if (conf.ansiEnabled) { - UnaryMinusResolver.ANSI_TYPE_COERCION_RULES + UnaryMinusResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS } else { - UnaryMinusResolver.TYPE_COERCION_RULES + UnaryMinusResolver.TYPE_COERCION_TRANSFORMATIONS } private val typeCoercionResolver: TypeCoercionResolver = - new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionRules) + new TypeCoercionResolver(timezoneAwareExpressionResolver, typeCoercionTransformations) override def resolve(unresolvedUnaryMinus: UnaryMinus): Expression = { val unaryMinusWithResolvedChildren: UnaryMinus = @@ -47,13 +47,13 @@ class UnaryMinusResolver( object UnaryMinusResolver { // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( TypeCoercion.ImplicitTypeCoercion.apply, TypeCoercion.DateTimeOperationsTypeCoercion.apply ) // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_RULES: Seq[Expression => Expression] = Seq( + private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( AnsiTypeCoercion.ImplicitTypeCoercion.apply, AnsiTypeCoercion.AnsiDateTimeOperationsTypeCoercion.apply ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala new file mode 100644 index 0000000000000..77cb45624d6e5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnionResolver.scala @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.HashSet + +import org.apache.spark.sql.catalyst.analysis.{ + withPosition, + AnsiTypeCoercion, + TypeCoercion, + TypeCoercionBase +} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, ExprId} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Project, Union} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{DataType, MetadataBuilder} + +/** + * The [[UnionResolver]] performs [[Union]] operator resolution. This operator has 2+ + * children. Resolution involves checking and normalizing child output attributes + * (data types and nullability). + */ +class UnionResolver( + resolver: Resolver, + expressionResolver: ExpressionResolver, + scopes: NameScopeStack) + extends TreeNodeResolver[Union, Union] { + private val expressionIdAssigner = expressionResolver.getExpressionIdAssigner + private val typeCoercion: TypeCoercionBase = + if (conf.ansiEnabled) { + AnsiTypeCoercion + } else { + TypeCoercion + } + + /** + * Resolve the [[Union]] operator: + * - Retrieve old output and child outputs if the operator is already resolved. This is relevant + * for partially resolved subtrees from DataFrame programs. + * - Resolve each child in the context of a) New [[NameScope]] b) New [[ExpressionIdAssigner]] + * mapping. Collect child outputs to coerce them later. + * - Perform projection-based expression ID deduplication if required. This is a hack to stay + * compatible with fixed-point [[Analyzer]]. + * - Perform individual output deduplication to handle the distinict union case described in + * [[performIndividualOutputExpressionIdDeduplication]] scaladoc. + * - Validate that child outputs have same length or throw "NUM_COLUMNS_MISMATCH" otherwise. + * - Compute widened data types for child output attributes using + * [[typeCoercion.findWiderTypeForTwo]] or throw "INCOMPATIBLE_COLUMN_TYPE" if coercion fails. + * - Add [[Project]] with [[Cast]] on children needing attribute data type widening. + * - Assert that coerced outputs don't have conflicting expression IDs. + * - Merge transformed outputs: For each column, merge child attributes' types using + * [[StructType.unionLikeMerge]]. Mark column as nullable if any child attribute is. + * - Store merged output in current [[NameScope]]. + * - Create a new mapping in [[ExpressionIdAssigner]] using the coerced and validated outputs. + * - Return the resolved [[Union]] with new children. + */ + override def resolve(unresolvedUnion: Union): Union = { + val (oldOutput, oldChildOutputs) = if (unresolvedUnion.resolved) { + (Some(unresolvedUnion.output), Some(unresolvedUnion.children.map(_.output))) + } else { + (None, None) + } + + val (resolvedChildren, childOutputs) = unresolvedUnion.children.zipWithIndex.map { + case (unresolvedChild, childIndex) => + scopes.withNewScope { + expressionIdAssigner.withNewMapping(isLeftmostChild = (childIndex == 0)) { + val resolvedChild = resolver.resolve(unresolvedChild) + (resolvedChild, scopes.top.output) + } + } + }.unzip + + val (projectBasedDeduplicatedChildren, projectBasedDeduplicatedChildOutputs) = + performProjectionBasedExpressionIdDeduplication( + resolvedChildren, + childOutputs, + oldChildOutputs + ) + val (deduplicatedChildren, deduplicatedChildOutputs) = + performIndividualOutputExpressionIdDeduplication( + projectBasedDeduplicatedChildren, + projectBasedDeduplicatedChildOutputs + ) + + val (newChildren, newChildOutputs) = if (needToCoerceChildOutputs(deduplicatedChildOutputs)) { + coerceChildOutputs( + unresolvedUnion, + deduplicatedChildren, + deduplicatedChildOutputs, + validateAndDeduceTypes(unresolvedUnion, deduplicatedChildOutputs) + ) + } else { + (deduplicatedChildren, deduplicatedChildOutputs) + } + + ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds(newChildOutputs) + + withPosition(unresolvedUnion) { + scopes.overwriteTop(Union.mergeChildOutputs(newChildOutputs)) + } + + expressionIdAssigner.createMapping(scopes.top.output, oldOutput) + + unresolvedUnion.copy(children = newChildren) + } + + /** + * Fixed-point [[Analyzer]] uses [[DeduplicateRelations]] rule to handle duplicate expression IDs + * in multi-child operator outputs. For [[Union]]s it uses a "projection-based deduplication", + * i.e. places another [[Project]] operator with new [[Alias]]es on the right child if duplicate + * expression IDs detected. New [[Alias]] "covers" the original attribute with new expression ID. + * This is done for all child operators except [[LeafNode]]s. + * + * We don't need this operation in single-pass [[Resolver]], since we have + * [[ExpressionIdAssigner]] for expression ID deduplication, but perform it nevertheless to stay + * compatible with fixed-point [[Analyzer]]. Since new outputs are already deduplicated by + * [[ExpressionIdAssigner]], we check the _old_ outputs for duplicates and place a [[Project]] + * only if old outputs are available (i.e. we are dealing with a resolved subtree from + * DataFrame program). + */ + private def performProjectionBasedExpressionIdDeduplication( + children: Seq[LogicalPlan], + childOutputs: Seq[Seq[Attribute]], + oldChildOutputs: Option[Seq[Seq[Attribute]]] + ): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = { + oldChildOutputs match { + case Some(oldChildOutputs) => + val oldExpressionIds = new HashSet[ExprId] + + children + .zip(childOutputs) + .zip(oldChildOutputs) + .map { + case ((child: LeafNode, output), _) => + (child, output) + + case ((child, output), oldOutput) => + val oldOutputExpressionIds = new HashSet[ExprId] + + val hasConflicting = oldOutput.exists { oldAttribute => + oldOutputExpressionIds.add(oldAttribute.exprId) + oldExpressionIds.contains(oldAttribute.exprId) + } + + if (hasConflicting) { + val newExpressions = output.map { attribute => + Alias(attribute, attribute.name)() + } + ( + Project(projectList = newExpressions, child = child), + newExpressions.map(_.toAttribute) + ) + } else { + oldExpressionIds.addAll(oldOutputExpressionIds) + + (child, output) + } + } + .unzip + case _ => + (children, childOutputs) + } + } + + /** + * Deduplicate expression IDs at the scope of each individual child output. This is necessary to + * handle the following case: + * + * {{{ + * -- The correct answer is (1, 1), (1, 2). Without deduplication it would be (1, 1), because + * -- aggregation would be done only based on the first column. + * SELECT + * a, a + * FROM + * VALUES (1, 1), (1, 2) AS t1 (a, b) + * UNION + * SELECT + * a, b + * FROM + * VALUES (1, 1), (1, 2) AS t2 (a, b) + * }}} + * + * Putting [[Alias]] introduces a new expression ID for the attribute duplicates in the output. We + * also add `__is_duplicate` metadata so that [[AttributeSeq.getCandidatesForResolution]] doesn't + * produce conficting candidates when resolving names in the upper [[Project]] - this is + * technically still the same attribute. + * + * Probably there's a better way to do that, but we want to stay compatible with the fixed-point + * [[Analyzer]]. + * + * See SPARK-37865 for more details. + */ + private def performIndividualOutputExpressionIdDeduplication( + children: Seq[LogicalPlan], + childOutputs: Seq[Seq[Attribute]] + ): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = { + children + .zip(childOutputs) + .map { + case (child, childOutput) => + var outputChanged = false + + val expressionIds = new HashSet[ExprId] + val newOutput = childOutput.map { attribute => + if (expressionIds.contains(attribute.exprId)) { + outputChanged = true + + val newMetadata = new MetadataBuilder() + .withMetadata(attribute.metadata) + .putNull("__is_duplicate") + .build() + Alias(attribute, attribute.name)(explicitMetadata = Some(newMetadata)) + } else { + expressionIds.add(attribute.exprId) + + attribute + } + } + + if (outputChanged) { + (Project(projectList = newOutput, child = child), newOutput.map(_.toAttribute)) + } else { + (child, childOutput) + } + } + .unzip + } + + /** + * Check if we need to coerce child output attributes to wider types. We need to do this if: + * - Output length differs between children. We will throw an appropriate error later during type + * coercion with more diagnostics. + * - Output data types differ between children. We don't care about nullability for type coercion, + * it will be correctly assigned later by [[Union.mergeChildOutputs]]. + */ + private def needToCoerceChildOutputs(childOutputs: Seq[Seq[Attribute]]): Boolean = { + val firstChildOutput = childOutputs.head + childOutputs.tail.exists { childOutput => + childOutput.length != firstChildOutput.length || + childOutput.zip(firstChildOutput).exists { + case (lhsAttribute, rhsAttribute) => + !DataType.equalsStructurally( + lhsAttribute.dataType, + rhsAttribute.dataType, + ignoreNullability = true + ) + } + } + } + + /** + * Returns a sequence of data types representing the widened data types for each column: + * - Validates that the number of columns in each child of the `Union` operator are equal. + * - Validates that the data types of columns can be widened to a common type. + * - Deduces the widened data types for each column. + */ + private def validateAndDeduceTypes( + unresolvedUnion: Union, + childOutputs: Seq[Seq[Attribute]]): Seq[DataType] = { + val childDataTypes = childOutputs.map(attributes => attributes.map(attr => attr.dataType)) + + val expectedNumColumns = childDataTypes.head.length + + childDataTypes.zipWithIndex.tail.foldLeft(childDataTypes.head) { + case (widenedTypes, (childColumnTypes, childIndex)) => + if (childColumnTypes.length != expectedNumColumns) { + throwNumColumnsMismatch( + expectedNumColumns, + childColumnTypes, + childIndex, + unresolvedUnion + ) + } + + widenedTypes.zip(childColumnTypes).zipWithIndex.map { + case ((widenedColumnType, columnTypeForCurrentRow), columnIndex) => + typeCoercion.findWiderTypeForTwo(widenedColumnType, columnTypeForCurrentRow).getOrElse { + throwIncompatibleColumnTypeError( + unresolvedUnion, + columnIndex, + childIndex + 1, + widenedColumnType, + columnTypeForCurrentRow + ) + } + } + } + } + + /** + * Coerce `childOutputs` to the previously calculated `widenedTypes`. If the data types for + * child output has changed, we have to add a [[Project]] operator with a [[Cast]] to the new + * type. + */ + private def coerceChildOutputs( + unresolvedUnion: Union, + children: Seq[LogicalPlan], + childOutputs: Seq[Seq[Attribute]], + widenedTypes: Seq[DataType]): (Seq[LogicalPlan], Seq[Seq[Attribute]]) = { + children + .zip(childOutputs) + .map { + case (child, output) => + var outputChanged = false + val newExpressions = output.zip(widenedTypes).map { + case (attribute, widenedType) => + /** + * Probably more correct way to compare data types here would be to call + * [[DataType.equalsStructurally]] but fixed-point [[Analyzer]] rule + * [[WidenSetOperationTypes]] uses `==`, so we do the same to stay compatible. + */ + if (attribute.dataType == widenedType) { + attribute + } else { + outputChanged = true + Alias( + Cast(attribute, widenedType, Some(conf.sessionLocalTimeZone)), + attribute.name + )() + } + } + + if (outputChanged) { + (Project(newExpressions, child), newExpressions.map(_.toAttribute)) + } else { + (child, output) + } + } + .unzip + } + + private def throwNumColumnsMismatch( + expectedNumColumns: Int, + childColumnTypes: Seq[DataType], + columnIndex: Int, + unresolvedUnion: Union): Unit = { + throw QueryCompilationErrors.numColumnsMismatch( + "UNION", + expectedNumColumns, + columnIndex + 1, + childColumnTypes.length, + unresolvedUnion.origin + ) + } + + private def throwIncompatibleColumnTypeError( + unresolvedUnion: Union, + columnIndex: Int, + childIndex: Int, + widenedColumnType: DataType, + columnTypeForCurrentRow: DataType): Nothing = { + throw QueryCompilationErrors.incompatibleColumnTypeError( + "UNION", + columnIndex, + childIndex + 1, + widenedColumnType, + columnTypeForCurrentRow, + hint = "", + origin = unresolvedUnion.origin + ) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala new file mode 100644 index 0000000000000..e7e9c5ec822ae --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.ArrayDeque + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, View} +import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf + +/** + * The [[ViewResolver]] resolves view plans that were already reconstructed by [[SessionCatalog]] + * from the view text and view metadata (schema, configs). + */ +class ViewResolver(resolver: Resolver, catalogManager: CatalogManager) + extends TreeNodeResolver[View, View] { + private val cteRegistry = resolver.getCteRegistry + private val sourceUnresolvedRelationStack = new ArrayDeque[UnresolvedRelation] + private val viewResolutionContextStack = new ArrayDeque[ViewResolutionContext] + + /** + * This method preserves the resolved [[UnresolvedRelation]] for the further view resolution + * process. + * + * [[sourceUnresolvedRelationStack]] is used to save the [[UnresolvedRelation]] after its + * resolution by [[Resolver.resolveRelation]], since [[View]] that was produced from this + * [[UnresolvedRelation]] needs [[UnresolvedRelation.options]] for its resolution. + * We pop from the [[sourceUnresolvedRelationStack]] after the `body` is executed. The stack is + * necessary, since [[withSourceUnresolvedRelation]] calls might be nested: + * [[UnresolvedRelation]] -> [[View]] + * ... + * [[UnresolvedRelation]] -> [[View]] + * ... + */ + def withSourceUnresolvedRelation(unresolvedRelation: UnresolvedRelation)( + body: => LogicalPlan): LogicalPlan = { + sourceUnresolvedRelationStack.push(unresolvedRelation) + try { + body + } finally { + sourceUnresolvedRelationStack.pop() + } + } + + /** + * Resolve the `unresolvedView` and its underlying plan. This method uses parent [[Resolver]] to + * resolve the view child. [[View]] resolution consists of the following steps: + * - Check if the single-pass resolver fully supports the view plan using the [[ResolverGuard]]. + * Throw [[ExplicitlyUnsupportedResolverFeature]] if the view plan is not supported. + * - Set the [[ViewResolutionContext]] for the view plan resolution. + * - Replace the necessary configurations in [[SQLConf]] with those that were stored with the + * view. + * - Resolve the view plan using parent [[Resolver]]. + * - Create a new [[CatalogTable]] description for the resolved view based on the original + * [[UnresolvedRelation.options]], original [[CatalogTable]] description and used + * [[ViewResolutionContext]]. + * - Return the resolved [[View]] with the resolved child and a new [[CatalogTable]] + * description. + */ + override def resolve(unresolvedView: View): View = { + checkResolverGuard(unresolvedView) + + val (resolvedChild, usedViewResolutionContext) = withViewResolutionContext(unresolvedView) { + SQLConf.withExistingConf( + View.effectiveSQLConf(unresolvedView.desc.viewSQLConfigs, unresolvedView.isTempView) + ) { + cteRegistry.withNewScope(isRoot = true, isOpaque = true) { + resolver.lookupMetadataAndResolve(unresolvedView.child) + } + } + } + + unresolvedView.copy(child = resolvedChild) + } + + /** + * Execute `body` with a fresh [[ViewResolutionContext]] specifically constructed for + * `unresolvedView` resolution. The context is popped back to the previous one after the + * `body` is executed, because views may be nested. + */ + private def withViewResolutionContext(unresolvedView: View)( + body: => LogicalPlan): (LogicalPlan, ViewResolutionContext) = { + val viewResolutionContext = if (viewResolutionContextStack.isEmpty()) { + ViewResolutionContext( + nestedViewDepth = 1, + maxNestedViewDepth = conf.maxNestedViewDepth + ) + } else { + val prevContext = viewResolutionContextStack.peek() + prevContext.copy(nestedViewDepth = prevContext.nestedViewDepth + 1) + } + viewResolutionContext.validate(unresolvedView) + + viewResolutionContextStack.push(viewResolutionContext) + try { + (body, viewResolutionContext) + } finally { + viewResolutionContextStack.pop() + } + } + + private def checkResolverGuard(unresolvedView: View): Unit = { + val resolverGuard = new ResolverGuard(catalogManager) + if (!resolverGuard(unresolvedView)) { + throw new ExplicitlyUnsupportedResolverFeature("View body is not supported") + } + } +} + +/** + * The [[ViewResolutionContext]] consists of data, which is specific to the specific view plan + * resolution. This data is also propagated to the subviews. + * + * @param nestedViewDepth Current nested view depth. Cannot exceed the `maxNestedViewDepth`. + * @param maxNestedViewDepth Maximum allowed nested view depth. Configured in the upper context + * based on [[SQLConf.MAX_NESTED_VIEW_DEPTH]]. + */ +case class ViewResolutionContext(nestedViewDepth: Int, maxNestedViewDepth: Int) { + def validate(unresolvedView: View): Unit = { + if (nestedViewDepth > maxNestedViewDepth) { + throw QueryCompilationErrors.viewDepthExceedsMaxResolutionDepthError( + unresolvedView.desc.identifier, + maxNestedViewDepth, + unresolvedView + ) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4907e7ee6276e..faa4da60a534b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -301,19 +301,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val ANALYZER_SINGLE_PASS_TRACK_RESOLVED_NODES_ENABLED = - buildConf("spark.sql.analyzer.singlePassResolver.trackResolvedNodes.enabled") - .internal() - .doc( - "When true, keep track of resolved nodes in order to assert that the single-pass " + - "invariant is never broken. While true, if a resolver attempts to resolve the same node " + - "twice, INTERNAL_ERROR exception is thrown. Used only for testing due to memory impact " + - "of storing each node in a HashSet." - ) - .version("4.0.0") - .booleanConf - .createWithDefault(false) - val ANALYZER_SINGLE_PASS_RESOLVER_RELATION_BRIDGING_ENABLED = buildConf("spark.sql.analyzer.singlePassResolver.relationBridging.enabled") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolverSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolverSuite.scala index fdab4df379a71..c8b4db3e10aa2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolverSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolverSuite.scala @@ -19,18 +19,13 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Literal} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types.IntegerType class LimitExpressionResolverSuite extends SparkFunSuite with QueryErrorsBase { - private class IdentityExpressionResolver extends TreeNodeResolver[Expression, Expression] { - override def resolve(expression: Expression): Expression = expression - } - - private val expressionResolver = new IdentityExpressionResolver - private val limitExpressionResolver = new LimitExpressionResolver(expressionResolver) + private val limitExpressionResolver = new LimitExpressionResolver test("Basic LIMIT without errors") { val expr = Literal(42, IntegerType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidatorSuite.scala index 922e94ea442b3..913de4b5a19f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidatorSuite.scala @@ -48,7 +48,9 @@ class ResolutionValidatorSuite extends SparkFunSuite with SQLConfHelper { // [[LocalRelation]], but produces only [[ResolvedInlineTable]] and [[LocalRelation]], so // we omit one of them here. // See [[Resolver.resolveInlineTable]] scaladoc for more info. - "resolveResolvedInlineTable" + "resolveResolvedInlineTable", + // [[UnresolvedSubqueryColumnAliases]] turns into a [[Project]] + "resolveSubqueryColumnAliases" ) private val colInteger = AttributeReference(name = "colInteger", dataType = IntegerType)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolver.scala index 4cd75736ea9eb..92e6fe465c856 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolver.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.analysis.resolver.{ ExplicitlyUnsupportedResolverFeature, + LogicalPlanResolver, ResolverExtension } import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation @@ -38,18 +39,18 @@ class DataSourceResolver(sparkSession: SparkSession) extends ResolverExtension { /** * Resolve [[UnresolvedCatalogRelation]]: * - Reuse [[FindDataSourceTable]] code to resolve [[UnresolvedCatalogRelation]] - * - Create a new instance of [[LogicalRelation]] to regenerate the expression IDs + * - Return [[LogicalRelation]] if it's resolved * - Explicitly disallow [[StreamingRelation]] and [[StreamingRelationV2]] for now * - [[FileResolver]], which is a [[ResolverExtension]], introduces a new [[LogicalPlan]] node * which resolution has to be handled here (further resolution of it doesn't need any specific * resolution except adding it's attributes to the scope). */ - override def resolveOperator: PartialFunction[LogicalPlan, LogicalPlan] = { + override def resolveOperator( + operator: LogicalPlan, + resolver: LogicalPlanResolver): Option[LogicalPlan] = operator match { case unresolvedCatalogRelation: UnresolvedCatalogRelation => val result = findDataSourceTable.resolveUnresolvedCatalogRelation(unresolvedCatalogRelation) - result match { - case logicalRelation: LogicalRelation => - logicalRelation.newInstance() + Some(result match { case streamingRelation: StreamingRelation => throw new ExplicitlyUnsupportedResolverFeature( s"unsupported operator: ${streamingRelation.getClass.getName}" @@ -60,8 +61,10 @@ class DataSourceResolver(sparkSession: SparkSession) extends ResolverExtension { ) case other => other - } + }) case logicalRelation: LogicalRelation => - logicalRelation.newInstance() + Some(logicalRelation) + case _ => + None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileResolver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileResolver.scala index 0728054625aa2..1a5c33f7dc756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileResolver.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileResolver.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension +import org.apache.spark.sql.catalyst.analysis.resolver.{LogicalPlanResolver, ResolverExtension} import org.apache.spark.sql.catalyst.plans.logical.{AnalysisHelper, LogicalPlan} import org.apache.spark.sql.classic.SparkSession @@ -48,8 +48,12 @@ class FileResolver(sparkSession: SparkSession) extends ResolverExtension { /** * Reuse [[ResolveSQLOnFile]] code to resolve [[UnresolvedRelation]] made out of file. */ - override def resolveOperator: PartialFunction[LogicalPlan, LogicalPlan] = { + override def resolveOperator( + operator: LogicalPlan, + resolver: LogicalPlanResolver): Option[LogicalPlan] = operator match { case UnresolvedRelationResolution(resolvedRelation) => - resolvedRelation + Some(resolvedRelation) + case _ => + None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/AggregateExpressionResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/AggregateExpressionResolverSuite.scala new file mode 100644 index 0000000000000..eb2a00dfa6c19 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/AggregateExpressionResolverSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.analysis.resolver + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.resolver.Resolver +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.test.SharedSparkSession + +class AggregateExpressionResolverSuite extends QueryTest with SharedSparkSession { + private val table = LocalRelation.fromExternalRows( + Seq("a".attr.int), + Seq(Row(1)) + ) + + test("Valid aggregate expression") { + val resolver = createResolver() + val query = table.select($"count".function("a".attr)) + resolver.resolve(query) + } + + test("Unsupported parent operator") { + val resolver = createResolver() + val query = table.limit($"count".function("a".attr)) + checkErrorMatchPVals( + exception = intercept[AnalysisException] { + resolver.resolve(query) + }, + condition = "UNSUPPORTED_EXPR_FOR_OPERATOR", + parameters = Map( + "invalidExprSqls" -> """count\(a#\d+\)""" + ) + ) + } + + test("Nested aggregate expression") { + val resolver = createResolver() + val query = table.select($"count".function($"count".function("a".attr)).as("a")) + + checkError( + exception = intercept[AnalysisException] { + resolver.resolve(query) + }, + condition = "NESTED_AGGREGATE_FUNCTION", + parameters = Map.empty + ) + } + + private def createResolver(): Resolver = { + new Resolver( + catalogManager = spark.sessionState.catalogManager, + extensions = spark.sessionState.analyzer.singlePassResolverExtensions + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExplicitlyUnsupportedResolverFeatureSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExplicitlyUnsupportedResolverFeatureSuite.scala index 7fd7d570ecfc1..b7133b1e34e6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExplicitlyUnsupportedResolverFeatureSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExplicitlyUnsupportedResolverFeatureSuite.scala @@ -18,10 +18,17 @@ package org.apache.spark.sql.analysis.resolver import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.analysis.resolver.Resolver +import org.apache.spark.sql.catalyst.analysis.resolver.{ + ExplicitlyUnsupportedResolverFeature, + Resolver +} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession class ExplicitlyUnsupportedResolverFeatureSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + test("Unsupported table types") { withTable("csv_table") { spark.sql("CREATE TABLE csv_table (col1 INT) USING CSV;").collect() @@ -41,22 +48,6 @@ class ExplicitlyUnsupportedResolverFeatureSuite extends QueryTest with SharedSpa } } - test("Unsupported view types") { - withTable("src_table") { - spark.sql("CREATE TABLE src_table (col1 INT) USING PARQUET;").collect() - - withView("temporary_view") { - spark.sql("CREATE TEMPORARY VIEW temporary_view AS SELECT * FROM src_table;").collect() - checkResolution("SELECT * FROM temporary_view;") - } - - withView("persistent_view") { - spark.sql("CREATE VIEW persistent_view AS SELECT * FROM src_table;").collect() - checkResolution("SELECT * FROM persistent_view;") - } - } - } - test("Unsupported char type padding") { withTable("char_type_padding") { spark.sql(s"CREATE TABLE t1 (c1 CHAR(3), c2 STRING) USING PARQUET") @@ -64,28 +55,55 @@ class ExplicitlyUnsupportedResolverFeatureSuite extends QueryTest with SharedSpa } } - test("Unsupported lateral column alias") { - checkResolution("SELECT 1 AS a, a AS b") - checkResolution("SELECT sum(1), `sum(1)` + 1 AS a") + test("Unsupported star expansion") { + checkResolution("SELECT * FROM VALUES (1, 2) WHERE 3 IN (*)") + } + + test("LateralColumnAlias in Aggregate") { + checkResolution("SELECT 1 AS a, sum(col1) as sum1, a + sum(col2) FROM VALUES(1, 2)") + } + + test("Unsupported UDF") { + spark.sql( + "CREATE FUNCTION udf(x INT) RETURNS INT RETURN x" + ) + checkResolution("SELECT udf(1)") + } + + test("Unsupported lambda") { + checkResolution( + "SELECT array_sort(array(2, 1), (p1, p2) -> CASE WHEN p1 > p2 THEN 1 ELSE 0 END)" + ) + } + + test("Missing attribute propagation") { + val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id") + checkResolution( + df.select(df("name")).filter(df("id") === 0).queryExecution.logical, + shouldPass = false + ) } private def checkResolution(sqlText: String, shouldPass: Boolean = false): Unit = { + val unresolvedPlan = spark.sessionState.sqlParser.parsePlan(sqlText) + checkResolution(unresolvedPlan, shouldPass) + } + + private def checkResolution(plan: LogicalPlan, shouldPass: Boolean): Unit = { def noopWrapper(body: => Unit) = body val wrapper = if (shouldPass) { noopWrapper _ } else { - intercept[Throwable] _ + intercept[ExplicitlyUnsupportedResolverFeature] _ } - val unresolvedPlan = spark.sql(sqlText).queryExecution.logical - val resolver = new Resolver( spark.sessionState.catalogManager, extensions = spark.sessionState.analyzer.singlePassResolverExtensions ) wrapper { - resolver.lookupMetadataAndResolve(unresolvedPlan) + resolver.lookupMetadataAndResolve(plan) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExpressionIdAssignerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExpressionIdAssignerSuite.scala new file mode 100644 index 0000000000000..bbcfda96dad0a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExpressionIdAssignerSuite.scala @@ -0,0 +1,818 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.analysis.resolver + +import java.util.IdentityHashMap + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.resolver.{ExpressionIdAssigner, Resolver} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class ExpressionIdAssignerSuite extends QueryTest with SharedSparkSession { + private val col1Integer = AttributeReference(name = "col1", dataType = IntegerType)() + private val col1IntegerAlias = Alias(col1Integer, "a")() + private val col2Integer = AttributeReference(name = "col2", dataType = IntegerType)() + private val col2IntegerAlias = Alias(col2Integer, "b")() + private val col3Integer = AttributeReference(name = "col3", dataType = IntegerType)() + + private val CONSTRAINTS_VALIDATED = TreeNodeTag[Boolean]("constraints_validated") + + test("Mapping is not created") { + val assigner = new ExpressionIdAssigner + + intercept[SparkException] { + assigner.mapExpression(col1Integer) + } + + assigner.withNewMapping() { + assigner.withNewMapping() { + intercept[SparkException] { + assigner.mapExpression(col1Integer) + } + } + } + } + + test("Mapping is created twice") { + val assigner = new ExpressionIdAssigner + + intercept[SparkException] { + assigner.createMapping() + assigner.createMapping() + } + + assigner.withNewMapping() { + assigner.createMapping() + + assigner.withNewMapping() { + assigner.createMapping() + + intercept[SparkException] { + assigner.createMapping() + } + } + + intercept[SparkException] { + assigner.createMapping() + } + } + } + + test("Create mapping with new output and old output with different length") { + val assigner = new ExpressionIdAssigner + + intercept[SparkException] { + assigner.createMapping( + newOutput = Seq(col1Integer.newInstance()), + oldOutput = Some(Seq(col1Integer, col2Integer)) + ) + } + } + + test("Left branch: Single AttributeReference") { + val assigner = new ExpressionIdAssigner + + assigner.createMapping() + + val col1IntegerMapped = assigner.mapExpression(col1Integer) + assert(col1IntegerMapped.isInstanceOf[AttributeReference]) + assert(col1IntegerMapped.exprId != col1Integer.exprId) + + val col1IntegerReferenced = assigner.mapExpression(col1Integer) + assert(col1IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerReferenced.exprId == col1IntegerMapped.exprId) + + val col1IntegerMappedReferenced = assigner.mapExpression(col1IntegerMapped) + assert(col1IntegerMappedReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerMappedReferenced.exprId == col1IntegerMapped.exprId) + } + + test("Right branch: Single AttributeReference") { + val assigner = new ExpressionIdAssigner + assigner.withNewMapping() { + assigner.createMapping() + + val col1IntegerMapped = assigner.mapExpression(col1Integer) + assert(col1IntegerMapped.isInstanceOf[AttributeReference]) + assert(col1IntegerMapped.exprId != col1Integer.exprId) + + val col1IntegerReferenced = assigner.mapExpression(col1Integer) + assert(col1IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerReferenced.exprId == col1IntegerMapped.exprId) + + val col1IntegerMappedReferenced = assigner.mapExpression(col1IntegerMapped) + assert(col1IntegerMappedReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerMappedReferenced.exprId == col1IntegerMapped.exprId) + } + } + + test("Left branch: Single Alias") { + val assigner = new ExpressionIdAssigner + + assigner.createMapping() + + val col1IntegerAliasMapped = assigner.mapExpression(col1IntegerAlias) + assert(col1IntegerAliasMapped.isInstanceOf[Alias]) + assert(col1IntegerAliasMapped.exprId == col1IntegerAlias.exprId) + + val col1IntegerAliasReferenced = assigner.mapExpression(col1IntegerAlias.toAttribute) + assert(col1IntegerAliasReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerAliasReferenced.exprId == col1IntegerAliasMapped.exprId) + + val col1IntegerAliasMappedReferenced = + assigner.mapExpression(col1IntegerAliasMapped.toAttribute) + assert(col1IntegerAliasMappedReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerAliasMappedReferenced.exprId == col1IntegerAliasMapped.exprId) + + val col1IntegerAliasMappedAgain = assigner.mapExpression(col1IntegerAlias) + assert(col1IntegerAliasMappedAgain.isInstanceOf[Alias]) + assert(col1IntegerAliasMappedAgain.exprId != col1IntegerAlias.exprId) + assert(col1IntegerAliasMappedAgain.exprId != col1IntegerAliasMapped.exprId) + } + + test("Right branch: Single Alias") { + val assigner = new ExpressionIdAssigner + assigner.withNewMapping() { + assigner.createMapping() + + val col1IntegerAliasMapped = assigner.mapExpression(col1IntegerAlias) + assert(col1IntegerAliasMapped.isInstanceOf[Alias]) + assert(col1IntegerAliasMapped.exprId != col1IntegerAlias.exprId) + + val col1IntegerAliasReferenced = assigner.mapExpression(col1IntegerAlias.toAttribute) + assert(col1IntegerAliasReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerAliasReferenced.exprId == col1IntegerAliasMapped.exprId) + + val col1IntegerAliasMappedReferenced = + assigner.mapExpression(col1IntegerAliasMapped.toAttribute) + assert(col1IntegerAliasMappedReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerAliasMappedReferenced.exprId == col1IntegerAliasMapped.exprId) + + val col1IntegerAliasMappedAgain = assigner.mapExpression(col1IntegerAlias) + assert(col1IntegerAliasMappedAgain.isInstanceOf[Alias]) + assert(col1IntegerAliasMappedAgain.exprId != col1IntegerAlias.exprId) + assert(col1IntegerAliasMappedAgain.exprId != col1IntegerAliasMapped.exprId) + } + } + + test("Left branch: Create mapping with new output") { + val assigner = new ExpressionIdAssigner + + assigner.createMapping(newOutput = Seq(col1Integer, col2Integer)) + + val col1IntegerReferenced = assigner.mapExpression(col1Integer) + assert(col1IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerReferenced.exprId == col1Integer.exprId) + + val col2IntegerReferenced = assigner.mapExpression(col2Integer) + assert(col2IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col2IntegerReferenced.exprId == col2Integer.exprId) + + val col2IntegerAliasMapped = assigner.mapExpression(col2IntegerAlias) + assert(col2IntegerAliasMapped.isInstanceOf[Alias]) + assert(col2IntegerAliasMapped.exprId == col2IntegerAlias.exprId) + assert(col2IntegerAliasMapped.exprId != col2Integer.exprId) + + val col3IntegerMapped = assigner.mapExpression(col3Integer) + assert(col3IntegerMapped.isInstanceOf[AttributeReference]) + assert(col3IntegerMapped.exprId != col3Integer.exprId) + } + + test("Right branch: Create mapping with new output") { + val assigner = new ExpressionIdAssigner + assigner.withNewMapping() { + assigner.createMapping(newOutput = Seq(col1Integer, col2Integer)) + + val col1IntegerReferenced = assigner.mapExpression(col1Integer) + assert(col1IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerReferenced.exprId == col1Integer.exprId) + + val col2IntegerReferenced = assigner.mapExpression(col2Integer) + assert(col2IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col2IntegerReferenced.exprId == col2Integer.exprId) + + val col2IntegerAliasMapped = assigner.mapExpression(col2IntegerAlias) + assert(col2IntegerAliasMapped.isInstanceOf[Alias]) + assert(col2IntegerAliasMapped.exprId != col2IntegerAlias.exprId) + assert(col2IntegerAliasMapped.exprId != col2Integer.exprId) + + val col3IntegerMapped = assigner.mapExpression(col3Integer) + assert(col3IntegerMapped.isInstanceOf[AttributeReference]) + assert(col3IntegerMapped.exprId != col3Integer.exprId) + } + } + + test("Left branch: Create mapping with new output and old output") { + val assigner = new ExpressionIdAssigner + + val col1IntegerNew = col1Integer.newInstance() + assert(col1IntegerNew.exprId != col1Integer.exprId) + + val col2IntegerNew = col2Integer.newInstance() + assert(col2IntegerNew.exprId != col2Integer.exprId) + + assigner.createMapping( + newOutput = Seq(col1IntegerNew, col2IntegerNew), + oldOutput = Some(Seq(col1Integer, col2Integer)) + ) + + val col1IntegerReferenced = assigner.mapExpression(col1Integer) + assert(col1IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerReferenced.exprId == col1IntegerNew.exprId) + + val col1IntegerNewReferenced = assigner.mapExpression(col1IntegerNew) + assert(col1IntegerNewReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerNewReferenced.exprId == col1IntegerNew.exprId) + + val col2IntegerReferenced = assigner.mapExpression(col2Integer) + assert(col2IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col2IntegerReferenced.exprId == col2IntegerNew.exprId) + + val col2IntegerNewReferenced = assigner.mapExpression(col2IntegerNew) + assert(col2IntegerNewReferenced.isInstanceOf[AttributeReference]) + assert(col2IntegerNewReferenced.exprId == col2IntegerNew.exprId) + + val col2IntegerAliasMapped = assigner.mapExpression(col2IntegerAlias) + assert(col2IntegerAliasMapped.isInstanceOf[Alias]) + assert(col2IntegerAliasMapped.exprId == col2IntegerAlias.exprId) + assert(col2IntegerAliasMapped.exprId != col2Integer.exprId) + assert(col2IntegerAliasMapped.exprId != col2IntegerNew.exprId) + + val col3IntegerMapped = assigner.mapExpression(col3Integer) + assert(col3IntegerMapped.isInstanceOf[AttributeReference]) + assert(col3IntegerMapped.exprId != col3Integer.exprId) + } + + test("Right branch: Create mapping with new output and old output") { + val assigner = new ExpressionIdAssigner + assigner.withNewMapping() { + val col1IntegerNew = col1Integer.newInstance() + assert(col1IntegerNew.exprId != col1Integer.exprId) + + val col2IntegerNew = col2Integer.newInstance() + assert(col2IntegerNew.exprId != col2Integer.exprId) + + assigner.createMapping( + newOutput = Seq(col1IntegerNew, col2IntegerNew), + oldOutput = Some(Seq(col1Integer, col2Integer)) + ) + + val col1IntegerReferenced = assigner.mapExpression(col1Integer) + assert(col1IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerReferenced.exprId == col1IntegerNew.exprId) + + val col1IntegerNewReferenced = assigner.mapExpression(col1IntegerNew) + assert(col1IntegerNewReferenced.isInstanceOf[AttributeReference]) + assert(col1IntegerNewReferenced.exprId == col1IntegerNew.exprId) + + val col2IntegerReferenced = assigner.mapExpression(col2Integer) + assert(col2IntegerReferenced.isInstanceOf[AttributeReference]) + assert(col2IntegerReferenced.exprId == col2IntegerNew.exprId) + + val col2IntegerNewReferenced = assigner.mapExpression(col2IntegerNew) + assert(col2IntegerNewReferenced.isInstanceOf[AttributeReference]) + assert(col2IntegerNewReferenced.exprId == col2IntegerNew.exprId) + + val col2IntegerAliasMapped = assigner.mapExpression(col2IntegerAlias) + assert(col2IntegerAliasMapped.isInstanceOf[Alias]) + assert(col2IntegerAliasMapped.exprId != col2IntegerAlias.exprId) + assert(col2IntegerAliasMapped.exprId != col2Integer.exprId) + assert(col2IntegerAliasMapped.exprId != col2IntegerNew.exprId) + + val col3IntegerMapped = assigner.mapExpression(col3Integer) + assert(col3IntegerMapped.isInstanceOf[AttributeReference]) + assert(col3IntegerMapped.exprId != col3Integer.exprId) + } + } + + test("Several layers") { + val assigner = new ExpressionIdAssigner + val literalAlias1 = Alias(Literal(1), "a")() + val literalAlias2 = Alias(Literal(2), "b")() + + val output1 = assigner.withNewMapping() { + val output1 = assigner.withNewMapping() { + assigner.createMapping() + + Seq( + assigner.mapExpression(col1Integer).toAttribute, + assigner.mapExpression(col2Integer).toAttribute + ) + } + + val output2 = assigner.withNewMapping() { + val col1IntegerNew = col1Integer.newInstance() + val col2IntegerNew = col2Integer.newInstance() + + assigner.createMapping(newOutput = Seq(col1IntegerNew, col2IntegerNew)) + + Seq( + assigner.mapExpression(col1IntegerNew).toAttribute, + assigner.mapExpression(col2IntegerNew).toAttribute + ) + } + + val output3 = assigner.withNewMapping() { + val col1IntegerNew = col1Integer.newInstance() + val col2IntegerNew = col2Integer.newInstance() + + assigner.createMapping( + newOutput = Seq(col1IntegerNew, col2IntegerNew), + oldOutput = Some(Seq(col1Integer, col2Integer)) + ) + + Seq( + assigner.mapExpression(col1Integer).toAttribute, + assigner.mapExpression(col2Integer).toAttribute + ) + } + + output1.zip(output2).zip(output3).zip(Seq(col1Integer, col2Integer)).foreach { + case (((attribute1, attribute2), attribute3), originalAttribute) => + assert(attribute1.exprId != originalAttribute.exprId) + assert(attribute2.exprId != originalAttribute.exprId) + assert(attribute3.exprId != originalAttribute.exprId) + assert(attribute1.exprId != attribute2.exprId) + assert(attribute1.exprId != attribute3.exprId) + assert(attribute2.exprId != attribute3.exprId) + } + + assigner.createMapping(newOutput = output2) + + val literalAlias1Remapped = assigner.mapExpression(literalAlias1) + assert(literalAlias1Remapped.isInstanceOf[Alias]) + assert(literalAlias1Remapped.exprId != literalAlias1.exprId) + + val literalAlias2Remapped = assigner.mapExpression(literalAlias2) + assert(literalAlias2Remapped.isInstanceOf[Alias]) + assert(literalAlias2Remapped.exprId != literalAlias2.exprId) + + Seq(literalAlias1Remapped.toAttribute, literalAlias2Remapped.toAttribute) ++ output2 + } + + val output2 = assigner.withNewMapping() { + assigner.createMapping() + + val literalAlias1Remapped = assigner.mapExpression(literalAlias1) + assert(literalAlias1Remapped.isInstanceOf[Alias]) + assert(literalAlias1Remapped.exprId != literalAlias1.exprId) + + val literalAlias2Remapped = assigner.mapExpression(literalAlias2) + assert(literalAlias2Remapped.isInstanceOf[Alias]) + assert(literalAlias2Remapped.exprId != literalAlias2.exprId) + + Seq(literalAlias1Remapped.toAttribute, literalAlias2Remapped.toAttribute) + } + + output1.zip(output2).foreach { + case (aliasReference1, aliasReference2) => + assert(aliasReference1.exprId != aliasReference2.exprId) + } + + assigner.createMapping(newOutput = output1) + + val aliasReferences = output1.map { aliasReference => + assigner.mapExpression(aliasReference) + } + + aliasReferences.zip(output1).zip(output2).foreach { + case ((aliasReference, aliasReference1), aliasReference2) => + assert(aliasReference.exprId == aliasReference1.exprId) + assert(aliasReference.exprId != aliasReference2.exprId) + } + + aliasReferences.map(_.toAttribute) + } + + test("Simple select") { + checkExpressionIdAssignment( + spark + .sql(""" + SELECT + col1, 1 AS a, col1, 1 AS a, col2, 2 AS b, col3, 3 AS c + FROM + VALUES (1, 2, 3) + """) + .queryExecution + .analyzed + ) + } + + test("Simple select, aliases referenced") { + checkExpressionIdAssignment( + spark + .sql(""" + SELECT + col3, c, col2, b, col1, a, col1, a + FROM ( + SELECT + col1, 1 AS a, col1, col2, 2 AS b, col3, 3 AS c + FROM + VALUES (1, 2, 3) + )""") + .queryExecution + .analyzed + ) + } + + test("Simple select, aliases referenced and rewritten") { + checkExpressionIdAssignment( + spark + .sql(""" + SELECT + col3, 3 AS c, col2, 2 AS b, col1, 1 AS a, col1, 1 AS a + FROM ( + SELECT + col2, b, col1, a, col1, a, col3, c + FROM ( + SELECT + col1, 1 AS a, col1, col2, 2 AS b, col3, 3 AS c + FROM + VALUES (1, 2, 3) + ) + )""") + .queryExecution + .analyzed + ) + } + + test("SQL Union, same table") { + withTable("t1") { + spark.sql("CREATE TABLE t1 (col1 INT, col2 INT, col3 INT)") + + checkExpressionIdAssignment( + spark + .sql(""" + SELECT * FROM ( + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + UNION ALL + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + UNION ALL + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + )""") + .queryExecution + .analyzed + ) + } + } + + test("SQL Union, different tables") { + withTable("t1") { + spark.sql("CREATE TABLE t1 (col1 INT, col2 INT, col3 INT)") + withTable("t2") { + spark.sql("CREATE TABLE t2 (col1 INT, col2 INT, col3 INT)") + withTable("t3") { + spark.sql("CREATE TABLE t3 (col1 INT, col2 INT, col3 INT)") + + checkExpressionIdAssignment( + spark + .sql(""" + SELECT * FROM ( + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + UNION ALL + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + UNION ALL + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + )""") + .queryExecution + .analyzed + ) + } + } + } + } + + test("SQL Union, same table, several layers") { + withTable("t1") { + spark.sql("CREATE TABLE t1 (col1 INT, col2 INT, col3 INT)") + + checkExpressionIdAssignment( + spark + .sql(""" + SELECT * FROM ( + SELECT * FROM ( + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + UNION ALL + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + ) + UNION ALL + SELECT * FROM ( + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + UNION ALL + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + ) + ) + UNION ALL + SELECT * FROM ( + SELECT * FROM ( + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + UNION ALL + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + ) + UNION ALL + SELECT * FROM ( + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + UNION ALL + SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1 + ) + )""") + .queryExecution + .analyzed + ) + } + } + + test("DataFrame Union, same table") { + withTable("t1") { + spark.sql("CREATE TABLE t1 (col1 INT, col2 INT, col3 INT)") + + val df = spark.sql("SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1") + checkExpressionIdAssignment(df.union(df).queryExecution.analyzed) + } + } + + test("DataFrame Union, different tables") { + withTable("t1") { + spark.sql("CREATE TABLE t1 (col1 INT, col2 INT, col3 INT)") + + withTable("t2") { + spark.sql("CREATE TABLE t2 (col1 INT, col2 INT, col3 INT)") + + val df1 = spark.sql("SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1") + val df2 = spark.sql("SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t2") + checkExpressionIdAssignment(df1.union(df2).queryExecution.analyzed) + } + } + } + + test("DataFrame Union, same table, several layers") { + withTable("t1") { + spark.sql("CREATE TABLE t1 (col1 INT, col2 INT, col3 INT)") + + val df = spark.sql("SELECT col1, 1 AS a, col2, 2 AS b, col3, 3 AS c FROM t1") + checkExpressionIdAssignment( + df.union(df) + .select("*") + .union(df.union(df).select("*")) + .union(df.union(df).select("*")) + .queryExecution + .analyzed + ) + } + } + + test("The case of output attribute names is preserved") { + val df = spark.sql("SELECT col1, COL1, cOl2, CoL2 FROM VALUES (1, 2)") + + checkExpressionIdAssignment(df.queryExecution.analyzed) + } + + test("The metadata of output attributes is preserved") { + val metadata1 = new MetadataBuilder().putString("m1", "1").putString("m2", "2").build() + val metadata2 = new MetadataBuilder().putString("m2", "3").putString("m3", "4").build() + val schema = new StructType().add("a", IntegerType, nullable = true, metadata = metadata2) + val df = + spark.sql("SELECT col1 FROM VALUES (1)").select(col("col1").as("a", metadata1)).to(schema) + + checkExpressionIdAssignment(df.queryExecution.analyzed) + } + + test("Alias with the same ID in multiple Projects") { + val t = LocalRelation.fromExternalRows( + Seq("a".attr.int, "b".attr.int), + 0.until(10).map(_ => Row(1, 2)) + ) + val alias = ("a".attr + 1).as("a") + val plan = t.select(alias).select(alias).select(alias) + + checkExpressionIdAssignment(plan) + } + + test("Raw union, same table") { + val t = LocalRelation.fromExternalRows( + Seq("col1".attr.int, "col2".attr.int), + 0.until(10).map(_ => Row(1, 2)) + ) + val query = t.select("col1".attr, Literal(1).as("a"), "col2".attr, Literal(2).as("b")) + val plan = query.union(query) + + checkExpressionIdAssignment(plan) + } + + test("DataFrame with binary arithmetic re-resolved") { + val result = withSQLConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED.key -> "true") { + val df = spark.sql("SELECT col1 + col2 AS a FROM VALUES (1, 2)") + df.union(df) + } + checkAnswer(result, Array(Row(3), Row(3))) + } + + test("Leftmost branch attributes are not regenerated in DataFrame") { + withTable("t1") { + spark.sql("CREATE TABLE t1 (col1 INT, col2 INT)") + spark.sql("INSERT INTO t1 VALUES (0, 1), (2, 3)") + + var result = withSQLConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED.key -> "true") { + val df1 = spark.table("t1") + df1.select(col("col1"), col("col2")).filter(df1("col1") === 0) + } + checkAnswer(result, Array(Row(0, 1))) + + result = withSQLConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED.key -> "true") { + val df1 = spark.table("t1").select(col("col1").as("a"), col("col2").as("b")) + df1.select(col("a"), col("b")).filter(df1("a") === 0) + } + checkAnswer(result, Array(Row(0, 1))) + + result = withSQLConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED.key -> "true") { + val df1 = spark.table("t1") + df1.union(df1).filter(df1("col1") === 0) + } + checkAnswer(result, Array(Row(0, 1), Row(0, 1))) + } + } + + private def checkExpressionIdAssignment(originalPlan: LogicalPlan): Unit = { + val resolver = new Resolver( + catalogManager = spark.sessionState.catalogManager, + extensions = spark.sessionState.analyzer.singlePassResolverExtensions + ) + val newPlan = resolver.resolve(originalPlan) + + checkPlanConstraints(originalPlan, newPlan, leftmostBranch = true) + checkSubtreeConstraints(originalPlan, newPlan, leftmostBranch = true) + } + + private def checkPlanConstraints( + originalPlan: LogicalPlan, + newPlan: LogicalPlan, + leftmostBranch: Boolean): Unit = { + originalPlan.children.zip(newPlan.children).zipWithIndex.foreach { + case ((originalChild, newChild), index) => + checkPlanConstraints(originalChild, newChild, leftmostBranch && index == 0) + } + + if (originalPlan.children.length > 1) { + ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds( + newPlan.children.map(_.output) + ) + originalPlan.children.zip(newPlan.children).zipWithIndex.foreach { + case ((oldChild, newChild), index) => + checkSubtreeConstraints(oldChild, newChild, leftmostBranch && index == 0) + } + } + } + + private def checkSubtreeConstraints( + originalPlan: LogicalPlan, + newPlan: LogicalPlan, + leftmostBranch: Boolean): Unit = { + val originalOperators = new ArrayBuffer[LogicalPlan] + originalPlan.foreach { + case operator if !operator.getTagValue(CONSTRAINTS_VALIDATED).getOrElse(false) => + originalOperators.append(operator) + case _ => + } + + val newOperators = new ArrayBuffer[LogicalPlan] + + val leftmostOperators = new IdentityHashMap[LogicalPlan, Boolean] + if (leftmostBranch) { + leftmostOperators.put(newPlan, true) + } + + newPlan.foreach { + case operator if !operator.getTagValue(CONSTRAINTS_VALIDATED).getOrElse(false) => + newOperators.append(operator) + + if (operator.children.nonEmpty && leftmostOperators.containsKey(operator)) { + leftmostOperators.put(operator.children.head, true) + } + case _ => + } + + val attributesByName = new HashMap[String, ArrayBuffer[AttributeReference]] + val aliasesByName = new HashMap[String, ArrayBuffer[Alias]] + originalOperators + .zip(newOperators) + .collect { + case (originalProject: Project, newProject: Project) => + if (originalProject.resolved) { + (originalProject.projectList, newProject.projectList, newProject) + } else { + (newProject.projectList, newProject) + } + case (originalOperator: LogicalPlan, newOperator: LogicalPlan) => + if (originalOperator.resolved) { + (originalOperator.output, newOperator.output, newOperator) + } else { + (newOperator.output, newOperator) + } + } + .foreach { + case ( + originalExpressions: Seq[NamedExpression], + newExpressions: Seq[NamedExpression], + newOperator: LogicalPlan + ) => + originalExpressions.zip(newExpressions).zipWithIndex.foreach { + case ( + (originalAttribute: AttributeReference, newAttribute: AttributeReference), + index + ) => + if (leftmostOperators.containsKey(newOperator)) { + assert( + originalAttribute.exprId == newAttribute.exprId, + s"Attribute at $index was regenerated: $originalAttribute, $newAttribute" + ) + } else { + assert( + originalAttribute.exprId != newAttribute.exprId, + s"Attribute at $index was not regenerated: $originalAttribute, $newAttribute" + ) + } + + attributesByName + .getOrElseUpdate(newAttribute.name, new ArrayBuffer[AttributeReference]) + .append(newAttribute) + case ((originalAlias: Alias, newAlias: Alias), index) => + if (leftmostOperators.containsKey(newOperator)) { + assert( + originalAlias.exprId == newAlias.exprId, + s"Alias at $index was regenerated: $originalAlias, $newAlias" + ) + } else { + assert( + originalAlias.exprId != newAlias.exprId, + s"Alias at $index was not regenerated: $originalAlias, $newAlias" + ) + } + + aliasesByName.getOrElseUpdate(newAlias.name, new ArrayBuffer[Alias]).append(newAlias) + } + case (newExpressions: Seq[NamedExpression], newOperator: LogicalPlan) => + newExpressions.foreach { + case newAttribute: AttributeReference => + attributesByName + .getOrElseUpdate(newAttribute.name, new ArrayBuffer[AttributeReference]) + .append(newAttribute) + case newAlias: Alias => + aliasesByName.getOrElseUpdate(newAlias.name, new ArrayBuffer[Alias]).append(newAlias) + } + } + + attributesByName.values.foreach { + case attributes => + val ids = attributes.map(attribute => attribute.exprId).distinct + assert( + ids.length == 1, + s"Different IDs for the same attribute in the plan: $attributes, $newPlan" + ) + } + aliasesByName.values.foreach { + case aliases => + val ids = aliases.map(alias => alias.exprId).distinct + assert( + ids.length == aliases.length, + s"Duplicate IDs for aliases with the same name: $aliases" + ) + } + + for (operator <- originalOperators) { + operator.setTagValue(CONSTRAINTS_VALIDATED, true) + } + for (operator <- newOperators) { + operator.setTagValue(CONSTRAINTS_VALIDATED, true) + } + + if (originalPlan.resolved) { + assert(newPlan.schema == originalPlan.schema) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/HybridAnalyzerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/HybridAnalyzerSuite.scala index 587725093f0e5..38d78f846ab90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/HybridAnalyzerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/HybridAnalyzerSuite.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.analysis.resolver.{ ResolverGuard } import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.NormalizePlan import org.apache.spark.sql.catalyst.plans.logical.{ LocalRelation, LogicalPlan, @@ -142,13 +143,12 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { Seq(col1Integer), LocalRelation(Seq(col1Integer)) ) - assert( + assertPlansEqual( new HybridAnalyzer( new ValidatingAnalyzer(bridgeRelations = true), new ResolverGuard(spark.sessionState.catalogManager), new ValidatingResolver(bridgeRelations = true) - ).apply(plan, null) - == + ).apply(plan, null), resolvedPlan ) } @@ -312,7 +312,7 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { Seq(col1Integer), LocalRelation(Seq(col1Integer)) ) - assert( + assertPlansEqual( withSQLConf( SQLConf.ANALYZER_DUAL_RUN_LEGACY_AND_SINGLE_PASS_RESOLVER.key -> "false" ) { @@ -324,8 +324,7 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { bridgeRelations = false ) ).apply(plan, null) - } - == + }, resolvedPlan ) } @@ -339,7 +338,7 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { Seq(col1Integer), LocalRelation(Seq(col1Integer)) ) - assert( + assertPlansEqual( withSQLConf( SQLConf.ANALYZER_DUAL_RUN_LEGACY_AND_SINGLE_PASS_RESOLVER.key -> "false", SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED.key -> "true" @@ -352,8 +351,7 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { new ResolverGuard(spark.sessionState.catalogManager), new ValidatingResolver(bridgeRelations = false) ).apply(plan, null) - } - == + }, resolvedPlan ) } @@ -369,7 +367,7 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { ) val nestedAnalysis = () => { - assert( + assertPlansEqual( withSQLConf( SQLConf.ANALYZER_DUAL_RUN_LEGACY_AND_SINGLE_PASS_RESOLVER.key -> "false", SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED.key -> "true" @@ -382,13 +380,12 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { new ResolverGuard(spark.sessionState.catalogManager), new ValidatingResolver(bridgeRelations = false) ).apply(plan, null) - } - == + }, resolvedPlan ) } - assert( + assertPlansEqual( new HybridAnalyzer( new CustomAnalyzer( customCode = () => { nestedAnalysis() }, @@ -396,9 +393,12 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { ), new ResolverGuard(spark.sessionState.catalogManager), new ValidatingResolver(bridgeRelations = true) - ).apply(plan, null) - == + ).apply(plan, null), resolvedPlan ) } + + private def assertPlansEqual(actualPlan: LogicalPlan, expectedPlan: LogicalPlan) = { + assert(NormalizePlan(actualPlan) == NormalizePlan(expectedPlan)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/MetadataResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/MetadataResolverSuite.scala index 5fd21d7543b33..c742e97447468 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/MetadataResolverSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/MetadataResolverSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation import org.apache.spark.sql.catalyst.expressions.{Expression, PlanExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{FileResolver, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -114,6 +115,41 @@ class MetadataResolverSuite extends QueryTest with SharedSparkSession with SQLTe } } + test("Relation inside a CTE definition") { + withTable("src") { + spark.sql("CREATE TABLE src (key INT, value STRING) USING PARQUET;").collect() + + checkResolveUnresolvedCatalogRelation( + sqlText = """ + WITH cte AS (SELECT key FROM src) + SELECT * FROM cte + """, + expectedTableData = Seq(createTableData("src")) + ) + } + } + + test("Relation inside a CTE definition inside a subquery expression") { + withTable("src") { + spark.sql("CREATE TABLE src (key INT, value STRING) USING PARQUET;").collect() + + checkResolveUnresolvedCatalogRelation( + sqlText = """ + SELECT + col1 + ( + SELECT 35 * ( + WITH cte AS (SELECT key FROM src) + SELECT key FROM cte LIMIT 1 + ) * col1 FROM VALUES (2) + ) + FROM + VALUES (1) + """, + expectedTableData = Seq(createTableData("src")) + ) + } + } + test("Relation from a file") { val df = spark.range(100).toDF() withTempPath(f => { @@ -223,7 +259,7 @@ class MetadataResolverSuite extends QueryTest with SharedSparkSession with SQLTe val actualTableData = new mutable.HashMap[RelationId, TestTableData] - def findUnresolvedRelations(unresolvedPlan: LogicalPlan): Unit = unresolvedPlan.foreach { + def findUnresolvedRelations(unresolvedPlan: LogicalPlan): Unit = unresolvedPlan match { case unresolvedRelation: UnresolvedRelation => metadataResolver.getRelationWithResolvedMetadata(unresolvedRelation) match { case Some(plan) => @@ -249,6 +285,11 @@ class MetadataResolverSuite extends QueryTest with SharedSparkSession with SQLTe expression.children.foreach(traverseExpressions) } + unresolvedPlan.children.foreach(findUnresolvedRelations) + unresolvedPlan.innerChildren.foreach { + case plan: LogicalPlan => findUnresolvedRelations(plan) + case _ => + } unresolvedPlan.expressions.foreach(traverseExpressions) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/NameScopeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/NameScopeSuite.scala index ec744af89f000..acb6290b3af09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/NameScopeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/NameScopeSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.UnresolvedStar import org.apache.spark.sql.catalyst.analysis.resolver.{NameScope, NameScopeStack, NameTarget} import org.apache.spark.sql.catalyst.expressions.{ + Attribute, AttributeReference, GetArrayItem, GetArrayStructFields, @@ -103,263 +104,113 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { test("Empty scope") { val nameScope = new NameScope - assert(nameScope.getAllAttributes.isEmpty) + assert(nameScope.output.isEmpty) - assert(nameScope.matchMultipartName(Seq("col1")) == NameTarget(candidates = Seq.empty)) - } - - test("Single unnamed plan") { - val nameScope = new NameScope - - nameScope += Seq(col1Integer, col2Integer, col3Boolean) - - assert(nameScope.getAllAttributes == Seq(col1Integer, col2Integer, col3Boolean)) - - assert( - nameScope.matchMultipartName(Seq("col1")) == NameTarget( - candidates = Seq(col1Integer), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col2")) == NameTarget( - candidates = Seq(col2Integer), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col3")) == NameTarget( - candidates = Seq(col3Boolean), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col4")) == NameTarget( - candidates = Seq.empty, - allAttributes = Seq(col1Integer, col2Integer, col3Boolean) - ) + checkOnePartNameLookup( + nameScope, + name = "col1", + candidates = Seq.empty ) } - test("Several unnamed plans") { - val nameScope = new NameScope - - nameScope += Seq(col1Integer) - nameScope += Seq(col2Integer, col3Boolean) - nameScope += Seq(col4String) + test("Distinct attributes") { + val nameScope = new NameScope(Seq(col1Integer, col2Integer, col3Boolean, col4String)) - assert(nameScope.getAllAttributes == Seq(col1Integer, col2Integer, col3Boolean, col4String)) + assert(nameScope.output == Seq(col1Integer, col2Integer, col3Boolean, col4String)) - assert( - nameScope.matchMultipartName(Seq("col1")) == NameTarget( - candidates = Seq(col1Integer), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col2")) == NameTarget( - candidates = Seq(col2Integer), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col3")) == NameTarget( - candidates = Seq(col3Boolean), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) - ) + checkOnePartNameLookup( + nameScope, + name = "col1", + candidates = Seq(col1Integer) ) - assert( - nameScope.matchMultipartName(Seq("col4")) == NameTarget( - candidates = Seq(col4String), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col5")) == NameTarget( - candidates = Seq.empty, - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) - ) - ) - } - - test("Single named plan") { - val nameScope = new NameScope - - nameScope("table1") = Seq(col1Integer, col2Integer, col3Boolean) - - assert(nameScope.getAllAttributes == Seq(col1Integer, col2Integer, col3Boolean)) - - assert( - nameScope.matchMultipartName(Seq("col1")) == NameTarget( - candidates = Seq(col1Integer), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean) - ) + checkOnePartNameLookup( + nameScope, + name = "col2", + candidates = Seq(col2Integer) ) - assert( - nameScope.matchMultipartName(Seq("col2")) == NameTarget( - candidates = Seq(col2Integer), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean) - ) + checkOnePartNameLookup( + nameScope, + name = "col3", + candidates = Seq(col3Boolean) ) - assert( - nameScope.matchMultipartName(Seq("col3")) == NameTarget( - candidates = Seq(col3Boolean), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean) - ) + checkOnePartNameLookup( + nameScope, + name = "col4", + candidates = Seq(col4String) ) - assert( - nameScope.matchMultipartName(Seq("col4")) == NameTarget( - candidates = Seq.empty, - allAttributes = Seq(col1Integer, col2Integer, col3Boolean) - ) + checkOnePartNameLookup( + nameScope, + name = "col5", + candidates = Seq.empty ) } - test("Several named plans") { - val nameScope = new NameScope + test("Duplicate attribute names") { + val nameScope = new NameScope(Seq(col1Integer, col1Integer, col1IntegerOther)) - nameScope("table1") = Seq(col1Integer) - nameScope("table2") = Seq(col2Integer, col3Boolean) - nameScope("table2") = Seq(col4String) - nameScope("table3") = Seq(col5String) - - assert( - nameScope.getAllAttributes == Seq( - col1Integer, - col2Integer, - col3Boolean, - col4String, - col5String - ) - ) + assert(nameScope.output == Seq(col1Integer, col1Integer, col1IntegerOther)) - assert( - nameScope.matchMultipartName(Seq("col1")) == NameTarget( - candidates = Seq(col1Integer), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String, col5String) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col2")) == NameTarget( - candidates = Seq(col2Integer), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String, col5String) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col3")) == NameTarget( - candidates = Seq(col3Boolean), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String, col5String) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col4")) == NameTarget( - candidates = Seq(col4String), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String, col5String) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col5")) == NameTarget( - candidates = Seq(col5String), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String, col5String) - ) - ) - assert( - nameScope.matchMultipartName(Seq("col6")) == NameTarget( - candidates = Seq.empty, - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String, col5String) - ) + checkOnePartNameLookup( + nameScope, + name = "col1", + candidates = Seq(col1Integer, col1Integer, col1IntegerOther) ) } - test("Named and unnamed plans with case insensitive comparison") { + test("Case insensitive comparison") { val col1Integer = AttributeReference(name = "Col1", dataType = IntegerType)() val col2Integer = AttributeReference(name = "col2", dataType = IntegerType)() val col3Boolean = AttributeReference(name = "coL3", dataType = BooleanType)() + val col3BooleanOther = AttributeReference(name = "Col3", dataType = BooleanType)() val col4String = AttributeReference(name = "Col4", dataType = StringType)() - val nameScope = new NameScope - - nameScope("TaBle1") = Seq(col1Integer) - nameScope("table2") = Seq(col2Integer, col3Boolean) - nameScope += Seq(col4String) - - assert(nameScope.getAllAttributes == Seq(col1Integer, col2Integer, col3Boolean, col4String)) + val nameScope = + new NameScope( + Seq(col1Integer, col3Boolean, col2Integer, col2Integer, col3BooleanOther, col4String) + ) assert( - nameScope.matchMultipartName(Seq("cOL1")) == NameTarget( - candidates = Seq(col1Integer.withName("cOL1")), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) + nameScope.output == Seq( + col1Integer, + col3Boolean, + col2Integer, + col2Integer, + col3BooleanOther, + col4String ) ) - assert( - nameScope.matchMultipartName(Seq("CoL2")) == NameTarget( - candidates = Seq(col2Integer.withName("CoL2")), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) - ) + + checkOnePartNameLookup( + nameScope, + name = "cOL1", + candidates = Seq(col1Integer) ) - assert( - nameScope.matchMultipartName(Seq("col3")) == NameTarget( - candidates = Seq(col3Boolean.withName("col3")), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) - ) + checkOnePartNameLookup( + nameScope, + name = "CoL2", + candidates = Seq(col2Integer, col2Integer) ) - assert( - nameScope.matchMultipartName(Seq("COL4")) == NameTarget( - candidates = Seq(col4String.withName("COL4")), - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) - ) + checkOnePartNameLookup( + nameScope, + name = "col3", + candidates = Seq(col3Boolean, col3BooleanOther) ) - assert( - nameScope.matchMultipartName(Seq("col5")) == NameTarget( - candidates = Seq.empty, - allAttributes = Seq(col1Integer, col2Integer, col3Boolean, col4String) - ) + checkOnePartNameLookup( + nameScope, + name = "COL4", + candidates = Seq(col4String) ) - } - - test("Duplicate attribute names from one plan") { - val nameScope = new NameScope - - nameScope("table1") = Seq(col1Integer, col1Integer) - nameScope("table1") = Seq(col1IntegerOther) - - assert(nameScope.getAllAttributes == Seq(col1Integer, col1Integer, col1IntegerOther)) - - nameScope.matchMultipartName(Seq("col1")) == NameTarget( - candidates = Seq(col1Integer, col1IntegerOther) - ) - } - - test("Duplicate attribute names from several plans") { - val nameScope = new NameScope - - nameScope("table1") = Seq(col1Integer, col1IntegerOther) - nameScope("table2") = Seq(col1Integer, col1IntegerOther) - - assert( - nameScope.getAllAttributes == Seq( - col1Integer, - col1IntegerOther, - col1Integer, - col1IntegerOther - ) - ) - - nameScope.matchMultipartName(Seq("col1")) == NameTarget( - candidates = Seq( - col1Integer, - col1IntegerOther, - col1Integer, - col1IntegerOther - ) + checkOnePartNameLookup( + nameScope, + name = "col5", + candidates = Seq.empty ) } test("Expand star") { - val nameScope = new NameScope - - nameScope("table") = + var nameScope = new NameScope( Seq(col6IntegerWithQualifier, col6IntegerOtherWithQualifier, col7StringWithQualifier) + ) Seq(Seq("table"), Seq("database", "table"), Seq("catalog", "database", "table")) .foreach(tableQualifier => { @@ -379,25 +230,10 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { "columns" -> "`col6`, `col6`, `col7`" ) ) - - nameScope("table2") = Seq(col6IntegerWithQualifier) - - checkError( - exception = intercept[AnalysisException]( - nameScope.expandStar(UnresolvedStar(Some(Seq("table2")))) - ), - condition = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map( - "elem" -> "'*'", - "prettyName" -> "query" - ) - ) } test("Multipart attribute names") { - val nameScope = new NameScope - - nameScope("table") = Seq(col6IntegerWithQualifier) + val nameScope = new NameScope(Seq(col6IntegerWithQualifier)) for (multipartIdentifier <- Seq( Seq("catalog", "database", "table", "col6"), @@ -405,11 +241,9 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { Seq("table", "col6") )) { assert( - nameScope.matchMultipartName(multipartIdentifier) == NameTarget( - candidates = Seq( - col6IntegerWithQualifier - ), - allAttributes = Seq(col6IntegerWithQualifier) + nameScope.resolveMultipartName(multipartIdentifier) == NameTarget( + candidates = Seq(col6IntegerWithQualifier), + output = Seq(col6IntegerWithQualifier) ) ) } @@ -420,34 +254,34 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { Seq("table.col6") )) { assert( - nameScope.matchMultipartName(multipartIdentifier) == NameTarget( + nameScope.resolveMultipartName(multipartIdentifier) == NameTarget( candidates = Seq.empty, - allAttributes = Seq(col6IntegerWithQualifier) + output = Seq(col6IntegerWithQualifier) ) ) } } test("Nested fields") { - val nameScope = new NameScope - - nameScope("table") = Seq( - col8Struct, - col9NestedStruct, - col10Map, - col11MapWithStruct, - col12Array, - col13ArrayWithStruct + var nameScope = new NameScope( + Seq( + col8Struct, + col9NestedStruct, + col10Map, + col11MapWithStruct, + col12Array, + col13ArrayWithStruct + ) ) - var matchedStructs = nameScope.matchMultipartName(Seq("col8", "field")) + var matchedStructs = nameScope.resolveMultipartName(Seq("col8", "field")) assert( matchedStructs == NameTarget( candidates = Seq( GetStructField(col8Struct, 0, Some("field")) ), aliasName = Some("field"), - allAttributes = Seq( + output = Seq( col8Struct, col9NestedStruct, col10Map, @@ -458,7 +292,7 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ) - matchedStructs = nameScope.matchMultipartName(Seq("col9", "field", "subfield")) + matchedStructs = nameScope.resolveMultipartName(Seq("col9", "field", "subfield")) assert( matchedStructs == NameTarget( candidates = Seq( @@ -473,7 +307,7 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ), aliasName = Some("subfield"), - allAttributes = Seq( + output = Seq( col8Struct, col9NestedStruct, col10Map, @@ -484,12 +318,12 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ) - var matchedMaps = nameScope.matchMultipartName(Seq("col10", "key")) + var matchedMaps = nameScope.resolveMultipartName(Seq("col10", "key")) assert( matchedMaps == NameTarget( candidates = Seq(GetMapValue(col10Map, Literal("key"))), aliasName = Some("key"), - allAttributes = Seq( + output = Seq( col8Struct, col9NestedStruct, col10Map, @@ -500,12 +334,12 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ) - matchedMaps = nameScope.matchMultipartName(Seq("col11", "key")) + matchedMaps = nameScope.resolveMultipartName(Seq("col11", "key")) assert( matchedMaps == NameTarget( candidates = Seq(GetMapValue(col11MapWithStruct, Literal("key"))), aliasName = Some("key"), - allAttributes = Seq( + output = Seq( col8Struct, col9NestedStruct, col10Map, @@ -516,12 +350,12 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ) - var matchedArrays = nameScope.matchMultipartName(Seq("col12", "element")) + var matchedArrays = nameScope.resolveMultipartName(Seq("col12", "element")) assert( matchedArrays == NameTarget( candidates = Seq(GetArrayItem(col12Array, Literal("element"))), aliasName = Some("element"), - allAttributes = Seq( + output = Seq( col8Struct, col9NestedStruct, col10Map, @@ -532,7 +366,7 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ) - matchedArrays = nameScope.matchMultipartName(Seq("col13", "field")) + matchedArrays = nameScope.resolveMultipartName(Seq("col13", "field")) assert( matchedArrays == NameTarget( candidates = Seq( @@ -545,7 +379,7 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ), aliasName = Some("field"), - allAttributes = Seq( + output = Seq( col8Struct, col9NestedStruct, col10Map, @@ -556,16 +390,11 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ) - nameScope("table2") = Seq(col8Struct) - matchedStructs = nameScope.matchMultipartName(Seq("col8", "field")) + nameScope = new NameScope(nameScope.output :+ col8Struct) + matchedStructs = nameScope.resolveMultipartName(Seq("col8", "field")) assert( matchedStructs == NameTarget( candidates = Seq( - GetStructField( - col8Struct, - 0, - Some("field") - ), GetStructField( col8Struct, 0, @@ -573,7 +402,7 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ), aliasName = Some("field"), - allAttributes = Seq( + output = Seq( col8Struct, col9NestedStruct, col10Map, @@ -585,6 +414,27 @@ class NameScopeSuite extends PlanTest with SQLConfHelper { ) ) } + + /** + * Check both [[resolveMultipartName]] and [[findAttributesByName]] for a single part name. + * + * [[resolveMultipartName]] respects the case sensitivity of the input name, and candidates are + * gonna have a new name which is case-identical to the queried name, while + * [[findAttributesByName]] is just a simple case-insensitive lookup. Also, + * [[resolveMultipartName]] deduplicates the candidates. + */ + private def checkOnePartNameLookup( + nameScope: NameScope, + name: String, + candidates: Seq[Attribute]): Unit = { + assert( + nameScope.resolveMultipartName(Seq(name)) == NameTarget( + candidates = candidates.distinct.map(attribute => attribute.withName(name)), + output = nameScope.output + ) + ) + assert(nameScope.findAttributesByName(name) == candidates) + } } class NameScopeStackSuite extends PlanTest { @@ -596,64 +446,84 @@ class NameScopeStackSuite extends PlanTest { test("Empty stack") { val stack = new NameScopeStack - assert(stack.top.getAllAttributes.isEmpty) + assert(stack.top.output.isEmpty) + } + + test("Overwrite top with empty sequence") { + val stack = new NameScopeStack + + stack.overwriteTop(Seq.empty) + assert(stack.top.output == Seq.empty) } test("Overwrite top of the stack containing single scope") { val stack = new NameScopeStack - stack.top.update("table1", Seq(col1Integer, col2String)) - assert(stack.top.getAllAttributes == Seq(col1Integer, col2String)) + stack.overwriteTop(Seq(col1Integer, col2String)) + assert(stack.top.output == Seq(col1Integer, col2String)) - stack.overwriteTop("table2", Seq(col3Integer, col4String)) - assert(stack.top.getAllAttributes == Seq(col3Integer, col4String)) + stack.overwriteTop(Seq(col3Integer, col4String)) + assert(stack.top.output == Seq(col3Integer, col4String)) stack.overwriteTop(Seq(col2String)) - assert(stack.top.getAllAttributes == Seq(col2String)) + assert(stack.top.output == Seq(col2String)) } test("Overwrite top of the stack containing several scopes") { val stack = new NameScopeStack - stack.top.update("table2", Seq(col3Integer)) + stack.overwriteTop(Seq(col3Integer)) - stack.withNewScope { - assert(stack.top.getAllAttributes.isEmpty) + val output = stack.withNewScope { + assert(stack.top.output.isEmpty) - stack.top.update("table1", Seq(col1Integer, col2String)) - assert(stack.top.getAllAttributes == Seq(col1Integer, col2String)) + stack.overwriteTop(Seq(col1Integer, col2String)) + assert(stack.top.output == Seq(col1Integer, col2String)) - stack.overwriteTop("table2", Seq(col3Integer, col4String)) - assert(stack.top.getAllAttributes == Seq(col3Integer, col4String)) + stack.overwriteTop(Seq(col3Integer, col4String)) + assert(stack.top.output == Seq(col3Integer, col4String)) stack.overwriteTop(Seq(col2String)) - assert(stack.top.getAllAttributes == Seq(col2String)) + assert(stack.top.output == Seq(col2String)) + + stack.top.output } + + assert(output == Seq(col2String)) } test("Scope stacking") { val stack = new NameScopeStack - stack.top.update("table1", Seq(col1Integer)) + stack.overwriteTop(Seq(col1Integer)) - stack.withNewScope { - stack.top.update("table2", Seq(col2String)) + val output = stack.withNewScope { + stack.overwriteTop(Seq(col2String)) + + val output = stack.withNewScope { + stack.overwriteTop(Seq(col3Integer)) - stack.withNewScope { - stack.top.update("table3", Seq(col3Integer)) + val output = stack.withNewScope { + stack.overwriteTop(Seq(col4String)) - stack.withNewScope { - stack.top.update("table4", Seq(col4String)) + assert(stack.top.output == Seq(col4String)) - assert(stack.top.getAllAttributes == Seq(col4String)) + stack.top.output } - assert(stack.top.getAllAttributes == Seq(col3Integer)) + assert(output == Seq(col4String)) + assert(stack.top.output == Seq(col3Integer)) + + stack.top.output } - assert(stack.top.getAllAttributes == Seq(col2String)) + assert(output == Seq(col3Integer)) + assert(stack.top.output == Seq(col2String)) + + stack.top.output } - assert(stack.top.getAllAttributes == Seq(col1Integer)) + assert(output == Seq(col2String)) + assert(stack.top.output == Seq(col1Integer)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala index d512adbb0af37..58048763f76db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.analysis.resolver import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.analysis.resolver.ResolverGuard +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.test.SharedSparkSession class ResolverGuardSuite extends QueryTest with SharedSparkSession { @@ -134,8 +135,6 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { checkResolverGuard("SELECT 1 AS alias", shouldPass = true) } - // Queries that shouldn't pass the OperatorResolverGuard - test("Select from table") { withTable("test_table") { sql("CREATE TABLE test_table (col1 INT, col2 INT)") @@ -151,6 +150,43 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { checkResolverGuard("SELECT * FROM (SELECT * FROM (SELECT * FROM VALUES(1)))", shouldPass = true) } + test("Union all") { + checkResolverGuard( + "SELECT * FROM VALUES(1) UNION ALL SELECT * FROM VALUES(2)", + shouldPass = true + ) + } + + test("CTE") { + checkResolverGuard( + """ + WITH cte1 AS ( + SELECT * FROM VALUES (1) + ), + cte2 AS ( + SELECT * FROM VALUES (2) + ) + SELECT * FROM cte1 + UNION ALL + SELECT * FROM cte2 + """, + shouldPass = true + ) + } + + test("Subquery column aliases") { + checkResolverGuard( + "SELECT t.a, t.b FROM VALUES (1, 2) t (a, b)", + shouldPass = true + ) + } + + test("Function") { + checkResolverGuard("SELECT assert_true(true)", shouldPass = true) + } + + // Queries that shouldn't pass the OperatorResolverGuard + test("Scalar subquery") { checkResolverGuard("SELECT (SELECT * FROM VALUES(1))", shouldPass = false) } @@ -169,10 +205,6 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { ) } - test("Function") { - checkResolverGuard("SELECT current_date()", shouldPass = false) - } - test("Function without the braces") { checkResolverGuard("SELECT current_date", shouldPass = false) } @@ -189,11 +221,29 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { } } + test("Union distinct") { + checkResolverGuard( + "SELECT * FROM VALUES (1) UNION DISTINCT SELECT * FROM VALUES (2)", + shouldPass = true + ) + } + + test("PLAN_ID_TAG") { + val plan = spark.sessionState.sqlParser.parsePlan("SELECT col1 FROM VALUES (1)") + + val planId: Long = 0 + plan.asInstanceOf[Project].projectList.head.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + + checkResolverGuard(plan, shouldPass = false) + } + private def checkResolverGuard(query: String, shouldPass: Boolean): Unit = { + checkResolverGuard(spark.sql(query).queryExecution.logical, shouldPass) + } + + private def checkResolverGuard(plan: LogicalPlan, shouldPass: Boolean): Unit = { val resolverGuard = new ResolverGuard(spark.sessionState.catalogManager) - assert( - resolverGuard.apply(sql(query).queryExecution.logical) == shouldPass - ) + assert(resolverGuard.apply(plan) == shouldPass) } private def withSessionVariable(body: => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverSuite.scala index 057724758d332..8a54f65209748 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverSuite.scala @@ -19,9 +19,14 @@ package org.apache.spark.sql.analysis.resolver import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.analysis.resolver.{Resolver, ResolverExtension} +import org.apache.spark.sql.catalyst.analysis.resolver.{ + LogicalPlanResolver, + Resolver, + ResolverExtension +} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.NormalizePlan +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Project} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.IntegerType @@ -42,8 +47,9 @@ class ResolverSuite extends QueryTest with SharedSparkSession { TestRelation(resolutionDone = false, output = Seq(col1Integer)) ) ) - assert( - result == Project( + assertPlansEqual( + result, + Project( Seq(col1Integer), TestRelation(resolutionDone = true, output = Seq(col1Integer)) ) @@ -70,7 +76,7 @@ class ResolverSuite extends QueryTest with SharedSparkSession { condition = "UNSUPPORTED_SINGLE_PASS_ANALYZER_FEATURE", parameters = Map( "feature" -> ("class " + - "org.apache.spark.sql.analysis.resolver.ResolverSuite$UnknownRelation operator resolution") + "org.apache.spark.sql.analysis.resolver.UnknownRelation operator resolution") ) ) } @@ -80,7 +86,7 @@ class ResolverSuite extends QueryTest with SharedSparkSession { Seq( new NoopResolver, new TestRelationResolver, - new TestRelationBrokenResolver + new TestRelationOtherResolver ) ) @@ -95,8 +101,8 @@ class ResolverSuite extends QueryTest with SharedSparkSession { ), condition = "AMBIGUOUS_RESOLVER_EXTENSION", parameters = Map( - "operator" -> "org.apache.spark.sql.analysis.resolver.ResolverSuite$TestRelation", - "extensions" -> "TestRelationResolver, TestRelationBrokenResolver" + "operator" -> "org.apache.spark.sql.analysis.resolver.TestRelation", + "extensions" -> "TestRelationResolver, TestRelationOtherResolver" ) ) } @@ -108,9 +114,13 @@ class ResolverSuite extends QueryTest with SharedSparkSession { private class TestRelationResolver extends ResolverExtension { var timesCalled = 0 - override def resolveOperator: PartialFunction[LogicalPlan, LogicalPlan] = { + override def resolveOperator( + operator: LogicalPlan, + resolver: LogicalPlanResolver): Option[LogicalPlan] = operator match { case testNode: TestRelation if countTimesCalled() => - testNode.copy(resolutionDone = true) + Some(testNode.copy(resolutionDone = true)) + case _ => + None } private def countTimesCalled(): Boolean = { @@ -120,38 +130,35 @@ class ResolverSuite extends QueryTest with SharedSparkSession { } } - private class TestRelationBrokenResolver extends ResolverExtension { - override def resolveOperator: PartialFunction[LogicalPlan, LogicalPlan] = { + private class TestRelationOtherResolver extends ResolverExtension { + override def resolveOperator( + operator: LogicalPlan, + resolver: LogicalPlanResolver): Option[LogicalPlan] = operator match { case testNode: TestRelation => - assert(false) - testNode + Some(testNode) + case _ => + None } } private class NoopResolver extends ResolverExtension { - override def resolveOperator: PartialFunction[LogicalPlan, LogicalPlan] = { + override def resolveOperator( + operator: LogicalPlan, + resolver: LogicalPlanResolver): Option[LogicalPlan] = operator match { case node: LogicalPlan if false => assert(false) - node + Some(node) + case _ => + None } } - private case class TestRelation( - resolutionDone: Boolean, - override val output: Seq[Attribute], - override val children: Seq[LogicalPlan] = Seq.empty) - extends LogicalPlan { - override protected def withNewChildrenInternal( - newChildren: IndexedSeq[LogicalPlan]): TestRelation = - copy(children = newChildren) - } - - private case class UnknownRelation( - override val output: Seq[Attribute], - override val children: Seq[LogicalPlan] = Seq.empty) - extends LogicalPlan { - override protected def withNewChildrenInternal( - newChildren: IndexedSeq[LogicalPlan]): UnknownRelation = - copy(children = newChildren) + private def assertPlansEqual(actualPlan: LogicalPlan, expectedPlan: LogicalPlan) = { + assert(NormalizePlan(actualPlan) == NormalizePlan(expectedPlan)) } } + +private case class TestRelation(resolutionDone: Boolean, override val output: Seq[Attribute]) + extends LeafNode {} + +private case class UnknownRelation(override val output: Seq[Attribute]) extends LeafNode {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/TracksResolvedNodesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/TracksResolvedNodesSuite.scala deleted file mode 100644 index b7bf73f326fa8..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/TracksResolvedNodesSuite.scala +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.analysis.resolver - -import org.apache.spark.SparkException -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.analysis.FunctionResolution -import org.apache.spark.sql.catalyst.analysis.resolver.{ - ExpressionResolver, - NameScopeStack, - PlanLogger, - Resolver -} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, ExprId} -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BooleanType, StringType} - -class TracksResolvedNodesSuite extends QueryTest with SharedSparkSession { - - override def beforeAll(): Unit = { - super.beforeAll() - spark.conf.set(SQLConf.ANALYZER_SINGLE_PASS_TRACK_RESOLVED_NODES_ENABLED.key, "true") - } - - test("Single-pass contract preserved for equal expressions with different memory addresses") { - val expressionResolver = createExpressionResolver() - val columnObjFirst = - AttributeReference(name = "column", dataType = BooleanType)(exprId = ExprId(0)) - val columnObjSecond = - AttributeReference(name = "column", dataType = BooleanType)(exprId = ExprId(0)) - - expressionResolver.resolve(columnObjFirst) - expressionResolver.resolve(columnObjSecond) - } - - test("Single-pass contract broken for operators") { - val resolver = createResolver() - - val project = Project( - projectList = Seq(), - child = Project( - projectList = Seq(), - child = OneRowRelation() - ) - ) - - val resolvedProject = resolver.lookupMetadataAndResolve(project) - - checkError( - exception = intercept[SparkException]({ - resolver.lookupMetadataAndResolve(resolvedProject.children.head) - }), - condition = "INTERNAL_ERROR", - parameters = Map( - "message" -> ("Single-pass resolver attempted to resolve the same " + - "node more than once: Project\n+- OneRowRelation\n") - ) - ) - checkError( - exception = intercept[SparkException]({ - resolver.lookupMetadataAndResolve(resolvedProject) - }), - condition = "INTERNAL_ERROR", - parameters = Map( - "message" -> ("Single-pass resolver attempted to resolve the same " + - "node more than once: Project\n+- Project\n +- OneRowRelation\n") - ) - ) - } - - test("Single-pass contract broken for expressions") { - val expressionResolver = createExpressionResolver() - - val cast = Cast( - child = AttributeReference(name = "column", dataType = BooleanType)(exprId = ExprId(0)), - dataType = StringType - ) - - val resolvedCast = expressionResolver.resolve(cast) - - checkError( - exception = intercept[SparkException]({ - expressionResolver.resolve(resolvedCast.children.head) - }), - condition = "INTERNAL_ERROR", - parameters = Map( - "message" -> ("Single-pass resolver attempted " + - "to resolve the same node more than once: column#0") - ) - ) - checkError( - exception = intercept[SparkException]({ - expressionResolver.resolve(resolvedCast) - }), - condition = "INTERNAL_ERROR", - parameters = Map( - "message" -> ("Single-pass resolver attempted " + - "to resolve the same node more than once: cast(column#0 as string)") - ) - ) - } - - private def createResolver(): Resolver = { - new Resolver(spark.sessionState.catalogManager) - } - - private def createExpressionResolver(): ExpressionResolver = { - new ExpressionResolver( - createResolver(), - new NameScopeStack, - new FunctionResolution( - spark.sessionState.catalogManager, - Resolver.createRelationResolution(spark.sessionState.catalogManager) - ), - new PlanLogger - ) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ViewResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ViewResolverSuite.scala new file mode 100644 index 0000000000000..2c8e240a8fd0b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ViewResolverSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.analysis.resolver + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.analysis.resolver.{MetadataResolver, Resolver} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{ + LocalRelation, + LogicalPlan, + OneRowRelation, + Project, + SubqueryAlias, + View +} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.util.Utils + +class ViewResolverSuite extends QueryTest with SharedSparkSession { + private val catalogName = if (Utils.unityCatalogTestsEnabled) { + "main" + } else { + "spark_catalog" + } + private val col1Integer = + AttributeReference(name = "col1", dataType = IntegerType, nullable = false)() + private val col2String = + AttributeReference(name = "col2", dataType = StringType, nullable = false)() + + test("Temporary view") { + withView("temporary_view") { + spark.sql("CREATE TEMPORARY VIEW temporary_view AS SELECT col1, col2 FROM VALUES (1, 'a');") + + checkViewResolution( + "SELECT * FROM temporary_view", + expectedChild = Project( + projectList = Seq( + Alias(Cast(col1Integer, IntegerType).withTimeZone(conf.sessionLocalTimeZone), "col1")(), + Alias(Cast(col2String, StringType).withTimeZone(conf.sessionLocalTimeZone), "col2")() + ), + child = Project( + projectList = Seq(col1Integer, col2String), + child = LocalRelation( + output = Seq(col1Integer, col2String), + data = Seq( + InternalRow.fromSeq(Seq(1, "a").map(CatalystTypeConverters.convertToCatalyst(_))) + ) + ) + ) + ) + ) + } + } + + test("Persistent view") { + withView("persistent_view") { + spark.sql("CREATE VIEW persistent_view AS SELECT col1, col2 FROM VALUES (1, 'a');") + + checkViewResolution( + "SELECT * FROM persistent_view", + expectedChild = Project( + projectList = Seq( + Alias(Cast(col1Integer, IntegerType).withTimeZone(conf.sessionLocalTimeZone), "col1")(), + Alias(Cast(col2String, StringType).withTimeZone(conf.sessionLocalTimeZone), "col2")() + ), + child = Project( + projectList = Seq(col1Integer, col2String), + child = LocalRelation( + output = Seq(col1Integer, col2String), + data = Seq( + InternalRow.fromSeq(Seq(1, "a").map(CatalystTypeConverters.convertToCatalyst(_))) + ) + ) + ) + ) + ) + } + } + + test("Nested views resolution failed") { + withTable("table1") { + spark.sql("CREATE TABLE table1 (col1 INT, col2 STRING);") + withView("view1") { + spark.sql("CREATE VIEW view1 AS SELECT col1, col2 FROM table1;") + withView("view2") { + spark.sql("CREATE VIEW view2 AS SELECT col2, col1 FROM view1;") + withView("view3") { + spark.sql("CREATE VIEW view3 AS SELECT col1, col2 FROM view2;") + + spark.sql("DROP TABLE table1;") + + checkErrorTableNotFound( + exception = intercept[AnalysisException] { + checkViewResolution("SELECT * FROM view3") + }, + tableName = "`table1`", + queryContext = ExpectedContext( + fragment = "view3", + start = 14, + stop = 18 + ) + ) + } + } + } + } + } + + test("Max nested view depth exceeded") { + try { + spark.sql("CREATE VIEW v0 AS SELECT * FROM VALUES (1);") + for (i <- 0 until conf.maxNestedViewDepth) { + spark.sql(s"CREATE VIEW v${i + 1} AS SELECT * FROM v${i};") + } + + checkError( + exception = intercept[AnalysisException] { + checkViewResolution(s"SELECT * FROM v${conf.maxNestedViewDepth}") + }, + condition = "VIEW_EXCEED_MAX_NESTED_DEPTH", + parameters = Map( + "viewName" -> s"`$catalogName`.`default`.`v0`", + "maxNestedDepth" -> conf.maxNestedViewDepth.toString + ), + context = ExpectedContext( + fragment = "v100", + start = 14, + stop = 17 + ) + ) + } finally { + for (i <- 0 until (conf.maxNestedViewDepth + 1)) { + spark.sql(s"DROP VIEW v${conf.maxNestedViewDepth - i};") + } + } + } + + private def checkViewResolution( + sqlText: String, + expectedChild: LogicalPlan = OneRowRelation()) = { + val metadataResolver = new MetadataResolver( + spark.sessionState.catalogManager, + Resolver.createRelationResolution(spark.sessionState.catalogManager) + ) + + val unresolvedPlan = spark.sessionState.sqlParser.parsePlan(sqlText) + + metadataResolver.resolve(unresolvedPlan) + + val unresolvedRelations = unresolvedPlan.collect { + case unresolvedRelation: UnresolvedRelation => unresolvedRelation + } + assert(unresolvedRelations.size == 1) + + val unresolvedView = metadataResolver + .getRelationWithResolvedMetadata(unresolvedRelations.head) + .get + .asInstanceOf[SubqueryAlias] + .child + .asInstanceOf[View] + + val resolver = new Resolver(spark.sessionState.catalogManager) + + val resolvedView = resolver + .lookupMetadataAndResolve(unresolvedPlan) + .asInstanceOf[Project] + .child + .asInstanceOf[SubqueryAlias] + .child + .asInstanceOf[View] + assert(resolvedView.isTempView == unresolvedView.isTempView) + assert( + normalizeExprIds(resolvedView.child).prettyJson == normalizeExprIds(expectedChild).prettyJson + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceResolverSuite.scala index 016c1e2f5457d..78b6fcc59b023 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceResolverSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceResolverSuite.scala @@ -19,7 +19,11 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.analysis.resolver.{MetadataResolver, Resolver} +import org.apache.spark.sql.catalyst.analysis.resolver.{ + MetadataResolver, + ProhibitedResolver, + Resolver +} import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.test.SharedSparkSession @@ -107,7 +111,8 @@ class DataSourceResolverSuite extends QueryTest with SharedSparkSession { .child assert(partiallyResolvedRelation.isInstanceOf[UnresolvedCatalogRelation]) - val result = dataSourceResolver.resolveOperator(partiallyResolvedRelation) + val result = + dataSourceResolver.resolveOperator(partiallyResolvedRelation, new ProhibitedResolver).get val logicalRelation = result.asInstanceOf[LogicalRelation] assert( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileResolverSuite.scala index 1d1b228028bdb..d6c9d497bce44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileResolverSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileResolverSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.analysis.resolver.ProhibitedResolver import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, StringType, StructType} @@ -83,9 +84,12 @@ class FileResolverSuite extends QueryTest with SharedSparkSession { val unresolvedPlan = spark.sql(sqlText).queryExecution.logical - val result = fileResolver.resolveOperator( - unresolvedPlan.asInstanceOf[Project].child.asInstanceOf[UnresolvedRelation] - ) + val result = fileResolver + .resolveOperator( + unresolvedPlan.asInstanceOf[Project].child.asInstanceOf[UnresolvedRelation], + new ProhibitedResolver + ) + .get val logicalRelation = result.asInstanceOf[LogicalRelation] assert( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolver.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolver.scala index acbc0cee0e301..d716318a7dc86 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolver.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolver.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.catalyst.analysis.resolver.LogicalPlanResolver import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.classic.SparkSession -import org.apache.spark.sql.execution.datasources.{DataSourceResolver, LogicalRelation} +import org.apache.spark.sql.execution.datasources.DataSourceResolver /** * [[DataSourceWithHiveResolver]] is a [[DataSourceResolver]] that additionally handles @@ -34,24 +35,27 @@ class DataSourceWithHiveResolver(sparkSession: SparkSession, hiveCatalog: HiveSe * Invoke [[DataSourceResolver]] to resolve the input operator. If [[DataSourceResolver]] produces * [[HiveTableRelation]], convert it to [[LogicalRelation]] if possible. */ - override def resolveOperator: PartialFunction[LogicalPlan, LogicalPlan] = { - case operator: LogicalPlan if super.resolveOperator.isDefinedAt(operator) => - val relationAfterDataSourceResolver = super.resolveOperator(operator) - - relationAfterDataSourceResolver match { - case hiveTableRelation: HiveTableRelation => - resolveHiveTableRelation(hiveTableRelation) - case other => other - } + override def resolveOperator( + operator: LogicalPlan, + resolver: LogicalPlanResolver): Option[LogicalPlan] = { + super.resolveOperator(operator, resolver) match { + case Some(relationAfterDataSourceResolver) => + val result = relationAfterDataSourceResolver match { + case hiveTableRelation: HiveTableRelation => + resolveHiveTableRelation(hiveTableRelation) + case other => other + } + Some(result) + case _ => + None + } } private def resolveHiveTableRelation(hiveTableRelation: HiveTableRelation): LogicalPlan = { if (relationConversions.doConvertHiveTableRelationForRead(hiveTableRelation)) { - val logicalRelation: LogicalRelation = - relationConversions.convertHiveTableRelationForRead(hiveTableRelation) - logicalRelation.newInstance() + relationConversions.convertHiveTableRelationForRead(hiveTableRelation) } else { - hiveTableRelation.newInstance() + hiveTableRelation } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolverSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolverSuite.scala index cb26354521b02..86f43796bca8c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolverSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DataSourceWithHiveResolverSuite.scala @@ -18,11 +18,15 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.analysis.resolver.{MetadataResolver, Resolver} +import org.apache.spark.sql.catalyst.analysis.resolver.{ + MetadataResolver, + ProhibitedResolver, + Resolver +} import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, UnresolvedCatalogRelation} import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.{HiveSessionCatalogUtils, HiveUtils} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -93,7 +97,9 @@ class DataSourceWithHiveResolverSuite extends TestHiveSingleton with SQLTestUtil .child assert(partiallyResolvedRelation.isInstanceOf[UnresolvedCatalogRelation]) - dataSourceWithHiveResolver.resolveOperator(partiallyResolvedRelation) match { + dataSourceWithHiveResolver + .resolveOperator(partiallyResolvedRelation, new ProhibitedResolver) + .get match { case logicalRelation: LogicalRelation => assert(convertedToLogicalRelation) assert(logicalRelation.catalogTable.get.identifier.unquotedString == expectedTableName)