Skip to content

Commit

Permalink
More wrapper use.
Browse files Browse the repository at this point in the history
  • Loading branch information
harrism committed Aug 29, 2024
1 parent 2f11ca4 commit 20cfc85
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ int main(int argc, char** argv)
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
&cuda_mr, rmm::percent_of_free_device_memory(50)};
// Updates the current device resource pointer to `pool_mr`
auto old_mr = rmm::mr::set_current_device_resource(&pool_mr);
auto old_mr = raft::resource::set_current_device_resource(&pool_mr);
auto ret = raft::bench::ann::run_main(argc, argv);
// Restores the current device resource pointer to its previous value
rmm::mr::set_current_device_resource(old_mr);
raft::resource::set_current_device_resource(old_mr);
return ret;
}
#endif
6 changes: 3 additions & 3 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class shared_raft_resources {
using large_mr_type = rmm::mr::managed_memory_resource;

shared_raft_resources()
try : orig_resource_{raft::resource::get_current_device_resource_ref()},
try : orig_resource_{raft::resource::get_current_device_resource()},
pool_resource_(orig_resource_, 1024 * 1024 * 1024ull),
resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() {
rmm::mr::set_current_device_resource(&resource_);
raft::resource::set_current_device_resource(&resource_);
} catch (const std::exception& e) {
auto cuda_status = cudaGetLastError();
size_t free = 0;
Expand All @@ -103,7 +103,7 @@ class shared_raft_resources {
shared_raft_resources(const shared_raft_resources& res) = delete;
shared_raft_resources& operator=(const shared_raft_resources& other) = delete;

~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); }
~shared_raft_resources() noexcept { raft::resource::set_current_device_resource(orig_resource_); }

auto get_large_memory_resource() noexcept
{
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ int main(int argc, char** argv)
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
&cuda_mr, rmm::percent_of_free_device_memory(50)};
// Updates the current device resource pointer to `pool_mr`
auto old_mr = rmm::mr::set_current_device_resource(&pool_mr);
auto old_mr = raft::resource::set_current_device_resource(&pool_mr);
auto ret = raft::bench::ann::run_main(argc, argv);
// Restores the current device resource pointer to its previous value
rmm::mr::set_current_device_resource(old_mr);
raft::resource::set_current_device_resource(old_mr);
return ret;
}
#endif
7 changes: 4 additions & 3 deletions cpp/bench/prims/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/device_resources.hpp>
#include <raft/core/interruptible.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/random/make_blobs.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -53,17 +54,17 @@ struct using_pool_memory_res {
: orig_res_(raft::resource::get_current_device_resource_ref()),
pool_res_(&cuda_res_, initial_size, max_size)
{
rmm::mr::set_current_device_resource(&pool_res_);
raft::resource::set_current_device_resource(&pool_res_);
}

using_pool_memory_res()
: orig_res_(raft::resource::get_current_device_resource_ref()),
pool_res_(&cuda_res_, rmm::percent_of_free_device_memory(50))
{
rmm::mr::set_current_device_resource(&pool_res_);
raft::resource::set_current_device_resource(&pool_res_);
}

~using_pool_memory_res() { rmm::mr::set_current_device_resource(orig_res_); }
~using_pool_memory_res() { raft::resource::set_current_device_resource(orig_res_); }
};

/**
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ struct Gather : public fixture {
stencil(this->handle),
matrix_h(this->handle)
{
rmm::mr::set_current_device_resource(&pool_mr);
raft::resource::set_current_device_resource(&pool_mr);
}

~Gather() { rmm::mr::set_current_device_resource(old_mr); }
~Gather() { raft::resource::set_current_device_resource(old_mr); }

void allocate_data(const ::benchmark::State& state) override
{
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/neighbors/refine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class RefineAnn : public fixture {
auto old_mr = raft::resource::get_current_device_resource_ref();
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr(
old_mr, rmm::percent_of_free_device_memory(50));
rmm::mr::set_current_device_resource(&pool_mr);
raft::resource::set_current_device_resource(&pool_mr);

if (data.p.host_data) {
loop_on_state(state, [this]() {
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/random/subsample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct sample : public fixture {
in(make_device_vector<T, int64_t>(res, p.n_samples)),
out(make_device_vector<T, int64_t>(res, p.n_train))
{
rmm::mr::set_current_device_resource(&pool_mr);
raft::resource::set_current_device_resource(&pool_mr);
raft::random::RngState r(123456ULL);
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/core/device_resources_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ struct device_resources_manager {
upstream,
params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)),
params.max_mem_pool_size);
rmm::mr::set_current_device_resource(result.get());
raft::resource::set_current_device_resource(result.get());
} else {
RAFT_LOG_WARN(
"Pool allocation requested, but other memory resource has already been set and "
Expand Down

0 comments on commit 20cfc85

Please sign in to comment.