Skip to content

Commit

Permalink
Cache IVF-PQ and select-warpsort kernel launch parameters to reduce l…
Browse files Browse the repository at this point in the history
…atency (#1786)

This PR aims at reducing the latency in IVF-PQ and related functions, especially with small work sizes and in the "throughput" benchmark mode.

 - Add kernel config caching to ivf_pq::search::compute_similarity kernel
 - Add kernel config caching to select::warpsort
 - Fix the memory_resource usage in `matrix::select_k`: make sure all temporary allocations use raft's workspace memory resource.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1786
  • Loading branch information
achirkin authored Feb 6, 2024
1 parent 9f6af2f commit d7cbcf9
Show file tree
Hide file tree
Showing 15 changed files with 305 additions and 235 deletions.
28 changes: 13 additions & 15 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,23 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k(raft::resources const& handle, \
const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::mr::device_memory_resource* mr, \
bool sorted, \
#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k(raft::resources const& handle, \
const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
bool sorted, \
raft::matrix::SelectAlgo algo)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
Expand Down
79 changes: 36 additions & 43 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/matrix/init.cuh>
#include <raft/core/operators.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/linalg/map.cuh>
#include <raft/matrix/select_k_types.hpp>

#include <raft/core/resource/thrust_policy.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <thrust/scan.h>
#include <cub/cub.cuh>

namespace raft::matrix::detail {

Expand Down Expand Up @@ -95,15 +94,17 @@ void segmented_sort_by_key(raft::resources const& handle,
const ValT* offsets,
bool asc)
{
auto stream = raft::resource::get_cuda_stream(handle);
auto out_inds = raft::make_device_vector<ValT, ValT>(handle, n_elements);
auto out_dists = raft::make_device_vector<KeyT, ValT>(handle, n_elements);
auto stream = resource::get_cuda_stream(handle);
auto mr = resource::get_workspace_resource(handle);
auto out_inds =
raft::make_device_mdarray<ValT, ValT>(handle, mr, raft::make_extents<ValT>(n_elements));
auto out_dists =
raft::make_device_mdarray<KeyT, ValT>(handle, mr, raft::make_extents<ValT>(n_elements));

// Determine temporary device storage requirements
auto d_temp_storage = raft::make_device_vector<char, int>(handle, 0);
size_t temp_storage_bytes = 0;
if (asc) {
cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(),
cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
temp_storage_bytes,
keys,
out_dists.data_handle(),
Expand All @@ -117,7 +118,7 @@ void segmented_sort_by_key(raft::resources const& handle,
sizeof(ValT) * 8,
stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)d_temp_storage.data_handle(),
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
temp_storage_bytes,
keys,
out_dists.data_handle(),
Expand All @@ -132,7 +133,8 @@ void segmented_sort_by_key(raft::resources const& handle,
stream);
}

d_temp_storage = raft::make_device_vector<char, int>(handle, temp_storage_bytes);
auto d_temp_storage = raft::make_device_mdarray<char, size_t>(
handle, mr, raft::make_extents<size_t>(temp_storage_bytes));

if (asc) {
// Run sorting operation
Expand Down Expand Up @@ -201,6 +203,7 @@ void segmented_sort_by_key(raft::resources const& handle,
* @tparam IdxT
* the index type (what is being selected together with the keys).
*
* @param[in] handle container of reusable resources
* @param[in] in_val
* contiguous device array of inputs of size (len * batch_size);
* these are compared and selected.
Expand All @@ -222,9 +225,10 @@ void segmented_sort_by_key(raft::resources const& handle,
* the payload selected together with `out_val`.
* @param select_min
* whether to select k smallest (true) or largest (false) keys.
* @param stream
* @param mr an optional memory resource to use across the calls (you can provide a large enough
* memory pool here to avoid memory allocations within the call).
* @param[in] sorted
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
Expand All @@ -236,59 +240,48 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);

if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }

if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); }

auto stream = raft::resource::get_cuda_stream(handle);
switch (algo) {
case SelectAlgo::kRadix8bits:
case SelectAlgo::kRadix11bits:
case SelectAlgo::kRadix11bitsExtraPass: {
if (algo == SelectAlgo::kRadix8bits) {
detail::select::radix::select_k<T, IdxT, 8, 512>(in_val,
detail::select::radix::select_k<T, IdxT, 8, 512>(handle,
in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
true, // fused_last_filter
stream,
mr);
true // fused_last_filter
);

} else {
bool fused_last_filter = algo == SelectAlgo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
detail::select::radix::select_k<T, IdxT, 11, 512>(handle,
in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
fused_last_filter,
stream,
mr);
fused_last_filter);
}
if (sorted) {
auto offsets = raft::make_device_vector<IdxT, IdxT>(handle, (IdxT)(batch_size + 1));

raft::matrix::fill(handle, offsets.view(), (IdxT)k);

thrust::exclusive_scan(raft::resource::get_thrust_policy(handle),
offsets.data_handle(),
offsets.data_handle() + offsets.size(),
offsets.data_handle(),
0);
auto offsets = make_device_mdarray<IdxT, IdxT>(
handle, resource::get_workspace_resource(handle), make_extents<IdxT>(batch_size + 1));
raft::linalg::map_offset(handle, offsets.view(), mul_const_op<IdxT>(k));

auto keys = raft::make_device_vector_view<T, IdxT>(out_val, (IdxT)(batch_size * k));
auto vals = raft::make_device_vector_view<IdxT, IdxT>(out_idx, (IdxT)(batch_size * k));
Expand All @@ -301,22 +294,22 @@ void select_k(raft::resources const& handle,
case SelectAlgo::kWarpDistributed:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
case SelectAlgo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
case SelectAlgo::kWarpAuto:
return detail::select::warpsort::select_k<T, IdxT>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
case SelectAlgo::kWarpImmediate:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_immediate>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
case SelectAlgo::kWarpFiltered:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_filtered>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
default: RAFT_FAIL("K-selection Algorithm not supported.");
}
}
Expand Down
29 changes: 11 additions & 18 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include <raft/core/detail/macros.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/device_properties.hpp>
#include <raft/linalg/map.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/device_atomics.cuh>
Expand Down Expand Up @@ -1157,6 +1160,7 @@ void radix_topk_one_block(const T* in,
* @tparam BlockSize
* Number of threads in a kernel thread block.
*
* @param[in] res container of reusable resources
* @param[in] in
* contiguous device array of inputs of size (len * batch_size);
* these are compared and selected.
Expand Down Expand Up @@ -1184,23 +1188,21 @@ void radix_topk_one_block(const T* in,
* blocks is called. The later case is preferable when leading bits of input data are almost the
* same. That is, when the value range of input data is narrow. In such case, there could be a
* large number of inputs for the last filter, hence using multiple thread blocks is beneficial.
* @param stream
* @param mr an optional memory resource to use across the calls (you can provide a large enough
* memory pool here to avoid memory allocations within the call).
*/
template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
void select_k(const T* in,
void select_k(raft::resources const& res,
const T* in,
const IdxT* in_idx,
int batch_size,
IdxT len,
IdxT k,
T* out,
IdxT* out_idx,
bool select_min,
bool fused_last_filter,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr)
bool fused_last_filter)
{
auto stream = resource::get_cuda_stream(res);
auto mr = resource::get_workspace_resource(res);
if (k == len) {
RAFT_CUDA_TRY(
cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream));
Expand All @@ -1210,21 +1212,12 @@ void select_k(const T* in,
} else {
auto out_idx_view =
raft::make_device_vector_view(out_idx, static_cast<size_t>(len) * batch_size);
raft::resources handle;
resource::set_cuda_stream(handle, stream);
raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op<IdxT>(len));
raft::linalg::map_offset(res, out_idx_view, raft::mod_const_op<IdxT>(len));
}
return;
}

// TODO: use device_resources::get_device_properties() instead; should change it when we refactor
// resource management
int sm_cnt;
{
int dev;
RAFT_CUDA_TRY(cudaGetDevice(&dev));
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev));
}
int sm_cnt = resource::get_device_properties(res).multiProcessorCount;

constexpr int items_per_thread = 32;

Expand Down
Loading

0 comments on commit d7cbcf9

Please sign in to comment.