From 62ade5f3f42bdad200bfd9ca9e8110594f7c12e4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 12 Oct 2024 20:33:33 +0900 Subject: [PATCH] [SPARK-49924][SQL] Keep `containsNull` after `ArrayCompact` replacement ### What changes were proposed in this pull request? Fix `containsNull` of `ArrayCompact`, by adding a new expression `KnownNotContainsNull` ### Why are the changes needed? https://github.com/apache/spark/pull/47430 attempted to set `containsNull = false` for `ArrayCompact` for further optimization, but in an incomplete way: The `ArrayCompact` is a runtime replaceable expression, so will be replaced in optimizer, and cause the `containsNull` be reverted, e.g. ```sql select array_compact(array(1, null)) ``` Rule `ReplaceExpressions` changed `containsNull: false -> true` ``` old schema: StructField(array_compact(array(1, NULL)),ArrayType(IntegerType,false),false) new schema StructField(array_compact(array(1, NULL)),ArrayType(IntegerType,true),false) ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48410 from zhengruifeng/fix_array_compact_null. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../expressions/collectionOperations.scala | 6 ++--- .../expressions/constraintExpressions.scala | 13 +++++++++- .../catalyst/optimizer/OptimizerSuite.scala | 25 +++++++++++++++++-- .../function_array_compact.explain | 2 +- 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c091d51fc177f..bb54749126860 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkException.internalError import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.KnownNotContainsNull import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke @@ -5330,15 +5331,12 @@ case class ArrayCompact(child: Expression) child.dataType.asInstanceOf[ArrayType].elementType, true) lazy val lambda = LambdaFunction(isNotNull(lv), Seq(lv)) - override lazy val replacement: Expression = ArrayFilter(child, lambda) + override lazy val replacement: Expression = KnownNotContainsNull(ArrayFilter(child, lambda)) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) override def prettyName: String = "array_compact" - override def dataType: ArrayType = - child.dataType.asInstanceOf[ArrayType].copy(containsNull = false) - override protected def withNewChildInternal(newChild: Expression): ArrayCompact = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index 75d912633a0fc..f05db0b090c90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{ArrayType, DataType} trait TaggingExpression extends UnaryExpression { override def nullable: Boolean = child.nullable @@ -52,6 +52,17 @@ case class KnownNotNull(child: Expression) extends TaggingExpression { copy(child = newChild) } +case class KnownNotContainsNull(child: Expression) extends TaggingExpression { + override def dataType: DataType = + child.dataType.asInstanceOf[ArrayType].copy(containsNull = false) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + child.genCode(ctx) + + override protected def withNewChildInternal(newChild: Expression): KnownNotContainsNull = + copy(child = newChild) +} + case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression { override protected def withNewChildInternal(newChild: Expression): KnownFloatingPointNormalized = copy(child = newChild) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala index 48cdbbe7be539..70a2ae94109fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Add, Alias, AttributeReference, IntegerLiteral, Literal, Multiply, NamedExpression, Remainder} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayCompact, AttributeReference, CreateArray, CreateStruct, IntegerLiteral, Literal, MapFromEntries, Multiply, NamedExpression, Remainder} import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StructField, StructType} /** * A dummy optimizer rule for testing that decrements integer literals until 0. @@ -313,4 +313,25 @@ class OptimizerSuite extends PlanTest { assert(message1.contains("not a valid aggregate expression")) } } + + test("SPARK-49924: Keep containsNull after ArrayCompact replacement") { + val optimizer = new SimpleTestOptimizer() { + override def defaultBatches: Seq[Batch] = + Batch("test", fixedPoint, + ReplaceExpressions) :: Nil + } + + val array1 = ArrayCompact(CreateArray(Literal(1) :: Literal.apply(null) :: Nil, false)) + val plan1 = Project(Alias(array1, "arr")() :: Nil, OneRowRelation()).analyze + val optimized1 = optimizer.execute(plan1) + assert(optimized1.schema === + StructType(StructField("arr", ArrayType(IntegerType, false), false) :: Nil)) + + val struct = CreateStruct(Literal(1) :: Literal(2) :: Nil) + val array2 = ArrayCompact(CreateArray(struct :: Literal.apply(null) :: Nil, false)) + val plan2 = Project(Alias(MapFromEntries(array2), "map")() :: Nil, OneRowRelation()).analyze + val optimized2 = optimizer.execute(plan2) + assert(optimized2.schema === + StructType(StructField("map", MapType(IntegerType, IntegerType, false), false) :: Nil)) + } } diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain index a78195c4ae295..d42d0fd0a46ee 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain @@ -1,2 +1,2 @@ -Project [filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false)) AS array_compact(e)#0] +Project [knownnotcontainsnull(filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false))) AS array_compact(e)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]