From 1a78fb835986d27b7ea7a140324dffa642442f6d Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 18 Oct 2023 00:27:09 +0200 Subject: [PATCH] Add bench for IVF filtering --- cpp/bench/prims/CMakeLists.txt | 2 + cpp/bench/prims/neighbors/knn.cuh | 123 +++++++++++++++++- .../knn/ivf_flat_filter_float_int64_t.cu | 24 ++++ .../knn/ivf_pq_filter_float_int64_t.cu | 24 ++++ cpp/include/raft/neighbors/sample_filter.cuh | 4 +- 5 files changed, 172 insertions(+), 5 deletions(-) create mode 100644 cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu create mode 100644 cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 5da2cd916b..fe58453d0d 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -147,10 +147,12 @@ if(BUILD_PRIMS_BENCH) bench/prims/neighbors/knn/brute_force_float_int64_t.cu bench/prims/neighbors/knn/brute_force_float_uint32_t.cu bench/prims/neighbors/knn/cagra_float_uint32_t.cu + bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu + bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu bench/prims/neighbors/refine_float_int64_t.cu diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 31ac869b37..2f2ee510a9 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -21,9 +21,12 @@ #include +#include #include #include +#include #include +#include #include #include @@ -31,6 +34,8 @@ #include #include +#include + #include namespace raft::bench::spatial { @@ -44,11 +49,14 @@ struct params { size_t n_queries; /** Number of nearest neighbours to find for every probe. */ size_t k; + /** Ratio of removed indices. */ + double removed_ratio; }; inline auto operator<<(std::ostream& os, const params& p) -> std::ostream& { - os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k; + os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#" + << p.removed_ratio; return os; } @@ -221,6 +229,108 @@ struct brute_force_knn { } }; +template +struct ivf_flat_filter_knn { + using dist_t = float; + + std::optional> index; + raft::neighbors::ivf_flat::index_params index_params; + raft::neighbors::ivf_flat::search_params search_params; + raft::core::bitset removed_indices_bitset_; + params ps; + + ivf_flat_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data) + : ps(ps), removed_indices_bitset_(handle, ps.n_samples) + { + index_params.n_lists = 4096; + index_params.metric = raft::distance::DistanceType::L2Expanded; + index.emplace(raft::neighbors::ivf_flat::build( + handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); + auto removed_indices = + raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples); + thrust::sequence( + resource::get_thrust_policy(handle), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); + removed_indices_bitset_.set(handle, removed_indices.view()); + } + + void search(const raft::device_resources& handle, + const ValT* search_items, + dist_t* out_dists, + IdxT* out_idxs) + { + search_params.n_probes = 20; + auto queries_view = + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto neighbors_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto distance_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); + auto filter = raft::neighbors::filtering::ivf_to_sample_filter( + index->inds_ptrs().data_handle(), + raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view())); + + if (ps.removed_ratio > 0) { + raft::neighbors::ivf_flat::search_with_filtering( + handle, search_params, *index, queries_view, neighbors_view, distance_view, filter); + } else { + raft::neighbors::ivf_flat::search( + handle, search_params, *index, queries_view, neighbors_view, distance_view); + } + } +}; + +template +struct ivf_pq_filter_knn { + using dist_t = float; + + std::optional> index; + raft::neighbors::ivf_pq::index_params index_params; + raft::neighbors::ivf_pq::search_params search_params; + raft::core::bitset removed_indices_bitset_; + params ps; + + ivf_pq_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data) + : ps(ps), removed_indices_bitset_(handle, ps.n_samples) + { + index_params.n_lists = 4096; + index_params.metric = raft::distance::DistanceType::L2Expanded; + auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); + index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); + auto removed_indices = + raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples); + thrust::sequence( + resource::get_thrust_policy(handle), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); + removed_indices_bitset_.set(handle, removed_indices.view()); + } + + void search(const raft::device_resources& handle, + const ValT* search_items, + dist_t* out_dists, + IdxT* out_idxs) + { + search_params.n_probes = 20; + auto queries_view = + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto neighbors_view = + raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto distance_view = + raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); + auto filter = raft::neighbors::filtering::ivf_to_sample_filter( + index->inds_ptrs().data_handle(), + raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view())); + + if (ps.removed_ratio > 0) { + raft::neighbors::ivf_pq::search_with_filtering( + handle, search_params, *index, queries_view, neighbors_view, distance_view, filter); + } else { + raft::neighbors::ivf_pq::search( + handle, search_params, *index, queries_view, neighbors_view, distance_view); + } + } +}; + template struct knn : public fixture { explicit knn(const params& p, const TransferStrategy& strategy, const Scope& scope) @@ -378,8 +488,15 @@ struct knn : public fixture { }; inline const std::vector kInputs{ - {2000000, 128, 1000, 32}, {10000000, 128, 1000, 32}, {10000, 8192, 1000, 32}}; - + {2000000, 128, 1000, 32, 0}, {10000000, 128, 1000, 32, 0}, {10000, 8192, 1000, 32, 0}}; + +const std::vector kInputsFilter = + raft::util::itertools::product({size_t(10000000)}, // n_samples + {size_t(128)}, // n_dim + {size_t(1000)}, // n_queries + {size_t(255)}, // k + {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio + ); inline const std::vector kAllStrategies{ TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED}; inline const std::vector kNoCopyOnly{TransferStrategy::NO_COPY}; diff --git a/cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu new file mode 100644 index 0000000000..bf5118ceae --- /dev/null +++ b/cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, int64_t, ivf_flat_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull); + +} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu new file mode 100644 index 0000000000..9534515cbb --- /dev/null +++ b/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../knn.cuh" + +namespace raft::bench::spatial { + +KNN_REGISTER(float, int64_t, ivf_pq_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull); + +} // namespace raft::bench::spatial diff --git a/cpp/include/raft/neighbors/sample_filter.cuh b/cpp/include/raft/neighbors/sample_filter.cuh index 63384e50bd..f79394adf9 100644 --- a/cpp/include/raft/neighbors/sample_filter.cuh +++ b/cpp/include/raft/neighbors/sample_filter.cuh @@ -56,10 +56,10 @@ struct bitset_filter { */ template struct ivf_to_sample_filter { - index_t** const inds_ptrs_; + const index_t* const* inds_ptrs_; const filter_t next_filter_; - ivf_to_sample_filter(index_t** const inds_ptrs, const filter_t next_filter) + ivf_to_sample_filter(const index_t* const* inds_ptrs, const filter_t next_filter) : inds_ptrs_{inds_ptrs}, next_filter_{next_filter} { }