Skip to content

Commit

Permalink
Improvements on bitset class
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Oct 6, 2023
1 parent d9fde97 commit 8e7bb87
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 24 deletions.
139 changes: 117 additions & 22 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

#pragma once

#include <raft/core/detail/mdspan_util.cuh> // native_popc
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/util/device_atomics.cuh>
#include <thrust/for_each.h>

Expand All @@ -39,7 +42,7 @@ namespace raft::core {
*/
template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset_view {
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

_RAFT_HOST_DEVICE bitset_view(bitset_t* bitset_ptr, index_t bitset_len)
: bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}
Expand Down Expand Up @@ -69,6 +72,34 @@ struct bitset_view {
const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0;
return is_bit_set;
}
/**
* @brief Device function to test if a given index is set in the bitset.
*
* @param sample_index Single index to test
* @return bool True if index has not been unset in the bitset
*/
inline _RAFT_DEVICE auto operator[](const index_t sample_index) const -> bool
{
return test(sample_index);
}
/**
* @brief Device function to set a given index to set_value in the bitset.
*
* @param sample_index index to set
* @param set_value Value to set the bit to (true or false)
*/
inline _RAFT_DEVICE void set(const index_t sample_index, bool set_value) const
{
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr_ + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr_ + bit_element, bitmask2);
}
}

/**
* @brief Get the device pointer to the bitset.
Expand Down Expand Up @@ -114,7 +145,7 @@ struct bitset_view {
*/
template <typename bitset_t = uint32_t, typename index_t = uint32_t>
struct bitset {
index_t static constexpr const bitset_element_size = sizeof(bitset_t) * 8;
static constexpr index_t bitset_element_size = sizeof(bitset_t) * 8;

/**
* @brief Construct a new bitset object with a list of indices to unset.
Expand All @@ -130,8 +161,7 @@ struct bitset {
bool default_value = true)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
raft::resource::get_cuda_stream(res)},
bitset_len_{bitset_len},
default_value_{default_value}
bitset_len_{bitset_len}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
Expand All @@ -150,8 +180,7 @@ struct bitset {
bitset(const raft::resources& res, index_t bitset_len, bool default_value = true)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
resource::get_cuda_stream(res)},
bitset_len_{bitset_len},
default_value_{default_value}
bitset_len_{bitset_len}
{
cudaMemsetAsync(bitset_.data(),
default_value ? 0xff : 0x00,
Expand Down Expand Up @@ -208,7 +237,7 @@ struct bitset {

/** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to
* the default value. */
void resize(const raft::resources& res, index_t new_bitset_len)
void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true)
{
auto old_size = raft::ceildiv(bitset_len_, bitset_element_size);
auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size);
Expand All @@ -217,7 +246,7 @@ struct bitset {
if (old_size < new_size) {
// If the new size is larger, set the new bits to the default value
cudaMemsetAsync(bitset_.data() + old_size,
default_value_ ? 0xff : 0x00,
default_value ? 0xff : 0x00,
(new_size - old_size) * sizeof(bitset_t),
resource::get_cuda_stream(res));
}
Expand Down Expand Up @@ -255,20 +284,12 @@ struct bitset {
raft::device_vector_view<const index_t, index_t> mask_index,
bool set_value = false)
{
auto* bitset_ptr = this->data_handle();
auto this_bitset_view = view();
thrust::for_each_n(resource::get_thrust_policy(res),
mask_index.data_handle(),
mask_index.extent(0),
[bitset_ptr, set_value] __device__(const index_t sample_index) {
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr + bit_element, bitmask2);
}
[this_bitset_view, set_value] __device__(const index_t sample_index) {
this_bitset_view.set(sample_index, set_value);
});
}
/**
Expand All @@ -289,19 +310,93 @@ struct bitset {
* @brief Reset the bits in a bitset.
*
* @param res RAFT resources
* @param default_value Value to set the bits to (true or false)
*/
void reset(const raft::resources& res)
void reset(const raft::resources& res, bool default_value = true)
{
cudaMemsetAsync(bitset_.data(),
default_value_ ? 0xff : 0x00,
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
}
/**
* @brief Returns the number of bits set to true in count_gpu_scalar.
*
* @param[in] res RAFT resources
* @param[out] count_gpu_scalar Device scalar to store the count
*/
void count(const raft::resources& res, raft::device_scalar_view<index_t> count_gpu_scalar)
{
auto n_elements_ = n_elements();
auto count_gpu =
raft::make_device_vector_view<index_t, index_t>(count_gpu_scalar.data_handle(), 1);
auto bitset_matrix_view = raft::make_device_matrix_view<const bitset_t, index_t, col_major>(
bitset_.data(), n_elements_, 1);

bitset_t n_last_element = (bitset_len_ % bitset_element_size);
bitset_t last_element_mask =
n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0};
raft::linalg::coalesced_reduction(
res,
bitset_matrix_view,
count_gpu,
index_t{0},
false,
[last_element_mask, n_elements_] __device__(bitset_t element, index_t index) {
index_t res = 0;
if constexpr (bitset_element_size == 64) { // Needed because __popc doesn't support 64bit
if (index == n_elements_ - 1)
res = index_t(raft::detail::native_popc<uint64_t>(element & last_element_mask));
else
res = index_t(raft::detail::native_popc<uint64_t>(element));
} else {
if (index == n_elements_ - 1)
res = index_t(__popc(element & last_element_mask));
else
res = index_t(__popc(element));
}

return res;
});
}
/**
* @brief Returns the number of bits set to true.
*
* @param[in] res RAFT resources
* @return index_t Number of bits set to true
*/
auto count(const raft::resources& res) -> index_t
{
auto count_gpu_scalar = raft::make_device_scalar<index_t>(res, 0.0);
count(res, count_gpu_scalar.view());
index_t count_cpu = 0;
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
return count_cpu;
}
/**
* @brief Checks if any of the bits are set to true in the bitset.
*
* @param res RAFT resources
*/
bool any(const raft::resources& res) { return count(res) > 0; }
/**
* @brief Checks if all of the bits are set to true in the bitset.
*
* @param res RAFT resources
*/
bool all(const raft::resources& res) { return count(res) == bitset_len_; }
/**
* @brief Checks if none of the bits are set to true in the bitset.
*
* @param res RAFT resources
*/
bool none(const raft::resources& res) { return count(res) == 0; }

private:
raft::device_uvector<bitset_t> bitset_;
index_t bitset_len_;
bool default_value_;
};

/** @} */
Expand Down
19 changes: 17 additions & 2 deletions cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/core/bitset.cuh>
#include <raft/core/device_mdarray.hpp>
#include <raft/linalg/init.cuh>
#include <raft/random/rng.cuh>

#include <gtest/gtest.h>
Expand All @@ -43,7 +44,7 @@ auto operator<<(std::ostream& os, const test_spec_bitset& ss) -> std::ostream&
template <typename bitset_t, typename index_t>
void add_cpu_bitset(std::vector<bitset_t>& bitset, const std::vector<index_t>& mask_idx)
{
static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8;
constexpr size_t bitset_element_size = sizeof(bitset_t) * 8;
for (size_t i = 0; i < mask_idx.size(); i++) {
auto idx = mask_idx[i];
bitset[idx / bitset_element_size] &= ~(bitset_t{1} << (idx % bitset_element_size));
Expand All @@ -64,7 +65,7 @@ void test_cpu_bitset(const std::vector<bitset_t>& bitset,
const std::vector<index_t>& queries,
std::vector<uint8_t>& result)
{
static size_t constexpr const bitset_element_size = sizeof(bitset_t) * 8;
constexpr size_t bitset_element_size = sizeof(bitset_t) * 8;
for (size_t i = 0; i < queries.size(); i++) {
result[i] = uint8_t((bitset[queries[i] / bitset_element_size] &
(bitset_t{1} << (queries[i] % bitset_element_size))) != 0);
Expand Down Expand Up @@ -145,11 +146,25 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// Flip the bitset and re-test
auto bitset_count = my_bitset.count(res);
my_bitset.flip(res);
ASSERT_EQ(my_bitset.count(res), spec.bitset_len - bitset_count);
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
flip_cpu_bitset(bitset_ref);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

my_bitset.reset(res, false);
ASSERT_EQ(my_bitset.any(res), false);
ASSERT_EQ(my_bitset.none(res), true);
raft::linalg::range(query_device.data_handle(), query_device.size(), stream);
my_bitset.set(res, raft::make_const_mdspan(query_device.view()), true);
bitset_count = my_bitset.count(res);
ASSERT_EQ(bitset_count, query_device.size());
ASSERT_EQ(my_bitset.any(res), true);
ASSERT_EQ(my_bitset.none(res), false);

ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));
}
};

Expand Down

0 comments on commit 8e7bb87

Please sign in to comment.