Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
winningsix committed Nov 16, 2023
1 parent 1364b97 commit 58d3f9c
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ class GpuBroadcastNestedLoopJoinMeta(
if (!leftExpr.isEmpty) GpuProjectExec(leftExpr ++ left.output, left)(true) else left
val rightChild =
if (!rightExpr.isEmpty) GpuProjectExec(rightExpr ++ right.output, right)(true) else right
val postBoardcastCondition =
val postBuildCondition =
if (gpuBuildSide == GpuBuildLeft) leftExpr ++ left.output else rightExpr ++ right.output

// TODO: a code refactor is needed to skip passing in postBoardcastCondition as a parameter to
// TODO: a code refactor is needed to skip passing in postBuildCondition as a parameter to
// instantiate GpuBroadcastNestedLoopJoinExec. This is because currently output columnar batch
// of broadcast side is handled inside GpuBroadcastNestedLoopJoinExec. Have to manually build
// a project node to build side batch.
val joinExec = GpuBroadcastNestedLoopJoinExec(
leftChild, rightChild,
join.joinType, gpuBuildSide,
remains,
postBoardcastCondition,
postBuildCondition,
conf.gpuTargetBatchSizeBytes)
if (leftExpr.isEmpty && rightExpr.isEmpty) {
joinExec
Expand Down Expand Up @@ -135,7 +135,7 @@ case class GpuBroadcastNestedLoopJoinExec(
joinType: JoinType,
gpuBuildSide: GpuBuildSide,
condition: Option[Expression],
postBroadcastCondition: List[NamedExpression],
postBuildCondition: List[NamedExpression],
targetSizeBytes: Long) extends GpuBroadcastNestedLoopJoinExecBase(
left, right, joinType, gpuBuildSide, condition, postBroadcastCondition, targetSizeBytes
left, right, joinType, gpuBuildSide, condition, postBuildCondition, targetSizeBytes
)
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ case class GpuBroadcastNestedLoopJoinExec(
joinType: JoinType,
gpuBuildSide: GpuBuildSide,
condition: Option[Expression],
postBroadcastCondition: List[NamedExpression],
postBuildCondition: List[NamedExpression],
targetSizeBytes: Long,
executorBroadcast: Boolean) extends GpuBroadcastNestedLoopJoinExecBase(
left, right, joinType, gpuBuildSide, condition, postBroadcastCondition, targetSizeBytes
left, right, joinType, gpuBuildSide, condition, postBuildCondition, targetSizeBytes
) {
import GpuMetric._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ case class GpuBroadcastNestedLoopJoinExec(
joinType: JoinType,
gpuBuildSide: GpuBuildSide,
condition: Option[Expression],
postBroadcastCondition: List[NamedExpression],
postBuildCondition: List[NamedExpression],
targetSizeBytes: Long) extends GpuBroadcastNestedLoopJoinExecBase(
left, right, joinType, gpuBuildSide, condition, postBroadcastCondition, targetSizeBytes
left, right, joinType, gpuBuildSide, condition, postBuildCondition, targetSizeBytes
)
16 changes: 0 additions & 16 deletions tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint}
import org.apache.spark.sql.functions.{length, lower, trim}
import org.apache.spark.sql.rapids.TestTrampolineUtil
import org.apache.spark.sql.types.BooleanType

Expand Down Expand Up @@ -103,21 +102,6 @@ class JoinsSuite extends SparkQueryCompareTestSuite {
(A, B) => A.join(B, A("longs") === B("longs"), "LeftAnti")
}

IGNORE_ORDER_testSparkResultsAreEqual2(
"join condition pushing down for AST non-supported case in outer join",
stringWithTailingSpaces, stringWithTailingSpaces2, conf = new SparkConf()) {
(A, B) =>
A.join(
B, length(lower(trim(A("name")))) < length(lower(trim(B("name")))), "leftouter")
}

IGNORE_ORDER_testSparkResultsAreEqual2(
"single side join condition pushing down for AST non-supported case in outer join",
stringWithTailingSpaces, stringWithTailingSpaces2, conf = new SparkConf()) {
(A, B) =>
A.join(B, length(lower(trim(A("name")))) < B("number_int"), "leftouter")
}

for (buildRight <- Seq(false, true)) {
for (leftEmpty <- Seq(false, true)) {
for (rightEmpty <- Seq(false, true)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1831,30 +1831,6 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll {
).toDF("doubles", "more_doubles")
}

def stringWithTailingSpaces(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[(java.lang.String, java.lang.Double)](
("Foo ", 1.0d),
("Barr ", Double.NaN),
("BAZZZ ", 3.0d),
("QuxXx ", 4.0d),
("Freed ", Double.NaN),
("ThuddD", 6.0d)
).toDF("name", "number_double")
}

def stringWithTailingSpaces2(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[(java.lang.String, java.lang.Long)](
("Foo2 ", 1),
("Barr2 ", null),
("BAZZZ2 ", 3),
("QuxXx2 ", 4),
("Freed2 ", null),
("ThuddD2", 6)
).toDF("name", "number_int")
}

def decimals(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq[(String, BigDecimal)](
Expand Down

0 comments on commit 58d3f9c

Please sign in to comment.