Skip to content

Commit

Permalink
Replace raft::random calls to not use deprecated API (rapidsai#1867)
Browse files Browse the repository at this point in the history
Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1867
  • Loading branch information
lowener authored and divyegala committed Oct 6, 2023
1 parent ffab8f6 commit 3f02ea1
Show file tree
Hide file tree
Showing 22 changed files with 108 additions and 87 deletions.
4 changes: 2 additions & 2 deletions cpp/bench/prims/cluster/kmeans_balanced.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ struct KMeansBalanced : public fixture {
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(
rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream);
handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax);
} else {
raft::random::uniform(
rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream);
handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax);
}
resource::sync_stream(handle, stream);
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/bench/prims/distance/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ struct GramMatrix : public fixture {
A.resize(params.m * params.k, stream);
B.resize(params.k * params.n, stream);
C.resize(params.m * params.n, stream);
raft::random::Rng r(123456ULL);
r.uniform(A.data(), params.m * params.k, T(-1.0), T(1.0), stream);
r.uniform(B.data(), params.k * params.n, T(-1.0), T(1.0), stream);
raft::random::RngState rng(123456ULL);
raft::random::uniform(handle, rng, A.data(), params.m * params.k, T(-1.0), T(1.0));
raft::random::uniform(handle, rng, B.data(), params.k * params.n, T(-1.0), T(1.0));
}

~GramMatrix()
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct rowNorm : public fixture {
rowNorm(const norm_input<IdxT>& p) : params(p), in(p.rows * p.cols, stream), dots(p.rows, stream)
{
raft::random::RngState rng{1234};
raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream);
raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/normalize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct rowNormalize : public fixture {
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.cols, stream)
{
raft::random::RngState rng{1234};
raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream);
raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/reduce_cols_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct reduce_cols_by_key : public fixture {
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.keys, stream), keys(p.cols, stream)
{
raft::random::RngState rng{42};
raft::random::uniformInt(rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys, stream);
raft::random::uniformInt(handle, rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/linalg/reduce_rows_by_key.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct reduce_rows_by_key : public fixture {
workspace(p.rows, stream)
{
raft::random::RngState rng{42};
raft::random::uniformInt(rng, keys.data(), p.rows, (KeyT)0, (KeyT)p.keys, stream);
raft::random::uniformInt(handle, rng, keys.data(), p.rows, (KeyT)0, (KeyT)p.keys);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/prims/matrix/argmin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct Argmin : public fixture {

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
handle, rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1));
resource::sync_stream(handle, stream);
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ struct Gather : public fixture {

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
handle, rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1));
raft::random::uniformInt(
handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows);
if constexpr (Conditional) {
raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream);
raft::random::uniform(handle, rng, stencil.data_handle(), params.map_length, T(-1), T(1));
}
resource::sync_stream(handle, stream);
}
Expand Down
10 changes: 5 additions & 5 deletions cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ struct CagraBench : public fixture {
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(
state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax, stream);
handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax);
raft::random::uniformInt(
state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax, stream);
handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax);
} else {
raft::random::uniform(
state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax, stream);
handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax);
raft::random::uniform(
state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax, stream);
handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax);
}

// Generate random knn graph

raft::random::uniformInt<IdxT>(
state, knn_graph_.data_handle(), knn_graph_.size(), 0, ps.n_samples - 1, stream);
handle, state, knn_graph_.data_handle(), knn_graph_.size(), 0, ps.n_samples - 1);

auto metric = raft::distance::DistanceType::L2Expanded;

Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ struct knn : public fixture {
constexpr T kRangeMax = std::is_integral_v<T> ? std::numeric_limits<T>::max() : T(1);
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(state, vec.data(), n, kRangeMin, kRangeMax, stream);
raft::random::uniformInt(handle, state, vec.data(), n, kRangeMin, kRangeMax);
} else {
raft::random::uniform(state, vec.data(), n, kRangeMin, kRangeMax, stream);
raft::random::uniform(handle, state, vec.data(), n, kRangeMin, kRangeMax);
}
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ inline void make_rotation_matrix(raft::resources const& handle,
uint32_t n_rows,
uint32_t n_cols,
float* rotation_matrix,
raft::random::Rng rng = raft::random::Rng(7ULL))
raft::random::RngState rng = raft::random::RngState(7ULL))
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_pq::make_rotation_matrix(%u * %u)", n_rows, n_cols);
Expand All @@ -134,7 +134,7 @@ inline void make_rotation_matrix(raft::resources const& handle,
if (force_random_rotation || !inplace) {
rmm::device_uvector<float> buf(inplace ? 0 : n * n, stream);
float* mat = inplace ? rotation_matrix : buf.data();
rng.normal(mat, n * n, 0.0f, 1.0f, stream);
raft::random::normal(handle, rng, mat, n * n, 0.0f, 1.0f);
linalg::detail::qrGetQ_inplace(handle, mat, n, n, stream);
if (!inplace) {
RAFT_CUDA_TRY(cudaMemcpy2DAsync(rotation_matrix,
Expand Down
14 changes: 9 additions & 5 deletions cpp/internal/raft_internal/neighbors/refine_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,20 @@ class RefineHelper {
refined_distances_host(handle),
refined_indices_host(handle)
{
raft::random::Rng r(1234ULL);
raft::random::RngState rng(1234ULL);

dataset = raft::make_device_matrix<DataT, IdxT>(handle_, p.n_rows, p.dim);
queries = raft::make_device_matrix<DataT, IdxT>(handle_, p.n_queries, p.dim);
if constexpr (std::is_same<DataT, float>{}) {
r.uniform(dataset.data_handle(), dataset.size(), DataT(-10.0), DataT(10.0), stream_);
r.uniform(queries.data_handle(), queries.size(), DataT(-10.0), DataT(10.0), stream_);
raft::random::uniform(
handle, rng, dataset.data_handle(), dataset.size(), DataT(-10.0), DataT(10.0));
raft::random::uniform(
handle, rng, queries.data_handle(), queries.size(), DataT(-10.0), DataT(10.0));
} else {
r.uniformInt(dataset.data_handle(), dataset.size(), DataT(1), DataT(20), stream_);
r.uniformInt(queries.data_handle(), queries.size(), DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle, rng, dataset.data_handle(), dataset.size(), DataT(1), DataT(20));
raft::random::uniformInt(
handle, rng, queries.data_handle(), queries.size(), DataT(1), DataT(20));
}

refined_distances = raft::make_device_matrix<DistanceT, IdxT>(handle_, p.n_queries, p.k);
Expand Down
6 changes: 3 additions & 3 deletions cpp/test/distance/gram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ class GramMatrixTest : public ::testing::TestWithParam<GramMatrixInputs> {
gram_host.resize(gram.size());
std::fill(gram_host.begin(), gram_host.end(), 0);

raft::random::Rng r(42137ULL);
r.uniform(x1.data(), x1.size(), math_t(0), math_t(1), stream);
r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream);
raft::random::RngState rng(42137ULL);
raft::random::uniform(handle, rng, x1.data(), x1.size(), math_t(0), math_t(1));
raft::random::uniform(handle, rng, x2.data(), x2.size(), math_t(0), math_t(1));
}

~GramMatrixTest() override {}
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/linalg/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ReduceTest : public ::testing::TestWithParam<ReduceInputs<InType, OutType,
raft::random::RngState r(params.seed);
IdxType rows = params.rows, cols = params.cols;
IdxType len = rows * cols;
gen_uniform(data.data(), r, len, stream);
gen_uniform(handle, data.data(), r, len);

MainLambda main_op;
ReduceLambda reduce_op;
Expand Down
42 changes: 25 additions & 17 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,19 @@ __global__ void GenerateRoundingErrorFreeDataset_kernel(float* const ptr,
ptr[tid] = u32 / resolution;
}

void GenerateRoundingErrorFreeDataset(float* const ptr,
void GenerateRoundingErrorFreeDataset(const raft::resources& handle,
float* const ptr,
const uint32_t n_row,
const uint32_t dim,
raft::random::Rng& rng,
cudaStream_t cuda_stream)
raft::random::RngState& rng)
{
auto cuda_stream = resource::get_cuda_stream(handle);
const uint32_t size = n_row * dim;
const uint32_t block_size = 256;
const uint32_t grid_size = (size + block_size - 1) / block_size;

const uint32_t resolution = 1u << static_cast<unsigned>(std::floor((24 - std::log2(dim)) / 2));
rng.uniformInt(reinterpret_cast<uint32_t*>(ptr), size, 0u, resolution - 1, cuda_stream);
raft::random::uniformInt(handle, rng, reinterpret_cast<uint32_t*>(ptr), size, 0u, resolution - 1);

GenerateRoundingErrorFreeDataset_kernel<<<grid_size, block_size, 0, cuda_stream>>>(
ptr, size, resolution);
Expand Down Expand Up @@ -293,13 +294,16 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
{
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
search_queries.resize(ps.n_queries * ps.dim, stream_);
raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_);
r.normal(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0));
raft::random::normal(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_);
r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down Expand Up @@ -379,11 +383,12 @@ class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {
void SetUp() override
{
database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream());
raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream());
GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r);
} else {
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), handle_.get_stream());
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
}
handle_.sync_stream();
}
Expand Down Expand Up @@ -643,13 +648,16 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
{
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
search_queries.resize(ps.n_queries * ps.dim, stream_);
raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_);
r.normal(search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0));
raft::random::normal(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_);
r.uniformInt(search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down
14 changes: 9 additions & 5 deletions cpp/test/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,17 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {
database.resize(ps.num_db_vecs * ps.dim, stream_);
search_queries.resize(ps.num_queries * ps.dim, stream_);

raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.uniform(database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0), stream_);
r.uniform(search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::uniform(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0));
raft::random::uniform(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20), stream_);
r.uniformInt(search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down
14 changes: 9 additions & 5 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,17 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_);
search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_);

raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.uniform(database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0), stream_);
r.uniform(search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::uniform(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(0.1), DataT(2.0));
raft::random::uniform(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20), stream_);
r.uniformInt(search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.num_db_vecs * ps.dim, DataT(1), DataT(20));
raft::random::uniformInt(
handle_, r, search_queries.data(), ps.num_queries * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down
7 changes: 4 additions & 3 deletions cpp/test/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,12 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
void SetUp() override
{
database.resize(((size_t)ps.n_rows) * ps.dim, stream_);
raft::random::Rng r(1234ULL);
raft::random::RngState r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
r.normal(database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0), stream_);
raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0));
} else {
r.uniformInt(database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20), stream_);
raft::random::uniformInt(
handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20));
}
resource::sync_stream(handle_);
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/random/rmat_rectangular_generator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class RmatGenTest : public ::testing::TestWithParam<RmatInputs> {
max_scale{std::max(params.r_scale, params.c_scale)}
{
theta.resize(4 * max_scale, stream);
uniform<float>(state, theta.data(), theta.size(), 0.0f, 1.0f, stream);
uniform<float>(handle, state, theta.data(), theta.size(), 0.0f, 1.0f);
normalize<float, float>(theta.data(),
theta.data(),
max_scale,
Expand Down Expand Up @@ -271,7 +271,7 @@ class RmatGenMdspanTest : public ::testing::TestWithParam<RmatInputs> {
max_scale{std::max(params.r_scale, params.c_scale)}
{
theta.resize(4 * max_scale, stream);
uniform<float>(state, theta.data(), theta.size(), 0.0f, 1.0f, stream);
uniform<float>(handle, state, theta.data(), theta.size(), 0.0f, 1.0f);
normalize<float, float>(theta.data(),
theta.data(),
max_scale,
Expand Down
12 changes: 5 additions & 7 deletions cpp/test/sparse/gram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class GramMatrixTest : public ::testing::TestWithParam<GramMatrixInputs> {
protected:
GramMatrixTest()
: params(GetParam()),
stream(0),
stream(resource::get_cuda_stream(handle)),
x1(0, stream),
x2(0, stream),
x1_csr_indptr(0, stream),
Expand All @@ -137,8 +137,6 @@ class GramMatrixTest : public ::testing::TestWithParam<GramMatrixInputs> {
gram(0, stream),
gram_host(0)
{
RAFT_CUDA_TRY(cudaStreamCreate(&stream));

if (params.ld1 == 0) { params.ld1 = params.is_row_major ? params.n_cols : params.n1; }
if (params.ld2 == 0) { params.ld2 = params.is_row_major ? params.n_cols : params.n2; }
if (params.ld_out == 0) { params.ld_out = params.is_row_major ? params.n2 : params.n1; }
Expand All @@ -154,14 +152,14 @@ class GramMatrixTest : public ::testing::TestWithParam<GramMatrixInputs> {
gram_host.resize(gram.size());
std::fill(gram_host.begin(), gram_host.end(), 0);

raft::random::Rng r(42137ULL);
r.uniform(x1.data(), x1.size(), math_t(0), math_t(1), stream);
r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream);
raft::random::RngState r(42137ULL);
raft::random::uniform(handle, r, x1.data(), x1.size(), math_t(0), math_t(1));
raft::random::uniform(handle, r, x2.data(), x2.size(), math_t(0), math_t(1));

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
}

~GramMatrixTest() override { RAFT_CUDA_TRY_NO_THROW(cudaStreamDestroy(stream)); }
~GramMatrixTest() override {}

int prepareCsr(math_t* dense, int n_rows, int ld, int* indptr, int* indices, math_t* data)
{
Expand Down
Loading

0 comments on commit 3f02ea1

Please sign in to comment.