From 6d2fb4eae6e5958be142454e9f1fa7753a12078a Mon Sep 17 00:00:00 2001 From: Partho Sarthi Date: Tue, 28 Nov 2023 16:48:59 -0800 Subject: [PATCH] Fix platform names as string constants and reduce redundancy in unit tests (#667) * Improve string usage of default platform and docs Signed-off-by: Partho Sarthi * Replace redundant tests with for loops Signed-off-by: Partho Sarthi * Fix typo in tests Signed-off-by: Partho Sarthi * Remove creation of new instance for default case Signed-off-by: Partho Sarthi --------- Signed-off-by: Partho Sarthi --- .../nvidia/spark/rapids/tool/Platform.scala | 24 +- .../rapids/tool/profiling/AutoTuner.scala | 13 +- .../rapids/tool/profiling/ProfileArgs.scala | 7 +- .../rapids/tool/profiling/Profiler.scala | 8 +- .../qualification/PluginTypeChecker.scala | 6 +- .../qualification/QualificationArgs.scala | 7 +- .../qualification/QualificationMain.scala | 8 +- .../tool/profiling/AutoTunerSuite.scala | 8 +- .../PluginTypeCheckerSuite.scala | 85 ++--- .../qualification/QualificationSuite.scala | 305 ++---------------- 10 files changed, 92 insertions(+), 379 deletions(-) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala index b38b7db14..67a36e8b9 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids.tool +import scala.annotation.tailrec + import org.apache.spark.internal.Logging /** @@ -33,6 +35,7 @@ object PlatformNames { val EMR_A10 = "emr-a10" val EMR_T4 = "emr-t4" val ONPREM = "onprem" + val DEFAULT: String = ONPREM /** * Return a list of all platform names. @@ -93,6 +96,8 @@ class Platform(platformName: String) { recommendationsToExclude.forall(excluded => !comment.contains(excluded)) } + def getName: String = platformName + def getOperatorScoreFile: String = { s"operatorsScore-$platformName.csv" } @@ -130,25 +135,26 @@ object PlatformFactory extends Logging { * @return An instance of the specified platform. * @throws IllegalArgumentException if the specified platform key is not supported. */ - def createInstance(platformKey: String): Platform = { + @tailrec + def createInstance(platformKey: String = PlatformNames.DEFAULT): Platform = { platformKey match { - case PlatformNames.DATABRICKS_AWS => new DatabricksPlatform(PlatformNames.DATABRICKS_AWS) - case PlatformNames.DATABRICKS_AZURE => new DatabricksPlatform(PlatformNames.DATABRICKS_AZURE) + case PlatformNames.DATABRICKS_AWS | PlatformNames.DATABRICKS_AZURE => + new DatabricksPlatform(platformKey) case PlatformNames.DATAPROC | PlatformNames.DATAPROC_T4 => // if no GPU specified, then default to dataproc-t4 for backward compatibility new DataprocPlatform(PlatformNames.DATAPROC_T4) - case PlatformNames.DATAPROC_L4 => new DataprocPlatform(PlatformNames.DATAPROC_L4) - case PlatformNames.DATAPROC_SL_L4 => new DataprocPlatform(PlatformNames.DATAPROC_SL_L4) - case PlatformNames.DATAPROC_GKE_L4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_L4) - case PlatformNames.DATAPROC_GKE_T4 => new DataprocPlatform(PlatformNames.DATAPROC_GKE_T4) + case PlatformNames.DATAPROC_L4 | PlatformNames.DATAPROC_SL_L4 | + PlatformNames.DATAPROC_GKE_L4 | PlatformNames.DATAPROC_GKE_T4 => + new DataprocPlatform(platformKey) case PlatformNames.EMR | PlatformNames.EMR_T4 => // if no GPU specified, then default to emr-t4 for backward compatibility new EmrPlatform(PlatformNames.EMR_T4) case PlatformNames.EMR_A10 => new EmrPlatform(PlatformNames.EMR_A10) case PlatformNames.ONPREM => new OnPremPlatform case p if p.isEmpty => - logInfo(s"Platform is not specified. Using ${PlatformNames.ONPREM} as default.") - new OnPremPlatform + logInfo(s"Platform is not specified. Using ${PlatformNames.DEFAULT} " + + "as default.") + PlatformFactory.createInstance(PlatformNames.DEFAULT) case _ => throw new IllegalArgumentException(s"Unsupported platform: $platformKey. " + s"Options include ${PlatformNames.getAllNames.mkString(", ")}.") } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala index 5e3475604..e0f923bfb 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala @@ -330,7 +330,7 @@ class RecommendationEntry(val name: String, class AutoTuner( val clusterProps: ClusterProperties, val appInfoProvider: AppSummaryInfoBaseProvider, - val platform: String) extends Logging { + val platform: Platform) extends Logging { import AutoTuner._ @@ -344,7 +344,6 @@ class AutoTuner( private val limitedLogicRecommendations: mutable.HashSet[String] = mutable.HashSet[String]() // When enabled, the profiler recommendations should only include updated settings. private var filterByUpdatedPropertiesEnabled: Boolean = true - val selectedPlatform: Platform = PlatformFactory.createInstance(platform) private def isCalculationEnabled(prop: String) : Boolean = { !limitedLogicRecommendations.contains(prop) @@ -908,7 +907,7 @@ class AutoTuner( limitedSeq.foreach(_ => limitedLogicRecommendations.add(_)) } skipList.foreach(skipSeq => skipSeq.foreach(_ => skippedRecommendations.add(_))) - skippedRecommendations ++= selectedPlatform.recommendationsToExclude + skippedRecommendations ++= platform.recommendationsToExclude initRecommendations() calculateJobLevelRecommendations() if (processPropsAndCheck) { @@ -918,7 +917,7 @@ class AutoTuner( addDefaultComments() } // add all platform specific recommendations - selectedPlatform.recommendationsToInclude.foreach { + platform.recommendationsToInclude.foreach { case (property, value) => appendRecommendation(property, value) } } @@ -1024,7 +1023,7 @@ object AutoTuner extends Logging { private def handleException( ex: Exception, appInfo: AppSummaryInfoBaseProvider, - platform: String): AutoTuner = { + platform: Platform): AutoTuner = { logError("Exception: " + ex.getStackTrace.mkString("Array(", ", ", ")")) val tuning = new AutoTuner(new ClusterProperties(), appInfo, platform) val msg = ex match { @@ -1076,7 +1075,7 @@ object AutoTuner extends Logging { def buildAutoTunerFromProps( clusterProps: String, singleAppProvider: AppSummaryInfoBaseProvider, - platform: String = Profiler.DEFAULT_PLATFORM): AutoTuner = { + platform: Platform = PlatformFactory.createInstance()): AutoTuner = { try { val clusterPropsOpt = loadClusterPropertiesFromContent(clusterProps) new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform) @@ -1089,7 +1088,7 @@ object AutoTuner extends Logging { def buildAutoTuner( filePath: String, singleAppProvider: AppSummaryInfoBaseProvider, - platform: String = Profiler.DEFAULT_PLATFORM): AutoTuner = { + platform: Platform = PlatformFactory.createInstance()): AutoTuner = { try { val clusterPropsOpt = loadClusterProps(filePath) new AutoTuner(clusterPropsOpt.getOrElse(new ClusterProperties()), singleAppProvider, platform) diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileArgs.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileArgs.scala index c40d6bdca..3723e4aea 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileArgs.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileArgs.scala @@ -15,7 +15,7 @@ */ package com.nvidia.spark.rapids.tool.profiling -import com.nvidia.spark.rapids.tool.PlatformNames +import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames} import org.rogach.scallop.{ScallopConf, ScallopOption} import org.rogach.scallop.exceptions.ScallopException @@ -71,8 +71,9 @@ Usage: java -cp rapids-4-spark-tools_2.12-.jar:$SPARK_HOME/jars/* val platform: ScallopOption[String] = opt[String](required = false, descr = "Cluster platform where Spark GPU workloads were executed. Options include " + - s"${PlatformNames.getAllNames.mkString(", ")}. Default is ${PlatformNames.ONPREM}.", - default = Some(PlatformNames.ONPREM)) + s"${PlatformNames.getAllNames.mkString(", ")}. " + + s"Default is ${PlatformNames.DEFAULT}.", + default = Some(PlatformNames.DEFAULT)) val generateTimeline: ScallopOption[Boolean] = opt[Boolean](required = false, descr = "Write an SVG graph out for the full application timeline.") diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala index 57108ec33..ae99e7343 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal import com.nvidia.spark.rapids.ThreadFactoryBuilder -import com.nvidia.spark.rapids.tool.{EventLogInfo, EventLogPathProcessor, PlatformNames} +import com.nvidia.spark.rapids.tool.{EventLogInfo, EventLogPathProcessor, PlatformFactory} import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging @@ -511,9 +511,10 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea if (useAutoTuner) { val workerInfoPath = appArgs.workerInfo.getOrElse(AutoTuner.DEFAULT_WORKER_INFO_PATH) - val platform = appArgs.platform.getOrElse(Profiler.DEFAULT_PLATFORM) + val platform = appArgs.platform() val autoTuner: AutoTuner = AutoTuner.buildAutoTuner(workerInfoPath, - new SingleAppSummaryInfoProvider(app), platform) + new SingleAppSummaryInfoProvider(app), + PlatformFactory.createInstance(platform)) // the autotuner allows skipping some properties // e.g. getRecommendedProperties(Some(Seq("spark.executor.instances"))) skips the // recommendation related to executor instances. @@ -548,7 +549,6 @@ object Profiler { val COMPARE_LOG_FILE_NAME_PREFIX = "rapids_4_spark_tools_compare" val COMBINED_LOG_FILE_NAME_PREFIX = "rapids_4_spark_tools_combined" val SUBDIR = "rapids_4_spark_profile" - val DEFAULT_PLATFORM: String = PlatformNames.ONPREM def getAutoTunerResultsAsString(props: Seq[RecommendedPropertyResult], comments: Seq[RecommendedCommentResult]): String = { diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala index eead523d6..0ff5bd614 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeChecker.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.{ArrayBuffer,HashMap} import scala.io.{BufferedSource, Source} import scala.util.control.NonFatal -import com.nvidia.spark.rapids.tool.PlatformFactory +import com.nvidia.spark.rapids.tool.{Platform, PlatformFactory} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} @@ -33,7 +33,7 @@ import org.apache.spark.internal.Logging * by the plugin which lists the formats and types supported. * The class also supports a custom speedup factor file as input. */ -class PluginTypeChecker(platform: String = "onprem", +class PluginTypeChecker(platform: Platform = PlatformFactory.createInstance(), speedupFactorFile: Option[String] = None) extends Logging { private val NS = "NS" @@ -92,7 +92,7 @@ class PluginTypeChecker(platform: String = "onprem", speedupFactorFile match { case None => logInfo(s"Reading operators scores with platform: $platform") - val file = PlatformFactory.createInstance(platform).getOperatorScoreFile + val file = platform.getOperatorScoreFile val source = Source.fromResource(file) readSupportedOperators(source, "score").map(x => (x._1, x._2.toDouble)) case Some(file) => diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala index 509402c9d..24942b4b9 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationArgs.scala @@ -15,7 +15,7 @@ */ package com.nvidia.spark.rapids.tool.qualification -import com.nvidia.spark.rapids.tool.PlatformNames +import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames} import org.rogach.scallop.{ScallopConf, ScallopOption} import org.rogach.scallop.exceptions.ScallopException @@ -156,8 +156,9 @@ Usage: java -cp rapids-4-spark-tools_2.12-.jar:$SPARK_HOME/jars/* val platform: ScallopOption[String] = opt[String](required = false, descr = "Cluster platform where Spark CPU workloads were executed. Options include " + - s"${PlatformNames.getAllNames.mkString(", ")}. Default is ${PlatformNames.ONPREM}.", - default = Some(PlatformNames.ONPREM)) + s"${PlatformNames.getAllNames.mkString(", ")}. " + + s"Default is ${PlatformNames.DEFAULT}.", + default = Some(PlatformNames.DEFAULT)) val speedupFactorFile: ScallopOption[String] = opt[String](required = false, descr = "Custom speedup factor file used to get estimated GPU speedup that is specific " + diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationMain.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationMain.scala index cb8a3c583..454b27695 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationMain.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/qualification/QualificationMain.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.tool.qualification -import com.nvidia.spark.rapids.tool.EventLogPathProcessor +import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformFactory} import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.tool.AppFilterImpl @@ -58,14 +58,16 @@ object QualificationMain extends Logging { val order = appArgs.order.getOrElse("desc") val uiEnabled = appArgs.htmlReport.getOrElse(false) val reportSqlLevel = appArgs.perSql.getOrElse(false) - val platform = appArgs.platform.getOrElse("onprem") + val platform = appArgs.platform() val mlOpsEnabled = appArgs.mlFunctions.getOrElse(false) val penalizeTransitions = appArgs.penalizeTransitions.getOrElse(true) val hadoopConf = RapidsToolsConfUtil.newHadoopConf val pluginTypeChecker = try { - new PluginTypeChecker(platform, appArgs.speedupFactorFile.toOption) + new PluginTypeChecker( + PlatformFactory.createInstance(platform), + appArgs.speedupFactorFile.toOption) } catch { case ie: IllegalStateException => logError("Error creating the plugin type checker!", ie) diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/AutoTunerSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/AutoTunerSuite.scala index 0428005f7..670e60958 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/AutoTunerSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/AutoTunerSuite.scala @@ -21,6 +21,7 @@ import java.util import scala.collection.JavaConverters._ import scala.collection.mutable +import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames} import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.scalatest.Matchers.convertToAnyShouldWrapper import org.yaml.snakeyaml.{DumperOptions, Yaml} @@ -1285,14 +1286,15 @@ class AutoTunerSuite extends FunSuite with BeforeAndAfterEach with Logging { test("test recommendations for databricks-aws platform argument") { val databricksWorkerInfo = buildWorkerInfoAsString() + val platform = PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS) val autoTuner = AutoTuner.buildAutoTunerFromProps(databricksWorkerInfo, - getGpuAppMockInfoProvider, "databricks-aws") + getGpuAppMockInfoProvider, platform) val (properties, comments) = autoTuner.getRecommendedProperties() // Assert recommendations are excluded in properties - assert(properties.map(_.property).forall(autoTuner.selectedPlatform.isValidRecommendation)) + assert(properties.map(_.property).forall(autoTuner.platform.isValidRecommendation)) // Assert recommendations are skipped in comments - assert(comments.map(_.comment).forall(autoTuner.selectedPlatform.isValidComment)) + assert(comments.map(_.comment).forall(autoTuner.platform.isValidComment)) } // When spark is running as a standalone, the memoryOverhead should not be listed as a diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeCheckerSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeCheckerSuite.scala index 61e8acf40..9a3640986 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeCheckerSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeCheckerSuite.scala @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids.tool.qualification import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} -import com.nvidia.spark.rapids.tool.ToolTestUtils +import com.nvidia.spark.rapids.tool.{PlatformFactory, PlatformNames, ToolTestUtils} import com.nvidia.spark.rapids.tool.planparser.DataWritingCommandExecParser import org.scalatest.FunSuite @@ -153,68 +153,33 @@ class PluginTypeCheckerSuite extends FunSuite with Logging { assert(result(2) == "ORC") } - test("supported operator score from onprem") { - val checker = new PluginTypeChecker("onprem") - assert(checker.getSpeedupFactor("UnionExec") == 3.0) - assert(checker.getSpeedupFactor("Ceil") == 4) - } - - test("supported operator score from dataproc-t4") { - val checker = new PluginTypeChecker("dataproc-t4") - assert(checker.getSpeedupFactor("UnionExec") == 4.88) - assert(checker.getSpeedupFactor("Ceil") == 4.88) - } - - test("supported operator score from emr-t4") { - val checker = new PluginTypeChecker("emr-t4") - assert(checker.getSpeedupFactor("UnionExec") == 2.07) - assert(checker.getSpeedupFactor("Ceil") == 2.07) - } - - test("supported operator score from databricks-aws") { - val checker = new PluginTypeChecker("databricks-aws") - assert(checker.getSpeedupFactor("UnionExec") == 2.45) - assert(checker.getSpeedupFactor("Ceil") == 2.45) - } - - test("supported operator score from databricks-azure") { - val checker = new PluginTypeChecker("databricks-azure") - assert(checker.getSpeedupFactor("UnionExec") == 2.73) - assert(checker.getSpeedupFactor("Ceil") == 2.73) - } - - test("supported operator score from dataproc-serverless-l4") { - val checker = new PluginTypeChecker("dataproc-serverless-l4") - assert(checker.getSpeedupFactor("WindowExec") == 4.25) - assert(checker.getSpeedupFactor("Ceil") == 4.25) - } - - test("supported operator score from dataproc-l4") { - val checker = new PluginTypeChecker("dataproc-l4") - assert(checker.getSpeedupFactor("UnionExec") == 4.16) - assert(checker.getSpeedupFactor("Ceil") == 4.16) - } - - test("supported operator score from dataproc-gke-t4") { - val checker = new PluginTypeChecker("dataproc-gke-t4") - assert(checker.getSpeedupFactor("WindowExec") == 3.65) - assert(checker.getSpeedupFactor("Ceil") == 3.65) - } - - test("supported operator score from dataproc-gke-l4") { - val checker = new PluginTypeChecker("dataproc-gke-l4") - assert(checker.getSpeedupFactor("WindowExec") == 3.74) - assert(checker.getSpeedupFactor("Ceil") == 3.74) - } - - test("supported operator score from emr-a10") { - val checker = new PluginTypeChecker("emr-a10") - assert(checker.getSpeedupFactor("UnionExec") == 2.59) - assert(checker.getSpeedupFactor("Ceil") == 2.59) + val platformSpeedupEntries: Seq[(String, Map[String, Double])] = Seq( + (PlatformNames.ONPREM, Map("UnionExec" -> 3.0, "Ceil" -> 4.0)), + (PlatformNames.DATAPROC_T4, Map("UnionExec" -> 4.88, "Ceil" -> 4.88)), + (PlatformNames.EMR_T4, Map("UnionExec" -> 2.07, "Ceil" -> 2.07)), + (PlatformNames.DATABRICKS_AWS, Map("UnionExec" -> 2.45, "Ceil" -> 2.45)), + (PlatformNames.DATABRICKS_AZURE, Map("UnionExec" -> 2.73, "Ceil" -> 2.73)), + (PlatformNames.DATAPROC_SL_L4, Map("WindowExec" -> 4.25, "Ceil" -> 4.25)), + (PlatformNames.DATAPROC_L4, Map("UnionExec" -> 4.16, "Ceil" -> 4.16)), + (PlatformNames.DATAPROC_GKE_T4, Map("WindowExec" -> 3.65, "Ceil" -> 3.65)), + (PlatformNames.DATAPROC_GKE_L4, Map("WindowExec" -> 3.74, "Ceil" -> 3.74)), + (PlatformNames.EMR_A10, Map("UnionExec" -> 2.59, "Ceil" -> 2.59)) + ) + + platformSpeedupEntries.foreach { case (platformName, speedupMap) => + test(s"supported operator score from $platformName") { + val platform = PlatformFactory.createInstance(platformName) + val checker = new PluginTypeChecker(platform) + speedupMap.foreach { case (operator, speedup) => + assert(checker.getSpeedupFactor(operator) == speedup) + } + } } test("supported operator score from custom speedup factor file") { - val speedupFactorFile = ToolTestUtils.getTestResourcePath("operatorsScore-databricks-azure.csv") + // Using databricks azure speedup factor as custom file + val platform = PlatformFactory.createInstance(PlatformNames.DATABRICKS_AZURE) + val speedupFactorFile = ToolTestUtils.getTestResourcePath(platform.getOperatorScoreFile) val checker = new PluginTypeChecker(speedupFactorFile=Some(speedupFactorFile)) assert(checker.getSpeedupFactor("SortExec") == 13.11) assert(checker.getSpeedupFactor("FilterExec") == 3.14) diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/QualificationSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/QualificationSuite.scala index 056db3c6e..08c0fead2 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/QualificationSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/QualificationSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.io.Source import com.nvidia.spark.rapids.BaseTestSuite -import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, StatusReportCounts, ToolTestUtils} +import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformNames, StatusReportCounts, ToolTestUtils} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.ml.feature.PCA @@ -1337,292 +1337,29 @@ class QualificationSuite extends BaseTestSuite { spark.sql("SELECT id, hour(current_timestamp()), second(to_timestamp(timestamp)) FROM t1") } - // run the qualification tool for onprem - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "onprem", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for emr. It should default to emr-t4. - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "emr", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for emr-t4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "emr-t4", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for emr-a10 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "emr-a10", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc. It should default to dataproc-t4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-t4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-t4", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-l4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-l4", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-serverless-l4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-serverless-l4", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-gke-t4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-gke-t4", - eventLog)) - - val (exit, _) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-gke-l4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-gke-l4", - eventLog)) - - val (exit, _) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for databricks-aws - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "databricks-aws", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for databricks-azure - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "databricks-azure", - eventLog)) + PlatformNames.getAllNames.foreach { platform => + // run the qualification tool for each platform + TrampolineUtil.withTempDir { outPath => + val appArgs = new QualificationArgs(Array( + "--output-directory", + outPath.getAbsolutePath, + "--platform", + platform, + eventLog)) - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) + val (exit, _) = QualificationMain.mainInternal(appArgs) + assert(exit == 0) - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() + // the code above that runs the Spark query stops the Spark Session, + // so create a new one to read in the csv file + createSparkSession() - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) + // validate that the SQL description in the csv file escapes commas properly + val outputResults = s"$outPath/rapids_4_spark_qualification_output/" + + s"rapids_4_spark_qualification_output.csv" + val outputActual = readExpectedFile(new File(outputResults)) + assert(outputActual.collect().length == 1) + } } } }