Skip to content

Commit

Permalink
Add tests for IVF-PQ filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Oct 17, 2023
1 parent 57d3b12 commit ae62ee3
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 5 deletions.
4 changes: 3 additions & 1 deletion cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,16 @@ if(BUILD_TESTS)
NAME
NEIGHBORS_ANN_IVF_TEST
PATH
test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu
test/neighbors/ann_ivf_flat/test_float_int64_t.cu
test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu
test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu
test/neighbors/ann_ivf_pq/test_float_int64_t.cu
test/neighbors/ann_ivf_pq/test_float_uint32_t.cu
test/neighbors/ann_ivf_pq/test_float_int64_t.cu
test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu
test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu
test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu
test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu
LIB
EXPLICIT_INSTANTIATE_ONLY
GPUS
Expand Down
2 changes: 0 additions & 2 deletions cpp/test/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
#pragma once

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter

#include "../test_utils.cuh"
#include "ann_utils.cuh"
#include <raft/core/device_mdarray.hpp>
Expand Down
29 changes: 29 additions & 0 deletions cpp/test/neighbors/ann_ivf_flat/test_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../ann_ivf_flat.cuh"

namespace raft::neighbors::ivf_flat {

typedef AnnIVFFlatTest<float, float, std::int64_t> AnnIVFFlatFilterTestF;
TEST_P(AnnIVFFlatFilterTestF, AnnIVFFlatFilter) { this->testFilter(); }

INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatFilterTestF, ::testing::ValuesIn(inputs));

} // namespace raft::neighbors::ivf_flat
1 change: 0 additions & 1 deletion cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ TEST_P(AnnIVFFlatTestF, AnnIVFFlat)
{
this->testIVFFlat();
this->testPacker();
this->testFilter();
}

INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF, ::testing::ValuesIn(inputs));
Expand Down
165 changes: 165 additions & 0 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/ivf_pq_helpers.cuh>
#include <raft/neighbors/ivf_pq_serialize.cuh>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/random/rng.cuh>

#include <rmm/cuda_stream_view.hpp>
Expand All @@ -39,6 +40,7 @@
#include <gtest/gtest.h>

#include <cub/cub.cuh>
#include <thrust/sequence.h>

#include <algorithm>
#include <cstddef>
Expand All @@ -48,6 +50,10 @@

namespace raft::neighbors::ivf_pq {

struct test_ivf_sample_filter {
static constexpr unsigned offset = 3000;
};

struct ivf_pq_inputs {
uint32_t num_db_vecs = 4096;
uint32_t num_queries = 1024;
Expand Down Expand Up @@ -499,6 +505,165 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
std::vector<EvalT> distances_ref; // NOLINT
};

template <typename EvalT, typename DataT, typename IdxT>
class ivf_pq_filter_test : public ::testing::TestWithParam<ivf_pq_inputs> {
public:
ivf_pq_filter_test()
: stream_(resource::get_cuda_stream(handle_)),
ps(::testing::TestWithParam<ivf_pq_inputs>::GetParam()),
database(0, stream_),
search_queries(0, stream_)
{
}

void gen_data()
{
database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_);
search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_);

raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
raft::random::uniform(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0));
raft::random::uniform(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0));
} else {
raft::random::uniformInt(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}

void calc_ref()
{
size_t queries_size = size_t{ps.num_queries} * size_t{ps.k};
rmm::device_uvector<EvalT> distances_naive_dev(queries_size, stream_);
rmm::device_uvector<IdxT> indices_naive_dev(queries_size, stream_);
naive_knn<EvalT, DataT, IdxT>(handle_,
distances_naive_dev.data(),
indices_naive_dev.data(),
search_queries.data(),
database.data() + test_ivf_sample_filter::offset * ps.dim,
ps.num_queries,
ps.num_db_vecs - test_ivf_sample_filter::offset,
ps.dim,
ps.k,
ps.index_params.metric);
raft::linalg::addScalar(indices_naive_dev.data(),
indices_naive_dev.data(),
IdxT(test_ivf_sample_filter::offset),
queries_size,
stream_);
distances_ref.resize(queries_size);
update_host(distances_ref.data(), distances_naive_dev.data(), queries_size, stream_);
indices_ref.resize(queries_size);
update_host(indices_ref.data(), indices_naive_dev.data(), queries_size, stream_);
resource::sync_stream(handle_);
}

auto build_only()
{
auto ipams = ps.index_params;
ipams.add_data_on_build = true;

auto index_view =
raft::make_device_matrix_view<DataT, IdxT>(database.data(), ps.num_db_vecs, ps.dim);
return ivf_pq::build<DataT, IdxT>(handle_, ipams, index_view);
}

template <typename BuildIndex>
void run(BuildIndex build_index)
{
index<IdxT> index = build_index();

double compression_ratio =
static_cast<double>(ps.dim * 8) / static_cast<double>(index.pq_dim() * index.pq_bits());
size_t queries_size = ps.num_queries * ps.k;
std::vector<IdxT> indices_ivf_pq(queries_size);
std::vector<EvalT> distances_ivf_pq(queries_size);

rmm::device_uvector<EvalT> distances_ivf_pq_dev(queries_size, stream_);
rmm::device_uvector<IdxT> indices_ivf_pq_dev(queries_size, stream_);

auto query_view =
raft::make_device_matrix_view<DataT, uint32_t>(search_queries.data(), ps.num_queries, ps.dim);
auto inds_view = raft::make_device_matrix_view<IdxT, uint32_t>(
indices_ivf_pq_dev.data(), ps.num_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<EvalT, uint32_t>(
distances_ivf_pq_dev.data(), ps.num_queries, ps.k);

// Create Bitset filter
auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle_, test_ivf_sample_filter::offset);
thrust::sequence(
resource::get_thrust_policy(handle_),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + test_ivf_sample_filter::offset));
resource::sync_stream(handle_);

raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset(
handle_, removed_indices.view(), ps.num_db_vecs);
ivf_pq::search_with_filtering<DataT, IdxT>(
handle_,
ps.search_params,
index,
query_view,
inds_view,
dists_view,
raft::neighbors::filtering::ivf_to_sample_filter(
index.inds_ptrs().data_handle(),
raft::neighbors::filtering::bitset_filter(removed_indices_bitset.view())));

update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_);
update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_);
resource::sync_stream(handle_);

// A very conservative lower bound on recall
double min_recall =
static_cast<double>(ps.search_params.n_probes) / static_cast<double>(ps.index_params.n_lists);
// Using a heuristic to lower the required recall due to code-packing errors
min_recall =
std::min(std::erfc(0.05 * compression_ratio / std::max(min_recall, 0.5)), min_recall);
// Use explicit per-test min recall value if provided.
min_recall = ps.min_recall.value_or(min_recall);

ASSERT_TRUE(eval_neighbours(indices_ref,
indices_ivf_pq,
distances_ref,
distances_ivf_pq,
ps.num_queries,
ps.k,
0.0001 * compression_ratio,
min_recall))
<< ps;
}

void SetUp() override // NOLINT
{
gen_data();
calc_ref();
}

void TearDown() override // NOLINT
{
cudaGetLastError();
resource::sync_stream(handle_);
database.resize(0, stream_);
search_queries.resize(0, stream_);
}

private:
raft::resources handle_;
rmm::cuda_stream_view stream_;
ivf_pq_inputs ps; // NOLINT
rmm::device_uvector<DataT> database; // NOLINT
rmm::device_uvector<DataT> search_queries; // NOLINT
std::vector<IdxT> indices_ref; // NOLINT
std::vector<EvalT> distances_ref; // NOLINT
};

/* Test cases */
using test_cases_t = std::vector<ivf_pq_inputs>;

Expand Down
26 changes: 26 additions & 0 deletions cpp/test/neighbors/ann_ivf_pq/test_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../ann_ivf_pq.cuh"

namespace raft::neighbors::ivf_pq {

using f32_f32_i64_filter = ivf_pq_filter_test<float, float, int64_t>;

TEST_BUILD_SEARCH(f32_f32_i64_filter)
INSTANTIATE(f32_f32_i64_filter, defaults() + big_dims_moderate_lut());
} // namespace raft::neighbors::ivf_pq
27 changes: 27 additions & 0 deletions cpp/test/neighbors/ann_ivf_pq/test_filter_int8_t_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../ann_ivf_pq.cuh"

namespace raft::neighbors::ivf_pq {

using f32_i08_i64_filter = ivf_pq_filter_test<float, int8_t, int64_t>;

TEST_BUILD_SEARCH(f32_i08_i64_filter)
INSTANTIATE(f32_i08_i64_filter, big_dims());

} // namespace raft::neighbors::ivf_pq
5 changes: 4 additions & 1 deletion cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@

namespace raft::neighbors::ivf_pq {

using f32_f32_u32 = ivf_pq_test<float, float, uint32_t>;
using f32_f32_u32 = ivf_pq_test<float, float, uint32_t>;
using f32_f32_u32_filter = ivf_pq_filter_test<float, float, uint32_t>;

TEST_BUILD_SEARCH(f32_f32_u32)
TEST_BUILD_SERIALIZE_SEARCH(f32_f32_u32)
INSTANTIATE(f32_f32_u32, defaults() + var_n_probes() + var_k() + special_cases());

TEST_BUILD_SEARCH(f32_f32_u32_filter)
INSTANTIATE(f32_f32_u32_filter, defaults());
} // namespace raft::neighbors::ivf_pq

0 comments on commit ae62ee3

Please sign in to comment.