Skip to content

Commit

Permalink
Merge branch 'branch-23.12' into faiss-ivf
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Nov 7, 2023
2 parents 6a5443a + c7aa826 commit bca8f40
Show file tree
Hide file tree
Showing 42 changed files with 1,256 additions and 143 deletions.
4 changes: 3 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ INSTALL_TARGET=install
BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"

CACHE_ARGS=""
Expand Down Expand Up @@ -324,6 +324,8 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
$CMAKE_TARGET == *"DISTANCE_TEST"* || \
$CMAKE_TARGET == *"MATRIX_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_TEST"* || \
$CMAKE_TARGET == *"SPARSE_DIST_TEST" || \
$CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \
Expand Down
20 changes: 16 additions & 4 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <benchmark/benchmark.h>

#include <algorithm>
#include <atomic>
#include <chrono>
#include <cmath>
#include <condition_variable>
Expand All @@ -39,6 +40,7 @@ namespace raft::bench::ann {

std::mutex init_mutex;
std::condition_variable cond_var;
std::atomic_int processed_threads{0};

static inline std::unique_ptr<AnnBase> current_algo{nullptr};
static inline std::shared_ptr<AlgoProperty> current_algo_props{nullptr};
Expand Down Expand Up @@ -198,7 +200,8 @@ void bench_search(::benchmark::State& state,
* Make sure the first thread loads the algo and dataset
*/
if (state.thread_index() == 0) {
std::lock_guard lk(init_mutex);
std::unique_lock lk(init_mutex);
cond_var.wait(lk, [] { return processed_threads.load(std::memory_order_acquire) == 0; });
// algo is static to cache it between close search runs to save time on index loading
static std::string index_file = "";
if (index.file != index_file) {
Expand Down Expand Up @@ -247,11 +250,14 @@ void bench_search(::benchmark::State& state,
}

query_set = dataset->query_set(current_algo_props->query_memory_type);
processed_threads.store(state.threads(), std::memory_order_acq_rel);
cond_var.notify_all();
} else {
// All other threads will wait for the first thread to initialize the algo.
std::unique_lock lk(init_mutex);
cond_var.wait(lk, [] { return current_algo_props.get() != nullptr; });
// All other threads will wait for the first thread to initialize the algo.
cond_var.wait(lk, [&state] {
return processed_threads.load(std::memory_order_acquire) == state.threads();
});
// gbench ensures that all threads are synchronized at the start of the benchmark loop.
// We are accessing shared variables (like current_algo, current_algo_probs) before the
// benchmark loop, therefore the synchronization here is necessary.
Expand Down Expand Up @@ -292,6 +298,7 @@ void bench_search(::benchmark::State& state,

// advance to the next batch
batch_offset = (batch_offset + n_queries) % query_set_size;

queries_processed += n_queries;
}
}
Expand All @@ -312,6 +319,10 @@ void bench_search(::benchmark::State& state,

if (state.skipped()) { return; }

// assume thread has finished processing successfully at this point
// last thread to finish processing notifies all
if (processed_threads-- == 0) { cond_var.notify_all(); }

// Use the last thread as a sanity check that all the threads are working.
if (state.thread_index() == state.threads() - 1) {
// evaluate recall
Expand Down Expand Up @@ -410,7 +421,6 @@ void register_search(std::shared_ptr<const Dataset<T>> dataset,
auto* b = ::benchmark::RegisterBenchmark(
index.name + suf, bench_search<T>, index, i, dataset, metric_objective)
->Unit(benchmark::kMillisecond)
->ThreadRange(threads[0], threads[1])
/**
* The following are important for getting accuracy QPS measurements on both CPU
* and GPU These make sure that
Expand All @@ -420,6 +430,8 @@ void register_search(std::shared_ptr<const Dataset<T>> dataset,
*/
->MeasureProcessCPUTime()
->UseRealTime();

if (metric_objective == Objective::THROUGHPUT) { b->ThreadRange(threads[0], threads[1]); }
}
}
}
Expand Down
13 changes: 5 additions & 8 deletions cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ void HnswLib<T>::build(const T* dataset, size_t nrow, cudaStream_t)
char buf[20];
std::time_t now = std::time(nullptr);
std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now));

printf("%s building %zu / %zu\n", buf, i, items_per_thread);
fflush(stdout);
}
Expand All @@ -163,13 +162,11 @@ void HnswLib<T>::set_search_param(const AnnSearchParam& param_)
auto param = dynamic_cast<const SearchParam&>(param_);
appr_alg_->ef_ = param.ef;
metric_objective_ = param.metric_objective;
num_threads_ = param.num_threads;

bool use_pool = (metric_objective_ == Objective::LATENCY && param.num_threads > 1) &&
(!thread_pool_ || num_threads_ != param.num_threads);
if (use_pool) {
num_threads_ = param.num_threads;
thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_);
}
// Create a pool if multiple query threads have been set and the pool hasn't been created already
bool create_pool = (metric_objective_ == Objective::LATENCY && num_threads_ > 1 && !thread_pool_);
if (create_pool) { thread_pool_ = std::make_unique<FixedThreadPool>(num_threads_); }
}

template <typename T>
Expand All @@ -180,7 +177,7 @@ void HnswLib<T>::search(
// hnsw can only handle a single vector at a time.
get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k);
};
if (metric_objective_ == Objective::LATENCY) {
if (metric_objective_ == Objective::LATENCY && num_threads_ > 1) {
thread_pool_->submit(f, batch_size);
} else {
for (int i = 0; i < batch_size; i++) {
Expand Down
77 changes: 67 additions & 10 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <algorithm>
#include <cmath>
#include <memory>
#include <raft/core/logger.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <stdexcept>
#include <string>
Expand All @@ -35,8 +36,10 @@ extern template class raft::bench::ann::RaftIvfFlatGpu<float, int64_t>;
extern template class raft::bench::ann::RaftIvfFlatGpu<uint8_t, int64_t>;
extern template class raft::bench::ann::RaftIvfFlatGpu<int8_t, int64_t>;
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA)
#include "raft_ivf_pq_wrapper.h"
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
extern template class raft::bench::ann::RaftIvfPQ<float, int64_t>;
extern template class raft::bench::ann::RaftIvfPQ<uint8_t, int64_t>;
extern template class raft::bench::ann::RaftIvfPQ<int8_t, int64_t>;
Expand Down Expand Up @@ -70,12 +73,12 @@ void parse_search_param(const nlohmann::json& conf,
}
#endif

#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA)
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::BuildParam& param)
{
param.n_lists = conf.at("nlist");
if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); }
if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); }
if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); }
if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); }
Expand All @@ -97,7 +100,7 @@ template <typename T, typename IdxT>
void parse_search_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::SearchParam& param)
{
param.pq_param.n_probes = conf.at("nprobe");
if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); }
if (conf.contains("internalDistanceDtype")) {
std::string type = conf.at("internalDistanceDtype");
if (type == "float") {
Expand Down Expand Up @@ -137,25 +140,79 @@ void parse_search_param(const nlohmann::json& conf,
#endif

#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA
template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
raft::neighbors::experimental::nn_descent::index_params& param)
{
if (conf.contains("graph_degree")) { param.graph_degree = conf.at("graph_degree"); }
if (conf.contains("intermediate_graph_degree")) {
param.intermediate_graph_degree = conf.at("intermediate_graph_degree");
}
// we allow niter shorthand for max_iterations
if (conf.contains("niter")) { param.max_iterations = conf.at("niter"); }
if (conf.contains("max_iterations")) { param.max_iterations = conf.at("max_iterations"); }
if (conf.contains("termination_threshold")) {
param.termination_threshold = conf.at("termination_threshold");
}
}

nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf,
const std::string& prefix,
bool remove_prefix = true)
{
nlohmann::json out;
for (auto& i : conf.items()) {
if (i.key().compare(0, prefix.size(), prefix) == 0) {
auto new_key = remove_prefix ? i.key().substr(prefix.size()) : i.key();
out[new_key] = i.value();
}
}
return out;
}

template <typename T, typename IdxT>
void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::RaftCagra<T, IdxT>::BuildParam& param)
{
if (conf.contains("graph_degree")) {
param.graph_degree = conf.at("graph_degree");
param.intermediate_graph_degree = param.graph_degree * 2;
param.cagra_params.graph_degree = conf.at("graph_degree");
param.cagra_params.intermediate_graph_degree = param.cagra_params.graph_degree * 2;
}
if (conf.contains("intermediate_graph_degree")) {
param.intermediate_graph_degree = conf.at("intermediate_graph_degree");
param.cagra_params.intermediate_graph_degree = conf.at("intermediate_graph_degree");
}
if (conf.contains("graph_build_algo")) {
if (conf.at("graph_build_algo") == "IVF_PQ") {
param.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ;
param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ;
} else if (conf.at("graph_build_algo") == "NN_DESCENT") {
param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
}
}
nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_");
if (!ivf_pq_build_conf.empty()) {
raft::neighbors::ivf_pq::index_params bparam;
parse_build_param<T, IdxT>(ivf_pq_build_conf, bparam);
param.ivf_pq_build_params = bparam;
}
nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_");
if (!ivf_pq_search_conf.empty()) {
typename raft::bench::ann::RaftIvfPQ<T, IdxT>::SearchParam sparam;
parse_search_param<T, IdxT>(ivf_pq_search_conf, sparam);
param.ivf_pq_search_params = sparam.pq_param;
param.ivf_pq_refine_rate = sparam.refine_ratio;
}
nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_");
if (!nn_descent_conf.empty()) {
raft::neighbors::experimental::nn_descent::index_params nn_param;
nn_param.intermediate_graph_degree = 1.5 * param.cagra_params.intermediate_graph_degree;
parse_build_param<T, IdxT>(nn_descent_conf, nn_param);
if (nn_param.graph_degree != param.cagra_params.intermediate_graph_degree) {
RAFT_LOG_WARN(
"nn_descent_graph_degree has to be equal to CAGRA intermediate_grpah_degree, overriding");
nn_param.graph_degree = param.cagra_params.intermediate_graph_degree;
}
param.nn_descent_params = nn_param;
}
if (conf.contains("nn_descent_niter")) { param.nn_descent_niter = conf.at("nn_descent_niter"); }
}

template <typename T, typename IdxT>
Expand Down
40 changes: 27 additions & 13 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <fstream>
#include <iostream>
#include <memory>
#include <optional>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
Expand All @@ -28,6 +29,9 @@
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/detail/cagra/cagra_build.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft/neighbors/nn_descent_types.hpp>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
#include <stdexcept>
Expand All @@ -50,12 +54,20 @@ class RaftCagra : public ANN<T> {
auto needs_dataset() const -> bool override { return true; }
};

using BuildParam = raft::neighbors::cagra::index_params;
struct BuildParam {
raft::neighbors::cagra::index_params cagra_params;
std::optional<raft::neighbors::experimental::nn_descent::index_params> nn_descent_params =
std::nullopt;
std::optional<float> ivf_pq_refine_rate = std::nullopt;
std::optional<raft::neighbors::ivf_pq::index_params> ivf_pq_build_params = std::nullopt;
std::optional<raft::neighbors::ivf_pq::search_params> ivf_pq_search_params = std::nullopt;
};

RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1)
: ANN<T>(metric, dim), index_params_(param), dimension_(dim), handle_(cudaStreamPerThread)
{
index_params_.metric = parse_metric_type(metric);
index_params_.cagra_params.metric = parse_metric_type(metric);
index_params_.ivf_pq_build_params->metric = parse_metric_type(metric);
RAFT_CUDA_TRY(cudaGetDevice(&device_));
}

Expand Down Expand Up @@ -99,17 +111,19 @@ class RaftCagra : public ANN<T> {
template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
{
if (raft::get_device_for_address(dataset) == -1) {
auto dataset_view =
raft::make_host_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
} else {
auto dataset_view =
raft::make_device_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view));
return;
}
auto dataset_view =
raft::make_host_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);

auto& params = index_params_.cagra_params;

index_.emplace(raft::neighbors::cagra::detail::build(handle_,
params,
dataset_view,
index_params_.nn_descent_params,
index_params_.ivf_pq_refine_rate,
index_params_.ivf_pq_build_params,
index_params_.ivf_pq_search_params));
return;
}

template <typename T, typename IdxT>
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@ if(BUILD_PRIMS_BENCH)
bench/prims/neighbors/knn/brute_force_float_int64_t.cu
bench/prims/neighbors/knn/brute_force_float_uint32_t.cu
bench/prims/neighbors/knn/cagra_float_uint32_t.cu
bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu
bench/prims/neighbors/refine_float_int64_t.cu
Expand Down
Loading

0 comments on commit bca8f40

Please sign in to comment.