Skip to content

Commit

Permalink
Improve implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kingcrimsontianyu committed Sep 11, 2024
1 parent b6737b8 commit 00b0549
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions cpp/src/utilities/stream_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#include "driver_types.h"

#include <cudf/detail/utilities/logger.hpp>
#include <cudf/detail/utilities/stream_pool.hpp>
#include <cudf/utilities/default_stream.hpp>
Expand Down Expand Up @@ -125,18 +127,6 @@ rmm::cuda_device_id get_current_cuda_device()
return rmm::cuda_device_id{device_id};
}

/**
* @brief RAII struct to wrap a cuda event and ensure its proper destruction.
*/
struct cuda_event {
cuda_event() { CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); }

operator cudaEvent_t() { return e_; }

private:
cudaEvent_t e_;
};

/**
* @brief Returns a cudaEvent_t for the current thread.
*
Expand All @@ -146,12 +136,13 @@ struct cuda_event {
*/
cudaEvent_t event_for_thread()
{
thread_local std::vector<std::unique_ptr<cuda_event>> thread_events(get_num_cuda_devices());
thread_local std::vector<cudaEvent_t> thread_events(get_num_cuda_devices());
auto const device_id = get_current_cuda_device();
if (not thread_events[device_id.value()]) {
thread_events[device_id.value()] = std::make_unique<cuda_event>();
CUDF_CUDA_TRY(
cudaEventCreateWithFlags(&thread_events[device_id.value()], cudaEventDisableTiming));
}
return *thread_events[device_id.value()];
return thread_events[device_id.value()];
}

/**
Expand Down

0 comments on commit 00b0549

Please sign in to comment.