From 98dc423dfbacb68e0d5d8d15069455aaffad618f Mon Sep 17 00:00:00 2001 From: Mike Wilson Date: Tue, 19 Dec 2023 13:38:25 -0500 Subject: [PATCH] Adding query support to parseURI (#1652) * Adding query to parseuri Signed-off-by: Mike Wilson --- src/main/cpp/src/ParseURIJni.cpp | 14 +++ src/main/cpp/src/parse_uri.cu | 87 +++++++++++++------ src/main/cpp/src/parse_uri.hpp | 15 +++- src/main/cpp/tests/parse_uri.cpp | 69 ++++++++++++--- .../com/nvidia/spark/rapids/jni/ParseURI.java | 13 ++- .../nvidia/spark/rapids/jni/ParseURITest.java | 79 +++++++++++++---- 6 files changed, 223 insertions(+), 54 deletions(-) diff --git a/src/main/cpp/src/ParseURIJni.cpp b/src/main/cpp/src/ParseURIJni.cpp index 9079d99b9d..3af72687b6 100644 --- a/src/main/cpp/src/ParseURIJni.cpp +++ b/src/main/cpp/src/ParseURIJni.cpp @@ -47,4 +47,18 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ParseURI_parseHost(JNIE } CATCH_STD(env, 0); } + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ParseURI_parseQuery(JNIEnv* env, + jclass, + jlong input_column) +{ + JNI_NULL_CHECK(env, input_column, "input column is null", 0); + + try { + cudf::jni::auto_set_device(env); + auto const input = reinterpret_cast(input_column); + return cudf::jni::ptr_as_jlong(spark_rapids_jni::parse_uri_to_query(*input).release()); + } + CATCH_STD(env, 0); +} } diff --git a/src/main/cpp/src/parse_uri.cu b/src/main/cpp/src/parse_uri.cu index 13a8effb37..d75dfc18c1 100644 --- a/src/main/cpp/src/parse_uri.cu +++ b/src/main/cpp/src/parse_uri.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -47,10 +48,20 @@ struct uri_parts { string_view userinfo; string_view port; string_view opaque; - bool valid{false}; + uint32_t valid{0}; }; -enum class URI_chunks : int8_t { PROTOCOL, HOST, AUTHORITY, PATH, QUERY, USERINFO }; +enum class URI_chunks : int8_t { + PROTOCOL, + HOST, + AUTHORITY, + PATH, + FRAGMENT, + QUERY, + USERINFO, + PORT, + OPAQUE +}; enum class chunk_validity : int8_t { VALID, INVALID, FATAL }; @@ -436,7 +447,7 @@ bool __device__ validate_path(string_view path) // path can be alphanum and @[]_-!.~'()*?/&,;:$+= return validate_chunk(path, [] __device__(string_view::const_iterator iter) { auto const c = *iter; - if (c != '!' && c != '$' && !(c >= '&' && c <= ';') && c != '=' && !(c >= '@' && c <= 'Z') && + if (c != '!' && c != '$' && !(c >= '&' && c <= ';') && c != '=' && !(c >= '?' && c <= 'Z') && c != '_' && !(c >= 'a' && c <= 'z') && c != '~') { return false; } @@ -474,6 +485,7 @@ uri_parts __device__ validate_uri(const char* str, int len) { uri_parts ret; + auto const original_str = str; // look for :/# characters. int col = -1; int slash = -1; @@ -503,9 +515,10 @@ uri_parts __device__ validate_uri(const char* str, int len) if (hash >= 0) { ret.fragment = {str + hash + 1, len - hash - 1}; if (!validate_fragment(ret.fragment)) { - ret.valid = false; + ret.valid = 0; return ret; } + ret.valid |= (1 << static_cast(URI_chunks::FRAGMENT)); len = hash; @@ -519,9 +532,10 @@ uri_parts __device__ validate_uri(const char* str, int len) // we have a scheme up to the : ret.scheme = {str, col}; if (!validate_scheme(ret.scheme)) { - ret.valid = false; + ret.valid = 0; return ret; } + ret.valid |= (1 << static_cast(URI_chunks::PROTOCOL)); // skip over scheme auto const skip = col + 1; @@ -534,20 +548,22 @@ uri_parts __device__ validate_uri(const char* str, int len) // no more string to parse is an error if (len <= 0) { - ret.valid = false; + ret.valid = 0; return ret; } - // If we have a '/' as the next character, we have a heirarchical uri. If not it is opaque. - bool const heirarchical = str[0] == '/'; + // If we have a '/' as the next character or this is still the start of the string, we have a + // heirarchical uri. If not it is opaque. + bool const heirarchical = str[0] == '/' || str == original_str; if (heirarchical) { // a '?' will break this into query and path/authority if (question >= 0) { ret.query = {str + question + 1, len - question - 1}; if (!validate_query(ret.query)) { - ret.valid = false; + ret.valid = 0; return ret; } + ret.valid |= (1 << static_cast(URI_chunks::QUERY)); } auto const path_len = question >= 0 ? question : len; @@ -567,17 +583,17 @@ uri_parts __device__ validate_uri(const char* str, int len) if (next_slash == -1 && ret.authority.size_bytes() == 0 && ret.query.size_bytes() == 0 && ret.fragment.size_bytes() == 0) { // invalid! - but spark like to return things as long as you don't have illegal characters - // ret.valid = false; - ret.valid = true; + // ret.valid = 0; return ret; } if (ret.authority.size_bytes() > 0) { auto ipv6_address = ret.authority.size_bytes() > 2 && *ret.authority.begin() == '['; if (!validate_authority(ret.authority, ipv6_address)) { - ret.valid = false; + ret.valid = 0; return ret; } + ret.valid |= (1 << static_cast(URI_chunks::AUTHORITY)); // Inspect the authority for userinfo, host, and port const char* auth = ret.authority.data(); @@ -604,9 +620,11 @@ uri_parts __device__ validate_uri(const char* str, int len) if (amp > 0) { ret.userinfo = {auth, amp}; if (!validate_userinfo(ret.userinfo)) { - ret.valid = false; + ret.valid = 0; return ret; } + ret.valid |= (1 << static_cast(URI_chunks::USERINFO)); + // skip over the @ amp++; @@ -617,36 +635,39 @@ uri_parts __device__ validate_uri(const char* str, int len) // Found a port, attempt to parse it ret.port = {auth + last_colon + 1, auth_size - last_colon - 1}; if (!validate_port(ret.port)) { - ret.valid = false; + ret.valid = 0; return ret; } + ret.valid |= (1 << static_cast(URI_chunks::PORT)); ret.host = {auth, last_colon}; } else { ret.host = {auth, auth_size}; } auto host_ret = validate_host(ret.host); switch (host_ret) { - case chunk_validity::FATAL: ret.valid = false; return ret; + case chunk_validity::FATAL: ret.valid = 0; return ret; case chunk_validity::INVALID: ret.host = {}; break; + case chunk_validity::VALID: ret.valid |= (1 << static_cast(URI_chunks::HOST)); break; } } } else { // path with no authority - ret.path = {str, len}; + ret.path = {str, path_len}; } if (!validate_path(ret.path)) { - ret.valid = false; + ret.valid = 0; return ret; } + ret.valid |= (1 << static_cast(URI_chunks::PATH)); } else { ret.opaque = {str, len}; if (!validate_opaque(ret.opaque)) { - ret.valid = false; + ret.valid = 0; return ret; } + ret.valid |= (1 << static_cast(URI_chunks::OPAQUE)); } - ret.valid = true; return ret; } @@ -697,7 +718,7 @@ __global__ void parse_uri_char_counter(column_device_view const in_strings, auto const string_length = in_string.size_bytes(); auto const uri = validate_uri(in_chars, string_length); - if (!uri.valid) { + if ((uri.valid & (1 << static_cast(chunk))) == 0) { out_lengths[row_idx] = 0; clear_bit(out_validity, row_idx); } else { @@ -727,11 +748,18 @@ __global__ void parse_uri_char_counter(column_device_view const in_strings, out_lengths[row_idx] = uri.userinfo.size_bytes(); out_offsets[row_idx] = uri.userinfo.data() - base_ptr; break; - } - - if (out_lengths[row_idx] == 0) { - // A URI can be valid, but still have no data for a specific chunk - clear_bit(out_validity, row_idx); + case URI_chunks::PORT: + out_lengths[row_idx] = uri.port.size_bytes(); + out_offsets[row_idx] = uri.port.data() - base_ptr; + break; + case URI_chunks::FRAGMENT: + out_lengths[row_idx] = uri.fragment.size_bytes(); + out_offsets[row_idx] = uri.fragment.data() - base_ptr; + break; + case URI_chunks::OPAQUE: + out_lengths[row_idx] = uri.opaque.size_bytes(); + out_offsets[row_idx] = uri.opaque.data() - base_ptr; + break; } } } @@ -858,4 +886,13 @@ std::unique_ptr parse_uri_to_host(strings_column_view const& input, return detail::parse_uri(input, detail::URI_chunks::HOST, stream, mr); } +std::unique_ptr parse_uri_to_query(strings_column_view const& input, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::parse_uri( + input, detail::URI_chunks::QUERY, stream, rmm::mr::get_current_device_resource()); +} + } // namespace spark_rapids_jni \ No newline at end of file diff --git a/src/main/cpp/src/parse_uri.hpp b/src/main/cpp/src/parse_uri.hpp index 0a76cec1b4..07f6f9cd46 100644 --- a/src/main/cpp/src/parse_uri.hpp +++ b/src/main/cpp/src/parse_uri.hpp @@ -49,7 +49,20 @@ std::unique_ptr parse_uri_to_protocol( */ std::unique_ptr parse_uri_to_host( cudf::strings_column_view const& input, - rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @brief Parse query and copy from the input string column to the output char buffer. + * + * @param input Input string column of URIs to parse + * @param stream Stream on which to operate. + * @param mr Memory resource for returned column + * @return std::unique_ptr String column of queries parsed. + */ +std::unique_ptr parse_uri_to_query( + cudf::strings_column_view const& input, + rmm::cuda_stream_view stream = cudf::get_default_stream(), rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); } // namespace spark_rapids_jni diff --git a/src/main/cpp/tests/parse_uri.cpp b/src/main/cpp/tests/parse_uri.cpp index 1112fea232..36ebbeacc0 100644 --- a/src/main/cpp/tests/parse_uri.cpp +++ b/src/main/cpp/tests/parse_uri.cpp @@ -19,10 +19,12 @@ #include #include #include +#include #include struct ParseURIProtocolTests : public cudf::test::BaseFixture {}; struct ParseURIHostTests : public cudf::test::BaseFixture {}; +struct ParseURIQueryTests : public cudf::test::BaseFixture {}; enum class test_types { SIMPLE, @@ -30,6 +32,7 @@ enum class test_types { IPv6, IPv4, UTF8, + QUERY, }; namespace { @@ -123,6 +126,15 @@ cudf::test::strings_column_wrapper get_test_data(test_types t) "http://✪↩d⁚f„⁈.ws/123", "https:// /path/to/file", }); + case test_types::QUERY: + return cudf::test::strings_column_wrapper({ + "https://www.nvidia.com/path?param0=1¶m2=3¶m4=5", + "https:// /?params=5&cloth=0&metal=1", + "https://[2001:db8::2:1]:443/parms/in/the/uri?a=b", + "https://[::1]/?invalid=param&f„⁈.=7", + "https://[::1]/?invalid=param&~.=!@&^", + "userinfo@www.nvidia.com/path?query=1#Ref", + }); default: CUDF_FAIL("Test type unsupported!"); return cudf::test::strings_column_wrapper(); } } @@ -136,7 +148,7 @@ TEST_F(ParseURIProtocolTests, Simple) cudf::test::strings_column_wrapper const expected( {"https", "http", "file", "smb", "http", "file", "", "", ""}, {1, 1, 1, 1, 1, 1, 0, 0, 0}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } TEST_F(ParseURIProtocolTests, SparkEdges) @@ -185,7 +197,7 @@ TEST_F(ParseURIProtocolTests, SparkEdges) {1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } TEST_F(ParseURIProtocolTests, IP6) @@ -197,7 +209,7 @@ TEST_F(ParseURIProtocolTests, IP6) {"https", "https", "https", "https", "http", "https", "https", "https", "", ""}, {1, 1, 1, 1, 1, 1, 1, 1, 0, 0}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } TEST_F(ParseURIProtocolTests, IP4) @@ -208,7 +220,7 @@ TEST_F(ParseURIProtocolTests, IP4) cudf::test::strings_column_wrapper const expected( {"https", "https", "https", "https", "https", "https"}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } TEST_F(ParseURIProtocolTests, UTF8) @@ -218,7 +230,7 @@ TEST_F(ParseURIProtocolTests, UTF8) cudf::test::strings_column_wrapper const expected({"https", "http", "http", ""}, {1, 1, 1, 0}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } TEST_F(ParseURIHostTests, Simple) @@ -230,7 +242,7 @@ TEST_F(ParseURIHostTests, Simple) {"www.nvidia.com", "www.nvidia.com", "path", "network", "", "", "", "", ""}, {1, 1, 1, 1, 0, 0, 0, 0, 0}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } TEST_F(ParseURIHostTests, SparkEdges) @@ -279,7 +291,7 @@ TEST_F(ParseURIHostTests, SparkEdges) {1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } TEST_F(ParseURIHostTests, IP6) @@ -299,7 +311,7 @@ TEST_F(ParseURIHostTests, IP6) ""}, {1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } TEST_F(ParseURIHostTests, IP4) @@ -310,7 +322,7 @@ TEST_F(ParseURIHostTests, IP4) cudf::test::strings_column_wrapper const expected( {"192.168.1.100", "192.168.1.100", "", "", "", ""}, {1, 1, 0, 0, 0, 0}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } TEST_F(ParseURIHostTests, UTF8) @@ -320,5 +332,42 @@ TEST_F(ParseURIHostTests, UTF8) cudf::test::strings_column_wrapper const expected({"nvidia.com", "", "", ""}, {1, 0, 0, 0}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); +} + +TEST_F(ParseURIQueryTests, Simple) +{ + auto const col = get_test_data(test_types::SIMPLE); + auto const result = spark_rapids_jni::parse_uri_to_query(cudf::strings_column_view{col}); + + cudf::test::strings_column_wrapper const expected({"param1=2", "", "", "", "", "", "", "", ""}, + {1, 0, 0, 0, 0, 0, 0, 0, 0}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); +} + +TEST_F(ParseURIQueryTests, SparkEdges) +{ + auto const col = get_test_data(test_types::SPARK_EDGES); + auto const result = spark_rapids_jni::parse_uri_to_query(cudf::strings_column_view{col}); + + cudf::test::strings_column_wrapper const expected( + {"", "", "", "", "", "", "", "", "", "", "", "", "", "", "", + "", // empty + "?", "?/", "", "query;p2", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); +} + +TEST_F(ParseURIQueryTests, Queries) +{ + auto const col = get_test_data(test_types::QUERY); + auto const result = spark_rapids_jni::parse_uri_to_query(cudf::strings_column_view{col}); + + cudf::test::strings_column_wrapper const expected( + {"param0=1¶m2=3¶m4=5", "", "a=b", "invalid=param&f„⁈.=7", "", "query=1"}, + {1, 0, 1, 1, 0, 1}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected, result->view()); } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java b/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java index 0e14f388d4..8f82bfc908 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ParseURI.java @@ -49,7 +49,18 @@ public static ColumnVector parseURIHost(ColumnView uriColumn) { return new ColumnVector(parseHost(uriColumn.getNativeView())); } + /** + * Parse query for each URI from the incoming column. + * + * @param URIColumn The input strings column in which each row contains a URI. + * @return A string column with query data extracted. + */ + public static ColumnVector parseURIQuery(ColumnView uriColumn) { + assert uriColumn.getType().equals(DType.STRING) : "Input type must be String"; + return new ColumnVector(parseQuery(uriColumn.getNativeView())); + } + private static native long parseProtocol(long jsonColumnHandle); private static native long parseHost(long jsonColumnHandle); - + private static native long parseQuery(long jsonColumnHandle); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java index c6e3b06ed1..ca76df2bf3 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java @@ -25,9 +25,8 @@ import ai.rapids.cudf.ColumnVector; public class ParseURITest { - void buildExpectedAndRun(String[] testData) { + void testProtocol(String[] testData) { String[] expectedProtocolStrings = new String[testData.length]; - String[] expectedHostStrings = new String[testData.length]; for (int i=0; i