From 9115bbb31fc6c0902bb9d5c8c15e360d41da4c97 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 20 Sep 2023 16:31:57 +0200 Subject: [PATCH] Move bitset to core and match std::bitset --- build.sh | 2 +- cpp/bench/prims/CMakeLists.txt | 2 +- cpp/bench/prims/{util => core}/bitset.cu | 11 +- cpp/include/raft/{util => core}/bitset.cuh | 176 ++++++++++----------- cpp/test/CMakeLists.txt | 2 +- cpp/test/{util => core}/bitset.cu | 22 ++- docs/source/cpp_api/core.rst | 3 +- docs/source/cpp_api/core_bitset.rst | 15 ++ docs/source/cpp_api/utils.rst | 12 -- 9 files changed, 121 insertions(+), 124 deletions(-) rename cpp/bench/prims/{util => core}/bitset.cu (89%) rename cpp/include/raft/{util => core}/bitset.cuh (68%) rename cpp/test/{util => core}/bitset.cu (90%) create mode 100644 docs/source/cpp_api/core_bitset.rst diff --git a/build.sh b/build.sh index 1fa1abbee5..5543faaebe 100755 --- a/build.sh +++ b/build.sh @@ -79,7 +79,7 @@ BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" -BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH;UTIL_BENCH" +BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" NVTX=ON diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 1690d6f320..ca4b0f099d 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -77,6 +77,7 @@ if(BUILD_PRIMS_BENCH) NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) + ConfigureBench(NAME CORE_BENCH PATH bench/prims/core/bitset.cu bench/prims/main.cpp) ConfigureBench( NAME TUNE_DISTANCE PATH bench/prims/distance/tune_pairwise/kernel.cu @@ -156,5 +157,4 @@ if(BUILD_PRIMS_BENCH) EXPLICIT_INSTANTIATE_ONLY ) - ConfigureBench(NAME UTIL_BENCH PATH bench/prims/util/bitset.cu bench/prims/main.cpp) endif() diff --git a/cpp/bench/prims/util/bitset.cu b/cpp/bench/prims/core/bitset.cu similarity index 89% rename from cpp/bench/prims/util/bitset.cu rename to cpp/bench/prims/core/bitset.cu index c7cba797f4..5f44aa9af5 100644 --- a/cpp/bench/prims/util/bitset.cu +++ b/cpp/bench/prims/core/bitset.cu @@ -15,11 +15,11 @@ */ #include +#include #include -#include #include -namespace raft::bench::util { +namespace raft::bench::core { struct bitset_inputs { uint32_t bitset_len; @@ -42,10 +42,9 @@ struct bitset_bench : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - auto my_bitset = raft::util::bitset( + auto my_bitset = raft::core::bitset( this->res, raft::make_const_mdspan(mask.view()), params.bitset_len); - raft::util::bitset_test( - res, my_bitset.view(), raft::make_const_mdspan(queries.view()), outputs.view()); + my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view()); }); } @@ -72,4 +71,4 @@ RAFT_BENCH_REGISTER(Uint16_64, "", bitset_input_vecs); RAFT_BENCH_REGISTER(Uint32_32, "", bitset_input_vecs); RAFT_BENCH_REGISTER(Uint32_64, "", bitset_input_vecs); -} // namespace raft::bench::util +} // namespace raft::bench::core diff --git a/cpp/include/raft/util/bitset.cuh b/cpp/include/raft/core/bitset.cuh similarity index 68% rename from cpp/include/raft/util/bitset.cuh rename to cpp/include/raft/core/bitset.cuh index af5ef79588..6747c5fab0 100644 --- a/cpp/include/raft/util/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -23,7 +23,7 @@ #include #include -namespace raft::util { +namespace raft::core { /** * @defgroup bitset Bitset * @{ @@ -69,6 +69,7 @@ struct bitset_view { const bool is_bit_set = (bit_element & (bitset_t{1} << bit_index)) != 0; return is_bit_set; } + /** * @brief Get the device pointer to the bitset. */ @@ -82,7 +83,7 @@ struct bitset_view { /** * @brief Get the number of elements used by the bitset representation. */ - inline auto n_elements() const -> index_t + inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t { return raft::ceildiv(bitset_len_, bitset_element_size); } @@ -136,7 +137,7 @@ struct bitset { default_value ? 0xff : 0x00, n_elements() * sizeof(bitset_t), resource::get_cuda_stream(res)); - bitset_set(res, view(), mask_index, !default_value); + set(res, mask_index, !default_value); } /** @@ -168,13 +169,13 @@ struct bitset { * * @return bitset_view */ - inline auto view() -> raft::util::bitset_view + inline auto view() -> raft::core::bitset_view { - return bitset_view(view_mdspan(), bitset_len_); + return bitset_view(to_mdspan(), bitset_len_); } - [[nodiscard]] inline auto view() const -> raft::util::bitset_view + [[nodiscard]] inline auto view() const -> raft::core::bitset_view { - return bitset_view(view_mdspan(), bitset_len_); + return bitset_view(to_mdspan(), bitset_len_); } /** @@ -196,11 +197,11 @@ struct bitset { } /** @brief Get an mdspan view of the current bitset */ - inline auto view_mdspan() -> raft::device_vector_view + inline auto to_mdspan() -> raft::device_vector_view { return raft::make_device_vector_view(bitset_.data(), n_elements()); } - [[nodiscard]] inline auto view_mdspan() const -> raft::device_vector_view + [[nodiscard]] inline auto to_mdspan() const -> raft::device_vector_view { return raft::make_device_vector_view(bitset_.data(), n_elements()); } @@ -222,91 +223,86 @@ struct bitset { } } + /** + * @brief Test a list of indices in a bitset. + * + * @tparam output_t Output type of the test. Default is bool. + * @param res RAFT resources + * @param queries List of indices to test + * @param output List of outputs + */ + template + void test(const raft::resources& res, + raft::device_vector_view queries, + raft::device_vector_view output) const + { + RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); + auto bitset_view = view(); + raft::linalg::map( + res, + output, + [bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); }, + queries); + } + /** + * @brief Set a list of indices in a bitset to set_value. + * + * @param res RAFT resources + * @param mask_index indices to remove from the bitset + * @param set_value Value to set the bits to (true or false) + */ + void set(const raft::resources& res, + raft::device_vector_view mask_index, + bool set_value = false) + { + auto* bitset_ptr = this->data_handle(); + thrust::for_each_n(resource::get_thrust_policy(res), + mask_index.data_handle(), + mask_index.extent(0), + [bitset_ptr, set_value] __device__(const index_t sample_index) { + const index_t bit_element = sample_index / bitset_element_size; + const index_t bit_index = sample_index % bitset_element_size; + const bitset_t bitmask = bitset_t{1} << bit_index; + if (set_value) { + atomicOr(bitset_ptr + bit_element, bitmask); + } else { + const bitset_t bitmask2 = ~bitmask; + atomicAnd(bitset_ptr + bit_element, bitmask2); + } + }); + } + /** + * @brief Flip all the bits in a bitset. + * + * @param res RAFT resources + */ + void flip(const raft::resources& res) + { + auto bitset_span = this->to_mdspan(); + raft::linalg::map( + res, + bitset_span, + [] __device__(bitset_t element) { return bitset_t(~element); }, + raft::make_const_mdspan(bitset_span)); + } + /** + * @brief Reset the bits in a bitset. + * + * @param res RAFT resources + */ + void reset(const raft::resources& res) + { + cudaMemsetAsync(bitset_.data(), + default_value_ ? 0xff : 0x00, + n_elements() * sizeof(bitset_t), + resource::get_cuda_stream(res)); + } + private: raft::device_uvector bitset_; index_t bitset_len_; bool default_value_; }; -/** - * @brief Set a list of indices in a bitset to set_value. - * - * @tparam bitset_t Underlying type of the bitset array - * @tparam index_t Indexing type used. - * @param res RAFT resources - * @param bitset_view_ View of the bitset - * @param mask_index indices to remove from the bitset - * @param set_value Value to set the bits to (true or false) - */ -template -void bitset_set(const raft::resources& res, - raft::util::bitset_view bitset_view_, - raft::device_vector_view mask_index, - bool set_value = false) -{ - auto* bitset_ptr = bitset_view_.data_handle(); - constexpr auto bitset_element_size = - raft::util::bitset_view::bitset_element_size; - thrust::for_each_n( - resource::get_thrust_policy(res), - mask_index.data_handle(), - mask_index.extent(0), - [bitset_ptr, set_value, bitset_element_size] __device__(const index_t sample_index) { - const index_t bit_element = sample_index / bitset_element_size; - const index_t bit_index = sample_index % bitset_element_size; - const bitset_t bitmask = bitset_t{1} << bit_index; - if (set_value) { - atomicOr(bitset_ptr + bit_element, bitmask); - } else { - const bitset_t bitmask2 = ~bitmask; - atomicAnd(bitset_ptr + bit_element, bitmask2); - } - }); -} - -/** - * @brief Test a list of indices in a bitset. - * - * @tparam bitset_t Underlying type of the bitset array - * @tparam index_t Indexing type - * @tparam output_t Output type of the test. Default is bool. - * @param res RAFT resources - * @param bitset_view_ View of the bitset - * @param queries List of indices to test - * @param output List of outputs - */ -template -void bitset_test(const raft::resources& res, - const raft::util::bitset_view bitset_view_, - raft::device_vector_view queries, - raft::device_vector_view output) -{ - RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); - raft::linalg::map( - res, - output, - [=] __device__(index_t query) { return output_t(bitset_view_.test(query)); }, - queries); -} - -/** - * @brief Flip all the bit in a bitset. - * - * @tparam bitset_t Underlying type of the bitset array - * @tparam index_t Indexing type - * @param res RAFT resources - * @param bitset_view_ View of the bitset - */ -template -void bitset_flip(const raft::resources& res, - raft::util::bitset_view bitset_view_) -{ - auto bitset_span = bitset_view_.to_mdspan(); - raft::linalg::map( - res, - bitset_span, - [] __device__(bitset_t element) { return bitset_t(~element); }, - raft::make_const_mdspan(bitset_span)); -} /** @} */ -} // end namespace raft::util +} // end namespace raft::core diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 57a45c557c..a9b387008f 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -105,6 +105,7 @@ if(BUILD_TESTS) NAME CORE_TEST PATH + test/core/bitset.cu test/core/device_resources_manager.cpp test/core/device_setter.cpp test/core/logger.cpp @@ -423,7 +424,6 @@ if(BUILD_TESTS) PATH test/core/seive.cu test/util/bitonic_sort.cu - test/util/bitset.cu test/util/cudart_utils.cpp test/util/device_atomics.cu test/util/integer_utils.cpp diff --git a/cpp/test/util/bitset.cu b/cpp/test/core/bitset.cu similarity index 90% rename from cpp/test/util/bitset.cu rename to cpp/test/core/bitset.cu index 4793dde2f1..215de98aaf 100644 --- a/cpp/test/util/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -16,16 +16,16 @@ #include "../test_utils.cuh" +#include #include #include -#include #include #include #include -namespace raft::util { +namespace raft::core { struct test_spec_bitset { uint64_t bitset_len; @@ -109,10 +109,9 @@ class BitsetTest : public testing::TestWithParam { resource::sync_stream(res, stream); // calculate the results - auto test_bitset = raft::util::bitset( + auto my_bitset = raft::core::bitset( res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len)); - update_host( - bitset_result.data(), test_bitset.view().data_handle(), bitset_result.size(), stream); + update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); // calculate the reference create_cpu_bitset(bitset_ref, mask_cpu); @@ -128,8 +127,7 @@ class BitsetTest : public testing::TestWithParam { // Create queries and verify the test results raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); - raft::util::bitset_test( - res, test_bitset.view(), raft::make_const_mdspan(query_device.view()), result_device.view()); + my_bitset.test(res, raft::make_const_mdspan(query_device.view()), result_device.view()); update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); test_cpu_bitset(bitset_ref, query_cpu, result_ref); resource::sync_stream(res, stream); @@ -139,16 +137,16 @@ class BitsetTest : public testing::TestWithParam { raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); - raft::util::bitset_set(res, test_bitset.view(), mask_device.view()); - update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream); + my_bitset.set(res, mask_device.view()); + update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); add_cpu_bitset(bitset_ref, mask_cpu); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); // Flip the bitset and re-test - raft::util::bitset_flip(res, test_bitset.view()); - update_host(bitset_result.data(), test_bitset.data_handle(), bitset_result.size(), stream); + my_bitset.flip(res); + update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); flip_cpu_bitset(bitset_ref); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); @@ -187,4 +185,4 @@ using Uint64_64 = BitsetTest; TEST_P(Uint64_64, Run) { run(); } INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64_64, inputs_bitset); -} // namespace raft::util +} // namespace raft::core diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index 7e69f92948..39e57fd69a 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -20,4 +20,5 @@ expose in public APIs. core_nvtx.rst core_interruptible.rst core_operators.rst - core_math.rst \ No newline at end of file + core_math.rst + core_bitset.rst \ No newline at end of file diff --git a/docs/source/cpp_api/core_bitset.rst b/docs/source/cpp_api/core_bitset.rst new file mode 100644 index 0000000000..af1cff6d37 --- /dev/null +++ b/docs/source/cpp_api/core_bitset.rst @@ -0,0 +1,15 @@ +Bitset +====== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::core* + +.. doxygengroup:: bitset + :project: RAFT + :members: + :content-only: \ No newline at end of file diff --git a/docs/source/cpp_api/utils.rst b/docs/source/cpp_api/utils.rst index ccdb9919ac..4471093c8b 100644 --- a/docs/source/cpp_api/utils.rst +++ b/docs/source/cpp_api/utils.rst @@ -8,18 +8,6 @@ This page provides C++ API references for the publicly-exposed utility functions :language: c++ :class: highlight -Bitset ------- - -``#include `` - -namespace *raft::utils* - -.. doxygengroup:: bitset - :project: RAFT - :members: - :content-only: - Memory Pool -----------