Skip to content

Commit

Permalink
Improve implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kingcrimsontianyu committed Sep 10, 2024
1 parent d167d6e commit 4985426
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions cpp/src/utilities/stream_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,40 +125,39 @@ rmm::cuda_device_id get_current_cuda_device()
return rmm::cuda_device_id{device_id};
}

class primary_context_checker {
class cuda_context_checker {
public:
void initialize()
static bool is_primary_context_active()
{
int device_id{};
cudaGetDevice(&device_id); // TODO: Handle runtime API error code
cuDeviceGet(&device_handle_, device_id); // TODO: Handle driver API error code
}
// TODO: Handle driver API error code
CUcontext current_ctx{};
CUdevice device_handle{};

cuCtxGetCurrent(&current_ctx);
if (current_ctx == nullptr) {
device_handle = 0;
} else {
cuCtxGetDevice(&device_handle);
}

bool is_primary_context_active()
{
// Whether the current context, if it exists, is a primary or user-created context,
// here we use the primary context to determine if cudaDeviceReset() has been called.
int active_state{};
// TODO: Handle driver API error code
cuDevicePrimaryCtxGetState(
device_handle_, nullptr /* do not query context flags */, &active_state);
device_handle, nullptr /* do not query context flags */, &active_state);

return static_cast<bool>(active_state);
}

private:
CUdevice device_handle_{};
};

/**
* @brief RAII struct to wrap a cuda event and ensure its proper destruction.
*/
struct cuda_event {
cuda_event()
{
CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming));
ctx_checker_.initialize();
}
cuda_event() { CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); }
virtual ~cuda_event()
{
if (ctx_checker_.is_primary_context_active()) {
if (cuda_context_checker::is_primary_context_active()) {
CUDF_ASSERT_CUDA_SUCCESS(cudaEventDestroy(e_));
}
}
Expand All @@ -167,7 +166,6 @@ struct cuda_event {

private:
cudaEvent_t e_;
primary_context_checker ctx_checker_;
};

/**
Expand Down

0 comments on commit 4985426

Please sign in to comment.