diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 5a883b64ed..2435c477ca 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -22,7 +22,7 @@ on: default: nightly concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }} cancel-in-progress: true jobs: diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index 454cac0d77..b8a088d0f3 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -48,7 +48,6 @@ requirements: - cython >=3.0.0 - libraft {{ version }} - libraft-headers {{ version }} - - numpy >=1.21 - python x.x - rmm ={{ minor_version }} - scikit-build >=0.13.1 @@ -60,6 +59,7 @@ requirements: {% endif %} - libraft {{ version }} - libraft-headers {{ version }} + - numpy >=1.21 - python x.x - rmm ={{ minor_version }} diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt index eb44e58cb5..5919de07e7 100644 --- a/cpp/bench/ann/CMakeLists.txt +++ b/cpp/bench/ann/CMakeLists.txt @@ -90,7 +90,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OR RAFT_ANN_BENCH_USE_RAFT_CAGRA - OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB + OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB ) set(RAFT_ANN_BENCH_USE_RAFT ON) endif() @@ -263,7 +263,8 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) ${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib LINKS raft::compiled - CXXFLAGS "${HNSW_CXX_FLAGS}" + CXXFLAGS + "${HNSW_CXX_FLAGS}" ) endif() diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 5b4048c1c3..a218c85a0a 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -44,7 +44,7 @@ struct l2_exp_cutlass_op { __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh index b8c00616da..4c1f7ea21e 100644 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -50,7 +50,7 @@ void search(raft::resources const& res, const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; + raft::device_matrix_view distances) RAFT_EXPLICIT; template & idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view distances) { - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), - "Number of columns in queries must match brute force index"); - - auto k = neighbors.extent(1); - auto d = idx.dataset().extent(1); - - std::vector dataset = {const_cast(idx.dataset().data_handle())}; - std::vector sizes = {idx.dataset().extent(0)}; - std::vector norms; - if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } - - detail::brute_force_knn_impl(res, - dataset, - sizes, - d, - const_cast(queries.data_handle()), - queries.extent(0), - neighbors.data_handle(), - distances.data_handle(), - k, - true, - true, - nullptr, - idx.metric(), - idx.metric_arg(), - raft::identity_op(), - norms.size() ? &norms : nullptr); + raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); } /** @} */ // end group brute_force_knn } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 6cebf4b52a..4ba9159556 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -14,6 +14,7 @@ * limitations under the License. */ #pragma once +#include #ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "brute_force-inl.cuh" @@ -22,3 +23,70 @@ #ifdef RAFT_COMPILED #include "brute_force-ext.cuh" #endif + +#include + +namespace raft::neighbors::brute_force { +/** + * @brief Make a brute force query over batches of k + * + * This lets you query for batches of k. For example, you can get + * the first 100 neighbors, then the next 100 neighbors etc. + * + * Example usage: + * @code{.cpp} + * #include + * #include + * #include + + * // create a random dataset + * int n_rows = 10000; + * int n_cols = 10000; + + * raft::device_resources res; + * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); + * auto labels = raft::make_device_vector(res, n_rows); + + * raft::make_blobs(res, dataset.view(), labels.view()); + * + * // create a brute_force knn index from the dataset + * auto index = raft::neighbors::brute_force::build(res, + * raft::make_const_mdspan(dataset.view())); + * + * // search the index in batches of 128 nearest neighbors + * auto search = raft::make_const_mdspan(dataset.view()); + * auto query = make_batch_k_query(res, index, search, 128); + * for (auto & batch: *query) { + * // batch.indices() and batch.distances() contain the information on the current batch + * } + * + * // we can also support variable sized batches - loaded up a different number + * // of neighbors at each iteration through the ::advance method + * int64_t batch_size = 128; + * query = make_batch_k_query(res, index, search, batch_size); + * for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { + * // batch.indices() and batch.distances() contain the information on the current batch + * + * batch_size += 16; // load up an extra 16 items in the next batch + * } + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * @param[in] res + * @param[in] index The index to query + * @param[in] query A device matrix view to query for [n_queries, index->dim()] + * @param[in] batch_size The size of each batch + */ + +template +std::shared_ptr> make_batch_k_query( + const raft::resources& res, + const raft::neighbors::brute_force::index& index, + raft::device_matrix_view query, + int64_t batch_size) +{ + return std::shared_ptr>( + new detail::gpu_batch_k_query(res, index, query, batch_size)); +} +} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp index f7030503f1..039599845e 100644 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ b/cpp/include/raft/neighbors/brute_force_types.hpp @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -69,7 +70,7 @@ struct index : ann::index { return norms_view_.value(); } - /** Whether ot not this index has dataset norms */ + /** Whether or not this index has dataset norms */ [[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); } [[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; } @@ -160,6 +161,122 @@ struct index : ann::index { T metric_arg_; }; +/** + * @brief Interface for performing queries over values of k + * + * This interface lets you iterate over batches of k from a brute_force::index. + * This lets you do things like retrieve the first 100 neighbors for a query, + * apply post processing to remove any unwanted items and then if needed get the + * next 100 closest neighbors for the query. + * + * This query interface exposes C++ iterators through the ::begin and ::end, and + * is compatible with range based for loops. + * + * Note that this class is an abstract class without any cuda dependencies, meaning + * that it doesn't require a cuda compiler to use - but also means it can't be directly + * instantiated. See the raft::neighbors::brute_force::make_batch_k_query + * function for usage examples. + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + */ +template +class batch_k_query { + public: + batch_k_query(const raft::resources& res, + int64_t index_size, + int64_t query_size, + int64_t batch_size) + : res(res), index_size(index_size), query_size(query_size), batch_size(batch_size) + { + } + virtual ~batch_k_query() {} + + using value_type = raft::neighbors::batch; + + class iterator { + public: + using value_type = raft::neighbors::batch; + using reference = const value_type&; + using pointer = const value_type*; + + iterator(const batch_k_query* query, int64_t offset = 0) + : current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset) + { + query->load_batch(offset, query->batch_size, &batches); + query->slice_batch(batches, offset, query->batch_size, ¤t); + } + + reference operator*() const { return current; } + + pointer operator->() const { return ¤t; } + + iterator& operator++() + { + advance(query->batch_size); + return *this; + } + + iterator operator++(int) + { + iterator previous(*this); + operator++(); + return previous; + } + + /** + * @brief Advance the iterator, using a custom size for the next batch + * + * Using operator++ means that we will load up the same batch_size for each + * batch. This method allows us to get around this restriction, and load up + * arbitrary batch sizes on each iteration. + * See raft::neighbors::brute_force::make_batch_k_query for a usage example. + * + * @param[in] next_batch_size: size of the next batch to load up + */ + void advance(int64_t next_batch_size) + { + offset = std::min(offset + current.batch_size(), query->index_size); + if (offset + next_batch_size > batches.batch_size()) { + query->load_batch(offset, next_batch_size, &batches); + } + query->slice_batch(batches, offset, next_batch_size, ¤t); + } + + friend bool operator==(const iterator& lhs, const iterator& rhs) + { + return (lhs.query == rhs.query) && (lhs.offset == rhs.offset); + }; + friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); }; + + protected: + // the current batch of data + value_type current; + + // the currently loaded group of data (containing multiple batches of data that we can iterate + // through) + value_type batches; + + const batch_k_query* query; + int64_t offset, current_batch_size; + }; + + iterator begin() const { return iterator(this); } + iterator end() const { return iterator(this, index_size); } + + protected: + // these two methods need cuda code, and are implemented in the subclass + virtual void load_batch(int64_t offset, + int64_t next_batch_size, + batch* output) const = 0; + virtual void slice_batch(const value_type& input, + int64_t offset, + int64_t batch_size, + value_type* output) const = 0; + + const raft::resources& res; + int64_t index_size, query_size, batch_size; +}; /** @} */ } // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/detail/div_utils.hpp b/cpp/include/raft/neighbors/detail/div_utils.hpp new file mode 100644 index 0000000000..0455d0ec9b --- /dev/null +++ b/cpp/include/raft/neighbors/detail/div_utils.hpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _RAFT_HAS_CUDA +#include +#else +#include +#endif + +/** + * @brief A simple wrapper for raft::Pow2 which uses Pow2 utils only when available and regular + * integer division otherwise. This is done to allow a common interface for division arithmetic for + * non CUDA headers. + * + * @tparam Value_ a compile-time value representable as a power-of-two. + */ +namespace raft::neighbors::detail { +template +struct div_utils { + typedef decltype(Value_) Type; + static constexpr Type Value = Value_; + + template + static constexpr _RAFT_HOST_DEVICE inline auto roundDown(T x) + { +#if defined(_RAFT_HAS_CUDA) + return Pow2::roundDown(x); +#else + return raft::round_down_safe(x, Value_); +#endif + } + + template + static constexpr _RAFT_HOST_DEVICE inline auto mod(T x) + { +#if defined(_RAFT_HAS_CUDA) + return Pow2::mod(x); +#else + return x % Value_; +#endif + } + + template + static constexpr _RAFT_HOST_DEVICE inline auto div(T x) + { +#if defined(_RAFT_HAS_CUDA) + return Pow2::div(x); +#else + return x / Value_; +#endif + } +}; +} // namespace raft::neighbors::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 33ed51ad05..e57133fc23 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -313,6 +313,59 @@ auto calculate_offsets_and_indices(IdxT n_rows, return max_cluster_size; } +template +void set_centers(raft::resources const& handle, index* index, const float* cluster_centers) +{ + auto stream = resource::get_cuda_stream(handle); + auto* device_memory = resource::get_workspace_resource(handle); + + // combine cluster_centers and their norms + RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(), + sizeof(float) * index->dim_ext(), + cluster_centers, + sizeof(float) * index->dim(), + sizeof(float) * index->dim(), + index->n_lists(), + cudaMemcpyDefault, + stream)); + + rmm::device_uvector center_norms(index->n_lists(), stream, device_memory); + raft::linalg::rowNorm(center_norms.data(), + cluster_centers, + index->dim(), + index->n_lists(), + raft::linalg::L2Norm, + true, + stream); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle() + index->dim(), + sizeof(float) * index->dim_ext(), + center_norms.data(), + sizeof(float), + sizeof(float), + index->n_lists(), + cudaMemcpyDefault, + stream)); + + // Rotate cluster_centers + float alpha = 1.0; + float beta = 0.0; + linalg::gemm(handle, + true, + false, + index->rot_dim(), + index->n_lists(), + index->dim(), + &alpha, + index->rotation_matrix().data_handle(), + index->dim(), + cluster_centers, + index->dim(), + &beta, + index->centers_rot().data_handle(), + index->rot_dim(), + resource::get_cuda_stream(handle)); +} + template void transpose_pq_centers(const resources& handle, index& index, @@ -613,6 +666,100 @@ void unpack_list_data(raft::resources const& res, resource::get_cuda_stream(res)); } +/** + * A consumer for the `run_on_vector` that just flattens PQ codes + * into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte. + */ +template +struct unpack_contiguous { + uint8_t* codes; + uint32_t code_size; + + /** + * Create a callable to be passed to `run_on_vector`. + * + * @param[in] codes flat compressed PQ codes + */ + __host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim) + : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} + { + } + + /** Write j-th component (code) of the i-th vector into the output array. */ + __host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) + { + bitfield_view_t code_view{codes + i * code_size}; + code_view[j] = code; + } +}; + +template +__launch_bounds__(BlockSize) RAFT_KERNEL unpack_contiguous_list_data_kernel( + uint8_t* out_codes, + device_mdspan::list_extents, row_major> in_list_data, + uint32_t n_rows, + uint32_t pq_dim, + std::variant offset_or_indices) +{ + run_on_list( + in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous(out_codes, pq_dim)); +} + +/** + * Unpack flat PQ codes from an existing list by the given offset. + * + * @param[out] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)] + * @param[in] list_data the packed ivf::list data. + * @param[in] offset_or_indices how many records in the list to skip or the exact indices. + * @param[in] pq_bits codebook size (1 << pq_bits) + * @param[in] stream + */ +inline void unpack_contiguous_list_data( + uint8_t* codes, + device_mdspan::list_extents, row_major> list_data, + uint32_t n_rows, + uint32_t pq_dim, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream) +{ + if (n_rows == 0) { return; } + + constexpr uint32_t kBlockSize = 256; + dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); + dim3 threads(kBlockSize, 1, 1); + auto kernel = [pq_bits]() { + switch (pq_bits) { + case 4: return unpack_contiguous_list_data_kernel; + case 5: return unpack_contiguous_list_data_kernel; + case 6: return unpack_contiguous_list_data_kernel; + case 7: return unpack_contiguous_list_data_kernel; + case 8: return unpack_contiguous_list_data_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(); + kernel<<>>(codes, list_data, n_rows, pq_dim, offset_or_indices); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** Unpack the list data; see the public interface for the api and usage. */ +template +void unpack_contiguous_list_data(raft::resources const& res, + const index& index, + uint8_t* out_codes, + uint32_t n_rows, + uint32_t label, + std::variant offset_or_indices) +{ + unpack_contiguous_list_data(out_codes, + index.lists()[label]->data.view(), + n_rows, + index.pq_dim(), + offset_or_indices, + index.pq_bits(), + resource::get_cuda_stream(res)); +} + /** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data. */ struct reconstruct_vectors { @@ -850,6 +997,101 @@ void pack_list_data(raft::resources const& res, resource::get_cuda_stream(res)); } +/** + * A producer for the `write_vector` reads tightly packed flat codes. That is, + * the codes are not expanded to one code-per-byte. + */ +template +struct pack_contiguous { + const uint8_t* codes; + uint32_t code_size; + + /** + * Create a callable to be passed to `write_vector`. + * + * @param[in] codes flat compressed PQ codes + */ + __host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim) + : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} + { + } + + /** Read j-th component (code) of the i-th vector from the source. */ + __host__ __device__ inline auto operator()(uint32_t i, uint32_t j) -> uint8_t + { + bitfield_view_t code_view{const_cast(codes + i * code_size)}; + return uint8_t(code_view[j]); + } +}; + +template +__launch_bounds__(BlockSize) RAFT_KERNEL pack_contiguous_list_data_kernel( + device_mdspan::list_extents, row_major> list_data, + const uint8_t* codes, + uint32_t n_rows, + uint32_t pq_dim, + std::variant offset_or_indices) +{ + write_list( + list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous(codes, pq_dim)); +} + +/** + * Write flat PQ codes into an existing list by the given offset. + * + * NB: no memory allocation happens here; the list must fit the data (offset + n_rows). + * + * @param[out] list_data the packed ivf::list data. + * @param[in] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)] + * @param[in] offset_or_indices how many records in the list to skip or the exact indices. + * @param[in] pq_bits codebook size (1 << pq_bits) + * @param[in] stream + */ +inline void pack_contiguous_list_data( + device_mdspan::list_extents, row_major> list_data, + const uint8_t* codes, + uint32_t n_rows, + uint32_t pq_dim, + std::variant offset_or_indices, + uint32_t pq_bits, + rmm::cuda_stream_view stream) +{ + if (n_rows == 0) { return; } + + constexpr uint32_t kBlockSize = 256; + dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); + dim3 threads(kBlockSize, 1, 1); + auto kernel = [pq_bits]() { + switch (pq_bits) { + case 4: return pack_contiguous_list_data_kernel; + case 5: return pack_contiguous_list_data_kernel; + case 6: return pack_contiguous_list_data_kernel; + case 7: return pack_contiguous_list_data_kernel; + case 8: return pack_contiguous_list_data_kernel; + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + }(); + kernel<<>>(list_data, codes, n_rows, pq_dim, offset_or_indices); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void pack_contiguous_list_data(raft::resources const& res, + index* index, + const uint8_t* new_codes, + uint32_t n_rows, + uint32_t label, + std::variant offset_or_indices) +{ + pack_contiguous_list_data(index->lists()[label]->data.view(), + new_codes, + n_rows, + index->pq_dim(), + offset_or_indices, + index->pq_bits(), + resource::get_cuda_stream(res)); +} + /** * * A producer for the `write_list` and `write_vector` that encodes level-1 input vector residuals @@ -1634,35 +1876,6 @@ auto build(raft::resources const& handle, labels_view, utils::mapping()); - { - // combine cluster_centers and their norms - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle(), - sizeof(float) * index.dim_ext(), - cluster_centers, - sizeof(float) * index.dim(), - sizeof(float) * index.dim(), - index.n_lists(), - cudaMemcpyDefault, - stream)); - - rmm::device_uvector center_norms(index.n_lists(), stream, device_memory); - raft::linalg::rowNorm(center_norms.data(), - cluster_centers, - index.dim(), - index.n_lists(), - raft::linalg::L2Norm, - true, - stream); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle() + index.dim(), - sizeof(float) * index.dim_ext(), - center_norms.data(), - sizeof(float), - sizeof(float), - index.n_lists(), - cudaMemcpyDefault, - stream)); - } - // Make rotation matrix make_rotation_matrix(handle, params.force_random_rotation, @@ -1670,24 +1883,7 @@ auto build(raft::resources const& handle, index.dim(), index.rotation_matrix().data_handle()); - // Rotate cluster_centers - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(handle, - true, - false, - index.rot_dim(), - index.n_lists(), - index.dim(), - &alpha, - index.rotation_matrix().data_handle(), - index.dim(), - cluster_centers, - index.dim(), - &beta, - index.centers_rot().data_handle(), - index.rot_dim(), - stream); + set_centers(handle, &index, cluster_centers); // Train PQ codebooks switch (index.codebook_kind()) { diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 5da4e77874..27ef00e385 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -65,11 +66,12 @@ void tiled_brute_force_knn(const raft::resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 2.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0, - DistanceEpilogue distance_epilogue = raft::identity_op(), - const ElementType* precomputed_index_norms = nullptr) + float metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + DistanceEpilogue distance_epilogue = raft::identity_op(), + const ElementType* precomputed_index_norms = nullptr, + const ElementType* precomputed_search_norms = nullptr) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -98,18 +100,20 @@ void tiled_brute_force_knn(const raft::resources& handle, if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded || metric == raft::distance::DistanceType::CosineExpanded) { - search_norms.resize(m, stream); + if (!precomputed_search_norms) { search_norms.resize(m, stream); } if (!precomputed_index_norms) { index_norms.resize(n, stream); } // cosine needs the l2norm, where as l2 distances needs the squared norm if (metric == raft::distance::DistanceType::CosineExpanded) { - raft::linalg::rowNorm(search_norms.data(), - search, - d, - m, - raft::linalg::NormType::L2Norm, - true, - stream, - raft::sqrt_op{}); + if (!precomputed_search_norms) { + raft::linalg::rowNorm(search_norms.data(), + search, + d, + m, + raft::linalg::NormType::L2Norm, + true, + stream, + raft::sqrt_op{}); + } if (!precomputed_index_norms) { raft::linalg::rowNorm(index_norms.data(), index, @@ -121,9 +125,10 @@ void tiled_brute_force_knn(const raft::resources& handle, raft::sqrt_op{}); } } else { - raft::linalg::rowNorm( - search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); - + if (!precomputed_search_norms) { + raft::linalg::rowNorm( + search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); + } if (!precomputed_index_norms) { raft::linalg::rowNorm( index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); @@ -184,7 +189,7 @@ void tiled_brute_force_knn(const raft::resources& handle, metric_arg); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto row_norms = search_norms.data(); + auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); bool sqrt = metric == raft::distance::DistanceType::L2SqrtExpanded; @@ -201,7 +206,7 @@ void tiled_brute_force_knn(const raft::resources& handle, return distance_epilogue(val, row, col); }); } else if (metric == raft::distance::DistanceType::CosineExpanded) { - auto row_norms = search_norms.data(); + auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); auto dist = temp_distances.data(); @@ -333,7 +338,8 @@ void brute_force_knn_impl( raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, float metricArg = 0, DistanceEpilogue distance_epilogue = raft::identity_op(), - std::vector* input_norms = nullptr) + std::vector* input_norms = nullptr, + const value_t* search_norms = nullptr) { auto userStream = resource::get_cuda_stream(handle); @@ -376,7 +382,7 @@ void brute_force_knn_impl( } // currently we don't support col_major inside tiled_brute_force_knn, because - // of limitattions of the pairwise_distance API: + // of limitations of the pairwise_distance API: // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have // multiple options here (like rowMajorQuery/rowMajorIndex) // 2) because of tiling, we need to be able to set a custom stride in the PW @@ -428,7 +434,8 @@ void brute_force_knn_impl( rowMajorQuery, stream, metric, - input_norms ? (*input_norms)[i] : nullptr); + input_norms ? (*input_norms)[i] : nullptr, + search_norms); // Perform necessary post-processing if (metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -478,7 +485,8 @@ void brute_force_knn_impl( 0, 0, distance_epilogue, - input_norms ? (*input_norms)[i] : nullptr); + input_norms ? (*input_norms)[i] : nullptr, + search_norms); break; } } @@ -500,4 +508,43 @@ void brute_force_knn_impl( if (translations == nullptr) delete id_ranges; }; +template +void brute_force_search( + raft::resources const& res, + const raft::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + std::optional> query_norms = std::nullopt) +{ + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), + "Number of columns in queries must match brute force index"); + + auto k = neighbors.extent(1); + auto d = idx.dataset().extent(1); + + std::vector dataset = {const_cast(idx.dataset().data_handle())}; + std::vector sizes = {idx.dataset().extent(0)}; + std::vector norms; + if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } + + brute_force_knn_impl(res, + dataset, + sizes, + d, + const_cast(queries.data_handle()), + queries.extent(0), + neighbors.data_handle(), + distances.data_handle(), + k, + true, + true, + nullptr, + idx.metric(), + idx.metric_arg(), + raft::identity_op(), + norms.size() ? &norms : nullptr, + query_norms ? query_norms->data_handle() : nullptr); +} } // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh new file mode 100644 index 0000000000..384eacae79 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::neighbors::brute_force::detail { +template +class gpu_batch_k_query : public batch_k_query { + public: + gpu_batch_k_query(const raft::resources& res, + const raft::neighbors::brute_force::index& index, + raft::device_matrix_view query, + int64_t batch_size) + : batch_k_query(res, index.size(), query.extent(0), batch_size), + index(index), + query(query) + { + auto metric = index.metric(); + + // precompute query norms, and re-use across batches + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::CosineExpanded) { + query_norms = make_device_vector(res, query.extent(0)); + + if (metric == raft::distance::DistanceType::CosineExpanded) { + raft::linalg::norm(res, + query, + query_norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op{}); + } else { + raft::linalg::norm(res, + query, + query_norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + } + } + } + + protected: + void load_batch(int64_t offset, int64_t next_batch_size, batch* output) const override + { + if (offset >= index.size()) { return; } + + // we're aiming to load multiple batches here - since we don't know the max iteration + // grow the size we're loading exponentially + int64_t batch_size = std::min(std::max(offset * 2, next_batch_size * 2), this->index_size); + output->resize(this->res, this->query_size, batch_size); + + std::optional> query_norms_view; + if (query_norms) { query_norms_view = query_norms->view(); } + + raft::neighbors::detail::brute_force_search( + this->res, index, query, output->indices(), output->distances(), query_norms_view); + }; + + void slice_batch(const batch& input, + int64_t offset, + int64_t batch_size, + batch* output) const override + { + auto num_queries = input.indices().extent(0); + batch_size = std::min(batch_size, index.size() - offset); + + output->resize(this->res, num_queries, batch_size); + + if (!num_queries || !batch_size) { return; } + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + matrix::slice(this->res, input.indices(), output->indices(), coords); + matrix::slice(this->res, input.distances(), output->distances(), coords); + } + + const raft::neighbors::brute_force::index& index; + raft::device_matrix_view query; + std::optional> query_norms; +}; +} // namespace raft::neighbors::brute_force::detail diff --git a/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp b/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp index 4594332fdf..5379788ab4 100644 --- a/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp @@ -19,36 +19,11 @@ #include #include #include +#include #include -#ifdef _RAFT_HAS_CUDA -#include -#else -#include -#endif - namespace raft::neighbors::ivf_flat::codepacker { -template -_RAFT_HOST_DEVICE inline auto roundDown(T x) -{ -#if defined(_RAFT_HAS_CUDA) - return Pow2::roundDown(x); -#else - return raft::round_down_safe(x, kIndexGroupSize); -#endif -} - -template -_RAFT_HOST_DEVICE inline auto mod(T x) -{ -#if defined(_RAFT_HAS_CUDA) - return Pow2::mod(x); -#else - return x % kIndexGroupSize; -#endif -} - /** * Write one flat code into a block by the given offset. The offset indicates the id of the record * in the list. This function interleaves the code and is intended to later copy the interleaved @@ -68,12 +43,12 @@ _RAFT_HOST_DEVICE void pack_1( const T* flat_code, T* block, uint32_t dim, uint32_t veclen, uint32_t offset) { // The data is written in interleaved groups of `index::kGroupSize` vectors - // using interleaved_group = Pow2; + using interleaved_group = neighbors::detail::div_utils; // Interleave dimensions of the source vector while recording it. // NB: such `veclen` is selected, that `dim % veclen == 0` - auto group_offset = roundDown(offset); - auto ingroup_id = mod(offset) * veclen; + auto group_offset = interleaved_group::roundDown(offset); + auto ingroup_id = interleaved_group::mod(offset) * veclen; for (uint32_t l = 0; l < dim; l += veclen) { for (uint32_t j = 0; j < veclen; j++) { @@ -100,11 +75,11 @@ _RAFT_HOST_DEVICE void unpack_1( const T* block, T* flat_code, uint32_t dim, uint32_t veclen, uint32_t offset) { // The data is written in interleaved groups of `index::kGroupSize` vectors - // using interleaved_group = Pow2; + using interleaved_group = neighbors::detail::div_utils; // NB: such `veclen` is selected, that `dim % veclen == 0` - auto group_offset = roundDown(offset); - auto ingroup_id = mod(offset) * veclen; + auto group_offset = interleaved_group::roundDown(offset); + auto ingroup_id = interleaved_group::mod(offset) * veclen; for (uint32_t l = 0; l < dim; l += veclen) { for (uint32_t j = 0; j < veclen; j++) { diff --git a/cpp/include/raft/neighbors/ivf_flat_helpers.cuh b/cpp/include/raft/neighbors/ivf_flat_helpers.cuh index 096e8051c3..7a05c9991c 100644 --- a/cpp/include/raft/neighbors/ivf_flat_helpers.cuh +++ b/cpp/include/raft/neighbors/ivf_flat_helpers.cuh @@ -22,7 +22,10 @@ #include #include +#include + namespace raft::neighbors::ivf_flat::helpers { +using namespace raft::spatial::knn::detail; // NOLINT /** * @defgroup ivf_flat_helpers Helper functions for manipulationg IVF Flat Index * @{ @@ -106,5 +109,37 @@ void unpack( res, list_data, veclen, offset, codes); } } // namespace codepacker + +/** + * @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for + * externally modifying the index without going through the build stage. The data and indices of the + * IVF lists will be lost. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * using namespace raft::neighbors; + * // use default index parameters + * ivf_flat::index_params index_params; + * // initialize an empty index + * ivf_flat::index index(res, index_params, D); + * // reset the index's state and list sizes + * ivf_flat::helpers::reset_index(res, &index); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + */ +template +void reset_index(const raft::resources& res, index* index) +{ + auto stream = resource::get_cuda_stream(res); + + utils::memzero(index->list_sizes().data_handle(), index->list_sizes().size(), stream); + utils::memzero(index->data_ptrs().data_handle(), index->data_ptrs().size(), stream); + utils::memzero(index->inds_ptrs().data_handle(), index->inds_ptrs().size(), stream); +} /** @} */ } // namespace raft::neighbors::ivf_flat::helpers diff --git a/cpp/include/raft/neighbors/ivf_pq_helpers.cuh b/cpp/include/raft/neighbors/ivf_pq_helpers.cuh index f00107f629..fec31f1c61 100644 --- a/cpp/include/raft/neighbors/ivf_pq_helpers.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_helpers.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -23,7 +24,10 @@ #include #include +#include + namespace raft::neighbors::ivf_pq::helpers { +using namespace raft::spatial::knn::detail; // NOLINT /** * @defgroup ivf_pq_helpers Helper functions for manipulationg IVF PQ Index * @{ @@ -71,6 +75,53 @@ inline void unpack( codes, list_data, offset, pq_bits, resource::get_cuda_stream(res)); } +/** + * @brief Unpack `n_rows` consecutive records of a single list (cluster) in the compressed index + * starting at given `offset`. The output codes of a single vector are contiguous, not expanded to + * one code per byte, which means the output has ceildiv(pq_dim * pq_bits, 8) bytes per PQ encoded + * vector. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * auto list_data = index.lists()[label]->data.view(); + * // allocate the buffer for the output + * uint32_t n_rows = 4; + * auto codes = raft::make_device_matrix( + * res, n_rows, raft::ceildiv(index.pq_dim() * index.pq_bits(), 8)); + * uint32_t offset = 0; + * // unpack n_rows elements from the list + * ivf_pq::helpers::codepacker::unpack_contiguous( + * res, list_data, index.pq_bits(), offset, n_rows, index.pq_dim(), codes.data_handle()); + * @endcode + * + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res raft resource + * @param[in] list_data block to read from + * @param[in] pq_bits bit length of encoded vector elements + * @param[in] offset + * How many records in the list to skip. + * @param[in] n_rows How many records to unpack + * @param[in] pq_dim The dimensionality of the PQ compressed records + * @param[out] codes + * the destination buffer [n_rows, ceildiv(pq_dim * pq_bits, 8)]. + * The length `n_rows` defines how many records to unpack, + * it must be smaller than the list size. + */ +inline void unpack_contiguous( + raft::resources const& res, + device_mdspan::list_extents, row_major> list_data, + uint32_t pq_bits, + uint32_t offset, + uint32_t n_rows, + uint32_t pq_dim, + uint8_t* codes) +{ + ivf_pq::detail::unpack_contiguous_list_data( + codes, list_data, n_rows, pq_dim, offset, pq_bits, resource::get_cuda_stream(res)); +} + /** * Write flat PQ codes into an existing list by the given offset. * @@ -87,7 +138,7 @@ inline void unpack( * res, make_const_mdspan(codes.view()), index.pq_bits(), 42, list_data); * @endcode * - * @param[in] res + * @param[in] res raft resource * @param[in] codes flat PQ codes, one code per byte [n_vec, pq_dim] * @param[in] pq_bits bit length of encoded vector elements * @param[in] offset how many records to skip before writing the data into the list @@ -102,6 +153,47 @@ inline void pack( { ivf_pq::detail::pack_list_data(list_data, codes, offset, pq_bits, resource::get_cuda_stream(res)); } + +/** + * Write flat PQ codes into an existing list by the given offset. The input codes of a single vector + * are contiguous (not expanded to one code per byte). + * + * NB: no memory allocation happens here; the list must fit the data (offset + n_rows records). + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * auto list_data = index.lists()[label]->data.view(); + * // allocate the buffer for the input codes + * auto codes = raft::make_device_matrix( + * res, n_rows, raft::ceildiv(index.pq_dim() * index.pq_bits(), 8)); + * ... prepare compressed vectors to pack into the list in codes ... + * // write codes into the list starting from the 42nd position. If the current size of the list + * // is greater than 42, this will overwrite the codes starting at this offset. + * ivf_pq::helpers::codepacker::pack_contiguous( + * res, codes.data_handle(), n_rows, index.pq_dim(), index.pq_bits(), 42, list_data); + * @endcode + * + * @param[in] res raft resource + * @param[in] codes flat PQ codes, [n_vec, ceildiv(pq_dim * pq_bits, 8)] + * @param[in] n_rows number of records + * @param[in] pq_dim + * @param[in] pq_bits bit length of encoded vector elements + * @param[in] offset how many records to skip before writing the data into the list + * @param[in] list_data block to write into + */ +inline void pack_contiguous( + raft::resources const& res, + const uint8_t* codes, + uint32_t n_rows, + uint32_t pq_dim, + uint32_t pq_bits, + uint32_t offset, + device_mdspan::list_extents, row_major> list_data) +{ + ivf_pq::detail::pack_contiguous_list_data( + list_data, codes, n_rows, pq_dim, offset, pq_bits, resource::get_cuda_stream(res)); +} } // namespace codepacker /** @@ -122,7 +214,7 @@ inline void pack( * ivf_pq::helpers::pack_list_data(res, &index, codes_to_pack, label, 42); * @endcode * - * @param[in] res + * @param[in] res raft resource * @param[inout] index IVF-PQ index. * @param[in] codes flat PQ codes, one code per byte [n_rows, pq_dim] * @param[in] label The id of the list (cluster) into which we write. @@ -138,6 +230,56 @@ void pack_list_data(raft::resources const& res, ivf_pq::detail::pack_list_data(res, index, codes, label, offset); } +/** + * Write flat PQ codes into an existing list by the given offset. Use this when the input + * vectors are PQ encoded and not expanded to one code per byte. + * + * The list is identified by its label. + * + * NB: no memory allocation happens here; the list into which the vectors are packed must fit offset + * + n_rows rows. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * raft::resources res; + * // use default index parameters + * ivf_pq::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_pq::build(res, index_params, dataset, N, D); + * // allocate the buffer for n_rows input codes. Each vector occupies + * // raft::ceildiv(index.pq_dim() * index.pq_bits(), 8) bytes because + * // codes are compressed and without gaps. + * auto codes = raft::make_device_matrix( + * res, n_rows, raft::ceildiv(index.pq_dim() * index.pq_bits(), 8)); + * ... prepare the compressed vectors to pack into the list in codes ... + * // the first n_rows codes in the fourth IVF list are to be overwritten. + * uint32_t label = 3; + * // write codes into the list starting from the 0th position + * ivf_pq::helpers::pack_contiguous_list_data( + * res, &index, codes.data_handle(), n_rows, label, 0); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + * @param[in] codes flat contiguous PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)] + * @param[in] n_rows how many records to pack + * @param[in] label The id of the list (cluster) into which we write. + * @param[in] offset how many records to skip before writing the data into the list + */ +template +void pack_contiguous_list_data(raft::resources const& res, + index* index, + uint8_t* codes, + uint32_t n_rows, + uint32_t label, + uint32_t offset) +{ + ivf_pq::detail::pack_contiguous_list_data(res, index, codes, n_rows, label, offset); +} + /** * @brief Unpack `n_take` consecutive records of a single list (cluster) in the compressed index * starting at given `offset`, one code per byte (independently of pq_bits). @@ -200,8 +342,8 @@ void unpack_list_data(raft::resources const& res, * * @tparam IdxT type of the indices in the source dataset * - * @param[in] res - * @param[in] index + * @param[in] res raft resource + * @param[in] index IVF-PQ index (passed by reference) * @param[in] in_cluster_indices * The offsets of the selected indices within the cluster. * @param[out] out_codes @@ -221,6 +363,53 @@ void unpack_list_data(raft::resources const& res, return ivf_pq::detail::unpack_list_data(res, index, out_codes, label, in_cluster_indices); } +/** + * @brief Unpack `n_rows` consecutive PQ encoded vectors of a single list (cluster) in the + * compressed index starting at given `offset`, not expanded to one code per byte. Each code in the + * output buffer occupies ceildiv(index.pq_dim() * index.pq_bits(), 8) bytes. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * // We will unpack the whole fourth cluster + * uint32_t label = 3; + * // Get the list size + * uint32_t list_size = 0; + * raft::update_host(&list_size, index.list_sizes().data_handle() + label, 1, + * raft::resource::get_cuda_stream(res)); raft::resource::sync_stream(res); + * // allocate the buffer for the output + * auto codes = raft::make_device_matrix(res, list_size, raft::ceildiv(index.pq_dim() * + * index.pq_bits(), 8)); + * // unpack the whole list + * ivf_pq::helpers::unpack_list_data(res, index, codes.data_handle(), list_size, label, 0); + * @endcode + * + * @tparam IdxT type of the indices in the source dataset + * + * @param[in] res raft resource + * @param[in] index IVF-PQ index (passed by reference) + * @param[out] out_codes + * the destination buffer [n_rows, ceildiv(index.pq_dim() * index.pq_bits(), 8)]. + * The length `n_rows` defines how many records to unpack, + * offset + n_rows must be smaller than or equal to the list size. + * @param[in] n_rows how many codes to unpack + * @param[in] label + * The id of the list (cluster) to decode. + * @param[in] offset + * How many records in the list to skip. + */ +template +void unpack_contiguous_list_data(raft::resources const& res, + const index& index, + uint8_t* out_codes, + uint32_t n_rows, + uint32_t label, + uint32_t offset) +{ + return ivf_pq::detail::unpack_contiguous_list_data( + res, index, out_codes, n_rows, label, offset); +} + /** * @brief Decode `n_take` consecutive records of a single list (cluster) in the compressed index * starting at given `offset`. @@ -232,7 +421,7 @@ void unpack_list_data(raft::resources const& res, * // Get the list size * uint32_t list_size = 0; * raft::copy(&list_size, index.list_sizes().data_handle() + label, 1, - * resource::get_cuda_stream(res)); resource::sync_stream(res); + * resource::get_cuda_stream(res)); resource::sync_stream(res); * // allocate the buffer for the output * auto decoded_vectors = raft::make_device_matrix(res, list_size, index.dim()); * // decode the whole list @@ -397,6 +586,7 @@ void extend_list(raft::resources const& res, * @endcode * * @tparam IdxT + * * @param[in] res * @param[inout] index * @param[in] label the id of the target list (cluster). @@ -407,5 +597,197 @@ void erase_list(raft::resources const& res, index* index, uint32_t label) ivf_pq::detail::erase_list(res, index, label); } +/** + * @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for + * externally modifying the index without going through the build stage. The data and indices of the + * IVF lists will be lost. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * using namespace raft::neighbors; + * // use default index parameters + * ivf_pq::index_params index_params; + * // initialize an empty index + * ivf_pq::index index(res, index_params, D); + * // reset the index's state and list sizes + * ivf_pq::helpers::reset_index(res, &index); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + */ +template +void reset_index(const raft::resources& res, index* index) +{ + auto stream = resource::get_cuda_stream(res); + + utils::memzero( + index->accum_sorted_sizes().data_handle(), index->accum_sorted_sizes().size(), stream); + utils::memzero(index->list_sizes().data_handle(), index->list_sizes().size(), stream); + utils::memzero(index->data_ptrs().data_handle(), index->data_ptrs().size(), stream); + utils::memzero(index->inds_ptrs().data_handle(), index->inds_ptrs().size(), stream); +} + +/** + * @brief Public helper API exposing the computation of the index's rotation matrix. + * NB: This is to be used only when the rotation matrix is not already computed through + * raft::neighbors::ivf_pq::build. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * // use default index parameters + * ivf_pq::index_params index_params; + * // force random rotation + * index_params.force_random_rotation = true; + * // initialize an empty index + * raft::neighbors::ivf_pq::index index(res, index_params, D); + * // reset the index + * reset_index(res, &index); + * // compute the rotation matrix with random_rotation + * raft::neighbors::ivf_pq::helpers::make_rotation_matrix( + * res, &index, index_params.force_random_rotation); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + * @param[in] force_random_rotation whether to apply a random rotation matrix on the input data. See + * raft::neighbors::ivf_pq::index_params for more details. + */ +template +void make_rotation_matrix(raft::resources const& res, + index* index, + bool force_random_rotation) +{ + raft::neighbors::ivf_pq::detail::make_rotation_matrix(res, + force_random_rotation, + index->rot_dim(), + index->dim(), + index->rotation_matrix().data_handle()); +} + +/** + * @brief Public helper API for externally modifying the index's IVF centroids. + * NB: The index must be reset before this. Use raft::neighbors::ivf_pq::extend to construct IVF + lists according to new centroids. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * // allocate the buffer for the input centers + * auto cluster_centers = raft::make_device_matrix(res, index.n_lists(), + index.dim()); + * ... prepare ivf centroids in cluster_centers ... + * // reset the index + * reset_index(res, &index); + * // recompute the state of the index + * raft::neighbors::ivf_pq::helpers::recompute_internal_state(res, index); + * // Write the IVF centroids + * raft::neighbors::ivf_pq::helpers::set_centers( + res, + &index, + cluster_centers); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + * @param[in] cluster_centers new cluster centers [index.n_lists(), index.dim()] + */ +template +void set_centers(raft::resources const& res, + index* index, + device_matrix_view cluster_centers) +{ + RAFT_EXPECTS(cluster_centers.extent(0) == index->n_lists(), + "Number of rows in the new centers must be equal to the number of IVF lists"); + RAFT_EXPECTS(cluster_centers.extent(1) == index->dim(), + "Number of columns in the new cluster centers and index dim are different"); + RAFT_EXPECTS(index->size() == 0, "Index must be empty"); + ivf_pq::detail::set_centers(res, index, cluster_centers.data_handle()); +} + +/** + * @brief Helper exposing the re-computation of list sizes and related arrays if IVF lists have been + * modified. + * + * Usage example: + * @code{.cpp} + * using namespace raft::neighbors; + * raft::resources res; + * // use default index parameters + * ivf_pq::index_params index_params; + * // initialize an empty index + * ivf_pq::index index(res, index_params, D); + * ivf_pq::helpers::reset_index(res, &index); + * // resize the first IVF list to hold 5 records + * auto spec = list_spec{ + * index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; + * uint32_t new_size = 5; + * ivf::resize_list(res, list, spec, new_size, 0); + * raft::update_device(index.list_sizes(), &new_size, 1, stream); + * // recompute the internal state of the index + * ivf_pq::recompute_internal_state(res, &index); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[inout] index pointer to IVF-PQ index + */ +template +void recompute_internal_state(const raft::resources& res, index* index) +{ + auto& list = index->lists()[0]; + ivf_pq::detail::recompute_internal_state(res, *index); +} + +/** + * @brief Public helper API for fetching a trained index's IVF centroids into a buffer that may be + * allocated on either host or device. + * + * Usage example: + * @code{.cpp} + * raft::resources res; + * // allocate the buffer for the output centers + * auto cluster_centers = raft::make_device_matrix( + * res, index.n_lists(), index.dim()); + * // Extract the IVF centroids into the buffer + * raft::neighbors::ivf_pq::helpers::extract_centers(res, index, cluster_centers.data_handle()); + * @endcode + * + * @tparam IdxT + * + * @param[in] res raft resource + * @param[in] index IVF-PQ index (passed by reference) + * @param[out] cluster_centers IVF cluster centers [index.n_lists(), index.dim] + */ +template +void extract_centers(raft::resources const& res, + const index& index, + raft::device_matrix_view cluster_centers) +{ + RAFT_EXPECTS(cluster_centers.extent(0) == index.n_lists(), + "Number of rows in the output buffer for cluster centers must be equal to the " + "number of IVF lists"); + RAFT_EXPECTS( + cluster_centers.extent(1) == index.dim(), + "Number of columns in the output buffer for cluster centers and index dim are different"); + auto stream = resource::get_cuda_stream(res); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data_handle(), + sizeof(float) * index.dim(), + index.centers().data_handle(), + sizeof(float) * index.dim_ext(), + sizeof(float) * index.dim(), + index.n_lists(), + cudaMemcpyDefault, + stream)); +} /** @} */ } // namespace raft::neighbors::ivf_pq::helpers diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index 24df77b35a..45ab18c84f 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -487,6 +487,30 @@ struct index : ann::index { return centers_rot_.view(); } + /** fetch size of a particular IVF list in bytes using the list extents. + * Usage example: + * @code{.cpp} + * raft::resources res; + * // use default index params + * ivf_pq::index_params index_params; + * // extend the IVF lists while building the index + * index_params.add_data_on_build = true; + * // create and fill the index from a [N, D] dataset + * auto index = raft::neighbors::ivf_pq::build(res, index_params, dataset, N, D); + * // Fetch the size of the fourth list + * uint32_t size = index.get_list_size_in_bytes(3); + * @endcode + * + * @param[in] label list ID + */ + inline auto get_list_size_in_bytes(uint32_t label) -> uint32_t + { + RAFT_EXPECTS(label < this->n_lists(), + "Expected label to be less than number of lists in the index"); + auto list_data = this->lists()[label]->data; + return list_data.size(); + } + private: raft::distance::DistanceType metric_; codebook_gen codebook_kind_; diff --git a/cpp/include/raft/neighbors/neighbors_types.hpp b/cpp/include/raft/neighbors/neighbors_types.hpp new file mode 100644 index 0000000000..d503779741 --- /dev/null +++ b/cpp/include/raft/neighbors/neighbors_types.hpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::neighbors { + +/** A single batch of nearest neighbors in device memory */ +template +class batch { + public: + /** Create a new empty batch of data */ + batch(raft::resources const& res, int64_t rows, int64_t cols) + : indices_(make_device_matrix(res, rows, cols)), + distances_(make_device_matrix(res, rows, cols)) + { + } + + void resize(raft::resources const& res, int64_t rows, int64_t cols) + { + indices_ = make_device_matrix(res, rows, cols); + distances_ = make_device_matrix(res, rows, cols); + } + + /** Returns the indices for the batch */ + device_matrix_view indices() const + { + return raft::make_const_mdspan(indices_.view()); + } + device_matrix_view indices() { return indices_.view(); } + + /** Returns the distances for the batch */ + device_matrix_view distances() const + { + return raft::make_const_mdspan(distances_.view()); + } + device_matrix_view distances() { return distances_.view(); } + + /** Returns the size of the batch */ + int64_t batch_size() const { return indices().extent(1); } + + protected: + raft::device_matrix indices_; + raft::device_matrix distances_; +}; +} // namespace raft::neighbors diff --git a/cpp/src/neighbors/brute_force_knn_index_float.cu b/cpp/src/neighbors/brute_force_knn_index_float.cu index f2fda93a97..d4f902c087 100644 --- a/cpp/src/neighbors/brute_force_knn_index_float.cu +++ b/cpp/src/neighbors/brute_force_knn_index_float.cu @@ -36,4 +36,4 @@ template raft::neighbors::brute_force::index raft::neighbors::brute_force raft::resources const& res, raft::device_matrix_view dataset, raft::distance::DistanceType metric, - float metric_arg); \ No newline at end of file + float metric_arg); diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index bdb83ecfdc..eb30b60eca 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -380,6 +380,83 @@ class ivf_pq_test : public ::testing::TestWithParam { list_data_size, Compare{})); } + void check_packing_contiguous(index* index, uint32_t label) + { + auto old_list = index->lists()[label]; + auto n_rows = old_list->size.load(); + + if (n_rows == 0) { return; } + + auto codes = make_device_matrix(handle_, n_rows, index->pq_dim()); + auto indices = make_device_vector(handle_, n_rows); + copy(indices.data_handle(), old_list->indices.data_handle(), n_rows, stream_); + + uint32_t code_size = ceildiv(index->pq_dim() * index->pq_bits(), 8); + + auto codes_compressed = make_device_matrix(handle_, n_rows, code_size); + + ivf_pq::helpers::unpack_contiguous_list_data( + handle_, *index, codes_compressed.data_handle(), n_rows, label, 0); + ivf_pq::helpers::erase_list(handle_, index, label); + ivf_pq::detail::extend_list_prepare(handle_, index, make_const_mdspan(indices.view()), label); + ivf_pq::helpers::pack_contiguous_list_data( + handle_, index, codes_compressed.data_handle(), n_rows, label, 0); + ivf_pq::helpers::recompute_internal_state(handle_, index); + + auto& new_list = index->lists()[label]; + ASSERT_NE(old_list.get(), new_list.get()) + << "The old list should have been shared and retained after ivf_pq index has erased the " + "corresponding cluster."; + auto list_data_size = (n_rows / ivf_pq::kIndexGroupSize) * new_list->data.extent(1) * + new_list->data.extent(2) * new_list->data.extent(3); + + ASSERT_TRUE(old_list->data.size() >= list_data_size); + ASSERT_TRUE(new_list->data.size() >= list_data_size); + ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), + new_list->data.data_handle(), + list_data_size, + Compare{})); + + // Pack a few vectors back to the list. + uint32_t row_offset = 9; + uint32_t n_vec = 3; + ASSERT_TRUE(row_offset + n_vec < n_rows); + size_t offset = row_offset * code_size; + auto codes_to_pack = make_device_matrix_view( + codes_compressed.data_handle() + offset, n_vec, index->pq_dim()); + ivf_pq::helpers::pack_contiguous_list_data( + handle_, index, codes_to_pack.data_handle(), n_vec, label, row_offset); + ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), + new_list->data.data_handle(), + list_data_size, + Compare{})); + + // // Another test with the API that take list_data directly + auto list_data = index->lists()[label]->data.view(); + uint32_t n_take = 4; + ASSERT_TRUE(row_offset + n_take < n_rows); + auto codes2 = raft::make_device_matrix(handle_, n_take, code_size); + ivf_pq::helpers::codepacker::unpack_contiguous(handle_, + list_data, + index->pq_bits(), + row_offset, + n_take, + index->pq_dim(), + codes2.data_handle()); + + // Write it back + ivf_pq::helpers::codepacker::pack_contiguous(handle_, + codes2.data_handle(), + n_vec, + index->pq_dim(), + index->pq_bits(), + row_offset, + list_data); + ASSERT_TRUE(devArrMatch(old_list->data.data_handle(), + new_list->data.data_handle(), + list_data_size, + Compare{})); + } template void run(BuildIndex build_index) @@ -398,6 +475,7 @@ class ivf_pq_test : public ::testing::TestWithParam { case 1: { // Dump and re-write codes for one label check_packing(&index, label); + check_packing_contiguous(&index, label); } break; default: { // check a small subset of data in a randomly chosen cluster to see if the data @@ -962,6 +1040,32 @@ inline auto special_cases() -> test_cases_t x.search_params.n_probes = 100; }); + ADD_CASE({ + x.num_db_vecs = 4335; + x.dim = 4; + x.num_queries = 100000; + x.k = 12; + x.index_params.metric = distance::DistanceType::L2Expanded; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.index_params.pq_dim = 2; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 69; + x.search_params.n_probes = 69; + }); + + ADD_CASE({ + x.num_db_vecs = 4335; + x.dim = 4; + x.num_queries = 100000; + x.k = 12; + x.index_params.metric = distance::DistanceType::L2Expanded; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; + x.index_params.pq_dim = 2; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 69; + x.search_params.n_probes = 69; + }); + return xs; } diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index ebde8e6d35..a84c9749d7 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -38,6 +38,7 @@ #include namespace raft::neighbors::brute_force { + struct TiledKNNInputs { int num_queries; int num_db_vecs; @@ -190,11 +191,13 @@ class TiledKNNTest : public ::testing::TestWithParam { metric, metric_arg); + auto query_view = raft::make_device_matrix_view( + search_queries.data(), params_.num_queries, params_.dim); + raft::neighbors::brute_force::search( handle_, idx, - raft::make_device_matrix_view( - search_queries.data(), params_.num_queries, params_.dim), + query_view, raft::make_device_matrix_view( raft_indices_.data(), params_.num_queries, params_.k), raft::make_device_matrix_view( @@ -209,6 +212,73 @@ class TiledKNNTest : public ::testing::TestWithParam { float(0.001), stream_, true)); + // also test out the batch api. First get new reference results (all k, up to a certain + // max size) + auto all_size = std::min(params_.num_db_vecs, 1024); + auto all_indices = raft::make_device_matrix(handle_, num_queries, all_size); + auto all_distances = raft::make_device_matrix(handle_, num_queries, all_size); + raft::neighbors::brute_force::search( + handle_, idx, query_view, all_indices.view(), all_distances.view()); + + int64_t offset = 0; + auto query = make_batch_k_query(handle_, idx, query_view, k_); + for (auto batch : *query) { + auto batch_size = batch.batch_size(); + auto indices = raft::make_device_matrix(handle_, num_queries, batch_size); + auto distances = raft::make_device_matrix(handle_, num_queries, batch_size); + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + + matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords); + matrix::slice( + handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(), + batch.indices().data_handle(), + distances.data_handle(), + batch.distances().data_handle(), + num_queries, + batch_size, + float(0.001), + stream_, + true)); + + offset += batch_size; + if (offset + batch_size > all_size) break; + } + + // also test out with variable batch sizes + offset = 0; + int64_t batch_size = k_; + query = make_batch_k_query(handle_, idx, query_view, batch_size); + for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { + // batch_size could be less than requested (in the case of final batch). handle. + ASSERT_TRUE(it->indices().extent(1) <= batch_size); + batch_size = it->indices().extent(1); + + auto indices = raft::make_device_matrix(handle_, num_queries, batch_size); + auto distances = raft::make_device_matrix(handle_, num_queries, batch_size); + + matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; + matrix::slice(handle_, raft::make_const_mdspan(all_indices.view()), indices.view(), coords); + matrix::slice( + handle_, raft::make_const_mdspan(all_distances.view()), distances.view(), coords); + + ASSERT_TRUE(raft::spatial::knn::devArrMatchKnnPair(indices.data_handle(), + it->indices().data_handle(), + distances.data_handle(), + it->distances().data_handle(), + num_queries, + batch_size, + float(0.001), + stream_, + true)); + + offset += batch_size; + if (offset + batch_size > all_size) break; + + batch_size += 2; + } } }