Skip to content

Commit

Permalink
Update select-k heuristic (#1985)
Browse files Browse the repository at this point in the history
With #1878 merged, the performance of the radix select algorithms is much improved and we no longer need to incorporate the faiss block select algorithm. With #1878 merged, faiss block select goes from being the 3rd ranked selection algorithm, to the 5th.

This regenerates the heuristic function with the latest benchmark times, and removes the faiss block select in favour of kWarpImmediate and kRadix11bitsExtraPass.

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1985
  • Loading branch information
benfred authored Nov 14, 2023
1 parent e86e75f commit 77bc461
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 394 deletions.
84 changes: 22 additions & 62 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace raft::matrix::detail {

// this is a subset of algorithms, chosen by running the algorithm_selection
// notebook in cpp/scripts/heuristics/select_k
enum class Algo { kRadix11bits, kWarpDistributedShm, kFaissBlockSelect };
enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bitsExtraPass };

/**
* Predict the fastest select_k algorithm based on the number of rows/cols/k
Expand All @@ -50,73 +50,29 @@ enum class Algo { kRadix11bits, kWarpDistributedShm, kFaissBlockSelect };
*/
inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)
{
if (k > 134) {
if (k > 256) {
if (k > 809) {
return Algo::kRadix11bits;
} else {
if (rows > 124) {
if (cols > 63488) {
return Algo::kFaissBlockSelect;
} else {
return Algo::kRadix11bits;
}
} else {
return Algo::kRadix11bits;
}
}
} else {
if (cols > 678736) {
return Algo::kWarpDistributedShm;
if (k > 256) {
if (cols > 16862) {
if (rows > 1020) {
return Algo::kRadix11bitsExtraPass;
} else {
return Algo::kRadix11bits;
}
} else {
return Algo::kRadix11bitsExtraPass;
}
} else {
if (cols > 13776) {
if (rows > 335) {
if (k > 1) {
if (rows > 546) {
return Algo::kWarpDistributedShm;
} else {
if (k > 17) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kFaissBlockSelect;
}
}
} else {
return Algo::kFaissBlockSelect;
}
if (k > 2) {
if (cols > 22061) {
return Algo::kWarpDistributedShm;
} else {
if (k > 44) {
if (cols > 1031051) {
return Algo::kWarpDistributedShm;
} else {
if (rows > 22) {
return Algo::kWarpDistributedShm;
} else {
return Algo::kRadix11bits;
}
}
} else {
return Algo::kWarpDistributedShm;
}
}
} else {
if (k > 1) {
if (rows > 188) {
if (rows > 198) {
return Algo::kWarpDistributedShm;
} else {
if (k > 72) {
return Algo::kRadix11bits;
} else {
return Algo::kWarpDistributedShm;
}
return Algo::kWarpImmediate;
}
} else {
return Algo::kFaissBlockSelect;
}
} else {
return Algo::kWarpImmediate;
}
}
}
Expand Down Expand Up @@ -294,6 +250,8 @@ void select_k(raft::resources const& handle,

switch (algo) {
case Algo::kRadix11bits:
case Algo::kRadix11bitsExtraPass: {
bool fused_last_filter = algo == Algo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
in_idx,
batch_size,
Expand All @@ -302,7 +260,7 @@ void select_k(raft::resources const& handle,
out_val,
out_idx,
select_min,
true, // fused_last_filter
fused_last_filter,
stream,
mr);

Expand All @@ -324,13 +282,15 @@ void select_k(raft::resources const& handle,
handle, raft::make_const_mdspan(offsets.view()), keys, vals, select_min);
}
return;
}
case Algo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
case Algo::kFaissBlockSelect:
return neighbors::detail::select_k(
in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream);
case Algo::kWarpImmediate:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_immediate>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
default: RAFT_FAIL("K-selection Algorithm not supported.");
}
}
Expand Down
Loading

0 comments on commit 77bc461

Please sign in to comment.