From 856288a2b4c4d9a74b5cbf4d0f5f2a64978072ba Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Thu, 11 Jan 2024 20:26:52 +0100 Subject: [PATCH] Add IVF-PQ example into the template project (#2091) A simple example with search and refinement. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2091 --- cpp/template/CMakeLists.txt | 5 +- cpp/template/src/common.cuh | 1 + cpp/template/src/ivf_pq_example.cu | 116 +++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 cpp/template/src/ivf_pq_example.cu diff --git a/cpp/template/CMakeLists.txt b/cpp/template/CMakeLists.txt index 538eac07ef..40a3795ed1 100644 --- a/cpp/template/CMakeLists.txt +++ b/cpp/template/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -39,3 +39,6 @@ target_link_libraries(CAGRA_EXAMPLE PRIVATE raft::raft raft::compiled) add_executable(IVF_FLAT_EXAMPLE src/ivf_flat_example.cu) target_link_libraries(IVF_FLAT_EXAMPLE PRIVATE raft::raft raft::compiled) + +add_executable(IVF_PQ_EXAMPLE src/ivf_pq_example.cu) +target_link_libraries(IVF_PQ_EXAMPLE PRIVATE raft::raft raft::compiled) diff --git a/cpp/template/src/common.cuh b/cpp/template/src/common.cuh index c2cb15bcf3..193abc747d 100644 --- a/cpp/template/src/common.cuh +++ b/cpp/template/src/common.cuh @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/template/src/ivf_pq_example.cu b/cpp/template/src/ivf_pq_example.cu new file mode 100644 index 0000000000..4bc0ba4348 --- /dev/null +++ b/cpp/template/src/ivf_pq_example.cu @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common.cuh" + +#include +#include +#include +#include + +#include +#include + +#include + +void ivf_pq_build_search(raft::device_resources const& dev_resources, + raft::device_matrix_view dataset, + raft::device_matrix_view queries) +{ + using namespace raft::neighbors; // NOLINT + + ivf_pq::index_params index_params; + index_params.n_lists = 1024; + index_params.kmeans_trainset_fraction = 0.1; + index_params.metric = raft::distance::DistanceType::L2Expanded; + index_params.pq_bits = 8; + index_params.pq_dim = 2; + + std::cout << "Building IVF-PQ index" << std::endl; + auto index = ivf_pq::build(dev_resources, index_params, dataset); + + std::cout << "Number of clusters " << index.n_lists() << ", number of vectors added to index " + << index.size() << std::endl; + + // Set search parameters. + ivf_pq::search_params search_params; + search_params.n_probes = 50; + // Set the internal search precision to 16-bit floats; + // usually, this improves the performance at a slight cost to the recall. + search_params.internal_distance_dtype = CUDA_R_16F; + search_params.lut_dtype = CUDA_R_16F; + + // Create output arrays. + int64_t topk = 10; + int64_t n_queries = queries.extent(0); + auto neighbors = raft::make_device_matrix(dev_resources, n_queries, topk); + auto distances = raft::make_device_matrix(dev_resources, n_queries, topk); + + // Search K nearest neighbors for each of the queries. + ivf_pq::search( + dev_resources, search_params, index, queries, neighbors.view(), distances.view()); + + // Re-ranking operation: refine the initial search results by computing exact distances + int64_t topk_refined = 7; + auto neighbors_refined = + raft::make_device_matrix(dev_resources, n_queries, topk_refined); + auto distances_refined = raft::make_device_matrix(dev_resources, n_queries, topk_refined); + + // Note, refinement requires the original dataset and the queries. + // Don't forget to specify the same distance metric as used by the index. + raft::neighbors::refine(dev_resources, + dataset, + queries, + raft::make_const_mdspan(neighbors.view()), + neighbors_refined.view(), + distances_refined.view(), + index.metric()); + + // Show both the original and the refined results + std::cout << std::endl << "Original results:" << std::endl; + print_results(dev_resources, neighbors.view(), distances.view()); + std::cout << std::endl << "Refined results:" << std::endl; + print_results(dev_resources, neighbors_refined.view(), distances_refined.view()); +} + +int main() +{ + raft::device_resources dev_resources; + + // Set pool memory resource with 1 GiB initial pool size. All allocations use the same pool. + rmm::mr::pool_memory_resource pool_mr( + rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull); + rmm::mr::set_current_device_resource(&pool_mr); + + // Alternatively, one could define a pool allocator for temporary arrays (used within RAFT + // algorithms). In that case only the internal arrays would use the pool, any other allocation + // uses the default RMM memory resource. Here is how to change the workspace memory resource to + // a pool with 2 GiB upper limit. + // raft::resource::set_workspace_to_pool_resource(dev_resources, 2 * 1024 * 1024 * 1024ull); + + // Create input arrays. + int64_t n_samples = 10000; + int64_t n_dim = 3; + int64_t n_queries = 10; + auto dataset = raft::make_device_matrix(dev_resources, n_samples, n_dim); + auto queries = raft::make_device_matrix(dev_resources, n_queries, n_dim); + generate_dataset(dev_resources, dataset.view(), queries.view()); + + // Simple build and search example. + ivf_pq_build_search(dev_resources, + raft::make_const_mdspan(dataset.view()), + raft::make_const_mdspan(queries.view())); +}