Skip to content

Commit

Permalink
Set RNG seeds in NN Descent to diagnose flaky tests (#1931)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1931
  • Loading branch information
divyegala authored Oct 27, 2023
1 parent e10fc9f commit 0d199f9
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 5 deletions.
4 changes: 3 additions & 1 deletion cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <limits>
#include <queue>

#include <random>
#include <rmm/device_uvector.hpp>

#include <thrust/execution_policy.h>
Expand Down Expand Up @@ -1025,7 +1026,8 @@ void GnndGraph<Index_t>::init_random_graph()
// segment_x stores neighbors which id % num_segments == x
std::vector<Index_t> rand_seq(nrow / num_segments);
std::iota(rand_seq.begin(), rand_seq.end(), 0);
std::random_shuffle(rand_seq.begin(), rand_seq.end());
auto gen = std::default_random_engine{seg_idx};
std::shuffle(rand_seq.begin(), rand_seq.end(), gen);

#pragma omp parallel for
for (size_t i = 0; i < nrow; i++) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
index_params.metric = ps.metric;
index_params.graph_degree = ps.graph_degree;
index_params.intermediate_graph_degree = 2 * ps.graph_degree;
index_params.max_iterations = 50;
index_params.max_iterations = 100;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace raft::neighbors::experimental::nn_descent {

typedef AnnNNDescentTest<float, float, std::uint32_t> AnnNNDescentTestF_U32;
TEST_P(AnnNNDescentTestF_U32, AnnCagra) { this->testNNDescent(); }
TEST_P(AnnNNDescentTestF_U32, AnnNNDescent) { this->testNNDescent(); }

INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_U32, ::testing::ValuesIn(inputs));

Expand Down
2 changes: 1 addition & 1 deletion cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace raft::neighbors::experimental::nn_descent {

typedef AnnNNDescentTest<float, int8_t, std::uint32_t> AnnNNDescentTestI8_U32;
TEST_P(AnnNNDescentTestI8_U32, AnnCagra) { this->testNNDescent(); }
TEST_P(AnnNNDescentTestI8_U32, AnnNNDescent) { this->testNNDescent(); }

INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestI8_U32, ::testing::ValuesIn(inputs));

Expand Down
2 changes: 1 addition & 1 deletion cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace raft::neighbors::experimental::nn_descent {

typedef AnnNNDescentTest<float, uint8_t, std::uint32_t> AnnNNDescentTestUI8_U32;
TEST_P(AnnNNDescentTestUI8_U32, AnnCagra) { this->testNNDescent(); }
TEST_P(AnnNNDescentTestUI8_U32, AnnNNDescent) { this->testNNDescent(); }

INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestUI8_U32, ::testing::ValuesIn(inputs));

Expand Down

0 comments on commit 0d199f9

Please sign in to comment.