diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh index 6b86f2463f..906371bd01 100644 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ b/cpp/include/raft/neighbors/brute_force-inl.cuh @@ -284,6 +284,44 @@ void fused_l2_knn(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * This function builds a brute force index for the given dataset. This lets you re-use + * precalculated norms for the dataset, leading to a speedup over calling + * raft::neighbors::brute_force::knn repeatedly. + * + * Example usage: + * @code{.cpp} + * #include + * #include + * #include + * + * // create a random dataset + * int n_rows = 10000; + * int n_cols = 10000; + * + * raft::device_resources res; + * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); + * auto labels = raft::make_device_vector(res, n_rows); + * + * raft::random::make_blobs(res, dataset.view(), labels.view()); + * + * // create a brute_force knn index from the dataset + * auto index = raft::neighbors::brute_force::build(res, + * raft::make_const_mdspan(dataset.view())); + * + * // Use the constructed index to search for the nearest 128 neighbors + * int k = 128; + * auto search = raft::make_const_mdspan(dataset.view()); + * + * auto indices= raft::make_device_matrix(res, search.extent(0), k); + * auto distances = raft::make_device_matrix(res, search.extent(0), k); + * + * raft::neighbors::brute_force::search(res, + * index, + * search, + * indices.view(), + * distances.view()); + * @endcode + * * @tparam T data element type * * @param[in] res @@ -330,6 +368,8 @@ index build(raft::resources const& res, /** * @brief Brute Force search using the constructed index. * + * See raft::neighbors::brute_force::build for a usage example + * * @tparam T data element type * @tparam IdxT type of the indices * diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 4ba9159556..331ea55540 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -44,10 +44,10 @@ namespace raft::neighbors::brute_force { * int n_cols = 10000; * raft::device_resources res; - * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); - * auto labels = raft::make_device_vector(res, n_rows); + * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); + * auto labels = raft::make_device_vector(res, n_rows); - * raft::make_blobs(res, dataset.view(), labels.view()); + * raft::random::make_blobs(res, dataset.view(), labels.view()); * * // create a brute_force knn index from the dataset * auto index = raft::neighbors::brute_force::build(res, diff --git a/docs/source/vector_search_tutorial.md b/docs/source/vector_search_tutorial.md index 126ac534c6..8ff25143e3 100644 --- a/docs/source/vector_search_tutorial.md +++ b/docs/source/vector_search_tutorial.md @@ -89,10 +89,10 @@ raft::device_resources res; int n_rows = 10000; int n_cols = 10000; -auto dataset = raft::make_device_matrix(res, n_rows, n_cols); -auto labels = raft::make_device_vector(res, n_rows); +auto dataset = raft::make_device_matrix(res, n_rows, n_cols); +auto labels = raft::make_device_vector(res, n_rows); -raft::make_blobs(res, dataset.view(), labels.view()); +raft::random::make_blobs(res, dataset.view(), labels.view()); ``` That's it. We've now generated a random 10kx10k matrix with points that cleanly separate into Gaussian clusters, along with a vector of cluster labels for each of the data points. Notice the `cuh` extension in the header file include for `make_blobs`. This signifies to us that this file contains CUDA device functions like kernel code so the CUDA compiler, `nvcc` is needed in order to compile any code that uses it. Generally, any source files that include headers with a `cuh` extension use the `.cu` extension instead of `.cpp`. The rule here is that `cpp` source files contain code which can be compiled with a C++ compiler like `g++` while `cu` files require the CUDA compiler. @@ -125,14 +125,14 @@ auto search = raft::make_const_mdspan(dataset.view()); // Indices and Distances are of dimensions (n, k) // where n is number of rows in the search matrix -auto reference_indices = raft::make_device_matrix(search.extent(0), k); // stores index of neighbors -auto reference_distances = raft::make_device_matrix(search.extent(0), k); // stores distance to neighbors +auto reference_indices = raft::make_device_matrix(res, search.extent(0), k); // stores index of neighbors +auto reference_distances = raft::make_device_matrix(res, search.extent(0), k); // stores distance to neighbors raft::neighbors::brute_force::search(res, bfknn_index, search, - raft::make_const_mdspan(indices.view()), - raft::make_const_mdspan(distances.view())); + reference_indices.view(), + reference_distances.view()); ``` We have established several things here by building a flat index. Now we know the exact 64 neighbors of all points in the matrix, and this algorithm can be generally useful in several ways: @@ -152,9 +152,9 @@ Next we'll train an ANN index. We'll use our graph-based CAGRA algorithm for thi raft::device_resources res; // use default index parameters -cagra::index_params index_params; +raft::neighbors::cagra::index_params index_params; -auto index = cagra::build(res, index_params, dataset); +auto index = raft::neighbors::cagra::build(res, index_params, raft::make_const_mdspan(dataset.view())); ``` ### Query the CAGRA index @@ -167,10 +167,10 @@ auto indices = raft::make_device_matrix(res, n_rows, k); auto distances = raft::make_device_matrix(res, n_rows, k); // use default search parameters -cagra::search_params search_params; +raft::neighbors::cagra::search_params search_params; // search K nearest neighbors -cagra::search( +raft::neighbors::cagra::search( res, search_params, index, search, indices.view(), distances.view()); ``` @@ -197,8 +197,8 @@ raft::stats::neighborhood_recall(res, raft::make_const_mdspan(indices.view()), raft::make_const_mdspan(reference_indices.view()), recall_value.view(), - raft::make_const_mdspan(distances), - raft::make_const_mdspan(reference_distances)); + raft::make_const_mdspan(distances.view()), + raft::make_const_mdspan(reference_distances.view())); res.sync_stream(); ``` @@ -340,4 +340,4 @@ The below example specifies the total number of bytes that RAFT can use for temp std::shared_ptr managed_resource; raft::device_resource res(managed_resource, std::make_optional(3 * 1024^3)); -``` \ No newline at end of file +```