Skip to content

Commit

Permalink
add back getStageTaskIds to avoid computing stage ids multiple times …
Browse files Browse the repository at this point in the history
…when unnecessary

Signed-off-by: cindyyuanjiang <[email protected]>
  • Loading branch information
cindyyuanjiang committed Dec 19, 2024
1 parent 550406c commit 6d0d798
Showing 1 changed file with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
IODiagnosticMetricsMap(key) += accum
}

/**
* Retrieves the task IDs associated with a specific stage.
*
* @param stageId The ID of the stage.
* @return A seq of task IDs corresponding to the given stage ID.
*/
private def getStageTaskIds(stageId: Int): Seq[Long] = {
app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId).toSet.toSeq
}

/**
* Retrieves task update values from the accumulator info for the specified stage ID.
*
Expand All @@ -117,9 +127,9 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
* @return A sorted sequence of task update values (`Long`) corresponding to the tasks
* in the specified stage.
*/
private def filterAccumTaskUpdatesForStage(accumInfo: AccumInfo, stageId: Int): Seq[Long] = {
val stageTaskIds = app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId).toSet
stageTaskIds.toSeq.collect {
private def filterAccumTaskUpdatesForStage(accumInfo: AccumInfo, stageTaskIds: Seq[Long])
: Seq[Long] = {
stageTaskIds.collect {
case taskId if accumInfo.taskUpdatesMap.contains(taskId) =>
accumInfo.taskUpdatesMap(taskId)
}.toSeq.sorted
Expand Down Expand Up @@ -406,6 +416,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
IODiagnosticMetricsMap.toSeq.flatMap { case ((sqlId, nodeId, stageIds), sqlAccums) =>
// Process each stage ID and compute diagnostic results
stageIds.split(",").filter(_.nonEmpty).map(_.toInt).flatMap { stageId =>
val stageTaskIds = getStageTaskIds(stageId)
val nodeName = sqlAccums.head.nodeName

// Initialize a map to store statistics for each IO metric
Expand All @@ -418,7 +429,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
sqlAccum.accumulatorId, emptyAccumInfo)

// Retrieve and sort task updates correspond to the current stage
val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageId)
val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageTaskIds)

// Compute the metric's statistics and store the results if available
if (filteredTaskUpdates.nonEmpty) {
Expand Down Expand Up @@ -476,7 +487,8 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
app.accumManager.accumInfoMap.flatMap { accumMapEntry =>
val accumInfo = accumMapEntry._2
accumInfo.stageValuesMap.keySet.flatMap( stageId => {
val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageId)
val filteredTaskUpdates =
filterAccumTaskUpdatesForStage(accumInfo, getStageTaskIds(stageId))

if (filteredTaskUpdates.isEmpty) {
None
Expand Down

0 comments on commit 6d0d798

Please sign in to comment.