From 88e9a555d46398be9445720e5a10cb142cabb136 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 28 Nov 2023 03:55:05 +0100 Subject: [PATCH] Enable host dataset for IVF-Flat (#1635) Enable host input data for IVF-Flat build. This is done by batch-wise processing the dataset during extend, similarly how IVF-PQ does it. Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1635 --- .../ann/src/raft/raft_ivf_flat_wrapper.h | 2 +- .../raft/neighbors/detail/ivf_flat_build.cuh | 106 ++++++++++++------ cpp/include/raft/neighbors/ivf_flat-ext.cuh | 52 ++++++++- cpp/include/raft/neighbors/ivf_flat-inl.cuh | 75 ++++++++++++- cpp/src/neighbors/ivf_flat_00_generate.py | 69 ++++++++---- .../neighbors/ivf_flat_build_float_int64_t.cu | 12 ++ .../ivf_flat_build_int8_t_int64_t.cu | 12 ++ .../ivf_flat_build_uint8_t_int64_t.cu | 12 ++ .../ivf_flat_extend_float_int64_t.cu | 13 +++ .../ivf_flat_extend_int8_t_int64_t.cu | 13 +++ .../ivf_flat_extend_uint8_t_int64_t.cu | 13 +++ cpp/test/neighbors/ann_ivf_flat.cuh | 91 +++++++++++---- 12 files changed, 385 insertions(+), 85 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 24b3c69bb6..13ea20d483 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -78,7 +78,7 @@ class RaftIvfFlatGpu : public ANN { AlgoProperty get_preference() const override { AlgoProperty property; - property.dataset_memory_type = MemoryType::Device; + property.dataset_memory_type = MemoryType::HostMmap; property.query_memory_type = MemoryType::Device; return property; } diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index a9a6ac025f..a35cb9e1f1 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -120,7 +120,8 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, uint32_t* list_sizes_ptr, IdxT n_rows, uint32_t dim, - uint32_t veclen) + uint32_t veclen, + IdxT batch_offset = 0) { const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; if (i >= n_rows) { return; } @@ -131,7 +132,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, auto* list_data = list_data_ptrs[list_id]; // Record the source vector id in the index - list_index[inlist_id] = source_ixs == nullptr ? i : source_ixs[i]; + list_index[inlist_id] = source_ixs == nullptr ? i + batch_offset : source_ixs[i]; // The data is written in interleaved groups of `index::kGroupSize` vectors using interleaved_group = Pow2; @@ -180,16 +181,33 @@ void extend(raft::resources const& handle, auto new_labels = raft::make_device_vector(handle, n_rows); raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = index->metric(); - auto new_vectors_view = raft::make_device_matrix_view(new_vectors, n_rows, dim); + kmeans_params.metric = index->metric(); auto orig_centroids_view = raft::make_device_matrix_view(index->centers().data_handle(), n_lists, dim); - raft::cluster::kmeans_balanced::predict(handle, - kmeans_params, - new_vectors_view, - orig_centroids_view, - new_labels.view(), - utils::mapping{}); + // Calculate the batch size for the input data if it's not accessible directly from the device + constexpr size_t kReasonableMaxBatchSize = 65536; + size_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); + + // Predict the cluster labels for the new data, in batches if necessary + utils::batch_load_iterator vec_batches(new_vectors, + n_rows, + index->dim(), + max_batch_size, + stream, + resource::get_workspace_resource(handle)); + + for (const auto& batch : vec_batches) { + auto batch_data_view = + raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( + new_labels.data_handle() + batch.offset(), batch.size()); + raft::cluster::kmeans_balanced::predict(handle, + kmeans_params, + batch_data_view, + orig_centroids_view, + batch_labels_view, + utils::mapping{}); + } auto* list_sizes_ptr = index->list_sizes().data_handle(); auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); @@ -202,14 +220,19 @@ void extend(raft::resources const& handle, auto list_sizes_view = raft::make_device_vector_view, IdxT>( list_sizes_ptr, n_lists); - auto const_labels_view = make_const_mdspan(new_labels.view()); - raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, - new_vectors_view, - const_labels_view, - centroids_view, - list_sizes_view, - false, - utils::mapping{}); + for (const auto& batch : vec_batches) { + auto batch_data_view = + raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); + auto batch_labels_view = raft::make_device_vector_view( + new_labels.data_handle() + batch.offset(), batch.size()); + raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, + batch_data_view, + batch_labels_view, + centroids_view, + list_sizes_view, + false, + utils::mapping{}); + } } else { raft::stats::histogram(raft::stats::HistTypeAuto, reinterpret_cast(list_sizes_ptr), @@ -244,20 +267,39 @@ void extend(raft::resources const& handle, // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); - // Kernel to insert the new vectors - const dim3 block_dim(256); - const dim3 grid_dim(raft::ceildiv(n_rows, block_dim.x)); - build_index_kernel<<>>(new_labels.data_handle(), - new_vectors, - new_indices, - index->data_ptrs().data_handle(), - index->inds_ptrs().data_handle(), - list_sizes_ptr, - n_rows, - dim, - index->veclen()); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - + utils::batch_load_iterator vec_indices( + new_indices, n_rows, 1, max_batch_size, stream, resource::get_workspace_resource(handle)); + utils::batch_load_iterator idx_batch = vec_indices.begin(); + size_t next_report_offset = 0; + size_t d_report_offset = n_rows * 5 / 100; + for (const auto& batch : vec_batches) { + auto batch_data_view = + raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); + // Kernel to insert the new vectors + const dim3 block_dim(256); + const dim3 grid_dim(raft::ceildiv(batch.size(), block_dim.x)); + build_index_kernel + <<>>(new_labels.data_handle() + batch.offset(), + batch_data_view.data_handle(), + idx_batch->data(), + index->data_ptrs().data_handle(), + index->inds_ptrs().data_handle(), + list_sizes_ptr, + batch.size(), + dim, + index->veclen(), + batch.offset()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + if (batch.offset() > next_report_offset) { + float progress = batch.offset() * 100.0f / n_rows; + RAFT_LOG_DEBUG("ivf_flat::extend added vectors %zu, %6.1f%% complete", + static_cast(batch.offset()), + progress); + next_report_offset += d_report_offset; + } + ++idx_batch; + } // Precompute the centers vector norms for L2Expanded distance if (!index->center_norms().has_value()) { index->allocate_center_norms(handle); diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index 8dbe7587ff..063105cf46 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -48,6 +48,18 @@ void build(raft::resources const& handle, raft::device_matrix_view dataset, raft::neighbors::ivf_flat::index& idx) RAFT_EXPLICIT; +template +auto build(raft::resources const& handle, + const index_params& params, + raft::host_matrix_view dataset) + -> index RAFT_EXPLICIT; + +template +void build(raft::resources const& handle, + const index_params& params, + raft::host_matrix_view dataset, + raft::neighbors::ivf_flat::index& idx) RAFT_EXPLICIT; + template auto extend(raft::resources const& handle, const index& orig_index, @@ -74,6 +86,19 @@ void extend(raft::resources const& handle, std::optional> new_indices, index* index) RAFT_EXPLICIT; +template +auto extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + const raft::neighbors::ivf_flat::index& orig_index) + -> raft::neighbors::ivf_flat::index RAFT_EXPLICIT; + +template +void extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + index* index) RAFT_EXPLICIT; + template void search_with_filtering(raft::resources const& handle, const search_params& params, @@ -137,6 +162,18 @@ void search(raft::resources const& handle, raft::resources const& handle, \ const raft::neighbors::ivf_flat::index_params& params, \ raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); \ + \ + extern template auto raft::neighbors::ivf_flat::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template void raft::neighbors::ivf_flat::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::host_matrix_view dataset, \ raft::neighbors::ivf_flat::index& idx); instantiate_raft_neighbors_ivf_flat_build(float, int64_t); @@ -171,7 +208,20 @@ instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); raft::resources const& handle, \ raft::device_matrix_view new_vectors, \ std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* index); + raft::neighbors::ivf_flat::index* index); \ + \ + extern template void raft::neighbors::ivf_flat::extend( \ + raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); \ + \ + extern template auto raft::neighbors::ivf_flat::extend( \ + const raft::resources& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& idx) \ + ->raft::neighbors::ivf_flat::index; instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index 692fb08810..b540de7f14 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -55,7 +55,7 @@ namespace raft::neighbors::ivf_flat { * * @param[in] handle * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * @param[in] dataset a host or device pointer to a row-major matrix [n_rows, dim] * @param[in] n_rows the number of samples * @param[in] dim the dimensionality of the data * @@ -102,7 +102,7 @@ auto build(raft::resources const& handle, * * @param[in] handle * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * @param[in] dataset a device matrix [n_rows, dim] * * @return the constructed ivf-flat index */ @@ -118,6 +118,20 @@ auto build(raft::resources const& handle, static_cast(dataset.extent(1))); } +/** + * @brief Build the index from a dataset in host memory. + */ +template +auto build(raft::resources const& handle, + const index_params& params, + raft::host_matrix_view dataset) -> index +{ + return raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} /** * @brief Build the index from the dataset for efficient search. * @@ -162,6 +176,21 @@ void build(raft::resources const& handle, static_cast(dataset.extent(1))); } +/** + * @brief Build the index from a dataset in host memory. + */ +template +void build(raft::resources const& handle, + const index_params& params, + raft::host_matrix_view dataset, + raft::neighbors::ivf_flat::index& idx) +{ + idx = raft::neighbors::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} /** @} */ /** @@ -188,8 +217,8 @@ void build(raft::resources const& handle, * * @param[in] handle * @param[in] orig_index original index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. * @param[in] n_rows number of rows in `new_vectors` @@ -257,6 +286,23 @@ auto extend(raft::resources const& handle, new_vectors.extent(0)); } +/** + * @brief Extend the index with additional vectors. + * + * This overloads takes input data in host memory. + */ +template +auto extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + const index& orig_index) -> index +{ + return extend(handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); +} /** @} */ /** @@ -279,8 +325,8 @@ auto extend(raft::resources const& handle, * * @param handle * @param[inout] index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. * @param[in] n_rows the number of samples @@ -339,6 +385,23 @@ void extend(raft::resources const& handle, static_cast(new_vectors.extent(0))); } +/** + * @brief Extend the index with additional vectors. + * + * This overloads takes input data in host memory. + */ +template +void extend(raft::resources const& handle, + raft::host_matrix_view new_vectors, + std::optional> new_indices, + index* index) +{ + extend(handle, + index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + static_cast(new_vectors.extent(0))); +} /** @} */ /** diff --git a/cpp/src/neighbors/ivf_flat_00_generate.py b/cpp/src/neighbors/ivf_flat_00_generate.py index b02606a23e..d987a4e17d 100644 --- a/cpp/src/neighbors/ivf_flat_00_generate.py +++ b/cpp/src/neighbors/ivf_flat_00_generate.py @@ -41,63 +41,88 @@ """ types = dict( - float_int64_t= ("float", "int64_t"), + float_int64_t=("float", "int64_t"), int8_t_int64_t=("int8_t", "int64_t"), uint8_t_int64_t=("uint8_t", "int64_t"), ) build_macro = """ #define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \\ - template auto raft::neighbors::ivf_flat::build( \\ - raft::resources const& handle, \\ + template auto raft::neighbors::ivf_flat::build( \\ + raft::resources const& handle, \\ const raft::neighbors::ivf_flat::index_params& params, \\ const T* dataset, \\ IdxT n_rows, \\ uint32_t dim) \\ ->raft::neighbors::ivf_flat::index; \\ \\ - template auto raft::neighbors::ivf_flat::build( \\ - raft::resources const& handle, \\ + template auto raft::neighbors::ivf_flat::build( \\ + raft::resources const& handle, \\ const raft::neighbors::ivf_flat::index_params& params, \\ raft::device_matrix_view dataset) \\ ->raft::neighbors::ivf_flat::index; \\ \\ - template void raft::neighbors::ivf_flat::build( \\ - raft::resources const& handle, \\ + template void raft::neighbors::ivf_flat::build( \\ + raft::resources const& handle, \\ const raft::neighbors::ivf_flat::index_params& params, \\ raft::device_matrix_view dataset, \\ + raft::neighbors::ivf_flat::index& idx); \\ + \\ + template auto raft::neighbors::ivf_flat::build( \\ + raft::resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + raft::host_matrix_view dataset) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template void raft::neighbors::ivf_flat::build( \\ + raft::resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + raft::host_matrix_view dataset, \\ raft::neighbors::ivf_flat::index& idx); """ extend_macro = """ #define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \\ - template auto raft::neighbors::ivf_flat::extend( \\ - raft::resources const& handle, \\ + template auto raft::neighbors::ivf_flat::extend( \\ + raft::resources const& handle, \\ const raft::neighbors::ivf_flat::index& orig_index, \\ const T* new_vectors, \\ const IdxT* new_indices, \\ IdxT n_rows) \\ ->raft::neighbors::ivf_flat::index; \\ \\ - template auto raft::neighbors::ivf_flat::extend( \\ - raft::resources const& handle, \\ + template auto raft::neighbors::ivf_flat::extend( \\ + raft::resources const& handle, \\ raft::device_matrix_view new_vectors, \\ std::optional> new_indices, \\ const raft::neighbors::ivf_flat::index& orig_index) \\ ->raft::neighbors::ivf_flat::index; \\ \\ - template void raft::neighbors::ivf_flat::extend( \\ - raft::resources const& handle, \\ + template void raft::neighbors::ivf_flat::extend( \\ + raft::resources const& handle, \\ raft::neighbors::ivf_flat::index* index, \\ const T* new_vectors, \\ const IdxT* new_indices, \\ IdxT n_rows); \\ \\ - template void raft::neighbors::ivf_flat::extend( \\ - raft::resources const& handle, \\ + template void raft::neighbors::ivf_flat::extend( \\ + raft::resources const& handle, \\ raft::device_matrix_view new_vectors, \\ std::optional> new_indices, \\ - raft::neighbors::ivf_flat::index* index); + raft::neighbors::ivf_flat::index* index); \\ + \\ + template auto raft::neighbors::ivf_flat::extend( \\ + const raft::resources& handle, \\ + raft::host_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + const raft::neighbors::ivf_flat::index& idx) \\ + -> raft::neighbors::ivf_flat::index; \\ + \\ + template void raft::neighbors::ivf_flat::extend( \\ + raft::resources const& handle, \\ + raft::host_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + raft::neighbors::ivf_flat::index* index); """ search_macro = """ @@ -125,13 +150,16 @@ macros = dict( build=dict( definition=build_macro, - name="instantiate_raft_neighbors_ivf_flat_build"), + name="instantiate_raft_neighbors_ivf_flat_build", + ), extend=dict( definition=extend_macro, - name="instantiate_raft_neighbors_ivf_flat_extend"), + name="instantiate_raft_neighbors_ivf_flat_extend", + ), search=dict( definition=search_macro, - name="instantiate_raft_neighbors_ivf_flat_search"), + name="instantiate_raft_neighbors_ivf_flat_search", + ), ) for type_path, (T, IdxT) in types.items(): @@ -139,8 +167,7 @@ path = f"ivf_flat_{macro_path}_{type_path}.cu" with open(path, "w") as f: f.write(header) - f.write(macro['definition']) - + f.write(macro["definition"]) f.write(f"{macro['name']}({T}, {IdxT});\n\n") f.write(f"#undef {macro['name']}\n") diff --git a/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu index 2ae795db56..cf3cb6b1b2 100644 --- a/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu @@ -44,6 +44,18 @@ raft::resources const& handle, \ const raft::neighbors::ivf_flat::index_params& params, \ raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::host_matrix_view dataset, \ raft::neighbors::ivf_flat::index& idx); instantiate_raft_neighbors_ivf_flat_build(float, int64_t); diff --git a/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu index deb31bf441..e1cf64907e 100644 --- a/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu @@ -44,6 +44,18 @@ raft::resources const& handle, \ const raft::neighbors::ivf_flat::index_params& params, \ raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::host_matrix_view dataset, \ raft::neighbors::ivf_flat::index& idx); instantiate_raft_neighbors_ivf_flat_build(int8_t, int64_t); diff --git a/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu index 402fdbab97..26d1647954 100644 --- a/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu @@ -44,6 +44,18 @@ raft::resources const& handle, \ const raft::neighbors::ivf_flat::index_params& params, \ raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::host_matrix_view dataset, \ raft::neighbors::ivf_flat::index& idx); instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); diff --git a/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu index 9e7701f773..16472c6692 100644 --- a/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu @@ -52,6 +52,19 @@ raft::resources const& handle, \ raft::device_matrix_view new_vectors, \ std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + const raft::resources& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& idx) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ raft::neighbors::ivf_flat::index* index); instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); diff --git a/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu index 5d3d23c3ab..d98b5225c3 100644 --- a/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu @@ -52,6 +52,19 @@ raft::resources const& handle, \ raft::device_matrix_view new_vectors, \ std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + const raft::resources& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& idx) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ raft::neighbors::ivf_flat::index* index); instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); diff --git a/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu index 3150a676eb..520c3be536 100644 --- a/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu @@ -52,6 +52,19 @@ raft::resources const& handle, \ raft::device_matrix_view new_vectors, \ std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + const raft::resources& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& idx) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::resources const& handle, \ + raft::host_matrix_view new_vectors, \ + std::optional> new_indices, \ raft::neighbors::ivf_flat::index* index); instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index a9fd696f1f..39439d392d 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -72,6 +72,7 @@ struct AnnIvfFlatInputs { IdxT nlist; raft::distance::DistanceType metric; bool adaptive_centers; + bool host_dataset; }; template @@ -79,7 +80,7 @@ template { os << "{ " << p.num_queries << ", " << p.num_db_vecs << ", " << p.dim << ", " << p.k << ", " << p.nprobe << ", " << p.nlist << ", " << static_cast(p.metric) << ", " - << p.adaptive_centers << '}' << std::endl; + << p.adaptive_centers << ", " << p.host_dataset << '}' << std::endl; return os; } @@ -178,36 +179,69 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { index_params.kmeans_trainset_fraction = 0.5; index_params.metric_arg = 0; - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - - auto idx = ivf_flat::build(handle_, index_params, database_view); + ivf_flat::index idx(handle_, index_params, ps.dim); + ivf_flat::index index_2(handle_, index_params, ps.dim); + + if (!ps.host_dataset) { + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + idx = ivf_flat::build(handle_, index_params, database_view); + rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); + thrust::sequence(resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(vector_indices.data()), + thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs)); + resource::sync_stream(handle_); - rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); - thrust::sequence(resource::get_thrust_policy(handle_), - thrust::device_pointer_cast(vector_indices.data()), - thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs)); - resource::sync_stream(handle_); + IdxT half_of_data = ps.num_db_vecs / 2; - IdxT half_of_data = ps.num_db_vecs / 2; + auto half_of_data_view = raft::make_device_matrix_view( + (const DataT*)database.data(), half_of_data, ps.dim); - auto half_of_data_view = raft::make_device_matrix_view( - (const DataT*)database.data(), half_of_data, ps.dim); + const std::optional> no_opt = std::nullopt; + index_2 = ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); - const std::optional> no_opt = std::nullopt; - index index_2 = ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); + auto new_half_of_data_view = raft::make_device_matrix_view( + database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); - auto new_half_of_data_view = raft::make_device_matrix_view( - database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); + auto new_half_of_data_indices_view = raft::make_device_vector_view( + vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); - auto new_half_of_data_indices_view = raft::make_device_vector_view( - vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); + ivf_flat::extend(handle_, + new_half_of_data_view, + std::make_optional>( + new_half_of_data_indices_view), + &index_2); - ivf_flat::extend(handle_, - new_half_of_data_view, - std::make_optional>( - new_half_of_data_indices_view), - &index_2); + } else { + auto host_database = raft::make_host_matrix(ps.num_db_vecs, ps.dim); + raft::copy( + host_database.data_handle(), database.data(), ps.num_db_vecs * ps.dim, stream_); + idx = + ivf_flat::build(handle_, index_params, raft::make_const_mdspan(host_database.view())); + + auto vector_indices = raft::make_host_vector(handle_, ps.num_db_vecs); + std::iota(vector_indices.data_handle(), vector_indices.data_handle() + ps.num_db_vecs, 0); + + IdxT half_of_data = ps.num_db_vecs / 2; + + auto half_of_data_view = raft::make_host_matrix_view( + (const DataT*)host_database.data_handle(), half_of_data, ps.dim); + + const std::optional> no_opt = std::nullopt; + index_2 = ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); + + auto new_half_of_data_view = raft::make_host_matrix_view( + host_database.data_handle() + half_of_data * ps.dim, + IdxT(ps.num_db_vecs) - half_of_data, + ps.dim); + auto new_half_of_data_indices_view = raft::make_host_vector_view( + vector_indices.data_handle() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); + ivf_flat::extend(handle_, + new_half_of_data_view, + std::make_optional>( + new_half_of_data_indices_view), + &index_2); + } auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); @@ -574,6 +608,15 @@ const std::vector> inputs = { {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true}, {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false}, + // host input data + {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::L2Expanded, false, true}, + {100, 10000, 16, 10, 20, 512, raft::distance::DistanceType::L2Expanded, false, true}, + {20, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, true}, + {1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, true}, + {10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::InnerProduct, true}, {1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::InnerProduct, true}, {1000, 10000, 16, 10, 70, 1024, raft::distance::DistanceType::InnerProduct, false},