From ec4236a9905aee1ad79a11265b6f8240bf9657c3 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Thu, 7 Sep 2023 19:07:03 +0200 Subject: [PATCH] ann-bench: miscellaneous improvements (#1808) 1. IVF-PQ: slightly improve stream ordering 2. IVF-PQ: build param 'codebook_kind' - as per `ivf_pq_types`. 3. FAISS IVF models: build param 'ratio' with the same meaning as in IVF-PQ - the clustering algorithm uses `1/ratio` of the given dataset for training. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1808 --- cpp/bench/ann/src/faiss/faiss_benchmark.cu | 16 +++-- cpp/bench/ann/src/faiss/faiss_wrapper.h | 75 +++++++++++++------- cpp/bench/ann/src/raft/raft_benchmark.cu | 16 +++-- cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 41 ++++++----- docs/source/ann_benchmarks_param_tuning.md | 12 ++-- 5 files changed, 105 insertions(+), 55 deletions(-) diff --git a/cpp/bench/ann/src/faiss/faiss_benchmark.cu b/cpp/bench/ann/src/faiss/faiss_benchmark.cu index 231154ccfd..56885cce5c 100644 --- a/cpp/bench/ann/src/faiss/faiss_benchmark.cu +++ b/cpp/bench/ann/src/faiss/faiss_benchmark.cu @@ -30,19 +30,27 @@ namespace raft::bench::ann { +template +void parse_base_build_param(const nlohmann::json& conf, + typename raft::bench::ann::FaissGpu::BuildParam& param) +{ + param.nlist = conf.at("nlist"); + if (conf.contains("ratio")) { param.ratio = conf.at("ratio"); } +} + template void parse_build_param(const nlohmann::json& conf, typename raft::bench::ann::FaissGpuIVFFlat::BuildParam& param) { - param.nlist = conf.at("nlist"); + parse_base_build_param(conf, param); } template void parse_build_param(const nlohmann::json& conf, typename raft::bench::ann::FaissGpuIVFPQ::BuildParam& param) { - param.nlist = conf.at("nlist"); - param.M = conf.at("M"); + parse_base_build_param(conf, param); + param.M = conf.at("M"); if (conf.contains("usePrecomputed")) { param.usePrecomputed = conf.at("usePrecomputed"); } else { @@ -59,7 +67,7 @@ template void parse_build_param(const nlohmann::json& conf, typename raft::bench::ann::FaissGpuIVFSQ::BuildParam& param) { - param.nlist = conf.at("nlist"); + parse_base_build_param(conf, param); param.quantizer_type = conf.at("quantizer_type"); } diff --git a/cpp/bench/ann/src/faiss/faiss_wrapper.h b/cpp/bench/ann/src/faiss/faiss_wrapper.h index ec80e6cbfd..672c685b1f 100644 --- a/cpp/bench/ann/src/faiss/faiss_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_wrapper.h @@ -18,6 +18,7 @@ #include "../common/ann_types.hpp" +#include #include #include @@ -85,7 +86,23 @@ class FaissGpu : public ANN { float refine_ratio = 1.0; }; - FaissGpu(Metric metric, int dim, int nlist); + struct BuildParam { + int nlist = 1; + int ratio = 2; + }; + + FaissGpu(Metric metric, int dim, const BuildParam& param) + : ANN(metric, dim), + metric_type_(parse_metric_type(metric)), + nlist_{param.nlist}, + training_sample_fraction_{1.0 / double(param.ratio)} + { + static_assert(std::is_same_v, "faiss support only float type"); + RAFT_CUDA_TRY(cudaGetDevice(&device_)); + RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming)); + faiss_default_stream_ = gpu_resource_.getDefaultStream(device_); + } + virtual ~FaissGpu() noexcept { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); } void build(const T* dataset, size_t nrow, cudaStream_t stream = 0) final; @@ -131,23 +148,35 @@ class FaissGpu : public ANN { int device_; cudaEvent_t sync_{nullptr}; cudaStream_t faiss_default_stream_{nullptr}; + double training_sample_fraction_; }; -template -FaissGpu::FaissGpu(Metric metric, int dim, int nlist) - : ANN(metric, dim), metric_type_(parse_metric_type(metric)), nlist_(nlist) -{ - static_assert(std::is_same_v, "faiss support only float type"); - RAFT_CUDA_TRY(cudaGetDevice(&device_)); - RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming)); - faiss_default_stream_ = gpu_resource_.getDefaultStream(device_); -} - template void FaissGpu::build(const T* dataset, size_t nrow, cudaStream_t stream) { OmpSingleThreadScope omp_single_thread; - + auto index_ivf = dynamic_cast(index_.get()); + if (index_ivf != nullptr) { + // set the min/max training size for clustering to use the whole provided training set. + double trainset_size = training_sample_fraction_ * static_cast(nrow); + double points_per_centroid = trainset_size / static_cast(nlist_); + int max_ppc = std::ceil(points_per_centroid); + int min_ppc = std::floor(points_per_centroid); + if (min_ppc < index_ivf->cp.min_points_per_centroid) { + RAFT_LOG_WARN( + "The suggested training set size %zu (data size %zu, training sample ratio %f) yields %d " + "points per cluster (n_lists = %d). This is smaller than the FAISS default " + "min_points_per_centroid = %d.", + static_cast(trainset_size), + nrow, + training_sample_fraction_, + min_ppc, + nlist_, + index_ivf->cp.min_points_per_centroid); + } + index_ivf->cp.max_points_per_centroid = max_ppc; + index_ivf->cp.min_points_per_centroid = min_ppc; + } index_->train(nrow, dataset); // faiss::gpu::GpuIndexFlat::train() will do nothing assert(index_->is_trained); index_->add(nrow, dataset); @@ -208,12 +237,9 @@ void FaissGpu::load_(const std::string& file) template class FaissGpuIVFFlat : public FaissGpu { public: - struct BuildParam { - int nlist; - }; + using typename FaissGpu::BuildParam; - FaissGpuIVFFlat(Metric metric, int dim, const BuildParam& param) - : FaissGpu(metric, dim, param.nlist) + FaissGpuIVFFlat(Metric metric, int dim, const BuildParam& param) : FaissGpu(metric, dim, param) { faiss::gpu::GpuIndexIVFFlatConfig config; config.device = this->device_; @@ -234,15 +260,13 @@ class FaissGpuIVFFlat : public FaissGpu { template class FaissGpuIVFPQ : public FaissGpu { public: - struct BuildParam { - int nlist; + struct BuildParam : public FaissGpu::BuildParam { int M; bool useFloat16; bool usePrecomputed; }; - FaissGpuIVFPQ(Metric metric, int dim, const BuildParam& param) - : FaissGpu(metric, dim, param.nlist) + FaissGpuIVFPQ(Metric metric, int dim, const BuildParam& param) : FaissGpu(metric, dim, param) { faiss::gpu::GpuIndexIVFPQConfig config; config.useFloat16LookupTables = param.useFloat16; @@ -271,13 +295,11 @@ class FaissGpuIVFPQ : public FaissGpu { template class FaissGpuIVFSQ : public FaissGpu { public: - struct BuildParam { - int nlist; + struct BuildParam : public FaissGpu::BuildParam { std::string quantizer_type; }; - FaissGpuIVFSQ(Metric metric, int dim, const BuildParam& param) - : FaissGpu(metric, dim, param.nlist) + FaissGpuIVFSQ(Metric metric, int dim, const BuildParam& param) : FaissGpu(metric, dim, param) { faiss::ScalarQuantizer::QuantizerType qtype; if (param.quantizer_type == "fp16") { @@ -310,7 +332,8 @@ class FaissGpuIVFSQ : public FaissGpu { template class FaissGpuFlat : public FaissGpu { public: - FaissGpuFlat(Metric metric, int dim) : FaissGpu(metric, dim, 0) + FaissGpuFlat(Metric metric, int dim) + : FaissGpu(metric, dim, typename FaissGpu::BuildParam{}) { faiss::gpu::GpuIndexFlatConfig config; config.device = this->device_; diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index aa25d1532f..7ba381ab0a 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -58,10 +58,7 @@ void parse_build_param(const nlohmann::json& conf, { param.n_lists = conf.at("nlist"); if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } - if (conf.contains("ratio")) { - param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); - std::cout << "kmeans_trainset_fraction " << param.kmeans_trainset_fraction; - } + if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } } template @@ -82,6 +79,17 @@ void parse_build_param(const nlohmann::json& conf, if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); } + if (conf.contains("codebook_kind")) { + std::string kind = conf.at("codebook_kind"); + if (kind == "cluster") { + param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; + } else if (kind == "subspace") { + param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + } else { + throw std::runtime_error("codebook_kind: '" + kind + + "', should be either 'cluster' or 'subspace'"); + } + } } template diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 1554c1f016..8f1e43a706 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -63,9 +63,14 @@ class RaftIvfPQ : public ANN { rmm::mr::set_current_device_resource(&mr_); index_params_.metric = parse_metric_type(metric); RAFT_CUDA_TRY(cudaGetDevice(&device_)); + RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming)); } - ~RaftIvfPQ() noexcept { rmm::mr::set_current_device_resource(mr_.get_upstream()); } + ~RaftIvfPQ() noexcept + { + RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_)); + rmm::mr::set_current_device_resource(mr_.get_upstream()); + } void build(const T* dataset, size_t nrow, cudaStream_t stream) final; @@ -96,6 +101,7 @@ class RaftIvfPQ : public ANN { // `mr_` must go first to make sure it dies last rmm::mr::pool_memory_resource mr_; raft::device_resources handle_; + cudaEvent_t sync_{nullptr}; BuildParam index_params_; raft::neighbors::ivf_pq::search_params search_params_; std::optional> index_; @@ -103,6 +109,12 @@ class RaftIvfPQ : public ANN { int dimension_; float refine_ratio_ = 1.0; raft::device_matrix_view dataset_; + + void stream_wait(cudaStream_t stream) const + { + RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(handle_))); + RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_)); + } }; template @@ -121,12 +133,12 @@ void RaftIvfPQ::load(const std::string& file) } template -void RaftIvfPQ::build(const T* dataset, size_t nrow, cudaStream_t) +void RaftIvfPQ::build(const T* dataset, size_t nrow, cudaStream_t stream) { auto dataset_v = raft::make_device_matrix_view(dataset, IdxT(nrow), dim_); index_.emplace(raft::runtime::neighbors::ivf_pq::build(handle_, index_params_, dataset_v)); - return; + stream_wait(stream); } template @@ -176,16 +188,14 @@ void RaftIvfPQ::search(const T* queries, neighbors_v, distances_v, index_->metric()); + stream_wait(stream); // RAFT stream -> bench stream } else { auto queries_host = raft::make_host_matrix(batch_size, index_->dim()); auto candidates_host = raft::make_host_matrix(batch_size, k0); auto neighbors_host = raft::make_host_matrix(batch_size, k); auto distances_host = raft::make_host_matrix(batch_size, k); - raft::copy(queries_host.data_handle(), - queries, - queries_host.size(), - resource::get_cuda_stream(handle_)); + raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); raft::copy(candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), @@ -194,6 +204,10 @@ void RaftIvfPQ::search(const T* queries, auto dataset_v = raft::make_host_matrix_view( dataset_.data_handle(), dataset_.extent(0), dataset_.extent(1)); + // wait for the queries to copy to host in 'stream` and for IVF-PQ::search to finish + RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(handle_))); + RAFT_CUDA_TRY(cudaEventRecord(sync_, stream)); + RAFT_CUDA_TRY(cudaEventSynchronize(sync_)); raft::runtime::neighbors::refine(handle_, dataset_v, queries_host.view(), @@ -202,14 +216,8 @@ void RaftIvfPQ::search(const T* queries, distances_host.view(), index_->metric()); - raft::copy(neighbors, - (size_t*)neighbors_host.data_handle(), - neighbors_host.size(), - resource::get_cuda_stream(handle_)); - raft::copy(distances, - distances_host.data_handle(), - distances_host.size(), - resource::get_cuda_stream(handle_)); + raft::copy(neighbors, (size_t*)neighbors_host.data_handle(), neighbors_host.size(), stream); + raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); } } else { auto queries_v = @@ -219,8 +227,7 @@ void RaftIvfPQ::search(const T* queries, raft::runtime::neighbors::ivf_pq::search( handle_, search_params_, *index_, queries_v, neighbors_v, distances_v); + stream_wait(stream); // RAFT stream -> bench stream } - resource::sync_stream(handle_); - return; } } // namespace raft::bench::ann diff --git a/docs/source/ann_benchmarks_param_tuning.md b/docs/source/ann_benchmarks_param_tuning.md index 020c2d5ad9..ca8ffa5e18 100644 --- a/docs/source/ann_benchmarks_param_tuning.md +++ b/docs/source/ann_benchmarks_param_tuning.md @@ -1,6 +1,6 @@ # ANN Benchmarks Parameter Tuning Guide -This guide outlines the various parameter settings that can be specified in [RAFT ANN Benchmark](raft_ann_benchmarks.md) json configuration files and explains the impact they have on corresponding algorithms to help inform their settings for benchmarking across desired levels of recall. +This guide outlines the various parameter settings that can be specified in [RAFT ANN Benchmark](raft_ann_benchmarks.md) json configuration files and explains the impact they have on corresponding algorithms to help inform their settings for benchmarking across desired levels of recall. ## RAFT Indexes @@ -15,8 +15,8 @@ IVF-flat is a simple algorithm which won't save any space, but it provides compe |-----------|------------------|----------|---------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | `nlists` | `build_param` | Y | Positive Integer >0 | | Number of clusters to partition the vectors into. Larger values will put less points into each cluster but this will impact index build time as more clusters need to be trained. | | `niter` | `build_param` | N | Positive Integer >0 | 20 | Number of clusters to partition the vectors into. Larger values will put less points into each cluster but this will impact index build time as more clusters need to be trained. | -| `ratio` | `build_param` | N | Positive Float >0 | 0.5 | Fraction of the number of training points which should be used to train the clusters. | -| `nprobe` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. | +| `ratio` | `build_param` | N | Positive Integer >0 | 2 | `1/ratio` is the number of training points which should be used to train the clusters. | +| `nprobe` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. | ### `raft_ivf_pq` @@ -27,8 +27,10 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of |-------------------------|----------------|---|------------------------------|---------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | `nlists` | `build_param` | Y | Positive Integer >0 | | Number of clusters to partition the vectors into. Larger values will put less points into each cluster but this will impact index build time as more clusters need to be trained. | | `niter` | `build_param` | N | Positive Integer >0 | 20 | Number of k-means iterations to use when training the clusters. | +| `ratio` | `build_param` | N | Positive Integer >0 | 2 | `1/ratio` is the number of training points which should be used to train the clusters. | | `pq_dim` | `build_param` | N | Positive Integer. Multiple of 8. | 0 | Dimensionality of the vector after product quantization. When 0, a heuristic is used to select this value. `pq_dim` * `pq_bits` must be a multiple of 8. | | `pq_bits` | `build_param` | N | Positive Integer. [4-8] | 8 | Bit length of the vector element after quantization. | +| `codebook_kind` | `build_param` | N | ["cluster", "subspace"] | "subspace" | Type of codebook. See the [API docs](https://docs.rapids.ai/api/raft/nightly/cpp_api/neighbors_ivf_pq/#_CPPv412codebook_gen) for more detail | | `nprobe` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. | | `internalDistanceDtype` | `search_params` | N | [`float`, `half`] | `half` | The precision to use for the distance computations. Lower precision can increase performance at the cost of accuracy. | | `smemLutDtype` | `search_params` | N | [`float`, `half`, `fp8`] | `half` | The precision to use for the lookup table in shared memory. Lower precision can increase performance at the cost of accuracy. | @@ -58,7 +60,8 @@ IVF-flat is a simple algorithm which won't save any space, but it provides compe | Parameter | Type | Required | Data Type | Default | Description | |-----------|----------------|----------|---------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `nlists` | `build_param` | Y | Positive Integer >0 | | Number of clusters to partition the vectors into. Larger values will put less points into each cluster but this will impact index build time as more clusters need to be trained. | +| `nlists` | `build_param` | Y | Positive Integer >0 | | Number of clusters to partition the vectors into. Larger values will put less points into each cluster but this will impact index build time as more clusters need to be trained. | +| `ratio` | `build_param` | N | Positive Integer >0 | 2 | `1/ratio` is the number of training points which should be used to train the clusters. | | `nprobe` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. | ### `faiss_gpu_ivf_pq` @@ -68,6 +71,7 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of | Parameter | Type | Required | Data Type | Default | Description | |------------------|----------------|----------|----------------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | `nlists` | `build_param` | Y | Positive Integer >0 | | Number of clusters to partition the vectors into. Larger values will put less points into each cluster but this will impact index build time as more clusters need to be trained. | +| `ratio` | `build_param` | N | Positive Integer >0 | 2 | `1/ratio` is the number of training points which should be used to train the clusters. | | `M` | `build_param` | Y | Positive Integer Power of 2 [8-64] | | Number of chunks or subquantizers for each vector. | | `usePrecomputed` | `build_param` | N | Boolean. Default=`false` | `false` | Use pre-computed lookup tables to speed up search at the cost of increased memory usage. | | `useFloat16` | `build_param` | N | Boolean. Default=`false` | `false` | Use half-precision floats for clustering step. |