From 00b054986615496b27ec46d68052091d7123eae8 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 11 Sep 2024 13:48:54 -0400 Subject: [PATCH] Improve implementation --- cpp/src/utilities/stream_pool.cpp | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index b53d6ab96b7..b9f3890a42b 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -14,6 +14,8 @@ * limitations under the License. */ +#include "driver_types.h" + #include #include #include @@ -125,18 +127,6 @@ rmm::cuda_device_id get_current_cuda_device() return rmm::cuda_device_id{device_id}; } -/** - * @brief RAII struct to wrap a cuda event and ensure its proper destruction. - */ -struct cuda_event { - cuda_event() { CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); } - - operator cudaEvent_t() { return e_; } - - private: - cudaEvent_t e_; -}; - /** * @brief Returns a cudaEvent_t for the current thread. * @@ -146,12 +136,13 @@ struct cuda_event { */ cudaEvent_t event_for_thread() { - thread_local std::vector> thread_events(get_num_cuda_devices()); + thread_local std::vector thread_events(get_num_cuda_devices()); auto const device_id = get_current_cuda_device(); if (not thread_events[device_id.value()]) { - thread_events[device_id.value()] = std::make_unique(); + CUDF_CUDA_TRY( + cudaEventCreateWithFlags(&thread_events[device_id.value()], cudaEventDisableTiming)); } - return *thread_events[device_id.value()]; + return thread_events[device_id.value()]; } /**