Skip to content

Commit

Permalink
[FEA] support of prefiltered brute force (#2294)
Browse files Browse the repository at this point in the history
- This PR is one part of the feature of #1969
- Add the API of 'search_with_filtering' for brute force.
Authors:
  - James Rong (https://github.com/rhdong)

```shell
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
-----------------------------------------------------------------------------------------------------
Benchmark                                                           Time             CPU   Iterations
-----------------------------------------------------------------------------------------------------
KNN/float/int64_t/brute_force_filter_knn/0/0/0/manual_time       33.1 ms         69.9 ms           21 1000000#128#1000#255#0#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/1/0/0/manual_time       38.0 ms         74.8 ms           18 1000000#128#1000#255#0#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/2/0/0/manual_time       41.7 ms         78.5 ms           17 1000000#128#1000#255#0.8#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/3/0/0/manual_time       57.5 ms         94.3 ms           12 1000000#128#1000#255#0.8#L2Expanded#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/4/0/0/manual_time       19.7 ms         56.4 ms           35 1000000#128#1000#255#0.9#InnerProduct#NO_COPY#SEARCH
KNN/float/int64_t/brute_force_filter_knn/5/0/0/manual_time       26.1 ms         62.8 ms           27 1000000#128#1000#255#0.9#L2Expanded#NO_COPY#SEARCH```

Authors:
  - rhdong (https://github.com/rhdong)
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Robert Maynard (https://github.com/robertmaynard)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Divye Gala (https://github.com/divyegala)

URL: #2294
  • Loading branch information
rhdong authored May 24, 2024
1 parent 5c6cd92 commit 5f0dfed
Show file tree
Hide file tree
Showing 18 changed files with 388 additions and 365 deletions.
116 changes: 17 additions & 99 deletions cpp/include/raft/core/bitmap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,112 +16,30 @@

#pragma once

#include <raft/core/bitmap.hpp>
#include <raft/core/bitset.cuh>
#include <raft/core/detail/mdspan_util.cuh>
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resources.hpp>

namespace raft::core {
/**
* @defgroup bitmap Bitmap
* @{
*/
/**
* @brief View of a RAFT Bitmap.
*
* This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view
* with row major order. This class provides functionality for handling a matrix where each element
* is represented as a bit in a bitmap.
*
* @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t.
* @tparam index_t Indexing type used. Default is uint32_t.
*/
template <typename bitmap_t = uint32_t, typename index_t = uint32_t>
struct bitmap_view : public bitset_view<bitmap_t, index_t> {
static_assert((std::is_same<bitmap_t, uint32_t>::value ||
std::is_same<bitmap_t, uint64_t>::value),
"The bitmap_t must be uint32_t or uint64_t.");
/**
* @brief Create a bitmap view from a device raw pointer.
*
* @param bitmap_ptr Device raw pointer
* @param rows Number of row in the matrix.
* @param cols Number of col in the matrix.
*/
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, rows * cols), rows_(rows), cols_(cols)
{
}

/**
* @brief Create a bitmap view from a device vector view of the bitset.
*
* @param bitmap_span Device vector view of the bitmap
* @param rows Number of row in the matrix.
* @param cols Number of col in the matrix.
*/
_RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view<bitmap_t, index_t> bitmap_span,
index_t rows,
index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_span, rows * cols), rows_(rows), cols_(cols)
{
}
#include <type_traits>

private:
// Hide the constructors of bitset_view.
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, bitmap_len)
{
}

_RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view<bitmap_t, index_t> bitmap_span,
index_t bitmap_len)
: bitset_view<bitmap_t, index_t>(bitmap_span, bitmap_len)
{
}

public:
/**
* @brief Device function to test if a given row and col are set in the bitmap.
*
* @param row Row index of the bit to test
* @param col Col index of the bit to test
* @return bool True if index has not been unset in the bitset
*/
inline _RAFT_DEVICE auto test(const index_t row, const index_t col) const -> bool
{
return test(row * cols_ + col);
}

/**
* @brief Device function to set a given row and col to set_value in the bitset.
*
* @param row Row index of the bit to set
* @param col Col index of the bit to set
* @param new_value Value to set the bit to (true or false)
*/
inline _RAFT_DEVICE void set(const index_t row, const index_t col, bool new_value) const
{
set(row * cols_ + col, &new_value);
}

/**
* @brief Get the total number of rows
* @return index_t The total number of rows
*/
inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; }

/**
* @brief Get the total number of columns
* @return index_t The total number of columns
*/
inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; }
namespace raft::core {

private:
index_t rows_;
index_t cols_;
};
template <typename bitmap_t, typename index_t>
_RAFT_HOST_DEVICE inline bool bitmap_view<bitmap_t, index_t>::test(const index_t row,
const index_t col) const
{
return test(row * cols_ + col);
}

template <typename bitmap_t, typename index_t>
_RAFT_HOST_DEVICE void bitmap_view<bitmap_t, index_t>::set(const index_t row,
const index_t col,
bool new_value) const
{
set(row * cols_ + col, &new_value);
}

/** @} */
} // end namespace raft::core
123 changes: 123 additions & 0 deletions cpp/include/raft/core/bitmap.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/bitset.hpp>
#include <raft/core/detail/mdspan_util.cuh>
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resources.hpp>

#include <type_traits>

namespace raft::core {
/**
* @defgroup bitmap Bitmap
* @{
*/
/**
* @brief View of a RAFT Bitmap.
*
* This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view
* with row major order. This class provides functionality for handling a matrix where each element
* is represented as a bit in a bitmap.
*
* @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t.
* @tparam index_t Indexing type used. Default is uint32_t.
*/
template <typename bitmap_t = uint32_t, typename index_t = uint32_t>
struct bitmap_view : public bitset_view<bitmap_t, index_t> {
static_assert((std::is_same<typename std::remove_const<bitmap_t>::type, uint32_t>::value ||
std::is_same<typename std::remove_const<bitmap_t>::type, uint64_t>::value),
"The bitmap_t must be uint32_t or uint64_t.");
/**
* @brief Create a bitmap view from a device raw pointer.
*
* @param bitmap_ptr Device raw pointer
* @param rows Number of row in the matrix.
* @param cols Number of col in the matrix.
*/
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, rows * cols), rows_(rows), cols_(cols)
{
}

/**
* @brief Create a bitmap view from a device vector view of the bitset.
*
* @param bitmap_span Device vector view of the bitmap
* @param rows Number of row in the matrix.
* @param cols Number of col in the matrix.
*/
_RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view<bitmap_t, index_t> bitmap_span,
index_t rows,
index_t cols)
: bitset_view<bitmap_t, index_t>(bitmap_span, rows * cols), rows_(rows), cols_(cols)
{
}

private:
// Hide the constructors of bitset_view.
_RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len)
: bitset_view<bitmap_t, index_t>(bitmap_ptr, bitmap_len)
{
}

_RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view<bitmap_t, index_t> bitmap_span,
index_t bitmap_len)
: bitset_view<bitmap_t, index_t>(bitmap_span, bitmap_len)
{
}

public:
/**
* @brief Device function to test if a given row and col are set in the bitmap.
*
* @param row Row index of the bit to test
* @param col Col index of the bit to test
* @return bool True if index has not been unset in the bitset
*/
inline _RAFT_HOST_DEVICE bool test(const index_t row, const index_t col) const;

/**
* @brief Device function to set a given row and col to set_value in the bitset.
*
* @param row Row index of the bit to set
* @param col Col index of the bit to set
* @param new_value Value to set the bit to (true or false)
*/
inline _RAFT_HOST_DEVICE void set(const index_t row, const index_t col, bool new_value) const;

/**
* @brief Get the total number of rows
* @return index_t The total number of rows
*/
inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; }

/**
* @brief Get the total number of columns
* @return index_t The total number of columns
*/
inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; }

private:
index_t rows_;
index_t cols_;
};

/** @} */
} // end namespace raft::core
42 changes: 10 additions & 32 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#pragma once

#include <raft/core/bitset.hpp>
#include <raft/core/detail/mdspan_util.cuh> // native_popc
#include <raft/core/detail/popc.cuh>
#include <raft/core/device_container_policy.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
Expand Down Expand Up @@ -60,6 +60,12 @@ _RAFT_HOST_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_
}
}

template <typename bitset_t, typename index_t>
_RAFT_HOST_DEVICE inline index_t bitset_view<bitset_t, index_t>::n_elements() const
{
return raft::ceildiv(bitset_len_, bitset_element_size);
}

template <typename bitset_t, typename index_t>
bitset<bitset_t, index_t>::bitset(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
Expand Down Expand Up @@ -161,37 +167,9 @@ template <typename bitset_t, typename index_t>
void bitset<bitset_t, index_t>::count(const raft::resources& res,
raft::device_scalar_view<index_t> count_gpu_scalar)
{
auto n_elements_ = n_elements();
auto count_gpu =
raft::make_device_vector_view<index_t, index_t>(count_gpu_scalar.data_handle(), 1);
auto bitset_matrix_view = raft::make_device_matrix_view<const bitset_t, index_t, raft::col_major>(
bitset_.data(), n_elements_, 1);

bitset_t n_last_element = (bitset_len_ % bitset_element_size);
bitset_t last_element_mask =
n_last_element ? (bitset_t)((bitset_t{1} << n_last_element) - bitset_t{1}) : ~bitset_t{0};
raft::linalg::coalesced_reduction(
res,
bitset_matrix_view,
count_gpu,
index_t{0},
false,
[last_element_mask, n_elements_] __device__(bitset_t element, index_t index) {
index_t result = 0;
if constexpr (bitset_element_size == 64) {
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(element & last_element_mask));
else
result = index_t(raft::detail::popc(element));
} else { // Needed because popc is not overloaded for 16 and 8 bit elements
if (index == n_elements_ - 1)
result = index_t(raft::detail::popc(uint32_t{element} & last_element_mask));
else
result = index_t(raft::detail::popc(uint32_t{element}));
}

return result;
});
auto values =
raft::make_device_vector_view<const bitset_t, index_t>(bitset_.data(), n_elements());
raft::detail::popc(res, values, bitset_len_, count_gpu_scalar);
}

} // end namespace raft::core
75 changes: 75 additions & 0 deletions cpp/include/raft/core/detail/popc.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/detail/mdspan_util.cuh>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/coalesced_reduction.cuh>

namespace raft::detail {

/**
* @brief Count the number of bits that are set to 1 in a vector.
*
* @tparam value_t the value type of the vector.
* @tparam index_t the index type of vector and scalar.
*
* @param[in] res raft handle for managing expensive resources
* @param[in] values Number of row in the matrix.
* @param[in] max_len Maximum number of bits to count.
* @param[out] counter Number of bits that are set to 1.
*/
template <typename value_t, typename index_t>
void popc(const raft::resources& res,
device_vector_view<value_t, index_t> values,
index_t max_len,
raft::device_scalar_view<index_t> counter)
{
auto values_size = values.size();
auto values_matrix = raft::make_device_matrix_view<value_t, index_t, col_major>(
values.data_handle(), values_size, 1);
auto counter_vector = raft::make_device_vector_view<index_t, index_t>(counter.data_handle(), 1);

static constexpr index_t len_per_item = sizeof(value_t) * 8;

value_t tail_len = (max_len % len_per_item);
value_t tail_mask = tail_len ? (value_t)((value_t{1} << tail_len) - value_t{1}) : ~value_t{0};
raft::linalg::coalesced_reduction(
res,
values_matrix,
counter_vector,
index_t{0},
false,
[tail_mask, values_size] __device__(value_t value, index_t index) {
index_t result = 0;
if constexpr (len_per_item == 64) {
if (index == values_size - 1)
result = index_t(raft::detail::popc(value & tail_mask));
else
result = index_t(raft::detail::popc(value));
} else { // Needed because popc is not overloaded for 16 and 8 bit elements
if (index == values_size - 1)
result = index_t(raft::detail::popc(uint32_t{value} & tail_mask));
else
result = index_t(raft::detail::popc(uint32_t{value}));
}

return result;
});
}

} // end namespace raft::detail
Loading

0 comments on commit 5f0dfed

Please sign in to comment.