From 61ce3d0e875e6e436fee66bf8bbeb88b84d24c17 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Mon, 4 Dec 2023 16:23:16 -0800 Subject: [PATCH] fix HITS convergence error --- cpp/src/link_analysis/hits_impl.cuh | 3 ++- cpp/tests/link_analysis/hits_test.cpp | 28 ++++++++++++++---------- cpp/tests/link_analysis/mg_hits_test.cpp | 18 +++++++-------- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/cpp/src/link_analysis/hits_impl.cuh b/cpp/src/link_analysis/hits_impl.cuh index 674046745b1..5cdf1b9dc6a 100644 --- a/cpp/src/link_analysis/hits_impl.cuh +++ b/cpp/src/link_analysis/hits_impl.cuh @@ -80,6 +80,7 @@ std::tuple hits(raft::handle_t const& handle, if (num_vertices == 0) { return std::make_tuple(diff_sum, final_iteration_count); } CUGRAPH_EXPECTS(epsilon >= 0.0, "Invalid input argument: epsilon should be non-negative."); + auto tolerance = static_cast(graph_view.number_of_vertices()) * epsilon; // Check validity of initial guess if supplied if (has_initial_hubs_guess && do_expensive_check) { @@ -171,7 +172,7 @@ std::tuple hits(raft::handle_t const& handle, std::swap(prev_hubs, curr_hubs); iter++; - if (diff_sum < epsilon) { + if (diff_sum < tolerance) { break; } else if (iter >= max_iterations) { CUGRAPH_FAIL("HITS failed to converge."); diff --git a/cpp/tests/link_analysis/hits_test.cpp b/cpp/tests/link_analysis/hits_test.cpp index d0e77769034..6796761e212 100644 --- a/cpp/tests/link_analysis/hits_test.cpp +++ b/cpp/tests/link_analysis/hits_test.cpp @@ -52,9 +52,11 @@ std::tuple, std::vector, double, size_t> hits_re size_t max_iterations, std::optional starting_hub_values, bool normalized, - double tolerance) + double epsilon) { CUGRAPH_EXPECTS(num_vertices > 1, "number of vertices expected to be non-zero"); + auto tolerance = static_cast(num_vertices) * epsilon; + std::vector prev_hubs(num_vertices, result_t{1.0} / num_vertices); std::vector prev_authorities(num_vertices, result_t{1.0} / num_vertices); std::vector curr_hubs(num_vertices); @@ -127,8 +129,8 @@ std::tuple, std::vector, double, size_t> hits_re } struct Hits_Usecase { - bool check_correctness{true}; bool check_initial_input{false}; + bool check_correctness{true}; }; template @@ -175,8 +177,8 @@ class Tests_Hits : public ::testing::TestWithParam d_hubs(graph_view.local_vertex_partition_range_size(), handle.get_stream()); @@ -201,7 +203,7 @@ class Tests_Hits : public ::testing::TestWithParam h_cugraph_hits{}; if (renumber) { @@ -246,8 +248,7 @@ class Tests_Hits : public ::testing::TestWithParam(graph_view.number_of_vertices())) * - threshold_ratio; // skip comparison for low hits vertices (lowly ranked vertices) + 1e-6; // skip comparison for low hits vertices (lowly ranked vertices) auto nearly_equal = [threshold_ratio, threshold_magnitude](auto lhs, auto rhs) { return std::abs(lhs - rhs) <= std::max(std::max(lhs, rhs) * threshold_ratio, threshold_magnitude); @@ -294,14 +295,17 @@ INSTANTIATE_TEST_SUITE_P( Tests_Hits_File, ::testing::Combine( // enable correctness checks - ::testing::Values(Hits_Usecase{true, false}, Hits_Usecase{true, true}), + ::testing::Values(Hits_Usecase{false, true}, Hits_Usecase{true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), + cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"), cugraph::test::File_Usecase("test/datasets/dolphins.mtx")))); INSTANTIATE_TEST_SUITE_P(rmat_small_test, Tests_Hits_Rmat, // enable correctness checks - ::testing::Combine(::testing::Values(Hits_Usecase{true, false}, + ::testing::Combine(::testing::Values(Hits_Usecase{false, true}, Hits_Usecase{true, true}), ::testing::Values(cugraph::test::Rmat_Usecase( 10, 16, 0.57, 0.19, 0.19, 0, false, false)))); @@ -315,7 +319,7 @@ INSTANTIATE_TEST_SUITE_P( Tests_Hits_File, ::testing::Combine( // disable correctness checks - ::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{false, true}), + ::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{true, false}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); INSTANTIATE_TEST_SUITE_P( @@ -327,7 +331,7 @@ INSTANTIATE_TEST_SUITE_P( Tests_Hits_Rmat, // disable correctness checks for large graphs ::testing::Combine( - ::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{false, true}), + ::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/link_analysis/mg_hits_test.cpp b/cpp/tests/link_analysis/mg_hits_test.cpp index cf95d03681d..5c89bafd08e 100644 --- a/cpp/tests/link_analysis/mg_hits_test.cpp +++ b/cpp/tests/link_analysis/mg_hits_test.cpp @@ -33,8 +33,8 @@ #include struct Hits_Usecase { - bool check_correctness{true}; bool check_initial_input{false}; + bool check_correctness{true}; }; template @@ -81,7 +81,7 @@ class Tests_MGHits : public ::testing::TestWithParam d_mg_hubs(mg_graph_view.local_vertex_partition_range_size(), handle_->get_stream()); @@ -110,7 +110,7 @@ class Tests_MGHits : public ::testing::TestWithParam(mg_graph_view.number_of_vertices())) * - threshold_ratio; // skip comparison for low Hits verties (lowly ranked - // vertices) + 1e-6; // skip comparison for low Hits verties (lowly ranked vertices) auto nearly_equal = [threshold_ratio, threshold_magnitude](auto lhs, auto rhs) { return std::abs(lhs - rhs) < std::max(std::max(lhs, rhs) * threshold_ratio, threshold_magnitude); @@ -274,7 +272,7 @@ INSTANTIATE_TEST_SUITE_P( Tests_MGHits_File, ::testing::Combine( // enable correctness checks - ::testing::Values(Hits_Usecase{true, false}, Hits_Usecase{true, true}), + ::testing::Values(Hits_Usecase{false, true}, Hits_Usecase{true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), @@ -285,7 +283,7 @@ INSTANTIATE_TEST_SUITE_P( Tests_MGHits_Rmat, ::testing::Combine( // enable correctness checks - ::testing::Values(Hits_Usecase{true, false}, Hits_Usecase{true, true}), + ::testing::Values(Hits_Usecase{false, true}, Hits_Usecase{true, true}), ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false)))); INSTANTIATE_TEST_SUITE_P( @@ -297,7 +295,7 @@ INSTANTIATE_TEST_SUITE_P( Tests_MGHits_Rmat, ::testing::Combine( // disable correctness checks for large graphs - ::testing::Values(Hits_Usecase{false, false}), + ::testing::Values(Hits_Usecase{false, false}, Hits_Usecase{true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_MG_TEST_PROGRAM_MAIN()