Skip to content

Commit

Permalink
Add n_elements fix size
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Sep 18, 2023
1 parent e0c1d24 commit 8d30c9f
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions cpp/include/raft/util/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ struct bitset_view {
* @param bitset_span Device vector view of the bitset
*/
_RAFT_HOST_DEVICE bitset_view(raft::device_vector_view<bitset_t, index_t> bitset_span)
: bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_span.extent(0)}
: bitset_ptr_{bitset_span.data_handle()},
bitset_len_{bitset_span.extent(0) * bitset_element_size}
{
}
/**
Expand All @@ -75,18 +76,23 @@ struct bitset_view {
/**
* @brief Get the number of bits of the bitset representation.
*/
inline _RAFT_HOST_DEVICE auto size() const -> index_t
inline _RAFT_HOST_DEVICE auto size() const -> index_t { return bitset_len_; }

/**
* @brief Get the number of elements used by the bitset representation.
*/
inline auto n_elements() const -> index_t
{
return bitset_len_ * bitset_element_size;
return raft::ceildiv(bitset_len_, bitset_element_size);
}

inline auto to_mdspan() -> raft::device_vector_view<bitset_t, index_t>
{
return raft::make_device_vector_view<bitset_t, index_t>(bitset_ptr_, bitset_len_);
return raft::make_device_vector_view<bitset_t, index_t>(bitset_ptr_, n_elements());
}
inline auto to_mdspan() const -> raft::device_vector_view<const bitset_t, index_t>
{
return raft::make_device_vector_view<const bitset_t, index_t>(bitset_ptr_, bitset_len_);
return raft::make_device_vector_view<const bitset_t, index_t>(bitset_ptr_, n_elements());
}

private:
Expand All @@ -100,7 +106,7 @@ struct bitset_view {
* This structure encapsulates a bitset in device memory. It provides a view() method to get a
* device-usable lightweight view of the bitset.
* Each index is represented by a single bit in the bitset. The total number of bytes used is
* ceil(bitset_len / 4).
* ceil(bitset_len / 8).
* @tparam bitset_t Underlying type of the bitset array. Default is uint32_t.
* @tparam index_t Indexing type used. Default is uint32_t.
*/
Expand All @@ -127,7 +133,7 @@ struct bitset {
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
raft::ceildiv(bitset_len, bitset_element_size) * sizeof(bitset_t),
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
bitset_set(res, view(), mask_index, !default_value);
}
Expand All @@ -147,7 +153,7 @@ struct bitset {
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
raft::ceildiv(bitset_len, bitset_element_size) * sizeof(bitset_t),
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
}
// Disable copy constructor
Expand Down Expand Up @@ -180,16 +186,22 @@ struct bitset {
*/
inline auto size() const -> index_t { return bitset_len_; }

/**
* @brief Get the number of elements used by the bitset representation.
*/
inline auto n_elements() const -> index_t
{
return raft::ceildiv(bitset_len_, bitset_element_size);
}

/** @brief Get an mdspan view of the current bitset */
inline auto view_mdspan() -> raft::device_vector_view<bitset_t, index_t>
{
return raft::make_device_vector_view<bitset_t, index_t>(
bitset_.data(), raft::ceildiv(bitset_len_, bitset_element_size));
return raft::make_device_vector_view<bitset_t, index_t>(bitset_.data(), n_elements());
}
[[nodiscard]] inline auto view_mdspan() const -> raft::device_vector_view<const bitset_t, index_t>
{
return raft::make_device_vector_view<const bitset_t, index_t>(
bitset_.data(), raft::ceildiv(bitset_len_, bitset_element_size));
return raft::make_device_vector_view<const bitset_t, index_t>(bitset_.data(), n_elements());
}

/** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to
Expand Down

0 comments on commit 8d30c9f

Please sign in to comment.