Skip to content

Commit

Permalink
Make device_resources accessed from device_resources_manager thread-s…
Browse files Browse the repository at this point in the history
…afe (#2030)

Update device_resources_manager to reuse only the memory manager, stream, and stream pools across threads. Create a unique resources object per device for each thread, since the resources object is not thread-safe.

Authors:
  - William Hicks (https://github.com/wphicks)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2030
  • Loading branch information
wphicks authored Dec 5, 2023
1 parent 8e1c62c commit 42e9f15
Showing 1 changed file with 28 additions and 62 deletions.
90 changes: 28 additions & 62 deletions cpp/include/raft/core/device_resources_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,6 @@ struct device_resources_manager {
// Container for underlying device resources to be re-used across host
// threads for each device
std::vector<resource_components> per_device_components_;
// Container for device_resources objects shared among threads. The index
// of the outer vector is the thread id of the thread requesting resources
// modulo the total number of resources managed by this object. The inner
// vector contains all resources associated with that id across devices
// in any order.
std::vector<std::vector<raft::device_resources>> resources_{};

// Return a lock for accessing shared data
[[nodiscard]] auto get_lock() const { return std::unique_lock{manager_mutex_}; }
Expand All @@ -271,72 +265,44 @@ struct device_resources_manager {
// all host threads.
auto const& get_device_resources_(int device_id)
{
// Each thread maintains an independent list of devices it has
// accessed. If it has not marked a device as initialized, it
// acquires a lock to initialize it exactly once. This means that each
// thread will lock once for a particular device and not proceed until
// some thread has actually generated the corresponding device
// components
thread_local auto initialized_devices = std::vector<int>{};
auto res_iter = decltype(std::end(resources_[0])){};
if (std::find(std::begin(initialized_devices), std::end(initialized_devices), device_id) ==
std::end(initialized_devices)) {
thread_local auto thread_resources = std::vector<std::optional<raft::device_resources>>([]() {
auto result = 0;
RAFT_CUDA_TRY(cudaGetDeviceCount(&result));
RAFT_EXPECTS(result != 0, "No CUDA devices found");
return result;
}());
if (!thread_resources[device_id]) {
// Only lock if we have not previously accessed this device on this
// thread
auto lock = get_lock();
initialized_devices.push_back(device_id);
// If we are building components, do not allow any further changes to
// resource parameters.
params_finalized_ = true;

if (resources_.empty()) {
// We will potentially need as many device_resources objects as there are combinations of
// streams and pools on a given device.
resources_.resize(std::max(params_.stream_count.value_or(1), std::size_t{1}) *
std::max(params_.pool_count, std::size_t{1}));
}

auto res_idx = get_thread_id() % resources_.size();
// Check to see if we have constructed device_resources for the
// requested device at the index assigned to this thread
res_iter = std::find_if(std::begin(resources_[res_idx]),
std::end(resources_[res_idx]),
[device_id](auto&& res) { return res.get_device() == device_id; });
// Even if we have not yet built device_resources for the current
// device, we may have already built the underlying components, since
// multiple device_resources may point to the same components.
auto component_iter = std::find_if(
std::begin(per_device_components_),
std::end(per_device_components_),
[device_id](auto&& components) { return components.get_device_id() == device_id; });

if (res_iter == std::end(resources_[res_idx])) {
// Even if we have not yet built device_resources for the current
// device, we may have already built the underlying components, since
// multiple device_resources may point to the same components.
auto component_iter = std::find_if(
std::begin(per_device_components_),
std::end(per_device_components_),
[device_id](auto&& components) { return components.get_device_id() == device_id; });
if (component_iter == std::end(per_device_components_)) {
// Build components for this device if we have not yet done so on
// another thread
per_device_components_.emplace_back(device_id, params_);
component_iter = std::prev(std::end(per_device_components_));
}
auto scoped_device = device_setter(device_id);
// Build the device_resources object for this thread out of shared
// components
resources_[res_idx].emplace_back(component_iter->get_stream(),
component_iter->get_pool(),
component_iter->get_workspace_memory_resource(),
component_iter->get_workspace_allocation_limit());
res_iter = std::prev(std::end(resources_[res_idx]));
if (component_iter == std::end(per_device_components_)) {
// Build components for this device if we have not yet done so on
// another thread
per_device_components_.emplace_back(device_id, params_);
component_iter = std::prev(std::end(per_device_components_));
}
} else {
auto res_idx = get_thread_id() % resources_.size();
// If we have previously accessed this device on this thread, we do not
// need to lock. We know that this thread already initialized the
// resources it requires for this device if no other thread had already done so, so we simply
// retrieve the previously-generated resources.
res_iter = std::find_if(std::begin(resources_[res_idx]),
std::end(resources_[res_idx]),
[device_id](auto&& res) { return res.get_device() == device_id; });
auto scoped_device = device_setter(device_id);
// Build the device_resources object for this thread out of shared
// components
thread_resources[device_id].emplace(component_iter->get_stream(),
component_iter->get_pool(),
component_iter->get_workspace_memory_resource(),
component_iter->get_workspace_allocation_limit());
}
return *res_iter;

return thread_resources[device_id].value();
}

// Thread-safe setter for the number of streams
Expand Down

0 comments on commit 42e9f15

Please sign in to comment.