diff --git a/integration_tests/src/main/python/aqe_test.py b/integration_tests/src/main/python/aqe_test.py index b7968f8e902..ba0553912d4 100755 --- a/integration_tests/src/main/python/aqe_test.py +++ b/integration_tests/src/main/python/aqe_test.py @@ -298,3 +298,40 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it, conf=bhj_disable_conf) + +# See https://github.com/NVIDIA/spark-rapids/issues/10645. Sometimes the exchange can provide multiple +# batches, so we to coalesce them into a single batch for the broadcast hash join. +@ignore_order(local=True) +@pytest.mark.skipif(not (is_databricks_runtime()), \ + reason="Executor side broadcast only supported on Databricks") +def test_aqe_join_executor_broadcast_enforce_single_batch(): + # Use a small batch to see if Databricks could send multiple batches + conf = copy_and_update(_adaptive_conf, { "spark.rapids.sql.batchSizeBytes": "25" }) + def prep(spark): + id_gen = RepeatSeqGen(IntegerGen(nullable=False), length=250) + name_gen = RepeatSeqGen(["Adam", "Bob", "Cathy"], data_type=StringType()) + school_gen = RepeatSeqGen(["School1", "School2", "School3"], data_type=StringType()) + + df = gen_df(spark, StructGen([('id', id_gen), ('name', name_gen)], nullable=False), length=1000) + df.createOrReplaceTempView("df") + + df_school = gen_df(spark, StructGen([('id', id_gen), ('school', school_gen)], nullable=False), length=250) + df.createOrReplaceTempView("df_school") + + with_cpu_session(prep) + + def do_it(spark): + res = spark.sql( + """ + select /*+ BROADCAST(df_school) */ * from df, df_school where df.id == df_school.id + """ + ) + res.explain() + return res + # Ensure this is an EXECUTOR_BROADCAST + assert_cpu_and_gpu_are_equal_collect_with_capture( + do_it, + exist_classes="GpuShuffleExchangeExec,GpuBroadcastHashJoinExec", + non_exist_classes="GpuBroadcastExchangeExec", + conf=conf) + diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuExecutorBroadcastHelper.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuExecutorBroadcastHelper.scala index 2522de85169..5fabf05069f 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuExecutorBroadcastHelper.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuExecutorBroadcastHelper.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution -import com.nvidia.spark.rapids.{ConcatAndConsumeAll, GpuColumnVector, GpuMetric, GpuShuffleCoalesceIterator, HostShuffleCoalesceIterator} +import com.nvidia.spark.rapids.{ConcatAndConsumeAll, GpuCoalesceIterator, GpuColumnVector, GpuMetric, GpuShuffleCoalesceIterator, HostShuffleCoalesceIterator, NoopMetric, RequireSingleBatch} import com.nvidia.spark.rapids.Arm.withResource import org.apache.spark.TaskContext @@ -40,6 +40,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * which means they require a format of the data that can be used on the GPU. */ object GpuExecutorBroadcastHelper { + import GpuMetric._ // This reads the shuffle data that we have retrieved using `getShuffleRDD` from the shuffle // exchange. WARNING: Do not use this method outside of this context. This method can only be @@ -65,11 +66,28 @@ object GpuExecutorBroadcastHelper { // Use the GPU Shuffle Coalesce iterator to concatenate and load batches onto the // host as needed. Since we don't have GpuShuffleCoalesceExec in the plan for the // executor broadcast scenario, we have to use that logic here to efficiently - // grab and release the semaphore while doing I/O + // grab and release the semaphore while doing I/O. We wrap this with GpuCoalesceIterator + // to ensure this always a single batch for the following step. + val shuffleMetrics = Map( + CONCAT_TIME -> metricsMap(CONCAT_TIME), + OP_TIME -> metricsMap(OP_TIME) + ).withDefaultValue(NoopMetric) + val iter = shuffleDataIterator(shuffleData) - new GpuShuffleCoalesceIterator( - new HostShuffleCoalesceIterator(iter, targetSize, metricsMap), - dataTypes, metricsMap).asInstanceOf[Iterator[ColumnarBatch]] + new GpuCoalesceIterator( + new GpuShuffleCoalesceIterator( + new HostShuffleCoalesceIterator(iter, targetSize, shuffleMetrics), + dataTypes, shuffleMetrics).asInstanceOf[Iterator[ColumnarBatch]], + dataTypes, + RequireSingleBatch, + NoopMetric, // numInputRows + NoopMetric, // numInputBatches + NoopMetric, // numOutputRows + NoopMetric, // numOutputBatches + NoopMetric, // collectTime + metricsMap(CONCAT_TIME), // concatTime + metricsMap(OP_TIME), // opTime + "GpuBroadcastHashJoinExec").asInstanceOf[Iterator[ColumnarBatch]] } /**