Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CAGRA-Q to ANN benchmarks #2233

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ extern template class raft::bench::ann::RaftIvfPQ<int8_t, int64_t>;
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA
extern template class raft::bench::ann::RaftCagra<float, uint32_t>;
extern template class raft::bench::ann::RaftCagra<half, uint32_t>;
extern template class raft::bench::ann::RaftCagra<uint8_t, uint32_t>;
extern template class raft::bench::ann::RaftCagra<int8_t, uint32_t>;
#endif
Expand Down Expand Up @@ -149,6 +150,20 @@ void parse_build_param(const nlohmann::json& conf,
}
}

inline void parse_build_param(const nlohmann::json& conf, raft::neighbors::vpq_params& param)
{
if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); }
if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); }
if (conf.contains("vq_n_centers")) { param.vq_n_centers = conf.at("vq_n_centers"); }
if (conf.contains("kmeans_n_iters")) { param.kmeans_n_iters = conf.at("kmeans_n_iters"); }
if (conf.contains("vq_kmeans_trainset_fraction")) {
param.vq_kmeans_trainset_fraction = conf.at("vq_kmeans_trainset_fraction");
}
if (conf.contains("pq_kmeans_trainset_fraction")) {
param.pq_kmeans_trainset_fraction = conf.at("pq_kmeans_trainset_fraction");
}
}

nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf,
const std::string& prefix,
bool remove_prefix = true)
Expand Down Expand Up @@ -204,6 +219,12 @@ void parse_build_param(const nlohmann::json& conf,
}
param.nn_descent_params = nn_param;
}
nlohmann::json comp_search_conf = collect_conf_with_prefix(conf, "compression_");
if (!comp_search_conf.empty()) {
raft::neighbors::vpq_params vpq_pams;
parse_build_param(ivf_pq_build_conf, vpq_pams);
param.cagra_params.compression.emplace(vpq_pams);
}
}

raft::bench::ann::AllocatorType parse_allocator(std::string mem_type)
Expand Down Expand Up @@ -248,5 +269,7 @@ void parse_search_param(const nlohmann::json& conf,
if (conf.contains("internal_dataset_memory_type")) {
param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type"));
}
// Same ratio as in IVF-PQ
param.refine_ratio = conf.value("refine_ratio", 1.0f);
}
#endif
83 changes: 80 additions & 3 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/dataset.hpp>
#include <raft/neighbors/detail/cagra/cagra_build.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft/neighbors/nn_descent_types.hpp>
Expand Down Expand Up @@ -56,6 +57,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {

struct SearchParam : public AnnSearchParam {
raft::neighbors::experimental::cagra::search_params p;
float refine_ratio;
AllocatorType graph_mem = AllocatorType::Device;
AllocatorType dataset_mem = AllocatorType::Device;
auto needs_dataset() const -> bool override { return true; }
Expand Down Expand Up @@ -98,6 +100,8 @@ class RaftCagra : public ANN<T>, public AnnGPU {
// will be filled with (size_t)-1
void search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const override;
void search_base(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const;

[[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override
{
Expand All @@ -124,6 +128,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {
raft::mr::cuda_huge_page_resource mr_huge_page_;
AllocatorType graph_mem_;
AllocatorType dataset_mem_;
float refine_ratio_;
BuildParam index_params_;
bool need_dataset_update_;
raft::neighbors::cagra::search_params search_params_;
Expand Down Expand Up @@ -151,6 +156,9 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow)

auto& params = index_params_.cagra_params;

// Do include the compressed dataset for the CAGRA-Q
bool shall_include_dataset = params.compression.has_value();

index_ = std::make_shared<raft::neighbors::cagra::index<T, IdxT>>(
std::move(raft::neighbors::cagra::detail::build(handle_,
params,
Expand All @@ -159,7 +167,7 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow)
index_params_.ivf_pq_refine_rate,
index_params_.ivf_pq_build_params,
index_params_.ivf_pq_search_params,
false)));
shall_include_dataset)));
}

inline std::string allocator_to_string(AllocatorType mem_type)
Expand All @@ -179,6 +187,7 @@ void RaftCagra<T, IdxT>::set_search_param(const AnnSearchParam& param)
{
auto search_param = dynamic_cast<const SearchParam&>(param);
search_params_ = search_param.p;
refine_ratio_ = search_param.refine_ratio;
if (search_param.graph_mem != graph_mem_) {
// Move graph to correct memory space
graph_mem_ = search_param.graph_mem;
Expand Down Expand Up @@ -223,12 +232,16 @@ void RaftCagra<T, IdxT>::set_search_param(const AnnSearchParam& param)
template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::set_search_dataset(const T* dataset, size_t nrow)
{
using ds_idx_type = decltype(index_->data().n_rows());
bool is_vpq =
dynamic_cast<const raft::neighbors::vpq_dataset<half, ds_idx_type>*>(&index_->data()) ||
dynamic_cast<const raft::neighbors::vpq_dataset<float, ds_idx_type>*>(&index_->data());
// It can happen that we are re-using a previous algo object which already has
// the dataset set. Check if we need update.
if (static_cast<size_t>(input_dataset_v_->extent(0)) != nrow ||
input_dataset_v_->data_handle() != dataset) {
*input_dataset_v_ = make_device_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
need_dataset_update_ = true;
need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset.
}
}

Expand Down Expand Up @@ -258,7 +271,7 @@ std::unique_ptr<ANN<T>> RaftCagra<T, IdxT>::copy()
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::search(
void RaftCagra<T, IdxT>::search_base(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
IdxT* neighbors_IdxT;
Expand Down Expand Up @@ -286,4 +299,68 @@ void RaftCagra<T, IdxT>::search(
raft::resource::get_cuda_stream(handle_));
}
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
auto k0 = static_cast<size_t>(refine_ratio_ * k);
const bool disable_refinement = k0 <= static_cast<size_t>(k);
const raft::resources& res = handle_;
auto stream = resource::get_cuda_stream(res);

if (disable_refinement) {
search_base(queries, batch_size, k, neighbors, distances);
} else {
auto candidate_ixs = raft::make_device_matrix<int64_t, int64_t>(res, batch_size, k0);
auto candidate_dists = raft::make_device_matrix<float, int64_t>(res, batch_size, k0);
search_base(queries,
batch_size,
k0,
reinterpret_cast<size_t*>(candidate_ixs.data_handle()),
candidate_dists.data_handle());

if (raft::get_device_for_address(input_dataset_v_->data_handle()) >= 0) {
auto queries_v =
raft::make_device_matrix_view<const T, int64_t>(queries, batch_size, dimension_);
auto neighours_v = raft::make_device_matrix_view<int64_t, int64_t>(
reinterpret_cast<int64_t*>(neighbors), batch_size, k);
auto distances_v = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);
raft::neighbors::refine<int64_t, T, float, int64_t>(
res,
*input_dataset_v_,
queries_v,
raft::make_const_mdspan(candidate_ixs.view()),
neighours_v,
distances_v,
index_->metric());
} else {
auto dataset_host = raft::make_host_matrix_view<const T, int64_t>(
input_dataset_v_->data_handle(), input_dataset_v_->extent(0), input_dataset_v_->extent(1));
auto queries_host = raft::make_host_matrix<T, int64_t>(batch_size, dimension_);
auto candidates_host = raft::make_host_matrix<int64_t, int64_t>(batch_size, k0);
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, int64_t>(batch_size, k);

raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream);
raft::copy(
candidates_host.data_handle(), candidate_ixs.data_handle(), candidates_host.size(), stream);

raft::resource::sync_stream(res); // wait for the queries and candidates
raft::neighbors::refine<int64_t, T, float, int64_t>(res,
dataset_host,
queries_host.view(),
candidates_host.view(),
neighbors_host.view(),
distances_host.view(),
index_->metric());

raft::copy(neighbors,
reinterpret_cast<size_t*>(neighbors_host.data_handle()),
neighbors_host.size(),
stream);
raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream);
}
}
}
} // namespace raft::bench::ann
Loading