diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala index 78f29adae..74b793fa7 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala @@ -109,6 +109,16 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap IODiagnosticMetricsMap(key) += accum } + /** + * Retrieves the task IDs associated with a specific stage. + * + * @param stageId The ID of the stage. + * @return A seq of task IDs corresponding to the given stage ID. + */ + private def getStageTaskIds(stageId: Int): Seq[Long] = { + app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId).toSet.toSeq + } + /** * Retrieves task update values from the accumulator info for the specified stage ID. * @@ -117,9 +127,9 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap * @return A sorted sequence of task update values (`Long`) corresponding to the tasks * in the specified stage. */ - private def filterAccumTaskUpdatesForStage(accumInfo: AccumInfo, stageId: Int): Seq[Long] = { - val stageTaskIds = app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId).toSet - stageTaskIds.toSeq.collect { + private def filterAccumTaskUpdatesForStage(accumInfo: AccumInfo, stageTaskIds: Seq[Long]) + : Seq[Long] = { + stageTaskIds.collect { case taskId if accumInfo.taskUpdatesMap.contains(taskId) => accumInfo.taskUpdatesMap(taskId) }.toSeq.sorted @@ -406,6 +416,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap IODiagnosticMetricsMap.toSeq.flatMap { case ((sqlId, nodeId, stageIds), sqlAccums) => // Process each stage ID and compute diagnostic results stageIds.split(",").filter(_.nonEmpty).map(_.toInt).flatMap { stageId => + val stageTaskIds = getStageTaskIds(stageId) val nodeName = sqlAccums.head.nodeName // Initialize a map to store statistics for each IO metric @@ -418,7 +429,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap sqlAccum.accumulatorId, emptyAccumInfo) // Retrieve and sort task updates correspond to the current stage - val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageId) + val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageTaskIds) // Compute the metric's statistics and store the results if available if (filteredTaskUpdates.nonEmpty) { @@ -476,7 +487,8 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap app.accumManager.accumInfoMap.flatMap { accumMapEntry => val accumInfo = accumMapEntry._2 accumInfo.stageValuesMap.keySet.flatMap( stageId => { - val filteredTaskUpdates = filterAccumTaskUpdatesForStage(accumInfo, stageId) + val filteredTaskUpdates = + filterAccumTaskUpdatesForStage(accumInfo, getStageTaskIds(stageId)) if (filteredTaskUpdates.isEmpty) { None