Skip to content

Commit

Permalink
Skip processing apps with invalid platform and spark runtime configur…
Browse files Browse the repository at this point in the history
…ations (#1421)

* Add platform specific runtime check

Signed-off-by: Partho Sarthi <[email protected]>

* Refactor comments

Signed-off-by: Partho Sarthi <[email protected]>

* Update behavior to fail on unsupported Spark Runtime

Signed-off-by: Partho Sarthi <[email protected]>

* Fix trailing comma

Signed-off-by: Partho Sarthi <[email protected]>

---------

Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa authored Dec 20, 2024
1 parent f0058c0 commit 7308c12
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 42 deletions.
19 changes: 18 additions & 1 deletion core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import com.nvidia.spark.rapids.tool.profiling.ClusterProperties

import org.apache.spark.internal.Logging
import org.apache.spark.sql.rapids.tool.{ExistingClusterInfo, RecommendedClusterInfo}
import org.apache.spark.sql.rapids.tool.util.StringUtils
import org.apache.spark.sql.rapids.tool.util.{SparkRuntime, StringUtils}

/**
* Utility object containing constants for various platform names.
Expand Down Expand Up @@ -132,6 +132,19 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
var recommendedClusterInfo: Option[RecommendedClusterInfo] = None
// the number of GPUs to use, this might be updated as we handle different cases
var numGpus: Int = 1
// Default runtime for the platform
val defaultRuntime: SparkRuntime.SparkRuntime = SparkRuntime.SPARK
// Set of supported runtimes for the platform
protected val supportedRuntimes: Set[SparkRuntime.SparkRuntime] = Set(
SparkRuntime.SPARK, SparkRuntime.SPARK_RAPIDS
)

/**
* Checks if the given runtime is supported by the platform.
*/
def isRuntimeSupported(runtime: SparkRuntime.SparkRuntime): Boolean = {
supportedRuntimes.contains(runtime)
}

// This function allow us to have one gpu type used by the auto
// tuner recommendations but have a different GPU used for speedup
Expand Down Expand Up @@ -511,6 +524,10 @@ abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice],
override val defaultGpuDevice: GpuDevice = T4Gpu
override def isPlatformCSP: Boolean = true

override val supportedRuntimes: Set[SparkRuntime.SparkRuntime] = Set(
SparkRuntime.SPARK, SparkRuntime.SPARK_RAPIDS, SparkRuntime.PHOTON
)

// note that Databricks generally sets the spark.executor.memory for the user. Our
// auto tuner heuristics generally sets it lower then Databricks so go ahead and
// allow our auto tuner to take affect for this in anticipation that we will use more
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal

import com.nvidia.spark.rapids.tool.{AppSummaryInfoBaseProvider, EventLogInfo, EventLogPathProcessor, FailedEventLog, PlatformFactory, ToolBase}
import com.nvidia.spark.rapids.tool.{AppSummaryInfoBaseProvider, EventLogInfo, EventLogPathProcessor, FailedEventLog, Platform, PlatformFactory, ToolBase}
import com.nvidia.spark.rapids.tool.profiling.AutoTuner.loadClusterProps
import com.nvidia.spark.rapids.tool.views._
import org.apache.hadoop.conf.Configuration
Expand All @@ -43,6 +43,8 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea
private val outputCombined: Boolean = appArgs.combined()
private val useAutoTuner: Boolean = appArgs.autoTuner()
private val outputAlignedSQLIds: Boolean = appArgs.outputSqlIdsAligned()
// Unlike qualification tool, profiler tool does not require platform per app
private val platform: Platform = PlatformFactory.createInstance(appArgs.platform())

override def getNumThreads: Int = appArgs.numThreads.getOrElse(
Math.ceil(Runtime.getRuntime.availableProcessors() / 4f).toInt)
Expand Down Expand Up @@ -295,9 +297,9 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea
private def createApp(path: EventLogInfo, index: Int,
hadoopConf: Configuration): Either[FailureApp, ApplicationInfo] = {
try {
// This apps only contains 1 app in each loop.
// These apps only contains 1 app in each loop.
val startTime = System.currentTimeMillis()
val app = new ApplicationInfo(hadoopConf, path, index)
val app = new ApplicationInfo(hadoopConf, path, index, platform)
EventLogPathProcessor.logApplicationInfo(app)
val endTime = System.currentTimeMillis()
if (!app.isAppMetaDefined) {
Expand Down
20 changes: 18 additions & 2 deletions core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashSet, Map}

import com.nvidia.spark.rapids.SparkRapidsBuildInfoEvent
import com.nvidia.spark.rapids.tool.{DatabricksEventLog, DatabricksRollingEventLogFilesFileReader, EventLogInfo}
import com.nvidia.spark.rapids.tool.{DatabricksEventLog, DatabricksRollingEventLogFilesFileReader, EventLogInfo, Platform}
import com.nvidia.spark.rapids.tool.planparser.{HiveParseHelper, ReadParser}
import com.nvidia.spark.rapids.tool.planparser.HiveParseHelper.isHiveTableScanNode
import com.nvidia.spark.rapids.tool.profiling.{BlockManagerRemovedCase, DriverAccumCase, JobInfoClass, ResourceProfileInfoCase, SQLExecutionInfoClass, SQLPlanMetricsCase}
Expand All @@ -42,7 +42,8 @@ import org.apache.spark.util.Utils

abstract class AppBase(
val eventLogInfo: Option[EventLogInfo],
val hadoopConf: Option[Configuration]) extends Logging
val hadoopConf: Option[Configuration],
val platform: Option[Platform] = None) extends Logging
with ClusterTagPropHandler
with AccumToStageRetriever {

Expand Down Expand Up @@ -481,6 +482,7 @@ abstract class AppBase(
protected def postCompletion(): Unit = {
registerAttemptId()
calculateAppDuration()
validateSparkRuntime()
}

/**
Expand All @@ -491,6 +493,20 @@ abstract class AppBase(
processEventsInternal()
postCompletion()
}

/**
* Validates if the spark runtime (parsed from event log) is supported by the platform.
* If the runtime is not supported, an `UnsupportedSparkRuntimeException`
* is thrown.
*/
private def validateSparkRuntime(): Unit = {
val parsedRuntime = getSparkRuntime
platform.foreach { p =>
require(p.isRuntimeSupported(parsedRuntime),
throw UnsupportedSparkRuntimeException(p, parsedRuntime)
)
}
}
}

object AppBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids.tool
import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import com.nvidia.spark.rapids.tool.Platform
import com.nvidia.spark.rapids.tool.planparser.SubqueryExecParser
import com.nvidia.spark.rapids.tool.profiling.ProfileUtils.replaceDelimiter
import com.nvidia.spark.rapids.tool.qualification.QualOutputWriter
Expand All @@ -28,7 +29,7 @@ import org.apache.spark.internal.{config, Logging}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphNode}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph
import org.apache.spark.sql.rapids.tool.util.{SparkRuntime, ToolsPlanGraph}

object ToolUtils extends Logging {
// List of recommended file-encodings on the GPUs.
Expand Down Expand Up @@ -441,6 +442,12 @@ case class UnsupportedMetricNameException(metricName: String)
extends AppEventlogProcessException(
s"Unsupported metric name found in the event log: $metricName")

case class UnsupportedSparkRuntimeException(
platform: Platform,
sparkRuntime: SparkRuntime.SparkRuntime)
extends AppEventlogProcessException(
s"Platform '${platform.platformName}' does not support the runtime '$sparkRuntime'")

// Class used a container to hold the information of the Tuple<sqlID, PlanInfo, SparkGraph>
// to simplify arguments of methods and caching.
case class SqlPlanInfoGraphEntry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.tool.profiling

import scala.collection.Map

import com.nvidia.spark.rapids.tool.EventLogInfo
import com.nvidia.spark.rapids.tool.{EventLogInfo, Platform, PlatformFactory}
import com.nvidia.spark.rapids.tool.analysis.AppSQLPlanAnalyzer
import org.apache.hadoop.conf.Configuration

Expand Down Expand Up @@ -184,8 +184,9 @@ object SparkPlanInfoWithStage {
class ApplicationInfo(
hadoopConf: Configuration,
eLogInfo: EventLogInfo,
val index: Int)
extends AppBase(Some(eLogInfo), Some(hadoopConf)) with Logging {
val index: Int,
platform: Platform = PlatformFactory.createInstance())
extends AppBase(Some(eLogInfo), Some(hadoopConf), Some(platform)) with Logging {

private lazy val eventProcessor = new EventsProcessor(this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class QualificationAppInfo(
mlOpsEnabled: Boolean = false,
penalizeTransitions: Boolean = true,
platform: Platform)
extends AppBase(eventLogInfo, hadoopConf) with Logging {
extends AppBase(eventLogInfo, hadoopConf, Some(platform)) with Logging {

var lastJobEndTime: Option[Long] = None
var lastSQLEndTime: Option[Long] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,13 @@ object ToolTestUtils extends Logging {
val apps: ArrayBuffer[ApplicationInfo] = ArrayBuffer[ApplicationInfo]()
val appArgs = new ProfileArgs(logs)
var index: Int = 1
val platform = PlatformFactory.createInstance(appArgs.platform())
for (path <- appArgs.eventlog()) {
val eventLogInfo = EventLogPathProcessor
.getEventLogInfo(path, RapidsToolsConfUtil.newHadoopConf())
assert(eventLogInfo.size >= 1, s"event log not parsed as expected $path")
assert(eventLogInfo.nonEmpty, s"event log not parsed as expected $path")
apps += new ApplicationInfo(RapidsToolsConfUtil.newHadoopConf(),
eventLogInfo.head._1, index)
eventLogInfo.head._1, index, platform)
index += 1
}
apps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.spark.rapids.tool.planparser

import com.nvidia.spark.rapids.BaseTestSuite
import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformFactory, ToolTestUtils}
import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformFactory, PlatformNames, ToolTestUtils}
import com.nvidia.spark.rapids.tool.qualification._

import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
Expand Down Expand Up @@ -59,15 +59,16 @@ class BasePlanParserSuite extends BaseTestSuite {
}
}

def createAppFromEventlog(eventLog: String): QualificationAppInfo = {
def createAppFromEventlog(eventLog: String,
platformName: String = PlatformNames.DEFAULT): QualificationAppInfo = {
val hadoopConf = RapidsToolsConfUtil.newHadoopConf()
val (_, allEventLogs) = EventLogPathProcessor.processAllPaths(
None, None, List(eventLog), hadoopConf)
val pluginTypeChecker = new PluginTypeChecker()
assert(allEventLogs.size == 1)
val appResult = QualificationAppInfo.createApp(allEventLogs.head, hadoopConf,
pluginTypeChecker, reportSqlLevel = false, mlOpsEnabled = false, penalizeTransitions = true,
PlatformFactory.createInstance())
PlatformFactory.createInstance(platformName))
appResult match {
case Right(app) => app
case Left(_) => throw new AssertionError("Cannot create application")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids.tool.planparser

import com.nvidia.spark.rapids.tool.PlatformNames
import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker


Expand All @@ -34,7 +35,7 @@ class PhotonPlanParserSuite extends BasePlanParserSuite {
test(s"$photonName is parsed as Spark $sparkName") {
val eventLog = s"$qualLogDir/nds_q88_photon_db_13_3.zstd"
val pluginTypeChecker = new PluginTypeChecker()
val app = createAppFromEventlog(eventLog)
val app = createAppFromEventlog(eventLog, platformName = PlatformNames.DATABRICKS_AWS)
assert(app.sqlPlans.nonEmpty)
val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.tool.profiling

import java.io.File

import com.nvidia.spark.rapids.tool.ToolTestUtils
import com.nvidia.spark.rapids.tool.{PlatformNames, ToolTestUtils}
import com.nvidia.spark.rapids.tool.views.{ProfDataSourceView, RawMetricProfilerView}
import org.scalatest.FunSuite

Expand Down Expand Up @@ -139,7 +139,8 @@ class AnalysisSuite extends FunSuite {
s"${fileName}_${metric}_metrics_agg_expectation.csv"
}
testSqlMetricsAggregation(Array(s"${qualLogDir}/${fileName}.zstd"),
expectFile("sql"), expectFile("job"), expectFile("stage"))
expectFile("sql"), expectFile("job"), expectFile("stage"),
platformName = PlatformNames.DATABRICKS_AWS)
}

test("test stage-level diagnostic aggregation simple") {
Expand All @@ -163,8 +164,10 @@ class AnalysisSuite extends FunSuite {
}

private def testSqlMetricsAggregation(logs: Array[String], expectFileSQL: String,
expectFileJob: String, expectFileStage: String): Unit = {
val apps = ToolTestUtils.processProfileApps(logs, sparkSession)
expectFileJob: String, expectFileStage: String,
platformName: String = PlatformNames.DEFAULT): Unit = {
val args = Array("--platform", platformName) ++ logs
val apps = ToolTestUtils.processProfileApps(args, sparkSession)
assert(apps.size == logs.size)
val aggResults = RawMetricProfilerView.getAggMetrics(apps)
import sparkSession.implicits._
Expand Down Expand Up @@ -256,9 +259,12 @@ class AnalysisSuite extends FunSuite {
}

test("test photon scan metrics") {
val fileName = "nds_q88_photon_db_13_3"
val logs = Array(s"${qualLogDir}/${fileName}.zstd")
val apps = ToolTestUtils.processProfileApps(logs, sparkSession)
val args = Array(
"--platform",
PlatformNames.DATABRICKS_AWS,
s"$qualLogDir/nds_q88_photon_db_13_3.zstd"
)
val apps = ToolTestUtils.processProfileApps(args, sparkSession)
val dataSourceResults = ProfDataSourceView.getRawView(apps)
assert(dataSourceResults.exists(_.scan_time > 0))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ import java.nio.file.{Files, Paths, StandardOpenOption}

import scala.collection.mutable.ArrayBuffer

import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, StatusReportCounts, ToolTestUtils}
import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformNames, StatusReportCounts, ToolTestUtils}
import com.nvidia.spark.rapids.tool.views.RawMetricProfilerView
import org.apache.hadoop.io.IOUtils
import org.scalatest.FunSuite

import org.apache.spark.internal.Logging
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.{SparkSession, TrampolineUtil}
import org.apache.spark.sql.rapids.tool.UnsupportedSparkRuntimeException
import org.apache.spark.sql.rapids.tool.profiling._
import org.apache.spark.sql.rapids.tool.util.{FSUtils, SparkRuntime}

Expand Down Expand Up @@ -1116,17 +1117,56 @@ class ApplicationInfoSuite extends FunSuite with Logging {
}
}

val sparkRuntimeTestCases: Seq[(SparkRuntime.Value, String)] = Seq(
SparkRuntime.SPARK -> s"$qualLogDir/nds_q86_test",
SparkRuntime.SPARK_RAPIDS -> s"$logDir/nds_q66_gpu.zstd",
SparkRuntime.PHOTON -> s"$qualLogDir/nds_q88_photon_db_13_3.zstd"
// scalastyle:off line.size.limit
val supportedSparkRuntimeTestCases: Map[String, Seq[(String, SparkRuntime.SparkRuntime)]] = Map(
// tests for standard Spark runtime
s"$qualLogDir/nds_q86_test" -> Seq(
(PlatformNames.DATABRICKS_AWS, SparkRuntime.SPARK), // Expected: SPARK on Databricks AWS
(PlatformNames.ONPREM, SparkRuntime.SPARK) // Expected: SPARK on Onprem
),
// tests for Spark Rapids runtime
s"$logDir/nds_q66_gpu.zstd" -> Seq(
(PlatformNames.DATABRICKS_AWS, SparkRuntime.SPARK_RAPIDS), // Expected: SPARK_RAPIDS on Databricks AWS
(PlatformNames.ONPREM, SparkRuntime.SPARK_RAPIDS) // Expected: SPARK_RAPIDS on Onprem
),
// tests for Photon runtime with fallback to SPARK for unsupported platforms
s"$qualLogDir/nds_q88_photon_db_13_3.zstd" -> Seq(
(PlatformNames.DATABRICKS_AWS, SparkRuntime.PHOTON), // Expected: PHOTON on Databricks AWS
(PlatformNames.DATABRICKS_AZURE, SparkRuntime.PHOTON) // Expected: PHOTON on Databricks Azure
)
)
// scalastyle:on line.size.limit

supportedSparkRuntimeTestCases.foreach { case (logPath, platformRuntimeCases) =>
val baseFileName = logPath.split("/").last
platformRuntimeCases.foreach { case (platform, expectedRuntime) =>
test(s"test eventlog $baseFileName on $platform has supported runtime: $expectedRuntime") {
val args = Array("--platform", platform, logPath)
val apps = ToolTestUtils.processProfileApps(args, sparkSession)
assert(apps.size == 1)
assert(apps.head.getSparkRuntime == expectedRuntime)
}
}
}

sparkRuntimeTestCases.foreach { case (expectedSparkRuntime, eventLog) =>
test(s"test spark runtime property for ${expectedSparkRuntime.toString} eventlog") {
val apps = ToolTestUtils.processProfileApps(Array(eventLog), sparkSession)
assert(apps.size == 1)
assert(apps.head.getSparkRuntime == expectedSparkRuntime)
// scalastyle:off line.size.limit
val unsupportedSparkRuntimeTestCases: Map[String, Seq[String]] = Map(
s"$qualLogDir/nds_q88_photon_db_13_3.zstd" -> Seq(
PlatformNames.ONPREM, // Expected: PHOTON runtime on Onprem is not supported
PlatformNames.DATAPROC // Expected: PHOTON runtime on Dataproc is not supported
)
)
// scalastyle:on line.size.limit

unsupportedSparkRuntimeTestCases.foreach { case (logPath, platformNames) =>
val baseFileName = logPath.split("/").last
platformNames.foreach { platform =>
test(s"test eventlog $baseFileName on $platform has unsupported runtime") {
val args = Array("--platform", platform, logPath)
intercept[UnsupportedSparkRuntimeException] {
ToolTestUtils.processProfileApps(args, sparkSession)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,15 @@ class QualificationSuite extends BaseTestSuite {
}
}

private def runQualificationTest(eventLogs: Array[String], expectFileName: String = "",
private def runQualificationTest(eventLogs: Array[String],
expectFileName: String = "", platformName: String = PlatformNames.DEFAULT,
shouldReturnEmpty: Boolean = false, expectPerSqlFileName: Option[String] = None,
expectedStatus: Option[StatusReportCounts] = None): Unit = {
TrampolineUtil.withTempDir { outpath =>
val qualOutputPrefix = "rapids_4_spark_qualification_output"
val outputArgs = Array(
"--platform",
platformName,
"--output-directory",
outpath.getAbsolutePath())

Expand Down Expand Up @@ -1762,7 +1765,8 @@ class QualificationSuite extends BaseTestSuite {
val logFiles = Array(s"$logDir/nds_q88_photon_db_13_3.zstd") // photon event log
// Status counts: 1 SUCCESS, 0 FAILURE, 0 SKIPPED, 0 UNKNOWN
val expectedStatus = Some(StatusReportCounts(1, 0, 0, 0))
runQualificationTest(logFiles, expectedStatus = expectedStatus)
runQualificationTest(logFiles, platformName = PlatformNames.DATABRICKS_AWS,
expectedStatus = expectedStatus)
}

test("process multiple attempts of the same app ID and skip lower attempts") {
Expand Down
Loading

0 comments on commit 7308c12

Please sign in to comment.