Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace all internal usage of get_upstream with get_upstream_resource #1491

Merged
merged 5 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions include/rmm/mr/device/aligned_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ class aligned_resource_adaptor final : public device_memory_resource {
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_;
}

/**
* @brief The default alignment used by the adaptor.
Expand Down Expand Up @@ -168,7 +172,8 @@ class aligned_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<aligned_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr && upstream_->is_equal(*cast->get_upstream()) &&
if (cast == nullptr) { return false; }
return get_upstream_resource() == cast->get_upstream_resource() &&
alignment_ == cast->alignment_ && alignment_threshold_ == cast->alignment_threshold_;
}

Expand Down
17 changes: 10 additions & 7 deletions include/rmm/mr/device/binning_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ class binning_memory_resource final : public device_memory_resource {
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_mr_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_mr_;
}

/**
* @brief Add a bin allocator to this resource
Expand Down Expand Up @@ -151,11 +155,11 @@ class binning_memory_resource final : public device_memory_resource {
* @param bytes Requested allocation size in bytes
* @return rmm::mr::device_memory_resource& memory_resource that can allocate the requested size.
miscco marked this conversation as resolved.
Show resolved Hide resolved
*/
device_memory_resource* get_resource(std::size_t bytes)
rmm::device_async_resource_ref get_resource_ref(std::size_t bytes)
{
auto iter = resource_bins_.lower_bound(bytes);
return (iter != resource_bins_.cend()) ? iter->second
: static_cast<device_memory_resource*>(get_upstream());
return (iter != resource_bins_.cend()) ? rmm::device_async_resource_ref{iter->second}
: get_upstream_resource();
}

/**
Expand All @@ -170,7 +174,7 @@ class binning_memory_resource final : public device_memory_resource {
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
if (bytes <= 0) { return nullptr; }
return get_resource(bytes)->allocate(bytes, stream);
return get_resource_ref(bytes).allocate_async(bytes, stream);
}

/**
Expand All @@ -183,8 +187,7 @@ class binning_memory_resource final : public device_memory_resource {
*/
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override
{
auto res = get_resource(bytes);
if (res != nullptr) { res->deallocate(ptr, bytes, stream); }
get_resource_ref(bytes).deallocate_async(ptr, bytes, stream);
}

Upstream* upstream_mr_; // The upstream memory_resource from which to allocate blocks.
Expand Down
10 changes: 7 additions & 3 deletions include/rmm/mr/device/failure_callback_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_;
}

private:
/**
Expand Down Expand Up @@ -183,8 +187,8 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<failure_callback_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr ? upstream_->is_equal(*cast->get_upstream())
: upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

Upstream* upstream_; // the upstream resource used for satisfying allocation requests
Expand Down
10 changes: 7 additions & 3 deletions include/rmm/mr/device/fixed_size_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ class fixed_size_memory_resource
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_mr_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_mr_;
}

/**
* @brief Get the size of blocks allocated by this memory resource.
Expand Down Expand Up @@ -156,7 +160,7 @@ class fixed_size_memory_resource
*/
free_list blocks_from_upstream(cuda_stream_view stream)
{
void* ptr = get_upstream()->allocate(upstream_chunk_size_, stream);
void* ptr = get_upstream_resource().allocate_async(upstream_chunk_size_, stream);
block_type block{ptr};
upstream_blocks_.push_back(block);

Expand Down Expand Up @@ -211,7 +215,7 @@ class fixed_size_memory_resource
lock_guard lock(this->get_mutex());

for (auto block : upstream_blocks_) {
get_upstream()->deallocate(block.pointer(), upstream_chunk_size_);
get_upstream_resource().deallocate(block.pointer(), upstream_chunk_size_);
}
upstream_blocks_.clear();
}
Expand Down
10 changes: 7 additions & 3 deletions include/rmm/mr/device/limiting_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ class limiting_resource_adaptor final : public device_memory_resource {
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_;
}

/**
* @brief Query the number of bytes that have been allocated. Note that
Expand Down Expand Up @@ -162,8 +166,8 @@ class limiting_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto const* cast = dynamic_cast<limiting_resource_adaptor<Upstream> const*>(&other);
if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); }
return upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

// maximum bytes this allocator is allowed to allocate.
Expand Down
10 changes: 7 additions & 3 deletions include/rmm/mr/device/logging_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ class logging_resource_adaptor final : public device_memory_resource {
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_;
}

/**
* @brief Flush logger contents.
Expand Down Expand Up @@ -277,8 +281,8 @@ class logging_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto const* cast = dynamic_cast<logging_resource_adaptor<Upstream> const*>(&other);
if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); }
return upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

// make_logging_adaptor needs access to private get_default_filename
Expand Down
10 changes: 7 additions & 3 deletions include/rmm/mr/device/pool_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ class pool_memory_resource final
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_mr_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_mr_;
}

/**
* @brief Computes the size of the current pool
Expand Down Expand Up @@ -503,7 +507,7 @@ class pool_memory_resource final
if (size == 0) { return {}; }

try {
void* ptr = get_upstream()->allocate_async(size, stream);
void* ptr = get_upstream_resource().allocate_async(size, stream);
return std::optional<block_type>{
*upstream_blocks_.emplace(static_cast<char*>(ptr), size, true).first};
} catch (std::exception const& e) {
Expand Down Expand Up @@ -570,7 +574,7 @@ class pool_memory_resource final
lock_guard lock(this->get_mutex());

for (auto block : upstream_blocks_) {
get_upstream()->deallocate(block.pointer(), block.size());
get_upstream_resource().deallocate(block.pointer(), block.size());
}
upstream_blocks_.clear();
#ifdef RMM_POOL_TRACK_ALLOCATIONS
Expand Down
10 changes: 7 additions & 3 deletions include/rmm/mr/device/statistics_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ class statistics_resource_adaptor final : public device_memory_resource {
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_;
}

/**
* @brief Returns a `counter` struct for this adaptor containing the current,
Expand Down Expand Up @@ -209,8 +213,8 @@ class statistics_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<statistics_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr ? upstream_->is_equal(*cast->get_upstream())
: upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

counter bytes_; // peak, current and total allocated bytes
Expand Down
14 changes: 8 additions & 6 deletions include/rmm/mr/device/thread_safe_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ class thread_safe_resource_adaptor final : public device_memory_resource {
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_;
}

private:
/**
Expand Down Expand Up @@ -119,11 +123,9 @@ class thread_safe_resource_adaptor final : public device_memory_resource {
bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other) { return true; }
auto thread_safe_other = dynamic_cast<thread_safe_resource_adaptor<Upstream> const*>(&other);
if (thread_safe_other != nullptr) {
return upstream_->is_equal(*thread_safe_other->get_upstream());
}
return upstream_->is_equal(other);
auto cast = dynamic_cast<thread_safe_resource_adaptor<Upstream> const*>(&other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

std::mutex mutable mtx; // mutex for thread safe access to upstream
Expand Down
10 changes: 7 additions & 3 deletions include/rmm/mr/device/tracking_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ class tracking_resource_adaptor final : public device_memory_resource {
/**
* @briefreturn{Upstream* to the upstream memory resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] Upstream* get_upstream()
const noexcept
{
return upstream_;
}

/**
* @brief Get the outstanding allocations map
Expand Down Expand Up @@ -264,8 +268,8 @@ class tracking_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<tracking_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr ? upstream_->is_equal(*cast->get_upstream())
: upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

bool capture_stacks_; // whether or not to capture call stacks
Expand Down
10 changes: 7 additions & 3 deletions tests/device_check_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ class device_check_resource_adaptor final : public rmm::mr::device_memory_resour
/**
* @briefreturn{device_memory_resource* to the upstream memory resource}
*/
[[nodiscard]] device_memory_resource* get_upstream() const noexcept { return upstream_; }
[[deprecated("Use get_upstream_resource instead")]] [[nodiscard]] device_memory_resource*
get_upstream() const noexcept
{
return upstream_;
}

private:
[[nodiscard]] bool check_device_id() const { return device_id == rmm::get_current_cuda_device(); }
Expand All @@ -64,8 +68,8 @@ class device_check_resource_adaptor final : public rmm::mr::device_memory_resour
{
if (this == &other) { return true; }
auto const* cast = dynamic_cast<device_check_resource_adaptor const*>(&other);
if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); }
return upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

rmm::cuda_device_id device_id;
Expand Down
9 changes: 0 additions & 9 deletions tests/mr/device/adaptor_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,6 @@ TYPED_TEST(AdaptorTest, Equality)
}
}

TYPED_TEST(AdaptorTest, GetUpstream)
{
if constexpr (std::is_same_v<TypeParam, owning_wrapper>) {
EXPECT_TRUE(this->mr->wrapped().get_upstream()->is_equal(this->cuda));
} else {
EXPECT_TRUE(this->mr->get_upstream()->is_equal(this->cuda));
}
}

TYPED_TEST(AdaptorTest, GetUpstreamResource)
{
rmm::device_async_resource_ref expected{this->cuda};
Expand Down
5 changes: 3 additions & 2 deletions tests/mr/device/statistics_mr_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ TEST(StatisticsTest, PeakAllocations)

TEST(StatisticsTest, MultiTracking)
{
statistics_adaptor mr{rmm::mr::get_current_device_resource()};
auto* orig_device_resource = rmm::mr::get_current_device_resource();
statistics_adaptor mr{orig_device_resource};
rmm::mr::set_current_device_resource(&mr);

std::vector<std::shared_ptr<rmm::device_buffer>> allocations;
Expand Down Expand Up @@ -171,7 +172,7 @@ TEST(StatisticsTest, MultiTracking)
EXPECT_EQ(inner_mr.get_allocations_counter().peak, 5);

// Reset the current device resource
rmm::mr::set_current_device_resource(mr.get_upstream());
rmm::mr::set_current_device_resource(orig_device_resource);
}

TEST(StatisticsTest, NegativeInnerTracking)
Expand Down
5 changes: 3 additions & 2 deletions tests/mr/device/tracking_mr_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ TEST(TrackingTest, AllocationsLeftWithoutStacks)

TEST(TrackingTest, MultiTracking)
{
tracking_adaptor mr{rmm::mr::get_current_device_resource(), true};
auto* orig_device_resource = rmm::mr::get_current_device_resource();
tracking_adaptor mr{orig_device_resource, true};
rmm::mr::set_current_device_resource(&mr);

std::vector<std::shared_ptr<rmm::device_buffer>> allocations;
Expand Down Expand Up @@ -140,7 +141,7 @@ TEST(TrackingTest, MultiTracking)
EXPECT_EQ(inner_mr.get_allocated_bytes(), 0);

// Reset the current device resource
rmm::mr::set_current_device_resource(mr.get_upstream());
rmm::mr::set_current_device_resource(orig_device_resource);
}

TEST(TrackingTest, NegativeInnerTracking)
Expand Down
Loading