Skip to content

Commit

Permalink
Adding GC Metrics (#12)
Browse files Browse the repository at this point in the history
* Adding GC Metrics

Signed-off-by: Sayed Bilal Bari <[email protected]>

* Review comment changes

Signed-off-by: Sayed Bilal Bari <[email protected]>

* Correcting output format + refactoring

Signed-off-by: Sayed Bilal Bari <[email protected]>

* Output Formatting Changes

Signed-off-by: Sayed Bilal Bari <[email protected]>

* Formatting + Making qual bench single threaded

Signed-off-by: Sayed Bilal Bari <[email protected]>

---------

Signed-off-by: Sayed Bilal Bari <[email protected]>
Co-authored-by: Sayed Bilal Bari <[email protected]>
  • Loading branch information
bilalbari and bilalbari authored Jul 5, 2024
1 parent 8e3b18b commit 1026e69
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.concurrent.duration.NANOSECONDS

import org.apache.commons.io.output.TeeOutputStream

import org.apache.spark.sql.rapids.tool.util.{RuntimeUtil, ToolsTimer}
import org.apache.spark.sql.rapids.tool.util.{MemoryMetricsTracker, RuntimeUtil, ToolsTimer}

/**
* This code is mostly copied from org.apache.spark.benchmark.BenchmarkBase
Expand All @@ -37,7 +37,7 @@ import org.apache.spark.sql.rapids.tool.util.{RuntimeUtil, ToolsTimer}
* This will output the average time to run each function and the rate of each function.
*/
class Benchmark(
name: String,
name: String = "Benchmarker",
valuesPerIteration: Long,
minNumIters: Int,
warmUpIterations: Int,
Expand Down Expand Up @@ -99,18 +99,31 @@ class Benchmark(

val firstBest = results.head.bestMs
// The results are going to be processor specific so it is useful to include that.
out.println(RuntimeUtil.getJVMOSInfo.mkString("\n"))
val jvmInfo = RuntimeUtil.getJVMOSInfo
out.printf(s"%-26s : %s \n","JVM Name", jvmInfo("jvm.name"))
out.printf(s"%-26s : %s \n","Java Version", jvmInfo("jvm.version"))
out.printf(s"%-26s : %s \n","OS Name", jvmInfo("os.name"))
out.printf(s"%-26s : %s \n","OS Version", jvmInfo("os.version"))
out.printf(s"%-26s : %s MB \n","MaxHeapMemory", (Runtime.getRuntime.maxMemory()/1024/1024).toString)
out.printf(s"%-26s : %s \n","Total Warm Up Iterations", warmUpIterations.toString)
out.printf(s"%-26s : %s \n \n","Total Runtime Iterations", minNumIters.toString)
val nameLen = Math.max(40, Math.max(name.length, benchmarks.map(_.name.length).max))
out.printf(s"%-${nameLen}s %14s %14s %11s %10s\n",
name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Relative")
out.println("-" * (nameLen + 80))
out.printf(s"%-${nameLen}s %14s %14s %11s %20s %18s %18s %18s %18s %10s\n",
name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)","Avg GC Time(ms)",
"Avg GC Count", "Stdev GC Count","Max GC Time(ms)","Max GC Count", "Relative")
out.println("-" * (nameLen + 160))
results.zip(benchmarks).foreach { case (result, benchmark) =>
out.printf(s"%-${nameLen}s %14s %14s %11s %10s\n",
out.printf(s"%-${nameLen}s %14s %14s %11s %20s %18s %18s %18s %18s %10s\n",
benchmark.name,
"%5.0f" format result.bestMs,
"%4.0f" format result.avgMs,
"%5.0f" format result.stdevMs,
"%3.1fX" format (firstBest / result.bestMs))
"%5.1f" format result.memoryParams.avgGCTime,
"%5.1f" format result.memoryParams.avgGCCount,
"%5.0f" format result.memoryParams.stdDevGCCount,
"%5d" format result.memoryParams.maxGcTime,
"%5d" format result.memoryParams.maxGCCount,
"%3.2fX" format (firstBest / result.bestMs))
}
out.println()
}
Expand All @@ -126,14 +139,17 @@ class Benchmark(
}
val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters
val runTimes = ArrayBuffer[Long]()
var totalTime = 0L
val gcCounts = ArrayBuffer[Long]()
val gcTimes = ArrayBuffer[Long]()
//For tracking maximum GC over iterations
for (i <- 0 until minIters) {
val timer = new ToolsTimer(i)
val memoryTracker = new MemoryMetricsTracker
f(timer)
val runTime = timer.totalTime()
runTimes += runTime
totalTime += runTime

gcCounts += memoryTracker.getTotalGCCount
gcTimes += memoryTracker.getTotalGCTime
if (outputPerIteration) {
// scalastyle:off
println("*"*80)
Expand All @@ -148,17 +164,34 @@ class Benchmark(
println("*"*80)
// scalastyle:on
assert(runTimes.nonEmpty)
val best = runTimes.min
val avg = runTimes.sum / runTimes.size
val stdev = if (runTimes.size > 1) {
math.sqrt(runTimes.map(time => (time - avg) * (time - avg)).sum / (runTimes.size - 1))
} else 0
Benchmark.Result(avg / 1000000.0, best / 1000000.0, stdev / 1000000.0)
val bestRuntime = runTimes.min
val avgRuntime = runTimes.sum / runTimes.size
val stdevRunTime = if (runTimes.size > 1) {
math.sqrt(runTimes.map(time => (time - avgRuntime) *
(time - avgRuntime)).sum / (runTimes.size - 1))
} else {
0
}
val maxGcCount = gcCounts.max
val stdevGcCount = if (gcCounts.size > 1) {
math.sqrt(gcCounts.map(gc => (gc - maxGcCount) *
(gc - maxGcCount)).sum / (gcCounts.size - 1))
} else {
0
}
val avgGcCount = gcCounts.sum / minIters
val avgGcTime = gcTimes.sum / minIters
val maxGcTime = gcTimes.max
Benchmark.Result(avgRuntime / 1000000.0, bestRuntime / 1000000.0, stdevRunTime / 1000000.0,
JVMMemoryParams(avgGcTime, avgGcCount, stdevGcCount, maxGcCount, maxGcTime))
}
}


object Benchmark {
case class Case(name: String, fn: ToolsTimer => Unit, numIters: Int)
case class Result(avgMs: Double, bestMs: Double, stdevMs: Double)
case class JVMMemoryParams( avgGCTime:Double, avgGCCount:Double,
stdDevGCCount: Double, maxGCCount: Long, maxGcTime:Long)
case class Result(avgMs: Double, bestMs: Double, stdevMs: Double,
memoryParams: JVMMemoryParams)
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,28 @@ import com.nvidia.spark.rapids.tool.qualification.QualificationMain.mainInternal
* 2. Write the benchmark code in the runBenchmark method passing relevant arguments
* 3. Write benchmarked code inside
*/
object QualificationBenchmark extends BenchmarkBase {
object SingleThreadedQualToolBenchmark extends BenchmarkBase {
override def runBenchmarkSuite(iterations: Int,
warmUpIterations: Int,
outputFormat: String,
mainArgs: Array[String]): Unit = {
runBenchmark("QualificationBenchmark") {
runBenchmark("Benchmark_Per_SQL_Arg_Qualification") {
val benchmarker =
new Benchmark(
"QualificationBenchmark",
2,
valuesPerIteration = 2,
output = output,
outputPerIteration = true,
warmUpIterations = warmUpIterations,
minNumIters = iterations)
benchmarker.addCase("QualificationBenchmark") { _ =>
mainInternal(new QualificationArgs(mainArgs),
printStdout = true, enablePB = true)
val (prefix,suffix) = mainArgs.splitAt(mainArgs.length - 1)
benchmarker.addCase("Enable_Per_SQL_Arg_Qualification") { _ =>
mainInternal(new QualificationArgs(prefix :+ "--per-sql" :+ "--num-threads"
:+ "1" :+ suffix.head),
enablePB = true)
}
benchmarker.addCase("Disable_Per_SQL_Arg_Qualification") { _ =>
mainInternal(new QualificationArgs(prefix :+ "--num-threads" :+ "1" :+ suffix.head),
enablePB = true)
}
benchmarker.run()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.apache.spark.sql.rapids.tool.util

import java.lang.management.ManagementFactory

import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`

/**
* Utility class to track memory metrics.
* This class is used to track memory metrics such as GC count, GC time,
* heap memory usage, etc.
*
*/
class MemoryMetricsTracker {
private val startGCMetrics = getCurrentGCMetrics

private def getCurrentGCMetrics: (Long, Long) = {
val gcBeans = ManagementFactory.getGarbageCollectorMXBeans

(gcBeans.map(_.getCollectionCount).sum,
gcBeans.map(_.getCollectionTime).sum)
}

def getTotalGCCount: Long = {
val (newGcCount:Long, _) = getCurrentGCMetrics
newGcCount - startGCMetrics._1
}

def getTotalGCTime: Long = {
val (_, newGcTime:Long) = getCurrentGCMetrics
newGcTime - startGCMetrics._2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object RuntimeUtil extends Logging {
def getJVMOSInfo: Map[String, String] = {
Map(
"jvm.name" -> System.getProperty("java.vm.name"),
"jvm.version" -> System.getProperty("java.vm.version"),
"jvm.version" -> System.getProperty("java.version"),
"os.name" -> System.getProperty("os.name"),
"os.version" -> System.getProperty("os.version")
)
Expand Down

0 comments on commit 1026e69

Please sign in to comment.