From 4393e9fc8c711dcf1aac7ef4a805cd0a53a6a4d5 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 14 Mar 2024 23:46:29 -0500 Subject: [PATCH] Add in small optimization for instr comparison (#10584) * Add in small optimization for instr comparison Signed-off-by: Robert (Bobby) Evans * Review Comments --------- Signed-off-by: Robert (Bobby) Evans --- .../src/main/python/string_test.py | 15 ++++++++++++- .../nvidia/spark/rapids/GpuOverrides.scala | 8 +++---- .../spark/sql/rapids/stringFunctions.scala | 21 +++++++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/integration_tests/src/main/python/string_test.py b/integration_tests/src/main/python/string_test.py index 7ecff4c939b..5631f13f13d 100644 --- a/integration_tests/src/main/python/string_test.py +++ b/integration_tests/src/main/python/string_test.py @@ -214,6 +214,19 @@ def assert_gpu_did_fallback(sql_text): assert_gpu_did_fallback('locate(a, a, pos)') assert_gpu_did_fallback('locate(a, "a", pos)') +# There is no contains function exposed in Spark. You can turn it into a +# LIKE %FOO% or we have seen some use instr > 0 to do the same thing. +# Spark optimizes LIKE to be a contains, we also optimize instr to do +# something similar. +def test_instr_as_contains(): + gen = mk_str_gen('.{0,3}Z_Z.{0,3}A.{0,3}') + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'instr(a, "A") > 0', + '0 < instr(a, "A")', + '1 <= instr(a, "A")', + 'instr(a, "A") >= 1', + 'a LIKE "%A%"')) def test_instr(): gen = mk_str_gen('.{0,3}Z_Z.{0,3}A.{0,3}') @@ -861,4 +874,4 @@ def test_format_number_float_value(): gpu_results = list(map(lambda x: float(x[0].replace(",", "")), with_gpu_session( lambda spark: unary_op_df(spark, data_gen).selectExpr('format_number(a, 5)').collect()))) for cpu, gpu in zip(cpu_results, gpu_results): - assert math.isclose(cpu, gpu, rel_tol=1e-7) or math.isclose(cpu, gpu, abs_tol=1.1e-5) \ No newline at end of file + assert math.isclose(cpu, gpu, rel_tol=1e-7) or math.isclose(cpu, gpu, abs_tol=1.1e-5) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 38562dfdb2f..8f1a720b92b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1928,7 +1928,7 @@ object GpuOverrides extends Logging { TypeSig.orderable)), (a, conf, p, r) => new BinaryAstExprMeta[GreaterThan](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuGreaterThan(lhs, rhs) + GpuStringInstr.optimizeContains(GpuGreaterThan(lhs, rhs)) }), expr[GreaterThanOrEqual]( ">= operator", @@ -1943,7 +1943,7 @@ object GpuOverrides extends Logging { TypeSig.orderable)), (a, conf, p, r) => new BinaryAstExprMeta[GreaterThanOrEqual](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuGreaterThanOrEqual(lhs, rhs) + GpuStringInstr.optimizeContains(GpuGreaterThanOrEqual(lhs, rhs)) }), expr[In]( "IN operator", @@ -1993,7 +1993,7 @@ object GpuOverrides extends Logging { TypeSig.orderable)), (a, conf, p, r) => new BinaryAstExprMeta[LessThan](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuLessThan(lhs, rhs) + GpuStringInstr.optimizeContains(GpuLessThan(lhs, rhs)) }), expr[LessThanOrEqual]( "<= operator", @@ -2008,7 +2008,7 @@ object GpuOverrides extends Logging { TypeSig.orderable)), (a, conf, p, r) => new BinaryAstExprMeta[LessThanOrEqual](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuLessThanOrEqual(lhs, rhs) + GpuStringInstr.optimizeContains(GpuLessThanOrEqual(lhs, rhs)) }), expr[CaseWhen]( "CASE WHEN expression", diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 1c9209189c1..0850b1fcafb 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1962,6 +1962,27 @@ case class GpuStringToMap(strExpr: Expression, } } +object GpuStringInstr { + def optimizeContains(cmp: GpuExpression): GpuExpression = { + cmp match { + case GpuGreaterThan(GpuStringInstr(str, substr: GpuLiteral), GpuLiteral(0, _)) => + // instr(A, B) > 0 becomes contains(A, B) + GpuContains(str, substr) + case GpuGreaterThanOrEqual(GpuStringInstr(str, substr: GpuLiteral), GpuLiteral(1, _)) => + // instr(A, B) >= 1 becomes contains(A, B) + GpuContains(str, substr) + case GpuLessThan(GpuLiteral(0, _), GpuStringInstr(str, substr: GpuLiteral)) => + // 0 < instr(A, B) becomes contains(A, B) + GpuContains(str, substr) + case GpuLessThanOrEqual(GpuLiteral(1, _), GpuStringInstr(str, substr: GpuLiteral)) => + // 1 <= instr(A, B) becomes contains(A, B) + GpuContains(str, substr) + case _ => + cmp + } + } +} + case class GpuStringInstr(str: Expression, substr: Expression) extends GpuBinaryExpressionArgsAnyScalar with ImplicitCastInputTypes