Skip to content

Commit

Permalink
Add bench for IVF filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Oct 17, 2023
1 parent ae62ee3 commit 1a78fb8
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 5 deletions.
2 changes: 2 additions & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@ if(BUILD_PRIMS_BENCH)
bench/prims/neighbors/knn/brute_force_float_int64_t.cu
bench/prims/neighbors/knn/brute_force_float_uint32_t.cu
bench/prims/neighbors/knn/cagra_float_uint32_t.cu
bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu
bench/prims/neighbors/refine_float_int64_t.cu
Expand Down
123 changes: 120 additions & 3 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,21 @@

#include <raft/random/rng.cuh>

#include <raft/core/bitset.cuh>
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <rmm/mr/host/new_delete_resource.hpp>
#include <rmm/mr/host/pinned_memory_resource.hpp>

#include <thrust/sequence.h>

#include <optional>

namespace raft::bench::spatial {
Expand All @@ -44,11 +49,14 @@ struct params {
size_t n_queries;
/** Number of nearest neighbours to find for every probe. */
size_t k;
/** Ratio of removed indices. */
double removed_ratio;
};

inline auto operator<<(std::ostream& os, const params& p) -> std::ostream&
{
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k;
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#"
<< p.removed_ratio;
return os;
}

Expand Down Expand Up @@ -221,6 +229,108 @@ struct brute_force_knn {
}
};

template <typename ValT, typename IdxT>
struct ivf_flat_filter_knn {
using dist_t = float;

std::optional<const raft::neighbors::ivf_flat::index<ValT, IdxT>> index;
raft::neighbors::ivf_flat::index_params index_params;
raft::neighbors::ivf_flat::search_params search_params;
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset_;
params ps;

ivf_flat_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data)
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index.emplace(raft::neighbors::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples);
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
}

void search(const raft::device_resources& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
{
search_params.n_probes = 20;
auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto neighbors_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto distance_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
auto filter = raft::neighbors::filtering::ivf_to_sample_filter(
index->inds_ptrs().data_handle(),
raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view()));

if (ps.removed_ratio > 0) {
raft::neighbors::ivf_flat::search_with_filtering(
handle, search_params, *index, queries_view, neighbors_view, distance_view, filter);
} else {
raft::neighbors::ivf_flat::search(
handle, search_params, *index, queries_view, neighbors_view, distance_view);
}
}
};

template <typename ValT, typename IdxT>
struct ivf_pq_filter_knn {
using dist_t = float;

std::optional<const raft::neighbors::ivf_pq::index<IdxT>> index;
raft::neighbors::ivf_pq::index_params index_params;
raft::neighbors::ivf_pq::search_params search_params;
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset_;
params ps;

ivf_pq_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data)
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples);
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
}

void search(const raft::device_resources& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
{
search_params.n_probes = 20;
auto queries_view =
raft::make_device_matrix_view<const ValT, uint32_t>(search_items, ps.n_queries, ps.n_dims);
auto neighbors_view =
raft::make_device_matrix_view<IdxT, uint32_t>(out_idxs, ps.n_queries, ps.k);
auto distance_view =
raft::make_device_matrix_view<dist_t, uint32_t>(out_dists, ps.n_queries, ps.k);
auto filter = raft::neighbors::filtering::ivf_to_sample_filter(
index->inds_ptrs().data_handle(),
raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view()));

if (ps.removed_ratio > 0) {
raft::neighbors::ivf_pq::search_with_filtering(
handle, search_params, *index, queries_view, neighbors_view, distance_view, filter);
} else {
raft::neighbors::ivf_pq::search(
handle, search_params, *index, queries_view, neighbors_view, distance_view);
}
}
};

template <typename ValT, typename IdxT, typename ImplT>
struct knn : public fixture {
explicit knn(const params& p, const TransferStrategy& strategy, const Scope& scope)
Expand Down Expand Up @@ -378,8 +488,15 @@ struct knn : public fixture {
};

inline const std::vector<params> kInputs{
{2000000, 128, 1000, 32}, {10000000, 128, 1000, 32}, {10000, 8192, 1000, 32}};

{2000000, 128, 1000, 32, 0}, {10000000, 128, 1000, 32, 0}, {10000, 8192, 1000, 32, 0}};

const std::vector<params> kInputsFilter =
raft::util::itertools::product<params>({size_t(10000000)}, // n_samples
{size_t(128)}, // n_dim
{size_t(1000)}, // n_queries
{size_t(255)}, // k
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);
inline const std::vector<TransferStrategy> kAllStrategies{
TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED};
inline const std::vector<TransferStrategy> kNoCopyOnly{TransferStrategy::NO_COPY};
Expand Down
24 changes: 24 additions & 0 deletions cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../knn.cuh"

namespace raft::bench::spatial {

KNN_REGISTER(float, int64_t, ivf_flat_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull);

} // namespace raft::bench::spatial
24 changes: 24 additions & 0 deletions cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../knn.cuh"

namespace raft::bench::spatial {

KNN_REGISTER(float, int64_t, ivf_pq_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull);

} // namespace raft::bench::spatial
4 changes: 2 additions & 2 deletions cpp/include/raft/neighbors/sample_filter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ struct bitset_filter {
*/
template <typename index_t, typename filter_t>
struct ivf_to_sample_filter {
index_t** const inds_ptrs_;
const index_t* const* inds_ptrs_;
const filter_t next_filter_;

ivf_to_sample_filter(index_t** const inds_ptrs, const filter_t next_filter)
ivf_to_sample_filter(const index_t* const* inds_ptrs, const filter_t next_filter)
: inds_ptrs_{inds_ptrs}, next_filter_{next_filter}
{
}
Expand Down

0 comments on commit 1a78fb8

Please sign in to comment.