Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-24.02' into subsampling…
Browse files Browse the repository at this point in the history
…-ivfpq-codebook
  • Loading branch information
abc99lr committed Jan 24, 2024
2 parents c2b2715 + 3ce00d3 commit 9395be8
Show file tree
Hide file tree
Showing 51 changed files with 2,582 additions and 576 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/neighbors/cagra_build.cu
src/raft_runtime/neighbors/cagra_search.cu
src/raft_runtime/neighbors/cagra_serialize.cu
src/raft_runtime/neighbors/eps_neighborhood.cu
src/raft_runtime/neighbors/ivf_flat_build.cu
src/raft_runtime/neighbors/ivf_flat_search.cu
src/raft_runtime/neighbors/ivf_flat_serialize.cu
Expand All @@ -443,6 +444,7 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu
src/raft_runtime/random/rmat_rectangular_generator_int_double.cu
src/raft_runtime/random/rmat_rectangular_generator_int_float.cu
src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu
src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu
src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu
src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu
Expand Down
18 changes: 8 additions & 10 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -287,11 +287,11 @@ void bench_search(::benchmark::State& state,
std::make_shared<buf<std::size_t>>(current_algo_props->query_memory_type, k * query_set_size);

cuda_timer gpu_timer;
auto start = std::chrono::high_resolution_clock::now();
{
nvtx_case nvtx{state.name()};

auto algo = dynamic_cast<ANN<T>*>(current_algo.get())->copy();
auto algo = dynamic_cast<ANN<T>*>(current_algo.get())->copy();
auto start = std::chrono::high_resolution_clock::now();
for (auto _ : state) {
[[maybe_unused]] auto ntx_lap = nvtx.lap();
[[maybe_unused]] auto gpu_lap = gpu_timer.lap();
Expand All @@ -314,17 +314,15 @@ void bench_search(::benchmark::State& state,

queries_processed += n_queries;
}
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
if (state.thread_index() == 0) { state.counters.insert({{"end_to_end", duration}}); }
state.counters.insert({"Latency", {duration, benchmark::Counter::kAvgIterations}});
}
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
if (state.thread_index() == 0) { state.counters.insert({{"end_to_end", duration}}); }
state.counters.insert(
{"Latency", {duration / double(state.iterations()), benchmark::Counter::kAvgThreads}});

state.SetItemsProcessed(queries_processed);
if (cudart.found()) {
double gpu_time_per_iteration = gpu_timer.total_time() / (double)state.iterations();
state.counters.insert({"GPU", {gpu_time_per_iteration, benchmark::Counter::kAvgThreads}});
state.counters.insert({"GPU", {gpu_timer.total_time(), benchmark::Counter::kAvgIterations}});
}

// This will be the total number of queries across all threads
Expand Down
27 changes: 2 additions & 25 deletions cpp/bench/ann/src/common/cuda_huge_page_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -49,13 +49,6 @@ class cuda_huge_page_resource final : public rmm::mr::device_memory_resource {
*/
[[nodiscard]] bool supports_streams() const noexcept override { return false; }

/**
* @brief Query whether the resource supports the get_mem_info API.
*
* @return true
*/
[[nodiscard]] bool supports_get_mem_info() const noexcept override { return true; }

private:
/**
* @brief Allocates memory of size at least `bytes` using cudaMalloc.
Expand Down Expand Up @@ -112,21 +105,5 @@ class cuda_huge_page_resource final : public rmm::mr::device_memory_resource {
{
return dynamic_cast<cuda_huge_page_resource const*>(&other) != nullptr;
}

/**
* @brief Get free and available memory for memory resource
*
* @throws `rmm::cuda_error` if unable to retrieve memory info.
*
* @return std::pair contaiing free_size and total_size of memory
*/
[[nodiscard]] std::pair<std::size_t, std::size_t> do_get_mem_info(
rmm::cuda_stream_view) const override
{
std::size_t free_size{};
std::size_t total_size{};
RMM_CUDA_TRY(cudaMemGetInfo(&free_size, &total_size));
return std::make_pair(free_size, total_size);
}
};
} // namespace raft::mr
} // namespace raft::mr
27 changes: 2 additions & 25 deletions cpp/bench/ann/src/common/cuda_pinned_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -53,13 +53,6 @@ class cuda_pinned_resource final : public rmm::mr::device_memory_resource {
*/
[[nodiscard]] bool supports_streams() const noexcept override { return false; }

/**
* @brief Query whether the resource supports the get_mem_info API.
*
* @return true
*/
[[nodiscard]] bool supports_get_mem_info() const noexcept override { return true; }

private:
/**
* @brief Allocates memory of size at least `bytes` using cudaMalloc.
Expand Down Expand Up @@ -110,21 +103,5 @@ class cuda_pinned_resource final : public rmm::mr::device_memory_resource {
{
return dynamic_cast<cuda_pinned_resource const*>(&other) != nullptr;
}

/**
* @brief Get free and available memory for memory resource
*
* @throws `rmm::cuda_error` if unable to retrieve memory info.
*
* @return std::pair contaiing free_size and total_size of memory
*/
[[nodiscard]] std::pair<std::size_t, std::size_t> do_get_mem_info(
rmm::cuda_stream_view) const override
{
std::size_t free_size{};
std::size_t total_size{};
RMM_CUDA_TRY(cudaMemGetInfo(&free_size, &total_size));
return std::make_pair(free_size, total_size);
}
};
} // namespace raft::mr
} // namespace raft::mr
32 changes: 31 additions & 1 deletion cpp/include/raft/core/math.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -49,12 +49,42 @@ RAFT_INLINE_FUNCTION auto abs(T x)
template <typename T>
constexpr RAFT_INLINE_FUNCTION auto abs(T x)
-> std::enable_if_t<!std::is_same_v<float, T> && !std::is_same_v<double, T> &&
#if defined(_RAFT_HAS_CUDA)
!std::is_same_v<__half, T> && !std::is_same_v<nv_bfloat16, T> &&
#endif
!std::is_same_v<int, T> && !std::is_same_v<long int, T> &&
!std::is_same_v<long long int, T>,
T>
{
return x < T{0} ? -x : x;
}
#if defined(_RAFT_HAS_CUDA)
template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, __half>, __half> abs(T x)
{
#if (__CUDA_ARCH__ >= 530)
return ::__habs(x);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "__half is only supported on __CUDA_ARCH__ >= 530");
return T{};
#endif
}

template <typename T>
RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t<std::is_same_v<T, nv_bfloat16>, nv_bfloat16>
abs(T x)
{
#if (__CUDA_ARCH__ >= 800)
return ::__habs(x);
#else
// Fail during template instantiation if the compute capability doesn't support this operation
static_assert(sizeof(T) != sizeof(T), "nv_bfloat16 is only supported on __CUDA_ARCH__ >= 800");
return T{};
#endif
}
#endif
/** @} */

/** Inverse cosine */
template <typename T>
Expand Down
41 changes: 39 additions & 2 deletions cpp/include/raft/neighbors/ball_cover-ext.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -67,6 +67,25 @@ void knn_query(raft::resources const& handle,
bool perform_post_filtering = true,
float weight = 1.0) RAFT_EXPLICIT;

template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void eps_nn(raft::resources const& handle,
const BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_matrix_view<bool, matrix_idx_t, row_major> adj,
raft::device_vector_view<idx_t, matrix_idx_t> vd,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> query,
value_t eps) RAFT_EXPLICIT;

template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void eps_nn(raft::resources const& handle,
const BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_vector_view<idx_t, matrix_idx_t> adj_ia,
raft::device_vector_view<idx_t, matrix_idx_t> adj_ja,
raft::device_vector_view<idx_t, matrix_idx_t> vd,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> query,
value_t eps,
std::optional<raft::host_scalar_view<int_t, matrix_idx_t>> max_k = std::nullopt)
RAFT_EXPLICIT;

} // namespace raft::neighbors::ball_cover

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY
Expand All @@ -87,6 +106,24 @@ void knn_query(raft::resources const& handle,
bool perform_post_filtering, \
float weight); \
\
extern template void raft::neighbors::ball_cover::eps_nn<idx_t, value_t, int_t, matrix_idx_t>( \
raft::resources const& handle, \
const raft::neighbors::ball_cover::BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index, \
raft::device_matrix_view<bool, matrix_idx_t, row_major> adj, \
raft::device_vector_view<idx_t, matrix_idx_t> vd, \
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> query, \
value_t eps); \
\
extern template void raft::neighbors::ball_cover::eps_nn<idx_t, value_t, int_t, matrix_idx_t>( \
raft::resources const& handle, \
const raft::neighbors::ball_cover::BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index, \
raft::device_vector_view<idx_t, matrix_idx_t> adj_ia, \
raft::device_vector_view<idx_t, matrix_idx_t> adj_ja, \
raft::device_vector_view<idx_t, matrix_idx_t> vd, \
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> query, \
value_t eps, \
std::optional<raft::host_scalar_view<int_t, matrix_idx_t>> max_k); \
\
extern template void \
raft::neighbors::ball_cover::all_knn_query<idx_t, value_t, int_t, matrix_idx_t>( \
raft::resources const& handle, \
Expand Down Expand Up @@ -119,6 +156,6 @@ void knn_query(raft::resources const& handle,
bool perform_post_filtering, \
float weight);

instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t);
instantiate_raft_neighbors_ball_cover(int64_t, float, int64_t, int64_t);

#undef instantiate_raft_neighbors_ball_cover
Loading

0 comments on commit 9395be8

Please sign in to comment.