From ae62ee3e5edc313d5613486bcac062ef9244bec6 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 17 Oct 2023 16:59:02 +0200 Subject: [PATCH] Add tests for IVF-PQ filtering --- cpp/test/CMakeLists.txt | 4 +- cpp/test/neighbors/ann_ivf_flat.cuh | 2 - .../ann_ivf_flat/test_filter_float_int64_t.cu | 29 +++ .../ann_ivf_flat/test_float_int64_t.cu | 1 - cpp/test/neighbors/ann_ivf_pq.cuh | 165 ++++++++++++++++++ .../ann_ivf_pq/test_filter_float_int64_t.cu | 26 +++ .../ann_ivf_pq/test_filter_int8_t_int64_t.cu | 27 +++ .../ann_ivf_pq/test_float_uint32_t.cu | 5 +- 8 files changed, 254 insertions(+), 5 deletions(-) create mode 100644 cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu create mode 100644 cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu create mode 100644 cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 9b9b882d1d..6c03da8d7f 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -379,14 +379,16 @@ if(BUILD_TESTS) NAME NEIGHBORS_ANN_IVF_TEST PATH + test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu - test/neighbors/ann_ivf_pq/test_float_int64_t.cu test/neighbors/ann_ivf_pq/test_float_uint32_t.cu test/neighbors/ann_ivf_pq/test_float_int64_t.cu test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu + test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu + test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu LIB EXPLICIT_INSTANTIATE_ONLY GPUS diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 9e22faaecb..a23c890cf4 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -15,8 +15,6 @@ */ #pragma once -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter - #include "../test_utils.cuh" #include "ann_utils.cuh" #include diff --git a/cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu new file mode 100644 index 0000000000..0e1036e566 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../ann_ivf_flat.cuh" + +namespace raft::neighbors::ivf_flat { + +typedef AnnIVFFlatTest AnnIVFFlatFilterTestF; +TEST_P(AnnIVFFlatFilterTestF, AnnIVFFlatFilter) { this->testFilter(); } + +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatFilterTestF, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index d22c3837a3..3bfea283e5 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -25,7 +25,6 @@ TEST_P(AnnIVFFlatTestF, AnnIVFFlat) { this->testIVFFlat(); this->testPacker(); - this->testFilter(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index d1f5ee5b03..9aaa4d9b93 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -39,6 +40,7 @@ #include #include +#include #include #include @@ -48,6 +50,10 @@ namespace raft::neighbors::ivf_pq { +struct test_ivf_sample_filter { + static constexpr unsigned offset = 3000; +}; + struct ivf_pq_inputs { uint32_t num_db_vecs = 4096; uint32_t num_queries = 1024; @@ -499,6 +505,165 @@ class ivf_pq_test : public ::testing::TestWithParam { std::vector distances_ref; // NOLINT }; +template +class ivf_pq_filter_test : public ::testing::TestWithParam { + public: + ivf_pq_filter_test() + : stream_(resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_), + search_queries(0, stream_) + { + } + + void gen_data() + { + database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_); + search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_); + + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); + } + resource::sync_stream(handle_); + } + + void calc_ref() + { + size_t queries_size = size_t{ps.num_queries} * size_t{ps.k}; + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + search_queries.data(), + database.data() + test_ivf_sample_filter::offset * ps.dim, + ps.num_queries, + ps.num_db_vecs - test_ivf_sample_filter::offset, + ps.dim, + ps.k, + ps.index_params.metric); + raft::linalg::addScalar(indices_naive_dev.data(), + indices_naive_dev.data(), + IdxT(test_ivf_sample_filter::offset), + queries_size, + stream_); + distances_ref.resize(queries_size); + update_host(distances_ref.data(), distances_naive_dev.data(), queries_size, stream_); + indices_ref.resize(queries_size); + update_host(indices_ref.data(), indices_naive_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + } + + auto build_only() + { + auto ipams = ps.index_params; + ipams.add_data_on_build = true; + + auto index_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + return ivf_pq::build(handle_, ipams, index_view); + } + + template + void run(BuildIndex build_index) + { + index index = build_index(); + + double compression_ratio = + static_cast(ps.dim * 8) / static_cast(index.pq_dim() * index.pq_bits()); + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivf_pq(queries_size); + std::vector distances_ivf_pq(queries_size); + + rmm::device_uvector distances_ivf_pq_dev(queries_size, stream_); + rmm::device_uvector indices_ivf_pq_dev(queries_size, stream_); + + auto query_view = + raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); + auto inds_view = raft::make_device_matrix_view( + indices_ivf_pq_dev.data(), ps.num_queries, ps.k); + auto dists_view = raft::make_device_matrix_view( + distances_ivf_pq_dev.data(), ps.num_queries, ps.k); + + // Create Bitset filter + auto removed_indices = + raft::make_device_vector(handle_, test_ivf_sample_filter::offset); + thrust::sequence( + resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + test_ivf_sample_filter::offset)); + resource::sync_stream(handle_); + + raft::core::bitset removed_indices_bitset( + handle_, removed_indices.view(), ps.num_db_vecs); + ivf_pq::search_with_filtering( + handle_, + ps.search_params, + index, + query_view, + inds_view, + dists_view, + raft::neighbors::filtering::ivf_to_sample_filter( + index.inds_ptrs().data_handle(), + raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view()))); + + update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); + update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); + resource::sync_stream(handle_); + + // A very conservative lower bound on recall + double min_recall = + static_cast(ps.search_params.n_probes) / static_cast(ps.index_params.n_lists); + // Using a heuristic to lower the required recall due to code-packing errors + min_recall = + std::min(std::erfc(0.05 * compression_ratio / std::max(min_recall, 0.5)), min_recall); + // Use explicit per-test min recall value if provided. + min_recall = ps.min_recall.value_or(min_recall); + + ASSERT_TRUE(eval_neighbours(indices_ref, + indices_ivf_pq, + distances_ref, + distances_ivf_pq, + ps.num_queries, + ps.k, + 0.0001 * compression_ratio, + min_recall)) + << ps; + } + + void SetUp() override // NOLINT + { + gen_data(); + calc_ref(); + } + + void TearDown() override // NOLINT + { + cudaGetLastError(); + resource::sync_stream(handle_); + database.resize(0, stream_); + search_queries.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + ivf_pq_inputs ps; // NOLINT + rmm::device_uvector database; // NOLINT + rmm::device_uvector search_queries; // NOLINT + std::vector indices_ref; // NOLINT + std::vector distances_ref; // NOLINT +}; + /* Test cases */ using test_cases_t = std::vector; diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu new file mode 100644 index 0000000000..17f72fb08a --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../ann_ivf_pq.cuh" + +namespace raft::neighbors::ivf_pq { + +using f32_f32_i64_filter = ivf_pq_filter_test; + +TEST_BUILD_SEARCH(f32_f32_i64_filter) +INSTANTIATE(f32_f32_i64_filter, defaults() + big_dims_moderate_lut()); +} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu new file mode 100644 index 0000000000..537dbb4979 --- /dev/null +++ b/cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter +#include "../ann_ivf_pq.cuh" + +namespace raft::neighbors::ivf_pq { + +using f32_i08_i64_filter = ivf_pq_filter_test; + +TEST_BUILD_SEARCH(f32_i08_i64_filter) +INSTANTIATE(f32_i08_i64_filter, big_dims()); + +} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu index 3d362a5261..5405ddc4a3 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu @@ -25,10 +25,13 @@ namespace raft::neighbors::ivf_pq { -using f32_f32_u32 = ivf_pq_test; +using f32_f32_u32 = ivf_pq_test; +using f32_f32_u32_filter = ivf_pq_filter_test; TEST_BUILD_SEARCH(f32_f32_u32) TEST_BUILD_SERIALIZE_SEARCH(f32_f32_u32) INSTANTIATE(f32_f32_u32, defaults() + var_n_probes() + var_k() + special_cases()); +TEST_BUILD_SEARCH(f32_f32_u32_filter) +INSTANTIATE(f32_f32_u32_filter, defaults()); } // namespace raft::neighbors::ivf_pq