From 8d30c9fc3a7df580c11f2a06277ecf36c6232588 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 18 Sep 2023 17:56:19 +0200 Subject: [PATCH] Add n_elements fix size --- cpp/include/raft/util/bitset.cuh | 36 +++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/util/bitset.cuh index 25e46c4514..cd2106dba2 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/util/bitset.cuh @@ -51,7 +51,8 @@ struct bitset_view { * @param bitset_span Device vector view of the bitset */ _RAFT_HOST_DEVICE bitset_view(raft::device_vector_view bitset_span) - : bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_span.extent(0)} + : bitset_ptr_{bitset_span.data_handle()}, + bitset_len_{bitset_span.extent(0) * bitset_element_size} { } /** @@ -75,18 +76,23 @@ struct bitset_view { /** * @brief Get the number of bits of the bitset representation. */ - inline _RAFT_HOST_DEVICE auto size() const -> index_t + inline _RAFT_HOST_DEVICE auto size() const -> index_t { return bitset_len_; } + + /** + * @brief Get the number of elements used by the bitset representation. + */ + inline auto n_elements() const -> index_t { - return bitset_len_ * bitset_element_size; + return raft::ceildiv(bitset_len_, bitset_element_size); } inline auto to_mdspan() -> raft::device_vector_view { - return raft::make_device_vector_view(bitset_ptr_, bitset_len_); + return raft::make_device_vector_view(bitset_ptr_, n_elements()); } inline auto to_mdspan() const -> raft::device_vector_view { - return raft::make_device_vector_view(bitset_ptr_, bitset_len_); + return raft::make_device_vector_view(bitset_ptr_, n_elements()); } private: @@ -100,7 +106,7 @@ struct bitset_view { * This structure encapsulates a bitset in device memory. It provides a view() method to get a * device-usable lightweight view of the bitset. * Each index is represented by a single bit in the bitset. The total number of bytes used is - * ceil(bitset_len / 4). + * ceil(bitset_len / 8). * @tparam bitset_t Underlying type of the bitset array. Default is uint32_t. * @tparam index_t Indexing type used. Default is uint32_t. */ @@ -127,7 +133,7 @@ struct bitset { { cudaMemsetAsync(bitset_.data(), default_value ? 0xff : 0x00, - raft::ceildiv(bitset_len, bitset_element_size) * sizeof(bitset_t), + n_elements() * sizeof(bitset_t), resource::get_cuda_stream(res)); bitset_set(res, view(), mask_index, !default_value); } @@ -147,7 +153,7 @@ struct bitset { { cudaMemsetAsync(bitset_.data(), default_value ? 0xff : 0x00, - raft::ceildiv(bitset_len, bitset_element_size) * sizeof(bitset_t), + n_elements() * sizeof(bitset_t), resource::get_cuda_stream(res)); } // Disable copy constructor @@ -180,16 +186,22 @@ struct bitset { */ inline auto size() const -> index_t { return bitset_len_; } + /** + * @brief Get the number of elements used by the bitset representation. + */ + inline auto n_elements() const -> index_t + { + return raft::ceildiv(bitset_len_, bitset_element_size); + } + /** @brief Get an mdspan view of the current bitset */ inline auto view_mdspan() -> raft::device_vector_view { - return raft::make_device_vector_view( - bitset_.data(), raft::ceildiv(bitset_len_, bitset_element_size)); + return raft::make_device_vector_view(bitset_.data(), n_elements()); } [[nodiscard]] inline auto view_mdspan() const -> raft::device_vector_view { - return raft::make_device_vector_view( - bitset_.data(), raft::ceildiv(bitset_len_, bitset_element_size)); + return raft::make_device_vector_view(bitset_.data(), n_elements()); } /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to