Skip to content

Commit

Permalink
Move bitset to core and match std::bitset
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Sep 20, 2023
1 parent 08c43d0 commit 9115bbb
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 124 deletions.
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH;UTIL_BENCH"
BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"

CACHE_ARGS=""
NVTX=ON
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ if(BUILD_PRIMS_BENCH)
NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu
bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY
)
ConfigureBench(NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/main.cpp)

ConfigureBench(
NAME TUNE_DISTANCE PATH bench/prims/distance/tune_pairwise/kernel.cu
Expand Down Expand Up @@ -156,5 +157,4 @@ if(BUILD_PRIMS_BENCH)
EXPLICIT_INSTANTIATE_ONLY
)

ConfigureBench(NAME UTIL_BENCH PATH bench/prims/util/bitset.cu bench/prims/main.cpp)
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
*/

#include <common/benchmark.hpp>
#include <raft/core/bitset.cuh>
#include <raft/core/device_mdspan.hpp>
#include <raft/util/bitset.cuh>
#include <rmm/device_uvector.hpp>

namespace raft::bench::util {
namespace raft::bench::core {

struct bitset_inputs {
uint32_t bitset_len;
Expand All @@ -42,10 +42,9 @@ struct bitset_bench : public fixture {
void run_benchmark(::benchmark::State& state) override
{
loop_on_state(state, [this]() {
auto my_bitset = raft::util::bitset<bitset_t, index_t>(
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
this->res, raft::make_const_mdspan(mask.view()), params.bitset_len);
raft::util::bitset_test(
res, my_bitset.view(), raft::make_const_mdspan(queries.view()), outputs.view());
my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view());
});
}

Expand All @@ -72,4 +71,4 @@ RAFT_BENCH_REGISTER(Uint16_64, "", bitset_input_vecs);
RAFT_BENCH_REGISTER(Uint32_32, "", bitset_input_vecs);
RAFT_BENCH_REGISTER(Uint32_64, "", bitset_input_vecs);

} // namespace raft::bench::util
} // namespace raft::bench::core
176 changes: 86 additions & 90 deletions cpp/include/raft/util/bitset.cuh → cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <raft/util/device_atomics.cuh>
#include <thrust/for_each.h>

namespace raft::util {
namespace raft::core {
/**
* @defgroup bitset Bitset
* @{
Expand Down Expand Up @@ -69,6 +69,7 @@ struct bitset_view {
const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0;
return is_bit_set;
}

/**
* @brief Get the device pointer to the bitset.
*/
Expand All @@ -82,7 +83,7 @@ struct bitset_view {
/**
* @brief Get the number of elements used by the bitset representation.
*/
inline auto n_elements() const -> index_t
inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t
{
return raft::ceildiv(bitset_len_, bitset_element_size);
}
Expand Down Expand Up @@ -136,7 +137,7 @@ struct bitset {
default_value ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
bitset_set(res, view(), mask_index, !default_value);
set(res, mask_index, !default_value);
}

/**
Expand Down Expand Up @@ -168,13 +169,13 @@ struct bitset {
*
* @return bitset_view<bitset_t, index_t>
*/
inline auto view() -> raft::util::bitset_view<bitset_t, index_t>
inline auto view() -> raft::core::bitset_view<bitset_t, index_t>
{
return bitset_view<bitset_t, index_t>(view_mdspan(), bitset_len_);
return bitset_view<bitset_t, index_t>(to_mdspan(), bitset_len_);
}
[[nodiscard]] inline auto view() const -> raft::util::bitset_view<const bitset_t, index_t>
[[nodiscard]] inline auto view() const -> raft::core::bitset_view<const bitset_t, index_t>
{
return bitset_view<const bitset_t, index_t>(view_mdspan(), bitset_len_);
return bitset_view<const bitset_t, index_t>(to_mdspan(), bitset_len_);
}

/**
Expand All @@ -196,11 +197,11 @@ struct bitset {
}

/** @brief Get an mdspan view of the current bitset */
inline auto view_mdspan() -> raft::device_vector_view<bitset_t, index_t>
inline auto to_mdspan() -> raft::device_vector_view<bitset_t, index_t>
{
return raft::make_device_vector_view<bitset_t, index_t>(bitset_.data(), n_elements());
}
[[nodiscard]] inline auto view_mdspan() const -> raft::device_vector_view<const bitset_t, index_t>
[[nodiscard]] inline auto to_mdspan() const -> raft::device_vector_view<const bitset_t, index_t>
{
return raft::make_device_vector_view<const bitset_t, index_t>(bitset_.data(), n_elements());
}
Expand All @@ -222,91 +223,86 @@ struct bitset {
}
}

/**
* @brief Test a list of indices in a bitset.
*
* @tparam output_t Output type of the test. Default is bool.
* @param res RAFT resources
* @param queries List of indices to test
* @param output List of outputs
*/
template <typename output_t = bool>
void test(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> queries,
raft::device_vector_view<output_t, index_t> output) const
{
RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size");
auto bitset_view = view();
raft::linalg::map(
res,
output,
[bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); },
queries);
}
/**
* @brief Set a list of indices in a bitset to set_value.
*
* @param res RAFT resources
* @param mask_index indices to remove from the bitset
* @param set_value Value to set the bits to (true or false)
*/
void set(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
bool set_value = false)
{
auto* bitset_ptr = this->data_handle();
thrust::for_each_n(resource::get_thrust_policy(res),
mask_index.data_handle(),
mask_index.extent(0),
[bitset_ptr, set_value] __device__(const index_t sample_index) {
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr + bit_element, bitmask2);
}
});
}
/**
* @brief Flip all the bits in a bitset.
*
* @param res RAFT resources
*/
void flip(const raft::resources& res)
{
auto bitset_span = this->to_mdspan();
raft::linalg::map(
res,
bitset_span,
[] __device__(bitset_t element) { return bitset_t(~element); },
raft::make_const_mdspan(bitset_span));
}
/**
* @brief Reset the bits in a bitset.
*
* @param res RAFT resources
*/
void reset(const raft::resources& res)
{
cudaMemsetAsync(bitset_.data(),
default_value_ ? 0xff : 0x00,
n_elements() * sizeof(bitset_t),
resource::get_cuda_stream(res));
}

private:
raft::device_uvector<bitset_t> bitset_;
index_t bitset_len_;
bool default_value_;
};

/**
* @brief Set a list of indices in a bitset to set_value.
*
* @tparam bitset_t Underlying type of the bitset array
* @tparam index_t Indexing type used.
* @param res RAFT resources
* @param bitset_view_ View of the bitset
* @param mask_index indices to remove from the bitset
* @param set_value Value to set the bits to (true or false)
*/
template <typename bitset_t, typename index_t>
void bitset_set(const raft::resources& res,
raft::util::bitset_view<bitset_t, index_t> bitset_view_,
raft::device_vector_view<const index_t, index_t> mask_index,
bool set_value = false)
{
auto* bitset_ptr = bitset_view_.data_handle();
constexpr auto bitset_element_size =
raft::util::bitset_view<bitset_t, index_t>::bitset_element_size;
thrust::for_each_n(
resource::get_thrust_policy(res),
mask_index.data_handle(),
mask_index.extent(0),
[bitset_ptr, set_value, bitset_element_size] __device__(const index_t sample_index) {
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
const bitset_t bitmask = bitset_t{1} << bit_index;
if (set_value) {
atomicOr(bitset_ptr + bit_element, bitmask);
} else {
const bitset_t bitmask2 = ~bitmask;
atomicAnd(bitset_ptr + bit_element, bitmask2);
}
});
}

/**
* @brief Test a list of indices in a bitset.
*
* @tparam bitset_t Underlying type of the bitset array
* @tparam index_t Indexing type
* @tparam output_t Output type of the test. Default is bool.
* @param res RAFT resources
* @param bitset_view_ View of the bitset
* @param queries List of indices to test
* @param output List of outputs
*/
template <typename bitset_t, typename index_t, typename output_t = bool>
void bitset_test(const raft::resources& res,
const raft::util::bitset_view<bitset_t, index_t> bitset_view_,
raft::device_vector_view<const index_t, index_t> queries,
raft::device_vector_view<output_t, index_t> output)
{
RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size");
raft::linalg::map(
res,
output,
[=] __device__(index_t query) { return output_t(bitset_view_.test(query)); },
queries);
}

/**
* @brief Flip all the bit in a bitset.
*
* @tparam bitset_t Underlying type of the bitset array
* @tparam index_t Indexing type
* @param res RAFT resources
* @param bitset_view_ View of the bitset
*/
template <typename bitset_t, typename index_t>
void bitset_flip(const raft::resources& res,
raft::util::bitset_view<bitset_t, index_t> bitset_view_)
{
auto bitset_span = bitset_view_.to_mdspan();
raft::linalg::map(
res,
bitset_span,
[] __device__(bitset_t element) { return bitset_t(~element); },
raft::make_const_mdspan(bitset_span));
}
/** @} */
} // end namespace raft::util
} // end namespace raft::core
2 changes: 1 addition & 1 deletion cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ if(BUILD_TESTS)
NAME
CORE_TEST
PATH
test/core/bitset.cu
test/core/device_resources_manager.cpp
test/core/device_setter.cpp
test/core/logger.cpp
Expand Down Expand Up @@ -423,7 +424,6 @@ if(BUILD_TESTS)
PATH
test/core/seive.cu
test/util/bitonic_sort.cu
test/util/bitset.cu
test/util/cudart_utils.cpp
test/util/device_atomics.cu
test/util/integer_utils.cpp
Expand Down
22 changes: 10 additions & 12 deletions cpp/test/util/bitset.cu → cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

#include "../test_utils.cuh"

#include <raft/core/bitset.cuh>
#include <raft/core/device_mdarray.hpp>
#include <raft/random/rng.cuh>
#include <raft/util/bitset.cuh>

#include <gtest/gtest.h>

#include <algorithm>
#include <numeric>

namespace raft::util {
namespace raft::core {

struct test_spec_bitset {
uint64_t bitset_len;
Expand Down Expand Up @@ -109,10 +109,9 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
resource::sync_stream(res, stream);

// calculate the results
auto test_bitset = raft::util::bitset<bitset_t, index_t>(
auto my_bitset = raft::core::bitset<bitset_t, index_t>(
res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len));
update_host(
bitset_result.data(), test_bitset.view().data_handle(), bitset_result.size(), stream);
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);

// calculate the reference
create_cpu_bitset(bitset_ref, mask_cpu);
Expand All @@ -128,8 +127,7 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
// Create queries and verify the test results
raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len));
update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream);
raft::util::bitset_test(
res, test_bitset.view(), raft::make_const_mdspan(query_device.view()), result_device.view());
my_bitset.test(res, raft::make_const_mdspan(query_device.view()), result_device.view());
update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream);
test_cpu_bitset(bitset_ref, query_cpu, result_ref);
resource::sync_stream(res, stream);
Expand All @@ -139,16 +137,16 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len));
update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream);
resource::sync_stream(res, stream);
raft::util::bitset_set<bitset_t, index_t>(res, test_bitset.view(), mask_device.view());
update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream);
my_bitset.set(res, mask_device.view());
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);

add_cpu_bitset(bitset_ref, mask_cpu);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// Flip the bitset and re-test
raft::util::bitset_flip<bitset_t, index_t>(res, test_bitset.view());
update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream);
my_bitset.flip(res);
update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream);
flip_cpu_bitset(bitset_ref);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));
Expand Down Expand Up @@ -187,4 +185,4 @@ using Uint64_64 = BitsetTest<uint64_t, uint64_t>;
TEST_P(Uint64_64, Run) { run(); }
INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_64, inputs_bitset);

} // namespace raft::util
} // namespace raft::core
3 changes: 2 additions & 1 deletion docs/source/cpp_api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ expose in public APIs.
core_nvtx.rst
core_interruptible.rst
core_operators.rst
core_math.rst
core_math.rst
core_bitset.rst
Loading

0 comments on commit 9115bbb

Please sign in to comment.