Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAFT PR 2030: Device resource manager thread safety #2

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 28 additions & 62 deletions cpp/include/raft/core/device_resources_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,6 @@ struct device_resources_manager {
// Container for underlying device resources to be re-used across host
// threads for each device
std::vector<resource_components> per_device_components_;
// Container for device_resources objects shared among threads. The index
// of the outer vector is the thread id of the thread requesting resources
// modulo the total number of resources managed by this object. The inner
// vector contains all resources associated with that id across devices
// in any order.
std::vector<std::vector<raft::device_resources>> resources_{};

// Return a lock for accessing shared data
[[nodiscard]] auto get_lock() const { return std::unique_lock{manager_mutex_}; }
Expand All @@ -271,72 +265,44 @@ struct device_resources_manager {
// all host threads.
auto const& get_device_resources_(int device_id)
{
// Each thread maintains an independent list of devices it has
// accessed. If it has not marked a device as initialized, it
// acquires a lock to initialize it exactly once. This means that each
// thread will lock once for a particular device and not proceed until
// some thread has actually generated the corresponding device
// components
thread_local auto initialized_devices = std::vector<int>{};
auto res_iter = decltype(std::end(resources_[0])){};
if (std::find(std::begin(initialized_devices), std::end(initialized_devices), device_id) ==
std::end(initialized_devices)) {
thread_local auto thread_resources = std::vector<std::optional<raft::device_resources>>([]() {
auto result = 0;
RAFT_CUDA_TRY(cudaGetDeviceCount(&result));
RAFT_EXPECTS(result != 0, "No CUDA devices found");
return result;
}());
if (!thread_resources[device_id]) {
// Only lock if we have not previously accessed this device on this
// thread
auto lock = get_lock();
initialized_devices.push_back(device_id);
// If we are building components, do not allow any further changes to
// resource parameters.
params_finalized_ = true;

if (resources_.empty()) {
// We will potentially need as many device_resources objects as there are combinations of
// streams and pools on a given device.
resources_.resize(std::max(params_.stream_count.value_or(1), std::size_t{1}) *
std::max(params_.pool_count, std::size_t{1}));
}

auto res_idx = get_thread_id() % resources_.size();
// Check to see if we have constructed device_resources for the
// requested device at the index assigned to this thread
res_iter = std::find_if(std::begin(resources_[res_idx]),
std::end(resources_[res_idx]),
[device_id](auto&& res) { return res.get_device() == device_id; });
// Even if we have not yet built device_resources for the current
// device, we may have already built the underlying components, since
// multiple device_resources may point to the same components.
auto component_iter = std::find_if(
std::begin(per_device_components_),
std::end(per_device_components_),
[device_id](auto&& components) { return components.get_device_id() == device_id; });

if (res_iter == std::end(resources_[res_idx])) {
// Even if we have not yet built device_resources for the current
// device, we may have already built the underlying components, since
// multiple device_resources may point to the same components.
auto component_iter = std::find_if(
std::begin(per_device_components_),
std::end(per_device_components_),
[device_id](auto&& components) { return components.get_device_id() == device_id; });
if (component_iter == std::end(per_device_components_)) {
// Build components for this device if we have not yet done so on
// another thread
per_device_components_.emplace_back(device_id, params_);
component_iter = std::prev(std::end(per_device_components_));
}
auto scoped_device = device_setter(device_id);
// Build the device_resources object for this thread out of shared
// components
resources_[res_idx].emplace_back(component_iter->get_stream(),
component_iter->get_pool(),
component_iter->get_workspace_memory_resource(),
component_iter->get_workspace_allocation_limit());
res_iter = std::prev(std::end(resources_[res_idx]));
if (component_iter == std::end(per_device_components_)) {
// Build components for this device if we have not yet done so on
// another thread
per_device_components_.emplace_back(device_id, params_);
component_iter = std::prev(std::end(per_device_components_));
}
} else {
auto res_idx = get_thread_id() % resources_.size();
// If we have previously accessed this device on this thread, we do not
// need to lock. We know that this thread already initialized the
// resources it requires for this device if no other thread had already done so, so we simply
// retrieve the previously-generated resources.
res_iter = std::find_if(std::begin(resources_[res_idx]),
std::end(resources_[res_idx]),
[device_id](auto&& res) { return res.get_device() == device_id; });
auto scoped_device = device_setter(device_id);
// Build the device_resources object for this thread out of shared
// components
thread_resources[device_id].emplace(component_iter->get_stream(),
component_iter->get_pool(),
component_iter->get_workspace_memory_resource(),
component_iter->get_workspace_allocation_limit());
}
return *res_iter;

return thread_resources[device_id].value();
}

// Thread-safe setter for the number of streams
Expand Down
Loading