From ad9d0a547bd2dec759b405feb5ceb738b7655477 Mon Sep 17 00:00:00 2001
From: "Ahmed Hussein (amahussein)"
Date: Wed, 18 Dec 2024 16:33:01 -0600
Subject: [PATCH] Improve implementation of finding median in StatisticsMetrics
Signed-off-by: Ahmed Hussein (amahussein)
Fixes #1461
Adds an InPlace median finding to improve the performance of the metric
aggregates.
We used to sort a sequence to create StatisticsMetrics which turned out
to be very expensive in large eventlogs.
Signed-off-by: Ahmed Hussein (amahussein)
---
.../tool/analysis/AppSQLPlanAnalyzer.scala | 84 ++++------
.../analysis/AppSparkMetricsAnalyzer.scala | 48 ++----
.../tool/analysis/StatisticsMetrics.scala | 29 ++++
.../sql/rapids/tool/store/AccumInfo.scala | 20 +--
.../tool/util/InPlaceMedianArrView.scala | 150 ++++++++++++++++++
.../rapids/tool/util/ToolUtilsSuite.scala | 25 ++-
6 files changed, 251 insertions(+), 105 deletions(-)
create mode 100644 core/src/main/scala/org/apache/spark/sql/rapids/tool/util/InPlaceMedianArrView.scala
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala
index 9580aa470..5cc645695 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSQLPlanAnalyzer.scala
@@ -16,6 +16,7 @@
package com.nvidia.spark.rapids.tool.analysis
+import scala.collection.breakOut
import scala.collection.mutable.{AbstractSet, ArrayBuffer, HashMap, LinkedHashSet}
import com.nvidia.spark.rapids.tool.profiling.{AccumProfileResults, SQLAccumProfileResults, SQLMetricInfoCase, SQLStageInfoProfileResult, UnsupportedSQLPlan, WholeStageCodeGenResults}
@@ -29,6 +30,8 @@ import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
import org.apache.spark.sql.rapids.tool.store.DataSourceRecord
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph
+
+
/**
* This class processes SQL plan to build some information such as: metrics, wholeStage nodes, and
* connecting operators to nodes. The implementation used to be directly under Profiler's
@@ -265,7 +268,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
val jobsWithSQL = app.jobIdToInfo.filter { case (_, j) =>
j.sqlID.nonEmpty
}
- val sqlToStages = jobsWithSQL.flatMap { case (jobId, j) =>
+ jobsWithSQL.flatMap { case (jobId, j) =>
val stages = j.stageIds
val stagesInJob = app.stageManager.getStagesByIds(stages)
stagesInJob.map { sModel =>
@@ -283,8 +286,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
SQLStageInfoProfileResult(appIndex, j.sqlID.get, jobId, sModel.stageInfo.stageId,
sModel.stageInfo.attemptNumber(), sModel.duration, nodeNames)
}
- }
- sqlToStages.toSeq
+ }(breakOut)
}
def generateSQLAccums(): Seq[SQLAccumProfileResults] = {
@@ -294,20 +296,11 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
val driverAccumsOpt = app.driverAccumMap.get(metric.accumulatorId)
val driverMax = driverAccumsOpt match {
case Some(accums) =>
- val filtered = accums.filter { a =>
- a.sqlID == metric.sqlID
- }
- val accumValues = filtered.map(_.value).sortWith(_ < _)
- if (accumValues.isEmpty) {
- None
- } else if (accumValues.length <= 1) {
- Some(StatisticsMetrics(0L, 0L, 0L, accumValues.sum))
- } else {
- Some(StatisticsMetrics(accumValues(0), accumValues(accumValues.size / 2),
- accumValues(accumValues.size - 1), accumValues.sum))
- }
- case None =>
- None
+ StatisticsMetrics.createOptionalFromArr(accums.collect {
+ case a if a.sqlID == metric.sqlID =>
+ a.value
+ }(breakOut))
+ case _ => None
}
if (accumTaskStats.isDefined || driverMax.isDefined) {
@@ -325,7 +318,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
} else {
None
}
- }
+ }(breakOut)
}
/**
@@ -341,40 +334,31 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
def generateStageLevelAccums(): Seq[AccumProfileResults] = {
app.accumManager.accumInfoMap.flatMap { accumMapEntry =>
val accumInfo = accumMapEntry._2
- accumInfo.stageValuesMap.keySet.flatMap( stageId => {
- val stageTaskIds = app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId).toSet
- // get the task updates that belong to that stage
- val taskUpatesSubset =
- accumInfo.taskUpdatesMap.filterKeys(stageTaskIds.contains).values.toSeq.sorted
- if (taskUpatesSubset.isEmpty) {
- None
- } else {
- val min = taskUpatesSubset.head
- val max = taskUpatesSubset.last
- val sum = taskUpatesSubset.sum
- val median = if (taskUpatesSubset.size % 2 == 0) {
- val mid = taskUpatesSubset.size / 2
- (taskUpatesSubset(mid) + taskUpatesSubset(mid - 1)) / 2
- } else {
- taskUpatesSubset(taskUpatesSubset.size / 2)
- }
- // reuse AccumProfileResults to avoid generating extra memory from allocating new objects
- val accumProfileResults = AccumProfileResults(
- appIndex,
- stageId,
- accumInfo.infoRef,
- min = min,
- median = median,
- max = max,
- total = sum)
- if (accumInfo.infoRef.name.isDiagnosticMetrics()) {
- updateStageDiagnosticMetrics(accumProfileResults)
- }
- Some(accumProfileResults)
+ accumInfo.stageValuesMap.keys.flatMap( stageId => {
+ val stageTaskIds: Set[Long] =
+ app.taskManager.getAllTasksStageAttempt(stageId).map(_.taskId)(breakOut)
+ // Get the task updates that belong to that stage
+ StatisticsMetrics.createOptionalFromArr(
+ accumInfo.taskUpdatesMap.filterKeys(stageTaskIds).map(_._2)(breakOut)) match {
+ case Some(stat) =>
+ // Reuse AccumProfileResults to avoid generating allocating new objects
+ val accumProfileResults = AccumProfileResults(
+ appIndex,
+ stageId,
+ accumInfo.infoRef,
+ min = stat.min,
+ median = stat.med,
+ max = stat.max,
+ total = stat.total)
+ if (accumInfo.infoRef.name.isDiagnosticMetrics()) {
+ updateStageDiagnosticMetrics(accumProfileResults)
+ }
+ Some(accumProfileResults)
+ case _ => None
}
})
- }
- }.toSeq
+ }(breakOut)
+ }
}
object AppSQLPlanAnalyzer {
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSparkMetricsAnalyzer.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSparkMetricsAnalyzer.scala
index 3a862097b..6b8c3d5e5 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSparkMetricsAnalyzer.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSparkMetricsAnalyzer.scala
@@ -16,6 +16,7 @@
package com.nvidia.spark.rapids.tool.analysis
+import scala.collection.breakOut
import scala.collection.mutable.{ArrayBuffer, HashMap, LinkedHashMap}
import com.nvidia.spark.rapids.tool.analysis.StageAccumDiagnosticMetrics._
@@ -79,7 +80,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
* @return sequence of JobAggTaskMetricsProfileResult that contains only Job Ids
*/
def aggregateSparkMetricsByJob(index: Int): Seq[JobAggTaskMetricsProfileResult] = {
- val jobRows = app.jobIdToInfo.flatMap { case (id, jc) =>
+ app.jobIdToInfo.flatMap { case (id, jc) =>
if (jc.stageIds.isEmpty) {
None
} else {
@@ -126,8 +127,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
perJobRec.swWriteTimeSum))
}
}
- }
- jobRows.toSeq
+ }(breakOut)
}
private case class AverageStageInfo(avgDuration: Double, avgShuffleReadBytes: Double)
@@ -163,7 +163,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
tc.taskId, tc.attempt, tc.duration, avg.avgDuration, tc.sr_totalBytesRead,
avg.avgShuffleReadBytes, tc.peakExecutionMemory, tc.successful, tc.endReason)
}
- }.toSeq
+ }(breakOut)
}
/**
@@ -172,7 +172,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
* @return sequence of SQLTaskAggMetricsProfileResult
*/
def aggregateSparkMetricsBySql(index: Int): Seq[SQLTaskAggMetricsProfileResult] = {
- val sqlRows = app.sqlIdToInfo.flatMap { case (sqlId, sqlCase) =>
+ app.sqlIdToInfo.flatMap { case (sqlId, sqlCase) =>
if (app.sqlIdToStages.contains(sqlId)) {
val stagesInSQL = app.sqlIdToStages(sqlId)
// TODO: Should we only consider successful tasks?
@@ -229,8 +229,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
} else {
None
}
- }
- sqlRows.toSeq
+ }(breakOut)
}
/**
@@ -241,7 +240,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
*/
def aggregateIOMetricsBySql(
sqlMetricsAggs: Seq[SQLTaskAggMetricsProfileResult]): Seq[IOAnalysisProfileResult] = {
- val sqlIORows = sqlMetricsAggs.map { sqlAgg =>
+ sqlMetricsAggs.map { sqlAgg =>
IOAnalysisProfileResult(sqlAgg.appIndex,
app.appId,
sqlAgg.sqlId,
@@ -253,8 +252,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
sqlAgg.memoryBytesSpilledSum,
sqlAgg.srTotalBytesReadSum,
sqlAgg.swBytesWrittenSum)
- }
- sqlIORows
+ }(breakOut)
}
/**
@@ -289,7 +287,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
* @return a sequence of SQLDurationExecutorTimeProfileResult or Empty if None.
*/
def aggregateDurationAndCPUTimeBySql(index: Int): Seq[SQLDurationExecutorTimeProfileResult] = {
- val sqlRows = app.sqlIdToInfo.map { case (sqlId, sqlCase) =>
+ app.sqlIdToInfo.map { case (sqlId, sqlCase) =>
// First, build the SQLIssues string by retrieving the potential issues from the
// app.sqlIDtoProblematic map.
val sqlIssues = if (app.sqlIDtoProblematic.contains(sqlId)) {
@@ -301,8 +299,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
SQLDurationExecutorTimeProfileResult(index, app.appId, sqlCase.rootExecutionID,
sqlId, sqlCase.duration, sqlCase.hasDatasetOrRDD,
app.getAppDuration.orElse(Option(0L)), sqlIssues, sqlCase.sqlCpuTimePercent)
- }
- sqlRows.toSeq
+ }(breakOut)
}
/**
@@ -338,7 +335,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
.getOrElse(sm.stageInfo.stageId, emptyDiagnosticMetrics)
.withDefaultValue(zeroAccumProfileResults)
val srTotalBytesMetrics =
- AppSparkMetricsAnalyzer.getStatistics(tasksInStage.map(_.sr_totalBytesRead))
+ StatisticsMetrics.createFromArr(tasksInStage.map(_.sr_totalBytesRead)(breakOut))
StageDiagnosticResult(index,
app.getAppName,
@@ -359,7 +356,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
diagnosticMetricsMap(SW_WRITE_TIME_METRIC),
diagnosticMetricsMap(GPU_SEMAPHORE_WAIT_METRIC),
nodeNames)
- }.toSeq
+ }(breakOut)
}
/**
@@ -456,24 +453,3 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
}
}
}
-
-
-object AppSparkMetricsAnalyzer {
- /**
- * Given an input iterable, returns its min, median, max and sum.
- */
- def getStatistics(arr: Iterable[Long]): StatisticsMetrics = {
- if (arr.isEmpty) {
- StatisticsMetrics(0L, 0L, 0L, 0L)
- } else {
- val sortedArr = arr.toSeq.sorted
- val len = sortedArr.size
- val med = if (len % 2 == 0) {
- (sortedArr(len / 2) + sortedArr(len / 2 - 1)) / 2
- } else {
- sortedArr(len / 2)
- }
- StatisticsMetrics(sortedArr.head, med, sortedArr(len - 1), sortedArr.sum)
- }
- }
-}
diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/StatisticsMetrics.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/StatisticsMetrics.scala
index 1b88d2d4c..d0a21a6c0 100644
--- a/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/StatisticsMetrics.scala
+++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/StatisticsMetrics.scala
@@ -16,6 +16,8 @@
package com.nvidia.spark.rapids.tool.analysis
+import org.apache.spark.sql.rapids.tool.util.InPlaceMedianArrView.{chooseMidpointPivotInPlace, findMedianInPlace}
+
// Store (min, median, max, total) for a given metric
case class StatisticsMetrics(min: Long, med: Long, max: Long, total: Long)
@@ -23,4 +25,31 @@ object StatisticsMetrics {
// a static variable used to represent zero-statistics instead of allocating a dummy record
// on every calculation.
val ZERO_RECORD: StatisticsMetrics = StatisticsMetrics(0L, 0L, 0L, 0L)
+
+ def createFromArr(arr: Array[Long]): StatisticsMetrics = {
+ if (arr.isEmpty) {
+ return ZERO_RECORD
+ }
+ val medV = findMedianInPlace(arr)(chooseMidpointPivotInPlace)
+ var minV = Long.MaxValue
+ var maxV = Long.MinValue
+ var totalV = 0L
+ arr.foreach { v =>
+ if (v < minV) {
+ minV = v
+ }
+ if (v > maxV) {
+ maxV = v
+ }
+ totalV += v
+ }
+ StatisticsMetrics(minV, medV, maxV, totalV)
+ }
+
+ def createOptionalFromArr(arr: Array[Long]): Option[StatisticsMetrics] = {
+ if (arr.isEmpty) {
+ return None
+ }
+ Some(createFromArr(arr))
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala
index 0f8e520c6..080a34df3 100644
--- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala
+++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala
@@ -16,7 +16,7 @@
package org.apache.spark.sql.rapids.tool.store
-import scala.collection.mutable
+import scala.collection.{breakOut, mutable}
import com.nvidia.spark.rapids.tool.analysis.StatisticsMetrics
@@ -98,22 +98,8 @@ class AccumInfo(val infoRef: AccumMetaRef) {
}
def calculateAccStats(): StatisticsMetrics = {
- val sortedTaskUpdates = taskUpdatesMap.values.toSeq.sorted
- if (sortedTaskUpdates.isEmpty) {
- // do not check stage values because the stats is only meant for task updates
- StatisticsMetrics.ZERO_RECORD
- } else {
- val min = sortedTaskUpdates.head
- val max = sortedTaskUpdates.last
- val sum = sortedTaskUpdates.sum
- val median = if (sortedTaskUpdates.size % 2 == 0) {
- val mid = sortedTaskUpdates.size / 2
- (sortedTaskUpdates(mid) + sortedTaskUpdates(mid - 1)) / 2
- } else {
- sortedTaskUpdates(sortedTaskUpdates.size / 2)
- }
- StatisticsMetrics(min, median, max, sum)
- }
+ // do not check stage values because the stats is only meant for task updates
+ StatisticsMetrics.createFromArr(taskUpdatesMap.map(_._2)(breakOut))
}
def getMaxStageValue: Option[Long] = {
diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/InPlaceMedianArrView.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/InPlaceMedianArrView.scala
new file mode 100644
index 000000000..1be48a6a7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/InPlaceMedianArrView.scala
@@ -0,0 +1,150 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.tool.util
+
+import scala.annotation.tailrec
+import scala.language.postfixOps
+
+/**
+ * Allows for in-place partitioning and finding the median.
+ * The tools used to find the median of a sequence by sorting the entire sequence, then returning
+ * the elements in the middle. As we started to capture all the accumulators in Spark plans,
+ * sorting is inefficient for large eventlogs that contain huge number of tasks and
+ * Accumulables. Thus, this class is an optimized version to get the median in a linear
+ * complexity while doing it in place to avoid allocating new array to store the sorted elements.
+ * The code is copied from a Stackoverflow thread:
+ * https://stackoverflow.com/questions/4662292/scala-median-implementation
+ *
+ * Notes:
+ * - The implementation assumes that the array is not empty.
+ */
+case class InPlaceMedianArrView(arr: Array[Long], from: Int, until: Int) {
+ def apply(n: Int): Long = {
+ if (from + n < until) {
+ arr(from + n)
+ } else {
+ throw new ArrayIndexOutOfBoundsException(n)
+ }
+ }
+
+ /**
+ * Returns a new view of the array with the same elements but a different range.
+ * @param p a predicate to apply on the elements to proceed with the partitioning.
+ * @return a tuple of 2 views, the first one contains the elements that satisfy the predicate,
+ * and the second one contains the rest.
+ */
+ def partitionInPlace(p: Long => Boolean): (InPlaceMedianArrView, InPlaceMedianArrView) = {
+ var upper = until - 1
+ var lower = from
+ while (lower < upper) {
+ while (lower < until && p(arr(lower))) lower += 1
+ while (upper >= from && !p(arr(upper))) upper -= 1
+ if (lower < upper) { val tmp = arr(lower); arr(lower) = arr(upper); arr(upper) = tmp }
+ }
+ (copy(until = lower), copy(from = lower))
+ }
+
+ def size: Int = {
+ until - from
+ }
+
+ def isEmpty: Boolean = {
+ size <= 0
+ }
+
+ override def toString = {
+ arr mkString ("ArraySize(", ", ", ")")
+ }
+}
+
+/**
+ * Companion object for InPlaceMedianArrView.
+ */
+object InPlaceMedianArrView {
+
+ def apply(arr: Array[Long]): InPlaceMedianArrView = {
+ InPlaceMedianArrView(arr, 0, arr.size)
+ }
+
+ /**
+ * Finds the median of the array in place.
+ * @param arr the Array[Long] to be processed
+ * @param k the index of the median
+ * @param choosePivot a function to choose the pivot index. This useful to choose different
+ * strategies. For example, choosing the midpoint works better for sorted
+ * arrays.
+ * @return the median of the array.
+ */
+ @tailrec
+ def findKMedianInPlace(arr: InPlaceMedianArrView, k: Int)
+ (implicit choosePivot: InPlaceMedianArrView => Long): Long = {
+ val a = choosePivot(arr)
+ val (s, b) = arr partitionInPlace (a >)
+ if (s.size == k) {
+ a
+ } else if (s.isEmpty) {
+ val (s, b) = arr partitionInPlace (a ==)
+ if (s.size > k) {
+ a
+ } else {
+ findKMedianInPlace(b, k - s.size)
+ }
+ } else if (s.size < k) {
+ findKMedianInPlace(b, k - s.size)
+ } else {
+ findKMedianInPlace(s, k)
+ }
+ }
+
+ /**
+ * Choose a random pivot in the array. This can lead to worst case for sorted arrays.
+ * @param arr the array to choose the pivot from.
+ * @return a random element from the array.
+ */
+ def chooseRandomPivotInPlace(arr: InPlaceMedianArrView): Long = {
+ arr(scala.util.Random.nextInt(arr.size))
+ }
+
+ /**
+ * Choose the element in the middle as a pivot. This works better to find median of sorted arrays.
+ * @param arr the array to choose the pivot from.
+ * @return the element in the middle of the array.
+ */
+ def chooseMidpointPivotInPlace(arr: InPlaceMedianArrView): Long = {
+ arr((arr.size - 1) / 2)
+ }
+
+ /**
+ * Finds the median of the array in place.
+ * @param arr the Array[Long] to be processed.
+ * @param choosePivot a function to choose the pivot index.
+ * @return the median of the array.
+ */
+ def findMedianInPlace(
+ arr: Array[Long])(implicit choosePivot: InPlaceMedianArrView => Long): Long = {
+ val midIndex = (arr.size - 1) / 2
+ if (arr.size % 2 == 0) {
+ // For even-length arrays, find the two middle elements and compute their average
+ val mid1 = findKMedianInPlace(InPlaceMedianArrView(arr), midIndex)
+ val mid2 = findKMedianInPlace(InPlaceMedianArrView(arr), midIndex + 1)
+ (mid1 + mid2) / 2
+ } else {
+ // For odd-length arrays, return the middle element
+ findKMedianInPlace(InPlaceMedianArrView(arr), midIndex)
+ }
+ }
+}
diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/util/ToolUtilsSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/util/ToolUtilsSuite.scala
index baba6eb79..5e1b6558b 100644
--- a/core/src/test/scala/com/nvidia/spark/rapids/tool/util/ToolUtilsSuite.scala
+++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/util/ToolUtilsSuite.scala
@@ -24,13 +24,13 @@ import scala.concurrent.duration._
import scala.xml.XML
import com.nvidia.spark.rapids.tool.profiling.{ProfileOutputWriter, ProfileResult}
+import org.scalatest.AppendedClues.convertToClueful
import org.scalatest.FunSuite
import org.scalatest.Matchers.{contain, convertToAnyShouldWrapper, equal, not}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.TrampolineUtil
-import org.apache.spark.sql.rapids.tool.util.{FSUtils, RapidsToolsConfUtil, StringUtils, WebCrawlerUtil}
-
+import org.apache.spark.sql.rapids.tool.util.{FSUtils, InPlaceMedianArrView, RapidsToolsConfUtil, StringUtils, WebCrawlerUtil}
class ToolUtilsSuite extends FunSuite with Logging {
test("get page links of a url") {
@@ -210,6 +210,27 @@ class ToolUtilsSuite extends FunSuite with Logging {
}
}
+ test("Finding median of arrays") {
+ val testSet: Map[String, (Array[Long], Long)] = Map(
+ "All same values" -> (Array[Long](5, 5, 5, 5) -> 5L),
+ "Odd number of values [9, 7, 5, 3, 1]" -> (Array[Long](9, 7, 5, 3, 1) -> 5L),
+ "Even number of values [11, 9, 7, 5, 3, 1]" -> (Array[Long](11, 9, 7, 5, 3, 1) -> 6),
+ "Even number of values(2) [15, 13, 11, 9, 7, 5, 3, 1]" ->
+ (Array[Long](15, 13, 11, 9, 7, 5, 3, 1) -> 8),
+ "Even number of values(3) [3, 13, 11, 9, 7, 5, 15, 1]" ->
+ (Array[Long](3, 13, 11, 9, 7, 5, 15, 1) -> 8),
+ "Single element" -> (Array[Long](1) -> 1),
+ "Two elements" -> (Array[Long](1, 2).reverse -> 1)
+ )
+ for ((desc, (arr, expectedMedian)) <- testSet) {
+ val actualMedian =
+ InPlaceMedianArrView.findMedianInPlace(arr)(InPlaceMedianArrView.chooseMidpointPivotInPlace)
+ actualMedian shouldBe expectedMedian withClue s"Failed for $desc. " +
+ s"Expected: $expectedMedian, " +
+ s"Actual: $actualMedian"
+ }
+ }
+
case class MockProfileResults(appID: String, appIndex: Int, nonEnglishField: String,
parentIDs: String) extends ProfileResult {
override val outputHeaders: Seq[String] = Seq("appID", "appIndex", "nonEnglishField",