diff --git a/cpp/include/raft/core/device_resources_snmg.hpp b/cpp/include/raft/core/device_resources_snmg.hpp index fc454e994d..da5c39a2f7 100644 --- a/cpp/include/raft/core/device_resources_snmg.hpp +++ b/cpp/include/raft/core/device_resources_snmg.hpp @@ -60,7 +60,7 @@ class device_resources_snmg : public device_resources { /** * @brief Construct a SNMG resources instance with all available GPUs */ - device_resources_snmg() : device_resources(), root_rank_(0), + device_resources_snmg() : device_resources(), root_rank_(0) { cudaGetDevice(&main_gpu_id_); @@ -164,12 +164,13 @@ class device_resources_snmg : public device_resources { rmm::percent_of_free_device_memory(percent_of_free_memory); // check limit for each device raft::resource::set_workspace_to_pool_resource(get_device_resources(rank), limit); } + cudaSetDevice(this->main_gpu_id_); } bool has_resource_factory(resource::resource_type resource_type) const override { cudaSetDevice(this->main_gpu_id_); - raft::resource::has_resource_factory(resource_type); + return raft::resources::has_resource_factory(resource_type); } /** Destroys all held-up resources */ @@ -180,6 +181,7 @@ class device_resources_snmg : public device_resources { RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank))); RAFT_NCCL_TRY(ncclCommDestroy(get_nccl_comm(rank))); } + cudaSetDevice(this->main_gpu_id_); } private: @@ -197,6 +199,7 @@ class device_resources_snmg : public device_resources { // ideally add the ncclComm_t to the device_resources object with // raft::comms::build_comms_nccl_only } + cudaSetDevice(this->main_gpu_id_); } int root_rank_; diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp index b0827d8e11..44525edb23 100644 --- a/cpp/include/raft/core/resources.hpp +++ b/cpp/include/raft/core/resources.hpp @@ -72,6 +72,7 @@ class resources { resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {} resources(resources&&) = delete; resources& operator=(resources&&) = delete; + virtual ~resources() {} /** * @brief Returns true if a resource_factory has been registered for the @@ -79,7 +80,7 @@ class resources { * @param resource_type resource type to check * @return true if resource_factory is registered for the given resource_type */ - bool has_resource_factory(resource::resource_type resource_type) const + virtual bool has_resource_factory(resource::resource_type resource_type) const { std::lock_guard _(mutex_); return factories_.at(resource_type).first != resource::resource_type::LAST_KEY;