Skip to content

Commit

Permalink
ANN_BENCH: common AnnBase::index_type (#2315)
Browse files Browse the repository at this point in the history
Replace the `size_t` type in the `AnnBase::search` for the output neighbor indices with a common `AnnBase::index_type`.
This PR stops short of changing the behavior of the benchmarks, since it keeps `using index_type = size_t`.

The introduction of the new type has couple benefits:
  - Makes the usage of the `index_type` more clear in the code, distinguishing it from the extents type, which is usually `size_t` as well.
  - Makes it possible to quickly change the alias to `uint32_t` during development and experiments. This is needed to avoid calling extra `linalg::map` on the produced results when the algorithm output is not compatible with `size_t`.


As a small extra change, I've factored out common IVF-PQ - CAGRA-Q refinement code into a separate `refine_helper` function.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #2315
  • Loading branch information
achirkin authored May 15, 2024
1 parent 6cc7134 commit eb1333d
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 164 deletions.
9 changes: 7 additions & 2 deletions cpp/bench/ann/src/common/ann_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ struct AlgoProperty {

class AnnBase {
public:
using index_type = size_t;

inline AnnBase(Metric metric, int dim) : metric_(metric), dim_(dim) {}
virtual ~AnnBase() noexcept = default;

Expand Down Expand Up @@ -127,8 +129,11 @@ class ANN : public AnnBase {
virtual void set_search_param(const AnnSearchParam& param) = 0;
// TODO: this assumes that an algorithm can always return k results.
// This is not always possible.
virtual void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const = 0;
virtual void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const = 0;

virtual void save(const std::string& file) const = 0;
virtual void load(const std::string& file) = 0;
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ void bench_search(::benchmark::State& state,
/**
* Each thread will manage its own outputs
*/
using index_type = size_t;
using index_type = AnnBase::index_type;
constexpr size_t kAlignResultBuf = 64;
size_t result_elem_count = k * query_set_size;
result_elem_count =
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ class FaissCpu : public ANN<T> {

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const final;

AlgoProperty get_preference() const override
{
Expand Down Expand Up @@ -169,7 +172,7 @@ void FaissCpu<T>::set_search_param(const AnnSearchParam& param)

template <typename T>
void FaissCpu<T>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,11 @@ class FaissGpu : public ANN<T>, public AnnGPU {

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const final;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const final;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand Down Expand Up @@ -196,7 +199,7 @@ void FaissGpu<T>::build(const T* dataset, size_t nrow)

template <typename T>
void FaissGpu<T>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");
Expand Down
16 changes: 11 additions & 5 deletions cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ class Ggnn : public ANN<T>, public AnnGPU {
void build(const T* dataset, size_t nrow) override { impl_->build(dataset, nrow); }

void set_search_param(const AnnSearchParam& param) override { impl_->set_search_param(param); }
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override
{
impl_->search(queries, batch_size, k, neighbors, distances);
}
Expand Down Expand Up @@ -123,8 +126,11 @@ class GgnnImpl : public ANN<T>, public AnnGPU {
void build(const T* dataset, size_t nrow) override;

void set_search_param(const AnnSearchParam& param) override;
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override;
[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { return stream_; }

void save(const std::string& file) const override;
Expand Down Expand Up @@ -243,7 +249,7 @@ void GgnnImpl<T, measure, D, KBuild, KQuery, S>::set_search_param(const AnnSearc

template <typename T, DistanceMeasure measure, int D, int KBuild, int KQuery, int S>
void GgnnImpl<T, measure, D, KBuild, KQuery, S>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(int64_t), "sizes of size_t and GGNN's KeyT are different");
if (k != KQuery) {
Expand Down
16 changes: 11 additions & 5 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ class HnswLib : public ANN<T> {
void build(const T* dataset, size_t nrow) override;

void set_search_param(const AnnSearchParam& param) override;
void search(
const T* query, int batch_size, int k, size_t* indices, float* distances) const override;
void search(const T* query,
int batch_size,
int k,
AnnBase::index_type* indices,
float* distances) const override;

void save(const std::string& path_to_index) const override;
void load(const std::string& path_to_index) override;
Expand All @@ -97,7 +100,10 @@ class HnswLib : public ANN<T> {
void set_base_layer_only() { appr_alg_->base_layer_only = true; }

private:
void get_search_knn_results_(const T* query, int k, size_t* indices, float* distances) const;
void get_search_knn_results_(const T* query,
int k,
AnnBase::index_type* indices,
float* distances) const;

std::shared_ptr<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type>> appr_alg_;
std::shared_ptr<hnswlib::SpaceInterface<typename hnsw_dist_t<T>::type>> space_;
Expand Down Expand Up @@ -176,7 +182,7 @@ void HnswLib<T>::set_search_param(const AnnSearchParam& param_)

template <typename T>
void HnswLib<T>::search(
const T* query, int batch_size, int k, size_t* indices, float* distances) const
const T* query, int batch_size, int k, AnnBase::index_type* indices, float* distances) const
{
auto f = [&](int i) {
// hnsw can only handle a single vector at a time.
Expand Down Expand Up @@ -217,7 +223,7 @@ void HnswLib<T>::load(const std::string& path_to_index)
template <typename T>
void HnswLib<T>::get_search_knn_results_(const T* query,
int k,
size_t* indices,
AnnBase::index_type* indices,
float* distances) const
{
auto result = appr_alg_->searchKnn(query, k);
Expand Down
72 changes: 72 additions & 0 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/operators.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/refine.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -166,4 +169,73 @@ inline configured_raft_resources::configured_raft_resources(configured_raft_reso
inline configured_raft_resources& configured_raft_resources::operator=(
configured_raft_resources&&) = default;

/** A helper to refine the neighbors when the data is on device or on host. */
template <typename DatasetT, typename QueriesT, typename CandidatesT>
void refine_helper(const raft::resources& res,
DatasetT dataset,
QueriesT queries,
CandidatesT candidates,
int k,
AnnBase::index_type* neighbors,
float* distances,
raft::distance::DistanceType metric)
{
using data_type = typename DatasetT::value_type;
using index_type = AnnBase::index_type;
using extents_type = index_type; // device-side refine requires this

static_assert(std::is_same_v<data_type, typename QueriesT::value_type>);
static_assert(std::is_same_v<data_type, typename DatasetT::value_type>);
static_assert(std::is_same_v<index_type, typename CandidatesT::value_type>);

extents_type batch_size = queries.extent(0);
extents_type dim = queries.extent(1);
extents_type k0 = candidates.extent(1);

if (raft::get_device_for_address(dataset.data_handle()) >= 0) {
auto dataset_device = raft::make_device_matrix_view<const data_type, extents_type>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1));
auto queries_device = raft::make_device_matrix_view<const data_type, extents_type>(
queries.data_handle(), batch_size, dim);
auto candidates_device = raft::make_device_matrix_view<const index_type, extents_type>(
candidates.data_handle(), batch_size, k0);
auto neighbors_device =
raft::make_device_matrix_view<index_type, extents_type>(neighbors, batch_size, k);
auto distances_device =
raft::make_device_matrix_view<float, extents_type>(distances, batch_size, k);

raft::neighbors::refine<index_type, data_type, float, extents_type>(res,
dataset_device,
queries_device,
candidates_device,
neighbors_device,
distances_device,
metric);
} else {
auto dataset_host = raft::make_host_matrix_view<const data_type, extents_type>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1));
auto queries_host = raft::make_host_matrix<data_type, extents_type>(batch_size, dim);
auto candidates_host = raft::make_host_matrix<index_type, extents_type>(batch_size, k0);
auto neighbors_host = raft::make_host_matrix<index_type, extents_type>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, extents_type>(batch_size, k);

auto stream = resource::get_cuda_stream(res);
raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream);
raft::copy(
candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream);

raft::resource::sync_stream(res); // wait for the queries and candidates
raft::neighbors::refine<index_type, data_type, float, extents_type>(res,
dataset_host,
queries_host.view(),
candidates_host.view(),
neighbors_host.view(),
distances_host.view(),
metric);

raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream);
raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream);
}
}

} // namespace raft::bench::ann
11 changes: 6 additions & 5 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ class RaftCagraHnswlib : public ANN<T>, public AnnGPU {

void set_search_param(const AnnSearchParam& param) override;

// TODO: if the number of results is less than k, the remaining elements of 'neighbors'
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search(const T* queries,
int batch_size,
int k,
AnnBase::index_type* neighbors,
float* distances) const override;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand Down Expand Up @@ -99,7 +100,7 @@ void RaftCagraHnswlib<T, IdxT>::load(const std::string& file)

template <typename T, typename IdxT>
void RaftCagraHnswlib<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
hnswlib_search_.search(queries, batch_size, k, neighbors, distances);
}
Expand Down
Loading

0 comments on commit eb1333d

Please sign in to comment.