Skip to content

Commit

Permalink
Fix Grace-specific issues in CAGRA
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Dec 11, 2024
1 parent e1a5708 commit ce56d93
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
11 changes: 11 additions & 0 deletions cpp/src/neighbors/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,17 @@ struct batch_load_iterator {

/** A single batch of data residing in device memory. */
struct batch {
~batch() noexcept
{
/*
If there's no copy, there's no allocation owned by the batch.
If there's no allocation, there's no guarantee that the device pointer is stream-ordered.
If there's no stream order guarantee, we must synchronize with the stream before the batch is
destroyed to make sure all GPU operations in that stream finish earlier.
*/
if (!does_copy()) { RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream_)); }
}

/** Logical width of a single row in a batch, in elements of type `T`. */
[[nodiscard]] auto row_width() const -> size_type { return row_width_; }
/** Logical offset of the batch, in rows (`row_width()`) */
Expand Down
16 changes: 14 additions & 2 deletions cpp/src/neighbors/detail/cagra/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class device_matrix_view_from_host {
public:
device_matrix_view_from_host(raft::resources const& res,
raft::host_matrix_view<T, IdxT> host_view)
: host_view_(host_view)
: res_(res), host_view_(host_view)
{
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle()));
Expand All @@ -199,6 +199,17 @@ class device_matrix_view_from_host {
}
}

~device_matrix_view_from_host() noexcept
{
/*
If there's no copy, there's no allocation owned by this struct.
If there's no allocation, there's no guarantee that the device pointer is stream-ordered.
If there's no stream order guarantee, we must synchronize with the stream before the struct is
destroyed to make sure all GPU operations in that stream finish earlier.
*/
if (!allocated_memory()) { raft::resource::sync_stream(res_); }
}

raft::device_matrix_view<T, IdxT> view()
{
return raft::make_device_matrix_view<T, IdxT>(
Expand All @@ -207,9 +218,10 @@ class device_matrix_view_from_host {

T* data_handle() { return device_ptr; }

bool allocated_memory() const { return device_mem_.has_value(); }
[[nodiscard]] bool allocated_memory() const { return device_mem_.has_value(); }

private:
const raft::resources& res_;
std::optional<raft::device_matrix<T, IdxT>> device_mem_;
raft::host_matrix_view<T, IdxT> host_view_;
T* device_ptr;
Expand Down

0 comments on commit ce56d93

Please sign in to comment.