Skip to content

Commit

Permalink
[FEA] Support vector deletion in ANN IVF (#1831)
Browse files Browse the repository at this point in the history
PR based on the new Bitset feature (#1803) to support vector deletion in ANN.
Closes #1177. 
Closes #1620.
This PR adds `ivf_to_sample_filter` that acts as an intermediate filter to use an IVF index with a bitset filter.

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - William Hicks (https://github.com/wphicks)

URL: #1831
  • Loading branch information
lowener authored Nov 6, 2023
1 parent bafd2a8 commit 28eb0b3
Show file tree
Hide file tree
Showing 29 changed files with 649 additions and 45 deletions.
4 changes: 3 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ INSTALL_TARGET=install
BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"

CACHE_ARGS=""
Expand Down Expand Up @@ -324,6 +324,8 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
$CMAKE_TARGET == *"DISTANCE_TEST"* || \
$CMAKE_TARGET == *"MATRIX_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_TEST"* || \
$CMAKE_TARGET == *"SPARSE_DIST_TEST" || \
$CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \
Expand Down
2 changes: 2 additions & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@ if(BUILD_PRIMS_BENCH)
bench/prims/neighbors/knn/brute_force_float_int64_t.cu
bench/prims/neighbors/knn/brute_force_float_uint32_t.cu
bench/prims/neighbors/knn/cagra_float_uint32_t.cu
bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu
bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu
bench/prims/neighbors/refine_float_int64_t.cu
Expand Down
119 changes: 116 additions & 3 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,21 @@

#include <raft/random/rng.cuh>

#include <raft/core/bitset.cuh>
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <rmm/mr/host/new_delete_resource.hpp>
#include <rmm/mr/host/pinned_memory_resource.hpp>

#include <thrust/sequence.h>

#include <optional>

namespace raft::bench::spatial {
Expand All @@ -44,11 +49,14 @@ struct params {
size_t n_queries;
/** Number of nearest neighbours to find for every probe. */
size_t k;
/** Ratio of removed indices. */
double removed_ratio;
};

inline auto operator<<(std::ostream& os, const params& p) -> std::ostream&
{
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k;
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#"
<< p.removed_ratio;
return os;
}

Expand Down Expand Up @@ -221,6 +229,104 @@ struct brute_force_knn {
}
};

template <typename ValT, typename IdxT>
struct ivf_flat_filter_knn {
using dist_t = float;

std::optional<const raft::neighbors::ivf_flat::index<ValT, IdxT>> index;
raft::neighbors::ivf_flat::index_params index_params;
raft::neighbors::ivf_flat::search_params search_params;
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset_;
params ps;

ivf_flat_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data)
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index.emplace(raft::neighbors::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples);
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
}

void search(const raft::device_resources& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
{
search_params.n_probes = 20;
auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto neighbors_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto distance_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view());

if (ps.removed_ratio > 0) {
raft::neighbors::ivf_flat::search_with_filtering(
handle, search_params, *index, queries_view, neighbors_view, distance_view, filter);
} else {
raft::neighbors::ivf_flat::search(
handle, search_params, *index, queries_view, neighbors_view, distance_view);
}
}
};

template <typename ValT, typename IdxT>
struct ivf_pq_filter_knn {
using dist_t = float;

std::optional<const raft::neighbors::ivf_pq::index<IdxT>> index;
raft::neighbors::ivf_pq::index_params index_params;
raft::neighbors::ivf_pq::search_params search_params;
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset_;
params ps;

ivf_pq_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data)
: ps(ps), removed_indices_bitset_(handle, ps.n_samples)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples);
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
}

void search(const raft::device_resources& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
{
search_params.n_probes = 20;
auto queries_view =
raft::make_device_matrix_view<const ValT, uint32_t>(search_items, ps.n_queries, ps.n_dims);
auto neighbors_view =
raft::make_device_matrix_view<IdxT, uint32_t>(out_idxs, ps.n_queries, ps.k);
auto distance_view =
raft::make_device_matrix_view<dist_t, uint32_t>(out_dists, ps.n_queries, ps.k);
auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view());

if (ps.removed_ratio > 0) {
raft::neighbors::ivf_pq::search_with_filtering(
handle, search_params, *index, queries_view, neighbors_view, distance_view, filter);
} else {
raft::neighbors::ivf_pq::search(
handle, search_params, *index, queries_view, neighbors_view, distance_view);
}
}
};

template <typename ValT, typename IdxT, typename ImplT>
struct knn : public fixture {
explicit knn(const params& p, const TransferStrategy& strategy, const Scope& scope)
Expand Down Expand Up @@ -378,8 +484,15 @@ struct knn : public fixture {
};

inline const std::vector<params> kInputs{
{2000000, 128, 1000, 32}, {10000000, 128, 1000, 32}, {10000, 8192, 1000, 32}};

{2000000, 128, 1000, 32, 0}, {10000000, 128, 1000, 32, 0}, {10000, 8192, 1000, 32, 0}};

const std::vector<params> kInputsFilter =
raft::util::itertools::product<params>({size_t(10000000)}, // n_samples
{size_t(128)}, // n_dim
{size_t(1000)}, // n_queries
{size_t(255)}, // k
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);
inline const std::vector<TransferStrategy> kAllStrategies{
TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED};
inline const std::vector<TransferStrategy> kNoCopyOnly{TransferStrategy::NO_COPY};
Expand Down
24 changes: 24 additions & 0 deletions cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../knn.cuh"

namespace raft::bench::spatial {

KNN_REGISTER(float, int64_t, ivf_flat_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull);

} // namespace raft::bench::spatial
24 changes: 24 additions & 0 deletions cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter
#include "../knn.cuh"

namespace raft::bench::spatial {

KNN_REGISTER(float, int64_t, ivf_pq_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull);

} // namespace raft::bench::spatial
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ inline auto build(raft::resources const& handle,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
"unsupported data type");
RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset");
RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists");

index<T, IdxT> index(handle, params, dim);
utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream);
Expand Down
41 changes: 23 additions & 18 deletions cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/distance/distance_types.hpp>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/neighbors/ivf_flat_types.hpp>
#include <raft/neighbors/sample_filter_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/cuda_rt_essentials.hpp> // RAFT_CUDA_TRY
#include <raft/util/device_loads_stores.cuh>
Expand Down Expand Up @@ -737,10 +738,11 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)

// This is the vector a given lane/thread handles
const uint32_t vec_id = group_id * WarpSize + lane_id;
const bool valid = vec_id < list_length;
const bool valid =
vec_id < list_length && sample_filter(queries_offset + blockIdx.y, list_id, vec_id);

// Process first shm_assisted_dim dimensions (always using shared memory)
if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) {
if (valid) {
loadAndComputeDist<kUnroll, decltype(compute_dist), Veclen, T, AccT> lc(dist,
compute_dist);
for (int pos = 0; pos < shm_assisted_dim;
Expand Down Expand Up @@ -1096,22 +1098,25 @@ void ivfflat_interleaved_scan(const index<T, IdxT>& index,
rmm::cuda_stream_view stream)
{
const int capacity = bound_by_power_of_two(k);
select_interleaved_scan_kernel<T, AccT, IdxT, IvfSampleFilterT>::run(capacity,
index.veclen(),
select_min,
metric,
index,
queries,
coarse_query_results,
n_queries,
queries_offset,
n_probes,
k,
sample_filter,
neighbors,
distances,
grid_dim_x,
stream);

auto filter_adapter = raft::neighbors::filtering::ivf_to_sample_filter(
index.inds_ptrs().data_handle(), sample_filter);
select_interleaved_scan_kernel<T, AccT, IdxT, decltype(filter_adapter)>::run(capacity,
index.veclen(),
select_min,
metric,
index,
queries,
coarse_query_results,
n_queries,
queries_offset,
n_probes,
k,
filter_adapter,
neighbors,
distances,
grid_dim_x,
stream);
}

} // namespace raft::neighbors::ivf_flat::detail
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,7 @@ auto build(raft::resources const& handle,
"Unsupported data type");

RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset");
RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists");

auto stream = resource::get_cuda_stream(handle);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,25 +180,38 @@ auto compute_similarity_select(const cudaDeviceProp& dev_props,
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>,
raft::neighbors::filtering::none_ivf_sample_filter);
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>,
raft::neighbors::filtering::none_ivf_sample_filter);
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
half, half, raft::neighbors::filtering::none_ivf_sample_filter);
half,
half,
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, half, raft::neighbors::filtering::none_ivf_sample_filter);
float,
half,
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float, float, raft::neighbors::filtering::none_ivf_sample_filter);
float,
float,
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>,
raft::neighbors::filtering::none_ivf_sample_filter);
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);
instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(
float,
raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>,
raft::neighbors::filtering::none_ivf_sample_filter);
raft::neighbors::filtering::ivf_to_sample_filter<
int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>);

#undef COMMA

Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/neighbors/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,9 @@ inline void search(raft::resources const& handle,
rmm::device_uvector<float> rot_queries(max_queries * index.rot_dim(), stream, mr);
rmm::device_uvector<uint32_t> clusters_to_probe(max_queries * n_probes, stream, mr);

auto search_instance = ivfpq_search<IdxT, IvfSampleFilterT>::fun(params, index.metric());
auto filter_adapter = raft::neighbors::filtering::ivf_to_sample_filter(
index.inds_ptrs().data_handle(), sample_filter);
auto search_instance = ivfpq_search<IdxT, decltype(filter_adapter)>::fun(params, index.metric());

for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) {
uint32_t queries_batch = min(max_queries, n_queries - offset_q);
Expand Down Expand Up @@ -850,7 +852,7 @@ inline void search(raft::resources const& handle,
distances + uint64_t(k) * (offset_q + offset_b),
utils::config<T>::kDivisor / utils::config<float>::kDivisor,
params.preferred_shmem_carveout,
sample_filter);
filter_adapter);
}
}
}
Expand Down
Loading

0 comments on commit 28eb0b3

Please sign in to comment.