Skip to content

Commit

Permalink
refactor due to new optimizations from dev
Browse files Browse the repository at this point in the history
Signed-off-by: cindyyuanjiang <[email protected]>
  • Loading branch information
cindyyuanjiang committed Dec 19, 2024
1 parent f8346dd commit 550406c
Showing 1 changed file with 29 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster,
import org.apache.spark.sql.rapids.tool.{AppBase, RDDCheckHelper, SqlPlanInfoGraphBuffer, SqlPlanInfoGraphEntry}
import org.apache.spark.sql.rapids.tool.profiling.ApplicationInfo
import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
import org.apache.spark.sql.rapids.tool.store.{AccumInfo, AccumMetaRef, AccumNameRef, DataSourceRecord}
import org.apache.spark.sql.rapids.tool.store.{AccumInfo, AccumMetaRef, DataSourceRecord}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

/**
Expand Down Expand Up @@ -110,13 +110,19 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
}

/**
* Retrieves the set of task IDs associated with a specific stage.
* Retrieves task update values from the accumulator info for the specified stage ID.
*
* @param stageId The ID of the stage.
* @return A set of task IDs corresponding to the given stage ID.
* @param accumInfo AccumInfo object containing the task updates map.
* @param stageId The stage ID for which task updates need to be retrived.
* @return A sorted sequence of task update values (`Long`) corresponding to the tasks
* in the specified stage.
*/
private def getStageTaskIds(stageId: Int): Set[Long] = {
app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId).toSet
private def filterAccumTaskUpdatesForStage(accumInfo: AccumInfo, stageId: Int): Seq[Long] = {
val stageTaskIds = app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId).toSet
stageTaskIds.toSeq.collect {
case taskId if accumInfo.taskUpdatesMap.contains(taskId) =>
accumInfo.taskUpdatesMap(taskId)
}.toSeq.sorted
}

/**
Expand Down Expand Up @@ -380,7 +386,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
}

/**
* Generate IO-related diagnostic metrics for the SQL plan. Metrics include:
* Generates IO-related diagnostic metrics for the SQL plan. Metrics include:
* - Output rows
* - Scan time
* - Output batches
Expand All @@ -389,29 +395,32 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
* - Fetch wait time
* - GPU decode time
*
* @return A sequence of `IODiagnosticResult` objects containing diagnostic metrics.
* This method processes accumulator information for each SQL stage and node and
* computes statistical results (min, median, max, sum) for IO-related metrics.
*
* @return A sequence of `IODiagnosticResult` objects one per SQL stage and node.
*/
def generateIODiagnosticAccums(): Seq[IODiagnosticResult] = {
val emptyAccumInfo = new AccumInfo(AccumMetaRef.EMPTY_ACCUM_META_REF)
// Transform the diagnostic metrics map into a sequence of results
IODiagnosticMetricsMap.toSeq.flatMap { case ((sqlId, nodeId, stageIds), sqlAccums) =>
// Process each stage ID and compute diagnostic results
stageIds.split(",").filter(_.nonEmpty).map(_.toInt).flatMap { stageId =>
val nodeName = sqlAccums.head.nodeName
val stageTaskIds = getStageTaskIds(stageId)
// A mapping from metric name to its statistical results (min, median, max, sum)

// Initialize a map to store statistics for each IO metric
val metricNameToStatistics = HashMap.empty[String, StatisticsMetrics].
withDefaultValue(StatisticsMetrics.ZERO_RECORD)

// Iterate through each IO metric
// Process each accumulator for the current SQL stage
sqlAccums.foreach { sqlAccum =>
val accumInfo = app.accumManager.accumInfoMap.getOrElse(
sqlAccum.accumulatorId,
new AccumInfo(AccumMetaRef(0L, AccumNameRef("")))
)
// Compute the metric's statistics (min, median, max, sum) for the given stage.
// Store the results if available.
val filteredTaskUpdates =
accumInfo.taskUpdatesMap.filterKeys(stageTaskIds.contains).values.toSeq.sorted
sqlAccum.accumulatorId, emptyAccumInfo)

// Retrieve and sort task updates correspond to the current stage
val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageId)

// Compute the metric's statistics and store the results if available
if (filteredTaskUpdates.nonEmpty) {
val min = filteredTaskUpdates.head
val max = filteredTaskUpdates.last
Expand All @@ -429,7 +438,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
}

if (metricNameToStatistics.isEmpty) {
// metricNameToStatistics is not updated - there is no IO metrics result for this stage
// No IO metric statistics were computed for this stage
None
} else {
Some(IODiagnosticResult(
Expand Down Expand Up @@ -467,9 +476,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
app.accumManager.accumInfoMap.flatMap { accumMapEntry =>
val accumInfo = accumMapEntry._2
accumInfo.stageValuesMap.keySet.flatMap( stageId => {
val stageTaskIds = getStageTaskIds(stageId)
val filteredTaskUpdates =
accumInfo.taskUpdatesMap.filterKeys(stageTaskIds.contains).values.toSeq.sorted
val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageId)

if (filteredTaskUpdates.isEmpty) {
None
Expand Down

0 comments on commit 550406c

Please sign in to comment.