Skip to content

Commit

Permalink
Simplify unique_ptr arithmetics
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 11, 2024
1 parent dd1cc99 commit 292406c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 42 deletions.
12 changes: 6 additions & 6 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ struct index : ann::index {
mdspan<const IdxT, matrix_extent<int64_t>, row_major, graph_accessor> knn_graph)
: ann::index(),
metric_(metric),
dataset_(upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16))),
dataset_(construct_aligned_dataset(res, dataset, 16)),
graph_(make_device_matrix<IdxT, int64_t>(res, 0, 0))
{
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
Expand All @@ -280,14 +280,14 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_);
dataset_ = construct_aligned_dataset(res, dataset, 16);
}

/** Set the dataset reference explicitly to a device matrix view with padding. */
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, layout_stride> dataset)
{
upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_);
dataset_ = construct_aligned_dataset(res, dataset, 16);
}

/**
Expand All @@ -298,22 +298,22 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_);
dataset_ = construct_aligned_dataset(res, dataset, 16);
}

/** Replace the dataset with a new dataset. */
template <typename DatasetT>
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
upcast_dataset_ptr(std::make_unique<DatasetT>(std::move(dataset))).swap(dataset_);
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
}

template <typename DatasetT>
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<int64_t>, DatasetT>>
{
upcast_dataset_ptr(std::move(dataset)).swap(dataset_);
dataset_ = std::move(dataset);
}

/**
Expand Down
20 changes: 4 additions & 16 deletions cpp/include/raft/neighbors/dataset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ auto construct_strided_dataset(const raft::resources& res,
using value_type = typename SrcT::value_type;
using index_type = typename SrcT::index_type;
using layout_type = typename SrcT::layout_type;
using out_type = strided_dataset<value_type, index_type>;
static_assert(extents_type::rank() == 2, "The input must be a matrix.");
static_assert(std::is_same_v<layout_type, layout_right> ||
std::is_same_v<layout_type, layout_right_padded<value_type>> ||
Expand All @@ -133,9 +132,9 @@ auto construct_strided_dataset(const raft::resources& res,

if (device_accessible && row_major && stride_matches) {
// Everything matches: make a non-owning dataset
return std::unique_ptr<out_type>{new non_owning_dataset<value_type, index_type>{
return std::make_unique<non_owning_dataset<value_type, index_type>>(
make_device_strided_matrix_view<const value_type, index_type>(
src.data_handle(), src.extent(0), src.extent(1), required_stride)}};
src.data_handle(), src.extent(0), src.extent(1), required_stride));
}
// Something is wrong: have to make a copy and produce an owning dataset
auto out_layout =
Expand All @@ -161,29 +160,18 @@ auto construct_strided_dataset(const raft::resources& res,
cudaMemcpyDefault,
resource::get_cuda_stream(res)));

return std::unique_ptr<out_type>{new out_owning_type{std::move(out_array), out_layout}};
return std::make_unique<out_owning_type>(std::move(out_array), out_layout);
}

template <typename SrcT>
auto construct_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes)
-> std::unique_ptr<strided_dataset<typename SrcT::value_type, typename SrcT::index_type>>
{
using value_type = typename SrcT::value_type;
using index_type = typename SrcT::index_type;
using out_type = strided_dataset<value_type, index_type>;
constexpr size_t kSize = sizeof(value_type);
uint32_t required_stride =
raft::round_up_safe<size_t>(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize;
return std::unique_ptr<out_type>{construct_strided_dataset(res, src, required_stride).release()};
}

template <typename DatasetT>
auto upcast_dataset_ptr(std::unique_ptr<DatasetT>&& src)
-> std::unique_ptr<dataset<typename DatasetT::index_type>>
{
using out_type = dataset<typename DatasetT::index_type>;
static_assert(std::is_base_of_v<out_type, DatasetT>, "The source must be a child of `dataset`");
return std::unique_ptr<out_type>{src.release()};
return construct_strided_dataset(res, src, required_stride);
}

/** Parameters for VPQ compression. */
Expand Down
46 changes: 26 additions & 20 deletions cpp/include/raft/neighbors/detail/dataset_serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,26 @@ void deserialize(raft::resources const& res,
std::unique_ptr<empty_dataset<IdxT>>& out)
{
auto suggested_dim = deserialize_scalar<uint32_t>(res, is);
return std::make_unique<empty_dataset<IdxT>>(suggested_dim).swap(out);
out = std::make_unique<empty_dataset<IdxT>>(suggested_dim);
}

template <typename DataT, typename IdxT>
void deserialize(raft::resources const& res,
std::istream& is,
std::unique_ptr<strided_dataset<DataT, IdxT>>& out)
{
using out_mdarray_type = device_mdarray<DataT, matrix_extent<IdxT>, layout_stride>;
using out_layout_type = typename out_mdarray_type::layout_type;
using out_container_policy_type = typename out_mdarray_type::container_policy_type;
using out_owning_type = owning_dataset<DataT, IdxT, out_layout_type, out_container_policy_type>;

auto n_rows = deserialize_scalar<IdxT>(res, is);
auto dim = deserialize_scalar<uint32_t>(res, is);
auto stride = deserialize_scalar<uint32_t>(res, is);
auto out_extents = make_extents<IdxT>(n_rows, dim);
auto out_layout = make_strided_layout(out_extents, std::array<IdxT, 2>{stride, 1});
auto out_array = out_mdarray_type{res, out_layout, out_container_policy_type{}};
auto out_array = make_device_matrix<DataT, IdxT>(res, n_rows, stride);

using out_mdarray_type = decltype(out_array);
using out_layout_type = typename out_mdarray_type::layout_type;
using out_container_policy_type = typename out_mdarray_type::container_policy_type;
using out_owning_type = owning_dataset<DataT, IdxT, out_layout_type, out_container_policy_type>;

auto host_arrray = make_host_mdarray<DataT, IdxT>(out_extents);
deserialize_mdspan(res, is, host_arrray.view());
RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(),
Expand All @@ -153,9 +154,8 @@ void deserialize(raft::resources const& res,
n_rows,
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
return std::unique_ptr<strided_dataset<DataT, IdxT>>{
new out_owning_type{std::move(out_array), out_layout}}
.swap(out);

out = std::make_unique<out_owning_type>(std::move(out_array), out_layout);
}

template <typename MathT, typename IdxT>
Expand All @@ -178,9 +178,8 @@ void deserialize(raft::resources const& res,
deserialize_mdspan(res, is, pq_code_book.view());
deserialize_mdspan(res, is, data.view());

return std::unique_ptr<vpq_dataset<MathT, IdxT>>{
new vpq_dataset{std::move(vq_code_book), std::move(pq_code_book), std::move(data)}}
.swap(out);
out = std::make_unique<vpq_dataset<MathT, IdxT>>(
std::move(vq_code_book), std::move(pq_code_book), std::move(data));
}

template <typename IdxT>
Expand All @@ -190,29 +189,34 @@ void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr<d
case kSerializeEmptyDataset: {
std::unique_ptr<empty_dataset<IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
out = std::move(p);
return;
}
case kSerializeStridedDataset:
switch (deserialize_scalar<cudaDataType_t>(res, is)) {
case CUDA_R_32F: {
std::unique_ptr<strided_dataset<float, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
out = std::move(p);
return;
}
case CUDA_R_16F: {
std::unique_ptr<strided_dataset<half, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
out = std::move(p);
return;
}
case CUDA_R_8I: {
std::unique_ptr<strided_dataset<int8_t, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
out = std::move(p);
return;
}
case CUDA_R_8U: {
std::unique_ptr<strided_dataset<uint8_t, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
out = std::move(p);
return;
}
default: break;
}
Expand All @@ -221,12 +225,14 @@ void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr<d
case CUDA_R_32F: {
std::unique_ptr<vpq_dataset<float, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
out = std::move(p);
return;
}
case CUDA_R_16F: {
std::unique_ptr<vpq_dataset<half, IdxT>> p;
deserialize(res, is, p);
return upcast_dataset_ptr(std::move(p)).swap(out);
out = std::move(p);
return;
}
default: break;
}
Expand Down

0 comments on commit 292406c

Please sign in to comment.