diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala index 69c41c7dc61..27b3beb3ba6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala @@ -58,7 +58,7 @@ object ThreadLocalMetrics { def onMetricsExit(gpuMetric: GpuMetric): Unit = { if (gpuMetric != NoopMetric) { if (!ThreadLocalMetrics.currentThreadMetrics.get().contains(gpuMetric)) { - throw new IllegalArgumentException("Metric missing from thread local storage: " + throw new IllegalStateException()("Metric missing from thread local storage: " + gpuMetric) } ThreadLocalMetrics.currentThreadMetrics.get().remove(gpuMetric) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala new file mode 100644 index 00000000000..529348d33d3 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/MetricsSuite.scala @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import com.nvidia.spark.rapids.Arm.withResource +import org.scalatest.funsuite.AnyFunSuite + +class MetricsSuite extends AnyFunSuite { + + test("duplicate timing on the same metrics") { + val m1 = new LocalGpuMetric() + val m2 = new LocalGpuMetric() + withResource(new MetricRange(m1, m2)) { _ => + withResource(new MetricRange(m2, m1)) { _ => + Thread.sleep(100) + } + } + + // if the timing is duplicated, the value should be around 200,000,000 + assert(m1.value < 100000000 * 1.5) + assert(m2.value < 100000000 * 1.5) + } +}