diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index c144d1399e..16b0f7e1ac 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -18,9 +18,6 @@ option(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT "Include faiss' brute-force knn algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT "Include faiss' ivf flat algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ "Include faiss' ivf pq algorithm in benchmark" ON) -option(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT - "Include faiss' cpu brute-force knn algorithm in benchmark" ON -) option(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT "Include faiss' cpu brute-force algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT "Include faiss' cpu ivf flat algorithm in benchmark" @@ -30,6 +27,7 @@ option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON) +option(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE "Include raft's brute force knn in benchmark" ON) option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON) option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON) option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON) @@ -55,6 +53,7 @@ if(BUILD_CPU_ONLY) set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF) set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF) + set(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OFF) set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF) set(RAFT_ANN_BENCH_USE_GGNN OFF) else() diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp b/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp index 97d1bbf307..e3e25a99a2 100644 --- a/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp +++ b/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp @@ -143,7 +143,7 @@ std::unique_ptr::AnnSearchParam> create_search parse_search_param(conf, *param); return param; } else if (algo == "faiss_cpu_flat") { - auto param = std::make_unique::AnnSearchParam>(); + auto param = std::make_unique::SearchParam>(); return param; } // else diff --git a/cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu b/cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu index 8b04ba1980..a9388531cc 100644 --- a/cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu +++ b/cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu @@ -143,7 +143,7 @@ std::unique_ptr::AnnSearchParam> create_search parse_search_param(conf, *param); return param; } else if (algo == "faiss_gpu_flat") { - auto param = std::make_unique::AnnSearchParam>(); + auto param = std::make_unique::SearchParam>(); return param; } // else diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h index 1eb0e53cc5..2a021a8a12 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h @@ -20,7 +20,7 @@ #include #undef WARP_SIZE -#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN +#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE #include "raft_wrapper.h" #endif #ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index b776a9fafb..cfc30bef7d 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -47,8 +48,10 @@ std::unique_ptr> create_algo(const std::string& algo, std::unique_ptr> ann; if constexpr (std::is_same_v) { -#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN - if (algo == "raft_bfknn") { ann = std::make_unique>(metric, dim); } +#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE + if (algo == "raft_brute_force") { + ann = std::make_unique>(metric, dim); + } #endif } @@ -85,7 +88,7 @@ template std::unique_ptr::AnnSearchParam> create_search_param( const std::string& algo, const nlohmann::json& conf) { -#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN +#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE if (algo == "raft_brute_force") { auto param = std::make_unique::AnnSearchParam>(); return param; diff --git a/cpp/bench/ann/src/raft/raft_wrapper.h b/cpp/bench/ann/src/raft/raft_wrapper.h index 499bdf29a1..eae615cba1 100644 --- a/cpp/bench/ann/src/raft/raft_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_wrapper.h @@ -17,33 +17,33 @@ #include #include +#include #include #include -#include +#include +#include #include #include #include #include "../common/ann_types.hpp" +#include "raft_ann_bench_utils.h" namespace raft_temp { inline raft::distance::DistanceType parse_metric_type(raft::bench::ann::Metric metric) { - if (metric == raft::bench::ann::Metric::kInnerProduct) { - return raft::distance::DistanceType::InnerProduct; - } else if (metric == raft::bench::ann::Metric::kEuclidean) { - return raft::distance::DistanceType::L2Expanded; - } else { - throw std::runtime_error("raft supports only metric type of inner product and L2"); + switch (metric) { + case raft::bench::ann::Metric::kInnerProduct: return raft::distance::DistanceType::InnerProduct; + case raft::bench::ann::Metric::kEuclidean: return raft::distance::DistanceType::L2Expanded; + default: throw std::runtime_error("raft supports only metric type of inner product and L2"); } } - } // namespace raft_temp namespace raft::bench::ann { -// brute force fused L2 KNN - RAFT +// brute force KNN - RAFT template class RaftGpu : public ANN { public: @@ -74,9 +74,13 @@ class RaftGpu : public ANN { } void set_search_dataset(const T* dataset, size_t nrow) override; void save(const std::string& file) const override; - void load(const std::string&) override { return; }; + void load(const std::string&) override; + std::unique_ptr> copy() override; protected: + // handle_ must go first to make sure it dies last and all memory allocated in pool + configured_raft_resources handle_{}; + std::shared_ptr> index_; raft::distance::DistanceType metric_type_; int device_; const T* dataset_; @@ -87,16 +91,19 @@ template RaftGpu::RaftGpu(Metric metric, int dim) : ANN(metric, dim), metric_type_(raft_temp::parse_metric_type(metric)) { - static_assert(std::is_same_v, "raft support only float type"); - assert(metric_type_ == raft::distance::DistanceType::L2Expanded); + static_assert(std::is_same_v || std::is_same_v, + "raft bfknn only supports float/double"); RAFT_CUDA_TRY(cudaGetDevice(&device_)); } template -void RaftGpu::build(const T*, size_t, cudaStream_t) +void RaftGpu::build(const T* dataset, size_t nrow, cudaStream_t stream) { - // as this is brute force algo so no index building required - return; + auto dataset_view = raft::make_host_matrix_view(dataset, nrow, this->dim_); + index_ = std::make_shared>( + std::move(raft::neighbors::brute_force::build(handle_, dataset_view))); + + handle_.stream_wait(stream); } template @@ -115,15 +122,14 @@ void RaftGpu::set_search_dataset(const T* dataset, size_t nrow) template void RaftGpu::save(const std::string& file) const { - // create a empty index file as no index to store. - std::fstream fp; - fp.open(file.c_str(), std::ios::out); - if (!fp) { - printf("Error in creating file!!!\n"); - ; - return; - } - fp.close(); + raft::neighbors::brute_force::serialize(handle_, file, *index_); +} + +template +void RaftGpu::load(const std::string& file) +{ + index_ = std::make_shared>( + std::move(raft::neighbors::brute_force::deserialize(handle_, file))); } template @@ -134,20 +140,22 @@ void RaftGpu::search(const T* queries, float* distances, cudaStream_t stream) const { - // TODO: Integrate new `raft::brute_force::index` (from - // https://github.com/rapidsai/raft/pull/1817) - raft::spatial::knn::detail::fusedL2Knn(this->dim_, - reinterpret_cast(neighbors), - distances, - dataset_, - queries, - nrow_, - static_cast(batch_size), - k, - true, - true, - stream, - metric_type_); + auto queries_view = + raft::make_device_matrix_view(queries, batch_size, this->dim_); + + auto neighbors_view = raft::make_device_matrix_view(neighbors, batch_size, k); + auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); + + raft::neighbors::brute_force::search( + handle_, *index_, queries_view, neighbors_view, distances_view); + + handle_.stream_wait(stream); +} + +template +std::unique_ptr> RaftGpu::copy() +{ + return std::make_unique>(*this); // use copy constructor } } // namespace raft::bench::ann diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py index 9841b47b98..a1f97d67d5 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py @@ -498,8 +498,8 @@ def add_algo_group(group_list): ) if executable not in executables_to_run: executables_to_run[executable] = {"index": []} - build_params = algos_conf[algo]["groups"][group]["build"] - search_params = algos_conf[algo]["groups"][group]["search"] + build_params = algos_conf[algo]["groups"][group]["build"] or {} + search_params = algos_conf[algo]["groups"][group]["search"] or {} param_names = [] param_lists = [] diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/faiss_cpu_flat.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/faiss_cpu_flat.yaml new file mode 100644 index 0000000000..25eaf03d40 --- /dev/null +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/faiss_cpu_flat.yaml @@ -0,0 +1,5 @@ +name: faiss_cpu_flat +groups: + base: + build: + search: diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/faiss_gpu_flat.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/faiss_gpu_flat.yaml new file mode 100644 index 0000000000..a722e1b91c --- /dev/null +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/faiss_gpu_flat.yaml @@ -0,0 +1,5 @@ +name: faiss_gpu_flat +groups: + base: + build: + search: diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_brute_force.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_brute_force.yaml new file mode 100644 index 0000000000..da99841f9b --- /dev/null +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_brute_force.yaml @@ -0,0 +1,5 @@ +name: raft_brute_force +groups: + base: + build: + search: