Skip to content

Commit

Permalink
Use parse_url kernel for QUERY parsing (#10061)
Browse files Browse the repository at this point in the history
* Use parse_url kernel for QUERY parsing

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>

* fallback query with key

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>

---------

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
thirtiseven authored Dec 27, 2023
1 parent 53ded10 commit a9c1d68
Showing 3 changed files with 30 additions and 8 deletions.
22 changes: 20 additions & 2 deletions integration_tests/src/main/python/url_test.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
r'(:[0-9]{1,3}){0,1}(/[a-z]{1,3}){0,3}(\?key=[a-z]{1,3}){0,1}(#([a-z]{1,3})){0,1}'

edge_cases = [
"userinfo@spark.apache.org/path?query=1#Ref",
"http://foo.com/blah_blah",
"http://foo.com/blah_blah/",
"http://foo.com/blah_blah_(wikipedia)",
@@ -103,6 +104,7 @@
"http://10.1.1.254",
"http://userinfo@spark.apache.org/path?query=1#Ref",
r"https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two",
r"https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%9Fy&q2=2#Ref%20two",
"http://user:pass@host",
"http://user:pass@host/",
"http://user:pass@host/?#",
@@ -146,8 +148,10 @@

url_gen = StringGen(url_pattern)

supported_parts = ['PROTOCOL', 'HOST']
unsupported_parts = ['PATH', 'QUERY', 'REF', 'FILE', 'AUTHORITY', 'USERINFO']
supported_parts = ['PROTOCOL', 'HOST', 'QUERY']
unsupported_parts = ['PATH', 'REF', 'FILE', 'AUTHORITY', 'USERINFO']
supported_with_key_parts = ['PROTOCOL', 'HOST']
unsupported_with_key_parts = ['QUERY', 'PATH', 'REF', 'FILE', 'AUTHORITY', 'USERINFO']

@pytest.mark.parametrize('data_gen', [url_gen, edge_cases_gen], ids=idfn)
@pytest.mark.parametrize('part', supported_parts, ids=idfn)
@@ -161,3 +165,17 @@ def test_parse_url_unsupported_fallback(part):
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, url_gen).selectExpr("a", "parse_url(a, '" + part + "')"),
'ParseUrl')

@pytest.mark.parametrize('part', supported_with_key_parts, ids=idfn)
def test_parse_url_with_key(part):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, url_gen).selectExpr("parse_url(a, '" + part + "', 'key')"))



@allow_non_gpu('ProjectExec', 'ParseUrl')
@pytest.mark.parametrize('part', unsupported_with_key_parts, ids=idfn)
def test_parse_url_query_with_key_fallback(part):
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, url_gen).selectExpr("parse_url(a, '" + part + "', 'key')"),
'ParseUrl')
Original file line number Diff line number Diff line change
@@ -3263,6 +3263,8 @@ object GpuOverrides extends Logging {
}

extractStringLit(a.children(1)).map(_.toUpperCase) match {
case Some("QUERY") if (a.children.size == 3) =>
willNotWorkOnGpu("Part to extract QUERY with key is not supported on GPU")
case Some(part) if GpuParseUrl.isSupportedPart(part) =>
case Some(other) =>
willNotWorkOnGpu(s"Part to extract $other is not supported on GPU")
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ object GpuParseUrl {

def isSupportedPart(part: String): Boolean = {
part match {
case PROTOCOL | HOST =>
case PROTOCOL | HOST | QUERY =>
true
case _ =>
false
@@ -65,22 +65,24 @@ case class GpuParseUrl(children: Seq[Expression])
ParseURI.parseURIProtocol(url.getBase)
case HOST =>
ParseURI.parseURIHost(url.getBase)
case PATH | QUERY | REF | FILE | AUTHORITY | USERINFO =>
case QUERY =>
ParseURI.parseURIQuery(url.getBase)
case PATH | REF | FILE | AUTHORITY | USERINFO =>
throw new UnsupportedOperationException(s"$this is not supported partToExtract=$part. " +
s"Only PROTOCOL and HOST are supported")
s"Only PROTOCOL, HOST and QUERY without a key are supported")
case _ =>
throw new IllegalArgumentException(s"Invalid partToExtract: $partToExtract")
}
}

def doColumnar(url: GpuColumnVector, partToExtract: GpuScalar, key: GpuScalar): ColumnVector = {
def doColumnar(col: GpuColumnVector, partToExtract: GpuScalar, key: GpuScalar): ColumnVector = {
val part = partToExtract.getValue.asInstanceOf[UTF8String].toString
if (part != QUERY) {
// return a null columnvector
return ColumnVector.fromStrings(null, null)
return GpuColumnVector.columnVectorFromNull(col.getRowCount.toInt, StringType)
}
throw new UnsupportedOperationException(s"$this is not supported partToExtract=$part. " +
s"Only PROTOCOL and HOST are supported")
s"Only PROTOCOL, HOST and QUERY without a key are supported")
}

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {

0 comments on commit a9c1d68

Please sign in to comment.