Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Nov 26, 2024
1 parent f73ae7c commit 33a4cd1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
7 changes: 5 additions & 2 deletions cpp/include/raft/core/device_resources_snmg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class device_resources_snmg : public device_resources {
/**
* @brief Construct a SNMG resources instance with all available GPUs
*/
device_resources_snmg() : device_resources(), root_rank_(0),
device_resources_snmg() : device_resources(), root_rank_(0)
{
cudaGetDevice(&main_gpu_id_);

Expand Down Expand Up @@ -164,12 +164,13 @@ class device_resources_snmg : public device_resources {
rmm::percent_of_free_device_memory(percent_of_free_memory); // check limit for each device
raft::resource::set_workspace_to_pool_resource(get_device_resources(rank), limit);
}
cudaSetDevice(this->main_gpu_id_);
}

bool has_resource_factory(resource::resource_type resource_type) const override
{
cudaSetDevice(this->main_gpu_id_);
raft::resource::has_resource_factory(resource_type);
return raft::resources::has_resource_factory(resource_type);
}

/** Destroys all held-up resources */
Expand All @@ -180,6 +181,7 @@ class device_resources_snmg : public device_resources {
RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank)));
RAFT_NCCL_TRY(ncclCommDestroy(get_nccl_comm(rank)));
}
cudaSetDevice(this->main_gpu_id_);
}

private:
Expand All @@ -197,6 +199,7 @@ class device_resources_snmg : public device_resources {
// ideally add the ncclComm_t to the device_resources object with
// raft::comms::build_comms_nccl_only
}
cudaSetDevice(this->main_gpu_id_);
}

int root_rank_;
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/core/resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,15 @@ class resources {
resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {}
resources(resources&&) = delete;
resources& operator=(resources&&) = delete;
virtual ~resources() {}

/**
* @brief Returns true if a resource_factory has been registered for the
* given resource_type, false otherwise.
* @param resource_type resource type to check
* @return true if resource_factory is registered for the given resource_type
*/
bool has_resource_factory(resource::resource_type resource_type) const
virtual bool has_resource_factory(resource::resource_type resource_type) const
{
std::lock_guard<std::mutex> _(mutex_);
return factories_.at(resource_type).first != resource::resource_type::LAST_KEY;
Expand Down

0 comments on commit 33a4cd1

Please sign in to comment.