Skip to content

Commit

Permalink
add c++/python test, bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Feb 19, 2024
1 parent 5282501 commit 5b136f8
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 39 deletions.
65 changes: 41 additions & 24 deletions cpp/include/raft/neighbors/detail/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ namespace raft::neighbors::ivf::detail {
template <typename IdxT>
constexpr static IdxT kOutOfBoundsRecord = std::numeric_limits<IdxT>::max();

template <typename T, typename IdxT>
template <typename T, typename IdxT, bool Ascending = true>
struct dummy_block_sort_t {
using queue_t = matrix::detail::select::warpsort::warp_sort_distributed<WarpSize, true, T, IdxT>;
using queue_t =
matrix::detail::select::warpsort::warp_sort_distributed<WarpSize, Ascending, T, IdxT>;
template <typename... Args>
__device__ dummy_block_sort_t(int k, Args...){};
};
Expand Down Expand Up @@ -215,36 +216,52 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
float scaling_factor,
rmm::cuda_stream_view stream)
{
size_t len = size_t(n_queries) * size_t(topk);
constexpr bool needs_cast = !std::is_same<ScoreInT, ScoreOutT>::value;
size_t len = size_t(n_queries) * size_t(topk);
switch (metric) {
case distance::DistanceType::L2Unexpanded:
case distance::DistanceType::L2Expanded: {
linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::mul_const_op<ScoreOutT>{scaling_factor * scaling_factor},
raft::cast_op<ScoreOutT>{}),
stream);
if (scaling_factor != 0) {
linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::mul_const_op<ScoreOutT>{scaling_factor * scaling_factor},
raft::cast_op<ScoreOutT>{}),
stream);
} else if (needs_cast) {
linalg::unaryOp(out, in, len, raft::cast_op<ScoreOutT>{}, stream);
}
} break;
case distance::DistanceType::L2SqrtUnexpanded:
case distance::DistanceType::L2SqrtExpanded: {
linalg::unaryOp(out,
in,
len,
raft::compose_op{raft::mul_const_op<ScoreOutT>{scaling_factor},
raft::sqrt_op{},
raft::cast_op<ScoreOutT>{}},
stream);
if (scaling_factor != 0) {
linalg::unaryOp(out,
in,
len,
raft::compose_op{raft::mul_const_op<ScoreOutT>{scaling_factor},
raft::sqrt_op{},
raft::cast_op<ScoreOutT>{}},
stream);
} else if (needs_cast) {
linalg::unaryOp(
out, in, len, raft::compose_op{raft::sqrt_op{}, raft::cast_op<ScoreOutT>{}}, stream);
} else {
linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream);
}
} break;
case distance::DistanceType::InnerProduct: {
linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::mul_const_op<ScoreOutT>{-scaling_factor * scaling_factor},
raft::cast_op<ScoreOutT>{}),
stream);
if (scaling_factor != 0) {
linalg::unaryOp(
out,
in,
len,
raft::compose_op(raft::mul_const_op<ScoreOutT>{-scaling_factor * scaling_factor},
raft::cast_op<ScoreOutT>{}),
stream);
} else if (needs_cast) {
linalg::unaryOp(out, in, len, raft::cast_op<ScoreOutT>{}, stream);
}
} break;
default: RAFT_FAIL("Unexpected metric.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,9 @@ struct flat_block_sort {
};

template <typename T, bool Ascending, typename IdxT>
struct flat_block_sort<0, Ascending, T, IdxT> : ivf::detail::dummy_block_sort_t<T, IdxT> {
using type = ivf::detail::dummy_block_sort_t<T, IdxT>;
struct flat_block_sort<0, Ascending, T, IdxT>
: ivf::detail::dummy_block_sort_t<T, IdxT, Ascending> {
using type = ivf::detail::dummy_block_sort_t<T, IdxT, Ascending>;
};

template <int Capacity, bool Ascending, typename T, typename IdxT>
Expand Down Expand Up @@ -944,7 +945,8 @@ void launch_kernel(Lambda lambda,
neighbors += grid_dim_y * grid_dim_x * k;
distances += grid_dim_y * grid_dim_x * k;
} else {
neighbors += grid_dim_y * max_samples;
distances += grid_dim_y * max_samples;
chunk_indices += grid_dim_y * n_probes;
}
coarse_index += grid_dim_y * n_probes;
}
Expand Down
18 changes: 7 additions & 11 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ void search_impl(raft::resources const& handle,
refined_distances_dev.data(),
refined_indices_dev.data(),
n_queries,
manage_local_topk ? k * grid_dim_x : max_samples,
manage_local_topk ? (k * grid_dim_x) : max_samples,
k,
distances,
neighbors,
Expand All @@ -236,7 +236,7 @@ void search_impl(raft::resources const& handle,
if (!manage_local_topk) {
// post process distances && neighbor IDs
ivf::detail::postprocess_distances(
distances, distances, index.metric(), n_queries, k, 1.0, stream);
distances, distances, index.metric(), n_queries, k, 0, stream);
ivf::detail::postprocess_neighbors(neighbors,
neighbors,
index.inds_ptrs().data_handle(),
Expand Down Expand Up @@ -276,7 +276,8 @@ inline void search(raft::resources const& handle,

uint32_t max_samples = 0;
if (!manage_local_topk) {
IdxT ms = Pow2<128 / sizeof(float)>::roundUp(index.accum_sorted_sizes()(n_probes));
IdxT ms =
Pow2<128 / sizeof(float)>::roundUp(std::max<IdxT>(index.accum_sorted_sizes()(n_probes), k));
RAFT_EXPECTS(ms <= IdxT(std::numeric_limits<uint32_t>::max()),
"The maximum sample size is too big.");
max_samples = ms;
Expand All @@ -286,14 +287,9 @@ inline void search(raft::resources const& handle,
constexpr uint64_t kExpectedWsSize = 1024 * 1024 * 1024;
uint64_t max_ws_size = std::min(resource::get_workspace_free_bytes(handle), kExpectedWsSize);

uint64_t ws_size_per_query =
// fixed
4ull * (2 * n_probes + index.n_lists() + index.dim() + 1)
// fused
+ manage_local_topk
? ((sizeof(IdxT) + 4) * n_probes * k)
// non-fused
: (4ull * (max_samples + n_probes + 1));
uint64_t ws_size_per_query = 4ull * (2 * n_probes + index.n_lists() + index.dim() + 1) +
(manage_local_topk ? ((sizeof(IdxT) + 4) * n_probes * k)
: (4ull * (max_samples + n_probes + 1)));

const uint32_t max_queries =
std::min<uint32_t>(n_queries, raft::div_rounding_up_safe(max_ws_size, ws_size_per_query));
Expand Down
17 changes: 17 additions & 0 deletions cpp/test/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,23 @@ const std::vector<AnnIvfFlatInputs<int64_t>> inputs = {
{1000, 100000, 16, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, true},
{10000, 131072, 8, 10, 20, 1024, raft::distance::DistanceType::L2Expanded, false},

// various combinations with k>raft::matrix::detail::select::warpsort::kMaxCapacity
{1000, 10000, 16, 1024, 40, 1024, raft::distance::DistanceType::L2SqrtExpanded, true},
{1000, 10000, 2053, 512, 50, 1024, raft::distance::DistanceType::L2SqrtExpanded, false},
{1000, 10000, 2049, 2048, 70, 1024, raft::distance::DistanceType::L2SqrtExpanded, false},
{1000, 10000, 16, 4000, 100, 2048, raft::distance::DistanceType::L2SqrtExpanded, false},
{10, 10000, 16, 4000, 100, 2048, raft::distance::DistanceType::L2SqrtExpanded, false},
{10, 10000, 16, 4000, 120, 2048, raft::distance::DistanceType::L2SqrtExpanded, true},
{20, 100000, 16, 257, 20, 1024, raft::distance::DistanceType::L2SqrtExpanded, true},
{1000, 100000, 16, 259, 20, 1024, raft::distance::DistanceType::L2Expanded, true, true},
{10000, 131072, 8, 280, 20, 1024, raft::distance::DistanceType::InnerProduct, false},
{100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::L2Expanded, false},
{100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::L2SqrtExpanded, false},
{100000, 1024, 32, 257, 64, 64, raft::distance::DistanceType::InnerProduct, false},
{100000, 1024, 16, 300, 20, 60, raft::distance::DistanceType::L2Expanded, false},
{100000, 1024, 16, 500, 20, 60, raft::distance::DistanceType::L2SqrtExpanded, false},
{100000, 1024, 16, 700, 20, 60, raft::distance::DistanceType::InnerProduct, false},

// host input data
{1000, 10000, 16, 10, 40, 1024, raft::distance::DistanceType::L2Expanded, false, true},
{1000, 10000, 16, 10, 50, 1024, raft::distance::DistanceType::L2Expanded, false, true},
Expand Down
10 changes: 9 additions & 1 deletion python/pylibraft/pylibraft/test/test_ivf_flat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -297,6 +297,14 @@ def test_ivf_flat_params(params):
"k": 129,
"n_probes": 100,
},
{
"k": 257,
"n_probes": 100,
},
{
"k": 4096,
"n_probes": 100,
},
],
)
def test_ivf_pq_search_params(params):
Expand Down

0 comments on commit 5b136f8

Please sign in to comment.