diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala index 2a951cb9500..4b7a984144e 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -202,7 +202,12 @@ case class GpuBatchScanExec( val partitionMapping = groupedPartitions.map { case (row, parts) => InternalRowComparableWrapper(row, p.expressions) -> parts }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = KeyGroupedPartitioningShim.getUniquePartitions(p).map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala new file mode 100644 index 00000000000..53644a02804 --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning +import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper + +object KeyGroupedPartitioningShim { + def getUniquePartitions(p: KeyGroupedPartitioning): Seq[InternalRow] = { + p.partitionValues + .map(InternalRowComparableWrapper(_, p.expressions)) + .distinct + .map(_.row) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala new file mode 100644 index 00000000000..af88498cd34 --- /dev/null +++ b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning + +object KeyGroupedPartitioningShim { + def getUniquePartitions(p: KeyGroupedPartitioning): Seq[InternalRow] = { + p.uniquePartitionValues + } +} \ No newline at end of file