Skip to content

Commit

Permalink
Benchmark brute force knn (#2063)
Browse files Browse the repository at this point in the history
Add our bfknn code to the raft-ann-bench project

Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2063
  • Loading branch information
benfred authored Dec 20, 2023
1 parent 2962169 commit 7e098b2
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 49 deletions.
5 changes: 2 additions & 3 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
option(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT "Include faiss' brute-force knn algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT "Include faiss' ivf flat algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ "Include faiss' ivf pq algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT
"Include faiss' cpu brute-force knn algorithm in benchmark" ON
)
option(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT "Include faiss' cpu brute-force algorithm in benchmark" ON)

option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT "Include faiss' cpu ivf flat algorithm in benchmark"
Expand All @@ -30,6 +27,7 @@ option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm
option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE "Include raft's brute force knn in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON)
Expand All @@ -55,6 +53,7 @@ if(BUILD_CPU_ONLY)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF)
set(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF)
set(RAFT_ANN_BENCH_USE_GGNN OFF)
else()
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search
parse_search_param<T>(conf, *param);
return param;
} else if (algo == "faiss_cpu_flat") {
auto param = std::make_unique<typename raft::bench::ann::ANN<T>::AnnSearchParam>();
auto param = std::make_unique<typename raft::bench::ann::FaissCpu<T>::SearchParam>();
return param;
}
// else
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search
parse_search_param<T>(conf, *param);
return param;
} else if (algo == "faiss_gpu_flat") {
auto param = std::make_unique<typename raft::bench::ann::ANN<T>::AnnSearchParam>();
auto param = std::make_unique<typename raft::bench::ann::FaissGpu<T>::SearchParam>();
return param;
}
// else
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <nlohmann/json.hpp>

#undef WARP_SIZE
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
#include "raft_wrapper.h"
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cmath>
#include <memory>
#include <raft/core/logger.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <stdexcept>
#include <string>
Expand All @@ -47,8 +48,10 @@ std::unique_ptr<raft::bench::ann::ANN<T>> create_algo(const std::string& algo,
std::unique_ptr<raft::bench::ann::ANN<T>> ann;

if constexpr (std::is_same_v<T, float>) {
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
if (algo == "raft_bfknn") { ann = std::make_unique<raft::bench::ann::RaftGpu<T>>(metric, dim); }
#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
if (algo == "raft_brute_force") {
ann = std::make_unique<raft::bench::ann::RaftGpu<T>>(metric, dim);
}
#endif
}

Expand Down Expand Up @@ -85,7 +88,7 @@ template <typename T>
std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search_param(
const std::string& algo, const nlohmann::json& conf)
{
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
if (algo == "raft_brute_force") {
auto param = std::make_unique<typename raft::bench::ann::ANN<T>::AnnSearchParam>();
return param;
Expand Down
84 changes: 46 additions & 38 deletions cpp/bench/ann/src/raft/raft_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@

#include <cassert>
#include <memory>
#include <raft/core/device_resources.hpp>
#include <raft/distance/detail/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/neighbors/brute_force.cuh>
#include <raft/neighbors/brute_force_serialize.cuh>
#include <stdexcept>
#include <string>
#include <type_traits>

#include "../common/ann_types.hpp"
#include "raft_ann_bench_utils.h"

namespace raft_temp {

inline raft::distance::DistanceType parse_metric_type(raft::bench::ann::Metric metric)
{
if (metric == raft::bench::ann::Metric::kInnerProduct) {
return raft::distance::DistanceType::InnerProduct;
} else if (metric == raft::bench::ann::Metric::kEuclidean) {
return raft::distance::DistanceType::L2Expanded;
} else {
throw std::runtime_error("raft supports only metric type of inner product and L2");
switch (metric) {
case raft::bench::ann::Metric::kInnerProduct: return raft::distance::DistanceType::InnerProduct;
case raft::bench::ann::Metric::kEuclidean: return raft::distance::DistanceType::L2Expanded;
default: throw std::runtime_error("raft supports only metric type of inner product and L2");
}
}

} // namespace raft_temp

namespace raft::bench::ann {

// brute force fused L2 KNN - RAFT
// brute force KNN - RAFT
template <typename T>
class RaftGpu : public ANN<T> {
public:
Expand Down Expand Up @@ -74,9 +74,13 @@ class RaftGpu : public ANN<T> {
}
void set_search_dataset(const T* dataset, size_t nrow) override;
void save(const std::string& file) const override;
void load(const std::string&) override { return; };
void load(const std::string&) override;
std::unique_ptr<ANN<T>> copy() override;

protected:
// handle_ must go first to make sure it dies last and all memory allocated in pool
configured_raft_resources handle_{};
std::shared_ptr<raft::neighbors::brute_force::index<T>> index_;
raft::distance::DistanceType metric_type_;
int device_;
const T* dataset_;
Expand All @@ -87,16 +91,19 @@ template <typename T>
RaftGpu<T>::RaftGpu(Metric metric, int dim)
: ANN<T>(metric, dim), metric_type_(raft_temp::parse_metric_type(metric))
{
static_assert(std::is_same_v<T, float>, "raft support only float type");
assert(metric_type_ == raft::distance::DistanceType::L2Expanded);
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"raft bfknn only supports float/double");
RAFT_CUDA_TRY(cudaGetDevice(&device_));
}

template <typename T>
void RaftGpu<T>::build(const T*, size_t, cudaStream_t)
void RaftGpu<T>::build(const T* dataset, size_t nrow, cudaStream_t stream)
{
// as this is brute force algo so no index building required
return;
auto dataset_view = raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
index_ = std::make_shared<raft::neighbors::brute_force::index<T>>(
std::move(raft::neighbors::brute_force::build(handle_, dataset_view)));

handle_.stream_wait(stream);
}

template <typename T>
Expand All @@ -115,15 +122,14 @@ void RaftGpu<T>::set_search_dataset(const T* dataset, size_t nrow)
template <typename T>
void RaftGpu<T>::save(const std::string& file) const
{
// create a empty index file as no index to store.
std::fstream fp;
fp.open(file.c_str(), std::ios::out);
if (!fp) {
printf("Error in creating file!!!\n");
;
return;
}
fp.close();
raft::neighbors::brute_force::serialize<T>(handle_, file, *index_);
}

template <typename T>
void RaftGpu<T>::load(const std::string& file)
{
index_ = std::make_shared<raft::neighbors::brute_force::index<T>>(
std::move(raft::neighbors::brute_force::deserialize<T>(handle_, file)));
}

template <typename T>
Expand All @@ -134,20 +140,22 @@ void RaftGpu<T>::search(const T* queries,
float* distances,
cudaStream_t stream) const
{
// TODO: Integrate new `raft::brute_force::index` (from
// https://github.com/rapidsai/raft/pull/1817)
raft::spatial::knn::detail::fusedL2Knn(this->dim_,
reinterpret_cast<int64_t*>(neighbors),
distances,
dataset_,
queries,
nrow_,
static_cast<size_t>(batch_size),
k,
true,
true,
stream,
metric_type_);
auto queries_view =
raft::make_device_matrix_view<const T, int64_t>(queries, batch_size, this->dim_);

auto neighbors_view = raft::make_device_matrix_view<size_t, int64_t>(neighbors, batch_size, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);

raft::neighbors::brute_force::search<T, size_t>(
handle_, *index_, queries_view, neighbors_view, distances_view);

handle_.stream_wait(stream);
}

template <typename T>
std::unique_ptr<ANN<T>> RaftGpu<T>::copy()
{
return std::make_unique<RaftGpu<T>>(*this); // use copy constructor
}

} // namespace raft::bench::ann
4 changes: 2 additions & 2 deletions python/raft-ann-bench/src/raft-ann-bench/run/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,8 @@ def add_algo_group(group_list):
)
if executable not in executables_to_run:
executables_to_run[executable] = {"index": []}
build_params = algos_conf[algo]["groups"][group]["build"]
search_params = algos_conf[algo]["groups"][group]["search"]
build_params = algos_conf[algo]["groups"][group]["build"] or {}
search_params = algos_conf[algo]["groups"][group]["search"] or {}

param_names = []
param_lists = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: faiss_cpu_flat
groups:
base:
build:
search:
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: faiss_gpu_flat
groups:
base:
build:
search:
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: raft_brute_force
groups:
base:
build:
search:

0 comments on commit 7e098b2

Please sign in to comment.