diff --git a/cpp/include/raft/core/device_resources_manager.hpp b/cpp/include/raft/core/device_resources_manager.hpp index ee4b151362..c3482b0c04 100644 --- a/cpp/include/raft/core/device_resources_manager.hpp +++ b/cpp/include/raft/core/device_resources_manager.hpp @@ -254,12 +254,6 @@ struct device_resources_manager { // Container for underlying device resources to be re-used across host // threads for each device std::vector per_device_components_; - // Container for device_resources objects shared among threads. The index - // of the outer vector is the thread id of the thread requesting resources - // modulo the total number of resources managed by this object. The inner - // vector contains all resources associated with that id across devices - // in any order. - std::vector> resources_{}; // Return a lock for accessing shared data [[nodiscard]] auto get_lock() const { return std::unique_lock{manager_mutex_}; } @@ -271,72 +265,44 @@ struct device_resources_manager { // all host threads. auto const& get_device_resources_(int device_id) { - // Each thread maintains an independent list of devices it has - // accessed. If it has not marked a device as initialized, it - // acquires a lock to initialize it exactly once. This means that each - // thread will lock once for a particular device and not proceed until - // some thread has actually generated the corresponding device - // components - thread_local auto initialized_devices = std::vector{}; - auto res_iter = decltype(std::end(resources_[0])){}; - if (std::find(std::begin(initialized_devices), std::end(initialized_devices), device_id) == - std::end(initialized_devices)) { + thread_local auto thread_resources = std::vector>([]() { + auto result = 0; + RAFT_CUDA_TRY(cudaGetDeviceCount(&result)); + RAFT_EXPECTS(result != 0, "No CUDA devices found"); + return result; + }()); + if (!thread_resources[device_id]) { // Only lock if we have not previously accessed this device on this // thread auto lock = get_lock(); - initialized_devices.push_back(device_id); // If we are building components, do not allow any further changes to // resource parameters. params_finalized_ = true; - if (resources_.empty()) { - // We will potentially need as many device_resources objects as there are combinations of - // streams and pools on a given device. - resources_.resize(std::max(params_.stream_count.value_or(1), std::size_t{1}) * - std::max(params_.pool_count, std::size_t{1})); - } - - auto res_idx = get_thread_id() % resources_.size(); - // Check to see if we have constructed device_resources for the - // requested device at the index assigned to this thread - res_iter = std::find_if(std::begin(resources_[res_idx]), - std::end(resources_[res_idx]), - [device_id](auto&& res) { return res.get_device() == device_id; }); + // Even if we have not yet built device_resources for the current + // device, we may have already built the underlying components, since + // multiple device_resources may point to the same components. + auto component_iter = std::find_if( + std::begin(per_device_components_), + std::end(per_device_components_), + [device_id](auto&& components) { return components.get_device_id() == device_id; }); - if (res_iter == std::end(resources_[res_idx])) { - // Even if we have not yet built device_resources for the current - // device, we may have already built the underlying components, since - // multiple device_resources may point to the same components. - auto component_iter = std::find_if( - std::begin(per_device_components_), - std::end(per_device_components_), - [device_id](auto&& components) { return components.get_device_id() == device_id; }); - if (component_iter == std::end(per_device_components_)) { - // Build components for this device if we have not yet done so on - // another thread - per_device_components_.emplace_back(device_id, params_); - component_iter = std::prev(std::end(per_device_components_)); - } - auto scoped_device = device_setter(device_id); - // Build the device_resources object for this thread out of shared - // components - resources_[res_idx].emplace_back(component_iter->get_stream(), - component_iter->get_pool(), - component_iter->get_workspace_memory_resource(), - component_iter->get_workspace_allocation_limit()); - res_iter = std::prev(std::end(resources_[res_idx])); + if (component_iter == std::end(per_device_components_)) { + // Build components for this device if we have not yet done so on + // another thread + per_device_components_.emplace_back(device_id, params_); + component_iter = std::prev(std::end(per_device_components_)); } - } else { - auto res_idx = get_thread_id() % resources_.size(); - // If we have previously accessed this device on this thread, we do not - // need to lock. We know that this thread already initialized the - // resources it requires for this device if no other thread had already done so, so we simply - // retrieve the previously-generated resources. - res_iter = std::find_if(std::begin(resources_[res_idx]), - std::end(resources_[res_idx]), - [device_id](auto&& res) { return res.get_device() == device_id; }); + auto scoped_device = device_setter(device_id); + // Build the device_resources object for this thread out of shared + // components + thread_resources[device_id].emplace(component_iter->get_stream(), + component_iter->get_pool(), + component_iter->get_workspace_memory_resource(), + component_iter->get_workspace_allocation_limit()); } - return *res_iter; + + return thread_resources[device_id].value(); } // Thread-safe setter for the number of streams