From 99efae25be6ade1ad00fb8850bd94e3154a89199 Mon Sep 17 00:00:00 2001 From: Partho Sarthi Date: Fri, 17 Nov 2023 14:52:34 -0800 Subject: [PATCH] Replace redundant tests with for loops Signed-off-by: Partho Sarthi --- .../PluginTypeCheckerSuite.scala | 89 ++--- .../qualification/QualificationSuite.scala | 305 ++---------------- 2 files changed, 42 insertions(+), 352 deletions(-) diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeCheckerSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeCheckerSuite.scala index 17755f64f..7ee9bc9fb 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeCheckerSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/PluginTypeCheckerSuite.scala @@ -153,74 +153,27 @@ class PluginTypeCheckerSuite extends FunSuite with Logging { assert(result(2) == "ORC") } - test("supported operator score from onprem") { - val platform = PlatformFactory.createInstance(PlatformNames.ONPREM) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("UnionExec") == 3.0) - assert(checker.getSpeedupFactor("Ceil") == 4) - } - - test("supported operator score from dataproc-t4") { - val platform = PlatformFactory.createInstance(PlatformNames.DATAPROC_T4) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("UnionExec") == 4.88) - assert(checker.getSpeedupFactor("Ceil") == 4.88) - } - - test("supported operator score from emr-t4") { - val platform = PlatformFactory.createInstance(PlatformNames.EMR_T4) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("UnionExec") == 2.07) - assert(checker.getSpeedupFactor("Ceil") == 2.07) - } - - test("supported operator score from databricks-aws") { - val platform = PlatformFactory.createInstance(PlatformNames.DATABRICKS_AWS) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("UnionExec") == 2.45) - assert(checker.getSpeedupFactor("Ceil") == 2.45) - } - - test("supported operator score from databricks-azure") { - val platform = PlatformFactory.createInstance(PlatformNames.DATABRICKS_AZURE) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("UnionExec") == 2.73) - assert(checker.getSpeedupFactor("Ceil") == 2.73) - } - - test("supported operator score from dataproc-serverless-l4") { - val platform = PlatformFactory.createInstance(PlatformNames.DATAPROC_SL_L4) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("WindowExec") == 4.25) - assert(checker.getSpeedupFactor("Ceil") == 4.25) - } - - test("supported operator score from dataproc-l4") { - val platform = PlatformFactory.createInstance(PlatformNames.DATAPROC_L4) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("UnionExec") == 4.16) - assert(checker.getSpeedupFactor("Ceil") == 4.16) - } - - test("supported operator score from dataproc-gke-t4") { - val platform = PlatformFactory.createInstance(PlatformNames.DATAPROC_GKE_T4) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("WindowExec") == 3.65) - assert(checker.getSpeedupFactor("Ceil") == 3.65) - } - - test("supported operator score from dataproc-gke-l4") { - val platform = PlatformFactory.createInstance(PlatformNames.DATAPROC_GKE_L4) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("WindowExec") == 3.74) - assert(checker.getSpeedupFactor("Ceil") == 3.74) - } - - test("supported operator score from emr-a10") { - val platform = PlatformFactory.createInstance(PlatformNames.EMR_A10) - val checker = new PluginTypeChecker(platform) - assert(checker.getSpeedupFactor("UnionExec") == 2.59) - assert(checker.getSpeedupFactor("Ceil") == 2.59) + val platformSpeedupEntries: Seq[(String, Map[String, Double])] = Seq( + (PlatformNames.ONPREM, Map("UnionExec" -> 3.0, "Ceil" -> 4.0)), + (PlatformNames.DATAPROC_T4, Map("UnionExec" -> 4.88, "Ceil" -> 4.88)), + (PlatformNames.EMR_T4, Map("UnionExec" -> 2.07, "Ceil" -> 2.07)), + (PlatformNames.DATABRICKS_AWS, Map("UnionExec" -> 2.45, "Ceil" -> 2.45)), + (PlatformNames.DATABRICKS_AZURE, Map("UnionExec" -> 2.73, "Ceil" -> 2.73)), + (PlatformNames.DATAPROC_SL_L4, Map("WindowExec" -> 4.25, "Ceil" -> 4.25)), + (PlatformNames.DATAPROC_L4, Map("UnionExec" -> 4.16, "Ceil" -> 4.16)), + (PlatformNames.DATAPROC_GKE_T4, Map("WindowExec" -> 3.65, "Ceil" -> 3.65)), + (PlatformNames.DATAPROC_GKE_L4, Map("WindowExec" -> 3.74, "Ceil" -> 3.74)), + (PlatformNames.EMR_A10, Map("UnionExec" -> 2.59, "Ceil" -> 2.59)), + ) + + platformSpeedupEntries.foreach { case (platformName, speedupMap) => + test(s"supported operator score from $platformName") { + val platform = PlatformFactory.createInstance(platformName) + val checker = new PluginTypeChecker(platform) + speedupMap.foreach { case (operator, speedup) => + assert(checker.getSpeedupFactor(operator) == speedup) + } + } } test("supported operator score from custom speedup factor file") { diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/QualificationSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/QualificationSuite.scala index 056db3c6e..08c0fead2 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/QualificationSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/qualification/QualificationSuite.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.io.Source import com.nvidia.spark.rapids.BaseTestSuite -import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, StatusReportCounts, ToolTestUtils} +import com.nvidia.spark.rapids.tool.{EventLogPathProcessor, PlatformNames, StatusReportCounts, ToolTestUtils} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.ml.feature.PCA @@ -1337,292 +1337,29 @@ class QualificationSuite extends BaseTestSuite { spark.sql("SELECT id, hour(current_timestamp()), second(to_timestamp(timestamp)) FROM t1") } - // run the qualification tool for onprem - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "onprem", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for emr. It should default to emr-t4. - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "emr", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for emr-t4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "emr-t4", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for emr-a10 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "emr-a10", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc. It should default to dataproc-t4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-t4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-t4", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-l4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-l4", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-serverless-l4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-serverless-l4", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-gke-t4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-gke-t4", - eventLog)) - - val (exit, _) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for dataproc-gke-l4 - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "dataproc-gke-l4", - eventLog)) - - val (exit, _) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for databricks-aws - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "databricks-aws", - eventLog)) - - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) - - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() - - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) - } - - // run the qualification tool for databricks-azure - TrampolineUtil.withTempDir { outpath => - val appArgs = new QualificationArgs(Array( - "--output-directory", - outpath.getAbsolutePath, - "--platform", - "databricks-azure", - eventLog)) + PlatformNames.getAllNames.foreach { platform => + // run the qualification tool for each platform + TrampolineUtil.withTempDir { outPath => + val appArgs = new QualificationArgs(Array( + "--output-directory", + outPath.getAbsolutePath, + "--platform", + platform, + eventLog)) - val (exit, sumInfo) = - QualificationMain.mainInternal(appArgs) - assert(exit == 0) + val (exit, _) = QualificationMain.mainInternal(appArgs) + assert(exit == 0) - // the code above that runs the Spark query stops the Sparksession - // so create a new one to read in the csv file - createSparkSession() + // the code above that runs the Spark query stops the Spark Session, + // so create a new one to read in the csv file + createSparkSession() - // validate that the SQL description in the csv file escapes commas properly - val outputResults = s"$outpath/rapids_4_spark_qualification_output/" + - s"rapids_4_spark_qualification_output.csv" - val outputActual = readExpectedFile(new File(outputResults)) - assert(outputActual.collect().size == 1) + // validate that the SQL description in the csv file escapes commas properly + val outputResults = s"$outPath/rapids_4_spark_qualification_output/" + + s"rapids_4_spark_qualification_output.csv" + val outputActual = readExpectedFile(new File(outputResults)) + assert(outputActual.collect().length == 1) + } } } }