diff --git a/cpp/bench/prims/core/bitset.cu b/cpp/bench/prims/core/bitset.cu index 5f44aa9af5..ce3136bcd5 100644 --- a/cpp/bench/prims/core/bitset.cu +++ b/cpp/bench/prims/core/bitset.cu @@ -44,7 +44,7 @@ struct bitset_bench : public fixture { loop_on_state(state, [this]() { auto my_bitset = raft::core::bitset( this->res, raft::make_const_mdspan(mask.view()), params.bitset_len); - my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view()); + my_bitset.test(this->res, raft::make_const_mdspan(queries.view()), outputs.view()); }); } diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index 6747c5fab0..552c2e9ac5 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -16,10 +16,13 @@ #pragma once +#include // native_popc #include +#include #include #include #include +#include #include #include @@ -39,7 +42,7 @@ namespace raft::core { */ template struct bitset_view { - index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; + static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; _RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len) : bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len} @@ -69,12 +72,40 @@ struct bitset_view { const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; return is_bit_set; } + /** + * @brief Device function to test if a given index is set in the bitset. + * + * @param sample_index Single index to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_DEVICE auto operator[](const index_t sample_index) const -> bool + { + return test(sample_index); + } + /** + * @brief Device function to set a given index to set_value in the bitset. + * + * @param sample_index index to set + * @param set_value Value to set the bit to (true or false) + */ + inline _RAFT_DEVICE void set(const index_t sample_index, bool set_value) const + { + const index_t bit_element = sample_index / bitset_element_size; + const index_t bit_index = sample_index % bitset_element_size; + const bitset_t bitmask = bitset_t{1} << bit_index; + if (set_value) { + atomicOr(bitset_ptr_ + bit_element, bitmask); + } else { + const bitset_t bitmask2 = ~bitmask; + atomicAnd(bitset_ptr_ + bit_element, bitmask2); + } + } /** * @brief Get the device pointer to the bitset. */ - inline _RAFT_HOST_DEVICE auto data_handle() -> bitset_t* { return bitset_ptr_; } - inline _RAFT_HOST_DEVICE auto data_handle() const -> const bitset_t* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto data() -> bitset_t* { return bitset_ptr_; } + inline _RAFT_HOST_DEVICE auto data() const -> const bitset_t* { return bitset_ptr_; } /** * @brief Get the number of bits of the bitset representation. */ @@ -114,7 +145,7 @@ struct bitset_view { */ template struct bitset { - index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8; + static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; /** * @brief Construct a new bitset object with a list of indices to unset. @@ -130,13 +161,9 @@ struct bitset { bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), raft::resource::get_cuda_stream(res)}, - bitset_len_{bitset_len}, - default_value_{default_value} + bitset_len_{bitset_len} { - cudaMemsetAsync(bitset_.data(), - default_value ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res)); + reset(res, default_value); set(res, mask_index, !default_value); } @@ -150,13 +177,9 @@ struct bitset { bitset(const raft::resources& res, index_t bitset_len, bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), resource::get_cuda_stream(res)}, - bitset_len_{bitset_len}, - default_value_{default_value} + bitset_len_{bitset_len} { - cudaMemsetAsync(bitset_.data(), - default_value ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res)); + reset(res, default_value); } // Disable copy constructor bitset(const bitset&) = delete; @@ -181,8 +204,8 @@ struct bitset { /** * @brief Get the device pointer to the bitset. */ - inline auto data_handle() -> bitset_t* { return bitset_.data(); } - inline auto data_handle() const -> const bitset_t* { return bitset_.data(); } + inline auto data() -> bitset_t* { return bitset_.data(); } + inline auto data() const -> const bitset_t* { return bitset_.data(); } /** * @brief Get the number of bits of the bitset representation. */ @@ -207,8 +230,12 @@ struct bitset { } /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to - * the default value. */ - void resize(const raft::resources& res, index_t new_bitset_len) + * the default value. + * @param res RAFT resources + * @param new_bitset_len new size of the bitset + * @param default_value default value to initialize the new bits to + */ + void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true) { auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); @@ -216,10 +243,11 @@ struct bitset { bitset_len_ = new_bitset_len; if (old_size < new_size) { // If the new size is larger, set the new bits to the default value - cudaMemsetAsync(bitset_.data() + old_size, - default_value_ ? 0xff : 0x00, - (new_size - old_size) * sizeof(bitset_t), - resource::get_cuda_stream(res)); + + thrust::fill_n(resource::get_thrust_policy(res), + bitset_.data() + old_size, + new_size - old_size, + default_value ? ~bitset_t{0} : bitset_t{0}); } } @@ -255,25 +283,16 @@ struct bitset { raft::device_vector_view mask_index, bool set_value = false) { - auto* bitset_ptr = this->data_handle(); + auto this_bitset_view = view(); thrust::for_each_n(resource::get_thrust_policy(res), mask_index.data_handle(), mask_index.extent(0), - [bitset_ptr, set_value] __device__(const index_t sample_index) { - const index_t bit_element = sample_index / bitset_element_size; - const index_t bit_index = sample_index % bitset_element_size; - const bitset_t bitmask = bitset_t{1} << bit_index; - if (set_value) { - atomicOr(bitset_ptr + bit_element, bitmask); - } else { - const bitset_t bitmask2 = ~bitmask; - atomicAnd(bitset_ptr + bit_element, bitmask2); - } + [this_bitset_view, set_value] __device__(const index_t sample_index) { + this_bitset_view.set(sample_index, set_value); }); } /** * @brief Flip all the bits in a bitset. - * * @param res RAFT resources */ void flip(const raft::resources& res) @@ -289,19 +308,90 @@ struct bitset { * @brief Reset the bits in a bitset. * * @param res RAFT resources + * @param default_value Value to set the bits to (true or false) */ - void reset(const raft::resources& res) + void reset(const raft::resources& res, bool default_value = true) { - cudaMemsetAsync(bitset_.data(), - default_value_ ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res)); + thrust::fill_n(resource::get_thrust_policy(res), + bitset_.data(), + n_elements(), + default_value ? ~bitset_t{0} : bitset_t{0}); } + /** + * @brief Returns the number of bits set to true in count_gpu_scalar. + * + * @param[in] res RAFT resources + * @param[out] count_gpu_scalar Device scalar to store the count + */ + void count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar) + { + auto n_elements_ = n_elements(); + auto count_gpu = + raft::make_device_vector_view(count_gpu_scalar.data_handle(), 1); + auto bitset_matrix_view = raft::make_device_matrix_view( + bitset_.data(), n_elements_, 1); + + bitset_t n_last_element = (bitset_len_ % bitset_element_size); + bitset_t last_element_mask = + n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0}; + raft::linalg::coalesced_reduction( + res, + bitset_matrix_view, + count_gpu, + index_t{0}, + false, + [last_element_mask, n_elements_] __device__(bitset_t element, index_t index) { + index_t result = 0; + if constexpr (bitset_element_size == 64) { + if (index == n_elements_ - 1) + result = index_t(raft::detail::popc(element & last_element_mask)); + else + result = index_t(raft::detail::popc(element)); + } else { // Needed because popc is not overloaded for 16 and 8 bit elements + if (index == n_elements_ - 1) + result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask)); + else + result = index_t(raft::detail::popc(uint32_t{element})); + } + + return result; + }); + } + /** + * @brief Returns the number of bits set to true. + * + * @param res RAFT resources + * @return index_t Number of bits set to true + */ + auto count(const raft::resources& res) -> index_t + { + auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); + count(res, count_gpu_scalar.view()); + index_t count_cpu = 0; + raft::update_host( + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); + resource::sync_stream(res); + return count_cpu; + } + /** + * @brief Checks if any of the bits are set to true in the bitset. + * @param res RAFT resources + */ + bool any(const raft::resources& res) { return count(res) > 0; } + /** + * @brief Checks if all of the bits are set to true in the bitset. + * @param res RAFT resources + */ + bool all(const raft::resources& res) { return count(res) == bitset_len_; } + /** + * @brief Checks if none of the bits are set to true in the bitset. + * @param res RAFT resources + */ + bool none(const raft::resources& res) { return count(res) == 0; } private: raft::device_uvector bitset_; index_t bitset_len_; - bool default_value_; }; /** @} */ diff --git a/cpp/test/core/bitset.cu b/cpp/test/core/bitset.cu index 215de98aaf..b799297e8c 100644 --- a/cpp/test/core/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -43,7 +44,7 @@ auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream& template void add_cpu_bitset(std::vector& bitset, const std::vector& mask_idx) { - static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8; + constexpr size_t bitset_element_size = sizeof(bitset_t) * 8; for (size_t i = 0; i < mask_idx.size(); i++) { auto idx = mask_idx[i]; bitset[idx / bitset_element_size] &= ~(bitset_t{1} << (idx % bitset_element_size)); @@ -64,7 +65,7 @@ void test_cpu_bitset(const std::vector& bitset, const std::vector& queries, std::vector& result) { - static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8; + constexpr size_t bitset_element_size = sizeof(bitset_t) * 8; for (size_t i = 0; i < queries.size(); i++) { result[i] = uint8_t((bitset[queries[i] / bitset_element_size] & (bitset_t{1} << (queries[i] % bitset_element_size))) != 0); @@ -111,7 +112,7 @@ class BitsetTest : public testing::TestWithParam { // calculate the results auto my_bitset = raft::core::bitset( res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len)); - update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); + update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream); // calculate the reference create_cpu_bitset(bitset_ref, mask_cpu); @@ -138,18 +139,31 @@ class BitsetTest : public testing::TestWithParam { update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); my_bitset.set(res, mask_device.view()); - update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); + update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream); add_cpu_bitset(bitset_ref, mask_cpu); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); // Flip the bitset and re-test + auto bitset_count = my_bitset.count(res); my_bitset.flip(res); - update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); + ASSERT_EQ(my_bitset.count(res), spec.bitset_len - bitset_count); + update_host(bitset_result.data(), my_bitset.data(), bitset_result.size(), stream); flip_cpu_bitset(bitset_ref); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + + // Test count() operations + my_bitset.reset(res, false); + ASSERT_EQ(my_bitset.any(res), false); + ASSERT_EQ(my_bitset.none(res), true); + raft::linalg::range(query_device.data_handle(), query_device.size(), stream); + my_bitset.set(res, raft::make_const_mdspan(query_device.view()), true); + bitset_count = my_bitset.count(res); + ASSERT_EQ(bitset_count, query_device.size()); + ASSERT_EQ(my_bitset.any(res), true); + ASSERT_EQ(my_bitset.none(res), false); } };