Skip to content

Commit

Permalink
Add in small optimization for instr comparison (#10584)
Browse files Browse the repository at this point in the history
* Add in small optimization for instr comparison

Signed-off-by: Robert (Bobby) Evans <[email protected]>

* Review Comments

---------

Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Mar 15, 2024
1 parent e0ef44a commit 4393e9f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
15 changes: 14 additions & 1 deletion integration_tests/src/main/python/string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,19 @@ def assert_gpu_did_fallback(sql_text):
assert_gpu_did_fallback('locate(a, a, pos)')
assert_gpu_did_fallback('locate(a, "a", pos)')

# There is no contains function exposed in Spark. You can turn it into a
# LIKE %FOO% or we have seen some use instr > 0 to do the same thing.
# Spark optimizes LIKE to be a contains, we also optimize instr to do
# something similar.
def test_instr_as_contains():
gen = mk_str_gen('.{0,3}Z_Z.{0,3}A.{0,3}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'instr(a, "A") > 0',
'0 < instr(a, "A")',
'1 <= instr(a, "A")',
'instr(a, "A") >= 1',
'a LIKE "%A%"'))

def test_instr():
gen = mk_str_gen('.{0,3}Z_Z.{0,3}A.{0,3}')
Expand Down Expand Up @@ -861,4 +874,4 @@ def test_format_number_float_value():
gpu_results = list(map(lambda x: float(x[0].replace(",", "")), with_gpu_session(
lambda spark: unary_op_df(spark, data_gen).selectExpr('format_number(a, 5)').collect())))
for cpu, gpu in zip(cpu_results, gpu_results):
assert math.isclose(cpu, gpu, rel_tol=1e-7) or math.isclose(cpu, gpu, abs_tol=1.1e-5)
assert math.isclose(cpu, gpu, rel_tol=1e-7) or math.isclose(cpu, gpu, abs_tol=1.1e-5)
Original file line number Diff line number Diff line change
Expand Up @@ -1928,7 +1928,7 @@ object GpuOverrides extends Logging {
TypeSig.orderable)),
(a, conf, p, r) => new BinaryAstExprMeta[GreaterThan](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuGreaterThan(lhs, rhs)
GpuStringInstr.optimizeContains(GpuGreaterThan(lhs, rhs))
}),
expr[GreaterThanOrEqual](
">= operator",
Expand All @@ -1943,7 +1943,7 @@ object GpuOverrides extends Logging {
TypeSig.orderable)),
(a, conf, p, r) => new BinaryAstExprMeta[GreaterThanOrEqual](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuGreaterThanOrEqual(lhs, rhs)
GpuStringInstr.optimizeContains(GpuGreaterThanOrEqual(lhs, rhs))
}),
expr[In](
"IN operator",
Expand Down Expand Up @@ -1993,7 +1993,7 @@ object GpuOverrides extends Logging {
TypeSig.orderable)),
(a, conf, p, r) => new BinaryAstExprMeta[LessThan](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuLessThan(lhs, rhs)
GpuStringInstr.optimizeContains(GpuLessThan(lhs, rhs))
}),
expr[LessThanOrEqual](
"<= operator",
Expand All @@ -2008,7 +2008,7 @@ object GpuOverrides extends Logging {
TypeSig.orderable)),
(a, conf, p, r) => new BinaryAstExprMeta[LessThanOrEqual](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuLessThanOrEqual(lhs, rhs)
GpuStringInstr.optimizeContains(GpuLessThanOrEqual(lhs, rhs))
}),
expr[CaseWhen](
"CASE WHEN expression",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,27 @@ case class GpuStringToMap(strExpr: Expression,
}
}

object GpuStringInstr {
def optimizeContains(cmp: GpuExpression): GpuExpression = {
cmp match {
case GpuGreaterThan(GpuStringInstr(str, substr: GpuLiteral), GpuLiteral(0, _)) =>
// instr(A, B) > 0 becomes contains(A, B)
GpuContains(str, substr)
case GpuGreaterThanOrEqual(GpuStringInstr(str, substr: GpuLiteral), GpuLiteral(1, _)) =>
// instr(A, B) >= 1 becomes contains(A, B)
GpuContains(str, substr)
case GpuLessThan(GpuLiteral(0, _), GpuStringInstr(str, substr: GpuLiteral)) =>
// 0 < instr(A, B) becomes contains(A, B)
GpuContains(str, substr)
case GpuLessThanOrEqual(GpuLiteral(1, _), GpuStringInstr(str, substr: GpuLiteral)) =>
// 1 <= instr(A, B) becomes contains(A, B)
GpuContains(str, substr)
case _ =>
cmp
}
}
}

case class GpuStringInstr(str: Expression, substr: Expression)
extends GpuBinaryExpressionArgsAnyScalar
with ImplicitCastInputTypes
Expand Down

0 comments on commit 4393e9f

Please sign in to comment.