diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md
index 941ab4046e6..7ef72a381d0 100644
--- a/docs/additional-functionality/advanced_configs.md
+++ b/docs/additional-functionality/advanced_configs.md
@@ -135,6 +135,16 @@ Name | Description | Default Value | Applicable at
spark.rapids.sql.json.read.decimal.enabled|When reading a quoted string as a decimal Spark supports reading non-ascii unicode digits, and the RAPIDS Accelerator does not.|true|Runtime
spark.rapids.sql.json.read.double.enabled|JSON reading is not 100% compatible when reading doubles.|true|Runtime
spark.rapids.sql.json.read.float.enabled|JSON reading is not 100% compatible when reading floats.|true|Runtime
+spark.rapids.sql.lore.dumpPath|The path to dump the lore nodes' input data. This must be set if spark.rapids.sql.lore.idsToDump has
+been set.|None|Runtime
+spark.rapids.sql.lore.idsToDump|Specify the lore ids of operators to dump. The format is a comma separated list of
+lore ids. For example: "1,2,3" will dump the gpu exec nodes with lore ids 1, 2, and 3.
+By default, all partitions of operators' input will be dumped. If you want to dump only
+some partitions, you can specify the partition index after the lore id, e.g. 1[0-2 4-5
+7], 2[0 4 5-8] , will dump partitions 0, 1, 2, 4, 5 and 7 of the operator with lore id
+ 1, and partitions 0, 4, 5, 6, 7, 8 of the operator with lore id 2.
+If this is not set, no lore nodes will be dumped.|None|Runtime
+spark.rapids.sql.lore.tag.enabled|Enable add a lore id to each gpu plan node|true|Runtime
spark.rapids.sql.mode|Set the mode for the Rapids Accelerator. The supported modes are explainOnly and executeOnGPU. This config can not be changed at runtime, you must restart the application for it to take affect. The default mode is executeOnGPU, which means the RAPIDS Accelerator plugin convert the Spark operations and execute them on the GPU when possible. The explainOnly mode allows running queries on the CPU and the RAPIDS Accelerator will evaluate the queries as if it was going to run on the GPU. The explanations of what would have run on the GPU and why are output in log messages. When using explainOnly mode, the default explain output is ALL, this can be changed by setting spark.rapids.sql.explain. See that config for more details.|executeongpu|Startup
spark.rapids.sql.optimizer.joinReorder.enabled|When enabled, joins may be reordered for improved query performance|true|Runtime
spark.rapids.sql.python.gpu.enabled|This is an experimental feature and is likely to change in the future. Enable (true) or disable (false) support for scheduling Python Pandas UDFs with GPU resources. When enabled, pandas UDFs are assumed to share the same GPU that the RAPIDs accelerator uses and will honor the python GPU configs|false|Runtime
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala
index bf949897c78..21d2de6ad68 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala
@@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids
-import java.io.{File, FileOutputStream}
+import java.io.{File, FileOutputStream, OutputStream}
import java.util.Random
import scala.collection.mutable
@@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf._
import ai.rapids.cudf.ColumnWriterOptions._
-import com.nvidia.spark.rapids.Arm.withResource
+import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
@@ -82,6 +82,23 @@ object DumpUtils extends Logging {
}
}
+ /**
+ * Dump columnar batch to output stream in parquet format.
+ *
+ * @param columnarBatch The columnar batch to be dumped, should be GPU columnar batch. It
+ * should be closed by caller.
+ * @param outputStream Will be closed after writing.
+ */
+ def dumpToParquet(columnarBatch: ColumnarBatch, outputStream: OutputStream): Unit = {
+ closeOnExcept(outputStream) { _ =>
+ withResource(GpuColumnVector.from(columnarBatch)) { table =>
+ withResource(new ParquetDumper(outputStream, table)) { dumper =>
+ dumper.writeTable(table)
+ }
+ }
+ }
+ }
+
/**
* Debug utility to dump table to parquet file.
* It's running on GPU. Parquet column names are generated from table column type info.
@@ -129,12 +146,15 @@ object DumpUtils extends Logging {
}
// parquet dumper
-class ParquetDumper(path: String, table: Table) extends HostBufferConsumer
+class ParquetDumper(private val outputStream: OutputStream, table: Table) extends HostBufferConsumer
with AutoCloseable {
- private[this] val outputStream = new FileOutputStream(path)
private[this] val tempBuffer = new Array[Byte](128 * 1024)
private[this] val buffers = mutable.Queue[(HostMemoryBuffer, Long)]()
+ def this(path: String, table: Table) = {
+ this(new FileOutputStream(path), table)
+ }
+
val tableWriter: TableWriter = {
// avoid anything conversion, just dump as it is
val builder = ParquetDumper.parquetWriterOptionsFromTable(ParquetWriterOptions.builder(), table)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala
index c58d9862be1..6d79dec447a 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala
@@ -1758,6 +1758,7 @@ case class GpuHashAggregateExec(
|${ExplainUtils.generateFieldString("Functions", aggregateExpressions)}
|${ExplainUtils.generateFieldString("Aggregate Attributes", aggregateAttributes)}
|${ExplainUtils.generateFieldString("Results", resultExpressions)}
+ |Lore: ${loreArgs.mkString(", ")}
|""".stripMargin
}
@@ -1886,9 +1887,9 @@ case class GpuHashAggregateExec(
truncatedString(allAggregateExpressions, "[", ", ", "]", maxFields)
val outputString = truncatedString(output, "[", ", ", "]", maxFields)
if (verbose) {
- s"GpuHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)"
+ s"$nodeName (keys=$keyString, functions=$functionString, output=$outputString)"
} else {
- s"GpuHashAggregate(keys=$keyString, functions=$functionString)," +
+ s"$nodeName (keys=$keyString, functions=$functionString)," +
s" filters=${aggregateExpressions.map(_.filter)})"
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
index ec87dd62d6c..d3eec3f83b2 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
@@ -19,7 +19,10 @@ package com.nvidia.spark.rapids
import ai.rapids.cudf.NvtxColor
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.filecache.FileCacheConf
+import com.nvidia.spark.rapids.lore.{GpuLore, GpuLoreDumpRDD}
+import com.nvidia.spark.rapids.lore.GpuLore.{loreIdOf, LORE_DUMP_PATH_TAG, LORE_DUMP_RDD_TAG}
import com.nvidia.spark.rapids.shims.SparkShimImpl
+import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.rapids.LocationPreservingMapPartitionsRDD
@@ -363,7 +366,8 @@ trait GpuExec extends SparkPlan {
this.getTagValue(GpuExec.TASK_METRICS_TAG)
final override def doExecuteColumnar(): RDD[ColumnarBatch] = {
- val orig = internalDoExecuteColumnar()
+ this.dumpLoreMetaInfo()
+ val orig = this.dumpLoreRDD(internalDoExecuteColumnar())
val metrics = getTaskMetrics
metrics.map { gpuMetrics =>
// This is ugly, but it reduces the need to change all exec nodes, so we are doing it here
@@ -374,5 +378,29 @@ trait GpuExec extends SparkPlan {
}.getOrElse(orig)
}
+ override def stringArgs: Iterator[Any] = super.stringArgs ++ loreArgs
+
+ protected def loreArgs: Iterator[String] = {
+ val loreIdStr = loreIdOf(this).map(id => s"[loreId=$id]")
+ val lorePathStr = getTagValue(LORE_DUMP_PATH_TAG).map(path => s"[lorePath=$path]")
+ val loreRDDInfoStr = getTagValue(LORE_DUMP_RDD_TAG).map(info => s"[loreRDDInfo=$info]")
+
+ List(loreIdStr, lorePathStr, loreRDDInfoStr).flatten.iterator
+ }
+
+ private def dumpLoreMetaInfo(): Unit = {
+ getTagValue(LORE_DUMP_PATH_TAG).foreach { rootPath =>
+ GpuLore.dumpPlan(this, new Path(rootPath))
+ }
+ }
+
+ protected def dumpLoreRDD(inner: RDD[ColumnarBatch]): RDD[ColumnarBatch] = {
+ getTagValue(LORE_DUMP_RDD_TAG).map { info =>
+ val rdd = new GpuLoreDumpRDD(info, inner)
+ rdd.saveMeta()
+ rdd
+ }.getOrElse(inner)
+ }
+
protected def internalDoExecuteColumnar(): RDD[ColumnarBatch]
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 295480d24cc..b309f296089 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -25,6 +25,7 @@ import scala.util.control.NonFatal
import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.RapidsConf.{SUPPRESS_PLANNING_FAILURE, TEST_CONF}
import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
+import com.nvidia.spark.rapids.lore.GpuLore
import com.nvidia.spark.rapids.shims._
import com.nvidia.spark.rapids.window.{GpuDenseRank, GpuLag, GpuLead, GpuPercentRank, GpuRank, GpuRowNumber, GpuSpecialFrameBoundary, GpuWindowExecMeta, GpuWindowSpecDefinitionMeta}
import org.apache.hadoop.fs.Path
@@ -4707,7 +4708,12 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging {
}
}
}
- GpuOverrides.doConvertPlan(wrap, conf, optimizations)
+ val convertedPlan = GpuOverrides.doConvertPlan(wrap, conf, optimizations)
+ if (conf.get(RapidsConf.TAG_LORE_ID_ENABLED)) {
+ GpuLore.tagForLore(convertedPlan, conf)
+ } else {
+ convertedPlan
+ }
}
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala
index 48f9de5a61a..b404b0df104 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala
@@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.tailrec
import scala.collection.mutable
+import com.nvidia.spark.rapids.lore.GpuLore
import com.nvidia.spark.rapids.shims.{GpuBatchScanExec, SparkShimImpl}
import org.apache.spark.SparkContext
@@ -823,6 +824,10 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
updatedPlan = fixupAdaptiveExchangeReuse(updatedPlan)
}
+ if (rapidsConf.get(RapidsConf.TAG_LORE_ID_ENABLED)) {
+ updatedPlan = GpuLore.tagForLore(updatedPlan, rapidsConf)
+ }
+
if (rapidsConf.logQueryTransformations) {
logWarning(s"Transformed query:" +
s"\nOriginal Plan:\n$plan\nTransformed Plan:\n$updatedPlan")
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
index 5203e926efa..5fe896fe100 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
@@ -2300,6 +2300,29 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.
.booleanConf
.createWithDefault(false)
+ val TAG_LORE_ID_ENABLED = conf("spark.rapids.sql.lore.tag.enabled")
+ .doc("Enable add a lore id to each gpu plan node")
+ .booleanConf
+ .createWithDefault(true)
+
+ val LORE_DUMP_IDS = conf("spark.rapids.sql.lore.idsToDump")
+ .doc("""Specify the lore ids of operators to dump. The format is a comma separated list of
+ |lore ids. For example: "1,2,3" will dump the gpu exec nodes with lore ids 1, 2, and 3.
+ |By default, all partitions of operators' input will be dumped. If you want to dump only
+ |some partitions, you can specify the partition index after the lore id, e.g. 1[0-2 4-5
+ |7], 2[0 4 5-8] , will dump partitions 0, 1, 2, 4, 5 and 7 of the operator with lore id
+ | 1, and partitions 0, 4, 5, 6, 7, 8 of the operator with lore id 2.
+ |If this is not set, no lore nodes will be dumped.""".stripMargin)
+ .stringConf
+ .createOptional
+
+ val LORE_DUMP_PATH = conf("spark.rapids.sql.lore.dumpPath")
+ .doc(
+ s"""The path to dump the lore nodes' input data. This must be set if ${LORE_DUMP_IDS.key} has
+ |been set.""".stripMargin)
+ .stringConf
+ .createOptional
+
private def printSectionHeader(category: String): Unit =
println(s"\n### $category")
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/GpuLore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/GpuLore.scala
new file mode 100644
index 00000000000..dc7705b6a2b
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/GpuLore.scala
@@ -0,0 +1,262 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.lore
+
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.reflect.ClassTag
+
+import com.nvidia.spark.rapids.{GpuColumnarToRowExec, GpuExec, GpuFilterExec, RapidsConf}
+import com.nvidia.spark.rapids.Arm.withResource
+import com.nvidia.spark.rapids.shims.{ShimLeafExecNode, SparkShimImpl}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
+import org.apache.spark.sql.execution.{BaseSubqueryExec, ExecSubqueryExpression, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
+import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuCustomShuffleReaderExec}
+import org.apache.spark.sql.types.DataType
+
+case class LoreRDDMeta(numPartitions: Int, outputPartitions: Seq[Int], attrs: Seq[Attribute])
+
+case class LoreRDDPartitionMeta(numBatches: Int, dataType: Seq[DataType])
+
+trait GpuLoreRDD {
+ def rootPath: Path
+
+ def pathOfMeta: Path = new Path(rootPath, "rdd.meta")
+
+ def pathOfPartition(partitionIndex: Int): Path = {
+ new Path(rootPath, s"partition-$partitionIndex")
+ }
+
+ def pathOfPartitionMeta(partitionIndex: Int): Path = {
+ new Path(pathOfPartition(partitionIndex), "partition.meta")
+ }
+
+ def pathOfBatch(partitionIndex: Int, batchIndex: Int): Path = {
+ new Path(pathOfPartition(partitionIndex), s"batch-$batchIndex.parquet")
+ }
+}
+
+
+object GpuLore {
+ /**
+ * Lore id of a plan node.
+ */
+ val LORE_ID_TAG: TreeNodeTag[String] = new TreeNodeTag[String]("rapids.gpu.lore.id")
+ /**
+ * When a [[GpuExec]] node has this tag, it means that this node is a root node whose meta and
+ * input should be dumped.
+ */
+ val LORE_DUMP_PATH_TAG: TreeNodeTag[String] = new TreeNodeTag[String]("rapids.gpu.lore.dump.path")
+ /**
+ * When a [[GpuExec]] node has this tag, it means that this node is a child node whose data
+ * should be dumped.
+ */
+ val LORE_DUMP_RDD_TAG: TreeNodeTag[LoreDumpRDDInfo] = new TreeNodeTag[LoreDumpRDDInfo](
+ "rapids.gpu.lore.dump.rdd.info")
+
+ def pathOfRootPlanMeta(rootPath: Path): Path = {
+ new Path(rootPath, "plan.meta")
+ }
+
+ def dumpPlan[T <: SparkPlan : ClassTag](plan: T, rootPath: Path): Unit = {
+ dumpObject(plan, pathOfRootPlanMeta(rootPath),
+ SparkShimImpl.sessionFromPlan(plan).sparkContext.hadoopConfiguration)
+ }
+
+ def dumpObject[T: ClassTag](obj: T, path: Path, hadoopConf: Configuration): Unit = {
+ withResource(path.getFileSystem(hadoopConf)) { fs =>
+ withResource(fs.create(path, false)) { fout =>
+ val serializerStream = SparkEnv.get.serializer.newInstance().serializeStream(fout)
+ withResource(serializerStream) { ser =>
+ ser.writeObject(obj)
+ }
+ }
+ }
+ }
+
+ def loadObject[T: ClassTag](path: Path, hadoopConf: Configuration): T = {
+ withResource(path.getFileSystem(hadoopConf)) { fs =>
+ withResource(fs.open(path)) { fin =>
+ val serializerStream = SparkEnv.get.serializer.newInstance().deserializeStream(fin)
+ withResource(serializerStream) { ser =>
+ ser.readObject().asInstanceOf[T]
+ }
+ }
+ }
+ }
+
+ def pathOfChild(rootPath: Path, childIndex: Int): Path = {
+ new Path(rootPath, s"input-$childIndex")
+ }
+
+ def restoreGpuExec(rootPath: Path, hadoopConf: Configuration): GpuExec = {
+ val rootExec = loadObject[GpuExec](pathOfRootPlanMeta(rootPath), hadoopConf)
+
+ // Load children
+ val newChildren = rootExec.children.zipWithIndex.map { case (plan, idx) =>
+ val newChild = GpuLoreReplayExec(idx, rootPath)
+ plan match {
+ case b: GpuBroadcastExchangeExec =>
+ b.withNewChildren(Seq(newChild))
+ case b: BroadcastQueryStageExec =>
+ b.broadcast.withNewChildren(Seq(newChild))
+ case _ => newChild
+ }
+ }
+
+ rootExec match {
+ case b: GpuFilterExec =>
+ val newExpr = restoreSubqueryExpression(1, b.condition, rootPath)
+ b.makeCopy(Array(newExpr, newChildren.head)).asInstanceOf[GpuExec]
+ case _ => rootExec.withNewChildren(newChildren)
+ .asInstanceOf[GpuExec]
+ }
+ }
+
+ private def restoreSubqueryExpression(startIdx: Int, expression: Expression,
+ rootPath: Path): Expression = {
+ var nextIdx = startIdx
+ val newExpr = expression.transformUp {
+ case sub: ExecSubqueryExpression if sub.plan.child.isInstanceOf[GpuExec] =>
+ var newChild: SparkPlan = GpuLoreReplayExec(nextIdx, rootPath)
+ if (!sub.plan.supportsColumnar) {
+ newChild = GpuColumnarToRowExec(newChild)
+ }
+ val newSubqueryExec = sub.plan.withNewChildren(Seq(newChild)).asInstanceOf[BaseSubqueryExec]
+ nextIdx += 1
+ sub.withNewPlan(newSubqueryExec)
+ }
+ newExpr
+ }
+
+ /**
+ * Lore id generator. Key is [[SQLExecution.EXECUTION_ID_KEY]].
+ */
+ private val idGen: ConcurrentMap[String, AtomicInteger] =
+ new ConcurrentHashMap[String, AtomicInteger]()
+
+ private def nextLoreIdOf(plan: SparkPlan): Option[Int] = {
+ // When the execution id is not set, it means there is no actual execution happening, in this
+ // case we don't need to generate lore id.
+ Option(SparkShimImpl.sessionFromPlan(plan)
+ .sparkContext
+ .getLocalProperty(SQLExecution.EXECUTION_ID_KEY))
+ .map { executionId =>
+ idGen.computeIfAbsent(executionId, _ => new AtomicInteger(0)).getAndIncrement()
+ }
+ }
+
+ def tagForLore(sparkPlan: SparkPlan, rapidsConf: RapidsConf): SparkPlan = {
+ val loreDumpIds = rapidsConf.get(RapidsConf.LORE_DUMP_IDS).map(OutputLoreId.parse)
+
+ val newPlan = loreDumpIds match {
+ case Some(dumpIds) =>
+ // We need to dump the output of the output of nodes with the lore id in the dump ids
+ val loreOutputRootPath = rapidsConf.get(RapidsConf.LORE_DUMP_PATH).getOrElse(throw
+ new IllegalArgumentException(s"${RapidsConf.LORE_DUMP_PATH.key} must be set " +
+ s"when ${RapidsConf.LORE_DUMP_IDS.key} is set."))
+
+ sparkPlan.foreachUp {
+ case g: GpuExec =>
+ nextLoreIdOf(g).foreach { loreId =>
+ g.setTagValue(LORE_ID_TAG, loreId.toString)
+
+ dumpIds.get(loreId).foreach { outputLoreIds =>
+ checkUnsupportedOperator(g)
+ val currentExecRootPath = new Path(loreOutputRootPath, s"loreId-$loreId")
+ g.setTagValue(LORE_DUMP_PATH_TAG, currentExecRootPath.toString)
+ val loreOutputInfo = LoreOutputInfo(outputLoreIds,
+ currentExecRootPath.toString)
+
+ g.children.zipWithIndex.foreach {
+ case (child, idx) =>
+ val dumpRDDInfo = LoreDumpRDDInfo(idx, loreOutputInfo, child.output)
+ child match {
+ case c: BroadcastQueryStageExec =>
+ c.broadcast.setTagValue(LORE_DUMP_RDD_TAG, dumpRDDInfo)
+ case o => o.setTagValue(LORE_DUMP_RDD_TAG, dumpRDDInfo)
+ }
+ }
+
+ g match {
+ case f: GpuFilterExec =>
+ tagForSubqueryPlan(1, f.condition, loreOutputInfo)
+ case _ =>
+ }
+ }
+ }
+ case _ =>
+ }
+
+ sparkPlan
+ case None =>
+ // We don't need to dump the output of the nodes, just tag the lore id
+ sparkPlan.foreachUp {
+ case g: GpuExec =>
+ nextLoreIdOf(g).foreach { loreId =>
+ g.setTagValue(LORE_ID_TAG, loreId.toString)
+ }
+ case _ =>
+ }
+
+ sparkPlan
+ }
+
+ newPlan
+ }
+
+ def loreIdOf(node: SparkPlan): Option[String] = {
+ node.getTagValue(LORE_ID_TAG)
+ }
+
+ private def tagForSubqueryPlan(startId: Int, expression: Expression,
+ loreOutputInfo: LoreOutputInfo): Int = {
+ var nextPlanId = startId
+ expression.foreachUp {
+ case sub: ExecSubqueryExpression =>
+ if (sub.plan.child.isInstanceOf[GpuExec]) {
+ val dumpRDDInfo = LoreDumpRDDInfo(nextPlanId, loreOutputInfo, sub.plan.child.output)
+ sub.plan.child match {
+ case p: GpuColumnarToRowExec => p.child.setTagValue(LORE_DUMP_RDD_TAG, dumpRDDInfo)
+ case c => c.setTagValue(LORE_DUMP_RDD_TAG, dumpRDDInfo)
+ }
+
+ nextPlanId += 1
+ } else {
+ throw new IllegalArgumentException(s"Subquery plan ${sub.plan} is not a GpuExec")
+ }
+ case _ =>
+ }
+ nextPlanId
+ }
+
+ private def checkUnsupportedOperator(plan: SparkPlan): Unit = {
+ if (plan.isInstanceOf[ShimLeafExecNode] ||
+ plan.isInstanceOf[GpuCustomShuffleReaderExec]
+ ) {
+ throw new UnsupportedOperationException(s"Currently we don't support dumping input of " +
+ s"${plan.getClass.getSimpleName} operator.")
+ }
+ }
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/OutputLoreId.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/OutputLoreId.scala
new file mode 100644
index 00000000000..47b7e8d1e63
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/OutputLoreId.scala
@@ -0,0 +1,73 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.lore
+
+import org.apache.hadoop.fs.Path
+
+case class OutputLoreId(loreId: LoreId, partitionIds: Set[Int]) {
+ def outputAllParitions: Boolean = partitionIds.isEmpty
+
+ def shouldOutputPartition(partitionId: Int): Boolean = outputAllParitions ||
+ partitionIds.contains(partitionId)
+}
+
+case class LoreOutputInfo(outputLoreId: OutputLoreId, pathStr: String) {
+ def path: Path = new Path(pathStr)
+}
+
+object OutputLoreId {
+ private val PARTITION_ID_RANGE_REGEX = raw"(\d+)-(\d+)".r("start", "end")
+ private val PARTITION_ID_REGEX = raw"(\d+)".r("partitionId")
+ private val PARTITION_IDS_REGEX = raw"($PARTITION_ID_RANGE_REGEX|$PARTITION_ID_REGEX)" +
+ raw"( +($PARTITION_ID_RANGE_REGEX|$PARTITION_ID_REGEX))*".r
+ private val PARTITION_ID_SEP_REGEX = raw" +".r
+
+ private val OUTPUT_LORE_ID_SEP_REGEX = ", *".r
+ private val OUTPUT_LORE_ID_REGEX =
+ raw"(?\d+)(\[(?$PARTITION_IDS_REGEX)\])?".r
+
+ def apply(loreId: Int): OutputLoreId = OutputLoreId(loreId, Set.empty)
+
+ def apply(inputStr: String): OutputLoreId = {
+ OUTPUT_LORE_ID_REGEX.findFirstMatchIn(inputStr).map { m =>
+ val loreId = m.group("loreId").toInt
+ val partitionIds: Set[Int] = m.group("partitionIds") match {
+ case null => Set.empty
+ case partitionIdsStr =>
+ PARTITION_ID_SEP_REGEX.split(partitionIdsStr).flatMap {
+ case PARTITION_ID_REGEX(partitionId) =>
+ Seq(partitionId.toInt)
+ case PARTITION_ID_RANGE_REGEX(start, end) =>
+ start.toInt until end.toInt
+ case partitionIdStr => throw new IllegalArgumentException(s"Invalid partition id: " +
+ s"$partitionIdStr")
+ }.toSet
+ }
+ OutputLoreId(loreId, partitionIds)
+ }.getOrElse(throw new IllegalArgumentException(s"Invalid output lore ids: $inputStr"))
+ }
+
+ def parse(inputStr: String): OutputLoreIds = {
+ require(inputStr != null, "inputStr should not be null")
+
+ OUTPUT_LORE_ID_SEP_REGEX.split(inputStr).map(OutputLoreId(_)).map { outputLoreId =>
+ outputLoreId.loreId -> outputLoreId
+ }.toMap
+ }
+}
+
+
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/dump.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/dump.scala
new file mode 100644
index 00000000000..bd3c88ec462
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/dump.scala
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.lore
+
+import com.nvidia.spark.rapids.{DumpUtils, GpuColumnVector}
+import com.nvidia.spark.rapids.GpuCoalesceExec.EmptyPartition
+import com.nvidia.spark.rapids.lore.GpuLore.pathOfChild
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.rapids.execution.GpuBroadcastHelper
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.SerializableConfiguration
+
+
+case class LoreDumpRDDInfo(idxInParent: Int, loreOutputInfo: LoreOutputInfo, attrs: Seq[Attribute])
+
+class GpuLoreDumpRDD(info: LoreDumpRDDInfo, input: RDD[ColumnarBatch])
+ extends RDD[ColumnarBatch](input) with GpuLoreRDD {
+ override def rootPath: Path = pathOfChild(info.loreOutputInfo.path, info.idxInParent)
+
+ private val hadoopConf = new SerializableConfiguration(this.context.hadoopConfiguration)
+
+ def saveMeta(): Unit = {
+ val meta = LoreRDDMeta(input.getNumPartitions, this.getPartitions.map(_.index), info.attrs)
+ GpuLore.dumpObject(meta, pathOfMeta, this.context.hadoopConfiguration)
+ }
+
+ override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
+ if (info.loreOutputInfo.outputLoreId.shouldOutputPartition(split.index)) {
+ val originalIter = input.compute(split, context)
+ new Iterator[ColumnarBatch] {
+ var batchIdx: Int = -1
+ var nextBatch: Option[ColumnarBatch] = None
+
+ override def hasNext: Boolean = {
+ if (batchIdx == -1) {
+ loadNextBatch()
+ }
+ nextBatch.isDefined
+ }
+
+ override def next(): ColumnarBatch = {
+ val ret = dumpCurrentBatch()
+ loadNextBatch()
+ if (!hasNext) {
+ // This is the last batch, save the partition meta
+ val partitionMeta = LoreRDDPartitionMeta(batchIdx, GpuColumnVector.extractTypes(ret))
+ GpuLore.dumpObject(partitionMeta, pathOfPartitionMeta(split.index), hadoopConf.value)
+ }
+ ret
+ }
+
+ private def dumpCurrentBatch(): ColumnarBatch = {
+ val outputPath = pathOfBatch(split.index, batchIdx)
+ val outputStream = outputPath.getFileSystem(hadoopConf.value).create(outputPath, false)
+ DumpUtils.dumpToParquet(nextBatch.get, outputStream)
+ nextBatch.get
+ }
+
+ private def loadNextBatch(): Unit = {
+ if (originalIter.hasNext) {
+ nextBatch = Some(originalIter.next())
+ } else {
+ nextBatch = None
+ }
+ batchIdx += 1
+ }
+ }
+ } else {
+ input.compute(split, context)
+ }
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ input.partitions
+ }
+}
+
+class SimpleRDD(_sc: SparkContext, data: Broadcast[Any], schema: StructType) extends
+ RDD[ColumnarBatch](_sc, Nil) {
+ override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
+ Seq(GpuBroadcastHelper.getBroadcastBatch(data, schema)).iterator
+ }
+
+ override protected def getPartitions: Array[Partition] = Array(EmptyPartition(0))
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/package.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/package.scala
new file mode 100644
index 00000000000..5fb03454c6b
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/package.scala
@@ -0,0 +1,54 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+/**
+ * Lore framework is used for dumping input data of a gpu executor to disk so that it can be
+ * replayed in local environment for performance analysis.
+ *
+ * When [[RapidsConf.TAG_LORE_ID_ENABLED]] is set, during the planning phase we will tag a lore
+ * id to each gpu operator. Lore id is guaranteed to be unique within a query, and it's supposed
+ * to be same for operators with same plan.
+ *
+ * When [[RapidsConf.LORE_DUMP_IDS]] is set, during the execution phase we will dump the input
+ * data of gpu operators with lore id to disk. The dumped data can be replayed in local
+ * environment. The dumped data will reside in [[RapidsConf.LORE_DUMP_PATH]], and typically will
+ * following directory hierarchy:
+ * {{{
+ * loreId-10/
+ * input-0/
+ * rdd.meta
+ * partition-0/
+ * partition.meta
+ * batch-0.parquet
+ * batch-1.parquet
+ * partition-1/
+ * partition.meta
+ * batch-0.parquet
+ *
+ *loreId-15/
+ * input-0/
+ * rdd.meta
+ * partition-0/
+ * partition.meta
+ * batch-0.parquet
+ * }}}
+ */
+package object lore {
+ type LoreId = Int
+ type OutputLoreIds = Map[LoreId, OutputLoreId]
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/replay.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/replay.scala
new file mode 100644
index 00000000000..e58297c4ba6
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/replay.scala
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.lore
+
+import ai.rapids.cudf.Table
+import com.nvidia.spark.rapids.{GpuColumnVector, GpuExec}
+import com.nvidia.spark.rapids.Arm.withResource
+import org.apache.commons.io.IOUtils
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.LeafExecNode
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.SerializableConfiguration
+
+case class GpuLoreReplayExec(idxInParent: Int, parentRootPath: Path) extends LeafExecNode
+ with GpuExec {
+ private lazy val rdd = new GpuLoreReplayRDD(sparkSession.sparkContext,
+ GpuLore.pathOfChild(parentRootPath, idxInParent))
+ override def output: Seq[Attribute] = rdd.loreRDDMeta.attrs
+
+ override def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException("LoreReplayExec does not support row mode")
+ }
+
+ override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = {
+ rdd
+ }
+}
+
+class GpuLoreReplayRDD(sc: SparkContext, override val rootPath: Path)
+ extends RDD[ColumnarBatch](sc, Nil) with GpuLoreRDD {
+ private val hadoopConf = new SerializableConfiguration(sc.hadoopConfiguration)
+ private[lore] val loreRDDMeta: LoreRDDMeta = GpuLore.loadObject(pathOfMeta, sc
+ .hadoopConfiguration)
+
+ override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
+ val partitionPath = pathOfPartition(split.index)
+ withResource(partitionPath.getFileSystem(hadoopConf.value)) { fs =>
+ if (!fs.exists(partitionPath)) {
+ Iterator.empty
+ } else {
+ val partitionMeta = GpuLore.loadObject[LoreRDDPartitionMeta](
+ pathOfPartitionMeta(split.index), hadoopConf.value)
+ new Iterator[ColumnarBatch] {
+ private var batchIdx: Int = 0
+
+ override def hasNext: Boolean = {
+ batchIdx < partitionMeta.numBatches
+ }
+
+ override def next(): ColumnarBatch = {
+ val batchPath = pathOfBatch(split.index, batchIdx)
+ val ret = withResource(batchPath.getFileSystem(hadoopConf.value)) { fs =>
+ if (!fs.exists(batchPath)) {
+ throw new IllegalStateException(s"Batch file $batchPath does not exist")
+ }
+ withResource(fs.open(batchPath)) { fin =>
+ val buffer = IOUtils.toByteArray(fin)
+ withResource(Table.readParquet(buffer)) { restoredTable =>
+ GpuColumnVector.from(restoredTable, partitionMeta.dataType.toArray)
+ }
+ }
+
+ }
+ batchIdx += 1
+ ret
+ }
+ }
+ }
+ }
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ (0 until loreRDDMeta.numPartitions).map(LoreReplayPartition).toArray
+ }
+}
+
+case class LoreReplayPartition(override val index: Int) extends Partition
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala
index 51c6f52d97e..bd30459d63e 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala
@@ -31,6 +31,8 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.GpuMetric._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
+import com.nvidia.spark.rapids.lore.{GpuLoreDumpRDD, SimpleRDD}
+import com.nvidia.spark.rapids.lore.GpuLore.LORE_DUMP_RDD_TAG
import com.nvidia.spark.rapids.shims.{ShimBroadcastExchangeLike, ShimUnaryExecNode, SparkShimImpl}
import org.apache.spark.SparkException
@@ -486,7 +488,9 @@ abstract class GpuBroadcastExchangeExecBase(
throw new IllegalStateException("A canonicalized plan is not supposed to be executed.")
}
try {
- relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[Broadcast[T]]
+ val ret = relationFuture.get(timeout, TimeUnit.SECONDS)
+ doLoreDump(ret)
+ ret.asInstanceOf[Broadcast[T]]
} catch {
case ex: TimeoutException =>
logError(s"Could not execute broadcast in $timeout secs.", ex)
@@ -501,6 +505,18 @@ abstract class GpuBroadcastExchangeExecBase(
}
}
+ // We have to do this explicitly here rather than similar to the general version one in
+ // [[GpuExec]] since in adaptive execution, the broadcast value has already been calculated
+ // before we tag this plan to dump.
+ private def doLoreDump(result: Broadcast[Any]): Unit = {
+ val inner = new SimpleRDD(session.sparkContext, result, schema)
+ getTagValue(LORE_DUMP_RDD_TAG).foreach { info =>
+ val rdd = new GpuLoreDumpRDD(info, inner)
+ rdd.saveMeta()
+ rdd.foreach(_.close())
+ }
+ }
+
override def runtimeStatistics: Statistics = {
Statistics(
sizeInBytes = metrics("dataSize").value,
diff --git a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/execution/datasources/GpuWriteFiles.scala b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/execution/datasources/GpuWriteFiles.scala
index 7cc94359daa..f1ffcf4df1f 100644
--- a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/execution/datasources/GpuWriteFiles.scala
+++ b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/execution/datasources/GpuWriteFiles.scala
@@ -157,7 +157,7 @@ case class GpuWriteFilesExec(
s" mismatch:\n$this")
}
- override protected def stringArgs: Iterator[Any] = Iterator(child)
+ override def stringArgs: Iterator[Any] = Iterator(child)
}
object GpuWriteFiles {
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/lore/GpuLoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/lore/GpuLoreSuite.scala
new file mode 100644
index 00000000000..ae109746b8e
--- /dev/null
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/lore/GpuLoreSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Copyright (c) 2020-2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.lore
+
+import com.nvidia.spark.rapids.{FunSuiteWithTempDir, GpuColumnarToRowExec, RapidsConf, SparkQueryCompareTestSuite}
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{functions, DataFrame, SparkSession}
+import org.apache.spark.sql.internal.SQLConf
+
+class GpuLoreSuite extends SparkQueryCompareTestSuite with FunSuiteWithTempDir with Logging {
+ test("Aggregate") {
+ doTestReplay("10") { spark =>
+ spark.range(0, 1000, 1, 100)
+ .selectExpr("id % 10 as key", "id % 100 as value")
+ .groupBy("key")
+ .agg(functions.sum("value").as("total"))
+ }
+ }
+
+ test("Broadcast join") {
+ doTestReplay("32") { spark =>
+ val df1 = spark.range(0, 1000, 1, 10)
+ .selectExpr("id % 10 as key", "id % 100 as value")
+ .groupBy("key")
+ .agg(functions.sum("value").as("count"))
+
+ val df2 = spark.range(0, 1000, 1, 10)
+ .selectExpr("(id % 10 + 5) as key", "id % 100 as value")
+ .groupBy("key")
+ .agg(functions.sum("value").as("count"))
+
+ df1.join(df2, Seq("key"))
+ }
+ }
+
+ test("Subquery") {
+ doTestReplay("13") { spark =>
+ spark.range(0, 100, 1, 10)
+ .createTempView("df1")
+
+ spark.range(50, 1000, 1, 10)
+ .createTempView("df2")
+
+ spark.sql("select * from df1 where id > (select max(id) from df2)")
+ }
+ }
+
+ test("No broadcast join") {
+ doTestReplay("30") { spark =>
+ spark.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
+
+ val df1 = spark.range(0, 1000, 1, 10)
+ .selectExpr("id % 10 as key", "id % 100 as value")
+ .groupBy("key")
+ .agg(functions.sum("value").as("count"))
+
+ val df2 = spark.range(0, 1000, 1, 10)
+ .selectExpr("(id % 10 + 5) as key", "id % 100 as value")
+ .groupBy("key")
+ .agg(functions.sum("value").as("count"))
+
+ df1.join(df2, Seq("key"))
+ }
+ }
+
+ test("AQE broadcast") {
+ doTestReplay("90") { spark =>
+ spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+ val df1 = spark.range(0, 1000, 1, 10)
+ .selectExpr("id % 10 as key", "id % 100 as value")
+ .groupBy("key")
+ .agg(functions.sum("value").as("count"))
+
+ val df2 = spark.range(0, 1000, 1, 10)
+ .selectExpr("(id % 10 + 5) as key", "id % 100 as value")
+ .groupBy("key")
+ .agg(functions.sum("value").as("count"))
+
+ df1.join(df2, Seq("key"))
+ }
+ }
+
+ test("AQE Exchange") {
+ doTestReplay("28") { spark =>
+ spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+
+ spark.range(0, 1000, 1, 100)
+ .selectExpr("id % 10 as key", "id % 100 as value")
+ .groupBy("key")
+ .agg(functions.sum("value").as("total"))
+ }
+ }
+
+ test("Partition only") {
+ withGpuSparkSession{ spark =>
+ spark.conf.set(RapidsConf.LORE_DUMP_PATH.key, TEST_FILES_ROOT.getAbsolutePath)
+ spark.conf.set(RapidsConf.LORE_DUMP_IDS.key, "3[0 2]")
+
+ val df = spark.range(0, 1000, 1, 100)
+ .selectExpr("id % 10 as key", "id % 100 as value")
+
+ val res = df.collect().length
+ println(s"Length of original: $res")
+
+
+ val restoredRes = GpuColumnarToRowExec(GpuLore.restoreGpuExec(
+ new Path(s"${TEST_FILES_ROOT.getAbsolutePath}/loreId-3"),
+ spark.sparkContext.hadoopConfiguration))
+ .executeCollect()
+ .length
+
+ assert(20 == restoredRes)
+ }
+ }
+
+ private def doTestReplay(loreDumpIds: String)(dfFunc: SparkSession => DataFrame) = {
+ val loreId = OutputLoreId.parse(loreDumpIds).head._1
+ withGpuSparkSession { spark =>
+ spark.conf.set(RapidsConf.LORE_DUMP_PATH.key, TEST_FILES_ROOT.getAbsolutePath)
+ spark.conf.set(RapidsConf.LORE_DUMP_IDS.key, loreDumpIds)
+
+ val df = dfFunc(spark)
+
+ val expectedLength = df.collect().length
+
+ val restoredResultLength = GpuColumnarToRowExec(GpuLore.restoreGpuExec(
+ new Path(s"${TEST_FILES_ROOT.getAbsolutePath}/loreId-$loreId"),
+ spark.sparkContext.hadoopConfiguration))
+ .executeCollect()
+ .length
+
+ assert(expectedLength == restoredResultLength)
+ }
+ }
+}
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/lore/OutputLoreIdSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/lore/OutputLoreIdSuite.scala
new file mode 100644
index 00000000000..d57f895c950
--- /dev/null
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/lore/OutputLoreIdSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2020-2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.lore
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class OutputLoreIdSuite extends AnyFunSuite {
+ test("Parse one output lore id") {
+ val expectedLoreIds = Map(1 -> OutputLoreId(1, Set(1, 2, 4, 8)))
+ val loreIds = OutputLoreId.parse("1[1 2 4 8]")
+
+ assert(loreIds == expectedLoreIds)
+ }
+
+ test("Parse multi output lore id") {
+ val expectedLoreIds = Map(
+ 1 -> OutputLoreId(1, Set(1, 2, 4, 8)),
+ 2 -> OutputLoreId(2, Set(1, 4, 5, 6, 7, 8, 100))
+ )
+ val loreIds = OutputLoreId.parse("1[1 2 4 8], 2[1 4-9 100]")
+
+ assert(loreIds == expectedLoreIds)
+ }
+
+ test("Parse empty output lore id") {
+ val expectedLoreIds = Map(1 -> OutputLoreId(1), 2 -> OutputLoreId(2))
+ val loreIds = OutputLoreId.parse("1 , 2")
+
+ assert(loreIds == expectedLoreIds)
+ }
+
+ test("Parse mixed") {
+ val expectedLoreIds = Map(
+ 1 -> OutputLoreId(1),
+ 2 -> OutputLoreId(2, Set(4, 5, 8)),
+ 3 -> OutputLoreId(3, Set(1, 2, 4, 8))
+ )
+ val loreIds = OutputLoreId.parse("1, 2[4-6 8] , 3[1 2 4 8]")
+
+ assert(loreIds == expectedLoreIds)
+ }
+}