diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f93c879a0..cdeffa99f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,6 @@ repos: rev: v4.0.1 hooks: - id: check-added-large-files - name: Check for file over 1.5MiB - args: ['--maxkb=1500', '--enforce-all'] + name: Check for file over 2.0MiB + args: ['--maxkb=2000', '--enforce-all'] diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala index ce556b8a3..877640d7b 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphNode} import org.apache.spark.sql.rapids.tool.{AppBase, BuildSide, JoinType, ToolUtils} +import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph class ExecInfo( val sqlID: Long, @@ -143,7 +144,7 @@ object SQLPlanParser extends Logging { sqlDesc: String, checker: PluginTypeChecker, app: AppBase): PlanInfo = { - val planGraph = SparkPlanGraph(planInfo) + val planGraph = ToolsPlanGraph(planInfo) // Find all the node graphs that should be excluded and send it to the parsePlanNode val excludedNodes = buildSkippedReusedNodesForPlan(planGraph) // we want the sub-graph nodes to be inside of the wholeStageCodeGen so use nodes diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala index 3aafb4597..f903bed5f 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala @@ -34,9 +34,9 @@ import org.apache.spark.deploy.history.{EventLogFileReader, EventLogFileWriter} import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListenerEnvironmentUpdate, SparkListenerEvent, SparkListenerJobStart, SparkListenerLogStart, StageInfo} import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphNode} +import org.apache.spark.sql.execution.ui.SparkPlanGraphNode import org.apache.spark.sql.rapids.tool.qualification.MLFunctions -import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil} +import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph} import org.apache.spark.util.Utils // Handles updating and caching Spark Properties for a Spark application. @@ -340,7 +340,7 @@ abstract class AppBase( protected def checkMetadataForReadSchema(sqlID: Long, planInfo: SparkPlanInfo): Unit = { // check if planInfo has ReadSchema val allMetaWithSchema = getPlanMetaWithSchema(planInfo) - val planGraph = SparkPlanGraph(planInfo) + val planGraph = ToolsPlanGraph(planInfo) val allNodes = planGraph.allNodes allMetaWithSchema.foreach { plan => @@ -365,7 +365,7 @@ abstract class AppBase( if (hiveEnabled) { // only scan for hive when the CatalogImplementation is using hive val allPlanWithHiveScan = getPlanInfoWithHiveScan(planInfo) allPlanWithHiveScan.foreach { hiveReadPlan => - val sqlGraph = SparkPlanGraph(hiveReadPlan) + val sqlGraph = ToolsPlanGraph(hiveReadPlan) val hiveScanNode = sqlGraph.allNodes.head val scanHiveMeta = HiveParseHelper.parseReadNode(hiveScanNode) dataSourceInfo += DataSourceCase(sqlID, diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala index f97c3119c..c623a9a71 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala @@ -28,8 +28,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.metric.SQLMetricInfo -import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.rapids.tool.{AppBase, ToolUtils} +import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph import org.apache.spark.ui.UIUtils @@ -238,7 +238,7 @@ class ApplicationInfo( // Connects Operators to Stages using AccumulatorIDs def connectOperatorToStage(): Unit = { for ((sqlId, planInfo) <- sqlPlans) { - val planGraph = SparkPlanGraph(planInfo) + val planGraph = ToolsPlanGraph(planInfo) // Maps stages to operators by checking for non-zero intersection // between nodeMetrics and stageAccumulateIDs val nodeIdToStage = planGraph.allNodes.map { node => @@ -256,7 +256,7 @@ class ApplicationInfo( connectOperatorToStage() for ((sqlID, planInfo) <- sqlPlans) { checkMetadataForReadSchema(sqlID, planInfo) - val planGraph = SparkPlanGraph(planInfo) + val planGraph = ToolsPlanGraph(planInfo) // SQLPlanMetric is a case Class of // (name: String,accumulatorId: Long,metricType: String) val allnodes = planGraph.allNodes @@ -339,7 +339,7 @@ class ApplicationInfo( v.contains(s) }.keys.toSeq val nodeNames = sqlPlans.get(j.sqlID.get).map { planInfo => - val nodes = SparkPlanGraph(planInfo).allNodes + val nodes = ToolsPlanGraph(planInfo).allNodes val validNodes = nodes.filter { n => nodeIds.contains((j.sqlID.get, n.id)) } diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala index b4c72140d..c0bd63b9d 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala @@ -28,8 +28,8 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEnvironmentUpdate, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.rapids.tool.{AppBase, GpuEventLogException, SupportedMLFuncsName, ToolUtils} +import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph class QualificationAppInfo( eventLogInfo: Option[EventLogInfo], @@ -765,7 +765,7 @@ class QualificationAppInfo( private[qualification] def processSQLPlan(sqlID: Long, planInfo: SparkPlanInfo): Unit = { checkMetadataForReadSchema(sqlID, planInfo) - val planGraph = SparkPlanGraph(planInfo) + val planGraph = ToolsPlanGraph(planInfo) val allnodes = planGraph.allNodes for (node <- allnodes) { checkGraphNodeForReads(sqlID, node) diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/ToolsPlanGraph.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/ToolsPlanGraph.scala new file mode 100644 index 000000000..308735fa0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/ToolsPlanGraph.scala @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.tool.util + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable + +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphEdge, SparkPlanGraphNode, SQLPlanMetric} + +/** + * This code is mostly copied from org.apache.spark.sql.execution.ui.SparkPlanGraph + * with changes to handle GPU nodes. Without this special handle, the default SparkPlanGraph + * would not be able to recognize reused/exchange nodes leading to duplicating nodes. + * + * Build a SparkPlanGraph from the root of a SparkPlan tree. + */ +object ToolsPlanGraph { + /** + * Build a SparkPlanGraph from the root of a SparkPlan tree. + */ + def apply(planInfo: SparkPlanInfo): SparkPlanGraph = { + val nodeIdGenerator = new AtomicLong(0) + val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() + val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() + val exchanges = mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]() + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null, exchanges) + new SparkPlanGraph(nodes, edges) + } + + private def processPlanInfo(nodeName: String): String = { + if (nodeName.startsWith("Gpu")) { + nodeName.replaceFirst("Gpu", "") + } else { + nodeName + } + } + + private def buildSparkPlanGraphNode( + planInfo: SparkPlanInfo, + nodeIdGenerator: AtomicLong, + nodes: mutable.ArrayBuffer[SparkPlanGraphNode], + edges: mutable.ArrayBuffer[SparkPlanGraphEdge], + parent: SparkPlanGraphNode, + subgraph: SparkPlanGraphCluster, + exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = { + processPlanInfo(planInfo.nodeName) match { + case name if name.startsWith("WholeStageCodegen") => + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + } + + val cluster = new SparkPlanGraphCluster( + nodeIdGenerator.getAndIncrement(), + planInfo.nodeName, + planInfo.simpleString, + mutable.ArrayBuffer[SparkPlanGraphNode](), + metrics) + nodes += cluster + + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges) + case "InputAdapter" => + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case "BroadcastQueryStage" | "ShuffleQueryStage" => + if (exchanges.contains(planInfo.children.head)) { + // Point to the re-used exchange + val node = exchanges(planInfo.children.head) + edges += SparkPlanGraphEdge(node.id, parent.id) + } else { + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + } + case "TableCacheQueryStage" => + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case "Subquery" | "SubqueryBroadcast" if subgraph != null => + // Subquery should not be included in WholeStageCodegen + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) + case "Subquery" | "SubqueryBroadcast" if exchanges.contains(planInfo) => + // Point to the re-used subquery + val node = exchanges(planInfo) + edges += SparkPlanGraphEdge(node.id, parent.id) + case "ReusedSubquery" => + // Re-used subquery might appear before the original subquery, so skip this node and let + // the previous `case` make sure the re-used and the original point to the same node. + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, subgraph, exchanges) + case "ReusedExchange" if exchanges.contains(planInfo.children.head) => + // Point to the re-used exchange + val node = exchanges(planInfo.children.head) + edges += SparkPlanGraphEdge(node.id, parent.id) + case name => + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + } + val node = new SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, metrics) + if (subgraph == null) { + nodes += node + } else { + subgraph.nodes += node + } + if (name.contains("Exchange") || name.contains("Subquery")) { + exchanges += planInfo -> node + } + + if (parent != null) { + edges += SparkPlanGraphEdge(node.id, parent.id) + } + planInfo.children.foreach( + buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph, exchanges)) + } + } +} diff --git a/core/src/test/resources/spark-events-profiling/nds_q66_gpu.zstd b/core/src/test/resources/spark-events-profiling/nds_q66_gpu.zstd new file mode 100644 index 000000000..e6ffca3fe Binary files /dev/null and b/core/src/test/resources/spark-events-profiling/nds_q66_gpu.zstd differ diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala index 16a0f4388..de6abf3be 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/planparser/SqlPlanParserSuite.scala @@ -847,7 +847,7 @@ class SQLPlanParserSuite extends BaseTestSuite { } } - test("get_json_object is supported in Project") { + test("get_json_object is unsupported in Project") { // get_json_object is disabled by default in the RAPIDS plugin TrampolineUtil.withTempDir { parquetoutputLoc => TrampolineUtil.withTempDir { eventLogDir => diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala index df2015956..65dcd7c81 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala @@ -895,4 +895,14 @@ class ApplicationInfoSuite extends FunSuite with Logging { } } } + + test("test gpu reused subquery") { + val apps = ToolTestUtils.processProfileApps(Array(s"$logDir/nds_q66_gpu.zstd"), sparkSession) + val collect = new CollectInformation(apps) + val sqlToStageInfo = collect.getSQLToStage + val countScanParquet = sqlToStageInfo.flatMap(_.nodeNames).count(_.contains("GpuScan parquet")) + // There are 12 `GpuScan parquet` raw nodes, but 4 are inside `ReusedExchange`, 1 is inside + // `ReusedSubquery`, so we expect 7. + assert(countScanParquet == 7) + } }