From 3f02ea17b218727a03211dcb1a5005d36b7ffb79 Mon Sep 17 00:00:00 2001 From: Micka Date: Fri, 6 Oct 2023 17:05:49 +0200 Subject: [PATCH] Replace `raft::random` calls to not use deprecated API (#1867) Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1867 --- cpp/bench/prims/cluster/kmeans_balanced.cu | 4 +- cpp/bench/prims/distance/kernels.cu | 6 +-- cpp/bench/prims/linalg/norm.cu | 2 +- cpp/bench/prims/linalg/normalize.cu | 2 +- cpp/bench/prims/linalg/reduce_cols_by_key.cu | 2 +- cpp/bench/prims/linalg/reduce_rows_by_key.cu | 2 +- cpp/bench/prims/matrix/argmin.cu | 2 +- cpp/bench/prims/matrix/gather.cu | 4 +- cpp/bench/prims/neighbors/cagra_bench.cuh | 10 ++--- cpp/bench/prims/neighbors/knn.cuh | 4 +- .../raft/neighbors/detail/ivf_pq_build.cuh | 4 +- .../raft_internal/neighbors/refine_helper.cuh | 14 ++++--- cpp/test/distance/gram.cu | 6 +-- cpp/test/linalg/reduce.cu | 2 +- cpp/test/neighbors/ann_cagra.cuh | 42 +++++++++++-------- cpp/test/neighbors/ann_ivf_flat.cuh | 14 ++++--- cpp/test/neighbors/ann_ivf_pq.cuh | 14 ++++--- cpp/test/neighbors/ann_nn_descent.cuh | 7 ++-- cpp/test/random/rmat_rectangular_generator.cu | 4 +- cpp/test/sparse/gram.cu | 12 +++--- cpp/test/test_utils.cuh | 25 +++++------ cpp/test/util/bitonic_sort.cu | 13 +++--- 22 files changed, 108 insertions(+), 87 deletions(-) diff --git a/cpp/bench/prims/cluster/kmeans_balanced.cu b/cpp/bench/prims/cluster/kmeans_balanced.cu index effe2a55a4..129578c303 100644 --- a/cpp/bench/prims/cluster/kmeans_balanced.cu +++ b/cpp/bench/prims/cluster/kmeans_balanced.cu @@ -50,10 +50,10 @@ struct KMeansBalanced : public fixture { constexpr T kRangeMin = std::is_integral_v ? std::numeric_limits::min() : T(-1); if constexpr (std::is_integral_v) { raft::random::uniformInt( - rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream); + handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax); } else { raft::random::uniform( - rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream); + handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax); } resource::sync_stream(handle, stream); } diff --git a/cpp/bench/prims/distance/kernels.cu b/cpp/bench/prims/distance/kernels.cu index 7d916e6ce0..3f74759665 100644 --- a/cpp/bench/prims/distance/kernels.cu +++ b/cpp/bench/prims/distance/kernels.cu @@ -46,9 +46,9 @@ struct GramMatrix : public fixture { A.resize(params.m * params.k, stream); B.resize(params.k * params.n, stream); C.resize(params.m * params.n, stream); - raft::random::Rng r(123456ULL); - r.uniform(A.data(), params.m * params.k, T(-1.0), T(1.0), stream); - r.uniform(B.data(), params.k * params.n, T(-1.0), T(1.0), stream); + raft::random::RngState rng(123456ULL); + raft::random::uniform(handle, rng, A.data(), params.m * params.k, T(-1.0), T(1.0)); + raft::random::uniform(handle, rng, B.data(), params.k * params.n, T(-1.0), T(1.0)); } ~GramMatrix() diff --git a/cpp/bench/prims/linalg/norm.cu b/cpp/bench/prims/linalg/norm.cu index f83953f8e4..1db23e4ca4 100644 --- a/cpp/bench/prims/linalg/norm.cu +++ b/cpp/bench/prims/linalg/norm.cu @@ -42,7 +42,7 @@ struct rowNorm : public fixture { rowNorm(const norm_input& p) : params(p), in(p.rows * p.cols, stream), dots(p.rows, stream) { raft::random::RngState rng{1234}; - raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream); + raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0); } void run_benchmark(::benchmark::State& state) override diff --git a/cpp/bench/prims/linalg/normalize.cu b/cpp/bench/prims/linalg/normalize.cu index ad9052a008..91319e774c 100644 --- a/cpp/bench/prims/linalg/normalize.cu +++ b/cpp/bench/prims/linalg/normalize.cu @@ -41,7 +41,7 @@ struct rowNormalize : public fixture { : params(p), in(p.rows * p.cols, stream), out(p.rows * p.cols, stream) { raft::random::RngState rng{1234}; - raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream); + raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0); } void run_benchmark(::benchmark::State& state) override diff --git a/cpp/bench/prims/linalg/reduce_cols_by_key.cu b/cpp/bench/prims/linalg/reduce_cols_by_key.cu index ac0c612ee4..1b584e80c8 100644 --- a/cpp/bench/prims/linalg/reduce_cols_by_key.cu +++ b/cpp/bench/prims/linalg/reduce_cols_by_key.cu @@ -42,7 +42,7 @@ struct reduce_cols_by_key : public fixture { : params(p), in(p.rows * p.cols, stream), out(p.rows * p.keys, stream), keys(p.cols, stream) { raft::random::RngState rng{42}; - raft::random::uniformInt(rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys, stream); + raft::random::uniformInt(handle, rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys); } void run_benchmark(::benchmark::State& state) override diff --git a/cpp/bench/prims/linalg/reduce_rows_by_key.cu b/cpp/bench/prims/linalg/reduce_rows_by_key.cu index aa9c9a1f62..b68cefc274 100644 --- a/cpp/bench/prims/linalg/reduce_rows_by_key.cu +++ b/cpp/bench/prims/linalg/reduce_rows_by_key.cu @@ -37,7 +37,7 @@ struct reduce_rows_by_key : public fixture { workspace(p.rows, stream) { raft::random::RngState rng{42}; - raft::random::uniformInt(rng, keys.data(), p.rows, (KeyT)0, (KeyT)p.keys, stream); + raft::random::uniformInt(handle, rng, keys.data(), p.rows, (KeyT)0, (KeyT)p.keys); } void run_benchmark(::benchmark::State& state) override diff --git a/cpp/bench/prims/matrix/argmin.cu b/cpp/bench/prims/matrix/argmin.cu index a8f667257a..afee81aa00 100644 --- a/cpp/bench/prims/matrix/argmin.cu +++ b/cpp/bench/prims/matrix/argmin.cu @@ -40,7 +40,7 @@ struct Argmin : public fixture { raft::random::RngState rng{1234}; raft::random::uniform( - rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); + handle, rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1)); resource::sync_stream(handle, stream); } diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index ca6a2830bd..00a145ffa9 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -52,11 +52,11 @@ struct Gather : public fixture { raft::random::RngState rng{1234}; raft::random::uniform( - rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); + handle, rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1)); raft::random::uniformInt( handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows); if constexpr (Conditional) { - raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream); + raft::random::uniform(handle, rng, stencil.data_handle(), params.map_length, T(-1), T(1)); } resource::sync_stream(handle, stream); } diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh index 63f6c14686..07e93a3473 100644 --- a/cpp/bench/prims/neighbors/cagra_bench.cuh +++ b/cpp/bench/prims/neighbors/cagra_bench.cuh @@ -62,20 +62,20 @@ struct CagraBench : public fixture { constexpr T kRangeMin = std::is_integral_v ? std::numeric_limits::min() : T(-1); if constexpr (std::is_integral_v) { raft::random::uniformInt( - state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax, stream); + handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax); raft::random::uniformInt( - state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax, stream); + handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax); } else { raft::random::uniform( - state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax, stream); + handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax); raft::random::uniform( - state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax, stream); + handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax); } // Generate random knn graph raft::random::uniformInt( - state, knn_graph_.data_handle(), knn_graph_.size(), 0, ps.n_samples - 1, stream); + handle, state, knn_graph_.data_handle(), knn_graph_.size(), 0, ps.n_samples - 1); auto metric = raft::distance::DistanceType::L2Expanded; diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index e580b20fdc..31ac869b37 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -260,9 +260,9 @@ struct knn : public fixture { constexpr T kRangeMax = std::is_integral_v ? std::numeric_limits::max() : T(1); constexpr T kRangeMin = std::is_integral_v ? std::numeric_limits::min() : T(-1); if constexpr (std::is_integral_v) { - raft::random::uniformInt(state, vec.data(), n, kRangeMin, kRangeMax, stream); + raft::random::uniformInt(handle, state, vec.data(), n, kRangeMin, kRangeMax); } else { - raft::random::uniform(state, vec.data(), n, kRangeMin, kRangeMax, stream); + raft::random::uniform(handle, state, vec.data(), n, kRangeMin, kRangeMax); } } diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 47c10de200..975ae9ec00 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -124,7 +124,7 @@ inline void make_rotation_matrix(raft::resources const& handle, uint32_t n_rows, uint32_t n_cols, float* rotation_matrix, - raft::random::Rng rng = raft::random::Rng(7ULL)) + raft::random::RngState rng = raft::random::RngState(7ULL)) { common::nvtx::range fun_scope( "ivf_pq::make_rotation_matrix(%u * %u)", n_rows, n_cols); @@ -134,7 +134,7 @@ inline void make_rotation_matrix(raft::resources const& handle, if (force_random_rotation || !inplace) { rmm::device_uvector buf(inplace ? 0 : n * n, stream); float* mat = inplace ? rotation_matrix : buf.data(); - rng.normal(mat, n * n, 0.0f, 1.0f, stream); + raft::random::normal(handle, rng, mat, n * n, 0.0f, 1.0f); linalg::detail::qrGetQ_inplace(handle, mat, n, n, stream); if (!inplace) { RAFT_CUDA_TRY(cudaMemcpy2DAsync(rotation_matrix, diff --git a/cpp/internal/raft_internal/neighbors/refine_helper.cuh b/cpp/internal/raft_internal/neighbors/refine_helper.cuh index ee06d90851..4a06116877 100644 --- a/cpp/internal/raft_internal/neighbors/refine_helper.cuh +++ b/cpp/internal/raft_internal/neighbors/refine_helper.cuh @@ -61,16 +61,20 @@ class RefineHelper { refined_distances_host(handle), refined_indices_host(handle) { - raft::random::Rng r(1234ULL); + raft::random::RngState rng(1234ULL); dataset = raft::make_device_matrix(handle_, p.n_rows, p.dim); queries = raft::make_device_matrix(handle_, p.n_queries, p.dim); if constexpr (std::is_same{}) { - r.uniform(dataset.data_handle(), dataset.size(), DataT(-10.0), DataT(10.0), stream_); - r.uniform(queries.data_handle(), queries.size(), DataT(-10.0), DataT(10.0), stream_); + raft::random::uniform( + handle, rng, dataset.data_handle(), dataset.size(), DataT(-10.0), DataT(10.0)); + raft::random::uniform( + handle, rng, queries.data_handle(), queries.size(), DataT(-10.0), DataT(10.0)); } else { - r.uniformInt(dataset.data_handle(), dataset.size(), DataT(1), DataT(20), stream_); - r.uniformInt(queries.data_handle(), queries.size(), DataT(1), DataT(20), stream_); + raft::random::uniformInt( + handle, rng, dataset.data_handle(), dataset.size(), DataT(1), DataT(20)); + raft::random::uniformInt( + handle, rng, queries.data_handle(), queries.size(), DataT(1), DataT(20)); } refined_distances = raft::make_device_matrix(handle_, p.n_queries, p.k); diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu index d5fecd93c6..a9dbd8328f 100644 --- a/cpp/test/distance/gram.cu +++ b/cpp/test/distance/gram.cu @@ -99,9 +99,9 @@ class GramMatrixTest : public ::testing::TestWithParam { gram_host.resize(gram.size()); std::fill(gram_host.begin(), gram_host.end(), 0); - raft::random::Rng r(42137ULL); - r.uniform(x1.data(), x1.size(), math_t(0), math_t(1), stream); - r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream); + raft::random::RngState rng(42137ULL); + raft::random::uniform(handle, rng, x1.data(), x1.size(), math_t(0), math_t(1)); + raft::random::uniform(handle, rng, x2.data(), x2.size(), math_t(0), math_t(1)); } ~GramMatrixTest() override {} diff --git a/cpp/test/linalg/reduce.cu b/cpp/test/linalg/reduce.cu index fd1b4e7b45..8578fe9637 100644 --- a/cpp/test/linalg/reduce.cu +++ b/cpp/test/linalg/reduce.cu @@ -124,7 +124,7 @@ class ReduceTest : public ::testing::TestWithParam(std::floor((24 - std::log2(dim)) / 2)); - rng.uniformInt(reinterpret_cast(ptr), size, 0u, resolution - 1, cuda_stream); + raft::random::uniformInt(handle, rng, reinterpret_cast(ptr), size, 0u, resolution - 1); GenerateRoundingErrorFreeDataset_kernel<<>>( ptr, size, resolution); @@ -293,13 +294,16 @@ class AnnCagraTest : public ::testing::TestWithParam { { database.resize(((size_t)ps.n_rows) * ps.dim, stream_); search_queries.resize(ps.n_queries * ps.dim, stream_); - raft::random::Rng r(1234ULL); + raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { - r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); - r.normal(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_); + raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::normal( + handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0)); } else { - r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); - r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_); + raft::random::uniformInt( + handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20)); } resource::sync_stream(handle_); } @@ -379,11 +383,12 @@ class AnnCagraSortTest : public ::testing::TestWithParam { void SetUp() override { database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream()); - raft::random::Rng r(1234ULL); + raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { - GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream()); + GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r); } else { - r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), handle_.get_stream()); + raft::random::uniformInt( + handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); } handle_.sync_stream(); } @@ -643,13 +648,16 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { { database.resize(((size_t)ps.n_rows) * ps.dim, stream_); search_queries.resize(ps.n_queries * ps.dim, stream_); - raft::random::Rng r(1234ULL); + raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { - r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); - r.normal(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_); + raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::normal( + handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0)); } else { - r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); - r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_); + raft::random::uniformInt( + handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20)); } resource::sync_stream(handle_); } diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 71d48cdeb7..7b1d32ca83 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -411,13 +411,17 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { database.resize(ps.num_db_vecs * ps.dim, stream_); search_queries.resize(ps.num_queries * ps.dim, stream_); - raft::random::Rng r(1234ULL); + raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { - r.uniform(database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0), stream_); - r.uniform(search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0), stream_); + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); } else { - r.uniformInt(database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20), stream_); - r.uniformInt(search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20), stream_); + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); } resource::sync_stream(handle_); } diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index e03d09ae50..d1f5ee5b03 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -170,13 +170,17 @@ class ivf_pq_test : public ::testing::TestWithParam { database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_); search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_); - raft::random::Rng r(1234ULL); + raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { - r.uniform(database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0), stream_); - r.uniform(search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0), stream_); + raft::random::uniform( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0)); + raft::random::uniform( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0)); } else { - r.uniformInt(database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20), stream_); - r.uniformInt(search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20), stream_); + raft::random::uniformInt( + handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20)); + raft::random::uniformInt( + handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20)); } resource::sync_stream(handle_); } diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh index 948323cf6e..d62b863437 100644 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -121,11 +121,12 @@ class AnnNNDescentTest : public ::testing::TestWithParam { void SetUp() override { database.resize(((size_t)ps.n_rows) * ps.dim, stream_); - raft::random::Rng r(1234ULL); + raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { - r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_); + raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); } else { - r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_); + raft::random::uniformInt( + handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); } resource::sync_stream(handle_); } diff --git a/cpp/test/random/rmat_rectangular_generator.cu b/cpp/test/random/rmat_rectangular_generator.cu index 1af3d2be31..77af44f133 100644 --- a/cpp/test/random/rmat_rectangular_generator.cu +++ b/cpp/test/random/rmat_rectangular_generator.cu @@ -178,7 +178,7 @@ class RmatGenTest : public ::testing::TestWithParam { max_scale{std::max(params.r_scale, params.c_scale)} { theta.resize(4 * max_scale, stream); - uniform(state, theta.data(), theta.size(), 0.0f, 1.0f, stream); + uniform(handle, state, theta.data(), theta.size(), 0.0f, 1.0f); normalize(theta.data(), theta.data(), max_scale, @@ -271,7 +271,7 @@ class RmatGenMdspanTest : public ::testing::TestWithParam { max_scale{std::max(params.r_scale, params.c_scale)} { theta.resize(4 * max_scale, stream); - uniform(state, theta.data(), theta.size(), 0.0f, 1.0f, stream); + uniform(handle, state, theta.data(), theta.size(), 0.0f, 1.0f); normalize(theta.data(), theta.data(), max_scale, diff --git a/cpp/test/sparse/gram.cu b/cpp/test/sparse/gram.cu index 7b4736a08c..ca43aa83b9 100644 --- a/cpp/test/sparse/gram.cu +++ b/cpp/test/sparse/gram.cu @@ -125,7 +125,7 @@ class GramMatrixTest : public ::testing::TestWithParam { protected: GramMatrixTest() : params(GetParam()), - stream(0), + stream(resource::get_cuda_stream(handle)), x1(0, stream), x2(0, stream), x1_csr_indptr(0, stream), @@ -137,8 +137,6 @@ class GramMatrixTest : public ::testing::TestWithParam { gram(0, stream), gram_host(0) { - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - if (params.ld1 == 0) { params.ld1 = params.is_row_major ? params.n_cols : params.n1; } if (params.ld2 == 0) { params.ld2 = params.is_row_major ? params.n_cols : params.n2; } if (params.ld_out == 0) { params.ld_out = params.is_row_major ? params.n2 : params.n1; } @@ -154,14 +152,14 @@ class GramMatrixTest : public ::testing::TestWithParam { gram_host.resize(gram.size()); std::fill(gram_host.begin(), gram_host.end(), 0); - raft::random::Rng r(42137ULL); - r.uniform(x1.data(), x1.size(), math_t(0), math_t(1), stream); - r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream); + raft::random::RngState r(42137ULL); + raft::random::uniform(handle, r, x1.data(), x1.size(), math_t(0), math_t(1)); + raft::random::uniform(handle, r, x2.data(), x2.size(), math_t(0), math_t(1)); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } - ~GramMatrixTest() override { RAFT_CUDA_TRY_NO_THROW(cudaStreamDestroy(stream)); } + ~GramMatrixTest() override {} int prepareCsr(math_t* dense, int n_rows, int ld, int* indptr, int* indices, math_t* data) { diff --git a/cpp/test/test_utils.cuh b/cpp/test/test_utils.cuh index 5704eefae3..1afa7acc83 100644 --- a/cpp/test/test_utils.cuh +++ b/cpp/test/test_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -228,38 +228,39 @@ testing::AssertionResult diagonalMatch( } template -typename std::enable_if_t> gen_uniform(T* out, +typename std::enable_if_t> gen_uniform(const raft::resources& handle, + T* out, raft::random::RngState& rng, IdxT len, - cudaStream_t stream, T range_min = T(-1), T range_max = T(1)) { - raft::random::uniform(rng, out, len, range_min, range_max, stream); + raft::random::uniform(handle, rng, out, len, range_min, range_max); } template -typename std::enable_if_t> gen_uniform(T* out, +typename std::enable_if_t> gen_uniform(const raft::resources& handle, + T* out, raft::random::RngState& rng, IdxT len, - cudaStream_t stream, T range_min = T(0), T range_max = T(100)) { - raft::random::uniformInt(rng, out, len, range_min, range_max, stream); + raft::random::uniformInt(handle, rng, out, len, range_min, range_max); } template -void gen_uniform(raft::KeyValuePair* out, +void gen_uniform(const raft::resources& handle, + raft::KeyValuePair* out, raft::random::RngState& rng, - IdxT len, - cudaStream_t stream) + IdxT len) { + auto stream = resource::get_cuda_stream(handle); rmm::device_uvector keys(len, stream); rmm::device_uvector values(len, stream); - gen_uniform(keys.data(), rng, len, stream); - gen_uniform(values.data(), rng, len, stream); + gen_uniform(handle, keys.data(), rng, len); + gen_uniform(handle, values.data(), rng, len); const T1* d_keys = keys.data(); const T2* d_values = values.data(); diff --git a/cpp/test/util/bitonic_sort.cu b/cpp/test/util/bitonic_sort.cu index 2cf5420334..f928480b54 100644 --- a/cpp/test/util/bitonic_sort.cu +++ b/cpp/test/util/bitonic_sort.cu @@ -109,6 +109,7 @@ class BitonicTest : public testing::TestWithParam { // NOLINT std::vector in; // NOLINT std::vector out; // NOLINT std::vector ref; // NOLINT + raft::resources handle_; void segmented_sort(std::vector& vec, int k, bool ascending) // NOLINT { @@ -128,14 +129,14 @@ class BitonicTest : public testing::TestWithParam { // NOLINT } } - void fill_random(rmm::device_uvector& arr, rmm::cuda_stream_view stream) + void fill_random(rmm::device_uvector& arr) { - raft::random::Rng rng(42); + raft::random::RngState rng(42); if constexpr (std::is_floating_point_v) { - return rng.normal(arr.data(), arr.size(), T(10), T(100), stream); + return raft::random::normal(handle_, rng, arr.data(), arr.size(), T(10), T(100)); } if constexpr (std::is_integral_v) { - return rng.normalInt(arr.data(), arr.size(), T(10), T(100), stream); + return raft::random::normalInt(handle_, rng, arr.data(), arr.size(), T(10), T(100)); } } @@ -146,11 +147,11 @@ class BitonicTest : public testing::TestWithParam { // NOLINT out(spec.len()), ref(spec.len()) { - auto stream = rmm::cuda_stream_default; + auto stream = resource::get_cuda_stream(handle_); // generate input rmm::device_uvector arr_d(spec.len(), stream); - fill_random(arr_d, stream); + fill_random(arr_d); update_host(in.data(), arr_d.data(), arr_d.size(), stream); // calculate the results