diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index c794b1617c..6e0636c37a 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -1025,7 +1026,8 @@ void GnndGraph::init_random_graph() // segment_x stores neighbors which id % num_segments == x std::vector rand_seq(nrow / num_segments); std::iota(rand_seq.begin(), rand_seq.end(), 0); - std::random_shuffle(rand_seq.begin(), rand_seq.end()); + auto gen = std::default_random_engine{seg_idx}; + std::shuffle(rand_seq.begin(), rand_seq.end(), gen); #pragma omp parallel for for (size_t i = 0; i < nrow; i++) { diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh index 0590fe52e8..c06823993e 100644 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -90,7 +90,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam { index_params.metric = ps.metric; index_params.graph_degree = ps.graph_degree; index_params.intermediate_graph_degree = 2 * ps.graph_degree; - index_params.max_iterations = 50; + index_params.max_iterations = 100; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); diff --git a/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu index 13bff6ac90..6aa503ca4d 100644 --- a/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu @@ -21,7 +21,7 @@ namespace raft::neighbors::experimental::nn_descent { typedef AnnNNDescentTest AnnNNDescentTestF_U32; -TEST_P(AnnNNDescentTestF_U32, AnnCagra) { this->testNNDescent(); } +TEST_P(AnnNNDescentTestF_U32, AnnNNDescent) { this->testNNDescent(); } INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_U32, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu index 5895303e09..863f7edcc0 100644 --- a/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_nn_descent/test_int8_t_uint32_t.cu @@ -21,7 +21,7 @@ namespace raft::neighbors::experimental::nn_descent { typedef AnnNNDescentTest AnnNNDescentTestI8_U32; -TEST_P(AnnNNDescentTestI8_U32, AnnCagra) { this->testNNDescent(); } +TEST_P(AnnNNDescentTestI8_U32, AnnNNDescent) { this->testNNDescent(); } INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestI8_U32, ::testing::ValuesIn(inputs)); diff --git a/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu index a034e84074..1a1b38fc19 100644 --- a/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu +++ b/cpp/test/neighbors/ann_nn_descent/test_uint8_t_uint32_t.cu @@ -21,7 +21,7 @@ namespace raft::neighbors::experimental::nn_descent { typedef AnnNNDescentTest AnnNNDescentTestUI8_U32; -TEST_P(AnnNNDescentTestUI8_U32, AnnCagra) { this->testNNDescent(); } +TEST_P(AnnNNDescentTestUI8_U32, AnnNNDescent) { this->testNNDescent(); } INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestUI8_U32, ::testing::ValuesIn(inputs));