Skip to content

Commit

Permalink
Remove ReusedSubquery from SparkPlanGraph construction
Browse files Browse the repository at this point in the history
Signed-off-by: Ahmed Hussein (amahussein) <[email protected]>

Fixes NVIDIA#718

- Reused subqueries should be excluded from the metrics.
- Added a common graph builder to be used for both CPU/GPU logs
- Added the unit test for GPU eventlogs
  • Loading branch information
amahussein committed Jan 25, 2024
1 parent a154c0b commit bee74bb
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 14 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ repos:
rev: v4.0.1
hooks:
- id: check-added-large-files
name: Check for file over 1.5MiB
args: ['--maxkb=1500', '--enforce-all']
name: Check for file over 2.0MiB
args: ['--maxkb=2000', '--enforce-all']

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphNode}
import org.apache.spark.sql.rapids.tool.{AppBase, BuildSide, JoinType, ToolUtils}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

class ExecInfo(
val sqlID: Long,
Expand Down Expand Up @@ -143,7 +144,7 @@ object SQLPlanParser extends Logging {
sqlDesc: String,
checker: PluginTypeChecker,
app: AppBase): PlanInfo = {
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
// Find all the node graphs that should be excluded and send it to the parsePlanNode
val excludedNodes = buildSkippedReusedNodesForPlan(planGraph)
// we want the sub-graph nodes to be inside of the wholeStageCodeGen so use nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ import org.apache.spark.deploy.history.{EventLogFileReader, EventLogFileWriter}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListenerEnvironmentUpdate, SparkListenerEvent, SparkListenerJobStart, SparkListenerLogStart, StageInfo}
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphNode}
import org.apache.spark.sql.execution.ui.SparkPlanGraphNode
import org.apache.spark.sql.rapids.tool.qualification.MLFunctions
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph}
import org.apache.spark.util.Utils

// Handles updating and caching Spark Properties for a Spark application.
Expand Down Expand Up @@ -340,7 +340,7 @@ abstract class AppBase(
protected def checkMetadataForReadSchema(sqlID: Long, planInfo: SparkPlanInfo): Unit = {
// check if planInfo has ReadSchema
val allMetaWithSchema = getPlanMetaWithSchema(planInfo)
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
val allNodes = planGraph.allNodes

allMetaWithSchema.foreach { plan =>
Expand All @@ -365,7 +365,7 @@ abstract class AppBase(
if (hiveEnabled) { // only scan for hive when the CatalogImplementation is using hive
val allPlanWithHiveScan = getPlanInfoWithHiveScan(planInfo)
allPlanWithHiveScan.foreach { hiveReadPlan =>
val sqlGraph = SparkPlanGraph(hiveReadPlan)
val sqlGraph = ToolsPlanGraph(hiveReadPlan)
val hiveScanNode = sqlGraph.allNodes.head
val scanHiveMeta = HiveParseHelper.parseReadNode(hiveScanNode)
dataSourceInfo += DataSourceCase(sqlID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.metric.SQLMetricInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.rapids.tool.{AppBase, ToolUtils}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph
import org.apache.spark.ui.UIUtils


Expand Down Expand Up @@ -238,7 +238,7 @@ class ApplicationInfo(
// Connects Operators to Stages using AccumulatorIDs
def connectOperatorToStage(): Unit = {
for ((sqlId, planInfo) <- sqlPlans) {
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
// Maps stages to operators by checking for non-zero intersection
// between nodeMetrics and stageAccumulateIDs
val nodeIdToStage = planGraph.allNodes.map { node =>
Expand All @@ -256,7 +256,7 @@ class ApplicationInfo(
connectOperatorToStage()
for ((sqlID, planInfo) <- sqlPlans) {
checkMetadataForReadSchema(sqlID, planInfo)
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
// SQLPlanMetric is a case Class of
// (name: String,accumulatorId: Long,metricType: String)
val allnodes = planGraph.allNodes
Expand Down Expand Up @@ -339,7 +339,7 @@ class ApplicationInfo(
v.contains(s)
}.keys.toSeq
val nodeNames = sqlPlans.get(j.sqlID.get).map { planInfo =>
val nodes = SparkPlanGraph(planInfo).allNodes
val nodes = ToolsPlanGraph(planInfo).allNodes
val validNodes = nodes.filter { n =>
nodeIds.contains((j.sqlID.get, n.id))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerEnvironmentUpdate, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.rapids.tool.{AppBase, GpuEventLogException, SupportedMLFuncsName, ToolUtils}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

class QualificationAppInfo(
eventLogInfo: Option[EventLogInfo],
Expand Down Expand Up @@ -765,7 +765,7 @@ class QualificationAppInfo(

private[qualification] def processSQLPlan(sqlID: Long, planInfo: SparkPlanInfo): Unit = {
checkMetadataForReadSchema(sqlID, planInfo)
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
val allnodes = planGraph.allNodes
for (node <- allnodes) {
checkGraphNodeForReads(sqlID, node)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids.tool.util

import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable

import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphEdge, SparkPlanGraphNode, SQLPlanMetric}

object ToolsPlanGraph {
/**
* Build a SparkPlanGraph from the root of a SparkPlan tree.
*/
def apply(planInfo: SparkPlanInfo): SparkPlanGraph = {
val nodeIdGenerator = new AtomicLong(0)
val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]()
val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]()
val exchanges = mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]()
buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null, exchanges)
new SparkPlanGraph(nodes, edges)
}

private def processPlanInfo(nodeName: String): String = {
if (nodeName.startsWith("Gpu")) {
nodeName.replaceFirst("Gpu", "")
} else {
nodeName
}
}

private def buildSparkPlanGraphNode(
planInfo: SparkPlanInfo,
nodeIdGenerator: AtomicLong,
nodes: mutable.ArrayBuffer[SparkPlanGraphNode],
edges: mutable.ArrayBuffer[SparkPlanGraphEdge],
parent: SparkPlanGraphNode,
subgraph: SparkPlanGraphCluster,
exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = {
processPlanInfo(planInfo.nodeName) match {
case name if name.startsWith("WholeStageCodegen") =>
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}

val cluster = new SparkPlanGraphCluster(
nodeIdGenerator.getAndIncrement(),
planInfo.nodeName,
planInfo.simpleString,
mutable.ArrayBuffer[SparkPlanGraphNode](),
metrics)
nodes += cluster

buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges)
case "InputAdapter" =>
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "BroadcastQueryStage" | "ShuffleQueryStage" =>
if (exchanges.contains(planInfo.children.head)) {
// Point to the re-used exchange
val node = exchanges(planInfo.children.head)
edges += SparkPlanGraphEdge(node.id, parent.id)
} else {
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
}
case "TableCacheQueryStage" =>
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "Subquery" | "SubqueryBroadcast" if subgraph != null =>
// Subquery should not be included in WholeStageCodegen
buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "Subquery" | "SubqueryBroadcast" if exchanges.contains(planInfo) =>
// Point to the re-used subquery
val node = exchanges(planInfo)
edges += SparkPlanGraphEdge(node.id, parent.id)
case "ReusedSubquery" =>
// Re-used subquery might appear before the original subquery, so skip this node and let
// the previous `case` make sure the re-used and the original point to the same node.
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, subgraph, exchanges)
case "ReusedExchange" if exchanges.contains(planInfo.children.head) =>
// Point to the re-used exchange
val node = exchanges(planInfo.children.head)
edges += SparkPlanGraphEdge(node.id, parent.id)
case name =>
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val node = new SparkPlanGraphNode(
nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
planInfo.simpleString, metrics)
if (subgraph == null) {
nodes += node
} else {
subgraph.nodes += node
}
if (name.contains("Exchange") || name.contains("Subquery")) {
exchanges += planInfo -> node
}

if (parent != null) {
edges += SparkPlanGraphEdge(node.id, parent.id)
}
planInfo.children.foreach(
buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph, exchanges))
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ class SQLPlanParserSuite extends BaseTestSuite {
}
}

test("get_json_object is supported in Project") {
test("get_json_object is unsupported in Project") {
// get_json_object is disabled by default in the RAPIDS plugin
TrampolineUtil.withTempDir { parquetoutputLoc =>
TrampolineUtil.withTempDir { eventLogDir =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,4 +895,14 @@ class ApplicationInfoSuite extends FunSuite with Logging {
}
}
}

test("test gpu reused subquery") {
val apps = ToolTestUtils.processProfileApps(Array(s"$logDir/nds_q66_gpu.zstd"), sparkSession)
val collect = new CollectInformation(apps)
val sqlToStageInfo = collect.getSQLToStage
val countScanParquet = sqlToStageInfo.flatMap(_.nodeNames).count(_.contains("GpuScan parquet"))
// There are 12 `GpuScan parquet` raw nodes, but 4 are inside `ReusedExchange`, 1 is inside
// `ReusedSubquery`, so we expect 7.
assert(countScanParquet == 7)
}
}

0 comments on commit bee74bb

Please sign in to comment.