From b1c55d1aca6d2f4459af8e8eceb13eb7e555df75 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 21 Sep 2023 16:25:53 -0600 Subject: [PATCH 1/3] Update code with fix for SPARK-44641 --- .../spark/rapids/shims/GpuBatchScanExec.scala | 1 - .../spark/rapids/shims/GpuBatchScanExec.scala | 261 ++++++++++++++++++ 2 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala index 2a951cb9500..4e15a9ea00c 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -17,7 +17,6 @@ /*** spark-rapids-shim-json-lines {"spark": "340"} {"spark": "341"} -{"spark": "350"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala new file mode 100644 index 00000000000..338281753a5 --- /dev/null +++ b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -0,0 +1,261 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import com.google.common.base.Objects +import com.nvidia.spark.rapids.{GpuBatchScanExecMetrics, GpuScan} + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, DynamicPruningExpression, Expression, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.execution.datasources.rapids.DataSourceStrategyUtils +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class GpuBatchScanExec( + output: Seq[AttributeReference], + @transient scan: GpuScan, + runtimeFilters: Seq[Expression] = Seq.empty, + keyGroupedPartitioning: Option[Seq[Expression]] = None, + ordering: Option[Seq[SortOrder]] = None, + @transient table: Table, + commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, + applyPartialClustering: Boolean = false, + replicatePartitions: Boolean = false) + extends DataSourceV2ScanExecBase with GpuBatchScanExecMetrics { + @transient lazy val batch: Batch = scan.toBatch + + // All expressions are filter expressions used on the CPU. + override def gpuExpressions: Seq[Expression] = Nil + + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: GpuBatchScanExec => + this.batch == other.batch && this.runtimeFilters == other.runtimeFilters && + this.commonPartitionValues == other.commonPartitionValues && + this.replicatePartitions == other.replicatePartitions && + this.applyPartialClustering == other.applyPartialClustering + case _ => + false + } + + override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters) + + @transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions() + + @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { + val dataSourceFilters = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => DataSourceStrategyUtils.translateRuntimeFilter(e) + case _ => None + } + + if (dataSourceFilters.nonEmpty) { + val originalPartitioning = outputPartitioning + + // the cast is safe as runtime filters are only assigned if the scan can be filtered + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] + filterableScan.filter(dataSourceFilters.toArray) + + // call toBatch again to get filtered partitions + val newPartitions = scan.toBatch.planInputPartitions() + + originalPartitioning match { + case p: KeyGroupedPartitioning => + if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { + throw new SparkException("Data source must have preserved the original partitioning " + + "during runtime filtering: not all partitions implement HasPartitionKey after " + + "filtering") + } + + val newPartitionValues = newPartitions.map(partition => + InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions)) + .toSet + val oldPartitionValues = p.partitionValues + .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet + // We require the new number of partition values to be equal or less than the old number + // of partition values here. In the case of less than, empty partitions will be added for + // those missing values that are not present in the new input partitions. + if (oldPartitionValues.size < newPartitionValues.size) { + throw new SparkException("During runtime filtering, data source must either report " + + "the same number of partition values, or a subset of partition values from the " + + s"original. Before: ${oldPartitionValues.size} partition values. " + + s"After: ${newPartitionValues.size} partition values") + } + + if (!newPartitionValues.forall(oldPartitionValues.contains)) { + throw new SparkException("During runtime filtering, data source must not report new " + + "partition values that are not present in the original partitioning.") + } + groupPartitions(newPartitions).get.map(_._2) + + case _ => + // no validation is needed as the data source did not report any specific partitioning + newPartitions.map(Seq(_)) + } + + } else { + partitions + } + } + + override def outputPartitioning: Partitioning = { + super.outputPartitioning match { + case k: KeyGroupedPartitioning if commonPartitionValues.isDefined => + // We allow duplicated partition values if + // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true + val newPartValues = commonPartitionValues.get.flatMap { case (partValue, numSplits) => + Seq.fill(numSplits)(partValue) + } + k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + case p => p + } + } + + override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() + + override lazy val inputRDD: RDD[InternalRow] = { + scan.metrics = allMetrics + val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) { + // return an empty RDD with 1 partition if dynamic filtering removed the only split + sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + var finalPartitions = filteredPartitions + + outputPartitioning match { + case p: KeyGroupedPartitioning => + if (conf.v2BucketingPushPartValuesEnabled && + conf.v2BucketingPartiallyClusteredDistributionEnabled) { + assert(filteredPartitions.forall(_.size == 1), + "Expect partitions to be not grouped when " + + s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + + "is enabled") + + val groupedPartitions = groupPartitions(finalPartitions.map(_.head), true).get + + // This means the input partitions are not grouped by partition values. We'll need to + // check `groupByPartitionValues` and decide whether to group and replicate splits + // within a partition. + if (commonPartitionValues.isDefined && applyPartialClustering) { + // A mapping from the common partition values to how many splits the partition + // should contain. Note this no longer maintain the partition key ordering. + val commonPartValuesMap = commonPartitionValues + .get + .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) + .toMap + val nestGroupedPartitions = groupedPartitions.map { + case (partValue, splits) => + // `commonPartValuesMap` should contain the part value since it's the super set. + val numSplits = commonPartValuesMap + .get(InternalRowComparableWrapper(partValue, p.expressions)) + assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + + "common partition values from Spark plan") + + val newSplits = if (replicatePartitions) { + // We need to also replicate partitions according to the other side of join + Seq.fill(numSplits.get)(splits) + } else { + // Not grouping by partition values: this could be the side with partially + // clustered distribution. Because of dynamic filtering, we'll need to check if + // the final number of splits of a partition is smaller than the original + // number, and fill with empty splits if so. This is necessary so that both + // sides of a join will have the same number of partitions & splits. + splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) + } + (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + } + + // Now fill missing partition keys with empty partitions + val partitionMapping = nestGroupedPartitions.toMap + finalPartitions = commonPartitionValues.get.flatMap { case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, p.expressions), + Seq.fill(numSplits)(Seq.empty)) + } + } else { + val partitionMapping = groupedPartitions.map { case (row, parts) => + InternalRowComparableWrapper(row, p.expressions) -> parts + }.toMap + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = p.uniquePartitionValues.map { partValue => + // Use empty partition for those partition values that are not present + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) + } + } + } else { + val partitionMapping = finalPartitions.map { parts => + val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey() + InternalRowComparableWrapper(row, p.expressions) -> parts + }.toMap + finalPartitions = p.partitionValues.map { partValue => + // Use empty partition for those partition values that are not present + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) + } + } + + case _ => + } + + new GpuDataSourceRDD(sparkContext, filteredPartitions, readerFactory) + } + postDriverMetrics() + rdd + } + + override def doCanonicalize(): GpuBatchScanExec = { + this.copy( + output = output.map(QueryPlan.normalizeExpressions(_, output)), + runtimeFilters = QueryPlan.normalizePredicates( + runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), + output)) + } + + override def simpleString(maxFields: Int): String = { + val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields) + val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}" + val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString" + redact(result) + } + + override def internalDoExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + inputRDD.asInstanceOf[RDD[ColumnarBatch]].map { b => + numOutputRows += b.numRows() + b + } + } + + override def nodeName: String = { + s"GpuBatchScan ${table.name()}".trim + } +} From 1ca454e161fd6db7d79b07134a1f6e75b77ed982 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 22 Sep 2023 08:10:59 -0600 Subject: [PATCH 2/3] signoff Signed-off-by: Andy Grove From c072447f6d8b19bbcd0d83a928538d8f28b37f2f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 25 Sep 2023 10:58:09 -0600 Subject: [PATCH 3/3] add KeyGroupedPartitioningShim and remove duplicate copy of GpuBatchScanExec --- .../spark/rapids/shims/GpuBatchScanExec.scala | 8 +- .../shims/KeyGroupedPartitioningShim.scala | 33 +++ .../spark/rapids/shims/GpuBatchScanExec.scala | 261 ------------------ .../shims/KeyGroupedPartitioningShim.scala | 28 ++ 4 files changed, 68 insertions(+), 262 deletions(-) create mode 100644 sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala delete mode 100644 sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala create mode 100644 sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala index 4e15a9ea00c..4b7a984144e 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -17,6 +17,7 @@ /*** spark-rapids-shim-json-lines {"spark": "340"} {"spark": "341"} +{"spark": "350"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims @@ -201,7 +202,12 @@ case class GpuBatchScanExec( val partitionMapping = groupedPartitions.map { case (row, parts) => InternalRowComparableWrapper(row, p.expressions) -> parts }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = KeyGroupedPartitioningShim.getUniquePartitions(p).map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala new file mode 100644 index 00000000000..53644a02804 --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning +import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper + +object KeyGroupedPartitioningShim { + def getUniquePartitions(p: KeyGroupedPartitioning): Seq[InternalRow] = { + p.partitionValues + .map(InternalRowComparableWrapper(_, p.expressions)) + .distinct + .map(_.row) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala deleted file mode 100644 index 338281753a5..00000000000 --- a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "350"} -spark-rapids-shim-json-lines ***/ -package com.nvidia.spark.rapids.shims - -import com.google.common.base.Objects -import com.nvidia.spark.rapids.{GpuBatchScanExecMetrics, GpuScan} - -import org.apache.spark.SparkException -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, DynamicPruningExpression, Expression, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition} -import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} -import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.execution.datasources.rapids.DataSourceStrategyUtils -import org.apache.spark.sql.execution.datasources.v2._ -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.vectorized.ColumnarBatch - -case class GpuBatchScanExec( - output: Seq[AttributeReference], - @transient scan: GpuScan, - runtimeFilters: Seq[Expression] = Seq.empty, - keyGroupedPartitioning: Option[Seq[Expression]] = None, - ordering: Option[Seq[SortOrder]] = None, - @transient table: Table, - commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, - applyPartialClustering: Boolean = false, - replicatePartitions: Boolean = false) - extends DataSourceV2ScanExecBase with GpuBatchScanExecMetrics { - @transient lazy val batch: Batch = scan.toBatch - - // All expressions are filter expressions used on the CPU. - override def gpuExpressions: Seq[Expression] = Nil - - // TODO: unify the equal/hashCode implementation for all data source v2 query plans. - override def equals(other: Any): Boolean = other match { - case other: GpuBatchScanExec => - this.batch == other.batch && this.runtimeFilters == other.runtimeFilters && - this.commonPartitionValues == other.commonPartitionValues && - this.replicatePartitions == other.replicatePartitions && - this.applyPartialClustering == other.applyPartialClustering - case _ => - false - } - - override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters) - - @transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions() - - @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { - val dataSourceFilters = runtimeFilters.flatMap { - case DynamicPruningExpression(e) => DataSourceStrategyUtils.translateRuntimeFilter(e) - case _ => None - } - - if (dataSourceFilters.nonEmpty) { - val originalPartitioning = outputPartitioning - - // the cast is safe as runtime filters are only assigned if the scan can be filtered - val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] - filterableScan.filter(dataSourceFilters.toArray) - - // call toBatch again to get filtered partitions - val newPartitions = scan.toBatch.planInputPartitions() - - originalPartitioning match { - case p: KeyGroupedPartitioning => - if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { - throw new SparkException("Data source must have preserved the original partitioning " + - "during runtime filtering: not all partitions implement HasPartitionKey after " + - "filtering") - } - - val newPartitionValues = newPartitions.map(partition => - InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions)) - .toSet - val oldPartitionValues = p.partitionValues - .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet - // We require the new number of partition values to be equal or less than the old number - // of partition values here. In the case of less than, empty partitions will be added for - // those missing values that are not present in the new input partitions. - if (oldPartitionValues.size < newPartitionValues.size) { - throw new SparkException("During runtime filtering, data source must either report " + - "the same number of partition values, or a subset of partition values from the " + - s"original. Before: ${oldPartitionValues.size} partition values. " + - s"After: ${newPartitionValues.size} partition values") - } - - if (!newPartitionValues.forall(oldPartitionValues.contains)) { - throw new SparkException("During runtime filtering, data source must not report new " + - "partition values that are not present in the original partitioning.") - } - groupPartitions(newPartitions).get.map(_._2) - - case _ => - // no validation is needed as the data source did not report any specific partitioning - newPartitions.map(Seq(_)) - } - - } else { - partitions - } - } - - override def outputPartitioning: Partitioning = { - super.outputPartitioning match { - case k: KeyGroupedPartitioning if commonPartitionValues.isDefined => - // We allow duplicated partition values if - // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true - val newPartValues = commonPartitionValues.get.flatMap { case (partValue, numSplits) => - Seq.fill(numSplits)(partValue) - } - k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) - case p => p - } - } - - override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() - - override lazy val inputRDD: RDD[InternalRow] = { - scan.metrics = allMetrics - val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) { - // return an empty RDD with 1 partition if dynamic filtering removed the only split - sparkContext.parallelize(Array.empty[InternalRow], 1) - } else { - var finalPartitions = filteredPartitions - - outputPartitioning match { - case p: KeyGroupedPartitioning => - if (conf.v2BucketingPushPartValuesEnabled && - conf.v2BucketingPartiallyClusteredDistributionEnabled) { - assert(filteredPartitions.forall(_.size == 1), - "Expect partitions to be not grouped when " + - s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + - "is enabled") - - val groupedPartitions = groupPartitions(finalPartitions.map(_.head), true).get - - // This means the input partitions are not grouped by partition values. We'll need to - // check `groupByPartitionValues` and decide whether to group and replicate splits - // within a partition. - if (commonPartitionValues.isDefined && applyPartialClustering) { - // A mapping from the common partition values to how many splits the partition - // should contain. Note this no longer maintain the partition key ordering. - val commonPartValuesMap = commonPartitionValues - .get - .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) - .toMap - val nestGroupedPartitions = groupedPartitions.map { - case (partValue, splits) => - // `commonPartValuesMap` should contain the part value since it's the super set. - val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, p.expressions)) - assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + - "common partition values from Spark plan") - - val newSplits = if (replicatePartitions) { - // We need to also replicate partitions according to the other side of join - Seq.fill(numSplits.get)(splits) - } else { - // Not grouping by partition values: this could be the side with partially - // clustered distribution. Because of dynamic filtering, we'll need to check if - // the final number of splits of a partition is smaller than the original - // number, and fill with empty splits if so. This is necessary so that both - // sides of a join will have the same number of partitions & splits. - splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) - } - (InternalRowComparableWrapper(partValue, p.expressions), newSplits) - } - - // Now fill missing partition keys with empty partitions - val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = commonPartitionValues.get.flatMap { case (partValue, numSplits) => - // Use empty partition for those partition values that are not present. - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), - Seq.fill(numSplits)(Seq.empty)) - } - } else { - val partitionMapping = groupedPartitions.map { case (row, parts) => - InternalRowComparableWrapper(row, p.expressions) -> parts - }.toMap - - // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there - // could exist duplicated partition values, as partition grouping is not done - // at the beginning and postponed to this method. It is important to use unique - // partition values here so that grouped partitions won't get duplicated. - finalPartitions = p.uniquePartitionValues.map { partValue => - // Use empty partition for those partition values that are not present - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) - } - } - } else { - val partitionMapping = finalPartitions.map { parts => - val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey() - InternalRowComparableWrapper(row, p.expressions) -> parts - }.toMap - finalPartitions = p.partitionValues.map { partValue => - // Use empty partition for those partition values that are not present - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) - } - } - - case _ => - } - - new GpuDataSourceRDD(sparkContext, filteredPartitions, readerFactory) - } - postDriverMetrics() - rdd - } - - override def doCanonicalize(): GpuBatchScanExec = { - this.copy( - output = output.map(QueryPlan.normalizeExpressions(_, output)), - runtimeFilters = QueryPlan.normalizePredicates( - runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), - output)) - } - - override def simpleString(maxFields: Int): String = { - val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields) - val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}" - val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString" - redact(result) - } - - override def internalDoExecuteColumnar(): RDD[ColumnarBatch] = { - val numOutputRows = longMetric("numOutputRows") - inputRDD.asInstanceOf[RDD[ColumnarBatch]].map { b => - numOutputRows += b.numRows() - b - } - } - - override def nodeName: String = { - s"GpuBatchScan ${table.name()}".trim - } -} diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala new file mode 100644 index 00000000000..af88498cd34 --- /dev/null +++ b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/KeyGroupedPartitioningShim.scala @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/*** spark-rapids-shim-json-lines +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning + +object KeyGroupedPartitioningShim { + def getUniquePartitions(p: KeyGroupedPartitioning): Seq[InternalRow] = { + p.uniquePartitionValues + } +} \ No newline at end of file