Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache IVF-PQ and select-warpsort kernel launch parameters to reduce latency #1786

Merged
merged 50 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
2cc477b
Replace GEMM backend: cublas.gemm -> cublaslt.matmul
achirkin Aug 14, 2023
dc7a9a4
Replace broken (due to missing direct includes) direct uses of cublas…
achirkin Aug 14, 2023
34a9479
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 15, 2023
71c03c0
Fix docs
achirkin Aug 15, 2023
a2fb088
Replace cublasgemm where it makes sense
achirkin Aug 16, 2023
699de0c
Fix a typo
achirkin Aug 16, 2023
f994f19
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 16, 2023
f4d634a
Put the cache into the resource handle as a user-define resource
achirkin Aug 21, 2023
2d1bf5c
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 22, 2023
e57eebf
Move matmul into a separate file
achirkin Aug 22, 2023
d44bf20
Complete the docs
achirkin Aug 22, 2023
facf81d
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 23, 2023
157d8ae
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 24, 2023
be68b61
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 24, 2023
f5ac41a
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 28, 2023
2d4dcb2
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 29, 2023
6f58669
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
a0e93fd
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
4c0d742
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
01c3634
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
abb3f00
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 30, 2023
e24b1c0
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 31, 2023
de29580
move matmul.hpp to cublaslt_wrappers.hpp
achirkin Aug 31, 2023
3835ed0
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Aug 31, 2023
de60202
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 1, 2023
fe84fae
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 5, 2023
f47626a
Merge branch 'branch-23.10' into fea-cublaslt-matmul
cjnolet Sep 6, 2023
d7efc0c
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 7, 2023
dd7ee22
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 8, 2023
01e62b0
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 9, 2023
8fdf6cc
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 13, 2023
324f5c6
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 19, 2023
ba6883f
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Sep 19, 2023
a56ea2c
Merge branch 'branch-23.10' into fea-cublaslt-matmul
achirkin Nov 20, 2023
cd4663a
Merge branch 'branch-23.12' into fea-cublaslt-matmul
achirkin Nov 20, 2023
090141a
Cache IVF-PQ and select-warpsort kernel launch parameters to reduce l…
achirkin Aug 30, 2023
3e6dfcd
Adapt the deprecated knn function to the changes
achirkin Aug 30, 2023
800ee80
Merge branch 'branch-24.02' into fea-cache-ivf-pq-params
achirkin Jan 25, 2024
82b34b9
Style check
achirkin Jan 25, 2024
d25778c
Merge branch 'branch-24.04' into fea-cache-ivf-pq-params
achirkin Jan 25, 2024
ce9d044
Remove unused code
achirkin Jan 25, 2024
976d597
Merge branch 'branch-24.04' into fea-cache-ivf-pq-params
achirkin Jan 29, 2024
6d29811
Make sure select_k always uses the workspace memory resource
achirkin Jan 29, 2024
542aa64
Revert an accidental copyright-only change
achirkin Jan 29, 2024
8f7e37e
Merge branch 'branch-24.04' into fea-cache-ivf-pq-params
achirkin Jan 29, 2024
d8e1380
Merge branch 'branch-24.04' into fea-cache-ivf-pq-params
achirkin Jan 31, 2024
035e709
Merge branch 'branch-24.04' into fea-cache-ivf-pq-params
achirkin Feb 1, 2024
336f599
Merge branch 'branch-24.04' into fea-cache-ivf-pq-params
achirkin Feb 1, 2024
272f87e
Merge branch 'branch-24.04' into fea-cache-ivf-pq-params
achirkin Feb 5, 2024
48e19e2
Merge branch 'branch-24.04' into fea-cache-ivf-pq-params
achirkin Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,23 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT;
} // namespace raft::matrix::detail

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k(raft::resources const& handle, \
const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::mr::device_memory_resource* mr, \
bool sorted, \
#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
extern template void raft::matrix::detail::select_k(raft::resources const& handle, \
const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
bool sorted, \
raft::matrix::SelectAlgo algo)
instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
Expand Down
79 changes: 36 additions & 43 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/matrix/init.cuh>
#include <raft/core/operators.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/linalg/map.cuh>
#include <raft/matrix/select_k_types.hpp>

#include <raft/core/resource/thrust_policy.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <thrust/scan.h>
#include <cub/cub.cuh>

namespace raft::matrix::detail {

Expand Down Expand Up @@ -95,15 +94,17 @@ void segmented_sort_by_key(raft::resources const& handle,
const ValT* offsets,
bool asc)
{
auto stream = raft::resource::get_cuda_stream(handle);
auto out_inds = raft::make_device_vector<ValT, ValT>(handle, n_elements);
auto out_dists = raft::make_device_vector<KeyT, ValT>(handle, n_elements);
auto stream = resource::get_cuda_stream(handle);
auto mr = resource::get_workspace_resource(handle);
auto out_inds =
raft::make_device_mdarray<ValT, ValT>(handle, mr, raft::make_extents<ValT>(n_elements));
auto out_dists =
raft::make_device_mdarray<KeyT, ValT>(handle, mr, raft::make_extents<ValT>(n_elements));

// Determine temporary device storage requirements
auto d_temp_storage = raft::make_device_vector<char, int>(handle, 0);
size_t temp_storage_bytes = 0;
if (asc) {
cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(),
cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
temp_storage_bytes,
keys,
out_dists.data_handle(),
Expand All @@ -117,7 +118,7 @@ void segmented_sort_by_key(raft::resources const& handle,
sizeof(ValT) * 8,
stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)d_temp_storage.data_handle(),
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
temp_storage_bytes,
keys,
out_dists.data_handle(),
Expand All @@ -132,7 +133,8 @@ void segmented_sort_by_key(raft::resources const& handle,
stream);
}

d_temp_storage = raft::make_device_vector<char, int>(handle, temp_storage_bytes);
auto d_temp_storage = raft::make_device_mdarray<char, size_t>(
handle, mr, raft::make_extents<size_t>(temp_storage_bytes));

if (asc) {
// Run sorting operation
Expand Down Expand Up @@ -201,6 +203,7 @@ void segmented_sort_by_key(raft::resources const& handle,
* @tparam IdxT
* the index type (what is being selected together with the keys).
*
* @param[in] handle container of reusable resources
* @param[in] in_val
* contiguous device array of inputs of size (len * batch_size);
* these are compared and selected.
Expand All @@ -222,9 +225,10 @@ void segmented_sort_by_key(raft::resources const& handle,
* the payload selected together with `out_val`.
* @param select_min
* whether to select k smallest (true) or largest (false) keys.
* @param stream
* @param mr an optional memory resource to use across the calls (you can provide a large enough
* memory pool here to avoid memory allocations within the call).
* @param[in] sorted
* whether to make sure selected pairs are sorted by value
* @param[in] algo
* the selection algorithm to use
*/
template <typename T, typename IdxT>
void select_k(raft::resources const& handle,
Expand All @@ -236,59 +240,48 @@ void select_k(raft::resources const& handle,
T* out_val,
IdxT* out_idx,
bool select_min,
rmm::mr::device_memory_resource* mr = nullptr,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);

if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }

if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); }

auto stream = raft::resource::get_cuda_stream(handle);
switch (algo) {
case SelectAlgo::kRadix8bits:
case SelectAlgo::kRadix11bits:
case SelectAlgo::kRadix11bitsExtraPass: {
if (algo == SelectAlgo::kRadix8bits) {
detail::select::radix::select_k<T, IdxT, 8, 512>(in_val,
detail::select::radix::select_k<T, IdxT, 8, 512>(handle,
in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
true, // fused_last_filter
stream,
mr);
true // fused_last_filter
);

} else {
bool fused_last_filter = algo == SelectAlgo::kRadix11bits;
detail::select::radix::select_k<T, IdxT, 11, 512>(in_val,
detail::select::radix::select_k<T, IdxT, 11, 512>(handle,
in_val,
in_idx,
batch_size,
len,
k,
out_val,
out_idx,
select_min,
fused_last_filter,
stream,
mr);
fused_last_filter);
}
if (sorted) {
auto offsets = raft::make_device_vector<IdxT, IdxT>(handle, (IdxT)(batch_size + 1));

raft::matrix::fill(handle, offsets.view(), (IdxT)k);

thrust::exclusive_scan(raft::resource::get_thrust_policy(handle),
offsets.data_handle(),
offsets.data_handle() + offsets.size(),
offsets.data_handle(),
0);
auto offsets = make_device_mdarray<IdxT, IdxT>(
handle, resource::get_workspace_resource(handle), make_extents<IdxT>(batch_size + 1));
raft::linalg::map_offset(handle, offsets.view(), mul_const_op<IdxT>(k));

auto keys = raft::make_device_vector_view<T, IdxT>(out_val, (IdxT)(batch_size * k));
auto vals = raft::make_device_vector_view<IdxT, IdxT>(out_idx, (IdxT)(batch_size * k));
Expand All @@ -301,22 +294,22 @@ void select_k(raft::resources const& handle,
case SelectAlgo::kWarpDistributed:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
case SelectAlgo::kWarpDistributedShm:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_distributed_ext>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
case SelectAlgo::kWarpAuto:
return detail::select::warpsort::select_k<T, IdxT>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
case SelectAlgo::kWarpImmediate:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_immediate>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
case SelectAlgo::kWarpFiltered:
return detail::select::warpsort::
select_k_impl<T, IdxT, detail::select::warpsort::warp_sort_filtered>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min);
default: RAFT_FAIL("K-selection Algorithm not supported.");
}
}
Expand Down
29 changes: 11 additions & 18 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include <raft/core/detail/macros.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/device_properties.hpp>
#include <raft/linalg/map.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/device_atomics.cuh>
Expand Down Expand Up @@ -1157,6 +1160,7 @@ void radix_topk_one_block(const T* in,
* @tparam BlockSize
* Number of threads in a kernel thread block.
*
* @param[in] res container of reusable resources
* @param[in] in
* contiguous device array of inputs of size (len * batch_size);
* these are compared and selected.
Expand Down Expand Up @@ -1184,23 +1188,21 @@ void radix_topk_one_block(const T* in,
* blocks is called. The later case is preferable when leading bits of input data are almost the
* same. That is, when the value range of input data is narrow. In such case, there could be a
* large number of inputs for the last filter, hence using multiple thread blocks is beneficial.
* @param stream
* @param mr an optional memory resource to use across the calls (you can provide a large enough
* memory pool here to avoid memory allocations within the call).
*/
template <typename T, typename IdxT, int BitsPerPass, int BlockSize>
void select_k(const T* in,
void select_k(raft::resources const& res,
const T* in,
const IdxT* in_idx,
int batch_size,
IdxT len,
IdxT k,
T* out,
IdxT* out_idx,
bool select_min,
bool fused_last_filter,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr)
bool fused_last_filter)
{
auto stream = resource::get_cuda_stream(res);
auto mr = resource::get_workspace_resource(res);
if (k == len) {
RAFT_CUDA_TRY(
cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream));
Expand All @@ -1210,21 +1212,12 @@ void select_k(const T* in,
} else {
auto out_idx_view =
raft::make_device_vector_view(out_idx, static_cast<size_t>(len) * batch_size);
raft::resources handle;
resource::set_cuda_stream(handle, stream);
raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op<IdxT>(len));
raft::linalg::map_offset(res, out_idx_view, raft::mod_const_op<IdxT>(len));
}
return;
}

// TODO: use device_resources::get_device_properties() instead; should change it when we refactor
// resource management
int sm_cnt;
{
int dev;
RAFT_CUDA_TRY(cudaGetDevice(&dev));
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev));
}
int sm_cnt = resource::get_device_properties(res).multiProcessorCount;

constexpr int items_per_thread = 32;

Expand Down
Loading
Loading