Skip to content

Commit

Permalink
Adding query support to parseURI (#1652)
Browse files Browse the repository at this point in the history
* Adding query to parseuri

Signed-off-by: Mike Wilson <[email protected]>
  • Loading branch information
hyperbolic2346 authored Dec 19, 2023
1 parent dadc7a0 commit 98dc423
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 54 deletions.
14 changes: 14 additions & 0 deletions src/main/cpp/src/ParseURIJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,18 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ParseURI_parseHost(JNIE
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ParseURI_parseQuery(JNIEnv* env,
jclass,
jlong input_column)
{
JNI_NULL_CHECK(env, input_column, "input column is null", 0);

try {
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::column_view const*>(input_column);
return cudf::jni::ptr_as_jlong(spark_rapids_jni::parse_uri_to_query(*input).release());
}
CATCH_STD(env, 0);
}
}
87 changes: 62 additions & 25 deletions src/main/cpp/src/parse_uri.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cudf/detail/valid_if.cuh>
#include <cudf/lists/lists_column_device_view.cuh>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/convert/convert_urls.hpp>
#include <cudf/strings/detail/strings_children.cuh>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/strings/string_view.cuh>
Expand All @@ -47,10 +48,20 @@ struct uri_parts {
string_view userinfo;
string_view port;
string_view opaque;
bool valid{false};
uint32_t valid{0};
};

enum class URI_chunks : int8_t { PROTOCOL, HOST, AUTHORITY, PATH, QUERY, USERINFO };
enum class URI_chunks : int8_t {
PROTOCOL,
HOST,
AUTHORITY,
PATH,
FRAGMENT,
QUERY,
USERINFO,
PORT,
OPAQUE
};

enum class chunk_validity : int8_t { VALID, INVALID, FATAL };

Expand Down Expand Up @@ -436,7 +447,7 @@ bool __device__ validate_path(string_view path)
// path can be alphanum and @[]_-!.~'()*?/&,;:$+=
return validate_chunk(path, [] __device__(string_view::const_iterator iter) {
auto const c = *iter;
if (c != '!' && c != '$' && !(c >= '&' && c <= ';') && c != '=' && !(c >= '@' && c <= 'Z') &&
if (c != '!' && c != '$' && !(c >= '&' && c <= ';') && c != '=' && !(c >= '?' && c <= 'Z') &&
c != '_' && !(c >= 'a' && c <= 'z') && c != '~') {
return false;
}
Expand Down Expand Up @@ -474,6 +485,7 @@ uri_parts __device__ validate_uri(const char* str, int len)
{
uri_parts ret;

auto const original_str = str;
// look for :/# characters.
int col = -1;
int slash = -1;
Expand Down Expand Up @@ -503,9 +515,10 @@ uri_parts __device__ validate_uri(const char* str, int len)
if (hash >= 0) {
ret.fragment = {str + hash + 1, len - hash - 1};
if (!validate_fragment(ret.fragment)) {
ret.valid = false;
ret.valid = 0;
return ret;
}
ret.valid |= (1 << static_cast<int>(URI_chunks::FRAGMENT));

len = hash;

Expand All @@ -519,9 +532,10 @@ uri_parts __device__ validate_uri(const char* str, int len)
// we have a scheme up to the :
ret.scheme = {str, col};
if (!validate_scheme(ret.scheme)) {
ret.valid = false;
ret.valid = 0;
return ret;
}
ret.valid |= (1 << static_cast<int>(URI_chunks::PROTOCOL));

// skip over scheme
auto const skip = col + 1;
Expand All @@ -534,20 +548,22 @@ uri_parts __device__ validate_uri(const char* str, int len)

// no more string to parse is an error
if (len <= 0) {
ret.valid = false;
ret.valid = 0;
return ret;
}

// If we have a '/' as the next character, we have a heirarchical uri. If not it is opaque.
bool const heirarchical = str[0] == '/';
// If we have a '/' as the next character or this is still the start of the string, we have a
// heirarchical uri. If not it is opaque.
bool const heirarchical = str[0] == '/' || str == original_str;
if (heirarchical) {
// a '?' will break this into query and path/authority
if (question >= 0) {
ret.query = {str + question + 1, len - question - 1};
if (!validate_query(ret.query)) {
ret.valid = false;
ret.valid = 0;
return ret;
}
ret.valid |= (1 << static_cast<int>(URI_chunks::QUERY));
}
auto const path_len = question >= 0 ? question : len;

Expand All @@ -567,17 +583,17 @@ uri_parts __device__ validate_uri(const char* str, int len)
if (next_slash == -1 && ret.authority.size_bytes() == 0 && ret.query.size_bytes() == 0 &&
ret.fragment.size_bytes() == 0) {
// invalid! - but spark like to return things as long as you don't have illegal characters
// ret.valid = false;
ret.valid = true;
// ret.valid = 0;
return ret;
}

if (ret.authority.size_bytes() > 0) {
auto ipv6_address = ret.authority.size_bytes() > 2 && *ret.authority.begin() == '[';
if (!validate_authority(ret.authority, ipv6_address)) {
ret.valid = false;
ret.valid = 0;
return ret;
}
ret.valid |= (1 << static_cast<int>(URI_chunks::AUTHORITY));

// Inspect the authority for userinfo, host, and port
const char* auth = ret.authority.data();
Expand All @@ -604,9 +620,11 @@ uri_parts __device__ validate_uri(const char* str, int len)
if (amp > 0) {
ret.userinfo = {auth, amp};
if (!validate_userinfo(ret.userinfo)) {
ret.valid = false;
ret.valid = 0;
return ret;
}
ret.valid |= (1 << static_cast<int>(URI_chunks::USERINFO));

// skip over the @
amp++;

Expand All @@ -617,36 +635,39 @@ uri_parts __device__ validate_uri(const char* str, int len)
// Found a port, attempt to parse it
ret.port = {auth + last_colon + 1, auth_size - last_colon - 1};
if (!validate_port(ret.port)) {
ret.valid = false;
ret.valid = 0;
return ret;
}
ret.valid |= (1 << static_cast<int>(URI_chunks::PORT));
ret.host = {auth, last_colon};
} else {
ret.host = {auth, auth_size};
}
auto host_ret = validate_host(ret.host);
switch (host_ret) {
case chunk_validity::FATAL: ret.valid = false; return ret;
case chunk_validity::FATAL: ret.valid = 0; return ret;
case chunk_validity::INVALID: ret.host = {}; break;
case chunk_validity::VALID: ret.valid |= (1 << static_cast<int>(URI_chunks::HOST)); break;
}
}
} else {
// path with no authority
ret.path = {str, len};
ret.path = {str, path_len};
}
if (!validate_path(ret.path)) {
ret.valid = false;
ret.valid = 0;
return ret;
}
ret.valid |= (1 << static_cast<int>(URI_chunks::PATH));
} else {
ret.opaque = {str, len};
if (!validate_opaque(ret.opaque)) {
ret.valid = false;
ret.valid = 0;
return ret;
}
ret.valid |= (1 << static_cast<int>(URI_chunks::OPAQUE));
}

ret.valid = true;
return ret;
}

Expand Down Expand Up @@ -697,7 +718,7 @@ __global__ void parse_uri_char_counter(column_device_view const in_strings,
auto const string_length = in_string.size_bytes();

auto const uri = validate_uri(in_chars, string_length);
if (!uri.valid) {
if ((uri.valid & (1 << static_cast<int>(chunk))) == 0) {
out_lengths[row_idx] = 0;
clear_bit(out_validity, row_idx);
} else {
Expand Down Expand Up @@ -727,11 +748,18 @@ __global__ void parse_uri_char_counter(column_device_view const in_strings,
out_lengths[row_idx] = uri.userinfo.size_bytes();
out_offsets[row_idx] = uri.userinfo.data() - base_ptr;
break;
}

if (out_lengths[row_idx] == 0) {
// A URI can be valid, but still have no data for a specific chunk
clear_bit(out_validity, row_idx);
case URI_chunks::PORT:
out_lengths[row_idx] = uri.port.size_bytes();
out_offsets[row_idx] = uri.port.data() - base_ptr;
break;
case URI_chunks::FRAGMENT:
out_lengths[row_idx] = uri.fragment.size_bytes();
out_offsets[row_idx] = uri.fragment.data() - base_ptr;
break;
case URI_chunks::OPAQUE:
out_lengths[row_idx] = uri.opaque.size_bytes();
out_offsets[row_idx] = uri.opaque.data() - base_ptr;
break;
}
}
}
Expand Down Expand Up @@ -858,4 +886,13 @@ std::unique_ptr<column> parse_uri_to_host(strings_column_view const& input,
return detail::parse_uri(input, detail::URI_chunks::HOST, stream, mr);
}

std::unique_ptr<column> parse_uri_to_query(strings_column_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::parse_uri(
input, detail::URI_chunks::QUERY, stream, rmm::mr::get_current_device_resource());
}

} // namespace spark_rapids_jni
15 changes: 14 additions & 1 deletion src/main/cpp/src/parse_uri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,20 @@ std::unique_ptr<cudf::column> parse_uri_to_protocol(
*/
std::unique_ptr<cudf::column> parse_uri_to_host(
cudf::strings_column_view const& input,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Parse query and copy from the input string column to the output char buffer.
*
* @param input Input string column of URIs to parse
* @param stream Stream on which to operate.
* @param mr Memory resource for returned column
* @return std::unique_ptr<column> String column of queries parsed.
*/
std::unique_ptr<cudf::column> parse_uri_to_query(
cudf::strings_column_view const& input,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace spark_rapids_jni
Loading

0 comments on commit 98dc423

Please sign in to comment.