Skip to content

Commit

Permalink
Apply the same pattern of waiting by the bench stream rather than syn…
Browse files Browse the repository at this point in the history
…cing the raft stream to CPU in all raft algos
  • Loading branch information
achirkin committed Nov 22, 2023
1 parent 6273248 commit 19cb314
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 73 deletions.
5 changes: 1 addition & 4 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,9 @@ function(ConfigureAnnBench)
add_dependencies(${BENCH_NAME} ANN_BENCH)
else()
add_executable(${BENCH_NAME} ${ConfigureAnnBench_PATH})
# NVTX for the benchmark wrapper is independent from raft::nvtx and tracks the benchmark
# iterations. We have an extra header check here to keep CPU-only builds working
CHECK_INCLUDE_FILE_CXX(nvtx3/nvToolsExt.h NVTX3_HEADERS_FOUND)
target_compile_definitions(
${BENCH_NAME} PRIVATE ANN_BENCH_BUILD_MAIN
$<$<BOOL:${NVTX3_HEADERS_FOUND}>:ANN_BENCH_NVTX3_HEADERS_FOUND>
$<$<BOOL:${GPU_BUILD}>:ANN_BENCH_NVTX3_HEADERS_FOUND>
)
target_link_libraries(${BENCH_NAME} PRIVATE benchmark::benchmark)
endif()
Expand Down
39 changes: 39 additions & 0 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,43 @@ inline raft::distance::DistanceType parse_metric_type(raft::bench::ann::Metric m
throw std::runtime_error("raft supports only metric type of inner product and L2");
}
}

class configured_raft_resources {
public:
configured_raft_resources()
: mr_{rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull},
res_{cudaStreamPerThread},
sync_{nullptr}
{
rmm::mr::set_current_device_resource(&mr_);
RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming));
}

~configured_raft_resources() noexcept
{
RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_));
if (rmm::mr::get_current_device_resource()->is_equal(mr_)) {
rmm::mr::set_current_device_resource(mr_.get_upstream());
}
}

operator raft::resources&() noexcept { return res_; }
operator const raft::resources&() const noexcept { return res_; }

/** Make the given stream wait on all work submitted to the resource. */
void stream_wait(cudaStream_t stream) const
{
RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(res_)));
RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_));
}

/** Get the internal sync event (which otherwise used only in `stream_wait`). */
cudaEvent_t get_sync_event() const { return sync_; }

private:
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> mr_;
raft::device_resources res_;
cudaEvent_t sync_;
};

} // namespace raft::bench::ann
44 changes: 21 additions & 23 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,17 @@ class RaftCagra : public ANN<T> {
: ANN<T>(metric, dim),
index_params_(param),
dimension_(dim),
mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull),
handle_(cudaStreamPerThread),
need_dataset_update_(true),
dataset_(make_device_matrix<T, int64_t>(handle_, 0, 0)),
graph_(make_device_matrix<IdxT, int64_t>(handle_, 0, 0)),
input_dataset_v_(nullptr, 0, 0),
graph_mem_(AllocatorType::Device),
dataset_mem_(AllocatorType::Device)
{
rmm::mr::set_current_device_resource(&mr_);
index_params_.cagra_params.metric = parse_metric_type(metric);
index_params_.ivf_pq_build_params->metric = parse_metric_type(metric);
RAFT_CUDA_TRY(cudaGetDevice(&device_));
}

~RaftCagra() noexcept { rmm::mr::set_current_device_resource(mr_.get_upstream()); }

void build(const T* dataset, size_t nrow, cudaStream_t stream) final;

void set_search_param(const AnnSearchParam& param) override;
Expand Down Expand Up @@ -121,34 +115,33 @@ class RaftCagra : public ANN<T> {
void save_to_hnswlib(const std::string& file) const;

private:
inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type)
{
switch (mem_type) {
case (AllocatorType::HostPinned): return &mr_pinned_;
case (AllocatorType::HostHugePage): return &mr_huge_page_;
default: return rmm::mr::get_current_device_resource();
}
}
// `mr_` must go first to make sure it dies last
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> mr_;
// handle_ must go first to make sure it dies last and all memory allocated in pool
configured_raft_resources handle_{};
raft::mr::cuda_pinned_resource mr_pinned_;
raft::mr::cuda_huge_page_resource mr_huge_page_;
raft::device_resources handle_;
AllocatorType graph_mem_;
AllocatorType dataset_mem_;
BuildParam index_params_;
bool need_dataset_update_;
raft::neighbors::cagra::search_params search_params_;
std::optional<raft::neighbors::cagra::index<T, IdxT>> index_;
int device_;
int dimension_;
raft::device_matrix<IdxT, int64_t, row_major> graph_;
raft::device_matrix<T, int64_t, row_major> dataset_;
raft::device_matrix_view<const T, int64_t, row_major> input_dataset_v_;

inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type)
{
switch (mem_type) {
case (AllocatorType::HostPinned): return &mr_pinned_;
case (AllocatorType::HostHugePage): return &mr_huge_page_;
default: return rmm::mr::get_current_device_resource();
}
}
};

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t stream)
{
auto dataset_view =
raft::make_host_matrix_view<const T, int64_t>(dataset, IdxT(nrow), dimension_);
Expand All @@ -162,7 +155,8 @@ void RaftCagra<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
index_params_.ivf_pq_refine_rate,
index_params_.ivf_pq_build_params,
index_params_.ivf_pq_search_params));
return;

handle_.stream_wait(stream); // RAFT stream -> bench stream
}

inline std::string allocator_to_string(AllocatorType mem_type)
Expand Down Expand Up @@ -257,8 +251,12 @@ void RaftCagra<T, IdxT>::load(const std::string& file)
}

template <typename T, typename IdxT>
void RaftCagra<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const
void RaftCagra<T, IdxT>::search(const T* queries,
int batch_size,
int k,
size_t* neighbors,
float* distances,
cudaStream_t stream) const
{
IdxT* neighbors_IdxT;
rmm::device_uvector<IdxT> neighbors_storage(0, resource::get_cuda_stream(handle_));
Expand All @@ -285,6 +283,6 @@ void RaftCagra<T, IdxT>::search(
raft::resource::get_cuda_stream(handle_));
}

handle_.sync_stream();
handle_.stream_wait(stream); // RAFT stream -> bench stream
}
} // namespace raft::bench::ann
28 changes: 12 additions & 16 deletions cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,13 @@ class RaftIvfFlatGpu : public ANN<T> {
using BuildParam = raft::neighbors::ivf_flat::index_params;

RaftIvfFlatGpu(Metric metric, int dim, const BuildParam& param)
: ANN<T>(metric, dim),
index_params_(param),
dimension_(dim),
mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull)
: ANN<T>(metric, dim), index_params_(param), dimension_(dim)
{
rmm::mr::set_current_device_resource(&mr_);
index_params_.metric = parse_metric_type(metric);
index_params_.conservative_memory_allocation = true;
RAFT_CUDA_TRY(cudaGetDevice(&device_));
}

~RaftIvfFlatGpu() noexcept { rmm::mr::set_current_device_resource(mr_.get_upstream()); }

void build(const T* dataset, size_t nrow, cudaStream_t stream) final;

void set_search_param(const AnnSearchParam& param) override;
Expand All @@ -90,9 +84,8 @@ class RaftIvfFlatGpu : public ANN<T> {
void load(const std::string&) override;

private:
// `mr_` must go first to make sure it dies last
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> mr_;
raft::device_resources handle_;
// handle_ must go first to make sure it dies last and all memory allocated in pool
configured_raft_resources handle_{};
BuildParam index_params_;
raft::neighbors::ivf_flat::search_params search_params_;
std::optional<raft::neighbors::ivf_flat::index<T, IdxT>> index_;
Expand All @@ -101,11 +94,11 @@ class RaftIvfFlatGpu : public ANN<T> {
};

template <typename T, typename IdxT>
void RaftIvfFlatGpu<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t)
void RaftIvfFlatGpu<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t stream)
{
index_.emplace(
raft::neighbors::ivf_flat::build(handle_, index_params_, dataset, IdxT(nrow), dimension_));
return;
handle_.stream_wait(stream); // RAFT stream -> bench stream
}

template <typename T, typename IdxT>
Expand All @@ -131,13 +124,16 @@ void RaftIvfFlatGpu<T, IdxT>::load(const std::string& file)
}

template <typename T, typename IdxT>
void RaftIvfFlatGpu<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances, cudaStream_t) const
void RaftIvfFlatGpu<T, IdxT>::search(const T* queries,
int batch_size,
int k,
size_t* neighbors,
float* distances,
cudaStream_t stream) const
{
static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t");
raft::neighbors::ivf_flat::search(
handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances);
resource::sync_stream(handle_);
return;
handle_.stream_wait(stream); // RAFT stream -> bench stream
}
} // namespace raft::bench::ann
39 changes: 9 additions & 30 deletions cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,9 @@ class RaftIvfPQ : public ANN<T> {
using BuildParam = raft::neighbors::ivf_pq::index_params;

RaftIvfPQ(Metric metric, int dim, const BuildParam& param)
: ANN<T>(metric, dim),
index_params_(param),
dimension_(dim),
mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull)
: ANN<T>(metric, dim), index_params_(param), dimension_(dim)
{
rmm::mr::set_current_device_resource(&mr_);
index_params_.metric = parse_metric_type(metric);
RAFT_CUDA_TRY(cudaGetDevice(&device_));
RAFT_CUDA_TRY(cudaEventCreate(&sync_, cudaEventDisableTiming));
}

~RaftIvfPQ() noexcept
{
rmm::mr::set_current_device_resource(mr_.get_upstream());
RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(sync_));
}

void build(const T* dataset, size_t nrow, cudaStream_t stream) final;
Expand Down Expand Up @@ -97,23 +85,14 @@ class RaftIvfPQ : public ANN<T> {
void load(const std::string&) override;

private:
// `mr_` must go first to make sure it dies last
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> mr_;
raft::device_resources handle_;
cudaEvent_t sync_{nullptr};
// handle_ must go first to make sure it dies last and all memory allocated in pool
configured_raft_resources handle_{};
BuildParam index_params_;
raft::neighbors::ivf_pq::search_params search_params_;
std::optional<raft::neighbors::ivf_pq::index<IdxT>> index_;
int device_;
int dimension_;
float refine_ratio_ = 1.0;
raft::device_matrix_view<const T, IdxT> dataset_;

void stream_wait(cudaStream_t stream) const
{
RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(handle_)));
RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, sync_));
}
};

template <typename T, typename IdxT>
Expand All @@ -137,7 +116,7 @@ void RaftIvfPQ<T, IdxT>::build(const T* dataset, size_t nrow, cudaStream_t strea
auto dataset_v = raft::make_device_matrix_view<const T, IdxT>(dataset, IdxT(nrow), dim_);

index_.emplace(raft::runtime::neighbors::ivf_pq::build(handle_, index_params_, dataset_v));
stream_wait(stream);
handle_.stream_wait(stream); // RAFT stream -> bench stream
}

template <typename T, typename IdxT>
Expand Down Expand Up @@ -186,7 +165,7 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,
neighbors_v,
distances_v,
index_->metric());
stream_wait(stream); // RAFT stream -> bench stream
handle_.stream_wait(stream); // RAFT stream -> bench stream
} else {
auto queries_host = raft::make_host_matrix<T, IdxT>(batch_size, index_->dim());
auto candidates_host = raft::make_host_matrix<IdxT, IdxT>(batch_size, k0);
Expand All @@ -203,9 +182,9 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,
dataset_.data_handle(), dataset_.extent(0), dataset_.extent(1));

// wait for the queries to copy to host in 'stream` and for IVF-PQ::search to finish
RAFT_CUDA_TRY(cudaEventRecord(sync_, resource::get_cuda_stream(handle_)));
RAFT_CUDA_TRY(cudaEventRecord(sync_, stream));
RAFT_CUDA_TRY(cudaEventSynchronize(sync_));
RAFT_CUDA_TRY(cudaEventRecord(handle_.get_sync_event(), resource::get_cuda_stream(handle_)));
RAFT_CUDA_TRY(cudaEventRecord(handle_.get_sync_event(), stream));
RAFT_CUDA_TRY(cudaEventSynchronize(handle_.get_sync_event()));
raft::runtime::neighbors::refine(handle_,
dataset_v,
queries_host.view(),
Expand All @@ -225,7 +204,7 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,

raft::runtime::neighbors::ivf_pq::search(
handle_, search_params_, *index_, queries_v, neighbors_v, distances_v);
stream_wait(stream); // RAFT stream -> bench stream
handle_.stream_wait(stream); // RAFT stream -> bench stream
}
}
} // namespace raft::bench::ann

0 comments on commit 19cb314

Please sign in to comment.