Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace usages of raw get_upstream with get_upstream_resource() #2207

Merged
merged 12 commits into from
Mar 21, 2024
16 changes: 8 additions & 8 deletions cpp/test/core/device_resources_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <rmm/mr/device/limiting_resource_adaptor.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda_runtime_api.h>

Expand Down Expand Up @@ -114,17 +115,16 @@ TEST(DeviceResourcesManager, ObeysSetters)

auto* mr = dynamic_cast<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>*>(
rmm::mr::get_current_device_resource());
auto* workspace_mr =
dynamic_cast<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>*>(
dynamic_cast<rmm::mr::limiting_resource_adaptor<rmm::mr::device_memory_resource>*>(
res.get_workspace_resource())
->get_upstream());
rmm::device_async_resource_ref workspace_mr =
dynamic_cast<rmm::mr::limiting_resource_adaptor<rmm::mr::device_memory_resource>*>(
res.get_workspace_resource())
->get_upstream_resource();
if (upstream_mrs[i % devices.size()] != nullptr) {
// Expect that the current memory resource is a pool memory resource as requested
EXPECT_NE(mr, nullptr);
// Expect that the upstream workspace memory resource is a pool memory
// resource as requested
EXPECT_NE(workspace_mr, nullptr);

// We cannot easily check the type of a resource_ref
(void)workspace_mr;
}

{
Expand Down
8 changes: 5 additions & 3 deletions cpp/test/core/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <rmm/device_buffer.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda_runtime.h>

Expand Down Expand Up @@ -281,7 +282,8 @@ TEST(Raft, WorkspaceResource)
raft::handle_t handle;

// The returned resource is always a limiting adaptor
auto* orig_mr = resource::get_workspace_resource(handle)->get_upstream();
rmm::device_async_resource_ref orig_mr{
resource::get_workspace_resource(handle)->get_upstream_resource()};

// Let's create a pooled resource
auto pool_mr = std::shared_ptr<rmm::mr::device_memory_resource>{new rmm::mr::pool_memory_resource(
Expand All @@ -295,8 +297,8 @@ TEST(Raft, WorkspaceResource)
auto new_mr = resource::get_workspace_resource(handle);

// By this point, the orig_mr likely points to a non-existent resource; don't dereference!
ASSERT_NE(orig_mr, new_mr);
ASSERT_EQ(pool_mr.get(), new_mr->get_upstream());
ASSERT_NE(orig_mr, rmm::device_async_resource_ref{new_mr});
ASSERT_EQ(rmm::device_async_resource_ref{pool_mr.get()}, new_mr->get_upstream_resource());
// We can safely reset pool_mr, because the shared_ptr to the pool memory stays in the resource
pool_mr.reset();

Expand Down
Loading