diff --git a/build.sh b/build.sh index 071820ba93..32582738c3 100755 --- a/build.sh +++ b/build.sh @@ -78,7 +78,7 @@ INSTALL_TARGET=install BUILD_REPORT_METRICS="" BUILD_REPORT_INCL_CACHE_STATS=OFF -TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" +TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" CACHE_ARGS="" diff --git a/ci/build_cpp.sh b/ci/build_cpp.sh index d2d2d08b99..a41f81152d 100755 --- a/ci/build_cpp.sh +++ b/ci/build_cpp.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. set -euo pipefail diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index 7ba381ab0a..a9ff6c2922 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -147,6 +147,13 @@ void parse_build_param(const nlohmann::json& conf, if (conf.contains("intermediate_graph_degree")) { param.intermediate_graph_degree = conf.at("intermediate_graph_degree"); } + if (conf.contains("graph_build_algo")) { + if (conf.at("graph_build_algo") == "IVF_PQ") { + param.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ; + } else if (conf.at("graph_build_algo") == "NN_DESCENT") { + param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; + } + } } template diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 1bd7010c83..f96dd34e05 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -35,12 +35,11 @@ namespace raft::neighbors::cagra { */ /** - * @brief Build a kNN graph. + * @brief Build a kNN graph using IVF-PQ. * * The kNN graph is the first building block for CAGRA index. - * This function uses the IVF-PQ method to build a kNN graph. * - * The output is a dense matrix that stores the neighbor indices for each pont in the dataset. + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. * Each point has the same number of neighbors. * * See [cagra::build](#cagra::build) for an alternative method. @@ -52,9 +51,9 @@ namespace raft::neighbors::cagra { * @code{.cpp} * using namespace raft::neighbors; * // use default index parameters - * cagra::index_params build_params; - * cagra::search_params search_params - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * ivf_pq::index_params build_params; + * ivf_pq::search_params search_params + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); * // create knn graph * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); @@ -70,7 +69,7 @@ namespace raft::neighbors::cagra { * @param[in] res raft resources * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] - * @param[in] refine_rate refinement rate for ivf-pq search + * @param[in] refine_rate (optional) refinement rate for ivf-pq search * @param[in] build_params (optional) ivf_pq index building parameters for knn graph * @param[in] search_params (optional) ivf_pq search parameters */ @@ -95,6 +94,58 @@ void build_knn_graph(raft::resources const& res, res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); } +/** + * @brief Build a kNN graph using NN-descent. + * + * The kNN graph is the first building block for CAGRA index. + * + * The output is a dense matrix that stores the neighbor indices for each point in the dataset. + * Each point has the same number of neighbors. + * + * See [cagra::build](#cagra::build) for an alternative method. + * + * The following distance metrics are supported: + * - L2Expanded + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params build_params; + * build_params.graph_degree = 128; + * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); + * // create knn graph + * cagra::build_knn_graph(res, dataset, knn_graph.view(), build_params); + * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); + * cagra::optimize(res, dataset, nn_descent_index.graph.view(), optimized_graph.view()); + * // Construct an index from dataset and optimized knn_graph + * auto index = cagra::index(res, build_params.metric(), dataset, + * optimized_graph.view()); + * @endcode + * + * @tparam DataT data element type + * @tparam IdxT type of the dataset vector indices + * @tparam accessor host or device accessor_type for the dataset + * @param[in] res raft::resources is an object mangaging resources + * @param[in] dataset input raft::host/device_matrix_view that can be located in + * in host or device memory + * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] + * @param[in] build_params an instance of experimental::nn_descent::index_params that are parameters + * to run the nn-descent algorithm + */ +template , memory_type::device>> +void build_knn_graph(raft::resources const& res, + mdspan, row_major, accessor> dataset, + raft::host_matrix_view knn_graph, + experimental::nn_descent::index_params build_params) +{ + detail::build_knn_graph(res, dataset, knn_graph, build_params); +} + /** * @brief Sort a KNN graph index. * Preprocessing step for `cagra::optimize`: If a KNN graph is not built using @@ -259,7 +310,16 @@ index build(raft::resources const& res, std::optional> knn_graph( raft::make_host_matrix(dataset.extent(0), intermediate_degree)); - build_knn_graph(res, dataset, knn_graph->view()); + if (params.build_algo == graph_build_algo::IVF_PQ) { + build_knn_graph(res, dataset, knn_graph->view()); + + } else { + // Use nn-descent to build CAGRA knn graph + auto nn_descent_params = experimental::nn_descent::index_params(); + nn_descent_params.graph_degree = intermediate_degree; + nn_descent_params.intermediate_graph_degree = 1.5 * intermediate_degree; + build_knn_graph(res, dataset, knn_graph->view(), nn_descent_params); + } auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 4728178194..5061d6082d 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -40,11 +40,24 @@ namespace raft::neighbors::cagra { * @{ */ +/** + * @brief ANN algorithm used by CAGRA to build knn graph + * + */ +enum class graph_build_algo { + /* Use IVF-PQ to build all-neighbors knn graph */ + IVF_PQ, + /* Experimental, use NN-Descent to build all-neighbors knn graph */ + NN_DESCENT +}; + struct index_params : ann::index_params { /** Degree of input graph for pruning. */ size_t intermediate_graph_degree = 128; /** Degree of output graph. */ size_t graph_degree = 64; + /** ANN algorithm to build knn graph. */ + graph_build_algo build_algo = graph_build_algo::IVF_PQ; }; enum class search_algo { @@ -362,6 +375,7 @@ struct index : ann::index { // TODO: Remove deprecated experimental namespace in 23.12 release namespace raft::neighbors::experimental::cagra { +using raft::neighbors::cagra::graph_build_algo; using raft::neighbors::cagra::hash_mode; using raft::neighbors::cagra::index; using raft::neighbors::cagra::index_params; diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 80e964df57..40024a3deb 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -35,6 +35,7 @@ #include #include #include +#include #include namespace raft::neighbors::cagra::detail { @@ -240,4 +241,27 @@ void build_knn_graph(raft::resources const& res, if (!first) RAFT_LOG_DEBUG("# Finished building kNN graph"); } +template +void build_knn_graph(raft::resources const& res, + mdspan, row_major, accessor> dataset, + raft::host_matrix_view knn_graph, + experimental::nn_descent::index_params build_params) +{ + auto nn_descent_idx = experimental::nn_descent::index(res, knn_graph); + experimental::nn_descent::build(res, build_params, dataset, nn_descent_idx); + + using internal_IdxT = typename std::make_unsigned::type; + using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type; + using g_accessor_internal = + host_device_accessor, g_accessor::mem_type>; + + auto knn_graph_internal = + mdspan, row_major, g_accessor_internal>( + reinterpret_cast(nn_descent_idx.graph().data_handle()), + nn_descent_idx.graph().extent(0), + nn_descent_idx.graph().extent(1)); + + graph::sort_knn_graph(res, dataset, knn_graph_internal); +} + } // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 18d451be60..8845e37973 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -244,7 +244,7 @@ void sort_knn_graph(raft::resources const& res, const uint32_t input_graph_degree = knn_graph.extent(1); IdxT* const input_graph_ptr = knn_graph.data_handle(); - auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); + auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); // // Sorting kNN graph diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh new file mode 100644 index 0000000000..3e4d0409bd --- /dev/null +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -0,0 +1,1453 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "../nn_descent_types.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include // raft::util::arch::SM_* +#include +#include +#include +#include + +namespace raft::neighbors::experimental::nn_descent::detail { + +using pinned_memory_resource = thrust::universal_host_pinned_memory_resource; +template +using pinned_memory_allocator = thrust::mr::stateless_resource_allocator; + +using DistData_t = float; +constexpr int DEGREE_ON_DEVICE{32}; +constexpr int SEGMENT_SIZE{32}; +constexpr int counter_interval{100}; +template +struct InternalID_t; + +// InternalID_t uses 1 bit for marking (new or old). +template <> +class InternalID_t { + private: + using Index_t = int; + Index_t id_{std::numeric_limits::max()}; + + public: + __host__ __device__ bool is_new() const { return id_ >= 0; } + __host__ __device__ Index_t& id_with_flag() { return id_; } + __host__ __device__ Index_t id() const + { + if (is_new()) return id_; + return -id_ - 1; + } + __host__ __device__ void mark_old() + { + if (id_ >= 0) id_ = -id_ - 1; + } + __host__ __device__ bool operator==(const InternalID_t& other) const + { + return id() == other.id(); + } +}; + +template +struct ResultItem; + +template <> +class ResultItem { + private: + using Index_t = int; + Index_t id_; + DistData_t dist_; + + public: + __host__ __device__ ResultItem() + : id_(std::numeric_limits::max()), dist_(std::numeric_limits::max()){}; + __host__ __device__ ResultItem(const Index_t id_with_flag, const DistData_t dist) + : id_(id_with_flag), dist_(dist){}; + __host__ __device__ bool is_new() const { return id_ >= 0; } + __host__ __device__ Index_t& id_with_flag() { return id_; } + __host__ __device__ Index_t id() const + { + if (is_new()) return id_; + return -id_ - 1; + } + __host__ __device__ DistData_t& dist() { return dist_; } + + __host__ __device__ void mark_old() + { + if (id_ >= 0) id_ = -id_ - 1; + } + + __host__ __device__ bool operator<(const ResultItem& other) const + { + if (dist_ == other.dist_) return id() < other.id(); + return dist_ < other.dist_; + } + __host__ __device__ bool operator==(const ResultItem& other) const + { + return id() == other.id(); + } + __host__ __device__ bool operator>=(const ResultItem& other) const + { + return !(*this < other); + } + __host__ __device__ bool operator<=(const ResultItem& other) const + { + return (*this == other) || (*this < other); + } + __host__ __device__ bool operator>(const ResultItem& other) const + { + return !(*this <= other); + } + __host__ __device__ bool operator!=(const ResultItem& other) const + { + return !(*this == other); + } +}; + +using align32 = raft::Pow2<32>; + +template +int get_batch_size(const int it_now, const T nrow, const int batch_size) +{ + int it_total = ceildiv(nrow, batch_size); + return (it_now == it_total - 1) ? nrow - it_now * batch_size : batch_size; +} + +// for avoiding bank conflict +template +constexpr __host__ __device__ __forceinline__ int skew_dim(int ndim) +{ + // all "4"s are for alignment + if constexpr (std::is_same::value) { + ndim = ceildiv(ndim, 4) * 4; + return ndim + (ndim % 32 == 0) * 4; + } +} + +template +__device__ __forceinline__ ResultItem xor_swap(ResultItem x, int mask, int dir) +{ + ResultItem y; + y.dist() = __shfl_xor_sync(raft::warp_full_mask(), x.dist(), mask, raft::warp_size()); + y.id_with_flag() = + __shfl_xor_sync(raft::warp_full_mask(), x.id_with_flag(), mask, raft::warp_size()); + return x < y == dir ? y : x; +} + +__device__ __forceinline__ int xor_swap(int x, int mask, int dir) +{ + int y = __shfl_xor_sync(raft::warp_full_mask(), x, mask, raft::warp_size()); + return x < y == dir ? y : x; +} + +// TODO: Move to RAFT utils https://github.com/rapidsai/raft/issues/1827 +__device__ __forceinline__ uint bfe(uint lane_id, uint pos) +{ + uint res; + asm("bfe.u32 %0,%1,%2,%3;" : "=r"(res) : "r"(lane_id), "r"(pos), "r"(1)); + return res; +} + +template +__device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane_id) +{ + static_assert(raft::warp_size() == 32); + auto& element = *element_ptr; + element = xor_swap(element, 0x01, bfe(lane_id, 1) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x02, bfe(lane_id, 2) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 2) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x04, bfe(lane_id, 3) ^ bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 3) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 3) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x08, bfe(lane_id, 4) ^ bfe(lane_id, 3)); + element = xor_swap(element, 0x04, bfe(lane_id, 4) ^ bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 4) ^ bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 4) ^ bfe(lane_id, 0)); + element = xor_swap(element, 0x10, bfe(lane_id, 4)); + element = xor_swap(element, 0x08, bfe(lane_id, 3)); + element = xor_swap(element, 0x04, bfe(lane_id, 2)); + element = xor_swap(element, 0x02, bfe(lane_id, 1)); + element = xor_swap(element, 0x01, bfe(lane_id, 0)); + return; +} + +struct BuildConfig { + size_t max_dataset_size; + size_t dataset_dim; + size_t node_degree{64}; + size_t internal_node_degree{0}; + // If internal_node_degree == 0, the value of node_degree will be assigned to it + size_t max_iterations{50}; + float termination_threshold{0.0001}; +}; + +template +class BloomFilter { + public: + BloomFilter(size_t nrow, size_t num_sets_per_list, size_t num_hashs) + : nrow_(nrow), + num_sets_per_list_(num_sets_per_list), + num_hashs_(num_hashs), + bitsets_(nrow * num_bits_per_set_ * num_sets_per_list) + { + } + + void add(size_t list_id, Index_t key) + { + if (is_cleared) { is_cleared = false; } + uint32_t hash = hash_0(key); + size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + + key % num_sets_per_list_ * num_bits_per_set_; + bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; + for (size_t i = 1; i < num_hashs_; i++) { + hash = hash + hash_1(key); + bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; + } + } + + bool check(size_t list_id, Index_t key) + { + bool is_present = true; + uint32_t hash = hash_0(key); + size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + + key % num_sets_per_list_ * num_bits_per_set_; + is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; + + if (!is_present) return false; + for (size_t i = 1; i < num_hashs_; i++) { + hash = hash + hash_1(key); + is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; + if (!is_present) return false; + } + return true; + } + + void clear() + { + if (is_cleared) return; +#pragma omp parallel for + for (size_t i = 0; i < nrow_ * num_bits_per_set_ * num_sets_per_list_; i++) { + bitsets_[i] = 0; + } + is_cleared = true; + } + + private: + uint32_t hash_0(uint32_t value) + { + value *= 1103515245; + value += 12345; + value ^= value << 13; + value ^= value >> 17; + value ^= value << 5; + return value; + } + + uint32_t hash_1(uint32_t value) + { + value *= 1664525; + value += 1013904223; + value ^= value << 13; + value ^= value >> 17; + value ^= value << 5; + return value; + } + + static constexpr int num_bits_per_set_ = 512; + bool is_cleared{true}; + std::vector bitsets_; + size_t nrow_; + size_t num_sets_per_list_; + size_t num_hashs_; +}; + +template +struct GnndGraph { + static constexpr int segment_size = 32; + InternalID_t* h_graph; + + size_t nrow; + size_t node_degree; + int num_samples; + int num_segments; + + raft::host_matrix h_dists; + + thrust::host_vector> h_graph_new; + thrust::host_vector> h_list_sizes_new; + + thrust::host_vector> h_graph_old; + thrust::host_vector> h_list_sizes_old; + BloomFilter bloom_filter; + + GnndGraph(const GnndGraph&) = delete; + GnndGraph& operator=(const GnndGraph&) = delete; + GnndGraph(const size_t nrow, + const size_t node_degree, + const size_t internal_node_degree, + const size_t num_samples); + void init_random_graph(); + // TODO: Create a generic bloom filter utility https://github.com/rapidsai/raft/issues/1827 + // Use Bloom filter to sample "new" neighbors for local joining + void sample_graph_new(InternalID_t* new_neighbors, const size_t width); + void sample_graph(bool sample_new); + void update_graph(const InternalID_t* new_neighbors, + const DistData_t* new_dists, + const size_t width, + std::atomic& update_counter); + void sort_lists(); + void clear(); + ~GnndGraph(); +}; + +template +class GNND { + public: + GNND(raft::resources const& res, const BuildConfig& build_config); + GNND(const GNND&) = delete; + GNND& operator=(const GNND&) = delete; + + void build(Data_t* data, const Index_t nrow, Index_t* output_graph); + ~GNND() = default; + using ID_t = InternalID_t; + + private: + void add_reverse_edges(Index_t* graph_ptr, + Index_t* h_rev_graph_ptr, + Index_t* d_rev_graph_ptr, + int2* list_sizes, + cudaStream_t stream = 0); + void local_join(cudaStream_t stream = 0); + + raft::resources const& res; + + BuildConfig build_config_; + GnndGraph graph_; + std::atomic update_counter_; + + Index_t nrow_; + const int ndim_; + + raft::device_matrix<__half, Index_t, raft::row_major> d_data_; + raft::device_vector l2_norms_; + + raft::device_matrix graph_buffer_; + raft::device_matrix dists_buffer_; + + // TODO: Investigate using RMM/RAFT types https://github.com/rapidsai/raft/issues/1827 + thrust::host_vector> graph_host_buffer_; + thrust::host_vector> dists_host_buffer_; + + raft::device_vector d_locks_; + + thrust::host_vector> h_rev_graph_new_; + thrust::host_vector> h_graph_old_; + thrust::host_vector> h_rev_graph_old_; + // int2.x is the number of forward edges, int2.y is the number of reverse edges + + raft::device_vector d_list_sizes_new_; + raft::device_vector d_list_sizes_old_; +}; + +constexpr int TILE_ROW_WIDTH = 64; +constexpr int TILE_COL_WIDTH = 128; + +constexpr int NUM_SAMPLES = 32; +// For now, the max. number of samples is 32, so the sample cache size is fixed +// to 64 (32 * 2). +constexpr int MAX_NUM_BI_SAMPLES = 64; +constexpr int SKEWED_MAX_NUM_BI_SAMPLES = skew_dim(MAX_NUM_BI_SAMPLES); +constexpr int BLOCK_SIZE = 512; +constexpr int WMMA_M = 16; +constexpr int WMMA_N = 16; +constexpr int WMMA_K = 16; + +template +__device__ __forceinline__ void load_vec(Data_t* vec_buffer, + const Data_t* d_vec, + const int load_dims, + const int padding_dims, + const int lane_id) +{ + if constexpr (std::is_same_v or std::is_same_v or + std::is_same_v) { + constexpr int num_load_elems_per_warp = raft::warp_size(); + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { + int idx = step * num_load_elems_per_warp + lane_id; + if (idx < load_dims) { + vec_buffer[idx] = d_vec[idx]; + } else if (idx < padding_dims) { + vec_buffer[idx] = 0.0f; + } + } + } + if constexpr (std::is_same_v) { + if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 && + load_dims % 4 == 0 && padding_dims % 4 == 0) { + constexpr int num_load_elems_per_warp = raft::warp_size() * 4; +#pragma unroll + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { + int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; + if (idx_in_vec + 4 <= load_dims) { + *(float2*)(vec_buffer + idx_in_vec) = *(float2*)(d_vec + idx_in_vec); + } else if (idx_in_vec + 4 <= padding_dims) { + *(float2*)(vec_buffer + idx_in_vec) = float2({0.0f, 0.0f}); + } + } + } else { + constexpr int num_load_elems_per_warp = raft::warp_size(); + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { + int idx = step * num_load_elems_per_warp + lane_id; + if (idx < load_dims) { + vec_buffer[idx] = d_vec[idx]; + } else if (idx < padding_dims) { + vec_buffer[idx] = 0.0f; + } + } + } + } +} + +// TODO: Replace with RAFT utilities https://github.com/rapidsai/raft/issues/1827 +/** Calculate L2 norm, and cast data to __half */ +template +__global__ void preprocess_data_kernel(const Data_t* input_data, + __half* output_data, + int dim, + DistData_t* l2_norms, + size_t list_offset = 0) +{ + extern __shared__ char buffer[]; + __shared__ float l2_norm; + Data_t* s_vec = (Data_t*)buffer; + size_t list_id = list_offset + blockIdx.x; + + load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % raft::warp_size()); + if (threadIdx.x == 0) { l2_norm = 0; } + __syncthreads(); + int lane_id = threadIdx.x % raft::warp_size(); + for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { + int idx = step * raft::warp_size() + lane_id; + float part_dist = 0; + if (idx < dim) { + part_dist = s_vec[idx]; + part_dist = part_dist * part_dist; + } + __syncwarp(); + for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { + part_dist += __shfl_down_sync(raft::warp_full_mask(), part_dist, offset); + } + if (lane_id == 0) { l2_norm += part_dist; } + __syncwarp(); + } + + for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { + int idx = step * raft::warp_size() + threadIdx.x; + if (idx < dim) { + if (l2_norms == nullptr) { + output_data[list_id * dim + idx] = + (float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm); + } else { + output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; + if (idx == 0) { l2_norms[list_id] = l2_norm; } + } + } + } +} + +template +__global__ void add_rev_edges_kernel(const Index_t* graph, + Index_t* rev_graph, + int num_samples, + int2* list_sizes) +{ + size_t list_id = blockIdx.x; + int2 list_size = list_sizes[list_id]; + + for (int idx = threadIdx.x; idx < list_size.x; idx += blockDim.x) { + // each node has same number (num_samples) of forward and reverse edges + size_t rev_list_id = graph[list_id * num_samples + idx]; + // there are already num_samples forward edges + int idx_in_rev_list = atomicAdd(&list_sizes[rev_list_id].y, 1); + if (idx_in_rev_list >= num_samples) { + atomicExch(&list_sizes[rev_list_id].y, num_samples); + } else { + rev_graph[rev_list_id * num_samples + idx_in_rev_list] = list_id; + } + } +} + +template > +__device__ void insert_to_global_graph(ResultItem elem, + size_t list_id, + ID_t* graph, + DistData_t* dists, + int node_degree, + int* locks) +{ + int tx = threadIdx.x; + int lane_id = tx % raft::warp_size(); + size_t global_idx_base = list_id * node_degree; + if (elem.id() == list_id) return; + + const int num_segments = ceildiv(node_degree, raft::warp_size()); + + int loop_flag = 0; + do { + int segment_id = elem.id() % num_segments; + if (lane_id == 0) { + loop_flag = atomicCAS(&locks[list_id * num_segments + segment_id], 0, 1) == 0; + } + + loop_flag = __shfl_sync(raft::warp_full_mask(), loop_flag, 0); + + if (loop_flag == 1) { + ResultItem knn_list_frag; + int local_idx = segment_id * raft::warp_size() + lane_id; + size_t global_idx = global_idx_base + local_idx; + if (local_idx < node_degree) { + knn_list_frag.id_with_flag() = graph[global_idx].id_with_flag(); + knn_list_frag.dist() = dists[global_idx]; + } + + int pos_to_insert = -1; + ResultItem prev_elem; + + prev_elem.id_with_flag() = + __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.id_with_flag(), 1); + prev_elem.dist() = __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.dist(), 1); + + if (lane_id == 0) { + prev_elem = ResultItem{std::numeric_limits::min(), + std::numeric_limits::lowest()}; + } + if (elem > prev_elem && elem < knn_list_frag) { + pos_to_insert = segment_id * raft::warp_size() + lane_id; + } else if (elem == prev_elem || elem == knn_list_frag) { + pos_to_insert = -2; + } + uint mask = __ballot_sync(raft::warp_full_mask(), pos_to_insert >= 0); + if (mask) { + uint set_lane_id = __fns(mask, 0, 1); + pos_to_insert = __shfl_sync(raft::warp_full_mask(), pos_to_insert, set_lane_id); + } + + if (pos_to_insert >= 0) { + int local_idx = segment_id * raft::warp_size() + lane_id; + if (local_idx > pos_to_insert) { + local_idx++; + } else if (local_idx == pos_to_insert) { + graph[global_idx_base + local_idx].id_with_flag() = elem.id_with_flag(); + dists[global_idx_base + local_idx] = elem.dist(); + local_idx++; + } + size_t global_pos = global_idx_base + local_idx; + if (local_idx < (segment_id + 1) * raft::warp_size() && local_idx < node_degree) { + graph[global_pos].id_with_flag() = knn_list_frag.id_with_flag(); + dists[global_pos] = knn_list_frag.dist(); + } + } + __threadfence(); + if (loop_flag && lane_id == 0) { atomicExch(&locks[list_id * num_segments + segment_id], 0); } + } + } while (!loop_flag); +} + +template +__device__ ResultItem get_min_item(const Index_t id, + const int idx_in_list, + const Index_t* neighbs, + const DistData_t* distances, + const bool find_in_row = true) +{ + int lane_id = threadIdx.x % raft::warp_size(); + + static_assert(MAX_NUM_BI_SAMPLES == 64); + int idx[MAX_NUM_BI_SAMPLES / raft::warp_size()]; + float dist[MAX_NUM_BI_SAMPLES / raft::warp_size()] = {std::numeric_limits::max(), + std::numeric_limits::max()}; + idx[0] = lane_id; + idx[1] = raft::warp_size() + lane_id; + + if (neighbs[idx[0]] != id) { + dist[0] = find_in_row ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + lane_id] + : distances[idx_in_list + lane_id * SKEWED_MAX_NUM_BI_SAMPLES]; + } + + if (neighbs[idx[1]] != id) { + dist[1] = + find_in_row + ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + raft::warp_size() + lane_id] + : distances[idx_in_list + (raft::warp_size() + lane_id) * SKEWED_MAX_NUM_BI_SAMPLES]; + } + + if (dist[1] < dist[0]) { + dist[0] = dist[1]; + idx[0] = idx[1]; + } + __syncwarp(); + for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { + float other_idx = __shfl_down_sync(raft::warp_full_mask(), idx[0], offset); + float other_dist = __shfl_down_sync(raft::warp_full_mask(), dist[0], offset); + if (other_dist < dist[0]) { + dist[0] = other_dist; + idx[0] = other_idx; + } + } + + ResultItem result; + result.dist() = __shfl_sync(raft::warp_full_mask(), dist[0], 0); + result.id_with_flag() = neighbs[__shfl_sync(raft::warp_full_mask(), idx[0], 0)]; + return result; +} + +template +__device__ __forceinline__ void remove_duplicates( + T* list_a, int list_a_size, T* list_b, int list_b_size, int& unique_counter, int execute_warp_id) +{ + static_assert(raft::warp_size() == 32); + if (!(threadIdx.x >= execute_warp_id * raft::warp_size() && + threadIdx.x < execute_warp_id * raft::warp_size() + raft::warp_size())) { + return; + } + int lane_id = threadIdx.x % raft::warp_size(); + T elem = std::numeric_limits::max(); + if (lane_id < list_a_size) { elem = list_a[lane_id]; } + warp_bitonic_sort(&elem, lane_id); + + if (elem != std::numeric_limits::max()) { list_a[lane_id] = elem; } + + T elem_b = std::numeric_limits::max(); + + if (lane_id < list_b_size) { elem_b = list_b[lane_id]; } + __syncwarp(); + + int idx_l = 0; + int idx_r = list_a_size; + bool existed = false; + while (idx_l < idx_r) { + int idx = (idx_l + idx_r) / 2; + int elem = list_a[idx]; + if (elem == elem_b) { + existed = true; + break; + } + if (elem_b > elem) { + idx_l = idx + 1; + } else { + idx_r = idx; + } + } + if (!existed && elem_b != std::numeric_limits::max()) { + int idx = atomicAdd(&unique_counter, 1); + list_a[list_a_size + idx] = elem_b; + } +} + +// launch_bounds here denote BLOCK_SIZE = 512 and MIN_BLOCKS_PER_SM = 4 +// Per +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications, +// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048 +// For architectures 750 and 860, the values for MAX_RESIDENT_THREAD_PER_SM +// is 1024 and 1536 respectively, which means the bounds don't work anymore +template > +__global__ void +#ifdef __CUDA_ARCH__ +#if (__CUDA_ARCH__) == 750 || (__CUDA_ARCH__) == 860 +__launch_bounds__(BLOCK_SIZE) +#else +__launch_bounds__(BLOCK_SIZE, 4) +#endif +#endif + local_join_kernel(const Index_t* graph_new, + const Index_t* rev_graph_new, + const int2* sizes_new, + const Index_t* graph_old, + const Index_t* rev_graph_old, + const int2* sizes_old, + const int width, + const __half* data, + const int data_dim, + ID_t* graph, + DistData_t* dists, + int graph_width, + int* locks, + DistData_t* l2_norms) +{ +#if (__CUDA_ARCH__ >= 700) + using namespace nvcuda; + __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; + + constexpr int APAD = 8; + constexpr int BPAD = 8; + __shared__ __half s_nv[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + APAD]; // New vectors + __shared__ __half s_ov[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + BPAD]; // Old vectors + static_assert(sizeof(float) * MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES <= + sizeof(__half) * MAX_NUM_BI_SAMPLES * (TILE_COL_WIDTH + BPAD)); + // s_distances: MAX_NUM_BI_SAMPLES x SKEWED_MAX_NUM_BI_SAMPLES, reuse the space of s_ov + float* s_distances = (float*)&s_ov[0][0]; + int* s_unique_counter = (int*)&s_ov[0][0]; + + if (threadIdx.x == 0) { + s_unique_counter[0] = 0; + s_unique_counter[1] = 0; + } + + Index_t* new_neighbors = s_list; + Index_t* old_neighbors = s_list + MAX_NUM_BI_SAMPLES; + + size_t list_id = blockIdx.x; + int2 list_new_size2 = sizes_new[list_id]; + int list_new_size = list_new_size2.x + list_new_size2.y; + int2 list_old_size2 = sizes_old[list_id]; + int list_old_size = list_old_size2.x + list_old_size2.y; + + if (!list_new_size) return; + int tx = threadIdx.x; + + if (tx < list_new_size2.x) { + new_neighbors[tx] = graph_new[list_id * width + tx]; + } else if (tx >= list_new_size2.x && tx < list_new_size) { + new_neighbors[tx] = rev_graph_new[list_id * width + tx - list_new_size2.x]; + } + + if (tx < list_old_size2.x) { + old_neighbors[tx] = graph_old[list_id * width + tx]; + } else if (tx >= list_old_size2.x && tx < list_old_size) { + old_neighbors[tx] = rev_graph_old[list_id * width + tx - list_old_size2.x]; + } + + __syncthreads(); + + remove_duplicates(new_neighbors, + list_new_size2.x, + new_neighbors + list_new_size2.x, + list_new_size2.y, + s_unique_counter[0], + 0); + + remove_duplicates(old_neighbors, + list_old_size2.x, + old_neighbors + list_old_size2.x, + list_old_size2.y, + s_unique_counter[1], + 1); + __syncthreads(); + list_new_size = list_new_size2.x + s_unique_counter[0]; + list_old_size = list_old_size2.x + s_unique_counter[1]; + + int warp_id = threadIdx.x / raft::warp_size(); + int lane_id = threadIdx.x % raft::warp_size(); + constexpr int num_warps = BLOCK_SIZE / raft::warp_size(); + + int warp_id_y = warp_id / 4; + int warp_id_x = warp_id % 4; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0); + for (int step = 0; step < ceildiv(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == ceildiv(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } + } + __syncthreads(); + + for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { + wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); + wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); + } + } + + wmma::store_matrix_sync( + s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, + c_frag, + SKEWED_MAX_NUM_BI_SAMPLES, + wmma::mem_row_major); + __syncthreads(); + + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { + if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size && + i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + if (l2_norms == nullptr) { + s_distances[i] = -s_distances[i]; + } else { + s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + + l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - + 2.0 * s_distances[i]; + } + } else { + s_distances[i] = std::numeric_limits::max(); + } + } + __syncthreads(); + + for (int step = 0; step < ceildiv(list_new_size, num_warps); step++) { + int idx_in_list = step * num_warps + tx / raft::warp_size(); + if (idx_in_list >= list_new_size) continue; + auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); + } + } + + if (!list_old_size) return; + + __syncthreads(); + + wmma::fill_fragment(c_frag, 0.0); + for (int step = 0; step < ceildiv(data_dim, TILE_COL_WIDTH); step++) { + int num_load_elems = (step == ceildiv(data_dim, TILE_COL_WIDTH) - 1) + ? data_dim - step * TILE_COL_WIDTH + : TILE_COL_WIDTH; + if (TILE_COL_WIDTH < data_dim) { +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_new_size) { + size_t neighbor_id = new_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_nv[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } + } + } +#pragma unroll + for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { + int idx = i * num_warps + warp_id; + if (idx < list_old_size) { + size_t neighbor_id = old_neighbors[idx]; + size_t idx_in_data = neighbor_id * data_dim; + load_vec(s_ov[idx], + data + idx_in_data + step * TILE_COL_WIDTH, + num_load_elems, + TILE_COL_WIDTH, + lane_id); + } + } + __syncthreads(); + + for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { + wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); + wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + __syncthreads(); + } + } + + wmma::store_matrix_sync( + s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, + c_frag, + SKEWED_MAX_NUM_BI_SAMPLES, + wmma::mem_row_major); + __syncthreads(); + + for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { + if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size && + i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) { + if (l2_norms == nullptr) { + s_distances[i] = -s_distances[i]; + } else { + s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] + + l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] - + 2.0 * s_distances[i]; + } + } else { + s_distances[i] = std::numeric_limits::max(); + } + } + __syncthreads(); + + for (int step = 0; step < ceildiv(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { + int idx_in_list = step * num_warps + tx / raft::warp_size(); + if (idx_in_list >= list_new_size && idx_in_list < MAX_NUM_BI_SAMPLES) continue; + if (idx_in_list >= MAX_NUM_BI_SAMPLES + list_old_size && idx_in_list < MAX_NUM_BI_SAMPLES * 2) + continue; + ResultItem min_elem{std::numeric_limits::max(), + std::numeric_limits::max()}; + if (idx_in_list < MAX_NUM_BI_SAMPLES) { + auto temp_min_item = + get_min_item(s_list[idx_in_list], idx_in_list, old_neighbors, s_distances); + if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } + } else { + auto temp_min_item = get_min_item( + s_list[idx_in_list], idx_in_list - MAX_NUM_BI_SAMPLES, new_neighbors, s_distances, false); + if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } + } + + if (min_elem.id() < gridDim.x) { + insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); + } + } +#endif +} + +namespace { +template +int insert_to_ordered_list(InternalID_t* list, + DistData_t* dist_list, + const int width, + const InternalID_t neighb_id, + const DistData_t dist) +{ + if (dist > dist_list[width - 1]) { return width; } + + int idx_insert = width; + bool position_found = false; + for (int i = 0; i < width; i++) { + if (list[i].id() == neighb_id.id()) { return width; } + if (!position_found && dist_list[i] > dist) { + idx_insert = i; + position_found = true; + } + } + if (idx_insert == width) return idx_insert; + + memmove(list + idx_insert + 1, list + idx_insert, sizeof(*list) * (width - idx_insert - 1)); + memmove(dist_list + idx_insert + 1, + dist_list + idx_insert, + sizeof(*dist_list) * (width - idx_insert - 1)); + + list[idx_insert] = neighb_id; + dist_list[idx_insert] = dist; + return idx_insert; +}; + +} // namespace + +template +GnndGraph::GnndGraph(const size_t nrow, + const size_t node_degree, + const size_t internal_node_degree, + const size_t num_samples) + : nrow(nrow), + node_degree(node_degree), + num_samples(num_samples), + bloom_filter(nrow, internal_node_degree / segment_size, 3), + h_dists{raft::make_host_matrix(nrow, node_degree)}, + h_graph_new{nrow * num_samples}, + h_list_sizes_new{nrow}, + h_graph_old{nrow * num_samples}, + h_list_sizes_old{nrow} +{ + // node_degree must be a multiple of segment_size; + assert(node_degree % segment_size == 0); + assert(internal_node_degree % segment_size == 0); + + num_segments = node_degree / segment_size; + // To save the CPU memory, graph should be allocated by external function + h_graph = nullptr; +} + +// This is the only operation on the CPU that cannot be overlapped. +// So it should be as fast as possible. +template +void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, const size_t width) +{ +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + auto list_new = h_graph_new.data() + i * num_samples; + h_list_sizes_new[i].x = 0; + h_list_sizes_new[i].y = 0; + + for (size_t j = 0; j < width; j++) { + auto new_neighb_id = new_neighbors[i * width + j].id(); + if ((size_t)new_neighb_id >= nrow) break; + if (bloom_filter.check(i, new_neighb_id)) { continue; } + bloom_filter.add(i, new_neighb_id); + new_neighbors[i * width + j].mark_old(); + list_new[h_list_sizes_new[i].x++] = new_neighb_id; + if (h_list_sizes_new[i].x == num_samples) break; + } + } +} + +template +void GnndGraph::init_random_graph() +{ + for (size_t seg_idx = 0; seg_idx < static_cast(num_segments); seg_idx++) { + // random sequence (range: 0~nrow) + // segment_x stores neighbors which id % num_segments == x + std::vector rand_seq(nrow / num_segments); + std::iota(rand_seq.begin(), rand_seq.end(), 0); + std::random_shuffle(rand_seq.begin(), rand_seq.end()); + +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + size_t base_idx = i * node_degree + seg_idx * segment_size; + auto h_neighbor_list = h_graph + base_idx; + auto h_dist_list = h_dists.data_handle() + base_idx; + for (size_t j = 0; j < static_cast(segment_size); j++) { + size_t idx = base_idx + j; + Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; + if ((size_t)id == i) { + id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx; + } + h_neighbor_list[j].id_with_flag() = id; + h_dist_list[j] = std::numeric_limits::max(); + } + } + } +} + +template +void GnndGraph::sample_graph(bool sample_new) +{ +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + h_list_sizes_old[i].x = 0; + h_list_sizes_old[i].y = 0; + h_list_sizes_new[i].x = 0; + h_list_sizes_new[i].y = 0; + + auto list = h_graph + i * node_degree; + auto list_old = h_graph_old.data() + i * num_samples; + auto list_new = h_graph_new.data() + i * num_samples; + for (int j = 0; j < segment_size; j++) { + for (int k = 0; k < num_segments; k++) { + auto neighbor = list[k * segment_size + j]; + if ((size_t)neighbor.id() >= nrow) continue; + if (!neighbor.is_new()) { + if (h_list_sizes_old[i].x < num_samples) { + list_old[h_list_sizes_old[i].x++] = neighbor.id(); + } + } else if (sample_new) { + if (h_list_sizes_new[i].x < num_samples) { + list[k * segment_size + j].mark_old(); + list_new[h_list_sizes_new[i].x++] = neighbor.id(); + } + } + if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } + } + if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } + } + } +} + +template +void GnndGraph::update_graph(const InternalID_t* new_neighbors, + const DistData_t* new_dists, + const size_t width, + std::atomic& update_counter) +{ +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + for (size_t j = 0; j < width; j++) { + auto new_neighb_id = new_neighbors[i * width + j]; + auto new_dist = new_dists[i * width + j]; + if (new_dist == std::numeric_limits::max()) break; + if ((size_t)new_neighb_id.id() == i) continue; + int seg_idx = new_neighb_id.id() % num_segments; + auto list = h_graph + i * node_degree + seg_idx * segment_size; + auto dist_list = h_dists.data_handle() + i * node_degree + seg_idx * segment_size; + int insert_pos = + insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist); + if (i % counter_interval == 0 && insert_pos != segment_size) { update_counter++; } + } + } +} + +template +void GnndGraph::sort_lists() +{ +#pragma omp parallel for + for (size_t i = 0; i < nrow; i++) { + std::vector> new_list; + for (size_t j = 0; j < node_degree; j++) { + new_list.emplace_back(h_dists.data_handle()[i * node_degree + j], + h_graph[i * node_degree + j].id()); + } + std::sort(new_list.begin(), new_list.end()); + for (size_t j = 0; j < node_degree; j++) { + h_graph[i * node_degree + j].id_with_flag() = new_list[j].second; + h_dists.data_handle()[i * node_degree + j] = new_list[j].first; + } + } +} + +template +void GnndGraph::clear() +{ + bloom_filter.clear(); +} + +template +GnndGraph::~GnndGraph() +{ + assert(h_graph == nullptr); +} + +template +GNND::GNND(raft::resources const& res, const BuildConfig& build_config) + : res(res), + build_config_(build_config), + graph_(build_config.max_dataset_size, + align32::roundUp(build_config.node_degree), + align32::roundUp(build_config.internal_node_degree ? build_config.internal_node_degree + : build_config.node_degree), + NUM_SAMPLES), + nrow_(build_config.max_dataset_size), + ndim_(build_config.dataset_dim), + d_data_{raft::make_device_matrix<__half, Index_t, raft::row_major>( + res, nrow_, build_config.dataset_dim)}, + l2_norms_{raft::make_device_vector(res, nrow_)}, + graph_buffer_{ + raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, + dists_buffer_{ + raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, + graph_host_buffer_{static_cast(nrow_ * DEGREE_ON_DEVICE)}, + dists_host_buffer_{static_cast(nrow_ * DEGREE_ON_DEVICE)}, + d_locks_{raft::make_device_vector(res, nrow_)}, + h_rev_graph_new_{static_cast(nrow_ * NUM_SAMPLES)}, + h_graph_old_{static_cast(nrow_ * NUM_SAMPLES)}, + h_rev_graph_old_{static_cast(nrow_ * NUM_SAMPLES)}, + d_list_sizes_new_{raft::make_device_vector(res, nrow_)}, + d_list_sizes_old_{raft::make_device_vector(res, nrow_)} +{ + static_assert(NUM_SAMPLES <= 32); + + thrust::fill(thrust::device, + dists_buffer_.data_handle(), + dists_buffer_.data_handle() + dists_buffer_.size(), + std::numeric_limits::max()); + thrust::fill(thrust::device, + reinterpret_cast(graph_buffer_.data_handle()), + reinterpret_cast(graph_buffer_.data_handle()) + graph_buffer_.size(), + std::numeric_limits::max()); + thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); +}; + +template +void GNND::add_reverse_edges(Index_t* graph_ptr, + Index_t* h_rev_graph_ptr, + Index_t* d_rev_graph_ptr, + int2* list_sizes, + cudaStream_t stream) +{ + add_rev_edges_kernel<<>>( + graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); + raft::copy( + h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, raft::resource::get_cuda_stream(res)); +} + +template +void GNND::local_join(cudaStream_t stream) +{ + thrust::fill(thrust::device.on(stream), + dists_buffer_.data_handle(), + dists_buffer_.data_handle() + dists_buffer_.size(), + std::numeric_limits::max()); + local_join_kernel<<>>( + thrust::raw_pointer_cast(graph_.h_graph_new.data()), + thrust::raw_pointer_cast(h_rev_graph_new_.data()), + d_list_sizes_new_.data_handle(), + thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(h_rev_graph_old_.data()), + d_list_sizes_old_.data_handle(), + NUM_SAMPLES, + d_data_.data_handle(), + ndim_, + graph_buffer_.data_handle(), + dists_buffer_.data_handle(), + DEGREE_ON_DEVICE, + d_locks_.data_handle(), + l2_norms_.data_handle()); +} + +template +void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) +{ + using input_t = typename std::remove_const::type; + + cudaStream_t stream = raft::resource::get_cuda_stream(res); + nrow_ = nrow; + graph_.h_graph = (InternalID_t*)output_graph; + + cudaPointerAttributes data_ptr_attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); + size_t batch_size = (data_ptr_attr.devicePointer == nullptr) ? 100000 : nrow_; + + raft::spatial::knn::detail::utils::batch_load_iterator vec_batches{ + data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; + for (auto const& batch : vec_batches) { + preprocess_data_kernel<<< + batch.size(), + raft::warp_size(), + sizeof(Data_t) * ceildiv(build_config_.dataset_dim, static_cast(raft::warp_size())) * + raft::warp_size(), + stream>>>(batch.data(), + d_data_.data_handle(), + build_config_.dataset_dim, + l2_norms_.data_handle(), + batch.offset()); + } + + thrust::fill(thrust::device.on(stream), + (Index_t*)graph_buffer_.data_handle(), + (Index_t*)graph_buffer_.data_handle() + graph_buffer_.size(), + std::numeric_limits::max()); + + graph_.clear(); + graph_.init_random_graph(); + graph_.sample_graph(true); + + auto update_and_sample = [&](bool update_graph) { + if (update_graph) { + update_counter_ = 0; + graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), + thrust::raw_pointer_cast(dists_host_buffer_.data()), + DEGREE_ON_DEVICE, + update_counter_); + if (update_counter_ < build_config_.termination_threshold * nrow_ * + build_config_.dataset_dim / counter_interval) { + update_counter_ = -1; + } + } + graph_.sample_graph(false); + }; + + for (size_t it = 0; it < build_config_.max_iterations; it++) { + raft::copy(d_list_sizes_new_.data_handle(), + thrust::raw_pointer_cast(graph_.h_list_sizes_new.data()), + nrow_, + raft::resource::get_cuda_stream(res)); + raft::copy(thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(graph_.h_graph_old.data()), + nrow_ * NUM_SAMPLES, + raft::resource::get_cuda_stream(res)); + raft::copy(d_list_sizes_old_.data_handle(), + thrust::raw_pointer_cast(graph_.h_list_sizes_old.data()), + nrow_, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + + std::thread update_and_sample_thread(update_and_sample, it); + + std::cout << "# GNND iteraton: " << it + 1 << "/" << build_config_.max_iterations << "\r"; + std::fflush(stdout); + + // Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it + // contains some information for local_join. + static_assert(DEGREE_ON_DEVICE * sizeof(*(dists_buffer_.data_handle())) >= + NUM_SAMPLES * sizeof(*(graph_buffer_.data_handle()))); + add_reverse_edges(thrust::raw_pointer_cast(graph_.h_graph_new.data()), + thrust::raw_pointer_cast(h_rev_graph_new_.data()), + (Index_t*)dists_buffer_.data_handle(), + d_list_sizes_new_.data_handle(), + stream); + add_reverse_edges(thrust::raw_pointer_cast(h_graph_old_.data()), + thrust::raw_pointer_cast(h_rev_graph_old_.data()), + (Index_t*)dists_buffer_.data_handle(), + d_list_sizes_old_.data_handle(), + stream); + + // Tensor operations from `mma.h` are guarded with archicteture + // __CUDA_ARCH__ >= 700. Since RAFT supports compilation for ARCH 600, + // we need to ensure that `local_join_kernel` (which uses tensor) operations + // is not only not compiled, but also a runtime error is presented to the user + auto kernel = preprocess_data_kernel; + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = raft::util::arch::kernel_virtual_arch(kernel_ptr); + auto wmma_range = + raft::util::arch::SM_range(raft::util::arch::SM_70(), raft::util::arch::SM_future()); + + if (wmma_range.contains(runtime_arch)) { + local_join(stream); + } else { + THROW("NN_DESCENT cannot be run for __CUDA_ARCH__ < 700"); + } + + update_and_sample_thread.join(); + + if (update_counter_ == -1) { break; } + raft::copy(thrust::raw_pointer_cast(graph_host_buffer_.data()), + graph_buffer_.data_handle(), + nrow_ * DEGREE_ON_DEVICE, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + raft::copy(thrust::raw_pointer_cast(dists_host_buffer_.data()), + dists_buffer_.data_handle(), + nrow_ * DEGREE_ON_DEVICE, + raft::resource::get_cuda_stream(res)); + + graph_.sample_graph_new(thrust::raw_pointer_cast(graph_host_buffer_.data()), DEGREE_ON_DEVICE); + } + + graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), + thrust::raw_pointer_cast(dists_host_buffer_.data()), + DEGREE_ON_DEVICE, + update_counter_); + raft::resource::sync_stream(res); + graph_.sort_lists(); + + // Reuse graph_.h_dists as the buffer for shrink the lists in graph + static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t)); + Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); + +#pragma omp parallel for + for (size_t i = 0; i < (size_t)nrow_; i++) { + for (size_t j = 0; j < build_config_.node_degree; j++) { + size_t idx = i * graph_.node_degree + j; + Index_t id = graph_.h_graph[idx].id(); + if (id < nrow_) { + graph_shrink_buffer[i * build_config_.node_degree + j] = id; + } else { + graph_shrink_buffer[i * build_config_.node_degree + j] = + raft::neighbors::cagra::detail::device::xorshift64(idx) % nrow_; + } + } + } + graph_.h_graph = nullptr; + +#pragma omp parallel for + for (size_t i = 0; i < (size_t)nrow_; i++) { + for (size_t j = 0; j < build_config_.node_degree; j++) { + output_graph[i * build_config_.node_degree + j] = + graph_shrink_buffer[i * build_config_.node_degree + j]; + } + } +} + +template , memory_type::host>> +void build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset, + index& idx) +{ + RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, + "The dataset size for GNND should be less than %d", + std::numeric_limits::max() - 1); + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + + if (intermediate_degree >= static_cast(dataset.extent(0))) { + RAFT_LOG_WARN( + "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", + dataset.extent(0)); + intermediate_degree = dataset.extent(0) - 1; + } + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + // The elements in each knn-list are partitioned into different buckets, and we need more buckets + // to mitigate bucket collisions. `intermediate_degree` is OK to larger than + // extended_graph_degree. + size_t extended_graph_degree = + align32::roundUp(static_cast(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3))); + size_t extended_intermediate_degree = align32::roundUp( + static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); + + auto int_graph = raft::make_host_matrix( + dataset.extent(0), static_cast(extended_graph_degree)); + + BuildConfig build_config{.max_dataset_size = static_cast(dataset.extent(0)), + .dataset_dim = static_cast(dataset.extent(1)), + .node_degree = extended_graph_degree, + .internal_node_degree = extended_intermediate_degree, + .max_iterations = params.max_iterations, + .termination_threshold = params.termination_threshold}; + + GNND nnd(res, build_config); + nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); + +#pragma omp parallel for + for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { + for (size_t j = 0; j < graph_degree; j++) { + auto graph = idx.graph().data_handle(); + graph[i * graph_degree + j] = int_graph.data_handle()[i * extended_graph_degree + j]; + } + } +} + +template , memory_type::host>> +index build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset) +{ + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + index idx{res, dataset.extent(0), static_cast(graph_degree)}; + + build(res, params, dataset, idx); + + return idx; +} + +} // namespace raft::neighbors::experimental::nn_descent::detail diff --git a/cpp/include/raft/neighbors/nn_descent.cuh b/cpp/include/raft/neighbors/nn_descent.cuh new file mode 100644 index 0000000000..ceb5ae5643 --- /dev/null +++ b/cpp/include/raft/neighbors/nn_descent.cuh @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/nn_descent.cuh" + +#include +#include + +namespace raft::neighbors::experimental::nn_descent { + +/** + * @defgroup nn-descent CUDA gradient descent nearest neighbor + * @{ + */ + +/** + * @brief Build nn-descent Index with dataset in device memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = cagra::build(res, index_params, dataset); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param[in] res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::device_matrix_view input dataset expected to be located + * in device memory + * @return index index containing all-neighbors knn graph in host memory + */ +template +index build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset) +{ + return detail::build(res, params, dataset); +} + +/** + * @brief Build nn-descent Index with dataset in device memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto knn_graph = raft::make_host_matrix(N, D); + * auto index = nn_descent::index{res, knn_graph.view()}; + * cagra::build(res, index_params, dataset, index); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::device_matrix_view input dataset expected to be located + * in device memory + * @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph + * in host memory + */ +template +void build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset, + index& idx) +{ + detail::build(res, params, dataset, idx); +} + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::host_matrix_view dataset + * auto index = cagra::build(res, index_params, dataset); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @return index index containing all-neighbors knn graph in host memory + */ +template +index build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset) +{ + return detail::build(res, params, dataset); +} + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors::experimental; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::host_matrix_view dataset + * auto knn_graph = raft::make_host_matrix(N, D); + * auto index = nn_descent::index{res, knn_graph.view()}; + * cagra::build(res, index_params, dataset, index); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param[in] res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] idx raft::neighbors::experimental::nn_descentindex containing all-neighbors knn graph + * in host memory + */ +template +void build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset, + index& idx) +{ + detail::build(res, params, dataset, idx); +} + +/** @} */ // end group nn-descent + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp new file mode 100644 index 0000000000..64e464c618 --- /dev/null +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ann_types.hpp" +#include + +#include +#include +#include +#include +#include + +namespace raft::neighbors::experimental::nn_descent { +/** + * @ingroup nn_descent + * @{ + */ + +/** + * @brief Parameters used to build an nn-descent index + * + * `graph_degree`: For an input dataset of dimensions (N, D), + * determines the final dimensions of the all-neighbors knn graph + * which turns out to be of dimensions (N, graph_degree) + * `intermediate_graph_degree`: Internally, nn-descent builds an + * all-neighbors knn graph of dimensions (N, intermediate_graph_degree) + * before selecting the final `graph_degree` neighbors. It's recommended + * that `intermediate_graph_degree` >= 1.5 * graph_degree + * `max_iterations`: The number of iterations that nn-descent will refine + * the graph for. More iterations produce a better quality graph at cost of performance + * `termination_threshold`: The delta at which nn-descent will terminate its iterations + * + */ +struct index_params : ann::index_params { + size_t graph_degree = 64; // Degree of output graph. + size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. + size_t max_iterations = 20; // Number of nn-descent iterations. + float termination_threshold = 0.0001; // Termination threshold of nn-descent. +}; + +/** + * @brief nn-descent Build an nn-descent index + * The index contains an all-neighbors graph of the input dataset + * stored in host memory of dimensions (n_rows, n_cols) + * + * @tparam IdxT dtype to be used for constructing knn-graph + */ +template +struct index : ann::index { + public: + /** + * @brief Construct a new index object + * + * This constructor creates an nn-descent index which is a knn-graph in host memory. + * The type of the knn-graph is a dense raft::host_matrix and dimensions are + * (n_rows, n_cols). + * + * @param res raft::resources is an object mangaging resources + * @param n_rows number of rows in knn-graph + * @param n_cols number of cols in knn-graph + */ + index(raft::resources const& res, int64_t n_rows, int64_t n_cols) + : ann::index(), + res_{res}, + metric_{raft::distance::DistanceType::L2Expanded}, + graph_{raft::make_host_matrix(n_rows, n_cols)}, + graph_view_{graph_.view()} + { + } + + /** + * @brief Construct a new index object + * + * This constructor creates an nn-descent index using a user allocated host memory knn-graph. + * The type of the knn-graph is a dense raft::host_matrix and dimensions are + * (n_rows, n_cols). + * + * @param res raft::resources is an object mangaging resources + * @param graph_view raft::host_matrix_view for storing knn-graph + */ + index(raft::resources const& res, + raft::host_matrix_view graph_view) + : ann::index(), + res_{res}, + metric_{raft::distance::DistanceType::L2Expanded}, + graph_{raft::make_host_matrix(0, 0)}, + graph_view_{graph_view} + { + } + + /** Distance metric used for clustering. */ + [[nodiscard]] constexpr inline auto metric() const noexcept -> raft::distance::DistanceType + { + return metric_; + } + + // /** Total length of the index (number of vectors). */ + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT + { + return graph_view_.extent(0); + } + + /** Graph degree */ + [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t + { + return graph_view_.extent(1); + } + + /** neighborhood graph [size, graph-degree] */ + [[nodiscard]] inline auto graph() noexcept -> host_matrix_view + { + return graph_view_; + } + + // Don't allow copying the index for performance reasons (try avoiding copying data) + index(const index&) = delete; + index(index&&) = default; + auto operator=(const index&) -> index& = delete; + auto operator=(index&&) -> index& = default; + ~index() = default; + + private: + raft::resources const& res_; + raft::distance::DistanceType metric_; + raft::host_matrix graph_; // graph to return for non-int IdxT + raft::host_matrix_view + graph_view_; // view of graph for user provided matrix +}; + +/** @} */ + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index db4c59c807..71de21e64a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -379,6 +379,21 @@ if(BUILD_TESTS) 100 ) + ConfigureTest( + NAME + NEIGHBORS_ANN_NN_DESCENT_TEST + PATH + test/neighbors/ann_nn_descent/test_float_uint32_t.cu + test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu + test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu + LIB + EXPLICIT_INSTANTIATE_ONLY + GPUS + 1 + PERCENT + 100 + ) + ConfigureTest( NAME NEIGHBORS_SELECTION_TEST PATH test/neighbors/selection.cu LIB EXPLICIT_INSTANTIATE_ONLY GPUS 1 PERCENT 50 diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 90f271e3ee..343afd04ec 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -147,6 +147,7 @@ struct AnnCagraInputs { int n_rows; int dim; int k; + graph_build_algo build_algo; search_algo algo; int max_queries; int team_size; @@ -161,12 +162,13 @@ struct AnnCagraInputs { inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) { - std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; + std::vector algo = {"single-cta", "multi_cta", "multi_kernel", "auto"}; + std::vector build_algo = {"IVF_PQ", "NN_DESCENT"}; os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim << ", k=" << p.k << ", " << algo.at((int)p.algo) << ", max_queries=" << p.max_queries << ", itopk_size=" << p.itopk_size << ", search_width=" << p.search_width - << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") << '}' - << std::endl; + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") + << ", build_algo=" << build_algo.at((int)p.build_algo) << '}' << std::endl; return os; } @@ -216,6 +218,7 @@ class AnnCagraTest : public ::testing::TestWithParam { cagra::index_params index_params; index_params.metric = ps.metric; // Note: currently ony the cagra::index_params metric is // not used for knn_graph building. + index_params.build_algo = ps.build_algo; cagra::search_params search_params; search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; @@ -340,11 +343,25 @@ class AnnCagraSortTest : public ::testing::TestWithParam { auto knn_graph = raft::make_host_matrix(ps.n_rows, index_params.intermediate_graph_degree); - if (ps.host_dataset) { - cagra::build_knn_graph(handle_, database_host_view, knn_graph.view()); + if (ps.build_algo == graph_build_algo::IVF_PQ) { + if (ps.host_dataset) { + cagra::build_knn_graph(handle_, database_host_view, knn_graph.view()); + } else { + cagra::build_knn_graph(handle_, database_view, knn_graph.view()); + } } else { - cagra::build_knn_graph(handle_, database_view, knn_graph.view()); - }; + auto nn_descent_idx_params = experimental::nn_descent::index_params{}; + nn_descent_idx_params.graph_degree = index_params.intermediate_graph_degree; + nn_descent_idx_params.intermediate_graph_degree = index_params.intermediate_graph_degree; + + if (ps.host_dataset) { + cagra::build_knn_graph( + handle_, database_host_view, knn_graph.view(), nn_descent_idx_params); + } else { + cagra::build_knn_graph( + handle_, database_host_view, knn_graph.view(), nn_descent_idx_params); + } + } handle_.sync_stream(); ASSERT_TRUE(CheckOrder(knn_graph.view(), database_host.view())); @@ -546,6 +563,7 @@ inline std::vector generate_inputs() {1000}, {1, 8, 17}, {1, 16}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, {0, 1, 10, 100}, // query size {0}, @@ -561,6 +579,7 @@ inline std::vector generate_inputs() {1000}, {1, 3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim {16}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::AUTO}, {10}, {0}, @@ -571,68 +590,55 @@ inline std::vector generate_inputs() {true}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {1000}, - {64}, - {16}, - {search_algo::AUTO}, - {10}, - {0, 4, 8, 16, 32}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false}, - {false}, - {0.995}); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - - inputs2 = - raft::util::itertools::product({100}, - {1000}, - {64}, - {16}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {32, 64, 128, 256, 512, 768}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false}, - {true}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {1000}, + {64}, + {16}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0, 4, 8, 16, 32}, // team_size + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {false}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {10000, 20000}, - {32}, - {10}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false, true}, - {false}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {1000}, + {64}, + {16}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {32, 64, 128, 256, 512, 768}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {true}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - inputs2 = - raft::util::itertools::product({100}, - {10000, 20000}, - {32}, - {10}, - {search_algo::AUTO}, - {10}, - {0}, // team_size - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false, true}, - {true}, - {0.995}); + inputs2 = raft::util::itertools::product( + {100}, + {10000, 20000}, + {32}, + {10}, + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::AUTO}, + {10}, + {0}, // team_size + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false, true}, + {false}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); return inputs; diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh new file mode 100644 index 0000000000..948323cf6e --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../test_utils.cuh" +#include "ann_utils.cuh" + +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace raft::neighbors::experimental::nn_descent { + +struct AnnNNDescentInputs { + int n_rows; + int dim; + int graph_degree; + raft::distance::DistanceType metric; + bool host_dataset; + double min_recall; +}; + +inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& p) +{ + os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") + << std::endl; + return os; +} + +template +class AnnNNDescentTest : public ::testing::TestWithParam { + public: + AnnNNDescentTest() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_) + { + } + + protected: + void testNNDescent() + { + size_t queries_size = ps.n_rows * ps.graph_degree; + std::vector indices_NNDescent(queries_size); + std::vector indices_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + database.data(), + database.data(), + ps.n_rows, + ps.n_rows, + ps.dim, + ps.graph_degree, + ps.metric); + update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + { + { + nn_descent::index_params index_params; + index_params.metric = ps.metric; + index_params.graph_degree = ps.graph_degree; + index_params.intermediate_graph_degree = 2 * ps.graph_degree; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + { + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + auto index = nn_descent::build(handle_, index_params, database_host_view); + update_host( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + } else { + auto index = nn_descent::build(handle_, index_params, database_view); + update_host( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + }; + } + resource::sync_stream(handle_); + } + + double min_recall = ps.min_recall; + EXPECT_TRUE(eval_recall( + indices_naive, indices_NNDescent, ps.n_rows, ps.graph_degree, 0.001, min_recall)); + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + raft::random::Rng r(1234ULL); + if constexpr (std::is_same{}) { + r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); + } else { + r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); + } + resource::sync_stream(handle_); + } + + void TearDown() override + { + resource::sync_stream(handle_); + database.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnNNDescentInputs ps; + rmm::device_uvector database; +}; + +const std::vector inputs = raft::util::itertools::product( + {1000, 2000}, // n_rows + {3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim + {32, 64}, // graph_degree + {raft::distance::DistanceType::L2Expanded}, + {false, true}, + {0.92}); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu new file mode 100644 index 0000000000..13bff6ac90 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestF_U32; +TEST_P(AnnNNDescentTestF_U32, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_U32, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu new file mode 100644 index 0000000000..5895303e09 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestI8_U32; +TEST_P(AnnNNDescentTestI8_U32, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestI8_U32, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu new file mode 100644 index 0000000000..a034e84074 --- /dev/null +++ b/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_nn_descent.cuh" + +namespace raft::neighbors::experimental::nn_descent { + +typedef AnnNNDescentTest AnnNNDescentTestUI8_U32; +TEST_P(AnnNNDescentTestUI8_U32, AnnCagra) { this->testNNDescent(); } + +INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestUI8_U32, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::nn_descent diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 0e54e29c01..be60ec5b6d 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -123,6 +123,49 @@ struct idx_dist_pair { idx_dist_pair(IdxT x, DistT y, CompareDist op) : idx(x), dist(y), eq_compare(op) {} }; +template +auto eval_recall(const std::vector& expected_idx, + const std::vector& actual_idx, + size_t rows, + size_t cols, + double eps, + double min_recall) -> testing::AssertionResult +{ + size_t match_count = 0; + size_t total_count = static_cast(rows) * static_cast(cols); + for (size_t i = 0; i < rows; ++i) { + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + for (size_t j = 0; j < cols; ++j) { + size_t idx = i * cols + j; // row major assumption! + auto exp_idx = expected_idx[idx]; + if (act_idx == exp_idx) { + match_count++; + break; + } + } + } + } + double actual_recall = static_cast(match_count) / static_cast(total_count); + double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps); + RAFT_LOG_INFO("Recall = %f (%zu/%zu), the error is %2.1f%% %s the threshold (eps = %f).", + actual_recall, + match_count, + total_count, + std::abs(error_margin * 100.0), + error_margin < 0 ? "above" : "below", + eps); + if (actual_recall < min_recall - eps) { + return testing::AssertionFailure() + << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" + << min_recall << "); eps = " << eps << ". "; + } + return testing::AssertionSuccess(); +} + +/** same as eval_recall, but in case indices do not match, + * then check distances as well, and accept match if actual dist is equal to expected_dist */ template auto eval_neighbours(const std::vector& expected_idx, const std::vector& actual_idx, diff --git a/docs/source/ann_benchmarks_param_tuning.md b/docs/source/ann_benchmarks_param_tuning.md index dd6090c5e2..433df2ae2f 100644 --- a/docs/source/ann_benchmarks_param_tuning.md +++ b/docs/source/ann_benchmarks_param_tuning.md @@ -48,7 +48,8 @@ CAGRA uses a graph-based index, which creates an intermediate, approximate kNN g |-----------------------------|----------------|----------|----------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | `graph_degree` | `build_param` | N | Positive Integer >0 | 64 | Degree of the final kNN graph index. | | `intermediate_graph_degree` | `build_param` | N | Positive Integer >0 | 128 | Degree of the intermediate kNN graph. | -| `dataset_memory_type` | `build_param` | N | ["device", "host", "mmap"] | "device" | What memory type should the dataset reside? | +| `graph_build_algo` | `build_param` | N | ["IVF_PQ", "NN_DESCENT"] | "IVF_PQ" | Algorithm to use for search | +| `dataset_memory_type` | `build_param` | N | ["device", "host", "mmap"] | "device" | What memory type should the dataset reside? | | `query_memory_type` | `search_params` | N | ["device", "host", "mmap"] | "device | What memory type should the queries reside? | | `itopk` | `search_wdith` | N | Positive Integer >0 | 64 | Number of intermediate search results retained during the search. Higher values improve search accuracy at the cost of speed. | | `search_width` | `search_param` | N | Positive Integer >0 | 1 | Number of graph nodes to select as the starting point for the search in each iteration. | diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index e0c59a5ed3..c11d933b27 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -104,11 +104,13 @@ cdef class IndexParams: graph_degree : int, default = 64 - add_data_on_build : bool, default = True - After training the coarse and fine quantizers, we will populate - the index with the dataset if add_data_on_build == True, otherwise - the index is left empty, and the extend method can be used - to add new vectors to the index. + build_algo: string denoting the graph building algorithm to use, + default = "ivf_pq" + Valid values for algo: ["ivf_pq", "nn_descent"], where + - ivf_pq will use the IVF-PQ algorithm for building the knn graph + - nn_descent (experimental) will use the NN-Descent algorithm for + building the knn graph. It is expected to be generally + faster than ivf_pq. """ cdef c_cagra.index_params params @@ -116,12 +118,15 @@ cdef class IndexParams: metric="sqeuclidean", intermediate_graph_degree=128, graph_degree=64, - add_data_on_build=True): + build_algo="ivf_pq"): self.params.metric = _get_metric(metric) self.params.metric_arg = 0 self.params.intermediate_graph_degree = intermediate_graph_degree self.params.graph_degree = graph_degree - self.params.add_data_on_build = add_data_on_build + if build_algo == "ivf_pq": + self.params.build_algo = c_cagra.graph_build_algo.IVF_PQ + elif build_algo == "nn_descent": + self.params.build_algo = c_cagra.graph_build_algo.NN_DESCENT @property def metric(self): @@ -135,10 +140,6 @@ cdef class IndexParams: def graph_degree(self): return self.params.graph_degree - @property - def add_data_on_build(self): - return self.params.add_data_on_build - cdef class Index: cdef readonly bool trained diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 0c683bcd9b..7e22f274e9 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -51,9 +51,14 @@ from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( cdef extern from "raft/neighbors/cagra_types.hpp" \ namespace "raft::neighbors::cagra" nogil: + ctypedef enum graph_build_algo: + IVF_PQ "raft::neighbors::cagra::graph_build_algo::IVF_PQ", + NN_DESCENT "raft::neighbors::cagra::graph_build_algo::NN_DESCENT" + cpdef cppclass index_params(ann_index_params): size_t intermediate_graph_degree size_t graph_degree + graph_build_algo build_algo ctypedef enum search_algo: SINGLE_CTA "raft::neighbors::cagra::search_algo::SINGLE_CTA", diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index 74e9f53b91..f74fc5ae62 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -52,6 +52,7 @@ def run_cagra_build_search_test( metric="euclidean", intermediate_graph_degree=128, graph_degree=64, + build_algo="ivf_pq", array_type="device", compare=True, inplace=True, @@ -67,6 +68,7 @@ def run_cagra_build_search_test( metric=metric, intermediate_graph_degree=intermediate_graph_degree, graph_degree=graph_degree, + build_algo=build_algo, ) if array_type == "device": @@ -139,13 +141,17 @@ def run_cagra_build_search_test( @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) @pytest.mark.parametrize("array_type", ["device", "host"]) -def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): +@pytest.mark.parametrize("build_algo", ["ivf_pq", "nn_descent"]) +def test_cagra_dataset_dtype_host_device( + dtype, array_type, inplace, build_algo +): # Note that inner_product tests use normalized input which we cannot # represent in int8, therefore we test only sqeuclidean metric here. run_cagra_build_search_test( dtype=dtype, inplace=inplace, array_type=array_type, + build_algo=build_algo, ) @@ -158,6 +164,7 @@ def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): "add_data_on_build": True, "k": 1, "metric": "euclidean", + "build_algo": "ivf_pq", }, { "intermediate_graph_degree": 32, @@ -165,6 +172,7 @@ def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): "add_data_on_build": False, "k": 5, "metric": "sqeuclidean", + "build_algo": "ivf_pq", }, { "intermediate_graph_degree": 128, @@ -172,6 +180,7 @@ def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): "add_data_on_build": True, "k": 10, "metric": "inner_product", + "build_algo": "nn_descent", }, ], ) @@ -184,6 +193,7 @@ def test_cagra_index_params(params): graph_degree=params["graph_degree"], intermediate_graph_degree=params["intermediate_graph_degree"], compare=False, + build_algo=params["build_algo"], )