Skip to content

Commit

Permalink
Removing accum to stage map
Browse files Browse the repository at this point in the history
Signed-off-by: Sayed Bilal Bari <[email protected]>
  • Loading branch information
bilalbari committed Aug 1, 2024
1 parent c926094 commit 81b98c0
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,7 @@ object SQLPlanParser extends Logging {

def getStagesInSQLNode(node: SparkPlanGraphNode, app: AppBase): Set[Int] = {
val nodeAccums = node.metrics.map(_.accumulatorId)
nodeAccums.flatMap { nodeAccumId =>
// val res = app.accumManager.getAccStageIds(nodeAccumId)
app.stageManager.getStagesIdsByAccumId(nodeAccumId)
}.toSet
nodeAccums.flatMap(app.accumManager.getAccStageIds).toSet
}

// Set containing execs that refers to other expressions. We need this to be a list to allow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CompareApplications(apps: Seq[ApplicationInfo]) extends Logging {
val normalizedByAppId = apps.map { app =>
val normalized = app.sqlPlans.mapValues { plan =>
SparkPlanInfoWithStage(plan,
app.stageManager.getAccumToSingleStage()).normalizeForStageComparison
app.accumManager.getAccumSingleStage).normalizeForStageComparison
}
(app.appId, normalized)
}.toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ object GenerateDot {
val accumSummary = accums.map { a =>
Seq(a.sqlID, a.accumulatorId, a.total)
}
val accumIdToStageId = app.stageManager.getAccumToSingleStage()
val accumIdToStageId = app.accumManager.getAccumSingleStage
val formatter = java.text.NumberFormat.getIntegerInstance
val stageIdToStageMetrics = app.taskManager.stageAttemptToTasks.collect { case (stageId, _) =>
val tasks = app.taskManager.getAllTasksStageAttempt(stageId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ abstract class AppBase(
def cleanupAccumId(accId: Long): Unit = {
accumManager.removeAccumInfo(accId)
driverAccumMap.remove(accId)
stageManager.removeAccumulatorId(accId)
}

def cleanupStages(stageIds: Set[Int]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,6 @@ abstract class EventProcessorBase[T <: AppBase](app: T) extends SparkListener wi
app: T,
event: SparkListenerTaskEnd): Unit = {
// TODO: this implementation needs to be updated to use attemptID
// Update the map between accumulators and stages
app.stageManager.addAccumIdToStage(
event.stageId, event.taskInfo.accumulables.map(_.id))
// Parse task accumulables
for (res <- event.taskInfo.accumulables) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

package org.apache.spark.sql.rapids.tool.store

case class AccMetaRef(id: Long, name: AccNameRef) {

}
case class AccMetaRef(id: Long, name: AccNameRef)

object AccMetaRef {
def apply(id: Long, name: Option[String]): AccMetaRef =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.sql.rapids.tool.util.EventUtils.normalizeMetricName

case class AccNameRef(value: String) {

}
case class AccNameRef(value: String)

object AccNameRef {
val EMPTY_ACC_NAME_REF: AccNameRef = new AccNameRef("N/A")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@ class AccumInfo(val infoRef: AccMetaRef) {
accumulableInfo: AccumulableInfo,
update: Option[Long] = None): Unit = {
val value = accumulableInfo.value.flatMap(parseAccumFieldToLong)
val existingValue = stageValuesMap.getOrElse(stageId, 0L)
value match {
case Some(v) =>
// This assert prevents out of order events to be processed
assert( v >= existingValue,
s"Stage $stageId: Out of order events detected.")
stageValuesMap.put(stageId, v)
case _ =>
// this could be the case when a task update has triggered the stage update
stageValuesMap.put(stageId, update.getOrElse(0L))
val incomingUpdate = update.getOrElse(0L)
assert( incomingUpdate >= existingValue,
s"Stage $stageId: Out of order events detected.")
// this case is for metrics that are not parsed as long
// We track the accumId to stageId and taskId mapping
stageValuesMap.put(stageId, incomingUpdate)
}
}

Expand All @@ -48,24 +56,30 @@ class AccumInfo(val infoRef: AccMetaRef) {
// we have to update the stageMap if the stageId does not exist in the map
var updateStageFlag = !stageValuesMap.contains(stageId)
// TODO: Task can update an accum multiple times. Should account for that case.
// This is for cases where same task updates the same accum multiple times
val existingUpdate = taskUpdatesMap.getOrElse(taskId, 0L)
update match {
case Some(v) =>
taskUpdatesMap.put(taskId, v)
taskUpdatesMap.put(taskId, v + existingUpdate)
// update teh stage if the task's update is non-zero
updateStageFlag ||= v != 0
case None =>
taskUpdatesMap.put(taskId, 0L)
taskUpdatesMap.put(taskId, existingUpdate)
}
// update the stage value map if necessary
if (updateStageFlag) {
addAccToStage(stageId, accumulableInfo, update)
addAccToStage(stageId, accumulableInfo, update.map(_ + existingUpdate))
}
}

def getStageIds: Set[Int] = {
stageValuesMap.keySet.toSet
}

def getMinStageId: Int = {
stageValuesMap.keys.min
}

def calculateAccStats(): StatisticsMetrics = {
val sortedTaskUpdates = taskUpdatesMap.values.toSeq.sorted
if (sortedTaskUpdates.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package org.apache.spark.sql.rapids.tool.store

import scala.collection.mutable
import scala.collection.{mutable, Map}

import com.nvidia.spark.rapids.tool.analysis.StatisticsMetrics

Expand Down Expand Up @@ -46,6 +46,13 @@ class AccumManager {
accumInfoMap.get(id).map(_.getStageIds).getOrElse(Set.empty)
}

def getAccumSingleStage: Map[Long, Int] = {
accumInfoMap.map { case (id, accInfo) =>
(id, accInfo.getMinStageId)
}.toMap
}


def removeAccumInfo(id: Long): Option[AccumInfo] = {
accumInfoMap.remove(id)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

package org.apache.spark.sql.rapids.tool.store

import scala.collection.{mutable, Map}
import scala.collection.immutable.{SortedSet, TreeSet}
import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.StageInfo
Expand Down Expand Up @@ -46,13 +45,6 @@ class StageModelManager extends Logging {
private val stageIdToInfo: mutable.SortedMap[Int, mutable.SortedMap[Int, StageModel]] =
mutable.SortedMap[Int, mutable.SortedMap[Int, StageModel]]()

// Holds the mapping between AccumulatorIDs to Stages (1-to-N)
// [Long: AccumId -> SortedSet[Int: StageId]]
// Note that we keep it as primitive type in case we receive a stageID that does not exist
// in the stageIdToInfo map.
private val accumIdToStageId: mutable.HashMap[Long, SortedSet[Int]] =
new mutable.HashMap[Long, SortedSet[Int]]()

/**
* Returns all StageModels that have been created as a result of handling
* StageSubmitted/StageCompleted-events. This includes stages with multiple attempts.
Expand Down Expand Up @@ -118,46 +110,7 @@ class StageModelManager extends Logging {
* @return existing or new instance of StageModel with (sInfo.stageId, sInfo.attemptID)
*/
def addStageInfo(sInfo: StageInfo): StageModel = {
// Creating stageModel instance if it does not exist
val stage = getOrCreateStage(sInfo)
// Maintaining the mapping between AccumulatorID and Corresponding Stage IDs
val sInfoAccumIds = sInfo.accumulables.keySet
if (sInfoAccumIds.nonEmpty) {
sInfoAccumIds.foreach { accumId =>
val stageIds = accumIdToStageId.getOrElseUpdate(accumId, TreeSet[Int]())
accumIdToStageId.put(accumId, stageIds + sInfo.stageId)
}
}
stage
}

/**
* Returns a mapping between AccumulatorID and a single stageId (1-to-1) by taking the head of
* the list.
* That getter is used as a temporary hack to avoid callers that expect a 1-to-1 mapping between
* accumulators and stages. i.e., GenerateDot.writeDotGraph expects a 1-to-1 mapping but it is
* rarely used for now.
*
* @return a Map of AccumulatorID to StageId
*/
def getAccumToSingleStage(): Map[Long, Int] = {
accumIdToStageId.map { case (accumId, stageIds) =>
accumId -> stageIds.head
}.toMap
}

def addAccumIdToStage(stageId: Int, accumIds: Iterable[Long]): Unit = {
accumIds.foreach { accumId =>
val stageIds = accumIdToStageId.getOrElseUpdate(accumId, TreeSet[Int]())
accumIdToStageId.put(accumId, stageIds + stageId)
}
}

def getStagesIdsByAccumId(accumId: Long): Iterable[Int] = {
accumIdToStageId.getOrElse(accumId, TreeSet[Int]())
}

def removeAccumulatorId(accId: Long): Unit = {
accumIdToStageId.remove(accId)
}
}

0 comments on commit 81b98c0

Please sign in to comment.