diff --git a/cpp/bench/prims/core/bitset.cu b/cpp/bench/prims/core/bitset.cu index 5f44aa9af5..85e24a3d37 100644 --- a/cpp/bench/prims/core/bitset.cu +++ b/cpp/bench/prims/core/bitset.cu @@ -44,7 +44,7 @@ struct bitset_bench : public fixture { loop_on_state(state, [this]() { auto my_bitset = raft::core::bitset( this->res, raft::make_const_mdspan(mask.view()), params.bitset_len); - my_bitset.test(res, raft::make_const_mdspan(queries.view()), outputs.view()); + my_bitset.test(raft::make_const_mdspan(queries.view()), outputs.view()); }); } diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh index 63f6c14686..0748177dff 100644 --- a/cpp/bench/prims/neighbors/cagra_bench.cuh +++ b/cpp/bench/prims/neighbors/cagra_bench.cuh @@ -85,7 +85,7 @@ struct CagraBench : public fixture { resource::get_thrust_policy(handle), thrust::device_pointer_cast(removed_indices.data_handle()), thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); - removed_indices_bitset_.set(handle, removed_indices.view()); + removed_indices_bitset_.set(removed_indices.view()); index_.emplace(raft::neighbors::cagra::index( handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view()))); } diff --git a/cpp/include/raft/core/bitset.cuh b/cpp/include/raft/core/bitset.cuh index d75957817a..bfb3364c07 100644 --- a/cpp/include/raft/core/bitset.cuh +++ b/cpp/include/raft/core/bitset.cuh @@ -161,13 +161,11 @@ struct bitset { bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), raft::resource::get_cuda_stream(res)}, - bitset_len_{bitset_len} + bitset_len_{bitset_len}, + res_{res} { - cudaMemsetAsync(bitset_.data(), - default_value ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res)); - set(res, mask_index, !default_value); + reset(default_value); + set(mask_index, !default_value); } /** @@ -180,12 +178,10 @@ struct bitset { bitset(const raft::resources& res, index_t bitset_len, bool default_value = true) : bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)), resource::get_cuda_stream(res)}, - bitset_len_{bitset_len} + bitset_len_{bitset_len}, + res_{res} { - cudaMemsetAsync(bitset_.data(), - default_value ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res)); + reset(default_value); } // Disable copy constructor bitset(const bitset&) = delete; @@ -237,7 +233,7 @@ struct bitset { /** @brief Resize the bitset. If the requested size is larger, new memory is allocated and set to * the default value. */ - void resize(const raft::resources& res, index_t new_bitset_len, bool default_value = true) + void resize(index_t new_bitset_len, bool default_value = true) { auto old_size = raft::ceildiv(bitset_len_, bitset_element_size); auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size); @@ -245,10 +241,10 @@ struct bitset { bitset_len_ = new_bitset_len; if (old_size < new_size) { // If the new size is larger, set the new bits to the default value - cudaMemsetAsync(bitset_.data() + old_size, - default_value ? 0xff : 0x00, - (new_size - old_size) * sizeof(bitset_t), - resource::get_cuda_stream(res)); + RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data() + old_size, + default_value ? 0xff : 0x00, + (new_size - old_size) * sizeof(bitset_t), + resource::get_cuda_stream(res_))); } } @@ -261,14 +257,13 @@ struct bitset { * @param output List of outputs */ template - void test(const raft::resources& res, - raft::device_vector_view queries, + void test(raft::device_vector_view queries, raft::device_vector_view output) const { RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size"); auto bitset_view = view(); raft::linalg::map( - res, + res_, output, [bitset_view] __device__(index_t query) { return output_t(bitset_view.test(query)); }, queries); @@ -280,12 +275,10 @@ struct bitset { * @param mask_index indices to remove from the bitset * @param set_value Value to set the bits to (true or false) */ - void set(const raft::resources& res, - raft::device_vector_view mask_index, - bool set_value = false) + void set(raft::device_vector_view mask_index, bool set_value = false) { auto this_bitset_view = view(); - thrust::for_each_n(resource::get_thrust_policy(res), + thrust::for_each_n(resource::get_thrust_policy(res_), mask_index.data_handle(), mask_index.extent(0), [this_bitset_view, set_value] __device__(const index_t sample_index) { @@ -294,14 +287,12 @@ struct bitset { } /** * @brief Flip all the bits in a bitset. - * - * @param res RAFT resources */ - void flip(const raft::resources& res) + void flip() { auto bitset_span = this->to_mdspan(); raft::linalg::map( - res, + res_, bitset_span, [] __device__(bitset_t element) { return bitset_t(~element); }, raft::make_const_mdspan(bitset_span)); @@ -309,23 +300,21 @@ struct bitset { /** * @brief Reset the bits in a bitset. * - * @param res RAFT resources * @param default_value Value to set the bits to (true or false) */ - void reset(const raft::resources& res, bool default_value = true) + void reset(bool default_value = true) { - cudaMemsetAsync(bitset_.data(), - default_value ? 0xff : 0x00, - n_elements() * sizeof(bitset_t), - resource::get_cuda_stream(res)); + RAFT_CUDA_TRY(cudaMemsetAsync(bitset_.data(), + default_value ? 0xff : 0x00, + n_elements() * sizeof(bitset_t), + resource::get_cuda_stream(res_))); } /** * @brief Returns the number of bits set to true in count_gpu_scalar. * - * @param[in] res RAFT resources * @param[out] count_gpu_scalar Device scalar to store the count */ - void count(const raft::resources& res, raft::device_scalar_view count_gpu_scalar) + void count(raft::device_scalar_view count_gpu_scalar) { auto n_elements_ = n_elements(); auto count_gpu = @@ -337,66 +326,106 @@ struct bitset { bitset_t last_element_mask = n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0}; raft::linalg::coalesced_reduction( - res, + res_, bitset_matrix_view, count_gpu, index_t{0}, false, [last_element_mask, n_elements_] __device__(bitset_t element, index_t index) { - index_t res = 0; + index_t result = 0; if constexpr (bitset_element_size == 64) { // Needed because __popc doesn't support 64bit if (index == n_elements_ - 1) - res = index_t(raft::detail::native_popc(element & last_element_mask)); + result = index_t(raft::detail::native_popc(element & last_element_mask)); else - res = index_t(raft::detail::native_popc(element)); + result = index_t(raft::detail::native_popc(element)); } else { if (index == n_elements_ - 1) - res = index_t(__popc(element & last_element_mask)); + result = index_t(__popc(element & last_element_mask)); else - res = index_t(__popc(element)); + result = index_t(__popc(element)); } - return res; + return result; }); } /** * @brief Returns the number of bits set to true. * - * @param[in] res RAFT resources * @return index_t Number of bits set to true */ - auto count(const raft::resources& res) -> index_t + auto count() -> index_t { - auto count_gpu_scalar = raft::make_device_scalar(res, 0.0); - count(res, count_gpu_scalar.view()); + auto count_gpu_scalar = raft::make_device_scalar(res_, 0.0); + count(count_gpu_scalar.view()); index_t count_cpu = 0; raft::update_host( - &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); + &count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res_)); + resource::sync_stream(res_); return count_cpu; } /** * @brief Checks if any of the bits are set to true in the bitset. - * - * @param res RAFT resources */ - bool any(const raft::resources& res) { return count(res) > 0; } + bool any() { return count() > 0; } /** * @brief Checks if all of the bits are set to true in the bitset. - * - * @param res RAFT resources */ - bool all(const raft::resources& res) { return count(res) == bitset_len_; } + bool all() { return count() == bitset_len_; } /** * @brief Checks if none of the bits are set to true in the bitset. - * - * @param res RAFT resources */ - bool none(const raft::resources& res) { return count(res) == 0; } + bool none() { return count() == 0; } + + bitset& operator|=(const bitset& other) + { + RAFT_EXPECTS(size() == other.size(), "Sizes must be equal"); + auto this_span = to_mdspan(); + auto other_span = other.to_mdspan(); + raft::linalg::map( + res_, + this_span, + [] __device__(bitset_t this_element, bitset_t other_element) { + return this_element | other_element; + }, + raft::make_const_mdspan(this_span), + other_span); + return *this; + } + bitset& operator&=(const bitset& other) + { + RAFT_EXPECTS(size() == other.size(), "Sizes must be equal"); + auto this_span = to_mdspan(); + auto other_span = other.to_mdspan(); + raft::linalg::map( + res_, + this_span, + [] __device__(bitset_t this_element, bitset_t other_element) { + return this_element & other_element; + }, + raft::make_const_mdspan(this_span), + other_span); + return *this; + } + bitset& operator^=(const bitset& other) + { + RAFT_EXPECTS(size() == other.size(), "Sizes must be equal"); + auto this_span = to_mdspan(); + auto other_span = other.to_mdspan(); + raft::linalg::map( + res_, + this_span, + [] __device__(bitset_t this_element, bitset_t other_element) { + return this_element ^ other_element; + }, + raft::make_const_mdspan(this_span), + other_span); + return *this; + } private: raft::device_uvector bitset_; index_t bitset_len_; + const raft::resources& res_; }; /** @} */ diff --git a/cpp/test/core/bitset.cu b/cpp/test/core/bitset.cu index 9d12f04891..edda1884b3 100644 --- a/cpp/test/core/bitset.cu +++ b/cpp/test/core/bitset.cu @@ -128,7 +128,7 @@ class BitsetTest : public testing::TestWithParam { // Create queries and verify the test results raft::random::uniformInt(res, rng, query_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream); - my_bitset.test(res, raft::make_const_mdspan(query_device.view()), result_device.view()); + my_bitset.test(raft::make_const_mdspan(query_device.view()), result_device.view()); update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream); test_cpu_bitset(bitset_ref, query_cpu, result_ref); resource::sync_stream(res, stream); @@ -138,7 +138,7 @@ class BitsetTest : public testing::TestWithParam { raft::random::uniformInt(res, rng, mask_device.view(), index_t(0), index_t(spec.bitset_len)); update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream); resource::sync_stream(res, stream); - my_bitset.set(res, mask_device.view()); + my_bitset.set(mask_device.view()); update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); add_cpu_bitset(bitset_ref, mask_cpu); @@ -146,25 +146,40 @@ class BitsetTest : public testing::TestWithParam { ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); // Flip the bitset and re-test - auto bitset_count = my_bitset.count(res); - my_bitset.flip(res); - ASSERT_EQ(my_bitset.count(res), spec.bitset_len - bitset_count); + auto bitset_count = my_bitset.count(); + my_bitset.flip(); + ASSERT_EQ(my_bitset.count(), spec.bitset_len - bitset_count); update_host(bitset_result.data(), my_bitset.data_handle(), bitset_result.size(), stream); flip_cpu_bitset(bitset_ref); resource::sync_stream(res, stream); ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); - my_bitset.reset(res, false); - ASSERT_EQ(my_bitset.any(res), false); - ASSERT_EQ(my_bitset.none(res), true); + // Test count() operations + my_bitset.reset(false); + ASSERT_EQ(my_bitset.any(), false); + ASSERT_EQ(my_bitset.none(), true); raft::linalg::range(query_device.data_handle(), query_device.size(), stream); - my_bitset.set(res, raft::make_const_mdspan(query_device.view()), true); - bitset_count = my_bitset.count(res); + my_bitset.set(raft::make_const_mdspan(query_device.view()), true); + bitset_count = my_bitset.count(); ASSERT_EQ(bitset_count, query_device.size()); - ASSERT_EQ(my_bitset.any(res), true); - ASSERT_EQ(my_bitset.none(res), false); - - ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare())); + ASSERT_EQ(my_bitset.any(), true); + ASSERT_EQ(my_bitset.none(), false); + + // Test operators + auto my_bitset_2 = raft::core::bitset( + res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false); + auto my_bitset_3 = raft::core::bitset( + res, raft::make_const_mdspan(mask_device.view()), index_t(spec.bitset_len), false); + my_bitset_2 ^= my_bitset; + ASSERT_FALSE(devArrMatch(my_bitset_2.data_handle(), + my_bitset_3.data_handle(), + my_bitset.n_elements(), + raft::Compare())); + my_bitset_2 ^= my_bitset; + ASSERT_TRUE(devArrMatch(my_bitset_2.data_handle(), + my_bitset_3.data_handle(), + my_bitset.n_elements(), + raft::Compare())); } };