Skip to content

Commit

Permalink
[Discussion] Scaling workspace resources
Browse files Browse the repository at this point in the history
Add another workspace memory resource that does not have the explicit memory limit.
It should be used for large allocation; a user can set it to the host-memory-backed resource, such as managed memory, for better scaling and to avoid many OOMs.
  • Loading branch information
achirkin committed Feb 22, 2024
1 parent 9fb05a2 commit 4e5d842
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 31 deletions.
20 changes: 17 additions & 3 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/failure_callback_resource_adaptor.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <memory>
Expand Down Expand Up @@ -70,13 +71,14 @@ inline auto rmm_oom_callback(std::size_t bytes, void*) -> bool
*/
class shared_raft_resources {
public:
using pool_mr_type = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
using mr_type = rmm::mr::failure_callback_resource_adaptor<pool_mr_type>;
using pool_mr_type = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
using mr_type = rmm::mr::failure_callback_resource_adaptor<pool_mr_type>;
using large_mr_type = rmm::mr::managed_memory_resource;

shared_raft_resources()
try : orig_resource_{rmm::mr::get_current_device_resource()},
pool_resource_(orig_resource_, 1024 * 1024 * 1024ull),
resource_(&pool_resource_, rmm_oom_callback, nullptr) {
resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() {
rmm::mr::set_current_device_resource(&resource_);
} catch (const std::exception& e) {
auto cuda_status = cudaGetLastError();
Expand All @@ -99,10 +101,16 @@ class shared_raft_resources {

~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); }

auto get_large_memory_resource() noexcept
{
return static_cast<rmm::mr::device_memory_resource*>(&large_mr_);
}

private:
rmm::mr::device_memory_resource* orig_resource_;
pool_mr_type pool_resource_;
mr_type resource_;
large_mr_type large_mr_;
};

/**
Expand All @@ -123,6 +131,12 @@ class configured_raft_resources {
explicit configured_raft_resources(const std::shared_ptr<shared_raft_resources>& shared_res)
: shared_res_{shared_res}, res_{rmm::cuda_stream_view(get_stream_from_global_pool())}
{
// set the large workspace resource to the raft handle, but without the deleter
// (this resource is managed by the shared_res).
raft::resource::set_large_workspace_resource(
res_,
std::shared_ptr<rmm::mr::device_memory_resource>(shared_res_->get_large_memory_resource(),
raft::void_op{}));
}

/** Default constructor creates all resources anew. */
Expand Down
50 changes: 49 additions & 1 deletion cpp/include/raft/core/resource/device_memory_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,16 @@ namespace raft::resource {
* @{
*/

class memory_resource : public resource {
public:
explicit memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr) : mr_(mr) {}
~memory_resource() override = default;
auto get_resource() -> void* override { return mr_.get(); }

private:
std::shared_ptr<rmm::mr::device_memory_resource> mr_;
};

class limiting_memory_resource : public resource {
public:
limiting_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr,
Expand Down Expand Up @@ -66,6 +76,29 @@ class limiting_memory_resource : public resource {
}
};

/**
* Factory that knows how to construct a specific raft::resource to populate
* the resources instance.
*/
class large_workspace_resource_factory : public resource_factory {
public:
explicit large_workspace_resource_factory(
std::shared_ptr<rmm::mr::device_memory_resource> mr = {nullptr})
: mr_{mr ? mr
: std::shared_ptr<rmm::mr::device_memory_resource>{
rmm::mr::get_current_device_resource(), void_op{}}}
{
}
auto get_resource_type() -> resource_type override
{
return resource_type::LARGE_MEMORY_RESOURCE;
}
auto make_resource() -> resource* override { return new memory_resource(mr_); }

private:
std::shared_ptr<rmm::mr::device_memory_resource> mr_;
};

/**
* Factory that knows how to construct a specific raft::resource to populate
* the resources instance.
Expand Down Expand Up @@ -241,6 +274,21 @@ inline void set_workspace_to_global_resource(
workspace_resource_factory::default_plain_resource(), allocation_limit, std::nullopt));
};

inline auto get_large_workspace_resource(resources const& res) -> rmm::mr::device_memory_resource*
{
if (!res.has_resource_factory(resource_type::LARGE_MEMORY_RESOURCE)) {
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>());
}
return res.get_resource<rmm::mr::device_memory_resource>(resource_type::LARGE_MEMORY_RESOURCE);
};

inline void set_large_workspace_resource(resources const& res,
std::shared_ptr<rmm::mr::device_memory_resource> mr = {
nullptr})
{
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>(mr));
};

/** @} */

} // namespace raft::resource
3 changes: 2 additions & 1 deletion cpp/include/raft/core/resource/resource_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ enum resource_type {
STREAM_VIEW, // view of a cuda stream or a placeholder in
// CUDA-free builds
THRUST_POLICY, // thrust execution policy
WORKSPACE_RESOURCE, // rmm device memory resource
WORKSPACE_RESOURCE, // rmm device memory resource for small temporary allocations
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource
LARGE_MEMORY_RESOURCE, // rmm device memory resource for somewhat large temporary allocations

LAST_KEY // reserved for the last key
};
Expand Down
26 changes: 17 additions & 9 deletions cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -244,15 +244,19 @@ void sort_knn_graph(raft::resources const& res,
const uint32_t input_graph_degree = knn_graph.extent(1);
IdxT* const input_graph_ptr = knn_graph.data_handle();

auto d_input_graph = raft::make_device_matrix<IdxT, int64_t>(res, graph_size, input_graph_degree);
auto large_tmp_mr = resource::get_large_workspace_resource(res);

auto d_input_graph = raft::make_device_mdarray<IdxT>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size, input_graph_degree));

//
// Sorting kNN graph
//
const double time_sort_start = cur_time();
RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs ");

auto d_dataset = raft::make_device_matrix<DataT, int64_t>(res, dataset_size, dataset_dim);
auto d_dataset = raft::make_device_mdarray<DataT>(
res, large_tmp_mr, raft::make_extents<int64_t>(dataset_size, dataset_dim));
raft::copy(d_dataset.data_handle(),
dataset_ptr,
dataset_size * dataset_dim,
Expand Down Expand Up @@ -323,6 +327,7 @@ void optimize(raft::resources const& res,
{
RAFT_LOG_DEBUG(
"# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1));
auto large_tmp_mr = resource::get_large_workspace_resource(res);

RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0),
"Each input array is expected to have the same number of rows");
Expand All @@ -338,15 +343,16 @@ void optimize(raft::resources const& res,
//
// Prune kNN graph
//
auto d_detour_count =
raft::make_device_matrix<uint8_t, int64_t>(res, graph_size, input_graph_degree);
auto d_detour_count = raft::make_device_mdarray<uint8_t>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size, input_graph_degree));

RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(),
0xff,
graph_size * input_graph_degree * sizeof(uint8_t),
resource::get_cuda_stream(res)));

auto d_num_no_detour_edges = raft::make_device_vector<uint32_t, int64_t>(res, graph_size);
auto d_num_no_detour_edges = raft::make_device_mdarray<uint32_t>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size));
RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(),
0x00,
graph_size * sizeof(uint32_t),
Expand Down Expand Up @@ -466,14 +472,16 @@ void optimize(raft::resources const& res,
graph_size * output_graph_degree * sizeof(IdxT),
resource::get_cuda_stream(res)));

auto d_rev_graph_count = raft::make_device_vector<uint32_t, int64_t>(res, graph_size);
auto d_rev_graph_count = raft::make_device_mdarray<uint32_t>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size));
RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(),
0x00,
graph_size * sizeof(uint32_t),
resource::get_cuda_stream(res)));

auto dest_nodes = raft::make_host_vector<IdxT, int64_t>(graph_size);
auto d_dest_nodes = raft::make_device_vector<IdxT, int64_t>(res, graph_size);
auto dest_nodes = raft::make_host_vector<IdxT, int64_t>(graph_size);
auto d_dest_nodes =
raft::make_device_mdarray<IdxT>(res, large_tmp_mr, raft::make_extents<int64_t>(graph_size));

for (uint64_t k = 0; k < output_graph_degree; k++) {
#pragma omp parallel for
Expand Down
9 changes: 6 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -176,8 +176,11 @@ class device_matrix_view_from_host {
device_ptr = reinterpret_cast<T*>(attr.devicePointer);
if (device_ptr == NULL) {
// allocate memory and copy over
device_mem_.emplace(
raft::make_device_matrix<T, IdxT>(res, host_view.extent(0), host_view.extent(1)));
// NB: We use the temporary "large" workspace resource here; this structure is supposed to
// live on stack and not returned to a user.
// The user may opt to set this resource to managed memory to allow large allocations.
device_mem_.emplace(make_device_mdarray<T, IdxT>(
res, resource::get_large_workspace_resource(res), host_view.extents()));
raft::copy(device_mem_->data_handle(),
host_view.data_handle(),
host_view.extent(0) * host_view.extent(1),
Expand Down
58 changes: 44 additions & 14 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#pragma once

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <raft/neighbors/detail/ivf_pq_codepacking.cuh>
Expand All @@ -28,6 +27,8 @@
#include <raft/core/logger.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/add.cuh>
Expand All @@ -46,9 +47,9 @@
#include <raft/util/pow2_utils.cuh>
#include <raft/util/vectorized.cuh>

#include <raft/core/resource/device_memory_resource.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>

#include <thrust/extrema.h>
Expand Down Expand Up @@ -1569,6 +1570,8 @@ void extend(raft::resources const& handle,
"Unsupported data type");

rmm::mr::device_memory_resource* device_memory = raft::resource::get_workspace_resource(handle);
rmm::mr::device_memory_resource* large_memory =
raft::resource::get_large_workspace_resource(handle);

// The spec defines how the clusters look like
auto spec = list_spec<uint32_t, IdxT>{
Expand All @@ -1586,9 +1589,20 @@ void extend(raft::resources const& handle,
size_t free_mem, total_mem;
RAFT_CUDA_TRY(cudaMemGetInfo(&free_mem, &total_mem));

// We try to use the workspace memory by default here.
// If the workspace limit is too small, we change the resource for batch data to the
// `large_workspace_resource`, which does not have the explicit allocation limit. The user may opt
// to populate the `large_workspace_resource` memory resource with managed memory for easier
// scaling.
rmm::mr::device_memory_resource* labels_mr = device_memory;
rmm::mr::device_memory_resource* batches_mr = device_memory;
if (n_rows * (index->dim() * sizeof(T) + index->pq_dim() + sizeof(IdxT) + sizeof(uint32_t)) >
free_mem) {
labels_mr = large_memory;
}
// Allocate a buffer for the new labels (classifying the new data)
rmm::device_uvector<uint32_t> new_data_labels(n_rows, stream, device_memory);
free_mem -= sizeof(uint32_t) * n_rows;
rmm::device_uvector<uint32_t> new_data_labels(n_rows, stream, labels_mr);
if (labels_mr == device_memory) { free_mem -= sizeof(uint32_t) * n_rows; }

// Calculate the batch size for the input data if it's not accessible directly from the device
constexpr size_t kReasonableMaxBatchSize = 65536;
Expand Down Expand Up @@ -1617,13 +1631,19 @@ void extend(raft::resources const& handle,
while (size_factor * max_batch_size > free_mem && max_batch_size > 128) {
max_batch_size >>= 1;
}
// If we're keeping the batches in device memory, update the available mem tracker.
free_mem -= size_factor * max_batch_size;
if (size_factor * max_batch_size > free_mem) {
// if that still doesn't fit, resort to the UVM
batches_mr = large_memory;
max_batch_size = kReasonableMaxBatchSize;
} else {
// If we're keeping the batches in device memory, update the available mem tracker.
free_mem -= size_factor * max_batch_size;
}
}

// Predict the cluster labels for the new data, in batches if necessary
utils::batch_load_iterator<T> vec_batches(
new_vectors, n_rows, index->dim(), max_batch_size, stream, device_memory);
new_vectors, n_rows, index->dim(), max_batch_size, stream, batches_mr);
// Release the placeholder memory, because we don't intend to allocate any more long-living
// temporary buffers before we allocate the index data.
// This memory could potentially speed up UVM accesses, if any.
Expand Down Expand Up @@ -1696,7 +1716,7 @@ void extend(raft::resources const& handle,
// By this point, the index state is updated and valid except it doesn't contain the new data
// Fill the extended index with the new data (possibly, in batches)
utils::batch_load_iterator<IdxT> idx_batches(
new_indices, n_rows, 1, max_batch_size, stream, device_memory);
new_indices, n_rows, 1, max_batch_size, stream, batches_mr);
for (const auto& vec_batch : vec_batches) {
const auto& idx_batch = *idx_batches++;
process_and_fill_codes(handle,
Expand All @@ -1707,7 +1727,7 @@ void extend(raft::resources const& handle,
: std::variant<IdxT, const IdxT*>(IdxT(idx_batch.offset())),
new_data_labels.data() + vec_batch.offset(),
IdxT(vec_batch.size()),
device_memory);
batches_mr);
}
}

Expand Down Expand Up @@ -1760,11 +1780,21 @@ auto build(raft::resources const& handle,
size_t n_rows_train = n_rows / trainset_ratio;

auto* device_memory = resource::get_workspace_resource(handle);
rmm::mr::managed_memory_resource managed_memory_upstream;
rmm::mr::managed_memory_resource managed_memory;

// If the trainset is small enough to comfortably fit into device memory, put it there.
// Otherwise, use the managed memory.
constexpr size_t kTolerableRatio = 4;
rmm::mr::device_memory_resource* big_memory_resource =
resource::get_large_workspace_resource(handle);
if (sizeof(float) * n_rows_train * index.dim() * kTolerableRatio <
resource::get_workspace_free_bytes(handle)) {
big_memory_resource = device_memory;
}

// Besides just sampling, we transform the input dataset into floats to make it easier
// to use gemm operations from cublas.
rmm::device_uvector<float> trainset(n_rows_train * index.dim(), stream, device_memory);
rmm::device_uvector<float> trainset(n_rows_train * index.dim(), stream, big_memory_resource);
// TODO: a proper sampling
if constexpr (std::is_same_v<T, float>) {
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(),
Expand Down Expand Up @@ -1833,7 +1863,7 @@ auto build(raft::resources const& handle,
handle, kmeans_params, trainset_const_view, centers_view, utils::mapping<float>{});

// Trainset labels are needed for training PQ codebooks
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, device_memory);
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, big_memory_resource);
auto centers_const_view = raft::make_device_matrix_view<const float, IdxT>(
cluster_centers, index.n_lists(), index.dim());
auto labels_view = raft::make_device_vector_view<uint32_t, IdxT>(labels.data(), n_rows_train);
Expand Down Expand Up @@ -1862,7 +1892,7 @@ auto build(raft::resources const& handle,
trainset.data(),
labels.data(),
params.kmeans_n_iters,
&managed_memory_upstream);
&managed_memory);
break;
case codebook_gen::PER_CLUSTER:
train_per_cluster(handle,
Expand All @@ -1871,7 +1901,7 @@ auto build(raft::resources const& handle,
trainset.data(),
labels.data(),
params.kmeans_n_iters,
&managed_memory_upstream);
&managed_memory);
break;
default: RAFT_FAIL("Unreachable code");
}
Expand Down

0 comments on commit 4e5d842

Please sign in to comment.