Skip to content

Commit

Permalink
Update behavior to fail on unsupported Spark Runtime
Browse files Browse the repository at this point in the history
Signed-off-by: Partho Sarthi <[email protected]>
  • Loading branch information
parthosa committed Dec 10, 2024
1 parent 68cdf83 commit c4b8a52
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 43 deletions.
30 changes: 12 additions & 18 deletions core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.scheduler.{SparkListenerEvent, StageInfo}
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraphNode
import org.apache.spark.sql.rapids.tool.store.{AccumManager, DataSourceRecord, SQLPlanModelManager, StageModel, StageModelManager, TaskModelManager}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, SparkRuntime, ToolsPlanGraph, UTF8Source}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph, UTF8Source}
import org.apache.spark.util.Utils

abstract class AppBase(
Expand Down Expand Up @@ -482,6 +482,7 @@ abstract class AppBase(
protected def postCompletion(): Unit = {
registerAttemptId()
calculateAppDuration()
validateSparkRuntime()
}

/**
Expand All @@ -494,24 +495,17 @@ abstract class AppBase(
}

/**
* Returns the SparkRuntime environment in which the application is being executed.
* This is calculated based on other cached properties.
*
* If the platform is provided, and it does not support the parsed runtime,
* the method will log a warning and fall back to the platform’s default runtime.
* Validates if the spark runtime (parsed from event log) is supported by the platform.
* If the runtime is not supported, an `UnsupportedSparkRuntimeException`
* is thrown.
*/
override def getSparkRuntime: SparkRuntime.SparkRuntime = {
val parsedRuntime = super.getSparkRuntime
platform.map { p =>
if (p.isRuntimeSupported(parsedRuntime)) {
parsedRuntime
} else {
logWarning(s"Application $appId: Platform '${p.platformName}' does not support " +
s"the parsed runtime '$parsedRuntime'. Falling back to default runtime - " +
s"'${p.defaultRuntime}'.")
p.defaultRuntime
}
}.getOrElse(parsedRuntime)
private def validateSparkRuntime(): Unit = {
val parsedRuntime = getSparkRuntime
platform.foreach { p =>
require(p.isRuntimeSupported(parsedRuntime),
throw UnsupportedSparkRuntimeException(p, parsedRuntime)
)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids.tool
import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import com.nvidia.spark.rapids.tool.Platform
import com.nvidia.spark.rapids.tool.planparser.SubqueryExecParser
import com.nvidia.spark.rapids.tool.profiling.ProfileUtils.replaceDelimiter
import com.nvidia.spark.rapids.tool.qualification.QualOutputWriter
Expand All @@ -28,7 +29,7 @@ import org.apache.spark.internal.{config, Logging}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphNode}
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph
import org.apache.spark.sql.rapids.tool.util.{SparkRuntime, ToolsPlanGraph}

object ToolUtils extends Logging {
// List of recommended file-encodings on the GPUs.
Expand Down Expand Up @@ -441,6 +442,12 @@ case class UnsupportedMetricNameException(metricName: String)
extends AppEventlogProcessException(
s"Unsupported metric name found in the event log: $metricName")

case class UnsupportedSparkRuntimeException(
platform: Platform,
sparkRuntime: SparkRuntime.SparkRuntime)
extends AppEventlogProcessException(
s"Platform '${platform.platformName}' does not support the runtime '$sparkRuntime'")

// Class used a container to hold the information of the Tuple<sqlID, PlanInfo, SparkGraph>
// to simplify arguments of methods and caching.
case class SqlPlanInfoGraphEntry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.spark.rapids.tool.planparser

import com.nvidia.spark.rapids.BaseTestSuite
import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformFactory, ToolTestUtils}
import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformFactory, PlatformNames, ToolTestUtils}
import com.nvidia.spark.rapids.tool.qualification._

import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
Expand Down Expand Up @@ -59,15 +59,16 @@ class BasePlanParserSuite extends BaseTestSuite {
}
}

def createAppFromEventlog(eventLog: String): QualificationAppInfo = {
def createAppFromEventlog(eventLog: String,
platformName: String = PlatformNames.DEFAULT): QualificationAppInfo = {
val hadoopConf = RapidsToolsConfUtil.newHadoopConf()
val (_, allEventLogs) = EventLogPathProcessor.processAllPaths(
None, None, List(eventLog), hadoopConf)
val pluginTypeChecker = new PluginTypeChecker()
assert(allEventLogs.size == 1)
val appResult = QualificationAppInfo.createApp(allEventLogs.head, hadoopConf,
pluginTypeChecker, reportSqlLevel = false, mlOpsEnabled = false, penalizeTransitions = true,
PlatformFactory.createInstance())
PlatformFactory.createInstance(platformName))
appResult match {
case Right(app) => app
case Left(_) => throw new AssertionError("Cannot create application")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids.tool.planparser

import com.nvidia.spark.rapids.tool.PlatformNames
import com.nvidia.spark.rapids.tool.qualification.PluginTypeChecker


Expand All @@ -34,7 +35,7 @@ class PhotonPlanParserSuite extends BasePlanParserSuite {
test(s"$photonName is parsed as Spark $sparkName") {
val eventLog = s"$qualLogDir/nds_q88_photon_db_13_3.zstd"
val pluginTypeChecker = new PluginTypeChecker()
val app = createAppFromEventlog(eventLog)
val app = createAppFromEventlog(eventLog, platformName = PlatformNames.DATABRICKS_AWS)
assert(app.sqlPlans.nonEmpty)
val parsedPlans = app.sqlPlans.map { case (sqlID, plan) =>
SQLPlanParser.parseSQLPlan(app.appId, plan, sqlID, "", pluginTypeChecker, app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.tool.profiling

import java.io.File

import com.nvidia.spark.rapids.tool.ToolTestUtils
import com.nvidia.spark.rapids.tool.{PlatformNames, ToolTestUtils}
import com.nvidia.spark.rapids.tool.views.{ProfDataSourceView, RawMetricProfilerView}
import org.scalatest.FunSuite

Expand Down Expand Up @@ -139,7 +139,8 @@ class AnalysisSuite extends FunSuite {
s"${fileName}_${metric}_metrics_agg_expectation.csv"
}
testSqlMetricsAggregation(Array(s"${qualLogDir}/${fileName}.zstd"),
expectFile("sql"), expectFile("job"), expectFile("stage"))
expectFile("sql"), expectFile("job"), expectFile("stage"),
platformName = PlatformNames.DATABRICKS_AWS)
}

test("test stage-level diagnostic aggregation simple") {
Expand All @@ -163,8 +164,10 @@ class AnalysisSuite extends FunSuite {
}

private def testSqlMetricsAggregation(logs: Array[String], expectFileSQL: String,
expectFileJob: String, expectFileStage: String): Unit = {
val apps = ToolTestUtils.processProfileApps(logs, sparkSession)
expectFileJob: String, expectFileStage: String,
platformName: String = PlatformNames.DEFAULT): Unit = {
val args = Array("--platform", platformName) ++ logs
val apps = ToolTestUtils.processProfileApps(args, sparkSession)
assert(apps.size == logs.size)
val aggResults = RawMetricProfilerView.getAggMetrics(apps)
import sparkSession.implicits._
Expand Down Expand Up @@ -256,9 +259,12 @@ class AnalysisSuite extends FunSuite {
}

test("test photon scan metrics") {
val fileName = "nds_q88_photon_db_13_3"
val logs = Array(s"${qualLogDir}/${fileName}.zstd")
val apps = ToolTestUtils.processProfileApps(logs, sparkSession)
val args = Array(
"--platform",
PlatformNames.DATABRICKS_AWS,
s"$qualLogDir/nds_q88_photon_db_13_3.zstd"
)
val apps = ToolTestUtils.processProfileApps(args, sparkSession)
val dataSourceResults = ProfDataSourceView.getRawView(apps)
assert(dataSourceResults.exists(_.scan_time > 0))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.scalatest.FunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.{SparkSession, TrampolineUtil}
import org.apache.spark.sql.rapids.tool.UnsupportedSparkRuntimeException
import org.apache.spark.sql.rapids.tool.profiling._
import org.apache.spark.sql.rapids.tool.util.{FSUtils, SparkRuntime}

Expand Down Expand Up @@ -1117,7 +1118,7 @@ class ApplicationInfoSuite extends FunSuite with Logging {
}

// scalastyle:off line.size.limit
val sparkRuntimeTestCases: Map[String, Seq[(String, SparkRuntime.Value)]] = Map(
val supportedSparkRuntimeTestCases: Map[String, Seq[(String, SparkRuntime.SparkRuntime)]] = Map(
// tests for standard Spark runtime
s"$qualLogDir/nds_q86_test" -> Seq(
(PlatformNames.DATABRICKS_AWS, SparkRuntime.SPARK), // Expected: SPARK on Databricks AWS
Expand All @@ -1132,21 +1133,40 @@ class ApplicationInfoSuite extends FunSuite with Logging {
s"$qualLogDir/nds_q88_photon_db_13_3.zstd" -> Seq(
(PlatformNames.DATABRICKS_AWS, SparkRuntime.PHOTON), // Expected: PHOTON on Databricks AWS
(PlatformNames.DATABRICKS_AZURE, SparkRuntime.PHOTON), // Expected: PHOTON on Databricks Azure
(PlatformNames.ONPREM, SparkRuntime.SPARK), // Expected: Fallback to SPARK on Onprem
(PlatformNames.DATAPROC, SparkRuntime.SPARK) // Expected: Fallback to SPARK on Dataproc
)
)
// scalastyle:on line.size.limit

sparkRuntimeTestCases.foreach { case (logPath, platformRuntimeCases) =>
supportedSparkRuntimeTestCases.foreach { case (logPath, platformRuntimeCases) =>
val baseFileName = logPath.split("/").last
platformRuntimeCases.foreach { case (platform, expectedRuntime) =>
test(s"test eventlog $baseFileName on $platform has runtime: $expectedRuntime") {
test(s"test eventlog $baseFileName on $platform has supported runtime: $expectedRuntime") {
val args = Array("--platform", platform, logPath)
val apps = ToolTestUtils.processProfileApps(args, sparkSession)
assert(apps.size == 1)
assert(apps.head.getSparkRuntime == expectedRuntime)
}
}
}

// scalastyle:off line.size.limit
val unsupportedSparkRuntimeTestCases: Map[String, Seq[String]] = Map(
s"$qualLogDir/nds_q88_photon_db_13_3.zstd" -> Seq(
PlatformNames.ONPREM, // Expected: PHOTON runtime on Onprem is not supported
PlatformNames.DATAPROC // Expected: PHOTON runtime on Dataproc is not supported
)
)
// scalastyle:on line.size.limit

unsupportedSparkRuntimeTestCases.foreach { case (logPath, platformNames) =>
val baseFileName = logPath.split("/").last
platformNames.foreach { platform =>
test(s"test eventlog $baseFileName on $platform has unsupported runtime") {
val args = Array("--platform", platform, logPath)
intercept[UnsupportedSparkRuntimeException] {
ToolTestUtils.processProfileApps(args, sparkSession)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,15 @@ class QualificationSuite extends BaseTestSuite {
}
}

private def runQualificationTest(eventLogs: Array[String], expectFileName: String = "",
private def runQualificationTest(eventLogs: Array[String],
expectFileName: String = "", platformName: String = PlatformNames.DEFAULT,
shouldReturnEmpty: Boolean = false, expectPerSqlFileName: Option[String] = None,
expectedStatus: Option[StatusReportCounts] = None): Unit = {
TrampolineUtil.withTempDir { outpath =>
val qualOutputPrefix = "rapids_4_spark_qualification_output"
val outputArgs = Array(
"--platform",
platformName,
"--output-directory",
outpath.getAbsolutePath())

Expand Down Expand Up @@ -1762,7 +1765,8 @@ class QualificationSuite extends BaseTestSuite {
val logFiles = Array(s"$logDir/nds_q88_photon_db_13_3.zstd") // photon event log
// Status counts: 1 SUCCESS, 0 FAILURE, 0 SKIPPED, 0 UNKNOWN
val expectedStatus = Some(StatusReportCounts(1, 0, 0, 0))
runQualificationTest(logFiles, expectedStatus = expectedStatus)
runQualificationTest(logFiles, platformName = PlatformNames.DATABRICKS_AWS,
expectedStatus = expectedStatus)
}

test("process multiple attempts of the same app ID and skip lower attempts") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Feature: Event Log Processing

@test_id_ELP_0001
Scenario Outline: Tool spark_rapids runs with different types of event logs
Given platform is "<platform>"
When spark-rapids tool is executed with "<event_logs>" eventlogs
Then stderr contains the following
"""
Expand All @@ -25,12 +26,13 @@ Feature: Event Log Processing
And return code is "0"

Examples:
| event_logs | expected_stderr | processed_apps_count |
| invalid_path_eventlog | process.failure.count = 1;invalid_path_eventlog not found, skipping! | 0 |
| gpu_eventlog.zstd | process.skipped.count = 1;GpuEventLogException: Cannot parse event logs from GPU run: skipping this file | 0 |
| photon_eventlog.zstd | process.success.count = 1; | 1 |
| streaming_eventlog.zstd | process.skipped.count = 1;StreamingEventLogException: Encountered Spark Structured Streaming Job: skipping this file! | 0 |
| incorrect_app_status_eventlog.zstd | process.NA.count = 1;IncorrectAppStatusException: Application status is incorrect. Missing AppInfo | 0 |
| platform | event_logs | expected_stderr | processed_apps_count |
| onprem | invalid_path_eventlog | process.failure.count = 1;invalid_path_eventlog not found, skipping! | 0 |
| onprem | gpu_eventlog.zstd | process.skipped.count = 1;GpuEventLogException: Cannot parse event logs from GPU run: skipping this file | 0 |
| onprem | streaming_eventlog.zstd | process.skipped.count = 1;StreamingEventLogException: Encountered Spark Structured Streaming Job: skipping this file! | 0 |
| onprem | incorrect_app_status_eventlog.zstd | process.NA.count = 1;IncorrectAppStatusException: Application status is incorrect. Missing AppInfo | 0 |
| onprem | photon_eventlog.zstd | process.skipped.count = 1;UnsupportedSparkRuntimeException: Platform 'onprem' does not support the runtime 'PHOTON' | 0 |
| databricks-aws | photon_eventlog.zstd | process.success.count = 1; | 1 |

@test_id_ELP_0002
Scenario: Qualification tool JAR crashes
Expand Down

0 comments on commit c4b8a52

Please sign in to comment.