From 7a26525eef6d2b235c10850d9cf628fc36ca069d Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Tue, 2 Jul 2024 09:40:29 +0800 Subject: [PATCH] Introduce LORE framework. (#11084) * Introduce lore id * Introduce lore id * Fix type * Fix type * Conf * style * part * Dump * Introduce lore framework * Add tests. * Rename test case Signed-off-by: liurenjie1024 * Fix AQE test * Fix style * Use args to display lore info. * Fix build break * Fix path in loreinfo * Remove path * Fix comments * Update configs * Fix comments * Fix config --------- Signed-off-by: liurenjie1024 --- .../advanced_configs.md | 2 + docs/dev/lore.md | 70 +++++ .../com/nvidia/spark/rapids/DumpUtils.scala | 28 +- .../spark/rapids/GpuAggregateExec.scala | 9 +- .../com/nvidia/spark/rapids/GpuExec.scala | 30 +- .../nvidia/spark/rapids/GpuOverrides.scala | 8 +- .../spark/rapids/GpuTransitionOverrides.scala | 5 + .../com/nvidia/spark/rapids/RapidsConf.scala | 31 ++ .../nvidia/spark/rapids/lore/GpuLore.scala | 295 ++++++++++++++++++ .../spark/rapids/lore/OutputLoreId.scala | 75 +++++ .../com/nvidia/spark/rapids/lore/dump.scala | 106 +++++++ .../nvidia/spark/rapids/lore/package.scala | 35 +++ .../com/nvidia/spark/rapids/lore/replay.scala | 102 ++++++ .../execution/GpuBroadcastExchangeExec.scala | 18 +- .../execution/datasources/GpuWriteFiles.scala | 2 +- .../spark/rapids/lore/GpuLoreSuite.scala | 169 ++++++++++ .../spark/rapids/lore/OutputLoreIdSuite.scala | 55 ++++ 17 files changed, 1029 insertions(+), 11 deletions(-) create mode 100644 docs/dev/lore.md create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/GpuLore.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/OutputLoreId.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/dump.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/package.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/replay.scala create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/lore/GpuLoreSuite.scala create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/lore/OutputLoreIdSuite.scala diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 25bba0dbd90..e8f1e7620c6 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -137,6 +137,8 @@ Name | Description | Default Value | Applicable at spark.rapids.sql.json.read.decimal.enabled|When reading a quoted string as a decimal Spark supports reading non-ascii unicode digits, and the RAPIDS Accelerator does not.|true|Runtime spark.rapids.sql.json.read.double.enabled|JSON reading is not 100% compatible when reading doubles.|true|Runtime spark.rapids.sql.json.read.float.enabled|JSON reading is not 100% compatible when reading floats.|true|Runtime +spark.rapids.sql.lore.dumpPath|The path to dump the LORE nodes' input data. This must be set if spark.rapids.sql.lore.idsToDump has been set. The data of each LORE node will be dumped to a subfolder with name 'loreId-' under this path. For more details, please refer to [the LORE documentation](../dev/lore.md).|None|Runtime +spark.rapids.sql.lore.idsToDump|Specify the LORE ids of operators to dump. The format is a comma separated list of LORE ids. For example: "1[0]" will dump partition 0 of input of gpu operator with lore id 1. For more details, please refer to [the LORE documentation](../dev/lore.md). If this is not set, no data will be dumped.|None|Runtime spark.rapids.sql.mode|Set the mode for the Rapids Accelerator. The supported modes are explainOnly and executeOnGPU. This config can not be changed at runtime, you must restart the application for it to take affect. The default mode is executeOnGPU, which means the RAPIDS Accelerator plugin convert the Spark operations and execute them on the GPU when possible. The explainOnly mode allows running queries on the CPU and the RAPIDS Accelerator will evaluate the queries as if it was going to run on the GPU. The explanations of what would have run on the GPU and why are output in log messages. When using explainOnly mode, the default explain output is ALL, this can be changed by setting spark.rapids.sql.explain. See that config for more details.|executeongpu|Startup spark.rapids.sql.optimizer.joinReorder.enabled|When enabled, joins may be reordered for improved query performance|true|Runtime spark.rapids.sql.python.gpu.enabled|This is an experimental feature and is likely to change in the future. Enable (true) or disable (false) support for scheduling Python Pandas UDFs with GPU resources. When enabled, pandas UDFs are assumed to share the same GPU that the RAPIDs accelerator uses and will honor the python GPU configs|false|Runtime diff --git a/docs/dev/lore.md b/docs/dev/lore.md new file mode 100644 index 00000000000..d6b28877ae7 --- /dev/null +++ b/docs/dev/lore.md @@ -0,0 +1,70 @@ +--- +layout: page +title: The Local Replay Framework +nav_order: 13 +parent: Developer Overview +--- + +# Local Replay Framework + +## Overview + +LORE (the local replay framework) is a tool that allows developer to replay the execution of a +gpu operator in local environment, so that developer could debug and profile the operator for +performance analysis. In high level it works as follows: + +1. Each gpu operator will be assigned a LORE id, which is a unique identifier for the operator. + This id is guaranteed to be unique within the same query, and guaranteed to be same when two + sql executions have same sql, same configuration, and same data. +2. In the first run of the query, developer could found the LORE id of the operator they are + interested in by checking spark ui, where LORE id usually appears in the arguments of operator. +3. In the second run of the query, developer needs to configure the LORE ids of the operators they + are interested in, and LORE will dump the input data of the operator to given path. +4. Developer could copy the dumped data to local environment, and replay the operator in local + environment. + +## Configuration + +By default, LORE id will always be generated for operators, but user could disable this behavior +by setting `spark.rapids.sql.lore.tag.enabled` to `false`. + +To tell LORE the LORE ids of the operators you are interested in, you need to set +`spark.rapids.sql.lore.idsToDump`. For example, you could set it to "1[*], 2[*], 3[*]" to tell +LORE to dump all partitions of input data of operators with id 1, 2, or 3. You can also only dump +some partition of the operator's input by appending partition numbers to lore ids. For example, +"1[0 4-6 7], 2[*]" tell LORE to dump operator with LORE id 1, but only dump partition 0, 4, 5, 6, +and 7. But for operator with LORE id 2, it will dump all partitions. + +You also need to set `spark.rapids.sql.lore.dumpPath` to tell LORE where to dump the data, the +value of which should point to a directory. All dumped data of a query will live in this +directory. A typical directory hierarchy would look like this: + +```console ++ loreId-10/ + - plan.meta + + input-0/ + - rdd.meta + + partition-0/ + - partition.meta + - batch-0.parquet + - batch-1.parquet + + partition-1/ + - partition.meta + - batch-0.parquet + + input-1/ + - rdd.meta + + partition-0/ + - partition.meta + - batch-0.parquet + - batch-1.parquet + ++ loreId-15/ + - plan.meta + + input-0/ + - rdd.meta + + partition-0/ + - partition.meta + - batch-0.parquet +``` + + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala index bf949897c78..21d2de6ad68 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DumpUtils.scala @@ -15,7 +15,7 @@ */ package com.nvidia.spark.rapids -import java.io.{File, FileOutputStream} +import java.io.{File, FileOutputStream, OutputStream} import java.util.Random import scala.collection.mutable @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf._ import ai.rapids.cudf.ColumnWriterOptions._ -import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration @@ -82,6 +82,23 @@ object DumpUtils extends Logging { } } + /** + * Dump columnar batch to output stream in parquet format.
+ * + * @param columnarBatch The columnar batch to be dumped, should be GPU columnar batch. It + * should be closed by caller. + * @param outputStream Will be closed after writing. + */ + def dumpToParquet(columnarBatch: ColumnarBatch, outputStream: OutputStream): Unit = { + closeOnExcept(outputStream) { _ => + withResource(GpuColumnVector.from(columnarBatch)) { table => + withResource(new ParquetDumper(outputStream, table)) { dumper => + dumper.writeTable(table) + } + } + } + } + /** * Debug utility to dump table to parquet file.
* It's running on GPU. Parquet column names are generated from table column type info.
@@ -129,12 +146,15 @@ object DumpUtils extends Logging { } // parquet dumper -class ParquetDumper(path: String, table: Table) extends HostBufferConsumer +class ParquetDumper(private val outputStream: OutputStream, table: Table) extends HostBufferConsumer with AutoCloseable { - private[this] val outputStream = new FileOutputStream(path) private[this] val tempBuffer = new Array[Byte](128 * 1024) private[this] val buffers = mutable.Queue[(HostMemoryBuffer, Long)]() + def this(path: String, table: Table) = { + this(new FileOutputStream(path), table) + } + val tableWriter: TableWriter = { // avoid anything conversion, just dump as it is val builder = ParquetDumper.parquetWriterOptionsFromTable(ParquetWriterOptions.builder(), table) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala index 252b9e8a95b..7fe362a6031 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuAggregateExec.scala @@ -2000,6 +2000,7 @@ case class GpuHashAggregateExec( |${ExplainUtils.generateFieldString("Functions", aggregateExpressions)} |${ExplainUtils.generateFieldString("Aggregate Attributes", aggregateAttributes)} |${ExplainUtils.generateFieldString("Results", resultExpressions)} + |Lore: ${loreArgs.mkString(", ")} |""".stripMargin } @@ -2130,10 +2131,12 @@ case class GpuHashAggregateExec( truncatedString(allAggregateExpressions, "[", ", ", "]", maxFields) val outputString = truncatedString(output, "[", ", ", "]", maxFields) if (verbose) { - s"GpuHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" + s"$nodeName (keys=$keyString, functions=$functionString, output=$outputString) " + + s"""${loreArgs.mkString(", ")}""" } else { - s"GpuHashAggregate(keys=$keyString, functions=$functionString)," + - s" filters=${aggregateExpressions.map(_.filter)})" + s"$nodeName (keys=$keyString, functions=$functionString)," + + s" filters=${aggregateExpressions.map(_.filter)})" + + s""" ${loreArgs.mkString(", ")}""" } } // diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index 1cbf899c04d..0c9f1a8ac5a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -19,7 +19,10 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.NvtxColor import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.filecache.FileCacheConf +import com.nvidia.spark.rapids.lore.{GpuLore, GpuLoreDumpRDD} +import com.nvidia.spark.rapids.lore.GpuLore.{loreIdOf, LORE_DUMP_PATH_TAG, LORE_DUMP_RDD_TAG} import com.nvidia.spark.rapids.shims.SparkShimImpl +import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.rapids.LocationPreservingMapPartitionsRDD @@ -387,7 +390,8 @@ trait GpuExec extends SparkPlan { this.getTagValue(GpuExec.TASK_METRICS_TAG) final override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val orig = internalDoExecuteColumnar() + this.dumpLoreMetaInfo() + val orig = this.dumpLoreRDD(internalDoExecuteColumnar()) val metrics = getTaskMetrics metrics.map { gpuMetrics => // This is ugly, but it reduces the need to change all exec nodes, so we are doing it here @@ -398,5 +402,29 @@ trait GpuExec extends SparkPlan { }.getOrElse(orig) } + override def stringArgs: Iterator[Any] = super.stringArgs ++ loreArgs + + protected def loreArgs: Iterator[String] = { + val loreIdStr = loreIdOf(this).map(id => s"[loreId=$id]") + val lorePathStr = getTagValue(LORE_DUMP_PATH_TAG).map(path => s"[lorePath=$path]") + val loreRDDInfoStr = getTagValue(LORE_DUMP_RDD_TAG).map(info => s"[loreRDDInfo=$info]") + + List(loreIdStr, lorePathStr, loreRDDInfoStr).flatten.iterator + } + + private def dumpLoreMetaInfo(): Unit = { + getTagValue(LORE_DUMP_PATH_TAG).foreach { rootPath => + GpuLore.dumpPlan(this, new Path(rootPath)) + } + } + + protected def dumpLoreRDD(inner: RDD[ColumnarBatch]): RDD[ColumnarBatch] = { + getTagValue(LORE_DUMP_RDD_TAG).map { info => + val rdd = new GpuLoreDumpRDD(info, inner) + rdd.saveMeta() + rdd + }.getOrElse(inner) + } + protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 938307239d0..c86facfb4d1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -25,6 +25,7 @@ import scala.util.control.NonFatal import ai.rapids.cudf.DType import com.nvidia.spark.rapids.RapidsConf.{SUPPRESS_PLANNING_FAILURE, TEST_CONF} import com.nvidia.spark.rapids.jni.GpuTimeZoneDB +import com.nvidia.spark.rapids.lore.GpuLore import com.nvidia.spark.rapids.shims._ import com.nvidia.spark.rapids.window.{GpuDenseRank, GpuLag, GpuLead, GpuPercentRank, GpuRank, GpuRowNumber, GpuSpecialFrameBoundary, GpuWindowExecMeta, GpuWindowSpecDefinitionMeta} import org.apache.hadoop.fs.Path @@ -4733,7 +4734,12 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging { } } } - GpuOverrides.doConvertPlan(wrap, conf, optimizations) + val convertedPlan = GpuOverrides.doConvertPlan(wrap, conf, optimizations) + if (conf.isTagLoreIdEnabled) { + GpuLore.tagForLore(convertedPlan, conf) + } else { + convertedPlan + } } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index 48f9de5a61a..c8596f983d9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.annotation.tailrec import scala.collection.mutable +import com.nvidia.spark.rapids.lore.GpuLore import com.nvidia.spark.rapids.shims.{GpuBatchScanExec, SparkShimImpl} import org.apache.spark.SparkContext @@ -823,6 +824,10 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { updatedPlan = fixupAdaptiveExchangeReuse(updatedPlan) } + if (rapidsConf.isTagLoreIdEnabled) { + updatedPlan = GpuLore.tagForLore(updatedPlan, rapidsConf) + } + if (rapidsConf.logQueryTransformations) { logWarning(s"Transformed query:" + s"\nOriginal Plan:\n$plan\nTransformed Plan:\n$updatedPlan") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 46c2806140e..22b0d6a2501 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.{HashMap, ListBuffer} import ai.rapids.cudf.Cuda import com.nvidia.spark.rapids.jni.RmmSpark.OomInjectionType +import com.nvidia.spark.rapids.lore.{LoreId, OutputLoreId} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -2315,6 +2316,28 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .booleanConf .createWithDefault(false) + val TAG_LORE_ID_ENABLED = conf("spark.rapids.sql.lore.tag.enabled") + .doc("Enable add a LORE id to each gpu plan node") + .internal() + .booleanConf + .createWithDefault(true) + + val LORE_DUMP_IDS = conf("spark.rapids.sql.lore.idsToDump") + .doc("Specify the LORE ids of operators to dump. The format is a comma separated list of " + + "LORE ids. For example: \"1[0]\" will dump partition 0 of input of gpu operator " + + "with lore id 1. For more details, please refer to " + + "[the LORE documentation](../dev/lore.md). If this is not set, no data will be dumped.") + .stringConf + .createOptional + + val LORE_DUMP_PATH = conf("spark.rapids.sql.lore.dumpPath") + .doc(s"The path to dump the LORE nodes' input data. This must be set if ${LORE_DUMP_IDS.key} " + + "has been set. The data of each LORE node will be dumped to a subfolder with name " + + "'loreId-' under this path. For more details, please refer to " + + "[the LORE documentation](../dev/lore.md).") + .stringConf + .createOptional + private def printSectionHeader(category: String): Unit = println(s"\n### $category") @@ -3130,6 +3153,14 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isDeltaLowShuffleMergeEnabled: Boolean = get(ENABLE_DELTA_LOW_SHUFFLE_MERGE) + lazy val isTagLoreIdEnabled: Boolean = get(TAG_LORE_ID_ENABLED) + + lazy val loreDumpIds: Map[LoreId, OutputLoreId] = get(LORE_DUMP_IDS) + .map(OutputLoreId.parse) + .getOrElse(Map.empty) + + lazy val loreDumpPath: Option[String] = get(LORE_DUMP_PATH) + private val optimizerDefaults = Map( // this is not accurate because CPU projections do have a cost due to appending values // to each row that is produced, but this needs to be a really small number because diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/GpuLore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/GpuLore.scala new file mode 100644 index 00000000000..a51a1e13a5e --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/GpuLore.scala @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.lore + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable +import scala.reflect.ClassTag + +import com.nvidia.spark.rapids.{GpuColumnarToRowExec, GpuExec, RapidsConf} +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.shims.SparkShimImpl +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkEnv +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.execution.{BaseSubqueryExec, ExecSubqueryExpression, ReusedSubqueryExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec +import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, GpuCustomShuffleReaderExec} +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.SerializableConfiguration + +case class LoreRDDMeta(numPartitions: Int, outputPartitions: Seq[Int], attrs: Seq[Attribute]) + +case class LoreRDDPartitionMeta(numBatches: Int, dataType: Seq[DataType]) + +trait GpuLoreRDD { + def rootPath: Path + + def pathOfMeta: Path = new Path(rootPath, "rdd.meta") + + def pathOfPartition(partitionIndex: Int): Path = { + new Path(rootPath, s"partition-$partitionIndex") + } + + def pathOfPartitionMeta(partitionIndex: Int): Path = { + new Path(pathOfPartition(partitionIndex), "partition.meta") + } + + def pathOfBatch(partitionIndex: Int, batchIndex: Int): Path = { + new Path(pathOfPartition(partitionIndex), s"batch-$batchIndex.parquet") + } +} + + +object GpuLore { + /** + * Lore id of a plan node. + */ + val LORE_ID_TAG: TreeNodeTag[String] = new TreeNodeTag[String]("rapids.gpu.lore.id") + /** + * When a [[GpuExec]] node has this tag, it means that this node is a root node whose meta and + * input should be dumped. + */ + val LORE_DUMP_PATH_TAG: TreeNodeTag[String] = new TreeNodeTag[String]("rapids.gpu.lore.dump.path") + /** + * When a [[GpuExec]] node has this tag, it means that this node is a child node whose data + * should be dumped. + */ + val LORE_DUMP_RDD_TAG: TreeNodeTag[LoreDumpRDDInfo] = new TreeNodeTag[LoreDumpRDDInfo]( + "rapids.gpu.lore.dump.rdd.info") + + def pathOfRootPlanMeta(rootPath: Path): Path = { + new Path(rootPath, "plan.meta") + } + + def dumpPlan[T <: SparkPlan : ClassTag](plan: T, rootPath: Path): Unit = { + dumpObject(plan, pathOfRootPlanMeta(rootPath), + SparkShimImpl.sessionFromPlan(plan).sparkContext.hadoopConfiguration) + } + + def dumpObject[T: ClassTag](obj: T, path: Path, hadoopConf: Configuration): Unit = { + withResource(path.getFileSystem(hadoopConf)) { fs => + withResource(fs.create(path, false)) { fout => + val serializerStream = SparkEnv.get.serializer.newInstance().serializeStream(fout) + withResource(serializerStream) { ser => + ser.writeObject(obj) + } + } + } + } + + def loadObject[T: ClassTag](path: Path, hadoopConf: Configuration): T = { + withResource(path.getFileSystem(hadoopConf)) { fs => + withResource(fs.open(path)) { fin => + val serializerStream = SparkEnv.get.serializer.newInstance().deserializeStream(fin) + withResource(serializerStream) { ser => + ser.readObject().asInstanceOf[T] + } + } + } + } + + def pathOfChild(rootPath: Path, childIndex: Int): Path = { + new Path(rootPath, s"input-$childIndex") + } + + def restoreGpuExec(rootPath: Path, spark: SparkSession): GpuExec = { + val rootExec = loadObject[GpuExec](pathOfRootPlanMeta(rootPath), + spark.sparkContext.hadoopConfiguration) + + checkUnsupportedOperator(rootExec) + + val broadcastHadoopConf = { + val sc = spark.sparkContext + sc.broadcast(new SerializableConfiguration(spark.sparkContext.hadoopConfiguration)) + } + + // Load children + val newChildren = rootExec.children.zipWithIndex.map { case (plan, idx) => + val newChild = GpuLoreReplayExec(idx, rootPath.toString, broadcastHadoopConf) + plan match { + case b: GpuBroadcastExchangeExec => + b.withNewChildren(Seq(newChild)) + case b: BroadcastQueryStageExec => + b.broadcast.withNewChildren(Seq(newChild)) + case _ => newChild + } + } + + var nextId = rootExec.children.length + + rootExec.transformExpressionsUp { + case sub: ExecSubqueryExpression => + val newSub = restoreSubqueryPlan(nextId, sub, rootPath, broadcastHadoopConf) + nextId += 1 + newSub + }.withNewChildren(newChildren).asInstanceOf[GpuExec] + } + + private def restoreSubqueryPlan(id: Int, sub: ExecSubqueryExpression, + rootPath: Path, hadoopConf: Broadcast[SerializableConfiguration]): ExecSubqueryExpression = { + val innerPlan = sub.plan.child + + if (innerPlan.isInstanceOf[GpuExec]) { + var newChild: SparkPlan = GpuLoreReplayExec(id, rootPath.toString, hadoopConf) + + if (!innerPlan.supportsColumnar) { + newChild = GpuColumnarToRowExec(newChild) + } + val newSubqueryExec = sub.plan match { + case ReusedSubqueryExec(subqueryExec) => subqueryExec.withNewChildren(Seq(newChild)) + .asInstanceOf[BaseSubqueryExec] + case p: BaseSubqueryExec => p.withNewChildren(Seq(newChild)) + .asInstanceOf[BaseSubqueryExec] + } + sub.withNewPlan(newSubqueryExec) + } else { + throw new IllegalArgumentException(s"Subquery plan ${innerPlan.getClass.getSimpleName} " + + s"is not a GpuExec") + } + } + + /** + * Lore id generator. Key is [[SQLExecution.EXECUTION_ID_KEY]]. + */ + private val idGen: ConcurrentMap[String, AtomicInteger] = + new ConcurrentHashMap[String, AtomicInteger]() + + private def nextLoreIdOf(plan: SparkPlan): Option[Int] = { + // When the execution id is not set, it means there is no actual execution happening, in this + // case we don't need to generate lore id. + Option(SparkShimImpl.sessionFromPlan(plan) + .sparkContext + .getLocalProperty(SQLExecution.EXECUTION_ID_KEY)) + .map { executionId => + idGen.computeIfAbsent(executionId, _ => new AtomicInteger(0)).getAndIncrement() + } + } + + def tagForLore(sparkPlan: SparkPlan, rapidsConf: RapidsConf): SparkPlan = { + val loreDumpIds = rapidsConf.loreDumpIds + + val newPlan = if (loreDumpIds.nonEmpty) { + // We need to dump the output of nodes with the lore id in the dump ids + val loreOutputRootPath = rapidsConf.loreDumpPath.getOrElse(throw + new IllegalArgumentException(s"${RapidsConf.LORE_DUMP_PATH.key} must be set " + + s"when ${RapidsConf.LORE_DUMP_IDS.key} is set.")) + + val spark = SparkShimImpl.sessionFromPlan(sparkPlan) + val hadoopConf = { + val sc = spark.sparkContext + sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) + } + + val subqueries = mutable.Set.empty[SparkPlan] + + sparkPlan.foreachUp { + case g: GpuExec => + nextLoreIdOf(g).foreach { loreId => + g.setTagValue(LORE_ID_TAG, loreId.toString) + + loreDumpIds.get(loreId).foreach { outputLoreIds => + checkUnsupportedOperator(g) + val currentExecRootPath = new Path(loreOutputRootPath, s"loreId-$loreId") + g.setTagValue(LORE_DUMP_PATH_TAG, currentExecRootPath.toString) + val loreOutputInfo = LoreOutputInfo(outputLoreIds, + currentExecRootPath.toString) + + g.children.zipWithIndex.foreach { + case (child, idx) => + val dumpRDDInfo = LoreDumpRDDInfo(idx, loreOutputInfo, child.output, hadoopConf) + child match { + case c: BroadcastQueryStageExec => + c.broadcast.setTagValue(LORE_DUMP_RDD_TAG, dumpRDDInfo) + case o => o.setTagValue(LORE_DUMP_RDD_TAG, dumpRDDInfo) + } + } + + var nextId = g.children.length + g.transformExpressionsUp { + case sub: ExecSubqueryExpression => + if (spark.sessionState.conf.subqueryReuseEnabled) { + if (!subqueries.contains(sub.plan.canonicalized)) { + subqueries += sub.plan.canonicalized + } else { + throw new IllegalArgumentException("Subquery reuse is enabled, and we found" + + " duplicated subqueries, which is currently not supported by LORE.") + } + } + tagSubqueryPlan(nextId, sub, loreOutputInfo, hadoopConf) + nextId += 1 + sub + } + } + } + case _ => + } + + sparkPlan + + } else { + // We don't need to dump the output of the nodes, just tag the lore id + sparkPlan.foreachUp { + case g: GpuExec => + nextLoreIdOf(g).foreach { loreId => + g.setTagValue(LORE_ID_TAG, loreId.toString) + } + case _ => + } + + sparkPlan + } + + newPlan + } + + def loreIdOf(node: SparkPlan): Option[String] = { + node.getTagValue(LORE_ID_TAG) + } + + private def tagSubqueryPlan(id: Int, sub: ExecSubqueryExpression, + loreOutputInfo: LoreOutputInfo, hadoopConf: Broadcast[SerializableConfiguration]) = { + val innerPlan = sub.plan.child + if (innerPlan.isInstanceOf[GpuExec]) { + val dumpRDDInfo = LoreDumpRDDInfo(id, loreOutputInfo, innerPlan.output, + hadoopConf) + innerPlan match { + case p: GpuColumnarToRowExec => p.child.setTagValue(LORE_DUMP_RDD_TAG, dumpRDDInfo) + case c => c.setTagValue(LORE_DUMP_RDD_TAG, dumpRDDInfo) + } + } else { + throw new IllegalArgumentException(s"Subquery plan ${innerPlan.getClass.getSimpleName} " + + s"is not a GpuExec") + } + } + + private def checkUnsupportedOperator(plan: SparkPlan): Unit = { + if (plan.children.isEmpty || + plan.isInstanceOf[GpuCustomShuffleReaderExec] + ) { + throw new UnsupportedOperationException(s"Currently we don't support dumping input of " + + s"${plan.getClass.getSimpleName} operator.") + } + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/OutputLoreId.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/OutputLoreId.scala new file mode 100644 index 00000000000..28fa0b2dbbf --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/OutputLoreId.scala @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.lore + +import org.apache.hadoop.fs.Path + +case class OutputLoreId(loreId: LoreId, partitionIds: Set[Int]) { + def outputAllParitions: Boolean = partitionIds.isEmpty + + def shouldOutputPartition(partitionId: Int): Boolean = outputAllParitions || + partitionIds.contains(partitionId) +} + +case class LoreOutputInfo(outputLoreId: OutputLoreId, pathStr: String) { + def path: Path = new Path(pathStr) +} + +object OutputLoreId { + private val PARTITION_ID_RANGE_REGEX = raw"(\d+)-(\d+)".r("start", "end") + private val PARTITION_ID_REGEX = raw"(\d+)".r("partitionId") + private val PARTITION_ID_SEP_REGEX = raw" +".r + + private val OUTPUT_LORE_ID_SEP_REGEX = ", *".r + private val OUTPUT_LORE_ID_REGEX = + raw"(?\d+)(\[(?.*)\])?".r + + def apply(loreId: Int): OutputLoreId = OutputLoreId(loreId, Set.empty) + + def apply(inputStr: String): OutputLoreId = { + OUTPUT_LORE_ID_REGEX.findFirstMatchIn(inputStr).map { m => + val loreId = m.group("loreId").toInt + val partitionIds: Set[Int] = m.group("partitionIds") match { + case partitionIdsStr if partitionIdsStr != null => + PARTITION_ID_SEP_REGEX.split(partitionIdsStr).flatMap { + case PARTITION_ID_REGEX(partitionId) => + Seq(partitionId.toInt) + case PARTITION_ID_RANGE_REGEX(start, end) => + start.toInt until end.toInt + case "*" => Set.empty + case partitionIdStr => throw new IllegalArgumentException(s"Invalid partition " + + s"id: $partitionIdStr") + }.toSet + case null => { + throw new IllegalArgumentException(s"Invalid output lore id string: $inputStr, " + + s"partition ids not found!") + } + } + OutputLoreId(loreId, partitionIds) + }.getOrElse(throw new IllegalArgumentException(s"Invalid output lore ids: $inputStr")) + } + + def parse(inputStr: String): OutputLoreIds = { + require(inputStr != null, "inputStr should not be null") + + OUTPUT_LORE_ID_SEP_REGEX.split(inputStr).map(OutputLoreId(_)).map { outputLoreId => + outputLoreId.loreId -> outputLoreId + }.toMap + } +} + + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/dump.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/dump.scala new file mode 100644 index 00000000000..1b9967e1bf4 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/dump.scala @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.lore + +import com.nvidia.spark.rapids.{DumpUtils, GpuColumnVector} +import com.nvidia.spark.rapids.GpuCoalesceExec.EmptyPartition +import com.nvidia.spark.rapids.lore.GpuLore.pathOfChild +import org.apache.hadoop.fs.Path + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.rapids.execution.GpuBroadcastHelper +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration + + +case class LoreDumpRDDInfo(idxInParent: Int, loreOutputInfo: LoreOutputInfo, attrs: Seq[Attribute], + hadoopConf: Broadcast[SerializableConfiguration]) + +class GpuLoreDumpRDD(info: LoreDumpRDDInfo, input: RDD[ColumnarBatch]) + extends RDD[ColumnarBatch](input) with GpuLoreRDD { + override def rootPath: Path = pathOfChild(info.loreOutputInfo.path, info.idxInParent) + + def saveMeta(): Unit = { + val meta = LoreRDDMeta(input.getNumPartitions, this.getPartitions.map(_.index), info.attrs) + GpuLore.dumpObject(meta, pathOfMeta, this.context.hadoopConfiguration) + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + if (info.loreOutputInfo.outputLoreId.shouldOutputPartition(split.index)) { + val originalIter = input.compute(split, context) + new Iterator[ColumnarBatch] { + var batchIdx: Int = -1 + var nextBatch: Option[ColumnarBatch] = None + + override def hasNext: Boolean = { + if (batchIdx == -1) { + loadNextBatch() + } + nextBatch.isDefined + } + + override def next(): ColumnarBatch = { + val ret = dumpCurrentBatch() + loadNextBatch() + if (!hasNext) { + // This is the last batch, save the partition meta + val partitionMeta = LoreRDDPartitionMeta(batchIdx, GpuColumnVector.extractTypes(ret)) + GpuLore.dumpObject(partitionMeta, pathOfPartitionMeta(split.index), + info.hadoopConf.value.value) + } + ret + } + + private def dumpCurrentBatch(): ColumnarBatch = { + val outputPath = pathOfBatch(split.index, batchIdx) + val outputStream = outputPath.getFileSystem(info.hadoopConf.value.value) + .create(outputPath, false) + DumpUtils.dumpToParquet(nextBatch.get, outputStream) + nextBatch.get + } + + private def loadNextBatch(): Unit = { + if (originalIter.hasNext) { + nextBatch = Some(originalIter.next()) + } else { + nextBatch = None + } + batchIdx += 1 + } + } + } else { + input.compute(split, context) + } + } + + override protected def getPartitions: Array[Partition] = { + input.partitions + } +} + +class SimpleRDD(_sc: SparkContext, data: Broadcast[Any], schema: StructType) extends + RDD[ColumnarBatch](_sc, Nil) { + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + Seq(GpuBroadcastHelper.getBroadcastBatch(data, schema)).iterator + } + + override protected def getPartitions: Array[Partition] = Array(EmptyPartition(0)) +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/package.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/package.scala new file mode 100644 index 00000000000..f304ea07d97 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/package.scala @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +/** + * Lore framework is used for dumping input data of a gpu executor to disk so that it can be + * replayed in local environment for performance analysis. + *
+ * When [[RapidsConf.TAG_LORE_ID_ENABLED]] is set, during the planning phase we will tag a lore + * id to each gpu operator. Lore id is guaranteed to be unique within a query, and it's supposed + * to be same for operators with same plan. + *
+ * When [[RapidsConf.LORE_DUMP_IDS]] is set, during the execution phase we will dump the input + * data of gpu operators with lore id to disk. The dumped data can be replayed in local + * environment. The dumped data will reside in [[RapidsConf.LORE_DUMP_PATH]]. For more details, + * please refer to `docs/dev/lore.md`. + */ +package object lore { + type LoreId = Int + type OutputLoreIds = Map[LoreId, OutputLoreId] +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/replay.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/replay.scala new file mode 100644 index 00000000000..ffbe207646a --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/lore/replay.scala @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.lore + +import ai.rapids.cudf.Table +import com.nvidia.spark.rapids.{GpuColumnVector, GpuExec} +import com.nvidia.spark.rapids.Arm.withResource +import org.apache.commons.io.IOUtils +import org.apache.hadoop.fs.Path + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration + +case class GpuLoreReplayExec(idxInParent: Int, parentRootPath: String, + hadoopConf: Broadcast[SerializableConfiguration]) + extends LeafExecNode + with GpuExec { + private lazy val rdd = new GpuLoreReplayRDD(sparkSession.sparkContext, + GpuLore.pathOfChild(new Path(parentRootPath), idxInParent).toString, hadoopConf) + override def output: Seq[Attribute] = rdd.loreRDDMeta.attrs + + override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException("LoreReplayExec does not support row mode") + } + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = { + rdd + } +} + +class GpuLoreReplayRDD(sc: SparkContext, rootPathStr: String, + hadoopConf: Broadcast[SerializableConfiguration]) + extends RDD[ColumnarBatch](sc, Nil) with GpuLoreRDD { + + override def rootPath: Path = new Path(rootPathStr) + + private[lore] val loreRDDMeta: LoreRDDMeta = GpuLore.loadObject(pathOfMeta, sc + .hadoopConfiguration) + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val partitionPath = pathOfPartition(split.index) + withResource(partitionPath.getFileSystem(hadoopConf.value.value)) { fs => + if (!fs.exists(partitionPath)) { + Iterator.empty + } else { + val partitionMeta = GpuLore.loadObject[LoreRDDPartitionMeta]( + pathOfPartitionMeta(split.index), hadoopConf.value.value) + new Iterator[ColumnarBatch] { + private var batchIdx: Int = 0 + + override def hasNext: Boolean = { + batchIdx < partitionMeta.numBatches + } + + override def next(): ColumnarBatch = { + val batchPath = pathOfBatch(split.index, batchIdx) + val ret = withResource(batchPath.getFileSystem(hadoopConf.value.value)) { fs => + if (!fs.exists(batchPath)) { + throw new IllegalStateException(s"Batch file $batchPath does not exist") + } + withResource(fs.open(batchPath)) { fin => + val buffer = IOUtils.toByteArray(fin) + withResource(Table.readParquet(buffer)) { restoredTable => + GpuColumnVector.from(restoredTable, partitionMeta.dataType.toArray) + } + } + + } + batchIdx += 1 + ret + } + } + } + } + } + + override protected def getPartitions: Array[Partition] = { + (0 until loreRDDMeta.numPartitions).map(LoreReplayPartition).toArray + } +} + +case class LoreReplayPartition(override val index: Int) extends Partition diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index 51c6f52d97e..bd30459d63e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -31,6 +31,8 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.lore.{GpuLoreDumpRDD, SimpleRDD} +import com.nvidia.spark.rapids.lore.GpuLore.LORE_DUMP_RDD_TAG import com.nvidia.spark.rapids.shims.{ShimBroadcastExchangeLike, ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.SparkException @@ -486,7 +488,9 @@ abstract class GpuBroadcastExchangeExecBase( throw new IllegalStateException("A canonicalized plan is not supposed to be executed.") } try { - relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[Broadcast[T]] + val ret = relationFuture.get(timeout, TimeUnit.SECONDS) + doLoreDump(ret) + ret.asInstanceOf[Broadcast[T]] } catch { case ex: TimeoutException => logError(s"Could not execute broadcast in $timeout secs.", ex) @@ -501,6 +505,18 @@ abstract class GpuBroadcastExchangeExecBase( } } + // We have to do this explicitly here rather than similar to the general version one in + // [[GpuExec]] since in adaptive execution, the broadcast value has already been calculated + // before we tag this plan to dump. + private def doLoreDump(result: Broadcast[Any]): Unit = { + val inner = new SimpleRDD(session.sparkContext, result, schema) + getTagValue(LORE_DUMP_RDD_TAG).foreach { info => + val rdd = new GpuLoreDumpRDD(info, inner) + rdd.saveMeta() + rdd.foreach(_.close()) + } + } + override def runtimeStatistics: Statistics = { Statistics( sizeInBytes = metrics("dataSize").value, diff --git a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/execution/datasources/GpuWriteFiles.scala b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/execution/datasources/GpuWriteFiles.scala index 7cc94359daa..f1ffcf4df1f 100644 --- a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/execution/datasources/GpuWriteFiles.scala +++ b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/execution/datasources/GpuWriteFiles.scala @@ -157,7 +157,7 @@ case class GpuWriteFilesExec( s" mismatch:\n$this") } - override protected def stringArgs: Iterator[Any] = Iterator(child) + override def stringArgs: Iterator[Any] = Iterator(child) } object GpuWriteFiles { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/lore/GpuLoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/lore/GpuLoreSuite.scala new file mode 100644 index 00000000000..7db46718e89 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/lore/GpuLoreSuite.scala @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.lore + +import com.nvidia.spark.rapids.{FunSuiteWithTempDir, GpuColumnarToRowExec, RapidsConf, SparkQueryCompareTestSuite} +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{functions, DataFrame, SparkSession} +import org.apache.spark.sql.internal.SQLConf + +class GpuLoreSuite extends SparkQueryCompareTestSuite with FunSuiteWithTempDir with Logging { + test("Aggregate") { + doTestReplay("10[*]") { spark => + spark.range(0, 1000, 1, 100) + .selectExpr("id % 10 as key", "id % 100 as value") + .groupBy("key") + .agg(functions.sum("value").as("total")) + } + } + + test("Broadcast join") { + doTestReplay("32[*]") { spark => + val df1 = spark.range(0, 1000, 1, 10) + .selectExpr("id % 10 as key", "id % 100 as value") + .groupBy("key") + .agg(functions.sum("value").as("count")) + + val df2 = spark.range(0, 1000, 1, 10) + .selectExpr("(id % 10 + 5) as key", "id % 100 as value") + .groupBy("key") + .agg(functions.sum("value").as("count")) + + df1.join(df2, Seq("key")) + } + } + + test("Subquery Filter") { + doTestReplay("13[*]") { spark => + spark.range(0, 100, 1, 10) + .createTempView("df1") + + spark.range(50, 1000, 1, 10) + .createTempView("df2") + + spark.sql("select * from df1 where id > (select max(id) from df2)") + } + } + + test("Subquery in projection") { + doTestReplay("11[*]") { spark => + spark.sql( + """ + |CREATE TEMPORARY VIEW t1 + |AS SELECT * FROM VALUES + |(1, "a"), + |(2, "a"), + |(3, "a") t(id, value) + |""".stripMargin) + + spark.sql( + """ + |SELECT *, (SELECT COUNT(*) FROM t1) FROM t1 + |""".stripMargin) + } + } + + test("No broadcast join") { + doTestReplay("30[*]") { spark => + spark.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") + + val df1 = spark.range(0, 1000, 1, 10) + .selectExpr("id % 10 as key", "id % 100 as value") + .groupBy("key") + .agg(functions.sum("value").as("count")) + + val df2 = spark.range(0, 1000, 1, 10) + .selectExpr("(id % 10 + 5) as key", "id % 100 as value") + .groupBy("key") + .agg(functions.sum("value").as("count")) + + df1.join(df2, Seq("key")) + } + } + + test("AQE broadcast") { + doTestReplay("90[*]") { spark => + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + + val df1 = spark.range(0, 1000, 1, 10) + .selectExpr("id % 10 as key", "id % 100 as value") + .groupBy("key") + .agg(functions.sum("value").as("count")) + + val df2 = spark.range(0, 1000, 1, 10) + .selectExpr("(id % 10 + 5) as key", "id % 100 as value") + .groupBy("key") + .agg(functions.sum("value").as("count")) + + df1.join(df2, Seq("key")) + } + } + + test("AQE Exchange") { + doTestReplay("28[*]") { spark => + spark.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + + spark.range(0, 1000, 1, 100) + .selectExpr("id % 10 as key", "id % 100 as value") + .groupBy("key") + .agg(functions.sum("value").as("total")) + } + } + + test("Partition only") { + withGpuSparkSession{ spark => + spark.conf.set(RapidsConf.LORE_DUMP_PATH.key, TEST_FILES_ROOT.getAbsolutePath) + spark.conf.set(RapidsConf.LORE_DUMP_IDS.key, "3[0 2]") + + val df = spark.range(0, 1000, 1, 100) + .selectExpr("id % 10 as key", "id % 100 as value") + + val res = df.collect().length + println(s"Length of original: $res") + + + val restoredRes = GpuColumnarToRowExec(GpuLore.restoreGpuExec( + new Path(s"${TEST_FILES_ROOT.getAbsolutePath}/loreId-3"), spark)) + .executeCollect() + .length + + assert(20 == restoredRes) + } + } + + private def doTestReplay(loreDumpIds: String)(dfFunc: SparkSession => DataFrame) = { + val loreId = OutputLoreId.parse(loreDumpIds).head._1 + withGpuSparkSession { spark => + spark.conf.set(RapidsConf.LORE_DUMP_PATH.key, TEST_FILES_ROOT.getAbsolutePath) + spark.conf.set(RapidsConf.LORE_DUMP_IDS.key, loreDumpIds) + + val df = dfFunc(spark) + + val expectedLength = df.collect().length + + val restoredResultLength = GpuColumnarToRowExec(GpuLore.restoreGpuExec( + new Path(s"${TEST_FILES_ROOT.getAbsolutePath}/loreId-$loreId"), + spark)) + .executeCollect() + .length + + assert(expectedLength == restoredResultLength) + } + } +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/lore/OutputLoreIdSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/lore/OutputLoreIdSuite.scala new file mode 100644 index 00000000000..aad3d997b9d --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/lore/OutputLoreIdSuite.scala @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.lore + +import org.scalatest.funsuite.AnyFunSuite + +class OutputLoreIdSuite extends AnyFunSuite { + test("Parse one output lore id") { + val expectedLoreIds = Map(1 -> OutputLoreId(1, Set(1, 2, 4, 8))) + val loreIds = OutputLoreId.parse("1[1 2 4 8]") + + assert(loreIds == expectedLoreIds) + } + + test("Parse multi output lore id") { + val expectedLoreIds = Map( + 1 -> OutputLoreId(1, Set(1, 2, 4, 8)), + 2 -> OutputLoreId(2, Set(1, 4, 5, 6, 7, 8, 100)) + ) + val loreIds = OutputLoreId.parse("1[1 2 4 8], 2[1 4-9 100]") + + assert(loreIds == expectedLoreIds) + } + + test("Parse empty output lore id should fail") { + assertThrows[IllegalArgumentException] { + OutputLoreId.parse(" 1, 2 ") + } + } + + test("Parse mixed") { + val expectedLoreIds = Map( + 1 -> OutputLoreId(1), + 2 -> OutputLoreId(2, Set(4, 5, 8)), + 3 -> OutputLoreId(3, Set(1, 2, 4, 8)) + ) + val loreIds = OutputLoreId.parse("1[*], 2[4-6 8] , 3[1 2 4 8]") + + assert(loreIds == expectedLoreIds) + } +}