Skip to content

Commit

Permalink
Merge branch 'branch-25.02' into diskann-wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Dec 12, 2024
2 parents 63e02ff + ef16a9e commit 48a6a9d
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 15 deletions.
78 changes: 75 additions & 3 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,77 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t
return std::make_unique<out_owning_type>(std::move(out_array), out_layout);
}

/**
* @brief Contstruct a strided matrix from any mdarray.
*
* This function constructs an owning device matrix and copies the data.
* When the data is copied, padding elements are filled with zeroes.
*
* @tparam DataT
* @tparam IdxT
* @tparam LayoutPolicy
* @tparam ContainerPolicy
*
* @param[in] res raft resources handle
* @param[in] src the source mdarray or mdspan
* @param[in] required_stride the leading dimension (in elements)
* @return owning current-device-accessible strided matrix
*/
template <typename DataT, typename IdxT, typename LayoutPolicy, typename ContainerPolicy>
auto make_strided_dataset(
const raft::resources& res,
raft::mdarray<DataT, raft::matrix_extent<IdxT>, LayoutPolicy, ContainerPolicy>&& src,
uint32_t required_stride) -> std::unique_ptr<strided_dataset<DataT, IdxT>>
{
using value_type = DataT;
using index_type = IdxT;
using layout_type = LayoutPolicy;
using container_policy_type = ContainerPolicy;
static_assert(std::is_same_v<layout_type, raft::layout_right> ||
std::is_same_v<layout_type, raft::layout_right_padded<value_type>> ||
std::is_same_v<layout_type, raft::layout_stride>,
"The input must be row-major");
RAFT_EXPECTS(src.extent(1) <= required_stride,
"The input row length must be not larger than the desired stride.");
const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1);
const bool stride_matches = required_stride == src_stride;

auto out_layout =
raft::make_strided_layout(src.extents(), std::array<index_type, 2>{required_stride, 1});

using out_mdarray_type = raft::device_matrix<value_type, index_type>;
using out_layout_type = typename out_mdarray_type::layout_type;
using out_container_policy_type = typename out_mdarray_type::container_policy_type;
using out_owning_type =
owning_dataset<value_type, index_type, out_layout_type, out_container_policy_type>;

if constexpr (std::is_same_v<layout_type, out_layout_type> &&
std::is_same_v<container_policy_type, out_container_policy_type>) {
if (stride_matches) {
// Everything matches, we can own the mdarray
return std::make_unique<out_owning_type>(std::move(src), out_layout);
}
}
// Something is wrong: have to make a copy and produce an owning dataset
auto out_array =
raft::make_device_matrix<value_type, index_type>(res, src.extent(0), required_stride);

RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(),
0,
out_array.size() * sizeof(value_type),
raft::resource::get_cuda_stream(res)));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(),
sizeof(value_type) * required_stride,
src.data_handle(),
sizeof(value_type) * src_stride,
sizeof(value_type) * src.extent(1),
src.extent(0),
cudaMemcpyDefault,
raft::resource::get_cuda_stream(res)));

return std::make_unique<out_owning_type>(std::move(out_array), out_layout);
}

/**
* @brief Contstruct a strided matrix from any mdarray or mdspan.
*
Expand All @@ -278,14 +349,15 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t
* @return maybe owning current-device-accessible strided matrix
*/
template <typename SrcT>
auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes = 16)
auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_bytes = 16)
-> std::unique_ptr<strided_dataset<typename SrcT::value_type, typename SrcT::index_type>>
{
using value_type = typename SrcT::value_type;
using source_type = std::remove_cv_t<std::remove_reference_t<SrcT>>;
using value_type = typename source_type::value_type;
constexpr size_t kSize = sizeof(value_type);
uint32_t required_stride =
raft::round_up_safe<size_t>(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize;
return make_strided_dataset(res, src, required_stride);
return make_strided_dataset(res, std::forward<SrcT>(src), required_stride);
}
/**
* @brief VPQ compressed dataset.
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/neighbors/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,17 @@ struct batch_load_iterator {

/** A single batch of data residing in device memory. */
struct batch {
~batch() noexcept
{
/*
If there's no copy, there's no allocation owned by the batch.
If there's no allocation, there's no guarantee that the device pointer is stream-ordered.
If there's no stream order guarantee, we must synchronize with the stream before the batch is
destroyed to make sure all GPU operations in that stream finish earlier.
*/
if (!does_copy()) { RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream_)); }
}

/** Logical width of a single row in a batch, in elements of type `T`. */
[[nodiscard]] auto row_width() const -> size_type { return row_width_; }
/** Logical offset of the batch, in rows (`row_width()`) */
Expand Down
16 changes: 14 additions & 2 deletions cpp/src/neighbors/detail/cagra/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class device_matrix_view_from_host {
public:
device_matrix_view_from_host(raft::resources const& res,
raft::host_matrix_view<T, IdxT> host_view)
: host_view_(host_view)
: res_(res), host_view_(host_view)
{
cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle()));
Expand All @@ -199,6 +199,17 @@ class device_matrix_view_from_host {
}
}

~device_matrix_view_from_host() noexcept
{
/*
If there's no copy, there's no allocation owned by this struct.
If there's no allocation, there's no guarantee that the device pointer is stream-ordered.
If there's no stream order guarantee, we must synchronize with the stream before the struct is
destroyed to make sure all GPU operations in that stream finish earlier.
*/
if (!allocated_memory()) { raft::resource::sync_stream(res_); }
}

raft::device_matrix_view<T, IdxT> view()
{
return raft::make_device_matrix_view<T, IdxT>(
Expand All @@ -207,9 +218,10 @@ class device_matrix_view_from_host {

T* data_handle() { return device_ptr; }

bool allocated_memory() const { return device_mem_.has_value(); }
[[nodiscard]] bool allocated_memory() const { return device_mem_.has_value(); }

private:
const raft::resources& res_;
std::optional<raft::device_matrix<T, IdxT>> device_mem_;
raft::host_matrix_view<T, IdxT> host_view_;
T* device_ptr;
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/dataset_serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ auto deserialize_strided(raft::resources const& res, std::istream& is)
auto stride = raft::deserialize_scalar<uint32_t>(res, is);
auto host_array = raft::make_host_matrix<DataT, IdxT>(n_rows, dim);
raft::deserialize_mdspan(res, is, host_array.view());
return make_strided_dataset(res, host_array, stride);
return make_strided_dataset(res, std::move(host_array), stride);
}

template <typename MathT, typename IdxT>
Expand Down
23 changes: 14 additions & 9 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,13 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
(const DataT*)database.data(), ps.n_rows, ps.dim);

{
std::optional<raft::host_matrix<DataT, int64_t>> database_host{std::nullopt};
cagra::index<DataT, IdxT> index(handle_, index_params.metric);
if (ps.host_dataset) {
auto database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host.data_handle(), database.data(), database.size(), stream_);
database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host->data_handle(), database.data(), database.size(), stream_);
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);
(const DataT*)database_host->data_handle(), ps.n_rows, ps.dim);

index = cagra::build(handle_, index_params, database_host_view);
} else {
Expand Down Expand Up @@ -567,13 +568,16 @@ class AnnCagraAddNodesTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto initial_database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), initial_database_size, ps.dim);

std::optional<raft::host_matrix<DataT, int64_t>> database_host{std::nullopt};
cagra::index<DataT, IdxT> index(handle_);
if (ps.host_dataset) {
auto database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(
database_host.data_handle(), database.data(), initial_database_view.size(), stream_);
database_host->data_handle(), database.data(), initial_database_view.size(), stream_);
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), initial_database_size, ps.dim);
(const DataT*)database_host->data_handle(), initial_database_size, ps.dim);
// NB: database_host must live no less than the index, because the index _may_be_
// non-onwning
index = cagra::build(handle_, index_params, database_host_view);
} else {
index = cagra::build(handle_, index_params, initial_database_view);
Expand Down Expand Up @@ -763,12 +767,13 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);

std::optional<raft::host_matrix<DataT, int64_t>> database_host{std::nullopt};
cagra::index<DataT, IdxT> index(handle_);
if (ps.host_dataset) {
auto database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host.data_handle(), database.data(), database.size(), stream_);
database_host = raft::make_host_matrix<DataT, int64_t>(ps.n_rows, ps.dim);
raft::copy(database_host->data_handle(), database.data(), database.size(), stream_);
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);
(const DataT*)database_host->data_handle(), ps.n_rows, ps.dim);
index = cagra::build(handle_, index_params, database_host_view);
} else {
index = cagra::build(handle_, index_params, database_view);
Expand Down

0 comments on commit 48a6a9d

Please sign in to comment.