Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Support split non-AST-able join condition for BroadcastNested… [databricks] #9695

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 2 additions & 32 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
from _pytest.mark.structures import ParameterSet
from pyspark.sql.functions import array_contains, broadcast, col
from pyspark.sql.functions import broadcast, col
from pyspark.sql.types import *
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime, is_emr_runtime
Expand Down Expand Up @@ -397,47 +397,17 @@ def do_join(spark):
return left.join(broadcast(right), left.a > f.log(right.r_a), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Cross', 'Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet which is supposed to be extracted into child
# nodes. And this test doesn't cover other join types due to:
# (1) build right are not supported for Right
# (2) FullOuter: currently is not supported
# Those fallback reasons are not due to AST. Additionally, this test case changes test_broadcast_nested_loop_join_with_condition_fallback:
# (1) adapt double to integer since AST current doesn't support it.
# (2) switch to right side build to pass checks of 'Left', 'LeftSemi', 'LeftAnti' join types
return left.join(broadcast(right), f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf={"spark.rapids.sql.castFloatToIntegralTypes.enabled": True})

@allow_non_gpu('BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'Cast', 'GreaterThan', 'Log')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support double type which is not split-able into child nodes.
# AST does not support cast or logarithm yet
return broadcast(left).join(right, left.a > f.log(right.r_a), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec')

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_array_contains(data_gen, join_type):
arr_gen = ArrayGen(data_gen)
literal = with_cpu_session(lambda spark: gen_scalar(data_gen))
def do_join(spark):
left, right = create_df(spark, arr_gen, 50, 25)
# Array_contains will be pushed down into project child nodes
return broadcast(left).join(right, array_contains(left.a, literal.cast(data_gen.data_type)) < array_contains(right.r_a, literal.cast(data_gen.data_type)))
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
Expand Down
122 changes: 0 additions & 122 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -1115,15 +1115,6 @@ abstract class BaseExprMeta[INPUT <: Expression](
childExprs.forall(_.canThisBeAst) && cannotBeAstReasons.isEmpty
}

/**
* Check whether this node itself can be converted to AST. It will not recursively check its
* children. It's used to check join condition AST-ability in top-down fashion.
*/
lazy val canSelfBeAst = {
tagForAst()
cannotBeAstReasons.isEmpty
}

final def requireAstForGpu(): Unit = {
tagForAst()
cannotBeAstReasons.foreach { reason =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,22 +240,6 @@ class GpuEquivalentExpressions {
}

object GpuEquivalentExpressions {
/**
* Recursively replaces semantic equal expression with its proxy expression in `substitutionMap`.
*/
def replaceWithSemanticCommonRef(
expr: Expression,
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression]): Expression = {
expr match {
case e: AttributeReference => e
case _ =>
substitutionMap.get(GpuExpressionEquals(expr)) match {
case Some(attr) => attr
case None => expr.mapChildren(replaceWithSemanticCommonRef(_, substitutionMap))
}
}
}

/**
* Recursively replaces expression with its proxy expression in `substitutionMap`.
*/
Expand Down
Loading