diff --git a/cpp/include/cugraph/utilities/dataframe_buffer.hpp b/cpp/include/cugraph/utilities/dataframe_buffer.hpp index d52160abd19..ab4c4eff6b5 100644 --- a/cpp/include/cugraph/utilities/dataframe_buffer.hpp +++ b/cpp/include/cugraph/utilities/dataframe_buffer.hpp @@ -68,24 +68,6 @@ auto get_dataframe_buffer_cend_tuple_impl(std::index_sequence, TupleType& } // namespace detail -template -struct dataframe_element { - using type = void; -}; - -template -struct dataframe_element...>> { - using type = thrust::tuple; -}; - -template -struct dataframe_element> { - using type = T; -}; - -template -using dataframe_element_t = typename dataframe_element::type; - template ::value>* = nullptr> auto allocate_dataframe_buffer(size_t buffer_size, rmm::cuda_stream_view stream_view) { @@ -224,4 +206,30 @@ auto get_dataframe_buffer_cend(BufferType& buffer) std::make_index_sequence::value>(), buffer); } +template +struct dataframe_buffer_value_type { + using type = void; +}; + +template +struct dataframe_buffer_value_type> { + using type = T; +}; + +template +struct dataframe_buffer_value_type...>> { + using type = thrust::tuple; +}; + +template +using dataframe_buffer_value_type_t = typename dataframe_buffer_value_type::type; + +template +struct dataframe_buffer_type { + using type = decltype(allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{})); +}; + +template +using dataframe_buffer_type_t = typename dataframe_buffer_type::type; + } // namespace cugraph diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 6550e3cf6c6..5908607e8a8 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -468,7 +468,7 @@ ConfigureTest(UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/uniform_neighbor_sampling. ################################################################################################### # - BIASED NBR SAMPLING tests --------------------------------------------------------------------- -ConfigureTest(BIASED_NEIGHBOR_SAMPLING_TEST sampling/biased_neighbor_sampling.cu) +ConfigureTest(BIASED_NEIGHBOR_SAMPLING_TEST sampling/biased_neighbor_sampling.cpp) ################################################################################################### # - SAMPLING_POST_PROCESSING tests ---------------------------------------------------------------- diff --git a/cpp/tests/centrality/katz_centrality_test.cpp b/cpp/tests/centrality/katz_centrality_test.cpp index 7c8a22221c0..190007c38e5 100644 --- a/cpp/tests/centrality/katz_centrality_test.cpp +++ b/cpp/tests/centrality/katz_centrality_test.cpp @@ -214,7 +214,8 @@ class Tests_KatzCentrality rmm::device_uvector d_unrenumbered_katz_centralities(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_katz_centralities) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_katz_centralities); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_katz_centralities); h_cugraph_katz_centralities = cugraph::test::to_host(handle, d_unrenumbered_katz_centralities); } else { diff --git a/cpp/tests/community/k_truss_test.cpp b/cpp/tests/community/k_truss_test.cpp index c8010422e42..424d52f2067 100644 --- a/cpp/tests/community/k_truss_test.cpp +++ b/cpp/tests/community/k_truss_test.cpp @@ -224,10 +224,11 @@ class Tests_KTruss : public ::testing::TestWithParam( + handle, d_cugraph_srcs, d_cugraph_dsts, *d_cugraph_wgts); } else { std::tie(d_sorted_cugraph_srcs, d_sorted_cugraph_dsts) = - cugraph::test::sort(handle, d_cugraph_srcs, d_cugraph_dsts); + cugraph::test::sort(handle, d_cugraph_srcs, d_cugraph_dsts); } auto h_cugraph_srcs = cugraph::test::to_host(handle, d_sorted_cugraph_srcs); diff --git a/cpp/tests/components/weakly_connected_components_test.cpp b/cpp/tests/components/weakly_connected_components_test.cpp index 2dd82316b00..7b909c6f594 100644 --- a/cpp/tests/components/weakly_connected_components_test.cpp +++ b/cpp/tests/components/weakly_connected_components_test.cpp @@ -170,7 +170,8 @@ class Tests_WeaklyConnectedComponent if (renumber) { rmm::device_uvector d_unrenumbered_components(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_components) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_components); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_components); h_cugraph_components = cugraph::test::to_host(handle, d_unrenumbered_components); } else { h_cugraph_components = cugraph::test::to_host(handle, d_components); diff --git a/cpp/tests/cores/core_number_test.cpp b/cpp/tests/cores/core_number_test.cpp index fb6f26278af..ca0174202c2 100644 --- a/cpp/tests/cores/core_number_test.cpp +++ b/cpp/tests/cores/core_number_test.cpp @@ -300,7 +300,8 @@ class Tests_CoreNumber if (renumber) { rmm::device_uvector d_unrenumbered_core_numbers(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_core_numbers) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_core_numbers); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_core_numbers); h_cugraph_core_numbers = cugraph::test::to_host(handle, d_unrenumbered_core_numbers); } else { h_cugraph_core_numbers = cugraph::test::to_host(handle, d_core_numbers); diff --git a/cpp/tests/link_analysis/hits_test.cpp b/cpp/tests/link_analysis/hits_test.cpp index f1b2a0ef0df..31ed5537a6b 100644 --- a/cpp/tests/link_analysis/hits_test.cpp +++ b/cpp/tests/link_analysis/hits_test.cpp @@ -255,7 +255,8 @@ class Tests_Hits : public ::testing::TestWithParam d_unrenumbered_initial_random_hubs(0, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_initial_random_hubs) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, *d_initial_random_hubs); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, *d_initial_random_hubs); h_initial_random_hubs = cugraph::test::to_host(handle, d_unrenumbered_initial_random_hubs); } else { @@ -277,7 +278,7 @@ class Tests_Hits : public ::testing::TestWithParam d_unrenumbered_hubs(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_hubs) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_hubs); + cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_hubs); h_cugraph_hits = cugraph::test::to_host(handle, d_unrenumbered_hubs); } else { h_cugraph_hits = cugraph::test::to_host(handle, d_hubs); diff --git a/cpp/tests/link_analysis/pagerank_test.cpp b/cpp/tests/link_analysis/pagerank_test.cpp index 9219832ac63..196476d6756 100644 --- a/cpp/tests/link_analysis/pagerank_test.cpp +++ b/cpp/tests/link_analysis/pagerank_test.cpp @@ -282,9 +282,9 @@ class Tests_PageRank vertex_t{0}, graph_view.number_of_vertices()); std::tie(d_unrenumbered_personalization_vertices, d_unrenumbered_personalization_values) = - cugraph::test::sort_by_key(handle, - d_unrenumbered_personalization_vertices, - d_unrenumbered_personalization_values); + cugraph::test::sort_by_key(handle, + d_unrenumbered_personalization_vertices, + d_unrenumbered_personalization_values); h_unrenumbered_personalization_vertices = cugraph::test::to_host(handle, d_unrenumbered_personalization_vertices); @@ -327,7 +327,8 @@ class Tests_PageRank if (renumber) { rmm::device_uvector d_unrenumbered_pageranks(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_pageranks) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_pageranks); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_pageranks); h_cugraph_pageranks = cugraph::test::to_host(handle, d_unrenumbered_pageranks); } else { h_cugraph_pageranks = cugraph::test::to_host(handle, d_pageranks); diff --git a/cpp/tests/sampling/biased_neighbor_sampling.cu b/cpp/tests/sampling/biased_neighbor_sampling.cpp similarity index 100% rename from cpp/tests/sampling/biased_neighbor_sampling.cu rename to cpp/tests/sampling/biased_neighbor_sampling.cpp diff --git a/cpp/tests/traversal/bfs_test.cpp b/cpp/tests/traversal/bfs_test.cpp index fda80f1c191..8d3cdb3d24b 100644 --- a/cpp/tests/traversal/bfs_test.cpp +++ b/cpp/tests/traversal/bfs_test.cpp @@ -206,10 +206,12 @@ class Tests_BFS : public ::testing::TestWithParam d_unrenumbered_distances(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_distances) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_distances); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_distances); rmm::device_uvector d_unrenumbered_predecessors(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_predecessors) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_predecessors); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_predecessors); h_cugraph_distances = cugraph::test::to_host(handle, d_unrenumbered_distances); h_cugraph_predecessors = cugraph::test::to_host(handle, d_unrenumbered_predecessors); } else { diff --git a/cpp/tests/traversal/sssp_test.cpp b/cpp/tests/traversal/sssp_test.cpp index ee236e72cdc..3eff1a8e106 100644 --- a/cpp/tests/traversal/sssp_test.cpp +++ b/cpp/tests/traversal/sssp_test.cpp @@ -206,10 +206,12 @@ class Tests_SSSP : public ::testing::TestWithParam d_unrenumbered_distances(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_distances) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_distances); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_distances); rmm::device_uvector d_unrenumbered_predecessors(size_t{0}, handle.get_stream()); std::tie(std::ignore, d_unrenumbered_predecessors) = - cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_predecessors); + cugraph::test::sort_by_key( + handle, *d_renumber_map_labels, d_predecessors); h_cugraph_distances = cugraph::test::to_host(handle, d_unrenumbered_distances); h_cugraph_predecessors = cugraph::test::to_host(handle, d_unrenumbered_predecessors); diff --git a/cpp/tests/utilities/conversion_utilities_impl.cuh b/cpp/tests/utilities/conversion_utilities_impl.cuh index 6eb7357eedd..b930d08d7d8 100644 --- a/cpp/tests/utilities/conversion_utilities_impl.cuh +++ b/cpp/tests/utilities/conversion_utilities_impl.cuh @@ -415,7 +415,8 @@ mg_vertex_property_values_to_sg_vertex_property_values( static_cast((*sg_renumber_map).size())); } - std::tie(sg_vertices, sg_values) = cugraph::test::sort_by_key(handle, sg_vertices, sg_values); + std::tie(sg_vertices, sg_values) = cugraph::test::sort_by_key( + handle, std::move(sg_vertices), std::move(sg_values)); if (mg_vertices) { return std::make_tuple(std::move(sg_vertices), std::move(sg_values)); diff --git a/cpp/tests/utilities/thrust_wrapper.cu b/cpp/tests/utilities/thrust_wrapper.cu index 93bb8a04e87..2ecfee6d252 100644 --- a/cpp/tests/utilities/thrust_wrapper.cu +++ b/cpp/tests/utilities/thrust_wrapper.cu @@ -16,10 +16,6 @@ #include "utilities/thrust_wrapper.hpp" -#include - -#include - #include #include @@ -31,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -38,12 +35,12 @@ namespace cugraph { namespace test { -template -value_buffer_type sort(raft::handle_t const& handle, value_buffer_type const& values) +template +cugraph::dataframe_buffer_type_t sort( + raft::handle_t const& handle, cugraph::dataframe_buffer_type_t const& values) { - auto sorted_values = - cugraph::allocate_dataframe_buffer>( - values.size(), handle.get_stream()); + auto sorted_values = cugraph::allocate_dataframe_buffer( + cugraph::size_dataframe_buffer(values), handle.get_stream()); thrust::copy(handle.get_thrust_policy(), cugraph::get_dataframe_buffer_begin(values), @@ -57,76 +54,85 @@ value_buffer_type sort(raft::handle_t const& handle, value_buffer_type const& va return sorted_values; } -template -std::tuple sort(raft::handle_t const& handle, - value_buffer_type const& first, - value_buffer_type const& second) +template rmm::device_uvector sort(raft::handle_t const& handle, + rmm::device_uvector const& values); + +template rmm::device_uvector sort(raft::handle_t const& handle, + rmm::device_uvector const& values); + +template +cugraph::dataframe_buffer_type_t sort(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& values) +{ + auto sorted_values = std::move(values); + + thrust::sort(handle.get_thrust_policy(), + cugraph::get_dataframe_buffer_begin(sorted_values), + cugraph::get_dataframe_buffer_end(sorted_values)); + + return sorted_values; +} + +template rmm::device_uvector sort(raft::handle_t const& handle, + rmm::device_uvector&& values); + +template rmm::device_uvector sort(raft::handle_t const& handle, + rmm::device_uvector&& values); + +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& first, + cugraph::dataframe_buffer_type_t const& second) { auto sorted_first = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); + cugraph::allocate_dataframe_buffer(size_dataframe_buffer(first), handle.get_stream()); auto sorted_second = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); - - auto execution_policy = handle.get_thrust_policy(); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(first), - cugraph::get_dataframe_buffer_end(first), - cugraph::get_dataframe_buffer_begin(sorted_first)); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(second), - cugraph::get_dataframe_buffer_end(second), - cugraph::get_dataframe_buffer_begin(sorted_second)); + cugraph::allocate_dataframe_buffer(size_dataframe_buffer(first), handle.get_stream()); + + auto input_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(first), + cugraph::get_dataframe_buffer_begin(second)); + auto output_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first), + cugraph::get_dataframe_buffer_begin(sorted_second)); + thrust::copy(handle.get_thrust_policy(), + input_first, + input_first + size_dataframe_buffer(first), + output_first); thrust::sort( - execution_policy, - thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first), - cugraph::get_dataframe_buffer_begin(sorted_second)), - thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first) + first.size(), - cugraph::get_dataframe_buffer_begin(sorted_second) + first.size())); + handle.get_thrust_policy(), output_first, output_first + size_dataframe_buffer(sorted_first)); return std::make_tuple(std::move(sorted_first), std::move(sorted_second)); } -template rmm::device_uvector sort(raft::handle_t const& handle, - rmm::device_uvector const& values); - -template rmm::device_uvector sort(raft::handle_t const& handle, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort( +template std::tuple, rmm::device_uvector> sort( raft::handle_t const& handle, rmm::device_uvector const& first, rmm::device_uvector const& second); -template std::tuple, rmm::device_uvector> sort( +template std::tuple, rmm::device_uvector> sort( raft::handle_t const& handle, rmm::device_uvector const& first, rmm::device_uvector const& second); -template -std::tuple sort_by_key(raft::handle_t const& handle, - key_buffer_type const& keys, - value_buffer_type const& values) +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& keys, + cugraph::dataframe_buffer_type_t const& values) { auto sorted_keys = - cugraph::allocate_dataframe_buffer>( - keys.size(), handle.get_stream()); + cugraph::allocate_dataframe_buffer(size_dataframe_buffer(keys), handle.get_stream()); auto sorted_values = - cugraph::allocate_dataframe_buffer>( - keys.size(), handle.get_stream()); - - auto execution_policy = handle.get_thrust_policy(); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(keys), - cugraph::get_dataframe_buffer_end(keys), - cugraph::get_dataframe_buffer_begin(sorted_keys)); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(values), - cugraph::get_dataframe_buffer_end(values), - cugraph::get_dataframe_buffer_begin(sorted_values)); + cugraph::allocate_dataframe_buffer(size_dataframe_buffer(keys), handle.get_stream()); - thrust::sort_by_key(execution_policy, + auto input_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(keys), + cugraph::get_dataframe_buffer_begin(values)); + thrust::copy(handle.get_thrust_policy(), + input_first, + input_first + size_dataframe_buffer(keys), + thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_keys), + cugraph::get_dataframe_buffer_begin(sorted_values))); + thrust::sort_by_key(handle.get_thrust_policy(), cugraph::get_dataframe_buffer_begin(sorted_keys), cugraph::get_dataframe_buffer_end(sorted_keys), cugraph::get_dataframe_buffer_begin(sorted_values)); @@ -134,93 +140,179 @@ std::tuple sort_by_key(raft::handle_t const& return std::make_tuple(std::move(sorted_keys), std::move(sorted_values)); } -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& keys, + rmm::device_uvector const& values); -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); - -template std::tuple, rmm::device_uvector> sort_by_key( +template std::tuple, + std::tuple, rmm::device_uvector>> +sort_by_key>( raft::handle_t const& handle, rmm::device_uvector const& keys, - rmm::device_uvector const& values); + std::tuple, rmm::device_uvector> const& values); -template std::tuple, rmm::device_uvector> sort_by_key( +template std::tuple, + std::tuple, rmm::device_uvector>> +sort_by_key>( raft::handle_t const& handle, rmm::device_uvector const& keys, - rmm::device_uvector const& values); + std::tuple, rmm::device_uvector> const& values); -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& keys, + cugraph::dataframe_buffer_type_t&& values) +{ + auto sorted_keys = std::move(keys); + auto sorted_values = std::move(values); -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); + thrust::sort_by_key(handle.get_thrust_policy(), + cugraph::get_dataframe_buffer_begin(sorted_keys), + cugraph::get_dataframe_buffer_end(sorted_keys), + cugraph::get_dataframe_buffer_begin(sorted_values)); -template std::tuple, rmm::device_uvector> sort_by_key( - raft::handle_t const& handle, - rmm::device_uvector const& keys, - rmm::device_uvector const& values); + return std::make_tuple(std::move(sorted_keys), std::move(sorted_values)); +} -template -std::tuple sort_by_key( - raft::handle_t const& handle, - key_buffer_type const& first, - key_buffer_type const& second, - value_buffer_type const& values) +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template std::tuple, rmm::device_uvector> +sort_by_key(raft::handle_t const& handle, + rmm::device_uvector&& keys, + rmm::device_uvector&& values); + +template +std::tuple, + cugraph::dataframe_buffer_type_t, + cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& first, + cugraph::dataframe_buffer_type_t const& second, + cugraph::dataframe_buffer_type_t const& values) { - auto sorted_first = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); - auto sorted_second = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); - auto sorted_values = - cugraph::allocate_dataframe_buffer>( - first.size(), handle.get_stream()); - - auto execution_policy = handle.get_thrust_policy(); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(first), - cugraph::get_dataframe_buffer_end(first), - cugraph::get_dataframe_buffer_begin(sorted_first)); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(second), - cugraph::get_dataframe_buffer_end(second), - cugraph::get_dataframe_buffer_begin(sorted_second)); - thrust::copy(execution_policy, - cugraph::get_dataframe_buffer_begin(values), - cugraph::get_dataframe_buffer_end(values), - cugraph::get_dataframe_buffer_begin(sorted_values)); - thrust::sort_by_key( - execution_policy, + auto sorted_first = cugraph::allocate_dataframe_buffer( + cugraph::size_dataframe_buffer(first), handle.get_stream()); + auto sorted_second = cugraph::allocate_dataframe_buffer( + cugraph::size_dataframe_buffer(first), handle.get_stream()); + auto sorted_values = cugraph::allocate_dataframe_buffer( + cugraph::size_dataframe_buffer(first), handle.get_stream()); + + auto input_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(first), + cugraph::get_dataframe_buffer_begin(second), + cugraph::get_dataframe_buffer_begin(values)); + thrust::copy(handle.get_thrust_policy(), + input_first, + input_first + size_dataframe_buffer(first), + thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first), + cugraph::get_dataframe_buffer_begin(sorted_second), + cugraph::get_dataframe_buffer_begin(sorted_values))); + auto sorted_key_first = thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first), - cugraph::get_dataframe_buffer_begin(sorted_second)), - thrust::make_zip_iterator(cugraph::get_dataframe_buffer_begin(sorted_first) + first.size(), - cugraph::get_dataframe_buffer_begin(sorted_second) + first.size()), - cugraph::get_dataframe_buffer_begin(sorted_values)); + cugraph::get_dataframe_buffer_begin(sorted_second)); + thrust::sort_by_key(handle.get_thrust_policy(), + sorted_key_first, + sorted_key_first + cugraph::size_dataframe_buffer(sorted_first), + cugraph::get_dataframe_buffer_begin(sorted_values)); return std::make_tuple( std::move(sorted_first), std::move(sorted_second), std::move(sorted_values)); @@ -228,43 +320,84 @@ std::tuple sort_by_key( template std:: tuple, rmm::device_uvector, rmm::device_uvector> - sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& first, - rmm::device_uvector const& second, - rmm::device_uvector const& values); + sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& first, + rmm::device_uvector const& second, + rmm::device_uvector const& values); template std:: tuple, rmm::device_uvector, rmm::device_uvector> - sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& first, - rmm::device_uvector const& second, - rmm::device_uvector const& values); + sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& first, + rmm::device_uvector const& second, + rmm::device_uvector const& values); template std:: tuple, rmm::device_uvector, rmm::device_uvector> - sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& first, - rmm::device_uvector const& second, - rmm::device_uvector const& values); + sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& first, + rmm::device_uvector const& second, + rmm::device_uvector const& values); template std:: tuple, rmm::device_uvector, rmm::device_uvector> - sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& first, - rmm::device_uvector const& second, - rmm::device_uvector const& values); + sort_by_key(raft::handle_t const& handle, + rmm::device_uvector const& first, + rmm::device_uvector const& second, + rmm::device_uvector const& values); + +template +cugraph::dataframe_buffer_type_t unique(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& values) +{ + auto last = thrust::unique(handle.get_thrust_policy(), + cugraph::get_dataframe_buffer_begin(values), + cugraph::get_dataframe_buffer_end(values)); + cugraph::resize_dataframe_buffer( + values, + thrust::distance(cugraph::get_dataframe_buffer_begin(values), last), + handle.get_stream()); + cugraph::shrink_to_fit_dataframe_buffer(values, handle.get_stream()); + + return std::move(values); +} -template std::tuple, - std::tuple, rmm::device_uvector>> -sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& keys, - std::tuple, rmm::device_uvector> const& values); +template rmm::device_uvector unique(raft::handle_t const& handle, + rmm::device_uvector&& values); -template std::tuple, - std::tuple, rmm::device_uvector>> -sort_by_key(raft::handle_t const& handle, - rmm::device_uvector const& keys, - std::tuple, rmm::device_uvector> const& values); +template rmm::device_uvector unique(raft::handle_t const& handle, + rmm::device_uvector&& values); + +template +cugraph::dataframe_buffer_type_t sequence(raft::handle_t const& handle, + size_t length, + size_t repeat_count, + value_t init) +{ + auto values = cugraph::allocate_dataframe_buffer(length, handle.get_stream()); + if (repeat_count == 1) { + thrust::sequence(handle.get_thrust_policy(), values.begin(), values.end(), init); + } else { + thrust::tabulate(handle.get_thrust_policy(), + values.begin(), + values.end(), + [repeat_count, init] __device__(size_t i) { + return init + static_cast(i / repeat_count); + }); + } + + return values; +} + +template rmm::device_uvector sequence(raft::handle_t const& handle, + size_t length, + size_t repeat_count, + int32_t init); + +template rmm::device_uvector sequence(raft::handle_t const& handle, + size_t length, + size_t repeat_count, + int64_t init); template vertex_t max_element(raft::handle_t const& handle, raft::device_span vertices) diff --git a/cpp/tests/utilities/thrust_wrapper.hpp b/cpp/tests/utilities/thrust_wrapper.hpp index c4b87126f50..6454b190a3f 100644 --- a/cpp/tests/utilities/thrust_wrapper.hpp +++ b/cpp/tests/utilities/thrust_wrapper.hpp @@ -15,6 +15,8 @@ */ #pragma once +#include + #include #include @@ -26,25 +28,50 @@ namespace cugraph { namespace test { -template -value_buffer_type sort(raft::handle_t const& handle, value_buffer_type const& values); - -template -std::tuple sort(raft::handle_t const& handle, - value_buffer_type const& first, - value_buffer_type const& second); - -template -std::tuple sort_by_key(raft::handle_t const& handle, - key_buffer_type const& keys, - value_buffer_type const& values); - -template -std::tuple sort_by_key( - raft::handle_t const& handle, - key_buffer_type const& first, - key_buffer_type const& second, - value_buffer_type const& values); +template +cugraph::dataframe_buffer_type_t sort( + raft::handle_t const& handle, cugraph::dataframe_buffer_type_t const& values); + +template +cugraph::dataframe_buffer_type_t sort(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& values); + +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& first, + cugraph::dataframe_buffer_type_t const& second); + +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& keys, + cugraph::dataframe_buffer_type_t const& values); + +template +std::tuple, cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t&& keys, + cugraph::dataframe_buffer_type_t&& values); + +template +std::tuple, + cugraph::dataframe_buffer_type_t, + cugraph::dataframe_buffer_type_t> +sort_by_key(raft::handle_t const& handle, + cugraph::dataframe_buffer_type_t const& first, + cugraph::dataframe_buffer_type_t const& second, + cugraph::dataframe_buffer_type_t const& values); + +template +cugraph::dataframe_buffer_type_t unique( + raft::handle_t const& handle, cugraph::dataframe_buffer_type_t&& values); + +template +cugraph::dataframe_buffer_type_t sequence(raft::handle_t const& handle, + size_t length, + size_t repeat_count, + value_t init); template vertex_t max_element(raft::handle_t const& handle, raft::device_span vertices);