Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ReusedSubquery from SparkPlanGraph construction #741

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ repos:
rev: v4.0.1
hooks:
- id: check-added-large-files
name: Check for file over 1.5MiB
args: ['--maxkb=1500', '--enforce-all']
name: Check for file over 2.0MiB
args: ['--maxkb=2000', '--enforce-all']

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphNode}
import org.apache.spark.sql.rapids.tool.{AppBase, BuildSide, JoinType, ToolUtils}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

class ExecInfo(
val sqlID: Long,
Expand Down Expand Up @@ -143,7 +144,7 @@ object SQLPlanParser extends Logging {
sqlDesc: String,
checker: PluginTypeChecker,
app: AppBase): PlanInfo = {
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
// Find all the node graphs that should be excluded and send it to the parsePlanNode
val excludedNodes = buildSkippedReusedNodesForPlan(planGraph)
// we want the sub-graph nodes to be inside of the wholeStageCodeGen so use nodes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ import org.apache.spark.deploy.history.{EventLogFileReader, EventLogFileWriter}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListenerEnvironmentUpdate, SparkListenerEvent, SparkListenerJobStart, SparkListenerLogStart, StageInfo}
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphNode}
import org.apache.spark.sql.execution.ui.SparkPlanGraphNode
import org.apache.spark.sql.rapids.tool.qualification.MLFunctions
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph}
import org.apache.spark.util.Utils

// Handles updating and caching Spark Properties for a Spark application.
Expand Down Expand Up @@ -340,7 +340,7 @@ abstract class AppBase(
protected def checkMetadataForReadSchema(sqlID: Long, planInfo: SparkPlanInfo): Unit = {
// check if planInfo has ReadSchema
val allMetaWithSchema = getPlanMetaWithSchema(planInfo)
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
val allNodes = planGraph.allNodes

allMetaWithSchema.foreach { plan =>
Expand All @@ -365,7 +365,7 @@ abstract class AppBase(
if (hiveEnabled) { // only scan for hive when the CatalogImplementation is using hive
val allPlanWithHiveScan = getPlanInfoWithHiveScan(planInfo)
allPlanWithHiveScan.foreach { hiveReadPlan =>
val sqlGraph = SparkPlanGraph(hiveReadPlan)
val sqlGraph = ToolsPlanGraph(hiveReadPlan)
val hiveScanNode = sqlGraph.allNodes.head
val scanHiveMeta = HiveParseHelper.parseReadNode(hiveScanNode)
dataSourceInfo += DataSourceCase(sqlID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.metric.SQLMetricInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.rapids.tool.{AppBase, ToolUtils}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph
import org.apache.spark.ui.UIUtils


Expand Down Expand Up @@ -238,7 +238,7 @@ class ApplicationInfo(
// Connects Operators to Stages using AccumulatorIDs
def connectOperatorToStage(): Unit = {
for ((sqlId, planInfo) <- sqlPlans) {
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
// Maps stages to operators by checking for non-zero intersection
// between nodeMetrics and stageAccumulateIDs
val nodeIdToStage = planGraph.allNodes.map { node =>
Expand All @@ -256,7 +256,7 @@ class ApplicationInfo(
connectOperatorToStage()
for ((sqlID, planInfo) <- sqlPlans) {
checkMetadataForReadSchema(sqlID, planInfo)
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
// SQLPlanMetric is a case Class of
// (name: String,accumulatorId: Long,metricType: String)
val allnodes = planGraph.allNodes
Expand Down Expand Up @@ -339,7 +339,7 @@ class ApplicationInfo(
v.contains(s)
}.keys.toSeq
val nodeNames = sqlPlans.get(j.sqlID.get).map { planInfo =>
val nodes = SparkPlanGraph(planInfo).allNodes
val nodes = ToolsPlanGraph(planInfo).allNodes
val validNodes = nodes.filter { n =>
nodeIds.contains((j.sqlID.get, n.id))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerEnvironmentUpdate, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.rapids.tool.{AppBase, GpuEventLogException, SupportedMLFuncsName, ToolUtils}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

class QualificationAppInfo(
eventLogInfo: Option[EventLogInfo],
Expand Down Expand Up @@ -765,7 +765,7 @@ class QualificationAppInfo(

private[qualification] def processSQLPlan(sqlID: Long, planInfo: SparkPlanInfo): Unit = {
checkMetadataForReadSchema(sqlID, planInfo)
val planGraph = SparkPlanGraph(planInfo)
val planGraph = ToolsPlanGraph(planInfo)
val allnodes = planGraph.allNodes
for (node <- allnodes) {
checkGraphNodeForReads(sqlID, node)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids.tool.util

import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable

import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphEdge, SparkPlanGraphNode, SQLPlanMetric}

object ToolsPlanGraph {
nartal1 marked this conversation as resolved.
Show resolved Hide resolved
/**
* Build a SparkPlanGraph from the root of a SparkPlan tree.
*/
def apply(planInfo: SparkPlanInfo): SparkPlanGraph = {
val nodeIdGenerator = new AtomicLong(0)
val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]()
val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]()
val exchanges = mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]()
buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null, exchanges)
new SparkPlanGraph(nodes, edges)
}

private def processPlanInfo(nodeName: String): String = {
if (nodeName.startsWith("Gpu")) {
nodeName.replaceFirst("Gpu", "")
} else {
nodeName
}
}

private def buildSparkPlanGraphNode(
planInfo: SparkPlanInfo,
nodeIdGenerator: AtomicLong,
nodes: mutable.ArrayBuffer[SparkPlanGraphNode],
edges: mutable.ArrayBuffer[SparkPlanGraphEdge],
parent: SparkPlanGraphNode,
subgraph: SparkPlanGraphCluster,
exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = {
processPlanInfo(planInfo.nodeName) match {
case name if name.startsWith("WholeStageCodegen") =>
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}

val cluster = new SparkPlanGraphCluster(
nodeIdGenerator.getAndIncrement(),
planInfo.nodeName,
planInfo.simpleString,
mutable.ArrayBuffer[SparkPlanGraphNode](),
metrics)
nodes += cluster

buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges)
case "InputAdapter" =>
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "BroadcastQueryStage" | "ShuffleQueryStage" =>
if (exchanges.contains(planInfo.children.head)) {
// Point to the re-used exchange
val node = exchanges(planInfo.children.head)
edges += SparkPlanGraphEdge(node.id, parent.id)
} else {
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
}
case "TableCacheQueryStage" =>
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "Subquery" | "SubqueryBroadcast" if subgraph != null =>
// Subquery should not be included in WholeStageCodegen
buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "Subquery" | "SubqueryBroadcast" if exchanges.contains(planInfo) =>
// Point to the re-used subquery
val node = exchanges(planInfo)
edges += SparkPlanGraphEdge(node.id, parent.id)
case "ReusedSubquery" =>
// Re-used subquery might appear before the original subquery, so skip this node and let
// the previous `case` make sure the re-used and the original point to the same node.
buildSparkPlanGraphNode(
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, subgraph, exchanges)
case "ReusedExchange" if exchanges.contains(planInfo.children.head) =>
// Point to the re-used exchange
val node = exchanges(planInfo.children.head)
edges += SparkPlanGraphEdge(node.id, parent.id)
case name =>
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val node = new SparkPlanGraphNode(
nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
planInfo.simpleString, metrics)
if (subgraph == null) {
nodes += node
} else {
subgraph.nodes += node
}
if (name.contains("Exchange") || name.contains("Subquery")) {
exchanges += planInfo -> node
}

if (parent != null) {
edges += SparkPlanGraphEdge(node.id, parent.id)
}
planInfo.children.foreach(
buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph, exchanges))
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ class SQLPlanParserSuite extends BaseTestSuite {
}
}

test("get_json_object is supported in Project") {
test("get_json_object is unsupported in Project") {
// get_json_object is disabled by default in the RAPIDS plugin
TrampolineUtil.withTempDir { parquetoutputLoc =>
TrampolineUtil.withTempDir { eventLogDir =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,4 +895,14 @@ class ApplicationInfoSuite extends FunSuite with Logging {
}
}
}

test("test gpu reused subquery") {
val apps = ToolTestUtils.processProfileApps(Array(s"$logDir/nds_q66_gpu.zstd"), sparkSession)
val collect = new CollectInformation(apps)
val sqlToStageInfo = collect.getSQLToStage
val countScanParquet = sqlToStageInfo.flatMap(_.nodeNames).count(_.contains("GpuScan parquet"))
// There are 12 `GpuScan parquet` raw nodes, but 4 are inside `ReusedExchange`, 1 is inside
// `ReusedSubquery`, so we expect 7.
assert(countScanParquet == 7)
}
}
Loading