diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 5e1f45e6230..75721b40a77 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids import java.time.ZoneId +import scala.collection.mutable import scala.collection.mutable.ListBuffer import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -53,7 +54,7 @@ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan import org.apache.spark.sql.execution.datasources.v2.json.JsonScan -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.window.WindowExec @@ -4510,6 +4511,39 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging { // gets called once for each query stage (where a query stage is an `Exchange`). override def apply(sparkPlan: SparkPlan): SparkPlan = applyWithContext(sparkPlan, None) + private def lookAtReusedExchange(sparkPlan: SparkPlan): Unit = { + val exchanges = mutable.Map.empty[SparkPlan, Exchange] + sparkPlan.foreach { + case exchange: Exchange if conf.exchangeReuseEnabled => + val cachedExchange = exchanges.getOrElseUpdate(exchange.canonicalized, exchange) + if (cachedExchange.ne(exchange)) { + println( + s"""==>REUSED_EX_DEBUG: found an exchange: + | $exchange + | (Canonicalized: ${exchange.canonicalized}) + | can reuse the cached one: + | $cachedExchange + | (Canonicalized: ${cachedExchange.canonicalized}) + """.stripMargin) + } else { + if (exchanges.size > 1) { + println( + s"""==>REUSED_EX_DEBUG: found maybe a different exchange: + | $cachedExchange + | (Canonicalized: ${cachedExchange.canonicalized}) + """.stripMargin) + println( + s"""==>REUSED_EX_DEBUG: current map: + | $exchanges + |""".stripMargin) + } else { + // the first one + } + } + case _ => // ignore + } + } + def applyWithContext(sparkPlan: SparkPlan, context: Option[String]): SparkPlan = GpuOverrideUtil.tryOverride { plan => val conf = new RapidsConf(plan.conf) @@ -4524,6 +4558,7 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging { logWarning(s"${logPrefix}Transformed query:" + s"\nOriginal Plan:\n$plan\nTransformed Plan:\n$updatedPlan") } + lookAtReusedExchange(updatedPlan) updatedPlan } } else if (conf.isSqlEnabled && conf.isSqlExplainOnlyEnabled) {