Skip to content

Commit

Permalink
Merge branch 'branch-23.12' into bf_batch_query
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred authored Nov 14, 2023
2 parents 5a859bc + 77bc461 commit eef2c4a
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 401 deletions.
10 changes: 9 additions & 1 deletion ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ sed_runner 's/release = .*/release = '"'${NEXT_FULL_TAG}'"'/g' docs/source/conf.
DEPENDENCIES=(
dask-cuda
pylibraft
pylibraft-cu11
pylibraft-cu12
rmm
rmm-cu11
rmm-cu12
rapids-dask-dependency
# ucx-py is handled separately below
)
Expand Down Expand Up @@ -84,9 +88,13 @@ sed_runner "s/RAPIDS_VERSION_NUMBER=\".*/RAPIDS_VERSION_NUMBER=\"${NEXT_SHORT_TA
sed_runner "/^PROJECT_NUMBER/ s|\".*\"|\"${NEXT_SHORT_TAG}\"|g" cpp/doxygen/Doxyfile

sed_runner "/^set(RAFT_VERSION/ s|\".*\"|\"${NEXT_SHORT_TAG}\"|g" docs/source/build.md
sed_runner "/GIT_TAG.*branch-/ s|branch-.*|branch-${NEXT_SHORT_TAG}|g" docs/source/build.md
sed_runner "s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" docs/source/build.md
sed_runner "/rapidsai\/raft/ s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" docs/source/developer_guide.md

sed_runner "s|:[0-9][0-9].[0-9][0-9]|:${NEXT_SHORT_TAG}|g" docs/source/raft_ann_benchmarks.md

sed_runner "s|branch-[0-9][0-9].[0-9][0-9]|branch-${NEXT_SHORT_TAG}|g" README.md

# .devcontainer files
find .devcontainer/ -type f -name devcontainer.json -print0 | while IFS= read -r -d '' filename; do
sed_runner "s@rapidsai/devcontainers:[0-9.]*@rapidsai/devcontainers:${NEXT_SHORT_TAG}@g" "${filename}"
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ void bench_search(::benchmark::State& state,
std::shared_ptr<buf<std::size_t>> neighbors =
std::make_shared<buf<std::size_t>>(algo_property.query_memory_type, k * query_set_size);

auto start = std::chrono::high_resolution_clock::now();
cuda_timer gpu_timer;
auto start = std::chrono::high_resolution_clock::now();
{
nvtx_case nvtx{state.name()};

Expand Down
84 changes: 22 additions & 62 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace raft::matrix::detail {

// this is a subset of algorithms, chosen by running the algorithm_selection
// notebook in cpp/scripts/heuristics/select_k
enum class Algo { kRadix11bits, kWarpDistributedShm, kFaissBlockSelect };
enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bitsExtraPass };

/**
* Predict the fastest select_k algorithm based on the number of rows/cols/k
Expand All @@ -50,73 +50,29 @@ enum class Algo { kRadix11bits, kWarpDistributedShm, kFaissBlockSelect };
*/
inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
{
if (k > 134) {
if (k > 256) {
if (k > 809) {
return Algo::kRadix11bits;
} else {
if (rows > 124) {
if (cols > 63488) {
return Algo::kFaissBlockSelect;
} else {
return Algo::kRadix11bits;
}
} else {
return Algo::kRadix11bits;
}
}
} else {
if (cols > 678736) {
return Algo::kWarpDistributedShm;
if (k > 256) {
if (cols > 16862) {
if (rows > 1020) {
return Algo::kRadix11bitsExtraPass;
} else {
return Algo::kRadix11bits;
}
} else {
return Algo::kRadix11bitsExtraPass;
}
} else {
if (cols > 13776) {
if (rows > 335) {
if (k > 1) {
if (rows > 546) {
return Algo::kWarpDistributedShm;
} else {
if (k > 17) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kFaissBlockSelect;
}
}
} else {
return Algo::kFaissBlockSelect;
}
if (k > 2) {
if (cols > 22061) {
return Algo::kWarpDistributedShm;
} else {
if (k > 44) {
if (cols > 1031051) {
return Algo::kWarpDistributedShm;
} else {
if (rows > 22) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kRadix11bits;
}
}
} else {
return Algo::kWarpDistributedShm;
}
}
} else {
if (k > 1) {
if (rows > 188) {
if (rows > 198) {
return Algo::kWarpDistributedShm;
} else {
if (k > 72) {
return Algo::kRadix11bits;
} else {
return Algo::kWarpDistributedShm;
}
return Algo::kWarpImmediate;
}
} else {
return Algo::kFaissBlockSelect;
}
} else {
return Algo::kWarpImmediate;
}
}
}
Expand Down Expand Up @@ -294,6 +250,8 @@ void select_k(raft::resources const& handle,

switch (algo) {
case Algo::kRadix11bits:
case Algo::kRadix11bitsExtraPass: {
bool fused_last_filter = algo == Algo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
Expand All @@ -302,7 +260,7 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
true, // fused_last_filter
fused_last_filter,
stream,
mr);

Expand All @@ -324,13 +282,15 @@ void select_k(raft::resources const& handle,
handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min);
}
return;
}
case Algo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream);
case Algo::kWarpImmediate:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_immediate>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
default: RAFT_FAIL("K-selection Algorithm not supported.");
}
}
Expand Down
Loading

0 comments on commit eef2c4a

Please sign in to comment.