From 59a3898001fa23c3e3bc99e888b4b5b0e65dfbff Mon Sep 17 00:00:00 2001 From: rhdong Date: Sat, 14 Sep 2024 14:04:17 -0700 Subject: [PATCH] [Feat] add `repeat`, `sparsity`, `eval_n_elements` APIs to `bitset` --- cpp/include/raft/core/bitset.cuh | 98 +++++++++++++++++++++++++++++ cpp/include/raft/core/bitset.hpp | 53 ++++++++++++++++ cpp/test/core/bitset.cu | 102 ++++++++++++++++++++++++++++--- 3 files changed, 243 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index b6e6128eca..94958c95e0 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -26,6 +26,8 @@ #include #include +#include + #include namespace raft::core { @@ -60,6 +62,102 @@ _RAFT_DEVICE void bitset_view::set(const index_t sample_index } } +template +struct bitset_copy_functor { + const bitset_t* bitset_ptr; + bitset_t* output_device_ptr; + index_t valid_bits; + index_t bits_per_element; + index_t total_bits; + + bitset_copy_functor(const bitset_t* _bitset_ptr, + bitset_t* _output_device_ptr, + index_t _valid_bits, + index_t _bits_per_element, + index_t _total_bits) + : bitset_ptr(_bitset_ptr), + output_device_ptr(_output_device_ptr), + valid_bits(_valid_bits), + bits_per_element(_bits_per_element), + total_bits(_total_bits) + { + } + + __device__ void operator()(index_t i) + { + if (i < total_bits) { + index_t src_bit_index = i % valid_bits; + index_t dst_bit_index = i; + + index_t src_element_index = src_bit_index / bits_per_element; + index_t src_bit_offset = src_bit_index % bits_per_element; + + index_t dst_element_index = dst_bit_index / bits_per_element; + index_t dst_bit_offset = dst_bit_index % bits_per_element; + + bitset_t src_element = bitset_ptr[src_element_index]; + bitset_t src_bit = (src_element >> src_bit_offset) & 1; + + if (src_bit) { + atomicOr(output_device_ptr + dst_element_index, bitset_t(1) << dst_bit_offset); + } else { + atomicAnd(output_device_ptr + dst_element_index, ~(bitset_t(1) << dst_bit_offset)); + } + } + } +}; + +template +void bitset_view::repeat(const raft::resources& res, + index_t times, + bitset_t* output_device_ptr) const +{ + auto thrust_policy = raft::resource::get_thrust_policy(res); + constexpr index_t bits_per_element = sizeof(bitset_t) * 8; + + if (bitset_len_ % bits_per_element == 0) { + index_t num_elements_to_copy = bitset_len_ / bits_per_element; + + for (index_t i = 0; i < times; ++i) { + raft::copy(output_device_ptr + i * num_elements_to_copy, + bitset_ptr_, + num_elements_to_copy, + raft::resource::get_cuda_stream(res)); + } + } else { + index_t valid_bits = bitset_len_; + index_t total_bits = valid_bits * times; + index_t output_row_elements = (total_bits + bits_per_element - 1) / bits_per_element; + thrust::for_each_n(thrust_policy, + thrust::counting_iterator(0), + total_bits, + bitset_copy_functor( + bitset_ptr_, output_device_ptr, valid_bits, bits_per_element, total_bits)); + } +} + +template +double bitset_view::sparsity(const raft::resources& res) const +{ + index_t nnz_h = 0; + index_t size_h = this->size(); + auto stream = raft::resource::get_cuda_stream(res); + + if (0 == size_h) { return static_cast(1.0); } + + rmm::device_scalar nnz(0, stream); + + auto vector_view = raft::make_device_vector_view(data(), n_elements()); + auto nnz_view = raft::make_device_scalar_view(nnz.data()); + auto size_view = raft::make_host_scalar_view(&size_h); + + raft::popc(res, vector_view, size_view, nnz_view); + raft::copy(&nnz_h, nnz.data(), 1, stream); + + raft::resource::sync_stream(res, stream); + return static_cast((1.0 * (size_h - nnz_h)) / (1.0 * size_h)); +} + template bitset::bitset(const raft::resources& res, raft::device_vector_view mask_index, diff --git a/cpp/include/raft/core/bitset.hpp b/cpp/include/raft/core/bitset.hpp index 3608ee43fa..179e2d78f0 100644 --- a/cpp/include/raft/core/bitset.hpp +++ b/cpp/include/raft/core/bitset.hpp @@ -22,6 +22,8 @@ #include #include +#include + namespace raft::core { /** * @defgroup bitset Bitset @@ -104,6 +106,57 @@ struct bitset_view { return raft::make_device_vector_view(bitset_ptr_, n_elements()); } + /** + * @brief Repeats the bitset data and copies it to the output device pointer. + * + * This function takes the original bitset data stored in the device memory + * and repeats it a specified number of times into a new location in the device memory. + * The bits are copied bit-by-bit to ensure that even if the number of bits (bitset_len_) + * is not a multiple of the bitset element size (e.g., 32 for uint32_t), the bits are + * tightly packed without any gaps between rows. + * + * @param res RAFT resources for managing CUDA streams and execution policies. + * @param times Number of times the bitset data should be repeated in the output. + * @param output_device_ptr Device pointer where the repeated bitset data will be stored. + * + * The caller must ensure that the output device pointer has enough memory allocated + * to hold `times * bitset_len` bits, where `bitset_len` is the number of bits in the original + * bitset. This function uses Thrust parallel algorithms to efficiently perform the operation on + * the GPU. + */ + void repeat(const raft::resources& res, index_t times, bitset_t* output_device_ptr) const; + + /** + * @brief Calculate the sparsity (fraction of 0s) of the bitset. + * + * This function computes the sparsity of the bitset, defined as the ratio of unset bits (0s) + * to the total number of bits in the set. If the total number of bits is zero, the function + * returns 1.0, indicating the set is fully sparse. + * + * @param res RAFT resources for managing CUDA streams and execution policies. + * @return double The sparsity of the bitset, i.e., the fraction of unset bits. + * + * This API will synchronize on the stream of `res`. + */ + double sparsity(const raft::resources& res) const; + + /** + * @brief Calculates the number of `bitset_t` elements required to store a bitset. + * + * This function computes the number of `bitset_t` elements needed to store a bitset, ensuring + * that all bits are accounted for. If the bitset length is not a multiple of the `bitset_t` size + * (in bits), the calculation rounds up to include the remaining bits in an additional `bitset_t` + * element. + * + * @param bitset_len The total length of the bitset in bits. + * @return size_t The number of `bitset_t` elements required to store the bitset. + */ + static inline size_t eval_n_elements(size_t bitset_len) + { + const size_t bits_per_element = sizeof(bitset_t) * 8; + return (bitset_len + bits_per_element - 1) / bits_per_element; + } + private: bitset_t* bitset_ptr_; index_t bitset_len_; diff --git a/cpp/test/core/bitset.cu b/cpp/test/core/bitset.cu index b799297e8c..f130139f31 100644 --- a/cpp/test/core/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,12 +32,13 @@ struct test_spec_bitset { uint64_t bitset_len; uint64_t mask_len; uint64_t query_len; + uint64_t repeat_times; }; auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream& { os << "bitset{bitset_len: " << ss.bitset_len << ", mask_len: " << ss.mask_len - << ", query_len: " << ss.query_len << "}"; + << ", query_len: " << ss.query_len << ", repeat_times: " << ss.repeat_times << "}"; return os; } @@ -80,6 +81,48 @@ void flip_cpu_bitset(std::vector& bitset) } } +template +void repeat_cpu_bitset(std::vector& input, + size_t input_bits, + size_t repeat, + std::vector& output) +{ + const size_t output_bits = input_bits * repeat; + const size_t output_units = (output_bits + sizeof(bitset_t) * 8 - 1) / (sizeof(bitset_t) * 8); + + std::memset(output.data(), 0, output_units * sizeof(bitset_t)); + + size_t output_bit_index = 0; + + for (size_t r = 0; r < repeat; ++r) { + for (size_t i = 0; i < input_bits; ++i) { + size_t input_unit_index = i / (sizeof(bitset_t) * 8); + size_t input_bit_offset = i % (sizeof(bitset_t) * 8); + bool bit = (input[input_unit_index] >> input_bit_offset) & 1; + + size_t output_unit_index = output_bit_index / (sizeof(bitset_t) * 8); + size_t output_bit_offset = output_bit_index % (sizeof(bitset_t) * 8); + + output[output_unit_index] |= (static_cast(bit) << output_bit_offset); + + ++output_bit_index; + } + } +} + +template +double sparsity_cpu_bitset(std::vector& data, size_t total_bits) +{ + size_t one_count = 0; + for (size_t i = 0; i < total_bits; ++i) { + size_t unit_index = i / (sizeof(bitset_t) * 8); + size_t bit_offset = i % (sizeof(bitset_t) * 8); + bool bit = (data[unit_index] >> bit_offset) & 1; + if (bit == 1) { ++one_count; } + } + return static_cast((total_bits - one_count) / (1.0 * total_bits)); +} + template class BitsetTest : public testing::TestWithParam { protected: @@ -87,13 +130,19 @@ class BitsetTest : public testing::TestWithParam { const test_spec_bitset spec; std::vector bitset_result; std::vector bitset_ref; + std::vector bitset_repeat_ref; + std::vector bitset_repeat_result; raft::resources res; public: explicit BitsetTest() : spec(testing::TestWithParam::GetParam()), bitset_result(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))), - bitset_ref(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))) + bitset_ref(raft::ceildiv(spec.bitset_len, uint64_t(bitset_element_size))), + bitset_repeat_ref( + raft::ceildiv(spec.bitset_len * spec.repeat_times, uint64_t(bitset_element_size))), + bitset_repeat_result( + raft::ceildiv(spec.bitset_len * spec.repeat_times, uint64_t(bitset_element_size))) { } @@ -145,6 +194,37 @@ class BitsetTest : public testing::TestWithParam { resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + // test sparsity, repeat and eval_n_elements + if constexpr (std::is_same_v || std::is_same_v) { + auto my_bitset_view = my_bitset.view(); + auto sparsity_result = my_bitset_view.sparsity(res); + auto sparsity_ref = sparsity_cpu_bitset(bitset_ref, size_t(spec.bitset_len)); + ASSERT_EQ(sparsity_result, sparsity_ref); + + auto eval_n_elements = + bitset_view::eval_n_elements(spec.bitset_len * spec.repeat_times); + ASSERT_EQ(bitset_repeat_ref.size(), eval_n_elements); + + auto repeat_device = raft::make_device_vector(res, eval_n_elements); + RAFT_CUDA_TRY(cudaMemsetAsync( + repeat_device.data_handle(), 0, eval_n_elements * sizeof(bitset_t), stream)); + repeat_cpu_bitset( + bitset_ref, size_t(spec.bitset_len), size_t(spec.repeat_times), bitset_repeat_ref); + + my_bitset_view.repeat(res, index_t(spec.repeat_times), repeat_device.data_handle()); + + ASSERT_EQ(bitset_repeat_ref.size(), repeat_device.size()); + update_host( + bitset_repeat_result.data(), repeat_device.data_handle(), repeat_device.size(), stream); + ASSERT_EQ(bitset_repeat_ref.size(), bitset_repeat_result.size()); + ASSERT_TRUE(hostVecMatch(bitset_repeat_ref, bitset_repeat_result, raft::Compare())); + + // recheck the sparsity after repeat + sparsity_result = + sparsity_cpu_bitset(bitset_repeat_result, size_t(spec.bitset_len * spec.repeat_times)); + ASSERT_EQ(sparsity_result, sparsity_ref); + } + // Flip the bitset and re-test auto bitset_count = my_bitset.count(res); my_bitset.flip(res); @@ -167,13 +247,15 @@ class BitsetTest : public testing::TestWithParam { } }; -auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10}, - test_spec_bitset{100, 30, 10}, - test_spec_bitset{1024, 55, 100}, - test_spec_bitset{10000, 1000, 1000}, - test_spec_bitset{1 << 15, 1 << 3, 1 << 12}, - test_spec_bitset{1 << 15, 1 << 24, 1 << 13}, - test_spec_bitset{1 << 25, 1 << 23, 1 << 14}); +auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10, 101}, + test_spec_bitset{100, 30, 10, 13}, + test_spec_bitset{1024, 55, 100, 1}, + test_spec_bitset{10000, 1000, 1000, 100}, + test_spec_bitset{1 << 15, 1 << 3, 1 << 12, 5}, + test_spec_bitset{1 << 15, 1 << 24, 1 << 13, 3}, + test_spec_bitset{1 << 25, 1 << 23, 1 << 14, 3}, + test_spec_bitset{1 << 25, 1 << 23, 1 << 14, 201}, + test_spec_bitset{10000000, 1 << 23, 1 << 14, 401}); using Uint16_32 = BitsetTest; TEST_P(Uint16_32, Run) { run(); }