diff --git a/cpp/tests/c_api/sg_random_walks_test.c b/cpp/tests/c_api/sg_random_walks_test.c index 14108d91c04..a4a77b5775a 100644 --- a/cpp/tests/c_api/sg_random_walks_test.c +++ b/cpp/tests/c_api/sg_random_walks_test.c @@ -192,9 +192,6 @@ int generic_biased_random_walks_test(vertex_t* h_src, ret_code = cugraph_biased_random_walks(handle, graph, d_start_view, max_depth, &result, &ret_error); -#if 1 - TEST_ASSERT(test_ret_value, ret_code != CUGRAPH_SUCCESS, "biased_random_walks should have failed") -#else TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "biased_random_walks failed."); @@ -208,10 +205,10 @@ int generic_biased_random_walks_test(vertex_t* h_src, size_t wgts_size = cugraph_type_erased_device_array_view_size(wgts); vertex_t h_result_verts[verts_size]; - vertex_t h_result_wgts[wgts_size]; + weight_t h_result_wgts[wgts_size]; - ret_code = - cugraph_type_erased_device_array_view_copy_to_host(handle, (byte_t*)h_verts, verts, &ret_error); + ret_code = cugraph_type_erased_device_array_view_copy_to_host( + handle, (byte_t*)h_result_verts, verts, &ret_error); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "copy_to_host failed."); ret_code = cugraph_type_erased_device_array_view_copy_to_host( @@ -231,23 +228,35 @@ int generic_biased_random_walks_test(vertex_t* h_src, M[h_src[i]][h_dst[i]] = h_wgt[i]; TEST_ASSERT(test_ret_value, - cugraph_random_walk_result_get_max_path_length() == max_depth, + cugraph_random_walk_result_get_max_path_length(result) == max_depth, "path length does not match"); for (int i = 0; (i < num_starts) && (test_ret_value == 0); ++i) { - TEST_ASSERT(test_ret_value, - M[h_start[i]][h_result_verts[i * (max_depth + 1)]] == h_result_wgts[i * max_depth], - "biased_random_walks got edge that doesn't exist"); - for (size_t j = 1; j < cugraph_random_walk_result_get_max_path_length(); ++j) - TEST_ASSERT( - test_ret_value, - M[h_start[i * (max_depth + 1) + j - 1]][h_result_verts[i * (max_depth + 1) + j]] == - h_result_wgts[i * max_depth + j - 1], - "biased_random_walks got edge that doesn't exist"); + TEST_ASSERT( + test_ret_value, h_start[i] == h_result_verts[i * (max_depth + 1)], "start of path not found"); + for (size_t j = 0; j < max_depth; ++j) { + int src_index = i * (max_depth + 1) + j; + int dst_index = src_index + 1; + if (h_result_verts[dst_index] < 0) { + if (h_result_verts[src_index] >= 0) { + int departing_count = 0; + for (int k = 0; k < num_vertices; ++k) { + if (M[h_result_verts[src_index]][k] >= 0) departing_count++; + } + TEST_ASSERT(test_ret_value, + departing_count == 0, + "biased_random_walks found no edge when an edge exists"); + } + } else { + TEST_ASSERT(test_ret_value, + M[h_result_verts[src_index]][h_result_verts[dst_index]] == + h_result_wgts[i * max_depth + j], + "biased_random_walks got edge that doesn't exist"); + } + } } cugraph_random_walk_result_free(result); -#endif cugraph_sg_graph_free(graph); cugraph_free_resource_handle(handle); @@ -302,10 +311,6 @@ int generic_node2vec_random_walks_test(vertex_t* h_src, ret_code = cugraph_node2vec_random_walks( handle, graph, d_start_view, max_depth, p, q, &result, &ret_error); -#if 1 - TEST_ASSERT( - test_ret_value, ret_code != CUGRAPH_SUCCESS, "node2vec_random_walks should have failed") -#else TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "node2vec_random_walks failed."); @@ -319,10 +324,10 @@ int generic_node2vec_random_walks_test(vertex_t* h_src, size_t wgts_size = cugraph_type_erased_device_array_view_size(wgts); vertex_t h_result_verts[verts_size]; - vertex_t h_result_wgts[wgts_size]; + weight_t h_result_wgts[wgts_size]; - ret_code = - cugraph_type_erased_device_array_view_copy_to_host(handle, (byte_t*)h_verts, verts, &ret_error); + ret_code = cugraph_type_erased_device_array_view_copy_to_host( + handle, (byte_t*)h_result_verts, verts, &ret_error); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "copy_to_host failed."); ret_code = cugraph_type_erased_device_array_view_copy_to_host( @@ -342,23 +347,35 @@ int generic_node2vec_random_walks_test(vertex_t* h_src, M[h_src[i]][h_dst[i]] = h_wgt[i]; TEST_ASSERT(test_ret_value, - cugraph_random_walk_result_get_max_path_length() == max_depth, + cugraph_random_walk_result_get_max_path_length(result) == max_depth, "path length does not match"); for (int i = 0; (i < num_starts) && (test_ret_value == 0); ++i) { - TEST_ASSERT(test_ret_value, - M[h_start[i]][h_result_verts[i * (max_depth + 1)]] == h_result_wgts[i * max_depth], - "node2vec_random_walks got edge that doesn't exist"); - for (size_t j = 1; j < max_depth; ++j) - TEST_ASSERT( - test_ret_value, - M[h_start[i * (max_depth + 1) + j - 1]][h_result_verts[i * (max_depth + 1) + j]] == - h_result_wgts[i * max_depth + j - 1], - "node2vec_random_walks got edge that doesn't exist"); + TEST_ASSERT( + test_ret_value, h_start[i] == h_result_verts[i * (max_depth + 1)], "start of path not found"); + for (size_t j = 0; j < max_depth; ++j) { + int src_index = i * (max_depth + 1) + j; + int dst_index = src_index + 1; + if (h_result_verts[dst_index] < 0) { + if (h_result_verts[src_index] >= 0) { + int departing_count = 0; + for (int k = 0; k < num_vertices; ++k) { + if (M[h_result_verts[src_index]][k] >= 0) departing_count++; + } + TEST_ASSERT(test_ret_value, + departing_count == 0, + "node2vec_random_walks found no edge when an edge exists"); + } + } else { + TEST_ASSERT(test_ret_value, + M[h_result_verts[src_index]][h_result_verts[dst_index]] == + h_result_wgts[i * max_depth + j], + "node2vec_random_walks got edge that doesn't exist"); + } + } } cugraph_random_walk_result_free(result); -#endif cugraph_sg_graph_free(graph); cugraph_free_resource_handle(handle); @@ -390,7 +407,7 @@ int test_biased_random_walks() vertex_t src[] = {0, 1, 1, 2, 2, 2, 3, 4}; vertex_t dst[] = {1, 3, 4, 0, 1, 3, 5, 5}; - weight_t wgt[] = {0, 1, 2, 3, 4, 5, 6, 7}; + weight_t wgt[] = {1, 2, 3, 4, 5, 6, 7, 8}; vertex_t start[] = {2, 2}; return generic_biased_random_walks_test(