Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
  • Loading branch information
binmahone committed Jun 17, 2024
1 parent 441f9c0 commit 8be7ddf
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 54 deletions.
34 changes: 26 additions & 8 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,34 @@ sealed abstract class GpuMetric extends Serializable {
def +=(v: Long): Unit
def add(v: Long): Unit

private var isTimerActive = false

final def tryActivateTimer(): Boolean = {
if (!isTimerActive) {
isTimerActive = true
true
} else {
false
}
}

final def deactivateTimer(duration: Long): Unit = {
if (isTimerActive) {
isTimerActive = false
add(duration)
}
}

final def ns[T](f: => T): T = {
val needTrack = ThreadLocalMetrics.onMetricsEnter(this)
val start = System.nanoTime()
try {
f
} finally {
if (needTrack) {
add(System.nanoTime() - start)
ThreadLocalMetrics.onMetricsExit(this)
if (tryActivateTimer()) {
val start = System.nanoTime()
try {
f
} finally {
deactivateTimer(System.nanoTime() - start)
}
} else {
f
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package com.nvidia.spark.rapids

import scala.collection.mutable

import ai.rapids.cudf.{NvtxColor, NvtxRange}

object NvtxWithMetrics {
Expand All @@ -29,74 +27,36 @@ object NvtxWithMetrics {
}
}

object ThreadLocalMetrics {
val addressOrdering: Ordering[GpuMetric] = Ordering.by(System.identityHashCode(_))

val currentThreadMetrics = new ThreadLocal[mutable.TreeSet[GpuMetric]] {
override def initialValue(): mutable.TreeSet[GpuMetric] =
mutable.TreeSet[GpuMetric]()(addressOrdering)
}

/**
* Check if current metric needs tracking.
*
* @param gpuMetric the metric to check
* @return true if the metric needs tracking,
*/
def onMetricsEnter(gpuMetric: GpuMetric): Boolean = {
if (gpuMetric != NoopMetric) {
if (ThreadLocalMetrics.currentThreadMetrics.get().contains(gpuMetric)) {
return false
}
ThreadLocalMetrics.currentThreadMetrics.get().add(gpuMetric)
true
} else {
false
}
}

def onMetricsExit(gpuMetric: GpuMetric): Unit = {
if (gpuMetric != NoopMetric) {
if (!ThreadLocalMetrics.currentThreadMetrics.get().contains(gpuMetric)) {
throw new IllegalStateException("Metric missing from thread local storage: "
+ gpuMetric)
}
ThreadLocalMetrics.currentThreadMetrics.get().remove(gpuMetric)
}
}
}
/**
* NvtxRange with option to pass one or more nano timing metric(s) that are updated upon close
* by the amount of time spent in the range
*/
class NvtxWithMetrics(name: String, color: NvtxColor, val metrics: GpuMetric*)
extends NvtxRange(name, color) {
extends NvtxRange(name, color) {

val needTracks = metrics.map(ThreadLocalMetrics.onMetricsEnter)
val needTracks = metrics.map(_.tryActivateTimer())
private val start = System.nanoTime()

override def close(): Unit = {
val time = System.nanoTime() - start
metrics.toSeq.zip(needTracks).foreach { pair =>
if (pair._2) {
pair._1 += time
ThreadLocalMetrics.onMetricsExit(pair._1)
pair._1.deactivateTimer(time)
}
}
super.close()
}
}

class MetricRange(val metrics: GpuMetric*) extends AutoCloseable {
val needTracks = metrics.map(ThreadLocalMetrics.onMetricsEnter)
val needTracks = metrics.map(_.tryActivateTimer())
private val start = System.nanoTime()

override def close(): Unit = {
val time = System.nanoTime() - start
metrics.toSeq.zip(needTracks).foreach { pair =>
if (pair._2) {
pair._1 += time
ThreadLocalMetrics.onMetricsExit(pair._1)
pair._1.deactivateTimer(time)
}
}
}
Expand Down
33 changes: 32 additions & 1 deletion tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,25 @@

package com.nvidia.spark.rapids

import ai.rapids.cudf.NvtxColor
import com.nvidia.spark.rapids.Arm.withResource
import org.scalatest.funsuite.AnyFunSuite

class MetricsSuite extends AnyFunSuite {

test("duplicate timing on the same metrics") {
test("GpuMetric.ns: duplicate timing on the same metrics") {
val m1 = new LocalGpuMetric()
m1.ns(
m1.ns(
Thread.sleep(100)
)
)
// if the timing is duplicated, the value should be around 200,000,000
assert(m1.value < 100000000 * 1.5)
assert(m1.value > 100000000 * 0.5)
}

test("MetricRange: duplicate timing on the same metrics") {
val m1 = new LocalGpuMetric()
val m2 = new LocalGpuMetric()
withResource(new MetricRange(m1, m2)) { _ =>
Expand All @@ -32,6 +45,24 @@ class MetricsSuite extends AnyFunSuite {

// if the timing is duplicated, the value should be around 200,000,000
assert(m1.value < 100000000 * 1.5)
assert(m1.value > 100000000 * 0.5)
assert(m2.value < 100000000 * 1.5)
assert(m2.value > 100000000 * 0.5)
}

test("NvtxWithMetrics: duplicate timing on the same metrics") {
val m1 = new LocalGpuMetric()
val m2 = new LocalGpuMetric()
withResource(new NvtxWithMetrics("a", NvtxColor.BLUE, m1, m2)) { _ =>
withResource(new NvtxWithMetrics("b", NvtxColor.BLUE, m2, m1)) { _ =>
Thread.sleep(100)
}
}

// if the timing is duplicated, the value should be around 200,000,000
assert(m1.value < 100000000 * 1.5)
assert(m1.value > 100000000 * 0.5)
assert(m2.value < 100000000 * 1.5)
assert(m2.value > 100000000 * 0.5)
}
}

0 comments on commit 8be7ddf

Please sign in to comment.