Skip to content

Commit

Permalink
Use the memory workspaces everywhere across ANN methods
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Feb 23, 2024
1 parent 5bf0a76 commit 26ae6fc
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
11 changes: 9 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,18 @@ void build_knn_graph(raft::resources const& res,
uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f);
gpu_top_k = std::min<IdxT>(std::max(gpu_top_k, top_k), dataset.extent(0));
const auto num_queries = dataset.extent(0);
rmm::mr::device_memory_resource* workspace_mr = raft::resource::get_workspace_resource(res);
// Heuristic: how much of the workspace we can spare for the queries.
// The rest is going to be used by ivf_pq::search
const auto workspace_queries_bytes = raft::resource::get_workspace_free_bytes(res) / 5;
const auto max_batch_size =
auto max_batch_size =
std::min<size_t>(workspace_queries_bytes / sizeof(DataT) / dataset.extent(1), 4096);
// Heuristic: if the workspace is too small for a decent batch size, switch to use the large
// resource with a default batch size.
if (max_batch_size < 128) {
max_batch_size = 1024;
workspace_mr = raft::resource::get_large_workspace_resource(res);
}
RAFT_LOG_DEBUG(
"IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u",
node_degree,
Expand Down Expand Up @@ -138,7 +145,7 @@ void build_knn_graph(raft::resources const& res,
dataset.extent(1),
max_batch_size,
resource::get_cuda_stream(res),
raft::resource::get_workspace_resource(res));
workspace_mr);

size_t next_report_offset = 0;
size_t d_report_offset = dataset.extent(0) / 100; // Report progress in 1% steps.
Expand Down
13 changes: 9 additions & 4 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/map.cuh>
Expand Down Expand Up @@ -179,7 +180,8 @@ void extend(raft::resources const& handle,
RAFT_EXPECTS(new_indices != nullptr || index->size() == 0,
"You must pass data indices when the index is non-empty.");

auto new_labels = raft::make_device_vector<LabelT, IdxT>(handle, n_rows);
auto new_labels = raft::make_device_mdarray<LabelT>(
handle, resource::get_large_workspace_resource(handle), raft::make_extents<IdxT>(n_rows));
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.metric = index->metric();
auto orig_centroids_view =
Expand Down Expand Up @@ -210,7 +212,8 @@ void extend(raft::resources const& handle,
}

auto* list_sizes_ptr = index->list_sizes().data_handle();
auto old_list_sizes_dev = raft::make_device_vector<uint32_t, IdxT>(handle, n_lists);
auto old_list_sizes_dev = raft::make_device_mdarray<uint32_t>(
handle, resource::get_workspace_resource(handle), raft::make_extents<IdxT>(n_lists));
copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream);

// Calculate the centers and sizes on the new data, starting from the original values
Expand Down Expand Up @@ -364,7 +367,8 @@ inline auto build(raft::resources const& handle,
auto trainset_ratio = std::max<size_t>(
1, n_rows / std::max<size_t>(params.kmeans_trainset_fraction * n_rows, index.n_lists()));
auto n_rows_train = n_rows / trainset_ratio;
rmm::device_uvector<T> trainset(n_rows_train * index.dim(), stream);
rmm::device_uvector<T> trainset(
n_rows_train * index.dim(), stream, raft::resource::get_large_workspace_resource(handle));
// TODO: a proper sampling
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(),
sizeof(T) * index.dim(),
Expand Down Expand Up @@ -424,7 +428,8 @@ inline void fill_refinement_index(raft::resources const& handle,
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries));

rmm::device_uvector<LabelT> new_labels(n_queries * n_candidates, stream);
rmm::device_uvector<LabelT> new_labels(
n_queries * n_candidates, stream, raft::resource::get_workspace_resource(handle));
auto new_labels_view =
raft::make_device_vector_view<LabelT, IdxT>(new_labels.data(), n_queries * n_candidates);
linalg::map_offset(
Expand Down
12 changes: 8 additions & 4 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/core/logger.hpp> // RAFT_LOG_TRACE
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp> // workspace resource
#include <raft/core/resources.hpp> // raft::resources
#include <raft/distance/distance_types.hpp> // is_min_close, DistanceType
#include <raft/linalg/gemm.cuh> // raft::linalg::gemm
Expand All @@ -28,7 +29,7 @@
#include <raft/neighbors/ivf_flat_types.hpp> // raft::neighbors::ivf_flat::index
#include <raft/neighbors/sample_filter_types.hpp> // none_ivf_sample_filter
#include <raft/spatial/knn/detail/ann_utils.cuh> // utils::mapping
#include <rmm/mr/device/per_device_resource.hpp> // rmm::device_memory_resource
#include <rmm/mr/device/device_memory_resource.hpp> // rmm::device_memory_resource

namespace raft::neighbors::ivf_flat::detail {

Expand Down Expand Up @@ -220,17 +221,20 @@ inline void search(raft::resources const& handle,
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim());

if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }
RAFT_EXPECTS(params.n_probes > 0,
"n_probes (number of clusters to probe in the search) must be positive.");
auto n_probes = std::min<uint32_t>(params.n_probes, index.n_lists());

// a batch size heuristic: try to keep the workspace within the specified size
constexpr uint32_t kExpectedWsSize = 1024 * 1024 * 1024;
uint64_t expected_ws_size = 1024 * 1024 * 1024ull;
if (mr == nullptr) {
mr = resource::get_workspace_resource(handle);
expected_ws_size = resource::get_workspace_free_bytes(handle);
}
const uint32_t max_queries =
std::min<uint32_t>(n_queries,
raft::div_rounding_up_safe<uint64_t>(
kExpectedWsSize, 16ull * uint64_t{n_probes} * k + 4ull * index.dim()));
expected_ws_size, 16ull * uint64_t{n_probes} * k + 4ull * index.dim()));

for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) {
uint32_t queries_batch = min(max_queries, n_queries - offset_q);
Expand Down

0 comments on commit 26ae6fc

Please sign in to comment.