diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala index 560431c4f..fc438be3e 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/planparser/SQLPlanParser.scala @@ -443,10 +443,7 @@ object SQLPlanParser extends Logging { def getStagesInSQLNode(node: SparkPlanGraphNode, app: AppBase): Set[Int] = { val nodeAccums = node.metrics.map(_.accumulatorId) - nodeAccums.flatMap { nodeAccumId => - // val res = app.accumManager.getAccStageIds(nodeAccumId) - app.stageManager.getStagesIdsByAccumId(nodeAccumId) - }.toSet + nodeAccums.flatMap(app.accumManager.getAccStageIds).toSet } // Set containing execs that refers to other expressions. We need this to be a list to allow diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CompareApplications.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CompareApplications.scala index 240114678..c67c5e8e3 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CompareApplications.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CompareApplications.scala @@ -33,7 +33,7 @@ class CompareApplications(apps: Seq[ApplicationInfo]) extends Logging { val normalizedByAppId = apps.map { app => val normalized = app.sqlPlans.mapValues { plan => SparkPlanInfoWithStage(plan, - app.stageManager.getAccumToSingleStage()).normalizeForStageComparison + app.accumManager.getAccumSingleStage).normalizeForStageComparison } (app.appId, normalized) }.toMap diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/GenerateDot.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/GenerateDot.scala index fc1963dfe..850290464 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/GenerateDot.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/GenerateDot.scala @@ -88,7 +88,7 @@ object GenerateDot { val accumSummary = accums.map { a => Seq(a.sqlID, a.accumulatorId, a.total) } - val accumIdToStageId = app.stageManager.getAccumToSingleStage() + val accumIdToStageId = app.accumManager.getAccumSingleStage val formatter = java.text.NumberFormat.getIntegerInstance val stageIdToStageMetrics = app.taskManager.stageAttemptToTasks.collect { case (stageId, _) => val tasks = app.taskManager.getAllTasksStageAttempt(stageId) diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala index 23192f93d..b7227c05a 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala @@ -181,7 +181,6 @@ abstract class AppBase( def cleanupAccumId(accId: Long): Unit = { accumManager.removeAccumInfo(accId) driverAccumMap.remove(accId) - stageManager.removeAccumulatorId(accId) } def cleanupStages(stageIds: Set[Int]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/EventProcessorBase.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/EventProcessorBase.scala index 38c4059e2..51beaa690 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/EventProcessorBase.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/EventProcessorBase.scala @@ -349,9 +349,6 @@ abstract class EventProcessorBase[T <: AppBase](app: T) extends SparkListener wi app: T, event: SparkListenerTaskEnd): Unit = { // TODO: this implementation needs to be updated to use attemptID - // Update the map between accumulators and stages - app.stageManager.addAccumIdToStage( - event.stageId, event.taskInfo.accumulables.map(_.id)) // Parse task accumulables for (res <- event.taskInfo.accumulables) { try { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccMetaRef.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccMetaRef.scala index 3deb457de..548229c8d 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccMetaRef.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccMetaRef.scala @@ -16,9 +16,7 @@ package org.apache.spark.sql.rapids.tool.store -case class AccMetaRef(id: Long, name: AccNameRef) { - -} +case class AccMetaRef(id: Long, name: AccNameRef) object AccMetaRef { def apply(id: Long, name: Option[String]): AccMetaRef = diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccNameRef.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccNameRef.scala index b55053d63..c40107f41 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccNameRef.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccNameRef.scala @@ -20,9 +20,7 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.sql.rapids.tool.util.EventUtils.normalizeMetricName -case class AccNameRef(value: String) { - -} +case class AccNameRef(value: String) object AccNameRef { val EMPTY_ACC_NAME_REF: AccNameRef = new AccNameRef("N/A") diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala index 4c6fb3cf7..fa46f2206 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala @@ -34,12 +34,20 @@ class AccumInfo(val infoRef: AccMetaRef) { accumulableInfo: AccumulableInfo, update: Option[Long] = None): Unit = { val value = accumulableInfo.value.flatMap(parseAccumFieldToLong) + val existingValue = stageValuesMap.getOrElse(stageId, 0L) value match { case Some(v) => + // This assert prevents out of order events to be processed + assert( v >= existingValue, + s"Stage $stageId: Out of order events detected.") stageValuesMap.put(stageId, v) case _ => - // this could be the case when a task update has triggered the stage update - stageValuesMap.put(stageId, update.getOrElse(0L)) + val incomingUpdate = update.getOrElse(0L) + assert( incomingUpdate >= existingValue, + s"Stage $stageId: Out of order events detected.") + // this case is for metrics that are not parsed as long + // We track the accumId to stageId and taskId mapping + stageValuesMap.put(stageId, incomingUpdate) } } @@ -48,17 +56,19 @@ class AccumInfo(val infoRef: AccMetaRef) { // we have to update the stageMap if the stageId does not exist in the map var updateStageFlag = !stageValuesMap.contains(stageId) // TODO: Task can update an accum multiple times. Should account for that case. + // This is for cases where same task updates the same accum multiple times + val existingUpdate = taskUpdatesMap.getOrElse(taskId, 0L) update match { case Some(v) => - taskUpdatesMap.put(taskId, v) + taskUpdatesMap.put(taskId, v + existingUpdate) // update teh stage if the task's update is non-zero updateStageFlag ||= v != 0 case None => - taskUpdatesMap.put(taskId, 0L) + taskUpdatesMap.put(taskId, existingUpdate) } // update the stage value map if necessary if (updateStageFlag) { - addAccToStage(stageId, accumulableInfo, update) + addAccToStage(stageId, accumulableInfo, update.map(_ + existingUpdate)) } } @@ -66,6 +76,10 @@ class AccumInfo(val infoRef: AccMetaRef) { stageValuesMap.keySet.toSet } + def getMinStageId: Int = { + stageValuesMap.keys.min + } + def calculateAccStats(): StatisticsMetrics = { val sortedTaskUpdates = taskUpdatesMap.values.toSeq.sorted if (sortedTaskUpdates.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumManager.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumManager.scala index 3e1a21a78..66514e56a 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumManager.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumManager.scala @@ -16,7 +16,7 @@ package org.apache.spark.sql.rapids.tool.store -import scala.collection.mutable +import scala.collection.{mutable, Map} import com.nvidia.spark.rapids.tool.analysis.StatisticsMetrics @@ -46,6 +46,13 @@ class AccumManager { accumInfoMap.get(id).map(_.getStageIds).getOrElse(Set.empty) } + def getAccumSingleStage: Map[Long, Int] = { + accumInfoMap.map { case (id, accInfo) => + (id, accInfo.getMinStageId) + }.toMap + } + + def removeAccumInfo(id: Long): Option[AccumInfo] = { accumInfoMap.remove(id) } diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/StageModelManager.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/StageModelManager.scala index 1f982ae21..e0da8f62d 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/StageModelManager.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/StageModelManager.scala @@ -16,8 +16,7 @@ package org.apache.spark.sql.rapids.tool.store -import scala.collection.{mutable, Map} -import scala.collection.immutable.{SortedSet, TreeSet} +import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.scheduler.StageInfo @@ -46,13 +45,6 @@ class StageModelManager extends Logging { private val stageIdToInfo: mutable.SortedMap[Int, mutable.SortedMap[Int, StageModel]] = mutable.SortedMap[Int, mutable.SortedMap[Int, StageModel]]() - // Holds the mapping between AccumulatorIDs to Stages (1-to-N) - // [Long: AccumId -> SortedSet[Int: StageId]] - // Note that we keep it as primitive type in case we receive a stageID that does not exist - // in the stageIdToInfo map. - private val accumIdToStageId: mutable.HashMap[Long, SortedSet[Int]] = - new mutable.HashMap[Long, SortedSet[Int]]() - /** * Returns all StageModels that have been created as a result of handling * StageSubmitted/StageCompleted-events. This includes stages with multiple attempts. @@ -118,46 +110,7 @@ class StageModelManager extends Logging { * @return existing or new instance of StageModel with (sInfo.stageId, sInfo.attemptID) */ def addStageInfo(sInfo: StageInfo): StageModel = { - // Creating stageModel instance if it does not exist val stage = getOrCreateStage(sInfo) - // Maintaining the mapping between AccumulatorID and Corresponding Stage IDs - val sInfoAccumIds = sInfo.accumulables.keySet - if (sInfoAccumIds.nonEmpty) { - sInfoAccumIds.foreach { accumId => - val stageIds = accumIdToStageId.getOrElseUpdate(accumId, TreeSet[Int]()) - accumIdToStageId.put(accumId, stageIds + sInfo.stageId) - } - } stage } - - /** - * Returns a mapping between AccumulatorID and a single stageId (1-to-1) by taking the head of - * the list. - * That getter is used as a temporary hack to avoid callers that expect a 1-to-1 mapping between - * accumulators and stages. i.e., GenerateDot.writeDotGraph expects a 1-to-1 mapping but it is - * rarely used for now. - * - * @return a Map of AccumulatorID to StageId - */ - def getAccumToSingleStage(): Map[Long, Int] = { - accumIdToStageId.map { case (accumId, stageIds) => - accumId -> stageIds.head - }.toMap - } - - def addAccumIdToStage(stageId: Int, accumIds: Iterable[Long]): Unit = { - accumIds.foreach { accumId => - val stageIds = accumIdToStageId.getOrElseUpdate(accumId, TreeSet[Int]()) - accumIdToStageId.put(accumId, stageIds + stageId) - } - } - - def getStagesIdsByAccumId(accumId: Long): Iterable[Int] = { - accumIdToStageId.getOrElse(accumId, TreeSet[Int]()) - } - - def removeAccumulatorId(accId: Long): Unit = { - accumIdToStageId.remove(accId) - } }