Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow cuco::arrow_filter_policy to accept a custom implementation of xxhash_64 #642

Merged
8 changes: 6 additions & 2 deletions include/cuco/bloom_filter_policies.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cuco/detail/bloom_filter/arrow_filter_policy.cuh>
#include <cuco/detail/bloom_filter/default_filter_policy_impl.cuh>
#include <cuco/hash_functions.cuh>

#include <cstdint>

Expand All @@ -28,9 +29,12 @@ namespace cuco {
* fingerprint.
*
* @tparam Key The type of the values to generate a fingerprint for.
* @tparam Hash Hash function used to generate a key's fingerprint. By default, cuco::xxhash_64 will
* be used.
*
*/
template <class Key>
using arrow_filter_policy = detail::arrow_filter_policy<Key>;
template <class Key, class Hash = cuco::xxhash_64<Key>>
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
using arrow_filter_policy = detail::arrow_filter_policy<Key, Hash>;

/**
* @brief The default policy that defines how a Blocked Bloom Filter generates and stores a key's
Expand Down
6 changes: 3 additions & 3 deletions include/cuco/detail/bloom_filter/arrow_filter_policy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ namespace cuco::detail {
*
* @tparam Key The type of the values to generate a fingerprint for.
*/
template <class Key>
template <class Key, class Hash>
class arrow_filter_policy {
public:
using hasher = cuco::xxhash_64<Key>; ///< xxhash_64 hasher for Arrow bloom filter policy
using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy
using hasher = Hash; ///< Hash function for Arrow bloom filter policy
using word_type = std::uint32_t; ///< uint32_t for Arrow bloom filter policy
using hash_argument_type = typename hasher::argument_type; ///< Hash function input type
using hash_result_type = decltype(std::declval<hasher>()(
std::declval<hash_argument_type>())); ///< hash function output type
Expand Down
26 changes: 26 additions & 0 deletions include/cuco/detail/hash_functions/xxhash.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,22 @@
#include <cuco/extent.cuh>

#include <cuda/std/cstddef>
#include <cuda/std/span>

#include <cstdint>
#include <type_traits>

// Helper trait to check if a type is `span` like.
// i.e., if it has `data()` and `size()` functions
template <typename T, typename = void>
struct is_span_like : std::false_type {};
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved

// Specialization for `span` like type.
template <typename T>
struct is_span_like<
T,
std::void_t<decltype(std::declval<T>().data()), decltype(std::declval<T>().size())>>
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
: std::true_type {};
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved

namespace cuco::detail {

Expand Down Expand Up @@ -283,6 +297,18 @@ struct XXHash_64 {
}
}

/**
* @brief Returns a hash value for its `span` like argument, as a value of type `result_type`.
*
* @param key The input argument to hash
* @return The resulting hash value for `span` like `key`
*/
template <typename T = Key, typename = std::enable_if_t<is_span_like<T>::value>>
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
constexpr result_type __host__ __device__ operator()(Key const& key) const noexcept
{
return compute_hash(key.data(), key.size());
}

/**
* @brief Returns a hash value for its argument, as a value of type `result_type`.
*
Expand Down
44 changes: 34 additions & 10 deletions tests/bloom_filter/unique_sequence_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#include <test_utils.hpp>

#include <cuco/bloom_filter.cuh>
#include <cuco/utility/key_generator.cuh>

#include <cuda/functional>
#include <cuda/std/span>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
Expand All @@ -30,15 +32,11 @@

using size_type = int32_t;

template <typename Filter>
void test_unique_sequence(Filter& filter, size_type num_keys)
template <typename Filter, typename Key>
void test_unique_sequence(Filter& filter,
thrust::device_vector<Key> const& keys,
size_type num_keys)
{
using Key = typename Filter::key_type;

thrust::device_vector<Key> keys(num_keys);

thrust::sequence(thrust::device, keys.begin(), keys.end());

thrust::device_vector<bool> contained(num_keys, false);

auto is_even =
Expand Down Expand Up @@ -102,7 +100,11 @@ TEMPLATE_TEST_CASE_SIG(

auto filter = filter_type{1000, {}, {pattern_bits}};

test_unique_sequence(filter, num_keys);
// Generate keys
thrust::device_vector<Key> keys(num_keys);
thrust::sequence(thrust::device, keys.begin(), keys.end());

test_unique_sequence(filter, keys, num_keys);
}

TEMPLATE_TEST_CASE_SIG("Unique sequence with arrow policy",
Expand All @@ -118,5 +120,27 @@ TEMPLATE_TEST_CASE_SIG("Unique sequence with arrow policy",

auto filter = filter_type{1000};

test_unique_sequence(filter, num_keys);
// Generate keys
thrust::device_vector<Key> keys(num_keys);
thrust::sequence(thrust::device, keys.begin(), keys.end());

test_unique_sequence(filter, keys, num_keys);
}

TEMPLATE_TEST_CASE_SIG("Unique string sequence with arrow policy",
"",
((class Key, class Policy), Key, Policy),
(cuda::std::span<cuda::std::byte>,
cuco::arrow_filter_policy<cuda::std::span<cuda::std::byte>>))
{
using filter_type =
cuco::bloom_filter<Key, cuco::extent<size_t>, cuda::thread_scope_device, Policy>;
constexpr size_type num_keys{400};

auto filter = filter_type{1000};

// Generate keys (string spans) and the actual string data
auto [keys, data] = cuco::utility::generate_random_byte_sequences(num_keys, 20, 50);

test_unique_sequence(filter, keys, num_keys);
}