From e13cd555d6e9514e765e79b6492912b87fec5263 Mon Sep 17 00:00:00 2001 From: Liangcai Li Date: Fri, 8 Nov 2024 01:06:14 +0800 Subject: [PATCH] Add retry in sub hash join (#11706) Signed-off-by: Firestarman --- .../execution/GpuSubPartitionHashJoin.scala | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala index fc4ad412dcc..0fea22356a2 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubPartitionHashJoin.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.execution import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.nvidia.spark.rapids.{GpuBatchUtils, GpuColumnVector, GpuExpression, GpuHashPartitioningBase, GpuMetric, SpillableColumnarBatch, SpillPriorities, TaskAutoCloseableResource} +import com.nvidia.spark.rapids.{GpuBatchUtils, GpuColumnVector, GpuExpression, GpuHashPartitioningBase, GpuMetric, RmmRapidsRetryIterator, SpillableColumnarBatch, SpillPriorities, TaskAutoCloseableResource} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -179,9 +179,19 @@ class GpuBatchSubPartitioner( // 1) Hash partition on the batch val partedTable = GpuHashPartitioningBase.hashPartitionAndClose( gpuBatch, inputBoundKeys, realNumPartitions, "Sub-Hash Calculate", hashSeed) + val (spillBatch, partitions) = withResource(partedTable) { _ => + // Convert to SpillableColumnarBatch for the following retry. + (SpillableColumnarBatch(GpuColumnVector.from(partedTable.getTable, types), + SpillPriorities.ACTIVE_BATCHING_PRIORITY), + partedTable.getPartitions) + } // 2) Split into smaller tables according to partitions - val subTables = withResource(partedTable) { _ => - partedTable.getTable.contiguousSplit(partedTable.getPartitions.tail: _*) + val subTables = RmmRapidsRetryIterator.withRetryNoSplit(spillBatch) { _ => + withResource(spillBatch.getColumnarBatch()) { cb => + withResource(GpuColumnVector.from(cb)) { tbl => + tbl.contiguousSplit(partitions.tail: _*) + } + } } // 3) Make each smaller table spillable and cache them in the queue withResource(subTables) { _ =>