diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 1040fcb7b91..27cb4d96814 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -796,7 +796,7 @@ add_dependencies(cudf jitify_preprocess_run) # Specify the target module library dependencies target_link_libraries( cudf - PUBLIC CCCL::CCCL rmm::rmm $ + PUBLIC CCCL::CCCL rmm::rmm $ CUDA::cuda_driver PRIVATE $ cuco::cuco ZLIB::ZLIB nvcomp::nvcomp kvikio::kvikio $ nanoarrow ) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 9d3a7ce5a4e..97a61b66594 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -125,17 +125,49 @@ rmm::cuda_device_id get_current_cuda_device() return rmm::cuda_device_id{device_id}; } +class primary_context_checker { + public: + void initialize() + { + int device_id{}; + cudaGetDevice(&device_id); // TODO: Handle runtime API error code + cuDeviceGet(&device_handle_, device_id); // TODO: Handle driver API error code + } + + bool is_primary_context_active() + { + int active_state{}; + // TODO: Handle driver API error code + cuDevicePrimaryCtxGetState( + device_handle_, nullptr /* do not query context flags */, &active_state); + return static_cast(active_state); + } + + private: + CUdevice device_handle_{}; +}; + /** * @brief RAII struct to wrap a cuda event and ensure its proper destruction. */ struct cuda_event { - cuda_event() { CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); } - virtual ~cuda_event() { CUDF_ASSERT_CUDA_SUCCESS(cudaEventDestroy(e_)); } + cuda_event() + { + CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); + ctx_checker_.initialize(); + } + virtual ~cuda_event() + { + if (ctx_checker_.is_primary_context_active()) { + CUDF_ASSERT_CUDA_SUCCESS(cudaEventDestroy(e_)); + } + } operator cudaEvent_t() { return e_; } private: cudaEvent_t e_; + primary_context_checker ctx_checker_; }; /**