Skip to content

Commit

Permalink
Merge branch 'branch-24.02' into cagra_hnswlib_pylibraft
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Jan 6, 2024
2 parents 397176e + 6762fe5 commit 114d176
Show file tree
Hide file tree
Showing 50 changed files with 3,801 additions and 261 deletions.
16 changes: 3 additions & 13 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ include(cmake/modules/ConfigureCUDA.cmake)
rapids_cpm_init()

if(NOT BUILD_CPU_ONLY)
# thrust before rmm/cuco so we get the right version of thrust/cub
include(cmake/thirdparty/get_thrust.cmake)
# CCCL before rmm/cuco so we get the right version of CCCL
include(cmake/thirdparty/get_cccl.cmake)
include(cmake/thirdparty/get_rmm.cmake)
include(cmake/thirdparty/get_cutlass.cmake)

Expand Down Expand Up @@ -212,7 +212,7 @@ target_include_directories(

if(NOT BUILD_CPU_ONLY)
# Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target.
target_link_libraries(raft INTERFACE rmm::rmm cuco::cuco nvidia::cutlass::cutlass raft::Thrust)
target_link_libraries(raft INTERFACE rmm::rmm cuco::cuco nvidia::cutlass::cutlass CCCL::CCCL)
endif()

target_compile_features(raft INTERFACE cxx_std_17 $<BUILD_INTERFACE:cuda_std_17>)
Expand Down Expand Up @@ -635,16 +635,6 @@ Imported Targets:

set(code_string ${nvtx_export_string})

string(
APPEND
code_string
[=[
if(NOT TARGET raft::Thrust)
thrust_create_target(raft::Thrust FROM_OPTIONS)
endif()
]=]
)

string(
APPEND
code_string
Expand Down
5 changes: 2 additions & 3 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
option(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT "Include faiss' brute-force knn algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT "Include faiss' ivf flat algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ "Include faiss' ivf pq algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT
"Include faiss' cpu brute-force knn algorithm in benchmark" ON
)
option(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT "Include faiss' cpu brute-force algorithm in benchmark" ON)

option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT "Include faiss' cpu ivf flat algorithm in benchmark"
Expand All @@ -30,6 +27,7 @@ option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm
option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE "Include raft's brute force knn in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON)
Expand All @@ -55,6 +53,7 @@ if(BUILD_CPU_ONLY)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF)
set(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF)
set(RAFT_ANN_BENCH_USE_GGNN OFF)
else()
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search
parse_search_param<T>(conf, *param);
return param;
} else if (algo == "faiss_cpu_flat") {
auto param = std::make_unique<typename raft::bench::ann::ANN<T>::AnnSearchParam>();
auto param = std::make_unique<typename raft::bench::ann::FaissCpu<T>::SearchParam>();
return param;
}
// else
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search
parse_search_param<T>(conf, *param);
return param;
} else if (algo == "faiss_gpu_flat") {
auto param = std::make_unique<typename raft::bench::ann::ANN<T>::AnnSearchParam>();
auto param = std::make_unique<typename raft::bench::ann::FaissGpu<T>::SearchParam>();
return param;
}
// else
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <nlohmann/json.hpp>

#undef WARP_SIZE
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
#include "raft_wrapper.h"
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cmath>
#include <memory>
#include <raft/core/logger.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <stdexcept>
#include <string>
Expand All @@ -47,8 +48,10 @@ std::unique_ptr<raft::bench::ann::ANN<T>> create_algo(const std::string& algo,
std::unique_ptr<raft::bench::ann::ANN<T>> ann;

if constexpr (std::is_same_v<T, float>) {
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
if (algo == "raft_bfknn") { ann = std::make_unique<raft::bench::ann::RaftGpu<T>>(metric, dim); }
#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
if (algo == "raft_brute_force") {
ann = std::make_unique<raft::bench::ann::RaftGpu<T>>(metric, dim);
}
#endif
}

Expand Down Expand Up @@ -85,7 +88,7 @@ template <typename T>
std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search_param(
const std::string& algo, const nlohmann::json& conf)
{
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
if (algo == "raft_brute_force") {
auto param = std::make_unique<typename raft::bench::ann::ANN<T>::AnnSearchParam>();
return param;
Expand Down
84 changes: 46 additions & 38 deletions cpp/bench/ann/src/raft/raft_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@

#include <cassert>
#include <memory>
#include <raft/core/device_resources.hpp>
#include <raft/distance/detail/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/neighbors/brute_force.cuh>
#include <raft/neighbors/brute_force_serialize.cuh>
#include <stdexcept>
#include <string>
#include <type_traits>

#include "../common/ann_types.hpp"
#include "raft_ann_bench_utils.h"

namespace raft_temp {

inline raft::distance::DistanceType parse_metric_type(raft::bench::ann::Metric metric)
{
if (metric == raft::bench::ann::Metric::kInnerProduct) {
return raft::distance::DistanceType::InnerProduct;
} else if (metric == raft::bench::ann::Metric::kEuclidean) {
return raft::distance::DistanceType::L2Expanded;
} else {
throw std::runtime_error("raft supports only metric type of inner product and L2");
switch (metric) {
case raft::bench::ann::Metric::kInnerProduct: return raft::distance::DistanceType::InnerProduct;
case raft::bench::ann::Metric::kEuclidean: return raft::distance::DistanceType::L2Expanded;
default: throw std::runtime_error("raft supports only metric type of inner product and L2");
}
}

} // namespace raft_temp

namespace raft::bench::ann {

// brute force fused L2 KNN - RAFT
// brute force KNN - RAFT
template <typename T>
class RaftGpu : public ANN<T> {
public:
Expand Down Expand Up @@ -74,9 +74,13 @@ class RaftGpu : public ANN<T> {
}
void set_search_dataset(const T* dataset, size_t nrow) override;
void save(const std::string& file) const override;
void load(const std::string&) override { return; };
void load(const std::string&) override;
std::unique_ptr<ANN<T>> copy() override;

protected:
// handle_ must go first to make sure it dies last and all memory allocated in pool
configured_raft_resources handle_{};
std::shared_ptr<raft::neighbors::brute_force::index<T>> index_;
raft::distance::DistanceType metric_type_;
int device_;
const T* dataset_;
Expand All @@ -87,16 +91,19 @@ template <typename T>
RaftGpu<T>::RaftGpu(Metric metric, int dim)
: ANN<T>(metric, dim), metric_type_(raft_temp::parse_metric_type(metric))
{
static_assert(std::is_same_v<T, float>, "raft support only float type");
assert(metric_type_ == raft::distance::DistanceType::L2Expanded);
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"raft bfknn only supports float/double");
RAFT_CUDA_TRY(cudaGetDevice(&device_));
}

template <typename T>
void RaftGpu<T>::build(const T*, size_t, cudaStream_t)
void RaftGpu<T>::build(const T* dataset, size_t nrow, cudaStream_t stream)
{
// as this is brute force algo so no index building required
return;
auto dataset_view = raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
index_ = std::make_shared<raft::neighbors::brute_force::index<T>>(
std::move(raft::neighbors::brute_force::build(handle_, dataset_view)));

handle_.stream_wait(stream);
}

template <typename T>
Expand All @@ -115,15 +122,14 @@ void RaftGpu<T>::set_search_dataset(const T* dataset, size_t nrow)
template <typename T>
void RaftGpu<T>::save(const std::string& file) const
{
// create a empty index file as no index to store.
std::fstream fp;
fp.open(file.c_str(), std::ios::out);
if (!fp) {
printf("Error in creating file!!!\n");
;
return;
}
fp.close();
raft::neighbors::brute_force::serialize<T>(handle_, file, *index_);
}

template <typename T>
void RaftGpu<T>::load(const std::string& file)
{
index_ = std::make_shared<raft::neighbors::brute_force::index<T>>(
std::move(raft::neighbors::brute_force::deserialize<T>(handle_, file)));
}

template <typename T>
Expand All @@ -134,20 +140,22 @@ void RaftGpu<T>::search(const T* queries,
float* distances,
cudaStream_t stream) const
{
// TODO: Integrate new `raft::brute_force::index` (from
// https://github.com/rapidsai/raft/pull/1817)
raft::spatial::knn::detail::fusedL2Knn(this->dim_,
reinterpret_cast<int64_t*>(neighbors),
distances,
dataset_,
queries,
nrow_,
static_cast<size_t>(batch_size),
k,
true,
true,
stream,
metric_type_);
auto queries_view =
raft::make_device_matrix_view<const T, int64_t>(queries, batch_size, this->dim_);

auto neighbors_view = raft::make_device_matrix_view<size_t, int64_t>(neighbors, batch_size, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);

raft::neighbors::brute_force::search<T, size_t>(
handle_, *index_, queries_view, neighbors_view, distances_view);

handle_.stream_wait(stream);
}

template <typename T>
std::unique_ptr<ANN<T>> RaftGpu<T>::copy()
{
return std::make_unique<RaftGpu<T>>(*this); // use copy constructor
}

} // namespace raft::bench::ann
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
# the License.
# =============================================================================

# Use CPM to find or clone thrust
function(find_and_configure_thrust)
include(${rapids-cmake-dir}/cpm/thrust.cmake)

rapids_cpm_thrust( NAMESPACE raft
BUILD_EXPORT_SET raft-exports
INSTALL_EXPORT_SET raft-exports)
# Use CPM to find or clone CCCL
function(find_and_configure_cccl)
include(${rapids-cmake-dir}/cpm/cccl.cmake)
rapids_cpm_cccl(BUILD_EXPORT_SET raft-exports INSTALL_EXPORT_SET raft-exports)
endfunction()

find_and_configure_thrust()
find_and_configure_cccl()
8 changes: 4 additions & 4 deletions cpp/include/raft/core/cublas_macros.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,7 +33,7 @@
namespace raft {

/**
* @ingroup error_handling
* @addtogroup error_handling
* @{
*/

Expand Down Expand Up @@ -76,7 +76,7 @@ inline const char* cublas_error_to_string(cublasStatus_t err)
#undef _CUBLAS_ERR_TO_STR

/**
* @ingroup assertion
* @addtogroup assertion
* @{
*/

Expand Down Expand Up @@ -135,4 +135,4 @@ inline const char* cublas_error_to_string(cublasStatus_t err)
#define CUBLAS_CHECK_NO_THROW(call) RAFT_CUBLAS_TRY_NO_THROW(call)
#endif

#endif
#endif
Loading

0 comments on commit 114d176

Please sign in to comment.