From 00580a04420ff524ac7f3dafb49af1b771ed8cea Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Tue, 17 Dec 2024 14:44:04 +0900 Subject: [PATCH] fix --- .../nvidia/spark/rapids/GpuPartitioning.scala | 16 ++++++++-------- .../execution/GpuShuffleExchangeExecBase.scala | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala index d2338a91384..616e2df721b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala @@ -256,15 +256,15 @@ trait GpuPartitioning extends Partitioning { private var memCopyTime: Option[GpuMetric] = None /** - * Setup Spark SQL Metrics for the details of GpuPartition. This method is expected to be called - * at the query planning stage for only once. + * Setup sub-metrics for the performance debugging of GpuPartition. This method is expected to + * be called at the query planning stage. Therefore, this method is NOT thread safe. */ - def setupMetrics(metrics: Map[String, GpuMetric]): Unit = { - metrics.get(GpuPartitioning.CopyToHostTime).foreach { metric => - // Check and set GpuPartitioning.CopyToHostTime - require(memCopyTime.isEmpty, - s"The GpuMetric[${GpuPartitioning.CopyToHostTime}] has already been set") - memCopyTime = Some(metric) + def setupDebugMetrics(metrics: Map[String, GpuMetric]): Unit = { + // Check and set GpuPartitioning.CopyToHostTime + if (memCopyTime.isEmpty) { + metrics.get(GpuPartitioning.CopyToHostTime).foreach { metric => + memCopyTime = Some(metric) + } } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala index 6fb78d85554..fa755de3dc9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala @@ -368,10 +368,10 @@ object GpuShuffleExchangeExecBase { rdd } val partitioner: GpuExpression = getPartitioner(newRdd, outputAttributes, newPartitioning) - // Inject detailed Metrics, such as D2HTime before SliceOnCpu + // Inject debugging subMetrics, such as D2HTime before SliceOnCpu // The injected metrics will be serialized as the members of GpuPartitioning partitioner match { - case pt: GpuPartitioning => pt.setupMetrics(additionalMetrics) + case pt: GpuPartitioning => pt.setupDebugMetrics(additionalMetrics) case _ => } val partitionTime: GpuMetric = metrics(METRIC_SHUFFLE_PARTITION_TIME)