Skip to content

Commit

Permalink
Fix NN Descent overflows (rapidsai#1875)
Browse files Browse the repository at this point in the history
NN-Descent was using `int` type for indexing in `mdarray`, however this was causing an overflow when the product of all extents was greater than `int`.

This PR also adds/fixes:

- Missing dependencies for `raft-ann-bench` development environment
- Exposes NN Descent iterations to use in CAGRA benchmarks

Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Ray Douglass (https://github.com/raydouglass)

URL: rapidsai#1875
  • Loading branch information
divyegala committed Oct 6, 2023
1 parent c272038 commit ffab8f6
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 18 deletions.
4 changes: 4 additions & 0 deletions conda/environments/bench_ann_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@ dependencies:
- libcusparse-dev=11.7.5.86
- libcusparse=11.7.5.86
- libfaiss>=1.7.1
- matplotlib
- nccl>=2.9.9
- ninja
- nlohmann_json>=3.11.2
- nvcc_linux-64=11.8
- pandas
- pyyaml
- rmm==23.12.*
- scikit-build>=0.13.1
- sysroot_linux-64==2.17
name: bench_ann_cuda-118_arch-x86_64
1 change: 1 addition & 0 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ void parse_build_param(const nlohmann::json& conf,
param.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT;
}
}
if (conf.contains("nn_descent_niter")) { param.nn_descent_niter = conf.at("nn_descent_niter"); }
}

template <typename T, typename IdxT>
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ index<T, IdxT> build(raft::resources const& res,
auto nn_descent_params = experimental::nn_descent::index_params();
nn_descent_params.graph_degree = intermediate_degree;
nn_descent_params.intermediate_graph_degree = 1.5 * intermediate_degree;
nn_descent_params.max_iterations = params.nn_descent_niter;
build_knn_graph<T, IdxT>(res, dataset, knn_graph->view(), nn_descent_params);
}

Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ struct index_params : ann::index_params {
size_t graph_degree = 64;
/** ANN algorithm to build knn graph. */
graph_build_algo build_algo = graph_build_algo::IVF_PQ;
/** Number of Iterations to run if building with NN_DESCENT */
size_t nn_descent_niter = 20;
};

enum class search_algo {
Expand Down
36 changes: 18 additions & 18 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -362,28 +362,28 @@ class GNND {
GnndGraph<Index_t> graph_;
std::atomic<int64_t> update_counter_;

Index_t nrow_;
const int ndim_;
size_t nrow_;
size_t ndim_;

raft::device_matrix<__half, Index_t, raft::row_major> d_data_;
raft::device_vector<DistData_t, Index_t> l2_norms_;
raft::device_matrix<__half, size_t, raft::row_major> d_data_;
raft::device_vector<DistData_t, size_t> l2_norms_;

raft::device_matrix<ID_t, Index_t, raft::row_major> graph_buffer_;
raft::device_matrix<DistData_t, Index_t, raft::row_major> dists_buffer_;
raft::device_matrix<ID_t, size_t, raft::row_major> graph_buffer_;
raft::device_matrix<DistData_t, size_t, raft::row_major> dists_buffer_;

// TODO: Investigate using RMM/RAFT types https://github.com/rapidsai/raft/issues/1827
thrust::host_vector<ID_t, pinned_memory_allocator<ID_t>> graph_host_buffer_;
thrust::host_vector<DistData_t, pinned_memory_allocator<DistData_t>> dists_host_buffer_;

raft::device_vector<int, Index_t> d_locks_;
raft::device_vector<int, size_t> d_locks_;

thrust::host_vector<Index_t, pinned_memory_allocator<Index_t>> h_rev_graph_new_;
thrust::host_vector<Index_t, pinned_memory_allocator<Index_t>> h_graph_old_;
thrust::host_vector<Index_t, pinned_memory_allocator<Index_t>> h_rev_graph_old_;
// int2.x is the number of forward edges, int2.y is the number of reverse edges

raft::device_vector<int2, Index_t> d_list_sizes_new_;
raft::device_vector<int2, Index_t> d_list_sizes_old_;
raft::device_vector<int2, size_t> d_list_sizes_new_;
raft::device_vector<int2, size_t> d_list_sizes_old_;
};

constexpr int TILE_ROW_WIDTH = 64;
Expand Down Expand Up @@ -1143,21 +1143,21 @@ GNND<Data_t, Index_t>::GNND(raft::resources const& res, const BuildConfig& build
NUM_SAMPLES),
nrow_(build_config.max_dataset_size),
ndim_(build_config.dataset_dim),
d_data_{raft::make_device_matrix<__half, Index_t, raft::row_major>(
d_data_{raft::make_device_matrix<__half, size_t, raft::row_major>(
res, nrow_, build_config.dataset_dim)},
l2_norms_{raft::make_device_vector<DistData_t, Index_t>(res, nrow_)},
l2_norms_{raft::make_device_vector<DistData_t, size_t>(res, nrow_)},
graph_buffer_{
raft::make_device_matrix<ID_t, Index_t, raft::row_major>(res, nrow_, DEGREE_ON_DEVICE)},
raft::make_device_matrix<ID_t, size_t, raft::row_major>(res, nrow_, DEGREE_ON_DEVICE)},
dists_buffer_{
raft::make_device_matrix<DistData_t, Index_t, raft::row_major>(res, nrow_, DEGREE_ON_DEVICE)},
raft::make_device_matrix<DistData_t, size_t, raft::row_major>(res, nrow_, DEGREE_ON_DEVICE)},
graph_host_buffer_(nrow_ * DEGREE_ON_DEVICE),
dists_host_buffer_(nrow_ * DEGREE_ON_DEVICE),
d_locks_{raft::make_device_vector<int, Index_t>(res, nrow_)},
d_locks_{raft::make_device_vector<int, size_t>(res, nrow_)},
h_rev_graph_new_(nrow_ * NUM_SAMPLES),
h_graph_old_(nrow_ * NUM_SAMPLES),
h_rev_graph_old_(nrow_ * NUM_SAMPLES),
d_list_sizes_new_{raft::make_device_vector<int2, Index_t>(res, nrow_)},
d_list_sizes_old_{raft::make_device_vector<int2, Index_t>(res, nrow_)}
d_list_sizes_new_{raft::make_device_vector<int2, size_t>(res, nrow_)},
d_list_sizes_old_{raft::make_device_vector<int2, size_t>(res, nrow_)}
{
static_assert(NUM_SAMPLES <= 32);

Expand Down Expand Up @@ -1342,8 +1342,8 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
for (size_t i = 0; i < (size_t)nrow_; i++) {
for (size_t j = 0; j < build_config_.node_degree; j++) {
size_t idx = i * graph_.node_degree + j;
Index_t id = graph_.h_graph[idx].id();
if (id < nrow_) {
int id = graph_.h_graph[idx].id();
if (id < static_cast<int>(nrow_)) {
graph_shrink_buffer[i * build_config_.node_degree + j] = id;
} else {
graph_shrink_buffer[i * build_config_.node_degree + j] =
Expand Down
2 changes: 2 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ files:
- develop
- cudatoolkit
- nn_bench
- nn_bench_python
test_cpp:
output: none
includes:
Expand Down Expand Up @@ -228,6 +229,7 @@ dependencies:
- libfaiss>=1.7.1
- benchmark>=1.8.2
- faiss-proc=*=cuda
- *rmm_conda
nn_bench_python:
common:
- output_types: [conda]
Expand Down
1 change: 1 addition & 0 deletions docs/source/ann_benchmarks_param_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ CAGRA uses a graph-based index, which creates an intermediate, approximate kNN g
| `graph_degree` | `build_param` | N | Positive Integer >0 | 64 | Degree of the final kNN graph index. |
| `intermediate_graph_degree` | `build_param` | N | Positive Integer >0 | 128 | Degree of the intermediate kNN graph. |
| `graph_build_algo` | `build_param` | N | ["IVF_PQ", "NN_DESCENT"] | "IVF_PQ" | Algorithm to use for search |
| `nn_descent_niter` | `build_param` | N | Positive Integer>0 | 20 | Number of iterations if using NN_DESCENT. |
| `dataset_memory_type` | `build_param` | N | ["device", "host", "mmap"] | "device" | What memory type should the dataset reside? |
| `query_memory_type` | `search_params` | N | ["device", "host", "mmap"] | "device | What memory type should the queries reside? |
| `itopk` | `search_wdith` | N | Positive Integer >0 | 64 | Number of intermediate search results retained during the search. Higher values improve search accuracy at the cost of speed. |
Expand Down

0 comments on commit ffab8f6

Please sign in to comment.