Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Databricks shim GpuBroadcastNestedLoopJoinExec for AST splitting change [databricks] #9696

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.nvidia.spark.rapids.Arm.withResource
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rapids.shims.GpuShuffleExchangeExec
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.execution.{CoalescedPartitionSpec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
Expand All @@ -53,29 +53,66 @@ class GpuBroadcastNestedLoopJoinMeta(
}
verifyBuildSideWasReplaced(buildSide)

val condition = conditionMeta.map(_.convertToGpu())
val isAstCondition = conditionMeta.forall(_.canThisBeAst)
join.joinType match {
case _: InnerLike =>
case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft =>
throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
case RightOuter if gpuBuildSide == GpuBuildRight =>
throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
// Cannot post-filter these types of joins
assert(isAstCondition, s"Non-AST condition in ${join.joinType}")
case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}")
}
// If ast-able, try to split if needed. Otherwise, do post-filter
val isAstCondition = canJoinCondAstAble()

val joinExec = GpuBroadcastNestedLoopJoinExec(
left, right,
join.joinType, gpuBuildSide,
if (isAstCondition) condition else None,
conf.gpuTargetBatchSizeBytes,
join.isExecutorBroadcast)
if (isAstCondition) {
joinExec
// Try to extract non-ast-able conditions from join conditions
val (remains, leftExpr, rightExpr) = AstUtil.extractNonAstFromJoinCond(conditionMeta,
left.output, right.output, true)

// Reconstruct the child with wrapped project node if needed.
val leftChild =
if (!leftExpr.isEmpty) GpuProjectExec(leftExpr ++ left.output, left)(true) else left
val rightChild =
if (!rightExpr.isEmpty) GpuProjectExec(rightExpr ++ right.output, right)(true) else right
val postBuildCondition =
if (gpuBuildSide == GpuBuildLeft) leftExpr ++ left.output else rightExpr ++ right.output

// TODO: a code refactor is needed to skip passing in postBuildCondition as a parameter to
// instantiate GpuBroadcastNestedLoopJoinExec. This is because currently output columnar batch
// of broadcast side is handled inside GpuBroadcastNestedLoopJoinExec. Have to manually build
// a project node to build side batch.
val joinExec = GpuBroadcastNestedLoopJoinExec(
leftChild, rightChild,
join.joinType, gpuBuildSide,
remains,
postBuildCondition,
conf.gpuTargetBatchSizeBytes,
join.isExecutorBroadcast)
if (leftExpr.isEmpty && rightExpr.isEmpty) {
joinExec
} else {
// Remove the intermediate attributes from left and right side project nodes. Output
// attributes need to be updated based on types
GpuProjectExec(
GpuBroadcastNestedLoopJoinExecBase.output(
join.joinType, left.output, right.output).toList,
joinExec)(false)
}
} else {
val condition = conditionMeta.map(_.convertToGpu())

join.joinType match {
case _: InnerLike =>
case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft =>
throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
case RightOuter if gpuBuildSide == GpuBuildRight =>
throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}")
case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
// Cannot post-filter these types of joins
assert(isAstCondition, s"Non-AST condition in ${join.joinType}")
case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}")
}

val joinExec = GpuBroadcastNestedLoopJoinExec(
left, right,
join.joinType, gpuBuildSide,
None,
List.empty,
conf.gpuTargetBatchSizeBytes,
join.isExecutorBroadcast)

// condition cannot be implemented via AST so fallback to a post-filter if necessary
condition.map {
// TODO: Restore batch coalescing logic here.
Expand All @@ -97,9 +134,10 @@ case class GpuBroadcastNestedLoopJoinExec(
joinType: JoinType,
gpuBuildSide: GpuBuildSide,
condition: Option[Expression],
postBroadcastCondition: List[NamedExpression],
targetSizeBytes: Long,
executorBroadcast: Boolean) extends GpuBroadcastNestedLoopJoinExecBase(
left, right, joinType, gpuBuildSide, condition, targetSizeBytes
left, right, joinType, gpuBuildSide, condition, postBroadcastCondition, targetSizeBytes
) {
import GpuMetric._

Expand All @@ -118,7 +156,7 @@ case class GpuBroadcastNestedLoopJoinExec(
executorBroadcast
}

def shuffleExchange: GpuShuffleExchangeExec = buildPlan match {
def shuffleExchange: GpuShuffleExchangeExec = getBroadcastPlan(buildPlan) match {
case bqse: ShuffleQueryStageExec if bqse.plan.isInstanceOf[GpuShuffleExchangeExec] =>
bqse.plan.asInstanceOf[GpuShuffleExchangeExec]
case bqse: ShuffleQueryStageExec if bqse.plan.isInstanceOf[ReusedExchangeExec] =>
Expand All @@ -127,6 +165,15 @@ case class GpuBroadcastNestedLoopJoinExec(
case reused: ReusedExchangeExec => reused.child.asInstanceOf[GpuShuffleExchangeExec]
}

private[this] def getBroadcastPlan(plan: SparkPlan): SparkPlan = {
plan match {
// In case has post broadcast project. It happens when join condition contains non-AST
// expression which results in a project right after broadcast.
case plan: GpuProjectExec => plan.child
case _ => plan
}
}

override def getBroadcastRelation(): Any = {
if (executorBroadcast) {
// Get all the broadcast data from the shuffle coalesced into a single partition
Expand Down Expand Up @@ -166,7 +213,7 @@ case class GpuBroadcastNestedLoopJoinExec(
}
}

override def makeBuiltBatch(
override def makeBuiltBatchInternal(
relation: Any,
buildTime: GpuMetric,
buildDataSize: GpuMetric): ColumnarBatch = {
Expand Down