diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index d2593788b67..61d7d7960bf 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -298,6 +298,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.NthValue|`nth_value`|nth window operator|true|None| spark.rapids.sql.expression.OctetLength|`octet_length`|The byte length of string data|true|None| spark.rapids.sql.expression.Or|`or`|Logical OR|true|None| +spark.rapids.sql.expression.ParseUrl|`parse_url`|Extracts a part from a URL|true|None| spark.rapids.sql.expression.PercentRank|`percent_rank`|Window function that returns the percent rank value within the aggregation window|true|None| spark.rapids.sql.expression.Pmod|`pmod`|Pmod|true|None| spark.rapids.sql.expression.PosExplode|`posexplode_outer`, `posexplode`|Given an input array produces a sequence of rows for each value in the array|true|None| diff --git a/docs/compatibility.md b/docs/compatibility.md index b5cb01757dd..fdbea192390 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -448,6 +448,23 @@ Spark stores timestamps internally relative to the JVM time zone. Converting an between time zones is not currently supported on the GPU. Therefore operations involving timestamps will only be GPU-accelerated if the time zone used by the JVM is UTC. +## URL parsing + +In Spark, parse_url is based on java's URI library, while the implementation in the RAPIDS Accelerator is based on regex extraction. Therefore, the results may be different in some edge cases. + +These are the known cases where running on the GPU will produce different results to the CPU: + +- Spark allow an empty authority component only when it's followed by a non-empty path, + query component, or fragment component. But in plugin, parse_url just simply allow empty + authority component without checking if it is followed something or not. So `parse_url('http://', 'HOST')` will + return `null` in Spark, but return `""` in plugin. +- If an input url has a invalid Ipv6 address, Spark will return `null` for all components, but plugin will parse other + components except `HOST` as normal. So `http://userinfo@[1:2:3:4:5:6:7:8:9:10]/path?query=1#Ref`'s result will be + `[null,/path,query=1,Ref,http,/path?query=1,userinfo@[1:2:3:4:5:6:7:8:9:10],userinfo]` +- PATH and FILE of some edge cases (like empty string and url without PROTOCOL but contains "//") will be empty string in plugin instead of null in Spark. +- Only UTF-8 encoding is supported in the plugin. If the input url contains characters that are not in UTF-8 encoding, + the result may be different from Spark. + ## Windowing ### Window Functions diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 1ebf37e95a1..cafe6f6072f 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -10530,6 +10530,95 @@ are limited. +ParseUrl +`parse_url` +Extracts a part from a URL +None +project +url + + + + + + + + + +S + + + + + + + + + + +partToExtract + + + + + + + + + +PS
Literal value only
+ + + + + + + + + + +key + + + + + + + + + +PS
Literal value only
+ + + + + + + + + + +result + + + + + + + + + +S + + + + + + + + + + PercentRank `percent_rank` Window function that returns the percent rank value within the aggregation window @@ -10645,6 +10734,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + PosExplode `posexplode_outer`, `posexplode` Given an input array produces a sequence of rows for each value in the array @@ -10824,32 +10939,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PreciseTimestampConversion Expression used internally to convert the TimestampType to Long and back without losing precision, i.e. in microseconds. Used in time windowing @@ -11120,6 +11209,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Quarter `quarter` Returns the quarter of the year for date, in the range 1 to 4 @@ -11235,32 +11350,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - RaiseError `raise_error` Throw an exception @@ -11491,6 +11580,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + RegExpExtractAll `regexp_extract_all` Extract all strings matching a regular expression corresponding to the regex group index @@ -11690,32 +11805,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Remainder `%`, `mod` Remainder or modulo @@ -11878,6 +11967,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Rint `rint` Rounds up a double value to the nearest double equal to an integer @@ -12062,32 +12177,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ScalaUDF User Defined Function, the UDF can choose to implement a RAPIDS accelerated interface to get better performance. @@ -12318,6 +12407,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ShiftLeft `shiftleft` Bitwise shift left (<<) @@ -12454,32 +12569,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ShiftRightUnsigned `shiftrightunsigned` Bitwise unsigned shift right (>>>) @@ -12685,6 +12774,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Sinh `sinh` Hyperbolic sine @@ -12798,54 +12913,28 @@ are limited. PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
- - - -result - - - -S - - - - - - - - - - - - - - - - -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT + + + +result + + + +S + + + + + + + + + + + + + + SortArray @@ -13057,6 +13146,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Sqrt `sqrt` Square root @@ -13215,32 +13330,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringInstr `instr` Instr string operator @@ -13487,6 +13576,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringRPad `rpad` Pad a string on the right @@ -13576,32 +13691,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringRepeat `repeat` StringRepeat operator that repeats the given strings with numbers of times given by repeatTimes @@ -13848,6 +13937,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringToMap `str_to_map` Creates a map after splitting the input string into pairs of key-value strings @@ -13937,32 +14052,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringTranslate `translate` StringTranslate operator @@ -14256,6 +14345,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Substring `substr`, `substring` Substring operator @@ -14345,32 +14460,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - SubstringIndex `substring_index` substring_index operator @@ -14682,6 +14771,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Tanh `tanh` Hyperbolic tangent @@ -14772,32 +14887,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - TimeAdd Adds interval to timestamp @@ -15096,6 +15185,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + TransformValues `transform_values` Transform values in a map using a transform function @@ -15164,32 +15279,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - UnaryMinus `negative` Negate a numeric value @@ -15490,6 +15579,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + UnscaledValue Convert a Decimal to an unscaled long value for some aggregation optimizations @@ -15537,32 +15652,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Upper `upper`, `ucase` String uppercase operator @@ -15887,6 +15976,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + AggregateExpression Aggregate expression @@ -16083,32 +16198,6 @@ are limited. S -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ApproximatePercentile `percentile_approx`, `approx_percentile` Approximate percentile @@ -16283,6 +16372,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Average `avg`, `mean` Average aggregate operator @@ -16525,54 +16640,28 @@ are limited. PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
-NS - - -result - - - - - - - - - - - - - - -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
- - - - - -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT +NS + + +result + + + + + + + + + + + + + + +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT
+ + + CollectSet @@ -16708,6 +16797,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Count `count` Count aggregate operator @@ -16974,32 +17089,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Last `last`, `last_value` last aggregate operator @@ -17133,6 +17222,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Max `max` Max aggregate operator @@ -17399,32 +17514,6 @@ are limited. NS -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PivotFirst PivotFirst operator @@ -17557,6 +17646,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StddevPop `stddev_pop` Aggregation computing population standard deviation @@ -17823,32 +17938,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Sum `sum` Sum aggregate operator @@ -17982,6 +18071,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + VariancePop `var_pop` Aggregation computing population variance @@ -18248,32 +18363,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - NormalizeNaNAndZero Normalize NaN and zero @@ -18347,6 +18436,32 @@ are limited. NS +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + HiveGenericUDF Hive Generic UDF, the UDF can choose to implement a RAPIDS accelerated interface to get better performance diff --git a/integration_tests/src/main/python/url_test.py b/integration_tests/src/main/python/url_test.py new file mode 100644 index 00000000000..e4f9e42d787 --- /dev/null +++ b/integration_tests/src/main/python/url_test.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, run_with_cpu_and_gpu +from data_gen import * +from marks import * +from pyspark.sql.types import * +import pyspark.sql.functions as f +from spark_session import is_before_spark_320 + +# regex to generate limit length urls with HOST, PATH, QUERY, REF, PROTOCOL, FILE, AUTHORITY, USERINFO +url_pattern = r'((http|https|ftp)://)(([a-zA-Z][a-zA-Z0-9]{0,2}\.){0,3}([a-zA-Z][a-zA-Z0-9]{0,2})\.([a-zA-Z][a-zA-Z0-9]{0,2}))' \ + r'(:[0-9]{1,3}){0,1}(/[a-zA-Z0-9]{1,3}){0,3}(\?[a-zA-Z0-9]{1,3}=[a-zA-Z0-9]{1,3}){0,1}(#([a-zA-Z0-9]{1,3})){0,1}' + +url_pattern_with_key = r'((http|https|ftp|file)://)(([a-z]{1,3}\.){0,3}([a-z]{1,3})\.([a-z]{1,3}))' \ + r'(:[0-9]{1,3}){0,1}(/[a-z]{1,3}){0,3}(\?key=[a-z]{1,3}){0,1}(#([a-z]{1,3})){0,1}' + +url_gen = StringGen(url_pattern) + +def test_parse_url_host(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'HOST')" + )) + +def test_parse_url_path(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'PATH')" + )) + +def test_parse_url_query(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'QUERY')" + )) + +def test_parse_url_ref(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'REF')" + )) + +def test_parse_url_protocol(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'PROTOCOL')" + )) + +def test_parse_url_file(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'FILE')" + )) + +def test_parse_url_authority(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'AUTHORITY')" + )) + +def test_parse_url_userinfo(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen).selectExpr( + "a", + "parse_url(a, 'USERINFO')" + )) + +def test_parse_url_with_no_query_key(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, url_gen, length=100).selectExpr( + "a", + "parse_url(a, 'HOST', '')", + "parse_url(a, 'PATH', '')", + "parse_url(a, 'REF', '')", + "parse_url(a, 'PROTOCOL', '')", + "parse_url(a, 'FILE', '')", + "parse_url(a, 'AUTHORITY', '')", + "parse_url(a, 'USERINFO', '')" + )) + +def test_parse_url_with_query_key(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, StringGen(url_pattern_with_key)).selectExpr( + "a", + "parse_url(a, 'QUERY', 'key')" + )) + +def test_parse_url_too_many_args(): + assert_gpu_and_cpu_error( + lambda spark : unary_op_df(spark, StringGen()).selectExpr( + "a","parse_url(a, 'USERINFO', 'key', 'value')").collect(), + conf={}, + error_message='parse_url function requires two or three arguments') diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 91e2b1ea6bd..cd20c306e11 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3153,6 +3153,26 @@ object GpuOverrides extends Logging { ParamCheck("regexp", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING), ParamCheck("idx", TypeSig.lit(TypeEnum.INT), TypeSig.INT))), (a, conf, p, r) => new GpuRegExpExtractAllMeta(a, conf, p, r)), + expr[ParseUrl]( + "Extracts a part from a URL", + ExprChecks.projectOnly(TypeSig.STRING, TypeSig.STRING, + Seq(ParamCheck("url", TypeSig.STRING, TypeSig.STRING), + ParamCheck("partToExtract", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING)), + // Should really be an OptionalParam + Some(RepeatingParamCheck("key", TypeSig.lit(TypeEnum.STRING), TypeSig.STRING))), + (a, conf, p, r) => new ExprMeta[ParseUrl](a, conf, p, r) { + val failOnError = a.failOnError + + override def tagExprForGpu(): Unit = { + if (failOnError) { + willNotWorkOnGpu("Fail on error is not supported on GPU when parsing urls.") + } + } + + override def convertToGpu(): GpuExpression = { + GpuParseUrl(childExprs.map(_.convertToGpu()), failOnError) + } + }), expr[Length]( "String character length or binary byte length", ExprChecks.unaryProject(TypeSig.INT, TypeSig.INT, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala new file mode 100644 index 00000000000..00bbc860115 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/urlFunctions.scala @@ -0,0 +1,324 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import ai.rapids.cudf.{ColumnVector, DType, RegexProgram, Scalar, Table} +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.Arm._ +import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.shims.ShimExpression + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.shims.RapidsErrorUtils +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.types.UTF8String + +object GpuParseUrl { + private val HOST = "HOST" + private val PATH = "PATH" + private val QUERY = "QUERY" + private val REF = "REF" + private val PROTOCOL = "PROTOCOL" + private val FILE = "FILE" + private val AUTHORITY = "AUTHORITY" + private val USERINFO = "USERINFO" + private val REGEXPREFIX = """(&|^)(""" + private val REGEXSUBFIX = "=)([^&]*)" + // scalastyle:off line.size.limit + // + // private val HOST_REGEX = """^(?:(?:([^:/?#]+):)?(?://((?:(?:(?:[^\@]*)@)?(\[[\w%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(([^?#]*)(\?[^#]*)?)(#[\w\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + // """^(?:(?:( [^:/?#]+):)?(?://( (?:(?:(?:[^\@/]*)@)?(\[[\w%.:]+\]|[^/#:?]*))(?::[0-9]+)?))?(([^?#]*)(\?[^#]*)?)(#.*)?)$""" + private val HOST_REGEX = """^(?:(?:(?:[\w+\-.]+):)?(?:[0-9]+[^\\#]*|(?://(?:(?:(?:(?:[\w%\-_.!~*'();:&=+$,]*)@)?(\[[\w%.:]+\]|[^/\\#:?"<>\^`{|}]*))(?::[0-9]+)?))?(?:[^?#[\]\\"<>\^`{|}]*)?(?:\?[^\\"#<>\^`{|}]*)?)(?:#[\w\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val PATH_REGEX = """^(?:[\w+\-.]+:[^/#][^/#][^#]*|(?:[\w+\-.]+:)?(?://(?:[\w%\-_.!~*'();:&=+$,]*@)?(?:\[[\w%.:]+\]|[^/\\#:?"<>\^`{|}]*)(?::[0-9]+)?)?([^?#[\]\\"<>\^`{|}]*)?(?:\?[^\\"#<>\^`{|}]*)?)(?:#[\w\-_.!~*'();/?:@&=+$,[\]%]*)?$""" + private val QUERY_REGEX = """^(?:(?:(?:[\w+\-.]+):)?(?:[0-9]+[^\\#]*|(?://(?:(?:(?:(?:[\w%\-_.!~*'();:&=+$,]*)@)?(?:\[[\w%.:]+\]|[^/\\#:?"<>\^`{|}]*))(?::[0-9]+)?))?(?:[^?#[\]\\"<>\^`{|}]*)?(\?[^\\"#<>\^`{|}]*)?)(?:#[\w\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val REF_REGEX = """^(?:(?:(?:[\w+\-.]+):)?(?:[0-9]+[^\\#]*|(?://(?:(?:(?:(?:[\w%\-_.!~*'();:&=+$,]*)@)?(?:\[[\w%.:]+\]|[^/\\#:?"<>\^`{|}]*))(?::[0-9]+)?))?(?:[^?#[\]\\"<>\^`{|}]*)?(?:\?[^\\"#<>\^`{|}]*)?)(#[\w\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val PROTOCOL_REGEX = """^(?:(?:([\w+\-.]+):)?(?:(?:[0-9]+[^\\#]*|(?://(?:(?:(?:[\w%\-_.!~*'();:&=+$,]*)@)?(?:\[[\w%.:]+\]|[^/\\#:?"<>\^`{|}]*))(?::[0-9]+)?))?(?:[^?#[\]\\"<>\^`{|}]*)?(?:\?[^\\"#<>\^`{|}]*)?)(?:#[\w\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val FILE_REGEX = """^(?:[\w+\-.]+:[^/#][^/#][^#]*|(?:[\w+\-.]+:)?(?://(?:[\w%\-_.!~*'();:&=+$,]*@)?(?:\[[\w%.:]+\]|[^/\\#:?"<>\^`{|}]*)(?::[0-9]+)?)?((?:[^?#[\]\\"<>\^`{|}]*)?(?:\?[^\\"#<>\^`{|}]*)?))(?:#[\w\-_.!~*'();/?:@&=+$,[\]%]*)?$""" + private val AUTHORITY_REGEX = """^(?:(?:(?:[\w+\-.]+):)?(?:[0-9]+[^\\#]*|(?://((?:(?:(?:[\w%\-_.!~*'();:&=+$,]*)@)?(?:\[[\w%.:]+\]|[^/\\#:?"<>\^`{|}]*))(?::[0-9]+)?))?(?:[^?#[\]\\"<>\^`{|}]*)?(?:\?[^\\"#<>\^`{|}]*)?)(?:#[\w\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + private val USERINFO_REGEX = """^(?:(?:(?:[\w+\-.]+):)?(?:[0-9]+[^\\#]*|(?://(?:(?:(?:([\w%\-_.!~*'();:&=+$,]*)@)?(?:\[[\w%.:]+\]|[^/\\#:?"<>\^`{|}]*))(?::[0-9]+)?))?(?:[^?#[\]\\"<>\^`{|}]*)?(?:\?[^\\"#<>\^`{|}]*)?)(?:#[\w\-_.!~*'();/?:@&=+$,[\]%]*)?)$""" + // HostName parsing followed rules in java URI lib: + // hostname = domainlabel [ "." ] | 1*( domainlabel "." ) toplabel [ "." ] + // domainlabel = alphanum | alphanum *( alphanum | "-" ) alphanum + // toplabel = alpha | alpha *( alphanum | "-" ) alphanum + val hostnameRegex = """((([0-9a-zA-Z]|[0-9a-zA-Z][0-9a-zA-Z\-]*[0-9a-zA-Z])|(([0-9a-zA-Z]|[0-9a-zA-Z][0-9a-zA-Z\-]*[0-9a-zA-Z])\.)+([a-zA-Z]|[a-zA-Z][0-9a-zA-Z\-]*[0-9a-zA-Z]))\.?)""" + val ipv4Regex = """(((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" + val simpleIpv6Regex = """(\[[\w%.:]+])""" + // based on https://stackoverflow.com/questions/53497/regular-expression-that-matches-valid-ipv6-addresses + val ipv6Regex1 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?""" + // 1:2:3:4:5:6:7:8 + val ipv6Regex2 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:""" + // 1:: 1:2:3:4:5:6:7:: + val ipv6Regex3 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)""" + // 1::8 1:2:3:4:5:6::8 1:2:3:4:5:6::8 + val ipv6Regex4 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::7:8 1:2:3:4:5::7:8 1:2:3:4:5::8 + val ipv6Regex5 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::6:7:8 1:2:3:4::6:7:8 1:2:3:4::8 + val ipv6Regex6 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::5:6:7:8 1:2:3::5:6:7:8 1:2:3::8 + val ipv6Regex7 = """(([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)""" + // 1::4:5:6:7:8 1:2::4:5:6:7:8 1:2::8 + val ipv6Regex8 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?))""" + // 1::3:4:5:6:7:8 1::3:4:5:6:7:8 1::8 + val ipv6Regex9 = """(:((:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?(:[0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?|:))""" + // ::2:3:4:5:6:7:8 ::2:3:4:5:6:7:8 ::8 :: + val ipv6Regex10 = """(fe80:((:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?(:([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?)?)?)?%[\w]+)""" + // fe80::7:8%eth0 fe80::7:8%1 (link-local IPv6 addresses with zone index) + val ipv6Regex11 = """(::((ffff|FFFF)(:00?0?0?)?:)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)?(25[0-5]|(2[0-4]|1?[0-9])?[0-9]))""" + // ::255.255.255.255 ::ffff:255.255.255.255 ::ffff:0:255.255.255.255 (IPv4-mapped IPv6 addresses and IPv4-translated addresses) + val ipv6Regex12 = """([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?([0-9a-fA-F][0-9a-fA-F]?[0-9a-fA-F]?[0-9a-fA-F]?:)?:((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)((25[0-5]|(2[0-4]|1?[0-9])?[0-9])\.)(25[0-5]|(2[0-4]|1?[0-9])?[0-9])""" + // 2001:db8:3:4::192.0.2.33 64:ff9b::192.0.2.33 (IPv4-Embedded IPv6 Address) + val ipv6Regex13 = """(0:0:0:0:0:(0|FFFF|ffff):((25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9]))""" + // 0:0:0:0:0:0:13.1.68.3 + // scalastyle:on +} + +case class GpuParseUrl(children: Seq[Expression], + failOnErrorOverride: Boolean = SQLConf.get.ansiEnabled) + extends GpuExpression with ShimExpression with ExpectsInputTypes { + + def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) + + override def nullable: Boolean = true + override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + override def prettyName: String = "parse_url" + + import GpuParseUrl._ + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size > 3 || children.size < 2) { + RapidsErrorUtils.parseUrlWrongNumArgs(children.size) match { + case res: Some[TypeCheckResult] => return res.get + case _ => // error message has been thrown + } + } + super[ExpectsInputTypes].checkInputDataTypes() + } + + private def getPattern(key: UTF8String): RegexProgram = { + val regex = REGEXPREFIX + key.toString + REGEXSUBFIX + new RegexProgram(regex) + } + + private def reValid(url: ColumnVector): ColumnVector = { + val regex = """([^\s]*\s|([^[]*|[^[]*\[.*].*)%([^0-9a-fA-F]|[0-9a-fA-F][^0-9a-fA-F]|$))""" + val prog = new RegexProgram(regex) + withResource(url.matchesRe(prog)) { isMatch => + withResource(Scalar.fromNull(DType.STRING)) { nullScalar => + isMatch.ifElse(nullScalar, url) + } + } + } + + private def reMatch(url: ColumnVector, partToExtract: String): ColumnVector = { + val regex = partToExtract match { + case HOST => HOST_REGEX + case PATH => PATH_REGEX + case QUERY => QUERY_REGEX + case REF => REF_REGEX + case PROTOCOL => PROTOCOL_REGEX + case FILE => FILE_REGEX + case AUTHORITY => AUTHORITY_REGEX + case USERINFO => USERINFO_REGEX + case _ => throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract") + } + val prog = new RegexProgram(regex) + withResource(url.extractRe(prog)) { table: Table => + table.getColumn(0).incRefCount() + } + } + + private def uriPathEmptyToNulls(cv: ColumnVector, isUri: ColumnVector): ColumnVector = { + val res = withResource(ColumnVector.fromStrings("")) { empty => + withResource(ColumnVector.fromStrings(null)) { nulls => + cv.findAndReplaceAll(empty, nulls) + } + } + withResource(res) { _ => + isUri.ifElse(res, cv) + } + } + + private def emptyToNulls(cv: ColumnVector): ColumnVector = { + withResource(ColumnVector.fromStrings("")) { empty => + withResource(ColumnVector.fromStrings(null)) { nulls => + cv.findAndReplaceAll(empty, nulls) + } + } + } + + private def unsetInvalidHost(cv: ColumnVector): ColumnVector = { + val regex = "^(" + hostnameRegex + "|" + ipv4Regex + "|" + simpleIpv6Regex + ")$" + val prog = new RegexProgram(regex) + val HostnameIpv4Res = withResource(cv.matchesRe(prog)) { isMatch => + withResource(Scalar.fromNull(DType.STRING)) { nullScalar => + isMatch.ifElse(cv, nullScalar) + } + } + // match the simple ipv6 address, valid ipv6 only when necessary cause the regex is very long + val simpleIpv6Prog = new RegexProgram(simpleIpv6Regex) + withResource(cv.matchesRe(simpleIpv6Prog)) { isMatch => + val anyIpv6 = withResource(isMatch.any()) { a => + a.isValid && a.getBoolean + } + if (anyIpv6) { + withResource(HostnameIpv4Res) { _ => + unsetInvalidIpv6Host(HostnameIpv4Res, isMatch) + } + } else { + HostnameIpv4Res + } + } + } + + private def unsetInvalidProtocol(cv: ColumnVector): ColumnVector = { + val regex = """^[a-zA-Z][\w+\-.]*$""" + val prog = new RegexProgram(regex) + withResource(cv.matchesRe(prog)) { isMatch => + withResource(Scalar.fromNull(DType.STRING)) { nullScalar => + isMatch.ifElse(cv, nullScalar) + } + } + } + + private def unsetInvalidIpv6Host(cv: ColumnVector, simpleMatched: ColumnVector): ColumnVector = { + val regex = """^\[(""" + ipv6Regex1 + "|" + ipv6Regex2 + "|" + ipv6Regex3 + "|" + ipv6Regex4 + "|" + + ipv6Regex5 + "|" + ipv6Regex6 + "|" + ipv6Regex7 + "|" + ipv6Regex8 + "|" + ipv6Regex9 + "|" + + ipv6Regex10 + "|" + ipv6Regex11 + "|" + ipv6Regex12 + "|" + ipv6Regex13 + """)(%[\w]*)?]$""" + + val prog = new RegexProgram(regex) + + val invalidIpv6 = withResource(cv.matchesRe(prog)) { matched => + withResource(matched.not()) { invalid => + simpleMatched.and(invalid) + } + } + withResource(invalidIpv6) { _ => + withResource(Scalar.fromNull(DType.STRING)) { nullScalar => + invalidIpv6.ifElse(nullScalar, cv) + } + } + } + + def doColumnar(numRows: Int, url: GpuScalar, partToExtract: GpuScalar): ColumnVector = { + withResource(GpuColumnVector.from(url, numRows, StringType)) { urlCol => + doColumnar(urlCol, partToExtract) + } + } + + def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar): ColumnVector = { + val part = partToExtract.getValue.asInstanceOf[UTF8String].toString + val valid = reValid(url.getBase) + val matched = withResource(valid) { _ => + reMatch(valid, part) + } + if (part == HOST) { + val valided = withResource(matched) { _ => + unsetInvalidHost(matched) + } + withResource(valided) { _ => + emptyToNulls(valided) + } + } else if (part == QUERY || part == REF) { + val resWithNulls = withResource(matched) { _ => + emptyToNulls(matched) + } + withResource(resWithNulls) { _ => + resWithNulls.substring(1) + } + } else if (part == PATH || part == FILE) { + val isUri = withResource(Scalar.fromString("//")) { doubleslash => + withResource(url.getBase.stringContains(doubleslash)) { isurl => + isurl.not() + } + } + withResource(isUri) { _ => + withResource(matched) { _ => + uriPathEmptyToNulls(matched, isUri) + } + } + } else if (part == PROTOCOL) { + val valided = withResource(matched) { _ => + unsetInvalidProtocol(matched) + } + withResource(valided) { _ => + emptyToNulls(valided) + } + } else { + withResource(matched) { _ => + emptyToNulls(matched) + } + } + } + + def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar, key: GpuScalar): ColumnVector = { + val part = partToExtract.getValue.asInstanceOf[UTF8String].toString + if (part != QUERY) { + // return a null columnvector + return ColumnVector.fromStrings(null, null) + } + val querys = withResource(reMatch(url.getBase, QUERY)) { matched => + matched.substring(1) + } + val keyStr = key.getValue.asInstanceOf[UTF8String] + val queryValue = withResource(querys) { _ => + withResource(querys.extractRe(getPattern(keyStr))) { table: Table => + table.getColumn(2).incRefCount() + } + } + withResource(queryValue) { _ => + emptyToNulls(queryValue) + } + } + + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + if (children.size == 2) { + val Seq(url, partToExtract) = children + withResourceIfAllowed(url.columnarEvalAny(batch)) { val0 => + withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { val1 => + (val0, val1) match { + case (v0: GpuColumnVector, v1: GpuScalar) => + GpuColumnVector.from(doColumnar(v0, v1), dataType) + case _ => + throw new UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + } + } + } + } else { + // 3-arg, i.e. QUERY with key + assert(children.size == 3) + val Seq(url, partToExtract, key) = children + withResourceIfAllowed(url.columnarEvalAny(batch)) { val0 => + withResourceIfAllowed(partToExtract.columnarEvalAny(batch)) { val1 => + withResourceIfAllowed(key.columnarEvalAny(batch)) { val2 => + (val0, val1, val2) match { + case (v0: GpuColumnVector, v1: GpuScalar, v2: GpuScalar) => + GpuColumnVector.from(doColumnar(v0, v1, v2), dataType) + case _ => + throw new + UnsupportedOperationException(s"Cannot columnar evaluate expression: $this") + } + } + } + } + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index f23229e0956..2084336cced 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -23,6 +23,7 @@ package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} @@ -81,4 +82,8 @@ object RapidsErrorUtils { def tableIdentifierExistsError(tableIdentifier: TableIdentifier): Throwable = { throw new AnalysisException(s"$tableIdentifier already exists.") } + + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) + } } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index b301397255a..c6122375cf2 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -25,6 +25,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} @@ -85,4 +86,8 @@ object RapidsErrorUtils { def tableIdentifierExistsError(tableIdentifier: TableIdentifier): Throwable = { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } + + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) + } } diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 6fa5b8350a5..fed5d0c4f6c 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -20,6 +20,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} @@ -83,4 +84,8 @@ object RapidsErrorUtils { def tableIdentifierExistsError(tableIdentifier: TableIdentifier): Throwable = { QueryCompilationErrors.tableIdentifierExistsError(tableIdentifier) } + + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) + } } diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 4b81f540e40..e428a3377d5 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -24,6 +24,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.SparkDateTimeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -79,4 +80,8 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { def sqlArrayIndexNotStartAtOneError(): RuntimeException = { new ArrayIndexOutOfBoundsException("SQL array indices start at 1") } + + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) + } } diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index 3585910993d..9f0a365f6dc 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -21,10 +21,12 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.SparkDateTimeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} +import org.apache.spark.unsafe.types.UTF8String object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { @@ -87,4 +89,12 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { override def intervalDivByZeroError(origin: Origin): ArithmeticException = { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } + + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + Some(TypeCheckResult.TypeCheckFailure(s"parse_url function requires two or three arguments")) + } + + def invalidUrlException(url: UTF8String, e: URISyntaxException): Throwable = { + QueryExecutionErrors.invalidUrlError(url, e) + } } diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala index f0b74c1c276..cca1493bcc1 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/shims/RapidsErrorUtils.scala @@ -21,8 +21,9 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.shims import org.apache.spark.SparkDateTimeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, Decimal, DecimalType} @@ -87,4 +88,11 @@ object RapidsErrorUtils extends RapidsErrorUtilsFor330plus { override def intervalDivByZeroError(origin: Origin): ArithmeticException = { QueryExecutionErrors.intervalDividedByZeroError(origin.context) } + + def parseUrlWrongNumArgs(actual: Int): Option[TypeCheckResult] = { + throw QueryCompilationErrors.wrongNumArgsError( + "parse_url", Seq("[2, 3]"), actual + ) + None + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala new file mode 100644 index 00000000000..bd992cf4bd5 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/UrlFunctionsSuite.scala @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.{DataFrame, SparkSession} + +class UrlFunctionsSuite extends SparkQueryCompareTestSuite { + def validUrlEdgeCasesDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + // [In search of the perfect URL validation regex](https://mathiasbynens.be/demo/url-regex) + Seq[String]( + "http://foo.com/blah_blah", + "http://foo.com/blah_blah/", + "http://foo.com/blah_blah_(wikipedia)", + "http://foo.com/blah_blah_(wikipedia)_(again)", + "http://www.example.com/wpstyle/?p=364", + "https://www.example.com/foo/?bar=baz&inga=42&quux", + "http://✪df.ws/123", + "http://userid:password@example.com:8080", + "http://userid:password@example.com:8080/", + "http://userid:password@example.com", + "http://userid:password@example.com/", + "http://142.42.1.1/", + "http://142.42.1.1:8080/", + "http://➡.ws/䨹", + "http://⌘.ws", + "http://⌘.ws/", + "http://foo.com/blah_(wikipedia)#cite-1", + "http://foo.com/blah_(wikipedia)_blah#cite-1", + "http://foo.com/unicode_(✪)_in_parens", + "http://foo.com/(something)?after=parens", + "http://☺.damowmow.com/", + "http://code.google.com/events/#&product=browser", + "http://j.mp", + "ftp://foo.bar/baz", + "http://foo.bar/?q=Test%20URL-encoded%20stuff", + "http://مثال.إختبار", + "http://例子.测试", + "http://उदाहरण.परीक्षा", + "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com", + "http://1337.net", + "http://a.b-c.de", + "http://223.255.255.254", + "https://foo_bar.example.com/", + // "http://", + "http://.", + "http://..", + "http://../", + "http://?", + "http://??", + "http://??/", + "http://#", + "http://##", + "http://##/", + "http://foo.bar?q=Spaces should be encoded", + // "//", + "//a", + "///a", + "///", + "http:///a", + "foo.com", + "rdar://1234", + "h://test", + "http:// shouldfail.com", + ":// should fail", + "http://foo.bar/foo(bar)baz quux", + "ftps://foo.bar/", + "http://-error-.invalid/", + "http://a.b--c.de/", + "http://-a.b.co", + "http://a.b-.co", + "http://0.0.0.0", + "http://10.1.1.0", + "http://10.1.1.255", + "http://224.1.1.1", + "http://1.1.1.1.1", + "http://123.123.123", + "http://3628126748", + "http://.www.foo.bar/", + "http://www.foo.bar./", + "http://.www.foo.bar./", + "http://10.1.1.1", + "http://10.1.1.254" + ).toDF("urls") + } + + def urlCasesFromSpark(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://userinfo@spark.apache.org/path?query=1#Ref", + "https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two", + "http://user:pass@host", + "http://user:pass@host/", + "http://user:pass@host/?#", + "http://user:pass@host/file;param?query;p2" + ).toDF("urls") + } + + def urlCasesFromSparkInvalid(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "inva lid://user:pass@host/file;param?query;p2" + ).toDF("urls") + } + + def urlCasesFromJavaUriLib(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "ftp://ftp.is.co.za/rfc/rfc1808.txt", + "http://www.math.uio.no/faq/compression-faq/part1.html", + "telnet://melvyl.ucop.edu/", + "http://www.w3.org/Addressing/", + "ftp://ds.internic.net/rfc/", + "http://www.ics.uci.edu/pub/ietf/uri/historical.html#WARNING", + "http://www.ics.uci.edu/pub/ietf/uri/#Related", + "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:80/index.html", + "http://[FEDC:BA98:7654:3210:FEDC:BA98:7654:10%12]:80/index.html", + "http://[1080:0:0:0:8:800:200C:417A]/index.html", + "http://[1080:0:0:0:8:800:200C:417A%1]/index.html", + "http://[3ffe:2a00:100:7031::1]", + "http://[1080::8:800:200C:417A]/foo", + "http://[::192.9.5.5]/ipng", + "http://[::192.9.5.5%interface]/ipng", + "http://[::FFFF:129.144.52.38]:80/index.html", + "http://[2010:836B:4179::836B:4179]", + "http://[FF01::101]", + "http://[::1]", + "http://[::]", + "http://[::%hme0]", + "http://[0:0:0:0:0:0:13.1.68.3]", + "http://[0:0:0:0:0:FFFF:129.144.52.38]", + "http://[0:0:0:0:0:FFFF:129.144.52.38%33]", + "http://[0:0:0:0:0:ffff:1.2.3.4]", + "http://[::13.1.68.3]" + ).toDF("urls") + } + + def otherEdgeCases(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://userinfo@spark.apache.org/path/%3C%f9?query=1#Ref", + "http://userinfo@spark.apache.org/path?query=%10%e3#Ref", + "http://userinfo@spark.apache.org/path?query=%10%xx#Ref", + "http://userinfo@spark.apache.org/path?query=1#R%20ef", + "http://abc.com/a%xx%ueue", + "123.foo.bar", + "123.foo.bar:123", + "foo.bar:123", + "http://foo.bar/baduser@xx/yy", + "http://foo.bar/xx/yy?baduser@zz", + "http://foo.bar?query=baduser@key", + "http://foo.bar#baduser@zz", + "http://foo.bar:666@123/xx/yy", + "mailto:xx@yy.com", + "foo.bar/yy?query=key#fragment", + "foo.bar:123", + "foo.bar:123/xx/yy", + "foo.bar:123/xx/yy?query=key", + "foo.bar:123/xx/yy/?query=key&query2=key2", + "foo.bar:123/xx/yy#fragment", + "foo.bar:123/xx/yy/index.html", + "foo.bar:123?query=key" + ).toDF("urls") + } + + def urlWithQueryKey(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://foo.com/blah_blah?foo=bar&baz=blah#vertical-bar" + ).toDF("urls") + } + + def urlIpv6Host(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[String]( + "http://[1:2:3:4:5:6:7:8]", + "http://[1::]", + "http://[1:2:3:4:5:6:7::]", + "http://[1::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1:2:3:4:5:6::8]", + "http://[1::7:8]", + "http://[1:2:3:4:5::7:8]", + "http://[1:2:3:4:5::8]", + "http://[1::6:7:8]", + "http://[1:2:3:4::6:7:8]", + "http://[1:2:3:4::8]", + "http://[1::5:6:7:8]", + "http://[1:2:3::5:6:7:8]", + "http://[1:2:3::8]", + "http://[1::4:5:6:7:8]", + "http://[1:2::4:5:6:7:8]", + "http://[1:2::8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::3:4:5:6:7:8]", + "http://[1::8]", + "http://[::2:3:4:5:6:7:8]", + "http://[::2:3:4:5:6:7:8]", + "http://[::8]", + "http://[::]", + "http://[fe80::7:8%eth0]", + "http://[fe80::7:8%1]", + "http://[::255.255.255.255]", + "http://[::ffff:255.255.255.255]", + "http://[::ffff:0:255.255.255.255]", + "http://[2001:db8:3:4::192.0.2.33]", + "http://[64:ff9b::192.0.2.33]" + ).toDF("urls") + } + + def parseUrls(frame: DataFrame): DataFrame = { + frame.selectExpr( + "urls", + "parse_url(urls, 'HOST') as HOST", + "parse_url(urls, 'PATH') as PATH", + "parse_url(urls, 'QUERY') as QUERY", + "parse_url(urls, 'REF') as REF", + "parse_url(urls, 'PROTOCOL') as PROTOCOL", + "parse_url(urls, 'FILE') as FILE", + "parse_url(urls, 'AUTHORITY') as AUTHORITY", + "parse_url(urls, 'USERINFO') as USERINFO") + } + + testSparkResultsAreEqual("Test parse_url edge cases from internet", validUrlEdgeCasesDf) { + parseUrls + } + + testSparkResultsAreEqual("Test parse_url cases from Spark", urlCasesFromSpark) { + parseUrls + } + + testSparkResultsAreEqual("Test parse_url invalid cases from Spark", urlCasesFromSparkInvalid) { + parseUrls + } + + testSparkResultsAreEqual("Test parse_url cases from java URI library", urlCasesFromJavaUriLib) { + parseUrls + } + + testSparkResultsAreEqual("Test parse_url ipv6 host", urlIpv6Host) { + parseUrls + } + + testSparkResultsAreEqual("Test other edge cases", otherEdgeCases) { + parseUrls + } + + testSparkResultsAreEqual("Test parse_url with query and key", urlWithQueryKey) { + frame => frame.selectExpr( + "urls", + "parse_url(urls, 'QUERY', 'foo') as QUERY", + "parse_url(urls, 'QUERY', 'baz') as QUERY") + } +} \ No newline at end of file diff --git a/tools/generated_files/operatorsScore.csv b/tools/generated_files/operatorsScore.csv index 235c33cca81..99f850d81c0 100644 --- a/tools/generated_files/operatorsScore.csv +++ b/tools/generated_files/operatorsScore.csv @@ -178,6 +178,7 @@ Not,4 NthValue,4 OctetLength,4 Or,4 +ParseUrl,4 PercentRank,4 PivotFirst,4 Pmod,4 diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv index 960ef3d4486..5e7fb90d379 100644 --- a/tools/generated_files/supportedExprs.csv +++ b/tools/generated_files/supportedExprs.csv @@ -373,6 +373,10 @@ Or,S,`or`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA, Or,S,`or`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Or,S,`or`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Or,S,`or`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,url,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,partToExtract,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,key,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA +ParseUrl,S,`parse_url`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA PercentRank,S,`percent_rank`,None,window,ordering,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,NS,NS,NS,NS PercentRank,S,`percent_rank`,None,window,result,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA Pmod,S,`pmod`,None,project,lhs,NA,S,S,S,S,S,S,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA