diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala index 6c04a2aeb57..85acd57c1a9 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala @@ -28,7 +28,7 @@ import com.nvidia.spark.rapids.Arm.withResource import org.apache.spark.broadcast.Broadcast import org.apache.spark.rapids.shims.GpuShuffleExchangeExec import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.execution.{CoalescedPartitionSpec, SparkPlan} import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec @@ -53,29 +53,66 @@ class GpuBroadcastNestedLoopJoinMeta( } verifyBuildSideWasReplaced(buildSide) - val condition = conditionMeta.map(_.convertToGpu()) - val isAstCondition = conditionMeta.forall(_.canThisBeAst) - join.joinType match { - case _: InnerLike => - case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft => - throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") - case RightOuter if gpuBuildSide == GpuBuildRight => - throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") - case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - // Cannot post-filter these types of joins - assert(isAstCondition, s"Non-AST condition in ${join.joinType}") - case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}") - } + // If ast-able, try to split if needed. Otherwise, do post-filter + val isAstCondition = canJoinCondAstAble() - val joinExec = GpuBroadcastNestedLoopJoinExec( - left, right, - join.joinType, gpuBuildSide, - if (isAstCondition) condition else None, - conf.gpuTargetBatchSizeBytes, - join.isExecutorBroadcast) if (isAstCondition) { - joinExec + // Try to extract non-ast-able conditions from join conditions + val (remains, leftExpr, rightExpr) = AstUtil.extractNonAstFromJoinCond(conditionMeta, + left.output, right.output, true) + + // Reconstruct the child with wrapped project node if needed. + val leftChild = + if (!leftExpr.isEmpty) GpuProjectExec(leftExpr ++ left.output, left)(true) else left + val rightChild = + if (!rightExpr.isEmpty) GpuProjectExec(rightExpr ++ right.output, right)(true) else right + val postBuildCondition = + if (gpuBuildSide == GpuBuildLeft) leftExpr ++ left.output else rightExpr ++ right.output + + // TODO: a code refactor is needed to skip passing in postBuildCondition as a parameter to + // instantiate GpuBroadcastNestedLoopJoinExec. This is because currently output columnar batch + // of broadcast side is handled inside GpuBroadcastNestedLoopJoinExec. Have to manually build + // a project node to build side batch. + val joinExec = GpuBroadcastNestedLoopJoinExec( + leftChild, rightChild, + join.joinType, gpuBuildSide, + remains, + postBuildCondition, + conf.gpuTargetBatchSizeBytes, + join.isExecutorBroadcast) + if (leftExpr.isEmpty && rightExpr.isEmpty) { + joinExec + } else { + // Remove the intermediate attributes from left and right side project nodes. Output + // attributes need to be updated based on types + GpuProjectExec( + GpuBroadcastNestedLoopJoinExecBase.output( + join.joinType, left.output, right.output).toList, + joinExec)(false) + } } else { + val condition = conditionMeta.map(_.convertToGpu()) + + join.joinType match { + case _: InnerLike => + case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft => + throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") + case RightOuter if gpuBuildSide == GpuBuildRight => + throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") + case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + // Cannot post-filter these types of joins + assert(isAstCondition, s"Non-AST condition in ${join.joinType}") + case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}") + } + + val joinExec = GpuBroadcastNestedLoopJoinExec( + left, right, + join.joinType, gpuBuildSide, + None, + List.empty, + conf.gpuTargetBatchSizeBytes, + join.isExecutorBroadcast) + // condition cannot be implemented via AST so fallback to a post-filter if necessary condition.map { // TODO: Restore batch coalescing logic here. @@ -97,9 +134,10 @@ case class GpuBroadcastNestedLoopJoinExec( joinType: JoinType, gpuBuildSide: GpuBuildSide, condition: Option[Expression], + postBroadcastCondition: List[NamedExpression], targetSizeBytes: Long, executorBroadcast: Boolean) extends GpuBroadcastNestedLoopJoinExecBase( - left, right, joinType, gpuBuildSide, condition, targetSizeBytes + left, right, joinType, gpuBuildSide, condition, postBroadcastCondition, targetSizeBytes ) { import GpuMetric._ @@ -166,7 +204,7 @@ case class GpuBroadcastNestedLoopJoinExec( } } - override def makeBuiltBatch( + override def makeBuiltBatchInternal( relation: Any, buildTime: GpuMetric, buildDataSize: GpuMetric): ColumnarBatch = {