diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py index b7c985edff2..bd8487a4f11 100644 --- a/integration_tests/src/main/python/url_test.py +++ b/integration_tests/src/main/python/url_test.py @@ -172,6 +172,14 @@ def test_parse_url_query_with_key(): lambda spark: unary_op_df(spark, url_gen) .selectExpr("a", "parse_url(a, 'QUERY', 'abc')", "parse_url(a, 'QUERY', 'a')") ) + +@allow_non_gpu('ProjectExec', 'ParseUrl') +def test_parse_url_query_with_key_regex_fallback(): + url_gen = StringGen(url_pattern_with_key) + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, url_gen) + .selectExpr("a", "parse_url(a, 'QUERY', 'a?c')", "parse_url(a, 'QUERY', '*')"), + 'ParseUrl') @pytest.mark.parametrize('part', supported_with_key_parts, ids=idfn) def test_parse_url_with_key(part): @@ -183,4 +191,4 @@ def test_parse_url_with_key(part): def test_parse_url_with_key_fallback(part): assert_gpu_fallback_collect( lambda spark: unary_op_df(spark, url_gen).selectExpr("parse_url(a, '" + part + "', 'key')"), - 'ParseUrl') \ No newline at end of file + 'ParseUrl') diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index b9451b51606..5263d539127 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3260,8 +3260,24 @@ object GpuOverrides extends Logging { if (a.failOnError) { willNotWorkOnGpu("Fail on error is not supported on GPU when parsing urls.") } - + + // In Spark, the key in parse_url could act like a regex, but on GPU, we only support + // literal keys. So we need to fallback if the key contains regex special characters. + // See Spark issue: https://issues.apache.org/jira/browse/SPARK-44500 extractStringLit(a.children(1)).map(_.toUpperCase) match { + case Some("QUERY") if (a.children.size == 3) => { + extractLit(a.children(2)).foreach { key => + if (key.value != null) { + val keyStr = key.value.asInstanceOf[UTF8String].toString + val specialCharacters = List("\\", "[", "]", "{", "}", "^", "-", "$", + ".", "+", "*", "?", "|") + if (specialCharacters.exists(keyStr.contains(_))) { + willNotWorkOnGpu(s"Key $keyStr could act like a regex which is not " + + "supported on GPU") + } + } + } + } case Some(part) if GpuParseUrl.isSupportedPart(part) => case Some(other) => willNotWorkOnGpu(s"Part to extract $other is not supported on GPU")