From 71a19a2495b74ef726d7d95a0c953e16b3c86d67 Mon Sep 17 00:00:00 2001 From: Mark Harris <783069+harrism@users.noreply.github.com> Date: Wed, 24 Apr 2024 11:54:01 +1000 Subject: [PATCH] Convert device_memory_resource* to device_async_resource_ref (#2269) Closes #2261 For reviewers: Many of changes are simple textual replace of `rmm::mr::device_memory_resource *` with `rmm::device_async_resource_ref`. However there are several places where RAFT used a default value of `nullptr` for `device_memory_resource*` parameters. This is incompatible with a `resource_ref`, which is a lightweight non-owning reference class, not a pointer. In most places, I was able to either remove the default parameter value, or use `rmm::mr::get_current_device_resource()`. In the case of ivf_pq, I removed the deprecated versions of `search` that took an `mr` parameter. I removed the unused old src/util/memory_pool.cpp and its headers. Authors: - Mark Harris (https://github.com/harrism) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Dante Gama Dessavre (https://github.com/dantegd) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2269 --- cpp/CMakeLists.txt | 1 - cpp/bench/ann/src/raft/raft_ann_bench_utils.h | 5 +- cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu | 12 +-- cpp/bench/ann/src/raft/raft_cagra_wrapper.h | 4 +- .../ann/src/raft/raft_ivf_flat_wrapper.h | 11 ++- cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 3 - cpp/bench/prims/common/benchmark.hpp | 1 + cpp/bench/prims/matrix/gather.cu | 1 + cpp/bench/prims/neighbors/knn.cuh | 15 +++- cpp/bench/prims/random/subsample.cu | 1 + .../raft/cluster/detail/kmeans_balanced.cuh | 44 +++++----- cpp/include/raft/cluster/kmeans_balanced.cuh | 3 +- .../raft/core/device_container_policy.hpp | 19 ++--- cpp/include/raft/core/device_mdarray.hpp | 4 +- cpp/include/raft/core/device_resources.hpp | 1 + .../raft/distance/detail/masked_nn.cuh | 3 +- .../raft/matrix/detail/select_k-ext.cuh | 3 - .../raft/matrix/detail/select_radix.cuh | 10 +-- .../raft/matrix/detail/select_warpsort.cuh | 5 +- .../neighbors/detail/cagra/cagra_build.cuh | 4 +- .../raft/neighbors/detail/cagra/utils.hpp | 5 +- .../neighbors/detail/ivf_flat_search-ext.cuh | 8 +- .../neighbors/detail/ivf_flat_search-inl.cuh | 9 +- .../raft/neighbors/detail/ivf_pq_build.cuh | 13 +-- .../raft/neighbors/detail/ivf_pq_search.cuh | 5 +- .../raft/neighbors/detail/knn_brute_force.cuh | 1 - cpp/include/raft/neighbors/ivf_flat-ext.cuh | 10 +-- cpp/include/raft/neighbors/ivf_flat-inl.cuh | 8 +- cpp/include/raft/neighbors/ivf_pq-ext.cuh | 32 +------ cpp/include/raft/neighbors/ivf_pq-inl.cuh | 72 +--------------- .../random/detail/multi_variable_gaussian.cuh | 18 ++-- .../raft/random/multi_variable_gaussian.cuh | 12 ++- .../sparse/matrix/detail/select_k-ext.cuh | 3 - .../raft/spatial/knn/detail/ann_quantized.cuh | 11 ++- .../raft/spatial/knn/detail/ann_utils.cuh | 5 +- cpp/include/raft/util/cudart_utils.hpp | 1 - cpp/include/raft/util/memory_pool-ext.hpp | 28 ------ cpp/include/raft/util/memory_pool-inl.hpp | 85 ------------------- cpp/include/raft/util/memory_pool.hpp | 23 ----- .../neighbors/ivf_pq_search_test-ext.cuh | 5 +- .../raft_internal/neighbors/naive_knn.cuh | 2 - cpp/src/neighbors/detail/ivf_flat_search.cu | 6 +- cpp/src/neighbors/ivf_flat_00_generate.py | 14 +-- .../ivf_flat_search_float_int64_t.cu | 6 +- .../ivf_flat_search_int8_t_int64_t.cu | 6 +- .../ivf_flat_search_uint8_t_int64_t.cu | 6 +- .../neighbors/ivfpq_search_float_int64_t.cu | 7 +- .../neighbors/ivfpq_search_half_int64_t.cu | 5 +- .../neighbors/ivfpq_search_int8_t_int64_t.cu | 7 +- .../neighbors/ivfpq_search_uint8_t_int64_t.cu | 7 +- cpp/test/CMakeLists.txt | 1 - cpp/test/core/device_resources_manager.cpp | 8 +- cpp/test/ext_headers/00_generate.py | 5 +- .../ext_headers/raft_util_memory_pool.cpp | 27 ------ cpp/test/matrix/select_k.cuh | 1 - .../ivf_pq_search_float_uint32_t.cu | 3 +- cpp/test/neighbors/ann_utils.cuh | 4 - cpp/test/random/multi_variable_gaussian.cu | 5 +- cpp/test/util/device_atomics.cu | 1 + 59 files changed, 187 insertions(+), 438 deletions(-) delete mode 100644 cpp/include/raft/util/memory_pool-ext.hpp delete mode 100644 cpp/include/raft/util/memory_pool-inl.hpp delete mode 100644 cpp/include/raft/util/memory_pool.hpp delete mode 100644 cpp/test/ext_headers/raft_util_memory_pool.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 25475fc6f2..eaab637338 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -565,7 +565,6 @@ if(RAFT_COMPILE_LIBRARY) src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu - src/util/memory_pool.cpp ) set_target_properties( raft_objs diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h index 40c1ecfa5e..72a2c0bb05 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -130,8 +131,8 @@ class configured_raft_resources { { } - configured_raft_resources(configured_raft_resources&&) = default; - configured_raft_resources& operator=(configured_raft_resources&&) = default; + configured_raft_resources(configured_raft_resources&&) = delete; + configured_raft_resources& operator=(configured_raft_resources&&) = delete; ~configured_raft_resources() = default; configured_raft_resources(const configured_raft_resources& res) : configured_raft_resources{res.shared_res_} diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu b/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu index 709b08db76..d9ef1d74a3 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu +++ b/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu @@ -20,6 +20,7 @@ #include #include +#include #define JSON_DIAGNOSTICS 1 #include @@ -89,10 +90,11 @@ int main(int argc, char** argv) // and is initially sized to half of free device memory. rmm::mr::pool_memory_resource pool_mr{ &cuda_mr, rmm::percent_of_free_device_memory(50)}; - rmm::mr::set_current_device_resource( - &pool_mr); // Updates the current device resource pointer to `pool_mr` - rmm::mr::device_memory_resource* mr = - rmm::mr::get_current_device_resource(); // Points to `pool_mr` - return raft::bench::ann::run_main(argc, argv); + // Updates the current device resource pointer to `pool_mr` + auto old_mr = rmm::mr::set_current_device_resource(&pool_mr); + auto ret = raft::bench::ann::run_main(argc, argv); + // Restores the current device resource pointer to its previous value + rmm::mr::set_current_device_resource(old_mr); + return ret; } #endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 70fd22001e..46da8c52e6 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -36,7 +36,7 @@ #include #include -#include +#include #include #include @@ -138,7 +138,7 @@ class RaftCagra : public ANN, public AnnGPU { std::shared_ptr> dataset_; std::shared_ptr> input_dataset_v_; - inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type) + inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type) { switch (mem_type) { case (AllocatorType::HostPinned): return &mr_pinned_; diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 7f2996d77a..48d2b9de80 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -134,7 +134,14 @@ void RaftIvfFlatGpu::search( const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const { static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t"); - raft::neighbors::ivf_flat::search( - handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances); + raft::neighbors::ivf_flat::search(handle_, + search_params_, + *index_, + queries, + batch_size, + k, + (IdxT*)neighbors, + distances, + resource::get_workspace_resource(handle_)); } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 5d8b682264..1d73bd2e51 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -32,9 +32,6 @@ #include #include -#include -#include - #include namespace raft::bench::ann { diff --git a/cpp/bench/prims/common/benchmark.hpp b/cpp/bench/prims/common/benchmark.hpp index 4ecad6df3d..3ce43cc1e7 100644 --- a/cpp/bench/prims/common/benchmark.hpp +++ b/cpp/bench/prims/common/benchmark.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index 078f9e6198..876e47525c 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -24,6 +24,7 @@ #include #include +#include #include namespace raft::bench::matrix { diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index aea7168142..6499078623 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -27,10 +27,12 @@ #include #include +#include #include #include #include #include +#include #include @@ -101,7 +103,7 @@ struct device_resource { if (managed_) { delete res_; } } - [[nodiscard]] auto get() const -> rmm::mr::device_memory_resource* { return res_; } + [[nodiscard]] auto get() const -> rmm::device_async_resource_ref { return res_; } private: const bool managed_; @@ -158,8 +160,15 @@ struct ivf_flat_knn { IdxT* out_idxs) { search_params.n_probes = 20; - raft::neighbors::ivf_flat::search( - handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists); + raft::neighbors::ivf_flat::search(handle, + search_params, + *index, + search_items, + ps.n_queries, + ps.k, + out_idxs, + out_dists, + resource::get_workspace_resource(handle)); } }; diff --git a/cpp/bench/prims/random/subsample.cu b/cpp/bench/prims/random/subsample.cu index 4c8ca2bf31..70a9c65e0d 100644 --- a/cpp/bench/prims/random/subsample.cu +++ b/cpp/bench/prims/random/subsample.cu @@ -27,6 +27,7 @@ #include #include +#include #include #include diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh index 6d3f430e88..0a5a3ba5aa 100644 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -43,15 +43,14 @@ #include #include -#include -#include #include -#include +#include #include #include #include +#include #include #include @@ -91,7 +90,7 @@ inline std::enable_if_t> predict_core( const MathT* dataset_norm, IdxT n_rows, LabelT* labels, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) { auto stream = resource::get_cuda_stream(handle); switch (params.metric) { @@ -263,10 +262,9 @@ void calc_centers_and_sizes(const raft::resources& handle, const LabelT* labels, bool reset_counters, MappingOpT mapping_op, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::device_async_resource_ref mr) { auto stream = resource::get_cuda_stream(handle); - if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } if (!reset_counters) { raft::linalg::matrixVectorOp( @@ -322,12 +320,12 @@ void compute_norm(const raft::resources& handle, IdxT dim, IdxT n_rows, MappingOpT mapping_op, - rmm::mr::device_memory_resource* mr = nullptr) + std::optional mr = std::nullopt) { common::nvtx::range fun_scope("compute_norm"); auto stream = resource::get_cuda_stream(handle); - if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } - rmm::device_uvector mapped_dataset(0, stream, mr); + rmm::device_uvector mapped_dataset( + 0, stream, mr.value_or(resource::get_workspace_resource(handle))); const MathT* dataset_ptr = nullptr; @@ -338,7 +336,7 @@ void compute_norm(const raft::resources& handle, linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream); - dataset_ptr = (const MathT*)mapped_dataset.data(); + dataset_ptr = static_cast(mapped_dataset.data()); } raft::linalg::rowNorm( @@ -376,22 +374,22 @@ void predict(const raft::resources& handle, IdxT n_rows, LabelT* labels, MappingOpT mapping_op, - rmm::mr::device_memory_resource* mr = nullptr, - const MathT* dataset_norm = nullptr) + std::optional mr = std::nullopt, + const MathT* dataset_norm = nullptr) { auto stream = resource::get_cuda_stream(handle); common::nvtx::range fun_scope( "predict(%zu, %u)", static_cast(n_rows), n_clusters); - if (mr == nullptr) { mr = resource::get_workspace_resource(handle); } + auto mem_res = mr.value_or(resource::get_workspace_resource(handle)); auto [max_minibatch_size, _mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); rmm::device_uvector cur_dataset( - std::is_same_v ? 0 : max_minibatch_size * dim, stream, mr); + std::is_same_v ? 0 : max_minibatch_size * dim, stream, mem_res); bool need_compute_norm = dataset_norm == nullptr && (params.metric == raft::distance::DistanceType::L2Expanded || params.metric == raft::distance::DistanceType::L2SqrtExpanded); rmm::device_uvector cur_dataset_norm( - need_compute_norm ? max_minibatch_size : 0, stream, mr); + need_compute_norm ? max_minibatch_size : 0, stream, mem_res); const MathT* dataset_norm_ptr = nullptr; auto cur_dataset_ptr = cur_dataset.data(); for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { @@ -407,7 +405,7 @@ void predict(const raft::resources& handle, // Compute the norm now if it hasn't been pre-computed. if (need_compute_norm) { compute_norm( - handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mr); + handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res); dataset_norm_ptr = cur_dataset_norm.data(); } else if (dataset_norm != nullptr) { dataset_norm_ptr = dataset_norm + offset; @@ -422,7 +420,7 @@ void predict(const raft::resources& handle, dataset_norm_ptr, minibatch_size, labels + offset, - mr); + mem_res); } } @@ -530,7 +528,7 @@ auto adjust_centers(MathT* centers, MathT threshold, MappingOpT mapping_op, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* device_memory) -> bool + rmm::device_async_resource_ref device_memory) -> bool { common::nvtx::range fun_scope( "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); @@ -628,7 +626,7 @@ void balancing_em_iters(const raft::resources& handle, uint32_t balancing_pullback, MathT balancing_threshold, MappingOpT mapping_op, - rmm::mr::device_memory_resource* device_memory) + rmm::device_async_resource_ref device_memory) { auto stream = resource::get_cuda_stream(handle); uint32_t balancing_counter = balancing_pullback; @@ -711,7 +709,7 @@ void build_clusters(const raft::resources& handle, LabelT* cluster_labels, CounterT* cluster_sizes, MappingOpT mapping_op, - rmm::mr::device_memory_resource* device_memory, + rmm::device_async_resource_ref device_memory, const MathT* dataset_norm = nullptr) { auto stream = resource::get_cuda_stream(handle); @@ -853,8 +851,8 @@ auto build_fine_clusters(const raft::resources& handle, IdxT fine_clusters_nums_max, MathT* cluster_centers, MappingOpT mapping_op, - rmm::mr::device_memory_resource* managed_memory, - rmm::mr::device_memory_resource* device_memory) -> IdxT + rmm::device_async_resource_ref managed_memory, + rmm::device_async_resource_ref device_memory) -> IdxT { auto stream = resource::get_cuda_stream(handle); rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); @@ -971,7 +969,7 @@ void build_hierarchical(const raft::resources& handle, // TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf. rmm::mr::managed_memory_resource managed_memory; - rmm::mr::device_memory_resource* device_memory = resource::get_workspace_resource(handle); + rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle); auto [max_minibatch_size, mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); diff --git a/cpp/include/raft/cluster/kmeans_balanced.cuh b/cpp/include/raft/cluster/kmeans_balanced.cuh index 8cd7730814..a1a182608b 100644 --- a/cpp/include/raft/cluster/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/kmeans_balanced.cuh @@ -358,7 +358,8 @@ void calc_centers_and_sizes(const raft::resources& handle, X.extent(0), labels.data_handle(), reset_counters, - mapping_op); + mapping_op, + resource::get_workspace_resource(handle)); } } // namespace helpers diff --git a/cpp/include/raft/core/device_container_policy.hpp b/cpp/include/raft/core/device_container_policy.hpp index 8c6eff582b..18d8b77364 100644 --- a/cpp/include/raft/core/device_container_policy.hpp +++ b/cpp/include/raft/core/device_container_policy.hpp @@ -31,7 +31,8 @@ #include #include -#include +#include +#include #include @@ -117,7 +118,7 @@ class device_uvector { */ explicit device_uvector(std::size_t size, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) : data_{size, stream, mr} { } @@ -164,19 +165,11 @@ class device_uvector_policy { public: auto create(raft::resources const& res, size_t n) -> container_type { - if (mr_ == nullptr) { - // NB: not using the workspace resource by default! - // The workspace resource is for short-lived temporary allocations. - return container_type(n, resource::get_cuda_stream(res)); - } else { - return container_type(n, resource::get_cuda_stream(res), mr_); - } + return container_type(n, resource::get_cuda_stream(res), mr_); } constexpr device_uvector_policy() = default; - constexpr explicit device_uvector_policy(rmm::mr::device_memory_resource* mr) noexcept : mr_(mr) - { - } + explicit device_uvector_policy(rmm::device_async_resource_ref mr) noexcept : mr_(mr) {} [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference { @@ -192,7 +185,7 @@ class device_uvector_policy { [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } private: - rmm::mr::device_memory_resource* mr_{nullptr}; + rmm::device_async_resource_ref mr_{rmm::mr::get_current_device_resource()}; }; } // namespace raft diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 855642cd76..a34f6e2e02 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -21,6 +21,8 @@ #include #include +#include + #include namespace raft { @@ -107,7 +109,7 @@ template auto make_device_mdarray(raft::resources const& handle, - rmm::mr::device_memory_resource* mr, + rmm::device_async_resource_ref mr, extents exts) { using mdarray_t = device_mdarray; diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index 366e387fdd..496c65d91f 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -37,6 +37,7 @@ #include #include +#include #include diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 3e3699766f..951e030cbd 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -256,9 +256,8 @@ void masked_l2_nn_impl(raft::resources const& handle, static_assert(P::Mblk == 64, "masked_l2_nn_impl only supports a policy with 64 rows per block."); // Get stream and workspace memory resource - rmm::mr::device_memory_resource* ws_mr = - dynamic_cast(resource::get_workspace_resource(handle)); auto stream = resource::get_cuda_stream(handle); + auto ws_mr = resource::get_workspace_resource(handle); // Acquire temporary buffers and initialize to zero: // 1) Adjacency matrix bitfield diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 506cbffcb9..6db1a5acac 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -20,9 +20,6 @@ #include #include // RAFT_EXPLICIT -#include // rmm:cuda_stream_view -#include // rmm::mr::device_memory_resource - #include // __half #include // uint32_t diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 83d4845c31..9480c8e202 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -29,9 +29,9 @@ #include #include +#include #include -#include -#include +#include #include #include @@ -894,14 +894,12 @@ void radix_topk(const T* in, unsigned grid_dim, int sm_cnt, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()) { // TODO: is it possible to relax this restriction? static_assert(calc_num_passes() > 1); constexpr int num_buckets = calc_num_buckets(); - if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - auto kernel = radix_kernel; const size_t max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel, false); @@ -1179,7 +1177,7 @@ void radix_topk_one_block(const T* in, const IdxT* len_i, int sm_cnt, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) { static_assert(calc_num_passes() > 1); diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index 2cb32585d5..7da659291c 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -27,8 +27,9 @@ #include #include +#include #include -#include +#include #include #include @@ -1032,7 +1033,7 @@ void select_k_(int num_of_block, IdxT* out_idx, bool select_min, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) { rmm::device_uvector tmp_val(num_of_block * k * batch_size, stream, mr); rmm::device_uvector tmp_idx(num_of_block * k * batch_size, stream, mr); diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index d91e45257e..d63f865c39 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -34,6 +34,8 @@ #include #include +#include + #include #include #include @@ -124,7 +126,7 @@ void build_knn_graph(raft::resources const& res, bool first = true; const auto start_clock = std::chrono::system_clock::now(); - rmm::mr::device_memory_resource* device_memory = raft::resource::get_workspace_resource(res); + rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(res); raft::spatial::knn::detail::utils::batch_load_iterator vec_batches( dataset.data_handle(), diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp index 265cbfdceb..ece95a7cb7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include @@ -261,9 +261,8 @@ template void copy_with_padding(raft::resources const& res, raft::device_matrix& dst, mdspan, row_major, data_accessor> src, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()) { - if (!mr) { mr = rmm::mr::get_current_device_resource(); } size_t padded_dim = round_up_safe(src.extent(1) * sizeof(T), 16) / sizeof(T); if ((dst.extent(0) != src.extent(0)) || (static_cast(dst.extent(1)) != padded_dim)) { diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index 350b82ede7..c14b0e810f 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -20,6 +20,8 @@ #include // none_ivf_sample_filter #include // RAFT_EXPLICIT +#include + #include #include // uintX_t @@ -37,8 +39,8 @@ void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; + rmm::device_async_resource_ref mr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_flat::detail @@ -54,7 +56,7 @@ void search(raft::resources const& handle, uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::mr::device_memory_resource* mr, \ + rmm::device_async_resource_ref mr, \ IvfSampleFilterT sample_filter) instantiate_raft_neighbors_ivf_flat_detail_search( diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 441fb76b2f..388dd60f14 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -30,7 +30,7 @@ #include // none_ivf_sample_filter #include // utils::mapping -#include // rmm::device_memory_resource +#include namespace raft::neighbors::ivf_flat::detail { @@ -48,7 +48,7 @@ void search_impl(raft::resources const& handle, bool select_min, IdxT* neighbors, AccT* distances, - rmm::mr::device_memory_resource* search_mr, + rmm::device_async_resource_ref search_mr, IvfSampleFilterT sample_filter) { auto stream = resource::get_cuda_stream(handle); @@ -276,13 +276,12 @@ inline void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) + rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource(), + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { common::nvtx::range fun_scope( "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); - if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } RAFT_EXPECTS(params.n_probes > 0, "n_probes (number of clusters to probe in the search) must be positive."); auto n_probes = std::min(params.n_probes, index.n_lists()); diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 8e3f7dbaf3..24574642ef 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -49,6 +49,7 @@ #include #include #include +#include #include #include @@ -171,7 +172,7 @@ void select_residuals(raft::resources const& handle, const float* center, // [dim] const T* dataset, // [.., dim] const IdxT* row_ids, // [n_rows] - rmm::mr::device_memory_resource* device_memory + rmm::device_async_resource_ref device_memory ) { @@ -225,7 +226,7 @@ void flat_compute_residuals( device_matrix_view centers, // [n_lists, dim_ext] const T* dataset, // [n_rows, dim] std::variant labels, // [n_rows] - rmm::mr::device_memory_resource* device_memory) + rmm::device_async_resource_ref device_memory) { auto stream = resource::get_cuda_stream(handle); auto dim = rotation_matrix.extent(1); @@ -397,7 +398,7 @@ void train_per_subset(raft::resources const& handle, const float* trainset, // [n_rows, dim] const uint32_t* labels, // [n_rows] uint32_t kmeans_n_iters, - rmm::mr::device_memory_resource* managed_memory) + rmm::device_async_resource_ref managed_memory) { auto stream = resource::get_cuda_stream(handle); auto device_memory = resource::get_workspace_resource(handle); @@ -475,7 +476,7 @@ void train_per_cluster(raft::resources const& handle, const float* trainset, // [n_rows, dim] const uint32_t* labels, // [n_rows] uint32_t kmeans_n_iters, - rmm::mr::device_memory_resource* managed_memory) + rmm::device_async_resource_ref managed_memory) { auto stream = resource::get_cuda_stream(handle); auto device_memory = resource::get_workspace_resource(handle); @@ -1325,7 +1326,7 @@ void process_and_fill_codes(raft::resources const& handle, std::variant src_offset_or_indices, const uint32_t* new_labels, IdxT n_rows, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) { auto new_vectors_residual = make_device_mdarray(handle, mr, make_extents(n_rows, index.rot_dim())); @@ -1516,7 +1517,7 @@ void extend(raft::resources const& handle, std::is_same_v, "Unsupported data type"); - rmm::mr::device_memory_resource* device_memory = raft::resource::get_workspace_resource(handle); + rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(handle); // The spec defines how the clusters look like auto spec = list_spec{ diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 4c5da38092..87e6d0a774 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -45,8 +45,7 @@ #include #include -#include -#include +#include #include #include @@ -76,7 +75,7 @@ void select_clusters(raft::resources const& handle, raft::distance::DistanceType metric, const T* queries, // [n_queries, dim] const float* cluster_centers, // [n_lists, dim_ext] - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) { common::nvtx::range fun_scope( "ivf_pq::search::select_clusters(n_probes = %u, n_queries = %u, n_lists = %u, dim = %u)", diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index adcb639301..daa2798b00 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -38,7 +38,6 @@ #include #include -#include #include #include diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index a1783dfcfd..12ab0dc3a6 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -22,7 +22,7 @@ #include // raft::neighbors::ivf_flat::index #include // RAFT_EXPLICIT -#include // rmm::mr::device_memory_resource +#include #include // int64_t @@ -109,8 +109,8 @@ void search_with_filtering(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; + rmm::device_async_resource_ref mr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; template void search(raft::resources const& handle, @@ -121,7 +121,7 @@ void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + rmm::device_async_resource_ref mr) RAFT_EXPLICIT; template void search_with_filtering(raft::resources const& handle, @@ -240,7 +240,7 @@ instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::mr::device_memory_resource* mr); \ + rmm::device_async_resource_ref mr); \ \ extern template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index ed1d320795..ea7cff7060 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -24,7 +24,7 @@ #include #include -#include +#include namespace raft::neighbors::ivf_flat { @@ -462,8 +462,8 @@ void search_with_filtering(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) + rmm::device_async_resource_ref mr, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { raft::neighbors::ivf_flat::detail::search( handle, params, index, queries, n_queries, k, neighbors, distances, mr, sample_filter); @@ -520,7 +520,7 @@ void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::device_async_resource_ref mr) { raft::neighbors::ivf_flat::detail::search(handle, params, diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh index 160a2753a5..620f4a244f 100644 --- a/cpp/include/raft/neighbors/ivf_pq-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -21,8 +21,6 @@ #include // raft::neighbors::ivf_pq::index #include // RAFT_EXPLICIT -#include // rmm::mr::device_memory_resource - #include // int64_t #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -105,33 +103,6 @@ void search(raft::resources const& handle, IdxT* neighbors, float* distances) RAFT_EXPLICIT; -template -[[deprecated( - "Drop the `mr` argument and use `raft::resource::set_workspace_resource` instead")]] void -search_with_filtering(raft::resources const& handle, - const raft::neighbors::ivf_pq::search_params& params, - const index& idx, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr, - IvfSampleFilterT sample_filter = IvfSampleFilterT{}) RAFT_EXPLICIT; - -template -[[deprecated( - "Drop the `mr` argument and use `raft::resource::set_workspace_resource` instead")]] void -search(raft::resources const& handle, - const raft::neighbors::ivf_pq::search_params& params, - const index& idx, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr) RAFT_EXPLICIT; - } // namespace raft::neighbors::ivf_pq #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -209,8 +180,7 @@ instantiate_raft_neighbors_ivf_pq_extend(uint8_t, int64_t); uint32_t n_queries, \ uint32_t k, \ IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr); \ + float* distances); \ \ extern template void raft::neighbors::ivf_pq::search( \ raft::resources const& handle, \ diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index a893153e1a..77c4bb8553 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -24,8 +24,6 @@ #include #include -#include - #include // shared_ptr namespace raft::neighbors::ivf_pq { @@ -403,38 +401,6 @@ void search_with_filtering(raft::resources const& handle, detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, sample_filter); } -/** - * This function is deprecated and will be removed in a future. - * Please drop the `mr` argument and use `raft::resource::set_workspace_resource` instead. - */ -template -[[deprecated( - "Drop the `mr` argument and use `raft::resource::set_workspace_resource` instead")]] void -search_with_filtering(raft::resources const& handle, - const search_params& params, - const index& idx, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr, - IvfSampleFilterT sample_filter = IvfSampleFilterT{}) -{ - if (mr != nullptr) { - // Shallow copy of the resource with the automatic lifespan: - // change the workspace resource temporarily - raft::resources res_local(handle); - resource::set_workspace_resource( - res_local, std::shared_ptr{mr, void_op{}}); - return search_with_filtering( - res_local, params, idx, queries, n_queries, k, neighbors, distances, sample_filter); - } else { - return search_with_filtering( - handle, params, idx, queries, n_queries, k, neighbors, distances, sample_filter); - } -} - /** * @brief Search ANN using the constructed index. * @@ -446,16 +412,13 @@ search_with_filtering(raft::resources const& handle, * eliminate entirely allocations happening within `search`: * @code{.cpp} * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); * // use default search parameters * ivf_pq::search_params search_params; * // Use the same allocator across multiple searches to reduce the number of * // cuda memory allocations - * ivf_pq::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); - * ivf_pq::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); - * ivf_pq::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); + * ivf_pq::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1); + * ivf_pq::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2); + * ivf_pq::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3); * ... * @endcode * The exact size of the temporary buffer depends on multiple factors and is an implementation @@ -496,33 +459,4 @@ void search(raft::resources const& handle, raft::neighbors::filtering::none_ivf_sample_filter{}); } -/** - * This function is deprecated and will be removed in a future. - * Please drop the `mr` argument and use `raft::resource::set_workspace_resource` instead. - */ -template -[[deprecated( - "Drop the `mr` argument and use `raft::resource::set_workspace_resource` instead")]] void -search(raft::resources const& handle, - const search_params& params, - const index& idx, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::mr::device_memory_resource* mr) -{ - return search_with_filtering(handle, - params, - idx, - queries, - n_queries, - k, - neighbors, - distances, - mr, - raft::neighbors::filtering::none_ivf_sample_filter{}); -} - } // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh index e88cbbdeea..c33bb8c348 100644 --- a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh +++ b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh @@ -31,10 +31,10 @@ #include #include - -#include +#include #include +#include #include #include #include @@ -278,7 +278,7 @@ class multi_variable_gaussian_setup_token; template multi_variable_gaussian_setup_token build_multi_variable_gaussian_token_impl( raft::resources const& handle, - rmm::mr::device_memory_resource& mem_resource, + rmm::device_async_resource_ref mem_resource, const int dim, const multi_variable_gaussian_decomposition_method method); @@ -294,7 +294,7 @@ class multi_variable_gaussian_setup_token { template friend multi_variable_gaussian_setup_token build_multi_variable_gaussian_token_impl( raft::resources const& handle, - rmm::mr::device_memory_resource& mem_resource, + rmm::device_async_resource_ref mem_resource, const int dim, const multi_variable_gaussian_decomposition_method method); @@ -321,7 +321,7 @@ class multi_variable_gaussian_setup_token { // Constructor, only for use by friend functions. // Hiding this will let us change the implementation in the future. multi_variable_gaussian_setup_token(raft::resources const& handle, - rmm::mr::device_memory_resource& mem_resource, + rmm::device_async_resource_ref mem_resource, const int dim, const multi_variable_gaussian_decomposition_method method) : impl_(std::make_unique>( @@ -378,14 +378,14 @@ class multi_variable_gaussian_setup_token { private: std::unique_ptr> impl_; raft::resources const& handle_; - rmm::mr::device_memory_resource& mem_resource_; + rmm::device_async_resource_ref mem_resource_; int dim_ = 0; auto allocate_workspace() const { const auto num_elements = impl_->get_workspace_size(); return rmm::device_uvector{ - num_elements, resource::get_cuda_stream(handle_), &mem_resource_}; + num_elements, resource::get_cuda_stream(handle_), mem_resource_}; } int dim() const { return dim_; } @@ -394,7 +394,7 @@ class multi_variable_gaussian_setup_token { template multi_variable_gaussian_setup_token build_multi_variable_gaussian_token_impl( raft::resources const& handle, - rmm::mr::device_memory_resource& mem_resource, + rmm::device_async_resource_ref mem_resource, const int dim, const multi_variable_gaussian_decomposition_method method) { @@ -414,7 +414,7 @@ void compute_multi_variable_gaussian_impl( template void compute_multi_variable_gaussian_impl( raft::resources const& handle, - rmm::mr::device_memory_resource& mem_resource, + rmm::device_async_resource_ref mem_resource, std::optional> x, raft::device_matrix_view P, raft::device_matrix_view X, diff --git a/cpp/include/raft/random/multi_variable_gaussian.cuh b/cpp/include/raft/random/multi_variable_gaussian.cuh index ab3f433422..4b37e1ff65 100644 --- a/cpp/include/raft/random/multi_variable_gaussian.cuh +++ b/cpp/include/raft/random/multi_variable_gaussian.cuh @@ -24,6 +24,8 @@ #include #include +#include + namespace raft::random { /** @@ -33,7 +35,7 @@ namespace raft::random { template void multi_variable_gaussian(raft::resources const& handle, - rmm::mr::device_memory_resource& mem_resource, + rmm::device_async_resource_ref mem_resource, std::optional> x, raft::device_matrix_view P, raft::device_matrix_view X, @@ -49,12 +51,8 @@ void multi_variable_gaussian(raft::resources const& handle, raft::device_matrix_view X, const multi_variable_gaussian_decomposition_method method) { - rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource(); - RAFT_EXPECTS(mem_resource_ptr != nullptr, - "compute_multi_variable_gaussian: " - "rmm::mr::get_current_device_resource() returned null; " - "please report this bug to the RAPIDS RAFT developers."); - detail::compute_multi_variable_gaussian_impl(handle, *mem_resource_ptr, x, P, X, method); + detail::compute_multi_variable_gaussian_impl( + handle, rmm::mr::get_current_device_resource(), x, P, X, method); } /** @} */ diff --git a/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh b/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh index 08bdfa6f30..922356b040 100644 --- a/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh @@ -21,9 +21,6 @@ #include #include // RAFT_EXPLICIT -#include // rmm:cuda_stream_view -#include // rmm::mr::device_memory_resource - #include // __half #include // uint32_t diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 041ab225f9..351bcd5531 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -108,8 +108,15 @@ void approx_knn_search(raft::resources const& handle, if (index->ivf_flat()) { ivf_flat::search_params params; params.n_probes = index->nprobe; - ivf_flat::search( - handle, params, *(index->ivf_flat()), query_array, n, k, indices, distances); + ivf_flat::search(handle, + params, + *(index->ivf_flat()), + query_array, + n, + k, + indices, + distances, + resource::get_workspace_resource(handle)); } else if (index->ivf_pq) { neighbors::ivf_pq::search_params params; params.n_probes = index->nprobe; diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index d862e586e3..920249172f 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -416,7 +417,7 @@ struct batch_load_iterator { size_type row_width, size_type batch_size, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) : stream_(stream), buf_(0, stream, mr), source_(source), @@ -502,7 +503,7 @@ struct batch_load_iterator { size_type row_width, size_type batch_size, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) + rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()) : cur_batch_(new batch(source, n_rows, row_width, batch_size, stream, mr)), cur_pos_(0) { } diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index e5ce15e8a3..2b334d1bbf 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -18,7 +18,6 @@ #include #include -#include #include diff --git a/cpp/include/raft/util/memory_pool-ext.hpp b/cpp/include/raft/util/memory_pool-ext.hpp deleted file mode 100644 index 030a9c681e..0000000000 --- a/cpp/include/raft/util/memory_pool-ext.hpp +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // rmm::mr::device_memory_resource - -#include // size_t -#include // std::unique_ptr - -namespace raft { - -std::unique_ptr get_pool_memory_resource( - rmm::mr::device_memory_resource*& mr, size_t initial_size); - -} // namespace raft diff --git a/cpp/include/raft/util/memory_pool-inl.hpp b/cpp/include/raft/util/memory_pool-inl.hpp deleted file mode 100644 index bd7e0186b3..0000000000 --- a/cpp/include/raft/util/memory_pool-inl.hpp +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // RAFT_INLINE_CONDITIONAL - -#include -#include -#include -#include - -#include -#include - -namespace raft { - -/** - * @defgroup memory_pool Memory Pool - * @{ - */ -/** - * @brief Get a pointer to a pooled memory resource within the scope of the lifetime of the returned - * unique pointer. - * - * This function is useful in the code where multiple repeated allocations/deallocations are - * expected. - * Use case example: - * @code{.cpp} - * void my_func(..., size_t n, rmm::mr::device_memory_resource* mr = nullptr) { - * auto pool_guard = raft::get_pool_memory_resource(mr, 2 * n * sizeof(float)); - * if (pool_guard){ - * RAFT_LOG_INFO("Created a pool"); - * } else { - * RAFT_LOG_INFO("Using the current default or explicitly passed device memory resource"); - * } - * rmm::device_uvector x(n, stream, mr); - * rmm::device_uvector y(n, stream, mr); - * ... - * } - * @endcode - * Here, the new memory resource would be created within the function scope if the passed `mr` is - * null and the default resource is not a pool. After the call, `mr` contains a valid memory - * resource in any case. - * - * @param[inout] mr if not null do nothing; otherwise get the current device resource and wrap it - * into a `pool_memory_resource` if necessary and return the pointer to the result. - * @param initial_size if a new memory pool is created, this would be its initial size (rounded up - * to 256 bytes). - * - * @return if a new memory pool is created, it returns a unique_ptr to it; - * this managed pointer controls the lifetime of the created memory resource. - */ -RAFT_INLINE_CONDITIONAL std::unique_ptr get_pool_memory_resource( - rmm::mr::device_memory_resource*& mr, size_t initial_size) -{ - using pool_res_t = rmm::mr::pool_memory_resource; - std::unique_ptr pool_res{nullptr}; - if (mr) return pool_res; - mr = rmm::mr::get_current_device_resource(); - if (!dynamic_cast(mr) && - !dynamic_cast*>(mr) && - !dynamic_cast*>(mr)) { - pool_res = std::make_unique( - mr, rmm::align_down(initial_size, rmm::CUDA_ALLOCATION_ALIGNMENT)); - mr = pool_res.get(); - } - return pool_res; -} - -/** @} */ -} // namespace raft diff --git a/cpp/include/raft/util/memory_pool.hpp b/cpp/include/raft/util/memory_pool.hpp deleted file mode 100644 index c9d25ecb1f..0000000000 --- a/cpp/include/raft/util/memory_pool.hpp +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "memory_pool-ext.hpp" - -#if !defined(RAFT_COMPILED) -#include "memory_pool-inl.hpp" -#endif // RAFT_COMPILED diff --git a/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh b/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh index 7a65e2d2f8..1e6f4f9976 100644 --- a/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh +++ b/cpp/internal/raft_internal/neighbors/ivf_pq_search_test-ext.cuh @@ -25,6 +25,8 @@ #include +#include + #include // int64_t #define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ @@ -44,8 +46,7 @@ uint32_t n_queries, \ uint32_t k, \ IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr); \ + float* distances); \ \ extern template void raft::neighbors::ivf_pq::search( \ raft::resources const& handle, \ diff --git a/cpp/internal/raft_internal/neighbors/naive_knn.cuh b/cpp/internal/raft_internal/neighbors/naive_knn.cuh index 79206c7a43..c14a8e3e9f 100644 --- a/cpp/internal/raft_internal/neighbors/naive_knn.cuh +++ b/cpp/internal/raft_internal/neighbors/naive_knn.cuh @@ -23,9 +23,7 @@ #include #include -#include #include -#include namespace raft::neighbors { diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu index 9d39607750..336bea19b6 100644 --- a/cpp/src/neighbors/detail/ivf_flat_search.cu +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ #include #include +#include + #define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ template void raft::neighbors::ivf_flat::detail::search( \ raft::resources const& handle, \ @@ -27,7 +29,7 @@ uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::mr::device_memory_resource* mr, \ + rmm::device_async_resource_ref mr, \ IvfSampleFilterT sample_filter) instantiate_raft_neighbors_ivf_flat_detail_search( diff --git a/cpp/src/neighbors/ivf_flat_00_generate.py b/cpp/src/neighbors/ivf_flat_00_generate.py index d987a4e17d..7b55cad4de 100644 --- a/cpp/src/neighbors/ivf_flat_00_generate.py +++ b/cpp/src/neighbors/ivf_flat_00_generate.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. header = """/* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -127,8 +127,8 @@ search_macro = """ #define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \\ - template void raft::neighbors::ivf_flat::search( \\ - raft::resources const& handle, \\ + template void raft::neighbors::ivf_flat::search( \\ + raft::resources const& handle, \\ const raft::neighbors::ivf_flat::search_params& params, \\ const raft::neighbors::ivf_flat::index& index, \\ const T* queries, \\ @@ -136,10 +136,10 @@ uint32_t k, \\ IdxT* neighbors, \\ float* distances, \\ - rmm::mr::device_memory_resource* mr ); \\ + rmm::device_async_resource_ref mr); \\ \\ - template void raft::neighbors::ivf_flat::search( \\ - raft::resources const& handle, \\ + template void raft::neighbors::ivf_flat::search( \\ + raft::resources const& handle, \\ const raft::neighbors::ivf_flat::search_params& params, \\ const raft::neighbors::ivf_flat::index& index, \\ raft::device_matrix_view queries, \\ diff --git a/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu index 03dcfee817..e5cfe14e3f 100644 --- a/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ #include +#include + #define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ @@ -35,7 +37,7 @@ uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::mr::device_memory_resource* mr); \ + rmm::device_async_resource_ref mr); \ \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ diff --git a/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu index 7646081183..35792a78a8 100644 --- a/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ #include +#include + #define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ @@ -35,7 +37,7 @@ uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::mr::device_memory_resource* mr); \ + rmm::device_async_resource_ref mr); \ \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ diff --git a/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu index 5d2effd385..663e52cb99 100644 --- a/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ #include +#include + #define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ @@ -35,7 +37,7 @@ uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::mr::device_memory_resource* mr); \ + rmm::device_async_resource_ref mr); \ \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ diff --git a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu index e56c107735..2d15167099 100644 --- a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ #include #include // raft::neighbors::ivf_pq::index +#include + #define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ template void raft::neighbors::ivf_pq::search( \ raft::resources const& handle, \ @@ -34,8 +36,7 @@ uint32_t n_queries, \ uint32_t k, \ IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) + float* distances) instantiate_raft_neighbors_ivf_pq_search(float, int64_t); diff --git a/cpp/src/neighbors/ivfpq_search_half_int64_t.cu b/cpp/src/neighbors/ivfpq_search_half_int64_t.cu index c9f2e6fdd5..c9a380e21f 100644 --- a/cpp/src/neighbors/ivfpq_search_half_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_half_int64_t.cu @@ -17,6 +17,8 @@ #include #include // raft::neighbors::ivf_pq::index +#include + #include #define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ @@ -36,8 +38,7 @@ uint32_t n_queries, \ uint32_t k, \ IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) + float* distances) instantiate_raft_neighbors_ivf_pq_search(half, int64_t); diff --git a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu index 1efe4f7fb2..e85c98d8dd 100644 --- a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ #include #include // raft::neighbors::ivf_pq::index +#include + #define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ template void raft::neighbors::ivf_pq::search( \ raft::resources const& handle, \ @@ -34,8 +36,7 @@ uint32_t n_queries, \ uint32_t k, \ IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) + float* distances) instantiate_raft_neighbors_ivf_pq_search(int8_t, int64_t); diff --git a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu index e746391443..42653254e9 100644 --- a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ #include #include // raft::neighbors::ivf_pq::index +#include + #define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ template void raft::neighbors::ivf_pq::search( \ raft::resources const& handle, \ @@ -34,8 +36,7 @@ uint32_t n_queries, \ uint32_t k, \ IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) + float* distances) instantiate_raft_neighbors_ivf_pq_search(uint8_t, int64_t); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 4d17aacffd..752dffdc16 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -188,7 +188,6 @@ if(BUILD_TESTS) test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu test/ext_headers/raft_distance_fused_l2_nn.cu test/ext_headers/raft_neighbors_ivf_pq.cu - test/ext_headers/raft_util_memory_pool.cpp test/ext_headers/raft_neighbors_ivf_flat.cu test/ext_headers/raft_core_logger.cpp test/ext_headers/raft_neighbors_refine.cu diff --git a/cpp/test/core/device_resources_manager.cpp b/cpp/test/core/device_resources_manager.cpp index b9b8996a09..c63d5896e5 100644 --- a/cpp/test/core/device_resources_manager.cpp +++ b/cpp/test/core/device_resources_manager.cpp @@ -115,16 +115,10 @@ TEST(DeviceResourcesManager, ObeysSetters) auto* mr = dynamic_cast*>( rmm::mr::get_current_device_resource()); - rmm::device_async_resource_ref workspace_mr = - dynamic_cast*>( - res.get_workspace_resource()) - ->get_upstream_resource(); + if (upstream_mrs[i % devices.size()] != nullptr) { // Expect that the current memory resource is a pool memory resource as requested EXPECT_NE(mr, nullptr); - - // We cannot easily check the type of a resource_ref - (void)workspace_mr; } { diff --git a/cpp/test/ext_headers/00_generate.py b/cpp/test/ext_headers/00_generate.py index 682cadbe89..d9c766979b 100644 --- a/cpp/test/ext_headers/00_generate.py +++ b/cpp/test/ext_headers/00_generate.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ copyright_notice = """ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,7 +49,6 @@ "raft/spatial/knn/detail/fused_l2_knn-ext.cuh", "raft/distance/fused_l2_nn-ext.cuh", "raft/neighbors/ivf_pq-ext.cuh", - "raft/util/memory_pool-ext.hpp", "raft/neighbors/ivf_flat-ext.cuh", "raft/core/logger-ext.hpp", "raft/neighbors/refine-ext.cuh", diff --git a/cpp/test/ext_headers/raft_util_memory_pool.cpp b/cpp/test/ext_headers/raft_util_memory_pool.cpp deleted file mode 100644 index 11a024b958..0000000000 --- a/cpp/test/ext_headers/raft_util_memory_pool.cpp +++ /dev/null @@ -1,27 +0,0 @@ - -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by 00_generate.py - * - * Make changes there and run in this directory: - * - * > python 00_generate.py - * - */ - -#include diff --git a/cpp/test/matrix/select_k.cuh b/cpp/test/matrix/select_k.cuh index 7f9b7b3fc3..f22f4f5fa7 100644 --- a/cpp/test/matrix/select_k.cuh +++ b/cpp/test/matrix/select_k.cuh @@ -25,7 +25,6 @@ #include #include -#include #include diff --git a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu index 942d0fcc44..00baa59f58 100644 --- a/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/ivf_pq_search_float_uint32_t.cu @@ -37,8 +37,7 @@ uint32_t n_queries, \ uint32_t k, \ IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) + float* distances) instantiate_raft_neighbors_ivf_pq_search(float, uint32_t); diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 3e0bead665..2139e97428 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -28,10 +28,6 @@ #include -#include -#include -#include - #include #include diff --git a/cpp/test/random/multi_variable_gaussian.cu b/cpp/test/random/multi_variable_gaussian.cu index 62bad8e543..bed9515a53 100644 --- a/cpp/test/random/multi_variable_gaussian.cu +++ b/cpp/test/random/multi_variable_gaussian.cu @@ -25,6 +25,7 @@ #include #include +#include #include @@ -287,10 +288,8 @@ class MVGMdspanTest : public ::testing::TestWithParam> { raft::device_matrix_view P_view(P_d.data(), dim, dim); raft::device_matrix_view X_view(X_d.data(), dim, nPoints); - rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource(); - ASSERT_TRUE(mem_resource_ptr != nullptr); raft::random::multi_variable_gaussian( - handle, *mem_resource_ptr, x_view, P_view, X_view, method); + handle, rmm::mr::get_current_device_resource(), x_view, P_view, X_view, method); // saving the mean of the randoms in Rand_mean //@todo can be swapped with a API that calculates mean diff --git a/cpp/test/util/device_atomics.cu b/cpp/test/util/device_atomics.cu index c5bb0ad3b6..086d1f4152 100644 --- a/cpp/test/util/device_atomics.cu +++ b/cpp/test/util/device_atomics.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include