From 536a3c2bacfdb46accdc705916f7a4fdfcd9641c Mon Sep 17 00:00:00 2001 From: viclafargue Date: Fri, 15 Nov 2024 15:26:12 +0100 Subject: [PATCH] Introducing raft::device_resources_snmg --- .../raft/core/device_resources_snmg.hpp | 148 ++++++++++++++++++ cpp/include/raft/core/nccl_clique.hpp | 148 ------------------ .../raft/core/resource/nccl_clique.hpp | 130 --------------- 3 files changed, 148 insertions(+), 278 deletions(-) create mode 100644 cpp/include/raft/core/device_resources_snmg.hpp delete mode 100644 cpp/include/raft/core/nccl_clique.hpp delete mode 100644 cpp/include/raft/core/resource/nccl_clique.hpp diff --git a/cpp/include/raft/core/device_resources_snmg.hpp b/cpp/include/raft/core/device_resources_snmg.hpp new file mode 100644 index 0000000000..dc21458315 --- /dev/null +++ b/cpp/include/raft/core/device_resources_snmg.hpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include + +#include +#include + +/** + * @brief Error checking macro for NCCL runtime API functions. + * + * Invokes a NCCL runtime API function call, if the call does not return ncclSuccess, throws an + * exception detailing the NCCL error that occurred + */ +#define RAFT_NCCL_TRY(call) \ + do { \ + ncclResult_t const status = (call); \ + if (ncclSuccess != status) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "NCCL error encountered at: ", \ + "call='%s', Reason=%d:%s", \ + #call, \ + status, \ + ncclGetErrorString(status)); \ + throw raft::logic_error(msg); \ + } \ + } while (0); + +namespace raft { + +class device_resources_snmg : public resources { + public: + device_resources_snmg() : resources{}, root_rank_(0) + { + int num_ranks; + RAFT_CUDA_TRY(cudaGetDeviceCount(&num_ranks)); + device_ids_.resize(num_ranks); + std::iota(device_ids_.begin(), device_ids_.end(), 0); + nccl_comms_.resize(num_ranks); + initialize(); + } + + device_resources_snmg(const std::vector& device_ids) + : resources{}, root_rank_(0), device_ids_(device_ids), nccl_comms_(device_ids.size()) + { + initialize(); + } + + device_resources_snmg(const device_resources_snmg& clique) + : resources(clique), + root_rank_(clique.root_rank_), + device_ids_(clique.device_ids_), + nccl_comms_(clique.nccl_comms_), + device_resources_(clique.device_resources_) + { + } + + device_resources_snmg(device_resources_snmg&&) = delete; + device_resources_snmg& operator=(device_resources_snmg&&) = delete; + + inline int set_root_rank(int rank) { this->root_rank_ = rank; } + + inline int get_root_rank() const { return this->root_rank_; } + + inline int get_num_ranks() const { return this->device_ids_.size(); } + + inline int get_device_id(int rank) const { return this->device_ids_[rank]; } + + inline ncclComm_t get_nccl_comm(int rank) const { return this->nccl_comms_[rank]; } + + inline const raft::device_resources& get_device_resources(int rank) const + { + return this->device_resources_[rank]; + } + + inline const raft::device_resources& set_current_device_to_root_rank() const + { + int root_device_id = get_device_id(get_root_rank()); + RAFT_CUDA_TRY(cudaSetDevice(root_device_id)); + return get_device_resources(root_rank_); + } + + inline const raft::device_resources& set_current_device_to_rank(int rank) const + { + RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank))); + return get_device_resources(rank); + } + + void set_memory_pool(int percent_of_free_memory) const + { + for (int rank = 0; rank < get_num_ranks(); rank++) { + RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank))); + size_t limit = + rmm::percent_of_free_device_memory(percent_of_free_memory); // check limit for each device + raft::resource::set_workspace_to_pool_resource(get_device_resources(rank), limit); + } + } + + ~device_resources_snmg() + { +#pragma omp parallel for // necessary to avoid hangs + for (int rank = 0; rank < get_num_ranks(); rank++) { + RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank))); + RAFT_NCCL_TRY(ncclCommDestroy(get_nccl_comm(rank))); + } + } + + private: + void initialize() + { + RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), get_num_ranks(), device_ids_.data())); + + for (int rank = 0; rank < get_num_ranks(); rank++) { + RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank))); + device_resources_.emplace_back(); + + // ideally add the ncclComm_t to the device_resources object with + // raft::comms::build_comms_nccl_only + } + } + + int root_rank_; + std::vector device_ids_; + std::vector nccl_comms_; + std::vector device_resources_; + +}; // class device_resources_snmg + +} // namespace raft diff --git a/cpp/include/raft/core/nccl_clique.hpp b/cpp/include/raft/core/nccl_clique.hpp deleted file mode 100644 index 9aba4175b9..0000000000 --- a/cpp/include/raft/core/nccl_clique.hpp +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -#include - -/** - * @brief Error checking macro for NCCL runtime API functions. - * - * Invokes a NCCL runtime API function call, if the call does not return ncclSuccess, throws an - * exception detailing the NCCL error that occurred - */ -#define RAFT_NCCL_TRY(call) \ - do { \ - ncclResult_t const status = (call); \ - if (ncclSuccess != status) { \ - std::string msg{}; \ - SET_ERROR_MSG(msg, \ - "NCCL error encountered at: ", \ - "call='%s', Reason=%d:%s", \ - #call, \ - status, \ - ncclGetErrorString(status)); \ - throw raft::logic_error(msg); \ - } \ - } while (0); - -namespace raft::core { - -struct nccl_clique { - using pool_mr = rmm::mr::pool_memory_resource; - - /** - * Instantiates a NCCL clique with all available GPUs - * - * @param[in] percent_of_free_memory percentage of device memory to pre-allocate as memory pool - * - */ - nccl_clique(int percent_of_free_memory = 80) - : root_rank_(0), - percent_of_free_memory_(percent_of_free_memory), - per_device_pools_(0), - device_resources_(0) - { - cudaGetDeviceCount(&num_ranks_); - device_ids_.resize(num_ranks_); - std::iota(device_ids_.begin(), device_ids_.end(), 0); - nccl_comms_.resize(num_ranks_); - nccl_clique_init(); - } - - /** - * Instantiates a NCCL clique - * - * Usage example: - * @code{.cpp} - * int n_devices; - * cudaGetDeviceCount(&n_devices); - * std::vector device_ids(n_devices); - * std::iota(device_ids.begin(), device_ids.end(), 0); - * cuvs::neighbors::mg::nccl_clique& clique(device_ids); // first device is the root rank - * @endcode - * - * @param[in] device_ids list of device IDs to be used to initiate the clique - * @param[in] percent_of_free_memory percentage of device memory to pre-allocate as memory pool - * - */ - nccl_clique(const std::vector& device_ids, int percent_of_free_memory = 80) - : root_rank_(0), - num_ranks_(device_ids.size()), - percent_of_free_memory_(percent_of_free_memory), - device_ids_(device_ids), - nccl_comms_(device_ids.size()), - per_device_pools_(0), - device_resources_(0) - { - nccl_clique_init(); - } - - void nccl_clique_init() - { - RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, device_ids_.data())); - - for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(device_ids_[rank])); - - // create a pool memory resource for each device - auto old_mr = rmm::mr::get_current_device_resource(); - per_device_pools_.push_back(std::make_unique( - old_mr, rmm::percent_of_free_device_memory(percent_of_free_memory_))); - rmm::cuda_device_id id(device_ids_[rank]); - rmm::mr::set_per_device_resource(id, per_device_pools_.back().get()); - - // create a device resource handle for each device - device_resources_.emplace_back(); - } - - for (int rank = 0; rank < num_ranks_; rank++) { - RAFT_CUDA_TRY(cudaSetDevice(device_ids_[rank])); - raft::resource::sync_stream(device_resources_[rank]); - } - } - - const raft::device_resources& set_current_device_to_root_rank() const - { - int root_device_id = device_ids_[root_rank_]; - RAFT_CUDA_TRY(cudaSetDevice(root_device_id)); - return device_resources_[root_rank_]; - } - - ~nccl_clique() - { -#pragma omp parallel for // necessary to avoid hangs - for (int rank = 0; rank < num_ranks_; rank++) { - cudaSetDevice(device_ids_[rank]); - ncclCommDestroy(nccl_comms_[rank]); - rmm::cuda_device_id id(device_ids_[rank]); - rmm::mr::set_per_device_resource(id, nullptr); - } - } - - int root_rank_; - int num_ranks_; - int percent_of_free_memory_; - std::vector device_ids_; - std::vector nccl_comms_; - std::vector> per_device_pools_; - std::vector device_resources_; -}; - -} // namespace raft::core diff --git a/cpp/include/raft/core/resource/nccl_clique.hpp b/cpp/include/raft/core/resource/nccl_clique.hpp deleted file mode 100644 index 6cbed4abb5..0000000000 --- a/cpp/include/raft/core/resource/nccl_clique.hpp +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -#include - -namespace raft::resource { - -class nccl_clique_resource : public resource { - public: - nccl_clique_resource(std::optional>& device_ids, int percent_of_free_memory) - { - if (device_ids.has_value()) { - clique_ = std::make_unique(*device_ids, percent_of_free_memory); - } else { - clique_ = std::make_unique(percent_of_free_memory); - } - } - - ~nccl_clique_resource() override {} - void* get_resource() override { return clique_.get(); } - - private: - std::unique_ptr clique_; -}; - -/** Factory that knows how to construct a specific raft::resource to populate the res_t. */ -class nccl_clique_resource_factory : public resource_factory { - public: - nccl_clique_resource_factory(const std::optional>& device_ids, - int percent_of_free_memory) - : device_ids(device_ids), percent_of_free_memory(percent_of_free_memory) - { - } - - resource_type get_resource_type() override { return resource_type::NCCL_CLIQUE; } - resource* make_resource() override - { - return new nccl_clique_resource(this->device_ids, this->percent_of_free_memory); - } - - std::optional> device_ids; - int percent_of_free_memory; -}; - -inline const raft::core::nccl_clique& build_nccl_clique( - resources const& res, - const std::optional>& device_ids, - int percent_of_free_memory) -{ - if (!res.has_resource_factory(resource_type::NCCL_CLIQUE)) { - res.add_resource_factory( - std::make_shared(device_ids, percent_of_free_memory)); - } else { - RAFT_LOG_WARN("Attempted re-initialize the NCCL clique on a RAFT resource."); - } - return *res.get_resource(resource_type::NCCL_CLIQUE); -} - -/** - * @defgroup nccl_clique_resource resource functions - * @{ - */ - -/** - * Initializes a NCCL clique and sets it into a raft resource instance - * - * @param[in] res the raft resources object - * @param[in] percent_of_free_memory percentage of device memory to pre-allocate as a memory pool on - * each GPU - * @return NCCL clique - */ -inline const raft::core::nccl_clique& initialize_nccl_clique(resources const& res, - int percent_of_free_memory = 80) -{ - return build_nccl_clique(res, std::nullopt, percent_of_free_memory); -}; - -/** - * Initializes a NCCL clique and sets it into a raft resource instance - * - * @param[in] res the raft resources object - * @param[in] device_ids selection of GPUs initialize the clique on - * @param[in] percent_of_free_memory percentage of device memory to pre-allocate as a memory pool on - * each GPU - * @return NCCL clique - */ -inline const raft::core::nccl_clique& initialize_nccl_clique( - resources const& res, std::optional> device_ids, int percent_of_free_memory = 80) -{ - return build_nccl_clique(res, device_ids, percent_of_free_memory); -}; - -/** - * Retrieves a NCCL clique from raft resource instance, initializes one with default parameters if - * absent - * - * @param[in] res the raft resources object - * @return NCCL clique - */ -inline const raft::core::nccl_clique& get_nccl_clique(resources const& res) -{ - if (!res.has_resource_factory(resource_type::NCCL_CLIQUE)) { - raft::resource::initialize_nccl_clique(res); - } - return *res.get_resource(resource_type::NCCL_CLIQUE); -}; - -/** - * @} - */ - -} // namespace raft::resource