From 292406c390ff16ea2bdd2ba0e0df1bbae470c0d2 Mon Sep 17 00:00:00 2001 From: achirkin Date: Mon, 11 Mar 2024 18:21:36 +0100 Subject: [PATCH] Simplify unique_ptr arithmetics --- cpp/include/raft/neighbors/cagra_types.hpp | 12 ++--- cpp/include/raft/neighbors/dataset.hpp | 20 ++------ .../neighbors/detail/dataset_serialize.hpp | 46 +++++++++++-------- 3 files changed, 36 insertions(+), 42 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 868f214320..9b86ca29f2 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -261,7 +261,7 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16))), + dataset_(construct_aligned_dataset(res, dataset, 16)), graph_(make_device_matrix(res, 0, 0)) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), @@ -280,14 +280,14 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); + dataset_ = construct_aligned_dataset(res, dataset, 16); } /** Set the dataset reference explicitly to a device matrix view with padding. */ void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); + dataset_ = construct_aligned_dataset(res, dataset, 16); } /** @@ -298,7 +298,7 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::host_matrix_view dataset) { - upcast_dataset_ptr(construct_aligned_dataset(res, dataset, 16)).swap(dataset_); + dataset_ = construct_aligned_dataset(res, dataset, 16); } /** Replace the dataset with a new dataset. */ @@ -306,14 +306,14 @@ struct index : ann::index { auto update_dataset(raft::resources const& res, DatasetT&& dataset) -> std::enable_if_t, DatasetT>> { - upcast_dataset_ptr(std::make_unique(std::move(dataset))).swap(dataset_); + dataset_ = std::make_unique(std::move(dataset)); } template auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) -> std::enable_if_t, DatasetT>> { - upcast_dataset_ptr(std::move(dataset)).swap(dataset_); + dataset_ = std::move(dataset); } /** diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp index dd346e6d9d..23ea6054bd 100644 --- a/cpp/include/raft/neighbors/dataset.hpp +++ b/cpp/include/raft/neighbors/dataset.hpp @@ -118,7 +118,6 @@ auto construct_strided_dataset(const raft::resources& res, using value_type = typename SrcT::value_type; using index_type = typename SrcT::index_type; using layout_type = typename SrcT::layout_type; - using out_type = strided_dataset; static_assert(extents_type::rank() == 2, "The input must be a matrix."); static_assert(std::is_same_v || std::is_same_v> || @@ -133,9 +132,9 @@ auto construct_strided_dataset(const raft::resources& res, if (device_accessible && row_major && stride_matches) { // Everything matches: make a non-owning dataset - return std::unique_ptr{new non_owning_dataset{ + return std::make_unique>( make_device_strided_matrix_view( - src.data_handle(), src.extent(0), src.extent(1), required_stride)}}; + src.data_handle(), src.extent(0), src.extent(1), required_stride)); } // Something is wrong: have to make a copy and produce an owning dataset auto out_layout = @@ -161,7 +160,7 @@ auto construct_strided_dataset(const raft::resources& res, cudaMemcpyDefault, resource::get_cuda_stream(res))); - return std::unique_ptr{new out_owning_type{std::move(out_array), out_layout}}; + return std::make_unique(std::move(out_array), out_layout); } template @@ -169,21 +168,10 @@ auto construct_aligned_dataset(const raft::resources& res, const SrcT& src, uint -> std::unique_ptr> { using value_type = typename SrcT::value_type; - using index_type = typename SrcT::index_type; - using out_type = strided_dataset; constexpr size_t kSize = sizeof(value_type); uint32_t required_stride = raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; - return std::unique_ptr{construct_strided_dataset(res, src, required_stride).release()}; -} - -template -auto upcast_dataset_ptr(std::unique_ptr&& src) - -> std::unique_ptr> -{ - using out_type = dataset; - static_assert(std::is_base_of_v, "The source must be a child of `dataset`"); - return std::unique_ptr{src.release()}; + return construct_strided_dataset(res, src, required_stride); } /** Parameters for VPQ compression. */ diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp index 2864d260c1..529569865b 100644 --- a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -124,7 +124,7 @@ void deserialize(raft::resources const& res, std::unique_ptr>& out) { auto suggested_dim = deserialize_scalar(res, is); - return std::make_unique>(suggested_dim).swap(out); + out = std::make_unique>(suggested_dim); } template @@ -132,17 +132,18 @@ void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr>& out) { - using out_mdarray_type = device_mdarray, layout_stride>; - using out_layout_type = typename out_mdarray_type::layout_type; - using out_container_policy_type = typename out_mdarray_type::container_policy_type; - using out_owning_type = owning_dataset; - auto n_rows = deserialize_scalar(res, is); auto dim = deserialize_scalar(res, is); auto stride = deserialize_scalar(res, is); auto out_extents = make_extents(n_rows, dim); auto out_layout = make_strided_layout(out_extents, std::array{stride, 1}); - auto out_array = out_mdarray_type{res, out_layout, out_container_policy_type{}}; + auto out_array = make_device_matrix(res, n_rows, stride); + + using out_mdarray_type = decltype(out_array); + using out_layout_type = typename out_mdarray_type::layout_type; + using out_container_policy_type = typename out_mdarray_type::container_policy_type; + using out_owning_type = owning_dataset; + auto host_arrray = make_host_mdarray(out_extents); deserialize_mdspan(res, is, host_arrray.view()); RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), @@ -153,9 +154,8 @@ void deserialize(raft::resources const& res, n_rows, cudaMemcpyDefault, resource::get_cuda_stream(res))); - return std::unique_ptr>{ - new out_owning_type{std::move(out_array), out_layout}} - .swap(out); + + out = std::make_unique(std::move(out_array), out_layout); } template @@ -178,9 +178,8 @@ void deserialize(raft::resources const& res, deserialize_mdspan(res, is, pq_code_book.view()); deserialize_mdspan(res, is, data.view()); - return std::unique_ptr>{ - new vpq_dataset{std::move(vq_code_book), std::move(pq_code_book), std::move(data)}} - .swap(out); + out = std::make_unique>( + std::move(vq_code_book), std::move(pq_code_book), std::move(data)); } template @@ -190,29 +189,34 @@ void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case kSerializeStridedDataset: switch (deserialize_scalar(res, is)) { case CUDA_R_32F: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case CUDA_R_16F: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case CUDA_R_8I: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case CUDA_R_8U: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } default: break; } @@ -221,12 +225,14 @@ void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } case CUDA_R_16F: { std::unique_ptr> p; deserialize(res, is, p); - return upcast_dataset_ptr(std::move(p)).swap(out); + out = std::move(p); + return; } default: break; }