diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 0d71026e17..517e6d3f49 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -424,6 +424,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/neighbors/cagra_build.cu src/raft_runtime/neighbors/cagra_search.cu src/raft_runtime/neighbors/cagra_serialize.cu + src/raft_runtime/neighbors/eps_neighborhood.cu src/raft_runtime/neighbors/ivf_flat_build.cu src/raft_runtime/neighbors/ivf_flat_search.cu src/raft_runtime/neighbors/ivf_flat_serialize.cu @@ -443,6 +444,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu src/raft_runtime/random/rmat_rectangular_generator_int_double.cu src/raft_runtime/random/rmat_rectangular_generator_int_float.cu + src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index e61de6745e..53f31d6232 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -287,11 +287,11 @@ void bench_search(::benchmark::State& state, std::make_shared>(current_algo_props->query_memory_type, k * query_set_size); cuda_timer gpu_timer; - auto start = std::chrono::high_resolution_clock::now(); { nvtx_case nvtx{state.name()}; - auto algo = dynamic_cast*>(current_algo.get())->copy(); + auto algo = dynamic_cast*>(current_algo.get())->copy(); + auto start = std::chrono::high_resolution_clock::now(); for (auto _ : state) { [[maybe_unused]] auto ntx_lap = nvtx.lap(); [[maybe_unused]] auto gpu_lap = gpu_timer.lap(); @@ -314,17 +314,15 @@ void bench_search(::benchmark::State& state, queries_processed += n_queries; } + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast>(end - start).count(); + if (state.thread_index() == 0) { state.counters.insert({{"end_to_end", duration}}); } + state.counters.insert({"Latency", {duration, benchmark::Counter::kAvgIterations}}); } - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast>(end - start).count(); - if (state.thread_index() == 0) { state.counters.insert({{"end_to_end", duration}}); } - state.counters.insert( - {"Latency", {duration / double(state.iterations()), benchmark::Counter::kAvgThreads}}); state.SetItemsProcessed(queries_processed); if (cudart.found()) { - double gpu_time_per_iteration = gpu_timer.total_time() / (double)state.iterations(); - state.counters.insert({"GPU", {gpu_time_per_iteration, benchmark::Counter::kAvgThreads}}); + state.counters.insert({"GPU", {gpu_timer.total_time(), benchmark::Counter::kAvgIterations}}); } // This will be the total number of queries across all threads diff --git a/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp b/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp index 9132db7c04..f7088c7271 100644 --- a/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp +++ b/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,13 +49,6 @@ class cuda_huge_page_resource final : public rmm::mr::device_memory_resource { */ [[nodiscard]] bool supports_streams() const noexcept override { return false; } - /** - * @brief Query whether the resource supports the get_mem_info API. - * - * @return true - */ - [[nodiscard]] bool supports_get_mem_info() const noexcept override { return true; } - private: /** * @brief Allocates memory of size at least `bytes` using cudaMalloc. @@ -112,21 +105,5 @@ class cuda_huge_page_resource final : public rmm::mr::device_memory_resource { { return dynamic_cast(&other) != nullptr; } - - /** - * @brief Get free and available memory for memory resource - * - * @throws `rmm::cuda_error` if unable to retrieve memory info. - * - * @return std::pair contaiing free_size and total_size of memory - */ - [[nodiscard]] std::pair do_get_mem_info( - rmm::cuda_stream_view) const override - { - std::size_t free_size{}; - std::size_t total_size{}; - RMM_CUDA_TRY(cudaMemGetInfo(&free_size, &total_size)); - return std::make_pair(free_size, total_size); - } }; -} // namespace raft::mr \ No newline at end of file +} // namespace raft::mr diff --git a/cpp/bench/ann/src/common/cuda_pinned_resource.hpp b/cpp/bench/ann/src/common/cuda_pinned_resource.hpp index 28ca691f86..ab207a36fe 100644 --- a/cpp/bench/ann/src/common/cuda_pinned_resource.hpp +++ b/cpp/bench/ann/src/common/cuda_pinned_resource.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,13 +53,6 @@ class cuda_pinned_resource final : public rmm::mr::device_memory_resource { */ [[nodiscard]] bool supports_streams() const noexcept override { return false; } - /** - * @brief Query whether the resource supports the get_mem_info API. - * - * @return true - */ - [[nodiscard]] bool supports_get_mem_info() const noexcept override { return true; } - private: /** * @brief Allocates memory of size at least `bytes` using cudaMalloc. @@ -110,21 +103,5 @@ class cuda_pinned_resource final : public rmm::mr::device_memory_resource { { return dynamic_cast(&other) != nullptr; } - - /** - * @brief Get free and available memory for memory resource - * - * @throws `rmm::cuda_error` if unable to retrieve memory info. - * - * @return std::pair contaiing free_size and total_size of memory - */ - [[nodiscard]] std::pair do_get_mem_info( - rmm::cuda_stream_view) const override - { - std::size_t free_size{}; - std::size_t total_size{}; - RMM_CUDA_TRY(cudaMemGetInfo(&free_size, &total_size)); - return std::make_pair(free_size, total_size); - } }; -} // namespace raft::mr \ No newline at end of file +} // namespace raft::mr diff --git a/cpp/include/raft/core/math.hpp b/cpp/include/raft/core/math.hpp index 56a8d78926..809b2948e7 100644 --- a/cpp/include/raft/core/math.hpp +++ b/cpp/include/raft/core/math.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,12 +49,42 @@ RAFT_INLINE_FUNCTION auto abs(T x) template constexpr RAFT_INLINE_FUNCTION auto abs(T x) -> std::enable_if_t && !std::is_same_v && +#if defined(_RAFT_HAS_CUDA) + !std::is_same_v<__half, T> && !std::is_same_v && +#endif !std::is_same_v && !std::is_same_v && !std::is_same_v, T> { return x < T{0} ? -x : x; } +#if defined(_RAFT_HAS_CUDA) +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, __half> abs(T x) +{ +#if (__CUDA_ARCH__ >= 530) + return ::__habs(x); +#else + // Fail during template instantiation if the compute capability doesn't support this operation + static_assert(sizeof(T) != sizeof(T), "__half is only supported on __CUDA_ARCH__ >= 530"); + return T{}; +#endif +} + +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, nv_bfloat16> +abs(T x) +{ +#if (__CUDA_ARCH__ >= 800) + return ::__habs(x); +#else + // Fail during template instantiation if the compute capability doesn't support this operation + static_assert(sizeof(T) != sizeof(T), "nv_bfloat16 is only supported on __CUDA_ARCH__ >= 800"); + return T{}; +#endif +} +#endif +/** @} */ /** Inverse cosine */ template diff --git a/cpp/include/raft/neighbors/ball_cover-ext.cuh b/cpp/include/raft/neighbors/ball_cover-ext.cuh index bc5fe934ab..3d0b3c7858 100644 --- a/cpp/include/raft/neighbors/ball_cover-ext.cuh +++ b/cpp/include/raft/neighbors/ball_cover-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,6 +67,25 @@ void knn_query(raft::resources const& handle, bool perform_post_filtering = true, float weight = 1.0) RAFT_EXPLICIT; +template +void eps_nn(raft::resources const& handle, + const BallCoverIndex& index, + raft::device_matrix_view adj, + raft::device_vector_view vd, + raft::device_matrix_view query, + value_t eps) RAFT_EXPLICIT; + +template +void eps_nn(raft::resources const& handle, + const BallCoverIndex& index, + raft::device_vector_view adj_ia, + raft::device_vector_view adj_ja, + raft::device_vector_view vd, + raft::device_matrix_view query, + value_t eps, + std::optional> max_k = std::nullopt) + RAFT_EXPLICIT; + } // namespace raft::neighbors::ball_cover #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -87,6 +106,24 @@ void knn_query(raft::resources const& handle, bool perform_post_filtering, \ float weight); \ \ + extern template void raft::neighbors::ball_cover::eps_nn( \ + raft::resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view adj, \ + raft::device_vector_view vd, \ + raft::device_matrix_view query, \ + value_t eps); \ + \ + extern template void raft::neighbors::ball_cover::eps_nn( \ + raft::resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_vector_view adj_ia, \ + raft::device_vector_view adj_ja, \ + raft::device_vector_view vd, \ + raft::device_matrix_view query, \ + value_t eps, \ + std::optional> max_k); \ + \ extern template void \ raft::neighbors::ball_cover::all_knn_query( \ raft::resources const& handle, \ @@ -119,6 +156,6 @@ void knn_query(raft::resources const& handle, bool perform_post_filtering, \ float weight); -instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t); +instantiate_raft_neighbors_ball_cover(int64_t, float, int64_t, int64_t); #undef instantiate_raft_neighbors_ball_cover diff --git a/cpp/include/raft/neighbors/ball_cover-inl.cuh b/cpp/include/raft/neighbors/ball_cover-inl.cuh index d35c1dc614..cdf7c30e89 100644 --- a/cpp/include/raft/neighbors/ball_cover-inl.cuh +++ b/cpp/include/raft/neighbors/ball_cover-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,7 +63,6 @@ template & index) { - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { raft::spatial::knn::detail::rbc_build_index( handle, index, spatial::knn::detail::HaversineFunc()); @@ -255,9 +254,9 @@ void all_knn_query(raft::resources const& handle, * looking in the closest landmark. * @param[in] n_query_pts number of query points */ -template +template void knn_query(raft::resources const& handle, - const BallCoverIndex& index, + const BallCoverIndex& index, int_t k, const value_t* query, int_t n_query_pts, @@ -295,6 +294,106 @@ void knn_query(raft::resources const& handle, } } +/** + * @brief Computes epsilon neighborhood for the L2 distance metric using rbc + * + * @tparam value_t IO and math type + * @tparam idx_t Index type + * + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has been built + * @param[out] adj adjacency matrix [row-major] [on device] [dim = m x n] + * @param[out] vd vertex degree array [on device] [len = m + 1] + * `vd + m` stores the total number of edges in the adjacency + * matrix. Pass a nullptr if you don't need this info. + * @param[in] query first matrix [row-major] [on device] [dim = m x k] + * @param[in] eps defines epsilon neighborhood radius + */ +template +void eps_nn(raft::resources const& handle, + const BallCoverIndex& index, + raft::device_matrix_view adj, + raft::device_vector_view vd, + raft::device_matrix_view query, + value_t eps) +{ + ASSERT(index.n == query.extent(1), "vector dimension needs to be the same for index and queries"); + ASSERT(index.metric == raft::distance::DistanceType::L2SqrtExpanded || + index.metric == raft::distance::DistanceType::L2SqrtUnexpanded, + "Metric not supported"); + ASSERT(index.is_index_trained(), "index must be previously trained"); + + // run query + raft::spatial::knn::detail::rbc_eps_nn_query( + handle, + index, + eps, + query.data_handle(), + query.extent(0), + adj.data_handle(), + vd.data_handle(), + spatial::knn::detail::EuclideanFunc()); +} + +/** + * @brief Computes epsilon neighborhood for the L2 distance metric using rbc + * + * @tparam value_t IO and math type + * @tparam idx_t Index type + * + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has been built + * @param[out] adj_ia adjacency matrix CSR row offsets + * @param[out] adj_ja adjacency matrix CSR column indices, needs to be nullptr + * in first pass with max_k nullopt + * @param[out] vd vertex degree array [on device] [len = m + 1] + * `vd + m` stores the total number of edges in the adjacency + * matrix. Pass a nullptr if you don't need this info. + * @param[in] query first matrix [row-major] [on device] [dim = m x k] + * @param[in] eps defines epsilon neighborhood radius + * @param[inout] max_k if nullopt (default), the user needs to make 2 subsequent calls: + * The first call computes row offsets in adj_ia, where adj_ia[m] + * contains the minimum required size for adj_ja. + * The second call fills in adj_ja based on adj_ia. + * If max_k != nullopt the algorithm only fills up neighbors up to a + * maximum number of max_k for each row in a single pass. Note + * that it is not guarantueed to return the nearest neighbors. + * Upon return max_k is overwritten with the actual max_k found during + * computation. + */ +template +void eps_nn(raft::resources const& handle, + const BallCoverIndex& index, + raft::device_vector_view adj_ia, + raft::device_vector_view adj_ja, + raft::device_vector_view vd, + raft::device_matrix_view query, + value_t eps, + std::optional> max_k = std::nullopt) +{ + ASSERT(index.n == query.extent(1), "vector dimension needs to be the same for index and queries"); + ASSERT(index.metric == raft::distance::DistanceType::L2SqrtExpanded || + index.metric == raft::distance::DistanceType::L2SqrtUnexpanded, + "Metric not supported"); + ASSERT(index.is_index_trained(), "index must be previously trained"); + + int_t* max_k_ptr = nullptr; + if (max_k.has_value()) { max_k_ptr = max_k.value().data_handle(); } + + // run query + raft::spatial::knn::detail::rbc_eps_nn_query( + handle, + index, + eps, + max_k_ptr, + query.data_handle(), + query.extent(0), + adj_ia.data_handle(), + adj_ja.data_handle(), + vd.data_handle(), + spatial::knn::detail::EuclideanFunc()); +} + /** * @ingroup random_ball_cover * @{ @@ -377,7 +476,7 @@ void knn_query(raft::resources const& handle, index, k, query.data_handle(), - query.extent(0), + (int_t)query.extent(0), inds.data_handle(), dists.data_handle(), perform_post_filtering, diff --git a/cpp/include/raft/neighbors/ball_cover.cuh b/cpp/include/raft/neighbors/ball_cover.cuh index 41c5d0310c..20c88f3318 100644 --- a/cpp/include/raft/neighbors/ball_cover.cuh +++ b/cpp/include/raft/neighbors/ball_cover.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ * limitations under the License. */ #pragma once - #ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "ball_cover-inl.cuh" #endif diff --git a/cpp/include/raft/neighbors/ball_cover_types.hpp b/cpp/include/raft/neighbors/ball_cover_types.hpp index 0a6ad8c407..dc96f0d45b 100644 --- a/cpp/include/raft/neighbors/ball_cover_types.hpp +++ b/cpp/include/raft/neighbors/ball_cover_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -41,8 +42,8 @@ namespace raft::neighbors::ball_cover { */ template + typename value_int = std::int64_t, + typename matrix_idx = std::int64_t> class BallCoverIndex { public: explicit BallCoverIndex(raft::resources const& handle_, diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index ddce6d8fda..e8c25f355b 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index dff9aceb8d..adcb639301 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -443,13 +443,13 @@ void brute_force_knn_impl( if (metric == raft::distance::DistanceType::L2SqrtExpanded || metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::LpUnexpanded) { - float p = 0.5; // standard l2 + value_t p = 0.5; // standard l2 if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; - raft::linalg::unaryOp( + raft::linalg::unaryOp( res_D, res_D, n * k, - [p] __device__(float input) { return powf(fabsf(input), p); }, + [p] __device__(value_t input) { return powf(fabsf(input), p); }, stream); } } else { diff --git a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh index 3d03d6db4f..c8ff03741c 100644 --- a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh +++ b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -111,7 +111,7 @@ inline void knn_merge_parts_impl(const value_t* inK, { auto grid = dim3(n_samples); - constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; + constexpr int n_threads = (warp_q < 1024) ? 128 : 64; auto block = dim3(n_threads); auto kInit = std::numeric_limits::max(); diff --git a/cpp/include/raft/sparse/linalg/spmm.hpp b/cpp/include/raft/sparse/linalg/spmm.hpp index c2fdd64574..03c97fdb9d 100644 --- a/cpp/include/raft/sparse/linalg/spmm.hpp +++ b/cpp/include/raft/sparse/linalg/spmm.hpp @@ -42,7 +42,7 @@ namespace linalg { * @param[in] x input raft::device_csr_matrix_view * @param[in] y input raft::device_matrix_view * @param[in] beta scalar - * @param[out] z output raft::device_matrix_view + * @param[inout] z input-output raft::device_matrix_view */ template ( + z.data_handle(), z.extent(0), z.extent(1), is_row_major ? z.stride(0) : z.stride(1)); + auto descr_x = detail::create_descriptor(x); auto descr_y = detail::create_descriptor(y); - auto descr_z = detail::create_descriptor(z); + auto descr_z = detail::create_descriptor(z_tmp_view); detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); @@ -76,4 +79,4 @@ void spmm(raft::resources const& handle, } // end namespace sparse } // end namespace raft -#endif +#endif \ No newline at end of file diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 4fe60e304b..879f54fd81 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ #include +#include #include #include #include @@ -64,9 +65,12 @@ namespace detail { * @param handle * @param index */ -template +template void sample_landmarks(raft::resources const& handle, - BallCoverIndex& index) + BallCoverIndex& index) { rmm::device_uvector R_1nn_cols2(index.n_landmarks, resource::get_cuda_stream(handle)); rmm::device_uvector R_1nn_ones(index.m, resource::get_cuda_stream(handle)); @@ -100,8 +104,6 @@ void sample_landmarks(raft::resources const& handle, (value_idx)index.n_landmarks, (value_idx)index.m); - // index.get_X() returns the wrong indextype (uint32_t where we need value_idx), so need to - // create new device_matrix_view here auto x = index.get_X(); auto r = index.get_R(); @@ -122,12 +124,15 @@ void sample_landmarks(raft::resources const& handle, * @param k * @param index */ -template +template void construct_landmark_1nn(raft::resources const& handle, const value_idx* R_knn_inds_ptr, const value_t* R_knn_dists_ptr, value_int k, - BallCoverIndex& index) + BallCoverIndex& index) { rmm::device_uvector R_1nn_inds(index.m, resource::get_cuda_stream(handle)); @@ -177,9 +182,12 @@ void construct_landmark_1nn(raft::resources const& handle, * @param R_knn_inds * @param R_knn_dists */ -template +template void k_closest_landmarks(raft::resources const& handle, - const BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query_pts, value_int n_query_pts, value_int k, @@ -205,9 +213,12 @@ void k_closest_landmarks(raft::resources const& handle, * @param handle * @param index */ -template +template void compute_landmark_radii(raft::resources const& handle, - BallCoverIndex& index) + BallCoverIndex& index) { auto entries = thrust::make_counting_iterator(0); @@ -235,13 +246,14 @@ void compute_landmark_radii(raft::resources const& handle, */ template void perform_rbc_query(raft::resources const& handle, - const BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, value_int n_query_pts, - std::uint32_t k, + value_int k, const value_idx* R_knn_inds, const value_t* R_knn_dists, dist_func dfunc, @@ -264,66 +276,128 @@ void perform_rbc_query(raft::resources const& handle, if (index.n == 2) { // Compute nearest k for each neighborhood in each closest R - rbc_low_dim_pass_one(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - dists_counter); + rbc_low_dim_pass_one(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + dists_counter); if (perform_post_filtering) { - rbc_low_dim_pass_two(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - post_dists_counter); + rbc_low_dim_pass_two(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + post_dists_counter); } } else if (index.n == 3) { // Compute nearest k for each neighborhood in each closest R - rbc_low_dim_pass_one(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - dists_counter); + rbc_low_dim_pass_one(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + dists_counter); if (perform_post_filtering) { - rbc_low_dim_pass_two(handle, - index, - query, - n_query_pts, - k, - R_knn_inds, - R_knn_dists, - dfunc, - inds, - dists, - weight, - post_dists_counter); + rbc_low_dim_pass_two(handle, + index, + query, + n_query_pts, + k, + R_knn_inds, + R_knn_dists, + dfunc, + inds, + dists, + weight, + post_dists_counter); } } } +/** + * Perform eps-select + * + * a. Map 1 row to each warp/block + * b. Add closest k R points to heap + * c. Iterate through batches of R, having each thread in the warp load a set + * of distances y from R (only if d(q, r) < 3 * distance to closest r) and + * marking the distance to be computed between x, y only + * if knn[k].distance >= d(x_i, R_k) + d(R_k, y) + */ +template +void perform_rbc_eps_nn_query( + raft::resources const& handle, + const BallCoverIndex& index, + const value_t* query, + value_int n_query_pts, + value_t eps, + const value_t* landmark_dists, + dist_func dfunc, + bool* adj, + value_idx* vd) +{ + // initialize output + RAFT_CUDA_TRY(cudaMemsetAsync( + adj, 0, index.m * n_query_pts * sizeof(bool), resource::get_cuda_stream(handle))); + + resource::sync_stream(handle); + + rbc_eps_pass( + handle, index, query, n_query_pts, eps, landmark_dists, dfunc, adj, vd); + + resource::sync_stream(handle); +} + +template +void perform_rbc_eps_nn_query( + raft::resources const& handle, + const BallCoverIndex& index, + const value_t* query, + value_int n_query_pts, + value_t eps, + value_int* max_k, + const value_t* landmark_dists, + dist_func dfunc, + value_idx* adj_ia, + value_idx* adj_ja, + value_idx* vd) +{ + rbc_eps_pass( + handle, index, query, n_query_pts, eps, max_k, landmark_dists, dfunc, adj_ia, adj_ja, vd); + + resource::sync_stream(handle); +} + /** * Similar to a ball tree, the random ball cover algorithm * uses the triangle inequality to prune distance computations @@ -337,13 +411,13 @@ void perform_rbc_query(raft::resources const& handle, */ template void rbc_build_index(raft::resources const& handle, - BallCoverIndex& index, + BallCoverIndex& index, distance_func dfunc) { - ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); ASSERT(!index.is_index_trained(), "index cannot be previously trained"); rmm::device_uvector R_knn_inds(index.m, resource::get_cuda_stream(handle)); @@ -396,10 +470,11 @@ void rbc_build_index(raft::resources const& handle, */ template void rbc_all_knn_query(raft::resources const& handle, - BallCoverIndex& index, + BallCoverIndex& index, value_int k, value_idx* inds, value_t* dists, @@ -469,10 +544,11 @@ void rbc_all_knn_query(raft::resources const& handle, */ template void rbc_knn_query(raft::resources const& handle, - const BallCoverIndex& index, + const BallCoverIndex& index, value_int k, const value_t* query, value_int n_query_pts, @@ -539,6 +615,106 @@ void rbc_knn_query(raft::resources const& handle, perform_post_filtering); } +template +void compute_landmark_dists(raft::resources const& handle, + const BallCoverIndex& index, + const value_t* query_pts, + value_int n_query_pts, + value_t* R_dists) +{ + // compute distances for all queries against all landmarks + // index.get_R() -- landmark points in row order (index.n_landmarks x index.k) + // query_pts -- query points in row order (n_query_pts x index.k) + RAFT_EXPECTS(std::max(index.n_landmarks, n_query_pts) * index.n < + static_cast(std::numeric_limits::max()), + "Too large input for pairwise_distance with `int` index."); + RAFT_EXPECTS(n_query_pts * static_cast(index.n_landmarks) < + static_cast(std::numeric_limits::max()), + "Too large input for pairwise_distance with `int` index."); + raft::distance::pairwise_distance(handle, + query_pts, + index.get_R().data_handle(), + R_dists, + n_query_pts, + index.n_landmarks, + index.n, + index.get_metric()); +} + +/** + * Performs a knn query against an index. This assumes the index has + * already been built. + * Modified version that takes an eps as threshold and outputs to a dense adj matrix (row-major) + * we are assuming that there are sufficiently many landmarks + */ +template +void rbc_eps_nn_query(raft::resources const& handle, + const BallCoverIndex& index, + const value_t eps, + const value_t* query, + value_int n_query_pts, + bool* adj, + value_idx* vd, + distance_func dfunc) +{ + ASSERT(index.is_index_trained(), "index must be previously trained"); + + auto R_dists = + raft::make_device_matrix(handle, index.n_landmarks, n_query_pts); + + // find all landmarks that might have points in range + compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle()); + + // query all points and write to adj + perform_rbc_eps_nn_query( + handle, index, query, n_query_pts, eps, R_dists.data_handle(), dfunc, adj, vd); +} + +template +void rbc_eps_nn_query(raft::resources const& handle, + const BallCoverIndex& index, + const value_t eps, + value_int* max_k, + const value_t* query, + value_int n_query_pts, + value_idx* adj_ia, + value_idx* adj_ja, + value_idx* vd, + distance_func dfunc) +{ + ASSERT(index.is_index_trained(), "index must be previously trained"); + + auto R_dists = + raft::make_device_matrix(handle, index.n_landmarks, n_query_pts); + + // find all landmarks that might have points in range + compute_landmark_dists(handle, index, query, n_query_pts, R_dists.data_handle()); + + // query all points and write to adj + perform_rbc_eps_nn_query(handle, + index, + query, + n_query_pts, + eps, + max_k, + R_dists.data_handle(), + dfunc, + adj_ia, + adj_ja, + vd); +} + }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh index 70c5cec23f..2ed6ee3284 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,11 +27,12 @@ namespace raft::spatial::knn::detail { template void rbc_low_dim_pass_one(raft::resources const& handle, - const BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, const value_int n_query_rows, value_int k, @@ -45,11 +46,12 @@ void rbc_low_dim_pass_one(raft::resources const& handle, template void rbc_low_dim_pass_two(raft::resources const& handle, - const BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, const value_int n_query_rows, value_int k, @@ -61,69 +63,133 @@ void rbc_low_dim_pass_two(raft::resources const& handle, float weight, value_int* post_dists_counter) RAFT_EXPLICIT; +template +void rbc_eps_pass(raft::resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_t eps, + const value_t* R_dists, + dist_func& dfunc, + bool* adj, + value_idx* vd) RAFT_EXPLICIT; + +template +void rbc_eps_pass(raft::resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_t eps, + value_int* max_k, + const value_t* R_dists, + dist_func& dfunc, + value_idx* adj_ia, + value_idx* adj_ja, + value_idx* vd) RAFT_EXPLICIT; + }; // namespace raft::spatial::knn::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - extern template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_one( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) - -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - extern template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_two( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + extern template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_one( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + extern template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_two( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_eps_pass( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdist_func) \ + extern template void \ + raft::spatial::knn::detail::rbc_eps_pass( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_t eps, \ + const Mvalue_t* R_dists, \ + Mdist_func& dfunc, \ + bool* adj, \ + Mvalue_idx* vd); \ + \ + extern template void \ + raft::spatial::knn::detail::rbc_eps_pass( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_t eps, \ + Mvalue_int* max_k, \ + const Mvalue_t* R_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* adj_ia, \ + Mvalue_idx* adj_ja, \ + Mvalue_idx* vd); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::HaversineFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::HaversineFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::EuclideanFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::EuclideanFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::DistFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::DistFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::HaversineFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::HaversineFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::EuclideanFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::EuclideanFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::DistFunc); instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::DistFunc); + +instantiate_raft_spatial_knn_detail_rbc_eps_pass( + std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one +#undef instantiate_raft_spatial_knn_detail_rbc_eps_pass diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index 9e75f3c9c8..8b4e8f287e 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,9 @@ #include #include +#include #include +#include namespace raft { namespace spatial { @@ -454,13 +456,259 @@ RAFT_KERNEL block_rbc_kernel_registers(const value_t* X_index, } } -template +RAFT_KERNEL block_rbc_kernel_eps_dense(const value_t* X_index, + const value_t* X, + const value_int n_cols, + const value_t* R_dists, + const value_int m, + const value_t eps, + const value_int n_landmarks, + const value_idx* R_indptr, + const value_idx* R_1nn_cols, + const value_t* R_1nn_dists, + const value_t* R_radius, + distance_func dfunc, + bool* adj, + value_idx* vd) +{ + __shared__ int column_count_smem; + + // initialize + if (vd != nullptr) { + if (threadIdx.x == 0) { column_count_smem = 0; } + __syncthreads(); + } + + const value_t* x_ptr = X + (n_cols * blockIdx.x); + + for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { + // TODO: this might also be worth computing in-place here + value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; + + // prune all R's that can't be within eps + if (cur_R_dist - R_radius[cur_k] > eps) continue; + + // The whole warp should iterate through the elements in the current R + value_idx R_start_offset = R_indptr[cur_k]; + value_idx R_stop_offset = R_indptr[cur_k + 1]; + + value_idx R_size = R_stop_offset - R_start_offset; + + value_int limit = Pow2::roundDown(R_size); + value_int i = threadIdx.x; + for (; i < limit; i += tpb) { + // Index and distance of current candidate's nearest landmark + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { + adj[blockIdx.x * m + cur_candidate_ind] = true; + if (vd != nullptr) atomicAdd(&column_count_smem, 1); + } + } + + if (i < R_size) { + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { + adj[blockIdx.x * m + cur_candidate_ind] = true; + if (vd != nullptr) atomicAdd(&column_count_smem, 1); + } + } + } + + if (vd != nullptr) { + __syncthreads(); + if (threadIdx.x == 0) { vd[blockIdx.x] = column_count_smem; } + } +} + +template +RAFT_KERNEL block_rbc_kernel_eps_csr_pass(const value_t* X_index, + const value_t* X, + const value_int n_cols, + const value_t* R_dists, + const value_int m, + const value_t eps, + const value_int n_landmarks, + const value_idx* R_indptr, + const value_idx* R_1nn_cols, + const value_t* R_1nn_dists, + const value_t* R_radius, + distance_func dfunc, + value_idx* adj_ia, + value_idx* adj_ja) +{ + const value_t* x_ptr = X + (n_cols * blockIdx.x); + + __shared__ unsigned long long int column_index_smem; + + bool pass2 = adj_ja != nullptr; + + // initialize + if (threadIdx.x == 0) { column_index_smem = pass2 ? adj_ia[blockIdx.x] : 0; } + + __syncthreads(); + + for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { + // TODO: this might also be worth computing in-place here + value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; + + // prune all R's that can't be within eps + if (cur_R_dist - R_radius[cur_k] > eps) continue; + + // The whole warp should iterate through the elements in the current R + value_idx R_start_offset = R_indptr[cur_k]; + value_idx R_stop_offset = R_indptr[cur_k + 1]; + + value_idx R_size = R_stop_offset - R_start_offset; + + value_int limit = Pow2::roundDown(R_size); + value_int i = threadIdx.x; + for (; i < limit; i += tpb) { + // Index and distance of current candidate's nearest landmark + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { + auto row_pos = atomicAdd(&column_index_smem, 1); + if (pass2) adj_ja[row_pos] = cur_candidate_ind; + } + } + + if (i < R_size) { + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { + auto row_pos = atomicAdd(&column_index_smem, 1); + if (pass2) adj_ja[row_pos] = cur_candidate_ind; + } + } + } + + __syncthreads(); + if (threadIdx.x == 0 && !pass2) { adj_ia[blockIdx.x] = (value_idx)column_index_smem; } +} + +template +RAFT_KERNEL block_rbc_kernel_eps_max_k(const value_t* X_index, + const value_t* X, + const value_int n_cols, + const value_t* R_dists, + const value_int m, + const value_t eps, + const value_int n_landmarks, + const value_idx* R_indptr, + const value_idx* R_1nn_cols, + const value_t* R_1nn_dists, + const value_t* R_radius, + distance_func dfunc, + value_idx* vd, + const value_int max_k, + value_idx* tmp) +{ + const value_t* x_ptr = X + (n_cols * blockIdx.x); + + __shared__ int column_count_smem; + + // initialize + if (threadIdx.x == 0) { column_count_smem = 0; } + + __syncthreads(); + + // we store all column indices in dense tmp store [blockDim.x * max_k] + value_int offset = blockIdx.x * max_k; + + for (value_int cur_k = 0; cur_k < n_landmarks; ++cur_k) { + // TODO: this might also be worth computing in-place here + value_t cur_R_dist = R_dists[blockIdx.x * n_landmarks + cur_k]; + + // prune all R's that can't be within eps + if (cur_R_dist - R_radius[cur_k] > eps) continue; + + // The whole warp should iterate through the elements in the current R + value_idx R_start_offset = R_indptr[cur_k]; + value_idx R_stop_offset = R_indptr[cur_k + 1]; + + value_idx R_size = R_stop_offset - R_start_offset; + + value_int limit = Pow2::roundDown(R_size); + value_int i = threadIdx.x; + for (; i < limit; i += tpb) { + // Index and distance of current candidate's nearest landmark + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { + int row_pos = atomicAdd(&column_count_smem, 1); + if (row_pos < max_k) tmp[row_pos + offset] = cur_candidate_ind; + } + } + + if (i < R_size) { + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_t cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + const value_t* y_ptr = X_index + (n_cols * cur_candidate_ind); + if (dfunc(x_ptr, y_ptr, n_cols) <= eps) { + int row_pos = atomicAdd(&column_count_smem, 1); + if (row_pos < max_k) tmp[row_pos + offset] = cur_candidate_ind; + } + } + } + + __syncthreads(); + if (threadIdx.x == 0) { vd[blockIdx.x] = column_count_smem; } +} + +template +RAFT_KERNEL block_rbc_kernel_eps_max_k_copy(const value_int max_k, + const value_idx* adj_ia, + const value_idx* tmp, + value_idx* adj_ja) +{ + value_int offset = blockIdx.x * max_k; + + value_int row_idx = blockIdx.x; + value_idx col_start_idx = adj_ia[row_idx]; + value_idx num_cols = adj_ia[row_idx + 1] - col_start_idx; + + value_int limit = Pow2::roundDown(num_cols); + value_int i = threadIdx.x; + for (; i < limit; i += tpb) { + adj_ja[col_start_idx + i] = tmp[offset + i]; + } + if (i < num_cols) { adj_ja[col_start_idx + i] = tmp[offset + i]; } +} + +template void rbc_low_dim_pass_one(raft::resources const& handle, - const BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, const value_int n_query_rows, value_int k, @@ -594,11 +842,12 @@ void rbc_low_dim_pass_one(raft::resources const& handle, template void rbc_low_dim_pass_two(raft::resources const& handle, - const BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, const value_int n_query_rows, value_int k, @@ -788,6 +1037,179 @@ void rbc_low_dim_pass_two(raft::resources const& handle, post_dists_counter); } +template +void rbc_eps_pass(raft::resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_t eps, + const value_t* R_dists, + dist_func& dfunc, + bool* adj, + value_idx* vd) +{ + block_rbc_kernel_eps_dense + <<>>( + index.get_X().data_handle(), + query, + index.n, + R_dists, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj, + vd); + + if (vd != nullptr) { + value_idx sum = thrust::reduce(resource::get_thrust_policy(handle), vd, vd + n_query_rows); + // copy sum to last element + RAFT_CUDA_TRY(cudaMemcpyAsync(vd + n_query_rows, + &sum, + sizeof(value_idx), + cudaMemcpyHostToDevice, + resource::get_cuda_stream(handle))); + } + + resource::sync_stream(handle); +} + +template +void rbc_eps_pass(raft::resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_t eps, + value_int* max_k, + const value_t* R_dists, + dist_func& dfunc, + value_idx* adj_ia, + value_idx* adj_ja, + value_idx* vd) +{ + // if max_k == nullptr we are either pass 1 or pass 2 + if (max_k == nullptr) { + if (adj_ja == nullptr) { + // pass 1 -> only compute adj_ia / vd + value_idx* vd_ptr = (vd != nullptr) ? vd : adj_ia; + block_rbc_kernel_eps_csr_pass + <<>>( + index.get_X().data_handle(), + query, + index.n, + R_dists, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + nullptr); + + thrust::exclusive_scan(resource::get_thrust_policy(handle), + vd_ptr, + vd_ptr + n_query_rows + 1, + adj_ia, + (value_idx)0); + + } else { + // pass 2 -> fill in adj_ja + block_rbc_kernel_eps_csr_pass + <<>>( + index.get_X().data_handle(), + query, + index.n, + R_dists, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + adj_ia, + adj_ja); + } + } else { + value_int max_k_in = *max_k; + value_idx* vd_ptr = (vd != nullptr) ? vd : adj_ia; + + rmm::device_uvector tmp(n_query_rows * max_k_in, resource::get_cuda_stream(handle)); + + block_rbc_kernel_eps_max_k + <<>>( + index.get_X().data_handle(), + query, + index.n, + R_dists, + index.m, + eps, + index.n_landmarks, + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + index.get_R_radius().data_handle(), + dfunc, + vd_ptr, + max_k_in, + tmp.data()); + + value_int actual_max = thrust::reduce(resource::get_thrust_policy(handle), + vd_ptr, + vd_ptr + n_query_rows, + (value_idx)0, + thrust::maximum()); + + if (actual_max > max_k_in) { + // ceil vd to max_k + thrust::transform(resource::get_thrust_policy(handle), + vd_ptr, + vd_ptr + n_query_rows, + vd_ptr, + [max_k_in] __device__(value_idx vd_count) { + return vd_count > max_k_in ? max_k_in : vd_count; + }); + } + + thrust::exclusive_scan( + resource::get_thrust_policy(handle), vd_ptr, vd_ptr + n_query_rows + 1, adj_ia, (value_idx)0); + + block_rbc_kernel_eps_max_k_copy + <<>>( + max_k_in, adj_ia, tmp.data(), adj_ja); + + // return 'new' max-k + *max_k = actual_max; + } + + if (vd != nullptr && (max_k != nullptr || adj_ja == nullptr)) { + // copy sum to last element + RAFT_CUDA_TRY(cudaMemcpyAsync(vd + n_query_rows, + adj_ia + n_query_rows, + sizeof(value_idx), + cudaMemcpyDeviceToDevice, + resource::get_cuda_stream(handle))); + } + + resource::sync_stream(handle); +} + }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index cb0ca6cc68..7a5a217959 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -95,13 +95,10 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { IdxT startx = blockIdx.x * P::Mblk + this->accrowid; IdxT starty = blockIdx.y * P::Nblk + this->acccolid; auto lid = raft::laneId(); - IdxT sums[P::AccColsPerTh]; -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - sums[j] = 0; - } + IdxT sums[P::AccRowsPerTh]; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { + sums[i] = 0; auto xid = startx + i * P::AccThRows; #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { @@ -110,7 +107,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { ///@todo: fix uncoalesced writes using shared mem if (xid < this->m && yid < this->n) { adj[xid * this->n + yid] = is_neigh; - sums[j] += is_neigh; + sums[i] += is_neigh; } } } @@ -137,19 +134,21 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } } - DI void updateVertexDegree(IdxT (&sums)[P::AccColsPerTh]) + DI void updateVertexDegree(IdxT (&sums)[P::AccRowsPerTh]) { __syncthreads(); // so that we can safely reuse smem - int gid = threadIdx.x / P::AccThCols; - int lid = threadIdx.x % P::AccThCols; - auto cidx = IdxT(blockIdx.y) * P::Nblk + lid; + int gid = this->accrowid; + int lid = this->acccolid; + auto cidx = IdxT(blockIdx.x) * P::Mblk + gid; IdxT totalSum = 0; // update the individual vertex degrees #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - sums[i] = batchedBlockReduce(sums[i], smem); - auto cid = cidx + i * P::AccThCols; - if (gid == 0 && cid < this->n) { + for (int i = 0; i < P::AccRowsPerTh; ++i) { + // P::AccThCols neighboring threads need to reduce + // -> we have P::Nblk/P::AccThCols individual reductions + auto cid = cidx + i * P::AccThRows; + sums[i] = raft::logicalWarpReduce(sums[i], raft::add_op()); + if (lid == 0 && cid < this->m) { atomicUpdate(cid, sums[i]); totalSum += sums[i]; } @@ -157,7 +156,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } // update the total edge count totalSum = raft::blockReduce(totalSum, smem); - if (threadIdx.x == 0) { atomicUpdate(this->n, totalSum); } + if (threadIdx.x == 0) { atomicUpdate(this->m, totalSum); } } DI void atomicUpdate(IdxT addrId, IdxT val) @@ -226,6 +225,8 @@ void epsUnexpL2SqNeighborhood(bool* adj, DataT eps, cudaStream_t stream) { + if (vd != nullptr) { RAFT_CUDA_TRY(cudaMemsetAsync(vd, 0, (m + 1) * sizeof(IdxT), stream)); } + size_t bytes = sizeof(DataT) * k; if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) { epsUnexpL2SqNeighImpl(adj, vd, x, y, m, n, k, eps, stream); @@ -238,4 +239,4 @@ void epsUnexpL2SqNeighborhood(bool* adj, } // namespace detail } // namespace knn } // namespace spatial -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh index 1a48e1adde..0eca119450 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh index 30ebab43b6..0c9f0059f9 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -969,12 +969,16 @@ void fusedL2Knn(size_t D, size_t worksize = 0, tempWorksize = 0; rmm::device_uvector workspace(worksize, stream); value_idx lda = D, ldb = D, ldd = n_index_rows; - + // switch (metric) { case raft::distance::DistanceType::L2SqrtExpanded: case raft::distance::DistanceType::L2Expanded: - tempWorksize = raft::distance::detail:: - getWorkspaceSize( + tempWorksize = + raft::distance::detail::getWorkspaceSize( query, index, n_query_rows, n_index_rows, D); worksize = tempWorksize; workspace.resize(worksize, stream); diff --git a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh index 5b8cc36368..5fb912843d 100644 --- a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh +++ b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -133,8 +133,10 @@ void haversine_knn(value_idx* out_inds, int k, cudaStream_t stream) { - haversine_knn_kernel<<>>( - out_inds, out_dists, index, query, n_index_rows, k); + // ensure kernel does not breach shared memory limits + constexpr int kWarpQ = sizeof(value_t) > 4 ? 512 : 1024; + haversine_knn_kernel + <<>>(out_inds, out_dists, index, query, n_index_rows, k); } } // namespace detail diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index e718ca3545..bf46e069e4 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,12 @@ #pragma once +#include +#include #include #include #include -#if defined(_RAFT_HAS_CUDA) -#include -#include -#endif - #include #include #include @@ -278,17 +275,53 @@ template <> * @{ */ template -inline __device__ T myInf(); -template <> -inline __device__ float myInf() +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, float> myInf() { return CUDART_INF_F; } -template <> -inline __device__ double myInf() +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, double> myInf() { return CUDART_INF; } +// Half/Bfloat constants only defined after CUDA 12.2 +#if __CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2) +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, __half> myInf() +{ +#if (__CUDA_ARCH__ >= 530) + return __ushort_as_half((unsigned short)0x7C00U); +#else + // Fail during template instantiation if the compute capability doesn't support this operation + static_assert(sizeof(T) != sizeof(T), "__half is only supported on __CUDA_ARCH__ >= 530"); + return T{}; +#endif +} +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, nv_bfloat16> +myInf() +{ +#if (__CUDA_ARCH__ >= 800) + return __ushort_as_bfloat16((unsigned short)0x7F80U); +#else + // Fail during template instantiation if the compute capability doesn't support this operation + static_assert(sizeof(T) != sizeof(T), "nv_bfloat16 is only supported on __CUDA_ARCH__ >= 800"); + return T{}; +#endif +} +#else +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, __half> myInf() +{ + return CUDART_INF_FP16; +} +template +RAFT_DEVICE_INLINE_FUNCTION typename std::enable_if_t, nv_bfloat16> +myInf() +{ + return CUDART_INF_BF16; +} +#endif /** @} */ /** diff --git a/cpp/include/raft_runtime/neighbors/eps_neighborhood.hpp b/cpp/include/raft_runtime/neighbors/eps_neighborhood.hpp new file mode 100644 index 0000000000..ee1ca846f6 --- /dev/null +++ b/cpp/include/raft_runtime/neighbors/eps_neighborhood.hpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::runtime::neighbors::epsilon_neighborhood { + +#define RAFT_INST_BFEPSN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void eps_neighbors(raft::resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view adj, \ + raft::device_vector_view vd, \ + DATA_T eps); + +RAFT_INST_BFEPSN(int64_t, float, int64_t, raft::row_major, raft::row_major); + +#undef RAFT_INST_BFEPSN + +#define RAFT_INST_RBCEPSN(IDX_T, DATA_T, INT_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void eps_neighbors_rbc( \ + raft::resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view adj, \ + raft::device_vector_view vd, \ + DATA_T eps); \ + void build_rbc_index( \ + raft::resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& rbc_index); \ + void eps_neighbors_rbc_pass1( \ + raft::resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex rbc_index, \ + raft::device_matrix_view search, \ + raft::device_vector_view adj_ia, \ + raft::device_vector_view vd, \ + DATA_T eps); \ + void eps_neighbors_rbc_pass2( \ + raft::resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex rbc_index, \ + raft::device_matrix_view search, \ + raft::device_vector_view adj_ia, \ + raft::device_vector_view adj_ja, \ + raft::device_vector_view vd, \ + DATA_T eps); + +RAFT_INST_RBCEPSN(int64_t, float, int64_t, int64_t, raft::row_major, raft::row_major); + +#undef RAFT_INST_RBCEPSN + +} // namespace raft::runtime::neighbors::epsilon_neighborhood diff --git a/cpp/src/neighbors/ball_cover.cu b/cpp/src/neighbors/ball_cover.cu index 3b129e168b..0a59060c8e 100644 --- a/cpp/src/neighbors/ball_cover.cu +++ b/cpp/src/neighbors/ball_cover.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,24 @@ raft::resources const& handle, \ raft::neighbors::ball_cover::BallCoverIndex& index); \ \ + template void raft::neighbors::ball_cover::eps_nn( \ + raft::resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view adj, \ + raft::device_vector_view vd, \ + raft::device_matrix_view query, \ + value_t eps); \ + \ + template void raft::neighbors::ball_cover::eps_nn( \ + raft::resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_vector_view ia, \ + raft::device_vector_view ja, \ + raft::device_vector_view vd, \ + raft::device_matrix_view query, \ + value_t eps, \ + std::optional> max_k); \ + \ template void raft::neighbors::ball_cover::all_knn_query( \ raft::resources const& handle, \ raft::neighbors::ball_cover::BallCoverIndex& index, \ @@ -40,9 +58,9 @@ bool perform_post_filtering, \ float weight); \ \ - template void raft::neighbors::ball_cover::knn_query( \ + template void raft::neighbors::ball_cover::knn_query( \ raft::resources const& handle, \ - const raft::neighbors::ball_cover::BallCoverIndex& index, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ int_t k, \ const value_t* query, \ int_t n_query_pts, \ @@ -61,6 +79,6 @@ bool perform_post_filtering, \ float weight); -instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t); +instantiate_raft_neighbors_ball_cover(int64_t, float, int64_t, int64_t); #undef instantiate_raft_neighbors_ball_cover diff --git a/cpp/src/neighbors/brute_force_00_generate.py b/cpp/src/neighbors/brute_force_00_generate.py index 9adc5fef90..8ed05dc4c2 100644 --- a/cpp/src/neighbors/brute_force_00_generate.py +++ b/cpp/src/neighbors/brute_force_00_generate.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ header = """ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/raft_runtime/neighbors/eps_neighborhood.cu b/cpp/src/raft_runtime/neighbors/eps_neighborhood.cu new file mode 100644 index 0000000000..23cb6fd790 --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/eps_neighborhood.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include + +namespace raft::runtime::neighbors::epsilon_neighborhood { + +#define RAFT_INST_BFEPSN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void eps_neighbors(raft::resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view adj, \ + raft::device_vector_view vd, \ + DATA_T eps) \ + { \ + raft::neighbors::epsilon_neighborhood::eps_neighbors_l2sq( \ + handle, search, index, adj, vd, eps* eps); \ + } + +RAFT_INST_BFEPSN(int64_t, float, int64_t, raft::row_major, raft::row_major); + +#undef RAFT_INST_BFEPSN + +#define RAFT_INST_RBCEPSN(IDX_T, DATA_T, INT_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void eps_neighbors_rbc( \ + raft::resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view adj, \ + raft::device_vector_view vd, \ + DATA_T eps) \ + { \ + raft::neighbors::ball_cover::BallCoverIndex rbc_index( \ + handle, \ + index.data_handle(), \ + index.extent(0), \ + index.extent(1), \ + raft::distance::DistanceType::L2SqrtUnexpanded); \ + raft::neighbors::ball_cover::build_index(handle, rbc_index); \ + raft::neighbors::ball_cover::eps_nn(handle, rbc_index, adj, vd, search, eps); \ + } \ + void build_rbc_index( \ + raft::resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& rbc_index) \ + { \ + raft::neighbors::ball_cover::build_index(handle, rbc_index); \ + } \ + void eps_neighbors_rbc_pass1( \ + raft::resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex rbc_index, \ + raft::device_matrix_view search, \ + raft::device_vector_view adj_ia, \ + raft::device_vector_view vd, \ + DATA_T eps) \ + { \ + raft::neighbors::ball_cover::eps_nn( \ + handle, \ + rbc_index, \ + adj_ia, \ + raft::make_device_vector_view(nullptr, 0), \ + vd, \ + search, \ + eps); \ + } \ + void eps_neighbors_rbc_pass2( \ + raft::resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex rbc_index, \ + raft::device_matrix_view search, \ + raft::device_vector_view adj_ia, \ + raft::device_vector_view adj_ja, \ + raft::device_vector_view vd, \ + DATA_T eps) \ + { \ + raft::neighbors::ball_cover::eps_nn(handle, rbc_index, adj_ia, adj_ja, vd, search, eps); \ + } + +RAFT_INST_RBCEPSN(int64_t, float, int64_t, int64_t, raft::row_major, raft::row_major); + +#undef RAFT_INST_RBCEPSN + +} // namespace raft::runtime::neighbors::epsilon_neighborhood diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers.cu b/cpp/src/spatial/knn/detail/ball_cover/registers.cu index 493a602362..31595272b6 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,45 +16,79 @@ #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_one( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - raft::spatial::knn::detail::DistFunc& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) - -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_two( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - raft::spatial::knn::detail::DistFunc& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) - -instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 2); -instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 3); - -instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 2); -instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 3); +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_one( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_two( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_eps_pass( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx) \ + template void \ + raft::spatial::knn::detail::rbc_eps_pass( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_t eps, \ + const Mvalue_t* R_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + bool* adj, \ + Mvalue_idx* vd); \ + \ + template void \ + raft::spatial::knn::detail::rbc_eps_pass( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_t eps, \ + const Mvalue_t* R_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* ia, \ + Mvalue_idx* ja, \ + Mvalue_idx* vd) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::int64_t, std::int64_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::int64_t, std::int64_t, 3); + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::int64_t, std::int64_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::int64_t, std::int64_t, 3); + +instantiate_raft_spatial_knn_detail_rbc_eps_pass(std::int64_t, float, std::int64_t, std::int64_t); + +#undef instantiate_raft_spatial_knn_detail_rbc_eps_pass #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py index d7b6e618fd..dff2e015a4 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. header = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,11 +45,11 @@ macro_pass_one = """ #define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \\ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \\ - template void \\ - raft::spatial::knn::detail::rbc_low_dim_pass_one( \\ - raft::resources const& handle, \\ - const BallCoverIndex& index, \\ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \\ + template void \\ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \\ + raft::resources const& handle, \\ + const BallCoverIndex& index, \\ const Mvalue_t* query, \\ const Mvalue_int n_query_rows, \\ Mvalue_int k, \\ @@ -65,11 +65,11 @@ macro_pass_two = """ #define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \\ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \\ - template void \\ - raft::spatial::knn::detail::rbc_low_dim_pass_two( \\ - raft::resources const& handle, \\ - const BallCoverIndex& index, \\ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \\ + template void \\ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \\ + raft::resources const& handle, \\ + const BallCoverIndex& index, \\ const Mvalue_t* query, \\ const Mvalue_int n_query_rows, \\ Mvalue_int k, \\ @@ -83,20 +83,58 @@ """ +macro_pass_eps = """ +#define instantiate_raft_spatial_knn_detail_rbc_eps_pass( \\ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdist_func) \\ + template void \\ + raft::spatial::knn::detail::rbc_eps_pass( \\ + raft::resources const& handle, \\ + const BallCoverIndex& index, \\ + const Mvalue_t* query, \\ + const Mvalue_int n_query_rows, \\ + Mvalue_t eps, \\ + const Mvalue_t* R_dists, \\ + Mdist_func& dfunc, \\ + bool* adj, \\ + Mvalue_idx* vd); \\ + \\ + template void \\ + raft::spatial::knn::detail::rbc_eps_pass( \\ + raft::resources const& handle, \\ + const BallCoverIndex& index, \\ + const Mvalue_t* query, \\ + const Mvalue_int n_query_rows, \\ + Mvalue_t eps, \\ + Mvalue_int* max_k, \\ + const Mvalue_t* R_dists, \\ + Mdist_func& dfunc, \\ + Mvalue_idx* adj_ia, \\ + Mvalue_idx* adj_ja, \\ + Mvalue_idx* vd) + +""" + + distances = dict( haversine="raft::spatial::knn::detail::HaversineFunc", euclidean="raft::spatial::knn::detail::EuclideanFunc", dist="raft::spatial::knn::detail::DistFunc", ) +types = dict( + int64_float=("std::int64_t", "float"), + #int64_double=("std::int64_t", "double"), +) + for k, v in distances.items(): for dim in [2, 3]: path = f"registers_pass_one_{dim}d_{k}.cu" with open(path, "w") as f: f.write(header) f.write(macro_pass_one) - f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(\n") - f.write(f" std::int64_t, float, std::uint32_t, {dim}, {v});\n") + for type_path, (int_t, data_t) in types.items(): + f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(\n") + f.write(f" {int_t}, {data_t}, {int_t}, {int_t}, {dim}, {v});\n") f.write("#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one\n") print(f"src/spatial/knn/detail/ball_cover/{path}") @@ -106,7 +144,19 @@ with open(path, "w") as f: f.write(header) f.write(macro_pass_two) - f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(\n") - f.write(f" std::int64_t, float, std::uint32_t, {dim}, {v});\n") + for type_path, (int_t, data_t) in types.items(): + f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(\n") + f.write(f" {int_t}, {data_t}, {int_t}, {int_t}, {dim}, {v});\n") f.write("#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two\n") print(f"src/spatial/knn/detail/ball_cover/{path}") + +path="registers_eps_pass_euclidean.cu" +with open(path, "w") as f: + f.write(header) + f.write(macro_pass_eps) + for type_path, (int_t, data_t) in types.items(): + f.write(f"instantiate_raft_spatial_knn_detail_rbc_eps_pass(\n") + f.write(f" {int_t}, {data_t}, {int_t}, {int_t}, {distances['euclidean']});\n") + f.write("#undef instantiate_raft_spatial_knn_detail_rbc_eps_pass\n") + print(f"src/spatial/knn/detail/ball_cover/{path}") + diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu new file mode 100644 index 0000000000..0d09f88b65 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_eps_pass( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_eps_pass( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_t eps, \ + const Mvalue_t* R_dists, \ + Mdist_func& dfunc, \ + bool* adj, \ + Mvalue_idx* vd); \ + \ + template void \ + raft::spatial::knn::detail::rbc_eps_pass( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_t eps, \ + Mvalue_int* max_k, \ + const Mvalue_t* R_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* adj_ia, \ + Mvalue_idx* adj_ja, \ + Mvalue_idx* vd) + +instantiate_raft_spatial_knn_detail_rbc_eps_pass( + std::int64_t, float, std::int64_t, std::int64_t, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_eps_pass diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu index bb9ec284cc..3681acf245 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_one( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_one( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::DistFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu index 2b06d0a1cd..3fa20779b7 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_one( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_one( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::EuclideanFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu index 6f4e4061ac..7abc89cc11 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_one( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_one( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::HaversineFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu index aa407eeb20..6251a86867 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_one( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_one( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::DistFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu index 7918fb79cb..07b97ac718 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_one( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_one( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::EuclideanFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu index f8f29a107c..4fc18184b0 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_one( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_one( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::HaversineFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu index 1facd24510..882496c7d9 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_two( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_two( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::DistFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu index 6e681e2e9b..0a736d7e13 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_two( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_two( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::EuclideanFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu index b4a038ffd7..23aff93966 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_two( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_two( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); + std::int64_t, float, std::int64_t, std::int64_t, 2, raft::spatial::knn::detail::HaversineFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu index bcb27568c1..d3ec2b4c65 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_two( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_two( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::DistFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu index e40d837862..dd9f0e4658 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_two( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_two( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::EuclideanFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu index 8a362bcf16..849bbf0f96 100644 --- a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,23 +26,23 @@ #include // int64_t #include -#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ - Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ - template void \ - raft::spatial::knn::detail::rbc_low_dim_pass_two( \ - raft::resources const& handle, \ - const BallCoverIndex& index, \ - const Mvalue_t* query, \ - const Mvalue_int n_query_rows, \ - Mvalue_int k, \ - const Mvalue_idx* R_knn_inds, \ - const Mvalue_t* R_knn_dists, \ - Mdist_func& dfunc, \ - Mvalue_idx* inds, \ - Mvalue_t* dists, \ - float weight, \ - Mvalue_int* dists_counter) +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mmatrix_idx, Mdims, Mdist_func) \ + template void raft::spatial::knn::detail:: \ + rbc_low_dim_pass_two( \ + raft::resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( - std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); + std::int64_t, float, std::int64_t, std::int64_t, 3, raft::spatial::knn::detail::HaversineFunc); #undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index fc711fc668..62fbdd6edb 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -126,7 +126,7 @@ struct ToRadians { __device__ __host__ float operator()(float a) { return a * (CUDART_PI_F / 180.0); } }; -template +template struct BallCoverInputs { value_int k; value_int n_rows; @@ -136,7 +136,7 @@ struct BallCoverInputs { raft::distance::DistanceType metric; }; -template +template class BallCoverKNNQueryTest : public ::testing::TestWithParam> { protected: void basicTest() @@ -151,26 +151,26 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam X(params.n_rows * params.n_cols, resource::get_cuda_stream(handle)); - rmm::device_uvector Y(params.n_rows, resource::get_cuda_stream(handle)); + rmm::device_uvector Y(params.n_rows, resource::get_cuda_stream(handle)); // Make sure the train and query sets are completely disjoint rmm::device_uvector X2(params.n_query * params.n_cols, resource::get_cuda_stream(handle)); - rmm::device_uvector Y2(params.n_query, resource::get_cuda_stream(handle)); - - raft::random::make_blobs(X.data(), - Y.data(), - params.n_rows, - params.n_cols, - n_centers, - resource::get_cuda_stream(handle)); - - raft::random::make_blobs(X2.data(), - Y2.data(), - params.n_query, - params.n_cols, - n_centers, - resource::get_cuda_stream(handle)); + rmm::device_uvector Y2(params.n_query, resource::get_cuda_stream(handle)); + + raft::random::make_blobs(X.data(), + Y.data(), + params.n_rows, + params.n_cols, + n_centers, + resource::get_cuda_stream(handle)); + + raft::random::make_blobs(X2.data(), + Y2.data(), + params.n_query, + params.n_cols, + n_centers, + resource::get_cuda_stream(handle)); rmm::device_uvector d_ref_I(params.n_query * k, resource::get_cuda_stream(handle)); rmm::device_uvector d_ref_D(params.n_query * k, resource::get_cuda_stream(handle)); @@ -215,7 +215,8 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam index(handle, X_view, metric); build_index(handle, index); - knn_query(handle, index, X2_view, d_pred_I_view, d_pred_D_view, k, true); + knn_query( + handle, index, X2_view, d_pred_I_view, d_pred_D_view, k, true); resource::sync_stream(handle); // What we really want are for the distances to match exactly. The @@ -249,7 +250,7 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam params; }; -template +template class BallCoverAllKNNTest : public ::testing::TestWithParam> { protected: void basicTest() @@ -264,14 +265,14 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam X(params.n_rows * params.n_cols, resource::get_cuda_stream(handle)); - rmm::device_uvector Y(params.n_rows, resource::get_cuda_stream(handle)); + rmm::device_uvector Y(params.n_rows, resource::get_cuda_stream(handle)); - raft::random::make_blobs(X.data(), - Y.data(), - params.n_rows, - params.n_cols, - n_centers, - resource::get_cuda_stream(handle)); + raft::random::make_blobs(X.data(), + Y.data(), + params.n_rows, + params.n_cols, + n_centers, + resource::get_cuda_stream(handle)); rmm::device_uvector d_ref_I(params.n_rows * k, resource::get_cuda_stream(handle)); rmm::device_uvector d_ref_D(params.n_rows * k, resource::get_cuda_stream(handle)); @@ -308,7 +309,8 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam index(handle, X_view, metric); - all_knn_query(handle, index, d_pred_I_view, d_pred_D_view, k, true); + all_knn_query( + handle, index, d_pred_I_view, d_pred_D_view, k, true); resource::sync_stream(handle); // What we really want are for the distances to match exactly. The @@ -348,7 +350,7 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam BallCoverAllKNNTestF; typedef BallCoverKNNQueryTest BallCoverKNNQueryTestF; -const std::vector> ballcover_inputs = { +const std::vector> ballcover_inputs = { {11, 5000, 2, 1.0, 10000, raft::distance::DistanceType::Haversine}, {25, 10000, 2, 1.0, 5000, raft::distance::DistanceType::Haversine}, {2, 10000, 2, 1.0, 5000, raft::distance::DistanceType::L2SqrtUnexpanded}, diff --git a/cpp/test/neighbors/epsilon_neighborhood.cu b/cpp/test/neighbors/epsilon_neighborhood.cu index 1601037edb..8b35e3ca70 100644 --- a/cpp/test/neighbors/epsilon_neighborhood.cu +++ b/cpp/test/neighbors/epsilon_neighborhood.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +18,11 @@ #include #include #include +#include #include +#include #include +#include #include #include #include @@ -82,8 +85,11 @@ class EpsNeighTest : public ::testing::TestWithParam> { IdxT batchSize; }; // class EpsNeighTest -const std::vector> inputsfi = { +const std::vector> inputsfi = { + {100, 16, 5, 2, 2.f}, + {1500, 16, 5, 3, 2.f}, {15000, 16, 5, 1, 2.f}, + {15000, 3, 5, 1, 2.f}, {14000, 16, 5, 1, 2.f}, {15000, 17, 5, 1, 2.f}, {14000, 17, 5, 1, 2.f}, @@ -91,31 +97,317 @@ const std::vector> inputsfi = { {14000, 18, 5, 1, 2.f}, {15000, 32, 5, 1, 2.f}, {14000, 32, 5, 1, 2.f}, + {14000, 32, 5, 10, 2.f}, {20000, 10000, 10, 1, 2.f}, {20000, 10000, 10, 2, 2.f}, }; -typedef EpsNeighTest EpsNeighTestFI; -TEST_P(EpsNeighTestFI, Result) + +typedef EpsNeighTest EpsNeighTestFI; + +TEST_P(EpsNeighTestFI, ResultBruteForce) { for (int i = 0; i < param.n_batches; ++i) { RAFT_CUDA_TRY(cudaMemsetAsync(adj.data(), 0, sizeof(bool) * param.n_row * batchSize, stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(vd.data(), 0, sizeof(int) * (batchSize + 1), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(vd.data(), 0, sizeof(int64_t) * (batchSize + 1), stream)); - auto adj_view = make_device_matrix_view(adj.data(), param.n_row, batchSize); - auto vd_view = make_device_vector_view(vd.data(), batchSize + 1); - auto x_view = make_device_matrix_view(data.data(), param.n_row, param.n_col); - auto y_view = make_device_matrix_view( + auto adj_view = make_device_matrix_view(adj.data(), batchSize, param.n_row); + auto vd_view = make_device_vector_view(vd.data(), batchSize + 1); + auto x_view = make_device_matrix_view( data.data() + (i * batchSize * param.n_col), batchSize, param.n_col); + auto y_view = make_device_matrix_view(data.data(), param.n_row, param.n_col); - eps_neighbors_l2sq( + eps_neighbors_l2sq( handle, x_view, y_view, adj_view, vd_view, param.eps * param.eps); ASSERT_TRUE(raft::devArrMatch( - param.n_row / param.n_centers, vd.data(), batchSize, raft::Compare(), stream)); + param.n_row / param.n_centers, vd.data(), batchSize, raft::Compare(), stream)); } } + INSTANTIATE_TEST_CASE_P(EpsNeighTests, EpsNeighTestFI, ::testing::ValuesIn(inputsfi)); +// rbc examples take fewer points as correctness checks are very costly +const std::vector> inputsfi_rbc = { + {100, 16, 5, 2, 2.f}, + {1500, 16, 5, 3, 2.f}, + {1500, 16, 5, 1, 2.f}, + {1500, 3, 5, 1, 2.f}, + {1400, 16, 5, 1, 2.f}, + {1500, 17, 5, 1, 2.f}, + {1400, 17, 5, 1, 2.f}, + {1500, 18, 5, 1, 2.f}, + {1400, 18, 5, 1, 2.f}, + {1500, 32, 5, 1, 2.f}, + {1400, 32, 5, 1, 2.f}, + {1400, 32, 5, 10, 2.f}, + {2000, 1000, 10, 1, 2.f}, + {2000, 1000, 10, 2, 2.f}, +}; + +typedef EpsNeighTest EpsNeighRbcTestFI; + +TEST_P(EpsNeighRbcTestFI, DenseRbc) +{ + auto adj_baseline = raft::make_device_matrix(handle, batchSize, param.n_row); + + raft::neighbors::ball_cover::BallCoverIndex rbc_index( + handle, data.data(), param.n_row, param.n_col, raft::distance::DistanceType::L2SqrtUnexpanded); + raft::neighbors::ball_cover::build_index(handle, rbc_index); + + for (int i = 0; i < param.n_batches; ++i) { + // invalidate + RAFT_CUDA_TRY(cudaMemsetAsync(adj.data(), 1, sizeof(bool) * param.n_row * batchSize, stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(vd.data(), 1, sizeof(int64_t) * (batchSize + 1), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync( + adj_baseline.data_handle(), 1, sizeof(bool) * param.n_row * batchSize, stream)); + + float* query = data.data() + (i * batchSize * param.n_col); + + raft::neighbors::ball_cover::eps_nn( + handle, + rbc_index, + make_device_matrix_view(adj.data(), batchSize, param.n_row), + make_device_vector_view(vd.data(), batchSize + 1), + make_device_matrix_view(query, batchSize, param.n_col), + param.eps * param.eps); + + ASSERT_TRUE(raft::devArrMatch( + param.n_row / param.n_centers, vd.data(), batchSize, raft::Compare(), stream)); + + // compute baseline via brute force + compare + epsUnexpL2SqNeighborhood(adj_baseline.data_handle(), + nullptr, + query, + data.data(), + batchSize, + param.n_row, + param.n_col, + param.eps * param.eps, + stream); + + ASSERT_TRUE(raft::devArrMatch(adj_baseline.data_handle(), + adj.data(), + batchSize, + param.n_row, + raft::Compare(), + stream)); + } +} + +template +testing::AssertionResult assertCsrEqualUnordered( + T* ia_exp, T* ja_exp, T* ia_act, T* ja_act, size_t rows, size_t cols, cudaStream_t stream) +{ + std::unique_ptr ia_exp_h(new T[rows + 1]); + std::unique_ptr ia_act_h(new T[rows + 1]); + raft::update_host(ia_exp_h.get(), ia_exp, rows + 1, stream); + raft::update_host(ia_act_h.get(), ia_act, rows + 1, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + + size_t nnz = ia_exp_h.get()[rows]; + std::unique_ptr ja_exp_h(new T[nnz]); + std::unique_ptr ja_act_h(new T[nnz]); + raft::update_host(ja_exp_h.get(), ja_exp, nnz, stream); + raft::update_host(ja_act_h.get(), ja_act, nnz, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + + for (size_t i(0); i < rows; ++i) { + auto row_start = ia_exp_h.get()[i]; + auto row_end = ia_exp_h.get()[i + 1]; + + // sort ja's + std::sort(ja_exp_h.get() + row_start, ja_exp_h.get() + row_end); + std::sort(ja_act_h.get() + row_start, ja_act_h.get() + row_end); + + for (size_t idx(row_start); idx < (size_t)row_end; ++idx) { + auto exp = ja_exp_h.get()[idx]; + auto act = ja_act_h.get()[idx]; + if (exp != act) { + return testing::AssertionFailure() + << "actual=" << act << " != expected=" << exp << " @" << i << "," << idx; + } + } + } + return testing::AssertionSuccess(); +} + +TEST_P(EpsNeighRbcTestFI, SparseRbc) +{ + auto adj_ia = raft::make_device_vector(handle, batchSize + 1); + auto adj_ja = raft::make_device_vector(handle, param.n_row * batchSize); + auto vd_baseline = raft::make_device_vector(handle, batchSize + 1); + auto adj_ia_baseline = raft::make_device_vector(handle, batchSize + 1); + auto adj_ja_baseline = raft::make_device_vector(handle, param.n_row * batchSize); + + raft::neighbors::ball_cover::BallCoverIndex rbc_index( + handle, data.data(), param.n_row, param.n_col, raft::distance::DistanceType::L2SqrtUnexpanded); + raft::neighbors::ball_cover::build_index(handle, rbc_index); + + for (int i = 0; i < param.n_batches; ++i) { + // reset full array -- that way we can compare the full size + RAFT_CUDA_TRY( + cudaMemsetAsync(adj_ja.data_handle(), 0, sizeof(int64_t) * param.n_row * batchSize, stream)); + RAFT_CUDA_TRY(cudaMemsetAsync( + adj_ja_baseline.data_handle(), 0, sizeof(int64_t) * param.n_row * batchSize, stream)); + + float* query = data.data() + (i * batchSize * param.n_col); + + // compute dense baseline and convert adj to csr + { + raft::neighbors::ball_cover::eps_nn( + handle, + rbc_index, + make_device_matrix_view(adj.data(), batchSize, param.n_row), + make_device_vector_view(vd_baseline.data_handle(), batchSize + 1), + make_device_matrix_view(query, batchSize, param.n_col), + param.eps * param.eps); + thrust::exclusive_scan(resource::get_thrust_policy(handle), + vd_baseline.data_handle(), + vd_baseline.data_handle() + batchSize + 1, + adj_ia_baseline.data_handle()); + raft::sparse::convert::adj_to_csr(handle, + adj.data(), + adj_ia_baseline.data_handle(), + batchSize, + param.n_row, + labels.data(), + adj_ja_baseline.data_handle()); + } + + // exact computation with 2 passes + { + raft::neighbors::ball_cover::eps_nn( + handle, + rbc_index, + make_device_vector_view(adj_ia.data_handle(), batchSize + 1), + make_device_vector_view(nullptr, 0), + make_device_vector_view(vd.data(), batchSize + 1), + make_device_matrix_view(query, batchSize, param.n_col), + param.eps * param.eps); + raft::neighbors::ball_cover::eps_nn( + handle, + rbc_index, + make_device_vector_view(adj_ia.data_handle(), batchSize + 1), + make_device_vector_view(adj_ja.data_handle(), batchSize * param.n_row), + make_device_vector_view(vd.data(), batchSize + 1), + make_device_matrix_view(query, batchSize, param.n_col), + param.eps * param.eps); + ASSERT_TRUE(raft::devArrMatch(adj_ia_baseline.data_handle(), + adj_ia.data_handle(), + batchSize + 1, + raft::Compare(), + stream)); + ASSERT_TRUE(assertCsrEqualUnordered(adj_ia_baseline.data_handle(), + adj_ja_baseline.data_handle(), + adj_ia.data_handle(), + adj_ja.data_handle(), + batchSize, + param.n_row, + stream)); + } + } +} + +TEST_P(EpsNeighRbcTestFI, SparseRbcMaxK) +{ + auto adj_ia = raft::make_device_vector(handle, batchSize + 1); + auto adj_ja = raft::make_device_vector(handle, param.n_row * batchSize); + auto vd_baseline = raft::make_device_vector(handle, batchSize + 1); + auto adj_ia_baseline = raft::make_device_vector(handle, batchSize + 1); + auto adj_ja_baseline = raft::make_device_vector(handle, param.n_row * batchSize); + + raft::neighbors::ball_cover::BallCoverIndex rbc_index( + handle, data.data(), param.n_row, param.n_col, raft::distance::DistanceType::L2SqrtUnexpanded); + raft::neighbors::ball_cover::build_index(handle, rbc_index); + + int64_t expected_max_k = param.n_row / param.n_centers; + + for (int i = 0; i < param.n_batches; ++i) { + // reset full array -- that way we can compare the full size + RAFT_CUDA_TRY( + cudaMemsetAsync(adj_ja.data_handle(), 0, sizeof(int64_t) * param.n_row * batchSize, stream)); + RAFT_CUDA_TRY(cudaMemsetAsync( + adj_ja_baseline.data_handle(), 0, sizeof(int64_t) * param.n_row * batchSize, stream)); + + float* query = data.data() + (i * batchSize * param.n_col); + + // compute dense baseline and convert adj to csr + { + raft::neighbors::ball_cover::eps_nn( + handle, + rbc_index, + make_device_matrix_view(adj.data(), batchSize, param.n_row), + make_device_vector_view(vd_baseline.data_handle(), batchSize + 1), + make_device_matrix_view(query, batchSize, param.n_col), + param.eps * param.eps); + thrust::exclusive_scan(resource::get_thrust_policy(handle), + vd_baseline.data_handle(), + vd_baseline.data_handle() + batchSize + 1, + adj_ia_baseline.data_handle()); + raft::sparse::convert::adj_to_csr(handle, + adj.data(), + adj_ia_baseline.data_handle(), + batchSize, + param.n_row, + labels.data(), + adj_ja_baseline.data_handle()); + } + + // exact computation with 1 pass + { + int64_t max_k = expected_max_k; + raft::neighbors::ball_cover::eps_nn( + handle, + rbc_index, + make_device_vector_view(adj_ia.data_handle(), batchSize + 1), + make_device_vector_view(adj_ja.data_handle(), batchSize * param.n_row), + make_device_vector_view(vd.data(), batchSize + 1), + make_device_matrix_view(query, batchSize, param.n_col), + param.eps * param.eps, + make_host_scalar_view(&max_k)); + ASSERT_TRUE(raft::devArrMatch(adj_ia_baseline.data_handle(), + adj_ia.data_handle(), + batchSize + 1, + raft::Compare(), + stream)); + ASSERT_TRUE(assertCsrEqualUnordered(adj_ia_baseline.data_handle(), + adj_ja_baseline.data_handle(), + adj_ia.data_handle(), + adj_ja.data_handle(), + batchSize, + param.n_row, + stream)); + ASSERT_TRUE(raft::devArrMatch( + vd_baseline.data_handle(), vd.data(), batchSize + 1, raft::Compare(), stream)); + ASSERT_TRUE(max_k == expected_max_k); + } + + // k-limited computation with 1 pass + { + int64_t max_k = expected_max_k / 2; + raft::neighbors::ball_cover::eps_nn( + handle, + rbc_index, + make_device_vector_view(adj_ia.data_handle(), batchSize + 1), + make_device_vector_view(adj_ja.data_handle(), batchSize * param.n_row), + make_device_vector_view(vd.data(), batchSize + 1), + make_device_matrix_view(query, batchSize, param.n_col), + param.eps * param.eps, + make_host_scalar_view(&max_k)); + ASSERT_TRUE(max_k == expected_max_k); + ASSERT_TRUE(raft::devArrMatch( + expected_max_k / 2, vd.data(), batchSize, raft::Compare(), stream)); + ASSERT_TRUE(raft::devArrMatch(expected_max_k / 2 * batchSize, + vd.data() + batchSize, + 1, + raft::Compare(), + stream)); + } + } +} + +INSTANTIATE_TEST_CASE_P(EpsNeighTests, EpsNeighRbcTestFI, ::testing::ValuesIn(inputsfi_rbc)); + }; // namespace knn }; // namespace spatial }; // namespace raft diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index 17dd2d8bfd..2c488ef427 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ # cython: language_level = 3 from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t +from libcpp cimport bool from libcpp.string cimport string from pylibraft.common.cpp.mdspan cimport ( @@ -40,6 +41,9 @@ ctypedef const uint8_t const_uint8_t cdef device_matrix_view[float, int64_t, row_major] get_dmv_float( array, check_shape) except * +cdef device_matrix_view[bool, int64_t, row_major] get_dmv_bool( + array, check_shape) except * + cdef device_matrix_view[uint8_t, int64_t, row_major] get_dmv_uint8( array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index 7442a6bb89..9a994e2ec9 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ from cpython.object cimport PyObject from cython.operator cimport dereference as deref from libc.stddef cimport size_t from libc.stdint cimport int8_t, int32_t, int64_t, uint8_t, uint32_t, uintptr_t +from libcpp cimport bool from pylibraft.common.cpp.mdspan cimport ( col_major, @@ -160,6 +161,18 @@ cdef device_matrix_view[float, int64_t, row_major] \ return make_device_matrix_view[float, int64_t, row_major]( cai.data, shape[0], shape[1]) + +cdef device_matrix_view[bool, int64_t, row_major] \ + get_dmv_bool(cai, check_shape) except *: + if cai.dtype != np.bool_: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[bool, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + cdef device_matrix_view[uint8_t, int64_t, row_major] \ get_dmv_uint8(cai, check_shape) except *: if cai.dtype != np.uint8: diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index 45cd9f74e6..e64032408a 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources common.pyx refine.pyx brute_force.pyx) +set(cython_sources common.pyx refine.pyx brute_force.pyx rbc.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index 325ea5842e..972058aaee 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,16 @@ # limitations under the License. # -from pylibraft.neighbors import brute_force, cagra, ivf_flat, ivf_pq +from pylibraft.neighbors import brute_force, cagra, ivf_flat, ivf_pq, rbc from .refine import refine -__all__ = ["common", "refine", "brute_force", "ivf_flat", "ivf_pq", "cagra"] +__all__ = [ + "common", + "refine", + "brute_force", + "ivf_flat", + "ivf_pq", + "rbc", + "cagra", +] diff --git a/python/pylibraft/pylibraft/neighbors/brute_force.pyx b/python/pylibraft/pylibraft/neighbors/brute_force.pyx index 4aa47b8a18..19d20fb75d 100644 --- a/python/pylibraft/pylibraft/neighbors/brute_force.pyx +++ b/python/pylibraft/pylibraft/neighbors/brute_force.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ from libc.stdint cimport int64_t, uintptr_t from pylibraft.common.cpp.optional cimport optional from pylibraft.common.handle cimport device_resources -from pylibraft.common.mdspan cimport get_dmv_float, get_dmv_int64 +from pylibraft.common.mdspan cimport get_dmv_bool, get_dmv_float, get_dmv_int64 from pylibraft.common.handle import auto_sync_handle from pylibraft.common.interruptible import cuda_interruptible @@ -51,12 +51,17 @@ from pylibraft.neighbors.common import _check_input_array from pylibraft.common.cpp.mdspan cimport ( device_matrix_view, + device_vector_view, host_matrix_view, make_device_matrix_view, + make_device_vector_view, make_host_matrix_view, row_major, ) -from pylibraft.neighbors.cpp.brute_force cimport knn as c_knn +from pylibraft.neighbors.cpp.brute_force cimport ( + eps_neighbors as c_eps_neighbors, + knn as c_knn, +) def _get_array_params(array_interface, check_dtype=None): @@ -177,3 +182,88 @@ def knn(dataset, queries, k=None, indices=None, distances=None, raise TypeError("dtype %s not supported" % dataset_cai.dtype) return (distances, indices) + + +@auto_sync_handle +@auto_convert_output +def eps_neighbors(dataset, queries, eps, handle=None): + """ + Perform an epsilon neighborhood search using the L2-norm. + + Parameters + ---------- + dataset : array interface compliant matrix, row-major layout, + shape (n_samples, dim). Supported dtype [float] + queries : array interface compliant matrix, row-major layout, + shape (n_queries, dim) Supported dtype [float] + eps : threshold + {handle_docstring} + + Returns + ------- + adj: array interface compliant object containing bool adjacency mask + shape (n_queries, n_samples) + + vd: array interface compliant object containing row sums of adj + shape (n_queries + 1). vd[n_queries] contains the total sum + + Examples + -------- + >>> import cupy as cp + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors.brute_force import eps_neighbors + >>> handle = DeviceResources() + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> eps = 0.1 + >>> adj, vd = eps_neighbors(dataset, queries, eps, handle=handle) + >>> adj = cp.asarray(adj) + >>> vd = cp.asarray(vd) + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + """ + + if handle is None: + handle = DeviceResources() + + dataset_cai = cai_wrapper(dataset) + queries_cai = cai_wrapper(queries) + + # we require c-contiguous (rowmajor) inputs here + _check_input_array(dataset_cai, [np.dtype("float32")]) + _check_input_array(queries_cai, [np.dtype("float32")], + exp_cols=dataset_cai.shape[1]) + + n_queries = queries_cai.shape[0] + n_samples = dataset_cai.shape[0] + + adj = device_ndarray.empty((n_queries, n_samples), dtype='bool') + vd = device_ndarray.empty((n_queries + 1, ), dtype='int64') + adj_cai = cai_wrapper(adj) + vd_cai = cai_wrapper(vd) + + cdef device_resources* handle_ = \ + handle.getHandle() + + vd_vector_view = make_device_vector_view( + vd_cai.data, vd_cai.shape[0]) + + if dataset_cai.dtype == np.float32: + with cuda_interruptible(): + c_eps_neighbors( + deref(handle_), + get_dmv_float(dataset_cai, check_shape=True), + get_dmv_float(queries_cai, check_shape=True), + get_dmv_bool(adj_cai, check_shape=True), + vd_vector_view, + eps) + else: + raise TypeError("dtype %s not supported" % dataset_cai.dtype) + + return (adj, vd) diff --git a/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd b/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd index de5e0af267..5f6a83a9dc 100644 --- a/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd +++ b/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd @@ -1,5 +1,5 @@ # -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,8 +32,10 @@ from rmm._lib.memory_resource cimport device_memory_resource from pylibraft.common.cpp.mdspan cimport ( device_matrix_view, + device_vector_view, host_matrix_view, make_device_matrix_view, + make_device_vector_view, make_host_matrix_view, row_major, ) @@ -53,3 +55,14 @@ cdef extern from "raft_runtime/neighbors/brute_force.hpp" \ DistanceType metric, optional[float] metric_arg, optional[int64_t] global_id_offset) except + + +cdef extern from "raft_runtime/neighbors/eps_neighborhood.hpp" \ + namespace "raft::runtime::neighbors::epsilon_neighborhood" nogil: + + cdef void eps_neighbors( + const device_resources & handle, + device_matrix_view[float, int64_t, row_major] index, + device_matrix_view[float, int64_t, row_major] search, + device_matrix_view[bool, int64_t, row_major] adj, + device_vector_view[int64_t, int64_t] vd, + float eps) except + diff --git a/python/pylibraft/pylibraft/neighbors/cpp/rbc.pxd b/python/pylibraft/pylibraft/neighbors/cpp/rbc.pxd new file mode 100644 index 0000000000..531c0dc2c1 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/rbc.pxd @@ -0,0 +1,84 @@ +# +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +import pylibraft.common.handle + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t, uintptr_t +from libcpp cimport bool, nullptr +from libcpp.string cimport string +from libcpp.vector cimport vector + +from rmm._lib.memory_resource cimport device_memory_resource + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + device_vector_view, + host_matrix_view, + make_device_matrix_view, + make_host_matrix_view, + row_major, +) +from pylibraft.common.handle cimport device_resources +from pylibraft.distance.distance_type cimport DistanceType + + +cdef extern from "raft/neighbors/ball_cover_types.hpp" \ + namespace "raft::neighbors::ball_cover" nogil: + + cdef cppclass BallCoverIndex[IdxT, T, IntT, MatIdxT]: + BallCoverIndex(const device_resources& handle, + device_matrix_view[T, MatIdxT, row_major] dataset, + DistanceType metric) + + +cdef extern from "raft_runtime/neighbors/eps_neighborhood.hpp" \ + namespace "raft::runtime::neighbors::epsilon_neighborhood" nogil: + + cdef void eps_neighbors_rbc( + const device_resources & handle, + device_matrix_view[float, int64_t, row_major] index, + device_matrix_view[float, int64_t, row_major] search, + device_matrix_view[bool, int64_t, row_major] adj, + device_vector_view[int64_t, int64_t] vd, + float eps) except + + + cdef void build_rbc_index( + const device_resources & handle, + BallCoverIndex[int64_t, float, int64_t, int64_t] rbc_index) except + + + cdef void eps_neighbors_rbc_pass1( + const device_resources & handle, + BallCoverIndex[int64_t, float, int64_t, int64_t] rbc_index, + device_matrix_view[float, int64_t, row_major] search, + device_vector_view[int64_t, int64_t] adj_ia, + device_vector_view[int64_t, int64_t] vd, + float eps) except + + + cdef void eps_neighbors_rbc_pass2( + const device_resources & handle, + BallCoverIndex[int64_t, float, int64_t, int64_t] rbc_index, + device_matrix_view[float, int64_t, row_major] search, + device_vector_view[int64_t, int64_t] adj_ia, + device_vector_view[int64_t, int64_t] adj_ja, + device_vector_view[int64_t, int64_t] vd, + float eps) except + diff --git a/python/pylibraft/pylibraft/neighbors/rbc.pyx b/python/pylibraft/pylibraft/neighbors/rbc.pyx new file mode 100644 index 0000000000..a703dc1745 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/rbc.pyx @@ -0,0 +1,241 @@ +# +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from cython.operator cimport dereference as deref +from libcpp cimport bool, nullptr +from libcpp.vector cimport vector + +from pylibraft.common import ( + DeviceResources, + auto_convert_output, + cai_wrapper, + device_ndarray, +) + +from libc.stdint cimport int64_t, uintptr_t + +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.common.mdspan cimport get_dmv_bool, get_dmv_float, get_dmv_int64 + +from pylibraft.common.handle import auto_sync_handle +from pylibraft.common.interruptible import cuda_interruptible +from pylibraft.neighbors.common import _check_input_array, _get_metric + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + device_vector_view, + host_matrix_view, + make_device_matrix_view, + make_device_vector_view, + make_host_matrix_view, + row_major, +) +from pylibraft.neighbors.cpp.rbc cimport ( + BallCoverIndex as c_BallCoverIndex, + build_rbc_index as c_build_rbc_index, + eps_neighbors_rbc as c_eps_neighbors_rbc, + eps_neighbors_rbc_pass1 as c_eps_neighbors_rbc_pass1, + eps_neighbors_rbc_pass2 as c_eps_neighbors_rbc_pass2, +) + + +cdef class RbcIndex: + cdef readonly bool trained + cdef str data_type + + def __cinit__(self): + self.trained = False + self.data_type = None + + +cdef class RbcIndexFloat(RbcIndex): + cdef c_BallCoverIndex[int64_t, float, int64_t, int64_t]* index + + def __cinit__(self, dataset, handle): + cdef device_resources* handle_ = \ + handle.getHandle() + self.index = new c_BallCoverIndex[int64_t, float, int64_t, int64_t]( + deref(handle_), + get_dmv_float(dataset, check_shape=True), + _get_metric("euclidean")) + + +@auto_sync_handle +@auto_convert_output +def build_rbc_index(dataset, handle=None): + """ + Builds a random ball cover index from dataset using the L2-norm. + + Parameters + ---------- + dataset : array interface compliant matrix, row-major layout, + shape (n_samples, dim). Supported dtype [float] + {handle_docstring} + + Returns + ------- + index : Index + + Examples + -------- + see 'eps_neighbors_sparse' + + """ + if handle is None: + handle = DeviceResources() + + dataset_cai = cai_wrapper(dataset) + + # we require c-contiguous (rowmajor) inputs here + _check_input_array(dataset_cai, [np.dtype("float32")]) + + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef RbcIndexFloat rbc_index_float + + if dataset_cai.dtype == np.float32: + rbc_index_float = RbcIndexFloat(dataset=dataset_cai, handle=handle) + rbc_index_float.data_type = "float32" + with cuda_interruptible(): + c_build_rbc_index( + deref(handle_), + deref(rbc_index_float.index)) + rbc_index_float.trained = True + return rbc_index_float + else: + raise TypeError("dtype %s not supported" % dataset_cai.dtype) + + +@auto_sync_handle +@auto_convert_output +def eps_neighbors(RbcIndex rbc_index, queries, eps, handle=None): + """ + Perform an epsilon neighborhood search with random ball cover (rbc) + using the L2-norm. + + Parameters + ---------- + rbc_index : RbcIndex created via 'build_rbc_index'. + Supported dtype [float] + queries : array interface compliant matrix, row-major layout, + shape (n_queries, dim) Supported dtype [float] + eps : threshold + {handle_docstring} + + Returns + ------- + adj_ia: array interface compliant object containing row indices for + adj_ja + + adj_ja: array interface compliant object containing adjacency mask + column indices + + vd: array interface compliant object containing row sums of adj + shape (n_queries + 1). vd[n_queries] contains the total sum + + Examples + -------- + >>> import cupy as cp + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors.rbc import eps_neighbors + >>> from pylibraft.neighbors.rbc import build_rbc_index + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> eps = 0.1 + >>> handle = DeviceResources() + >>> rbc_index = build_rbc_index(dataset) + >>> adj_ia, adj_ja, vd = eps_neighbors(rbc_index, queries, eps) + >>> adj_ia = cp.asarray(adj_ia) + >>> adj_ja = cp.asarray(adj_ja) + >>> vd = cp.asarray(vd) + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + """ + if not rbc_index.trained: + raise ValueError("Index need to be built before calling extend.") + + if handle is None: + handle = DeviceResources() + + queries_cai = cai_wrapper(queries) + + _check_input_array(queries_cai, [np.dtype(rbc_index.data_type)]) + + n_queries = queries_cai.shape[0] + + adj_ia = device_ndarray.empty((n_queries + 1, ), dtype='int64') + vd = device_ndarray.empty((n_queries + 1, ), dtype='int64') + adj_ia_cai = cai_wrapper(adj_ia) + vd_cai = cai_wrapper(vd) + + cdef device_resources* handle_ = \ + handle.getHandle() + + vd_vector_view = make_device_vector_view( + vd_cai.data, vd_cai.shape[0]) + adj_ia_vector_view = make_device_vector_view( + adj_ia_cai.data, adj_ia_cai.shape[0]) + + cdef RbcIndexFloat rbc_index_float + + if queries_cai.dtype == np.float32: + rbc_index_float = rbc_index + with cuda_interruptible(): + c_eps_neighbors_rbc_pass1( + deref(handle_), + deref(rbc_index_float.index), + get_dmv_float(queries_cai, check_shape=True), + adj_ia_vector_view, + vd_vector_view, + eps) + else: + raise TypeError("dtype %s not supported" % queries_cai.dtype) + + handle.sync() + n_nnz = adj_ia.copy_to_host()[n_queries] + adj_ja = device_ndarray.empty((n_nnz, ), dtype='int64') + adj_ja_cai = cai_wrapper(adj_ja) + adj_ja_vector_view = make_device_vector_view( + adj_ja_cai.data, adj_ja_cai.shape[0]) + + if queries_cai.dtype == np.float32: + with cuda_interruptible(): + c_eps_neighbors_rbc_pass2( + deref(handle_), + deref(rbc_index_float.index), + get_dmv_float(queries_cai, check_shape=True), + adj_ia_vector_view, + adj_ja_vector_view, + vd_vector_view, + eps) + else: + raise TypeError("dtype %s not supported" % queries_cai.dtype) + + return (adj_ia, adj_ja, vd) diff --git a/python/pylibraft/pylibraft/test/test_eps_neighborhood.py b/python/pylibraft/pylibraft/test/test_eps_neighborhood.py new file mode 100644 index 0000000000..f2643de904 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_eps_neighborhood.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from scipy.sparse import csr_array + +from pylibraft.common import DeviceResources, Stream +from pylibraft.neighbors.brute_force import eps_neighbors as eps_neighbors_bf +from pylibraft.neighbors.rbc import ( + build_rbc_index, + eps_neighbors as eps_neighbors_rbc, +) + + +def test_bf_eps_neighbors_check_col_major_inputs(): + # make sure that we get an exception if passed col-major inputs, + # instead of returning incorrect results + cp = pytest.importorskip("cupy") + n_index_rows, n_query_rows, n_cols = 128, 16, 32 + eps = 0.02 + index = cp.random.random_sample((n_index_rows, n_cols), dtype="float32") + queries = cp.random.random_sample((n_query_rows, n_cols), dtype="float32") + + with pytest.raises(ValueError): + eps_neighbors_bf(cp.asarray(index, order="F"), queries, eps) + + with pytest.raises(ValueError): + eps_neighbors_bf(index, cp.asarray(queries, order="F"), eps) + + # shouldn't throw an exception with c-contiguous inputs + eps_neighbors_bf(index, queries, eps) + + +def test_rbc_eps_neighbors_check_col_major_inputs(): + # make sure that we get an exception if passed col-major inputs, + # instead of returning incorrect results + cp = pytest.importorskip("cupy") + n_index_rows, n_query_rows, n_cols = 128, 16, 32 + eps = 0.02 + index = cp.random.random_sample((n_index_rows, n_cols), dtype="float32") + queries = cp.random.random_sample((n_query_rows, n_cols), dtype="float32") + + with pytest.raises(ValueError): + build_rbc_index(cp.asarray(index, order="F")) + + rbc_index = build_rbc_index(index) + + with pytest.raises(ValueError): + eps_neighbors_rbc(rbc_index, cp.asarray(queries, order="F"), eps) + + eps_neighbors_rbc(rbc_index, queries, eps) + + +@pytest.mark.parametrize("n_index_rows", [32, 100, 1000]) +@pytest.mark.parametrize("n_query_rows", [32, 100, 1000]) +@pytest.mark.parametrize("n_cols", [2, 3, 40, 100]) +def test_eps_neighbors(n_index_rows, n_query_rows, n_cols): + s2 = Stream() + handle = DeviceResources(stream=s2) + + cp = pytest.importorskip("cupy") + eps = 0.02 + index = cp.random.random_sample((n_index_rows, n_cols), dtype="float32") + queries = cp.random.random_sample((n_query_rows, n_cols), dtype="float32") + + # brute force + adj_bf, vd_bf = eps_neighbors_bf(index, queries, eps, handle=handle) + adj_bf = cp.asarray(adj_bf) + vd_bf = cp.asarray(vd_bf) + + rbc_index = build_rbc_index(index, handle=handle) + adj_rbc_ia, adj_rbc_ja, vd_rbc = eps_neighbors_rbc( + rbc_index, queries, eps, handle=handle + ) + adj_rbc_ia = cp.asarray(adj_rbc_ia) + adj_rbc_ja = cp.asarray(adj_rbc_ja) + vd_rbc = cp.asarray(vd_rbc) + + np.testing.assert_array_equal(vd_bf.get(), vd_rbc.get()) + + adj_rbc = csr_array( + ( + np.ones(adj_rbc_ia.get()[n_query_rows]), + adj_rbc_ja.get(), + adj_rbc_ia.get(), + ), + shape=(n_query_rows, n_index_rows), + ).toarray() + np.testing.assert_array_equal(adj_bf.get(), adj_rbc)