diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index c2062605ca1..b67f9dc6679 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -1012,14 +1012,16 @@ def test_regexp_replace_simple(regexp_enabled): 'REGEXP_REPLACE(a, "ab", "PROD")', 'REGEXP_REPLACE(a, "ae", "PROD")', 'REGEXP_REPLACE(a, "bc", "PROD")', - 'REGEXP_REPLACE(a, "fa", "PROD")' + 'REGEXP_REPLACE(a, "fa", "PROD")', + 'REGEXP_REPLACE(a, "a\n", "PROD")', + 'REGEXP_REPLACE(a, "\n", "PROD")' ), conf=conf ) @pytest.mark.parametrize("regexp_enabled", ['true', 'false']) def test_regexp_replace_multi_optimization(regexp_enabled): - gen = mk_str_gen('[abcdef]{0,2}') + gen = mk_str_gen('[abcdef\t\n\a]{0,3}') conf = { 'spark.rapids.sql.regexp.enabled': regexp_enabled } @@ -1032,7 +1034,9 @@ def test_regexp_replace_multi_optimization(regexp_enabled): 'REGEXP_REPLACE(a, "aa|bb|cc|dd", "PROD")', 'REGEXP_REPLACE(a, "(aa|bb)|(cc|dd)", "PROD")', 'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee", "PROD")', - 'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee|ff", "PROD")' + 'REGEXP_REPLACE(a, "aa|bb|cc|dd|ee|ff", "PROD")', + 'REGEXP_REPLACE(a, "a\n|b\a|c\t", "PROD")', + 'REGEXP_REPLACE(a, "a\ta|b\nb", "PROD")' ), conf=conf ) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 45905f0b9e0..07b2d022f67 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -593,8 +593,9 @@ object GpuOverrides extends Logging { } def isSupportedStringReplacePattern(strLit: String): Boolean = { - // check for regex special characters, except for \u0000 which we can support - !regexList.filterNot(_ == "\u0000").exists(pattern => strLit.contains(pattern)) + // check for regex special characters, except for \u0000, \n, \r, and \t which we can support + val supported = Seq("\u0000", "\n", "\r", "\t") + !regexList.filterNot(supported.contains(_)).exists(pattern => strLit.contains(pattern)) } def isSupportedStringReplacePattern(exp: Expression): Boolean = { @@ -605,7 +606,6 @@ object GpuOverrides extends Logging { if (strLit.isEmpty) { false } else { - // check for regex special characters, except for \u0000 which we can support isSupportedStringReplacePattern(strLit) } case _ => false diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala index 3c3933946c5..25c8c10b26d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/StringFunctionSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -207,7 +207,8 @@ class RegExpUtilsSuite extends AnyFunSuite { "aa|bb|cc|dd" -> Seq("aa", "bb", "cc", "dd"), "(aa|bb)|(cc|dd)" -> Seq("aa", "bb", "cc", "dd"), "aa|bb|cc|dd|ee" -> Seq("aa", "bb", "cc", "dd", "ee"), - "aa|bb|cc|dd|ee|ff" -> Seq("aa", "bb", "cc", "dd", "ee", "ff") + "aa|bb|cc|dd|ee|ff" -> Seq("aa", "bb", "cc", "dd", "ee", "ff"), + "a\n|b\t|c\r" -> Seq("a\n", "b\t", "c\r") ) regexChoices.foreach { case (pattern, choices) =>