From 8be7ddf68953068ea3aebdabd566099e4c518513 Mon Sep 17 00:00:00 2001 From: "Hongbin Ma (Mahone)" Date: Mon, 17 Jun 2024 10:29:03 +0800 Subject: [PATCH] address review comments Signed-off-by: Hongbin Ma (Mahone) --- .../com/nvidia/spark/rapids/GpuExec.scala | 34 ++++++++++--- .../nvidia/spark/rapids/NvtxWithMetrics.scala | 50 ++----------------- .../nvidia/spark/rapids/MetricsSuite.scala | 33 +++++++++++- 3 files changed, 63 insertions(+), 54 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index 733cb2cd3d9..d83f20113b2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -152,16 +152,34 @@ sealed abstract class GpuMetric extends Serializable { def +=(v: Long): Unit def add(v: Long): Unit + private var isTimerActive = false + + final def tryActivateTimer(): Boolean = { + if (!isTimerActive) { + isTimerActive = true + true + } else { + false + } + } + + final def deactivateTimer(duration: Long): Unit = { + if (isTimerActive) { + isTimerActive = false + add(duration) + } + } + final def ns[T](f: => T): T = { - val needTrack = ThreadLocalMetrics.onMetricsEnter(this) - val start = System.nanoTime() - try { - f - } finally { - if (needTrack) { - add(System.nanoTime() - start) - ThreadLocalMetrics.onMetricsExit(this) + if (tryActivateTimer()) { + val start = System.nanoTime() + try { + f + } finally { + deactivateTimer(System.nanoTime() - start) } + } else { + f } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala index 1ff3e0b0b84..538f117e50f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala @@ -16,8 +16,6 @@ package com.nvidia.spark.rapids -import scala.collection.mutable - import ai.rapids.cudf.{NvtxColor, NvtxRange} object NvtxWithMetrics { @@ -29,58 +27,21 @@ object NvtxWithMetrics { } } -object ThreadLocalMetrics { - val addressOrdering: Ordering[GpuMetric] = Ordering.by(System.identityHashCode(_)) - - val currentThreadMetrics = new ThreadLocal[mutable.TreeSet[GpuMetric]] { - override def initialValue(): mutable.TreeSet[GpuMetric] = - mutable.TreeSet[GpuMetric]()(addressOrdering) - } - - /** - * Check if current metric needs tracking. - * - * @param gpuMetric the metric to check - * @return true if the metric needs tracking, - */ - def onMetricsEnter(gpuMetric: GpuMetric): Boolean = { - if (gpuMetric != NoopMetric) { - if (ThreadLocalMetrics.currentThreadMetrics.get().contains(gpuMetric)) { - return false - } - ThreadLocalMetrics.currentThreadMetrics.get().add(gpuMetric) - true - } else { - false - } - } - - def onMetricsExit(gpuMetric: GpuMetric): Unit = { - if (gpuMetric != NoopMetric) { - if (!ThreadLocalMetrics.currentThreadMetrics.get().contains(gpuMetric)) { - throw new IllegalStateException("Metric missing from thread local storage: " - + gpuMetric) - } - ThreadLocalMetrics.currentThreadMetrics.get().remove(gpuMetric) - } - } -} /** * NvtxRange with option to pass one or more nano timing metric(s) that are updated upon close * by the amount of time spent in the range */ class NvtxWithMetrics(name: String, color: NvtxColor, val metrics: GpuMetric*) - extends NvtxRange(name, color) { + extends NvtxRange(name, color) { - val needTracks = metrics.map(ThreadLocalMetrics.onMetricsEnter) + val needTracks = metrics.map(_.tryActivateTimer()) private val start = System.nanoTime() override def close(): Unit = { val time = System.nanoTime() - start metrics.toSeq.zip(needTracks).foreach { pair => if (pair._2) { - pair._1 += time - ThreadLocalMetrics.onMetricsExit(pair._1) + pair._1.deactivateTimer(time) } } super.close() @@ -88,15 +49,14 @@ class NvtxWithMetrics(name: String, color: NvtxColor, val metrics: GpuMetric*) } class MetricRange(val metrics: GpuMetric*) extends AutoCloseable { - val needTracks = metrics.map(ThreadLocalMetrics.onMetricsEnter) + val needTracks = metrics.map(_.tryActivateTimer()) private val start = System.nanoTime() override def close(): Unit = { val time = System.nanoTime() - start metrics.toSeq.zip(needTracks).foreach { pair => if (pair._2) { - pair._1 += time - ThreadLocalMetrics.onMetricsExit(pair._1) + pair._1.deactivateTimer(time) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala index 529348d33d3..580c5a2ed55 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala @@ -16,12 +16,25 @@ package com.nvidia.spark.rapids +import ai.rapids.cudf.NvtxColor import com.nvidia.spark.rapids.Arm.withResource import org.scalatest.funsuite.AnyFunSuite class MetricsSuite extends AnyFunSuite { - test("duplicate timing on the same metrics") { + test("GpuMetric.ns: duplicate timing on the same metrics") { + val m1 = new LocalGpuMetric() + m1.ns( + m1.ns( + Thread.sleep(100) + ) + ) + // if the timing is duplicated, the value should be around 200,000,000 + assert(m1.value < 100000000 * 1.5) + assert(m1.value > 100000000 * 0.5) + } + + test("MetricRange: duplicate timing on the same metrics") { val m1 = new LocalGpuMetric() val m2 = new LocalGpuMetric() withResource(new MetricRange(m1, m2)) { _ => @@ -32,6 +45,24 @@ class MetricsSuite extends AnyFunSuite { // if the timing is duplicated, the value should be around 200,000,000 assert(m1.value < 100000000 * 1.5) + assert(m1.value > 100000000 * 0.5) + assert(m2.value < 100000000 * 1.5) + assert(m2.value > 100000000 * 0.5) + } + + test("NvtxWithMetrics: duplicate timing on the same metrics") { + val m1 = new LocalGpuMetric() + val m2 = new LocalGpuMetric() + withResource(new NvtxWithMetrics("a", NvtxColor.BLUE, m1, m2)) { _ => + withResource(new NvtxWithMetrics("b", NvtxColor.BLUE, m2, m1)) { _ => + Thread.sleep(100) + } + } + + // if the timing is duplicated, the value should be around 200,000,000 + assert(m1.value < 100000000 * 1.5) + assert(m1.value > 100000000 * 0.5) assert(m2.value < 100000000 * 1.5) + assert(m2.value > 100000000 * 0.5) } }