diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 97a61b66594..aaefed78d57 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -125,40 +125,39 @@ rmm::cuda_device_id get_current_cuda_device() return rmm::cuda_device_id{device_id}; } -class primary_context_checker { +class cuda_context_checker { public: - void initialize() + static bool is_primary_context_active() { - int device_id{}; - cudaGetDevice(&device_id); // TODO: Handle runtime API error code - cuDeviceGet(&device_handle_, device_id); // TODO: Handle driver API error code - } + // TODO: Handle driver API error code + CUcontext current_ctx{}; + CUdevice device_handle{}; + + cuCtxGetCurrent(¤t_ctx); + if (current_ctx == nullptr) { + device_handle = 0; + } else { + cuCtxGetDevice(&device_handle); + } - bool is_primary_context_active() - { + // Whether the current context, if it exists, is a primary or user-created context, + // here we use the primary context to determine if cudaDeviceReset() has been called. int active_state{}; - // TODO: Handle driver API error code cuDevicePrimaryCtxGetState( - device_handle_, nullptr /* do not query context flags */, &active_state); + device_handle, nullptr /* do not query context flags */, &active_state); + return static_cast(active_state); } - - private: - CUdevice device_handle_{}; }; /** * @brief RAII struct to wrap a cuda event and ensure its proper destruction. */ struct cuda_event { - cuda_event() - { - CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); - ctx_checker_.initialize(); - } + cuda_event() { CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); } virtual ~cuda_event() { - if (ctx_checker_.is_primary_context_active()) { + if (cuda_context_checker::is_primary_context_active()) { CUDF_ASSERT_CUDA_SUCCESS(cudaEventDestroy(e_)); } } @@ -167,7 +166,6 @@ struct cuda_event { private: cudaEvent_t e_; - primary_context_checker ctx_checker_; }; /**