Skip to content

Commit

Permalink
update thrust_wrapper to include more wrapper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed May 28, 2024
1 parent fcdb1b4 commit f8a5321
Show file tree
Hide file tree
Showing 14 changed files with 393 additions and 214 deletions.
44 changes: 26 additions & 18 deletions cpp/include/cugraph/utilities/dataframe_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,6 @@ auto get_dataframe_buffer_cend_tuple_impl(std::index_sequence<I...>, TupleType&

} // namespace detail

template <typename T>
struct dataframe_element {
using type = void;
};

template <typename... Ts>
struct dataframe_element<std::tuple<rmm::device_uvector<Ts>...>> {
using type = thrust::tuple<Ts...>;
};

template <typename T>
struct dataframe_element<rmm::device_uvector<T>> {
using type = T;
};

template <typename DataframeType>
using dataframe_element_t = typename dataframe_element<DataframeType>::type;

template <typename T, typename std::enable_if_t<std::is_arithmetic<T>::value>* = nullptr>
auto allocate_dataframe_buffer(size_t buffer_size, rmm::cuda_stream_view stream_view)
{
Expand Down Expand Up @@ -224,4 +206,30 @@ auto get_dataframe_buffer_cend(BufferType& buffer)
std::make_index_sequence<std::tuple_size<BufferType>::value>(), buffer);
}

template <typename T>
struct dataframe_buffer_value_type {
using type = void;
};

template <typename T>
struct dataframe_buffer_value_type<rmm::device_uvector<T>> {
using type = T;
};

template <typename... Ts>
struct dataframe_buffer_value_type<std::tuple<rmm::device_uvector<Ts>...>> {
using type = thrust::tuple<Ts...>;
};

template <typename BufferType>
using dataframe_buffer_value_type_t = typename dataframe_buffer_value_type<BufferType>::type;

template <typename T>
struct dataframe_buffer_type {
using type = decltype(allocate_dataframe_buffer<T>(size_t{0}, rmm::cuda_stream_view{}));
};

template <typename T>
using dataframe_buffer_type_t = typename dataframe_buffer_type<T>::type;

} // namespace cugraph
2 changes: 1 addition & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ ConfigureTest(UNIFORM_NEIGHBOR_SAMPLING_TEST sampling/uniform_neighbor_sampling.

###################################################################################################
# - BIASED NBR SAMPLING tests ---------------------------------------------------------------------
ConfigureTest(BIASED_NEIGHBOR_SAMPLING_TEST sampling/biased_neighbor_sampling.cu)
ConfigureTest(BIASED_NEIGHBOR_SAMPLING_TEST sampling/biased_neighbor_sampling.cpp)

###################################################################################################
# - SAMPLING_POST_PROCESSING tests ----------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/centrality/katz_centrality_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ class Tests_KatzCentrality
rmm::device_uvector<result_t> d_unrenumbered_katz_centralities(size_t{0},
handle.get_stream());
std::tie(std::ignore, d_unrenumbered_katz_centralities) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_katz_centralities);
cugraph::test::sort_by_key<vertex_t, result_t>(
handle, *d_renumber_map_labels, d_katz_centralities);
h_cugraph_katz_centralities =
cugraph::test::to_host(handle, d_unrenumbered_katz_centralities);
} else {
Expand Down
5 changes: 3 additions & 2 deletions cpp/tests/community/k_truss_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,11 @@ class Tests_KTruss : public ::testing::TestWithParam<std::tuple<KTruss_Usecase,

if (edge_weight) {
std::tie(d_sorted_cugraph_srcs, d_sorted_cugraph_dsts, d_sorted_cugraph_wgts) =
cugraph::test::sort_by_key(handle, d_cugraph_srcs, d_cugraph_dsts, *d_cugraph_wgts);
cugraph::test::sort_by_key<vertex_t, weight_t>(
handle, d_cugraph_srcs, d_cugraph_dsts, *d_cugraph_wgts);
} else {
std::tie(d_sorted_cugraph_srcs, d_sorted_cugraph_dsts) =
cugraph::test::sort(handle, d_cugraph_srcs, d_cugraph_dsts);
cugraph::test::sort<vertex_t>(handle, d_cugraph_srcs, d_cugraph_dsts);
}

auto h_cugraph_srcs = cugraph::test::to_host(handle, d_sorted_cugraph_srcs);
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/components/weakly_connected_components_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ class Tests_WeaklyConnectedComponent
if (renumber) {
rmm::device_uvector<vertex_t> d_unrenumbered_components(size_t{0}, handle.get_stream());
std::tie(std::ignore, d_unrenumbered_components) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_components);
cugraph::test::sort_by_key<vertex_t, vertex_t>(
handle, *d_renumber_map_labels, d_components);
h_cugraph_components = cugraph::test::to_host(handle, d_unrenumbered_components);
} else {
h_cugraph_components = cugraph::test::to_host(handle, d_components);
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/cores/core_number_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ class Tests_CoreNumber
if (renumber) {
rmm::device_uvector<edge_t> d_unrenumbered_core_numbers(size_t{0}, handle.get_stream());
std::tie(std::ignore, d_unrenumbered_core_numbers) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_core_numbers);
cugraph::test::sort_by_key<vertex_t, edge_t>(
handle, *d_renumber_map_labels, d_core_numbers);
h_cugraph_core_numbers = cugraph::test::to_host(handle, d_unrenumbered_core_numbers);
} else {
h_cugraph_core_numbers = cugraph::test::to_host(handle, d_core_numbers);
Expand Down
5 changes: 3 additions & 2 deletions cpp/tests/link_analysis/hits_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ class Tests_Hits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, inpu
if (renumber) {
rmm::device_uvector<weight_t> d_unrenumbered_initial_random_hubs(0, handle.get_stream());
std::tie(std::ignore, d_unrenumbered_initial_random_hubs) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, *d_initial_random_hubs);
cugraph::test::sort_by_key<vertex_t, weight_t>(
handle, *d_renumber_map_labels, *d_initial_random_hubs);
h_initial_random_hubs =
cugraph::test::to_host(handle, d_unrenumbered_initial_random_hubs);
} else {
Expand All @@ -277,7 +278,7 @@ class Tests_Hits : public ::testing::TestWithParam<std::tuple<Hits_Usecase, inpu
if (renumber) {
rmm::device_uvector<weight_t> d_unrenumbered_hubs(size_t{0}, handle.get_stream());
std::tie(std::ignore, d_unrenumbered_hubs) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_hubs);
cugraph::test::sort_by_key<vertex_t, weight_t>(handle, *d_renumber_map_labels, d_hubs);
h_cugraph_hits = cugraph::test::to_host(handle, d_unrenumbered_hubs);
} else {
h_cugraph_hits = cugraph::test::to_host(handle, d_hubs);
Expand Down
9 changes: 5 additions & 4 deletions cpp/tests/link_analysis/pagerank_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ class Tests_PageRank
vertex_t{0},
graph_view.number_of_vertices());
std::tie(d_unrenumbered_personalization_vertices, d_unrenumbered_personalization_values) =
cugraph::test::sort_by_key(handle,
d_unrenumbered_personalization_vertices,
d_unrenumbered_personalization_values);
cugraph::test::sort_by_key<vertex_t, result_t>(handle,
d_unrenumbered_personalization_vertices,
d_unrenumbered_personalization_values);

h_unrenumbered_personalization_vertices =
cugraph::test::to_host(handle, d_unrenumbered_personalization_vertices);
Expand Down Expand Up @@ -327,7 +327,8 @@ class Tests_PageRank
if (renumber) {
rmm::device_uvector<result_t> d_unrenumbered_pageranks(size_t{0}, handle.get_stream());
std::tie(std::ignore, d_unrenumbered_pageranks) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_pageranks);
cugraph::test::sort_by_key<vertex_t, result_t>(
handle, *d_renumber_map_labels, d_pageranks);
h_cugraph_pageranks = cugraph::test::to_host(handle, d_unrenumbered_pageranks);
} else {
h_cugraph_pageranks = cugraph::test::to_host(handle, d_pageranks);
Expand Down
6 changes: 4 additions & 2 deletions cpp/tests/traversal/bfs_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,12 @@ class Tests_BFS : public ::testing::TestWithParam<std::tuple<BFS_Usecase, input_

rmm::device_uvector<vertex_t> d_unrenumbered_distances(size_t{0}, handle.get_stream());
std::tie(std::ignore, d_unrenumbered_distances) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_distances);
cugraph::test::sort_by_key<vertex_t, vertex_t>(
handle, *d_renumber_map_labels, d_distances);
rmm::device_uvector<vertex_t> d_unrenumbered_predecessors(size_t{0}, handle.get_stream());
std::tie(std::ignore, d_unrenumbered_predecessors) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_predecessors);
cugraph::test::sort_by_key<vertex_t, vertex_t>(
handle, *d_renumber_map_labels, d_predecessors);
h_cugraph_distances = cugraph::test::to_host(handle, d_unrenumbered_distances);
h_cugraph_predecessors = cugraph::test::to_host(handle, d_unrenumbered_predecessors);
} else {
Expand Down
6 changes: 4 additions & 2 deletions cpp/tests/traversal/sssp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,12 @@ class Tests_SSSP : public ::testing::TestWithParam<std::tuple<SSSP_Usecase, inpu

rmm::device_uvector<weight_t> d_unrenumbered_distances(size_t{0}, handle.get_stream());
std::tie(std::ignore, d_unrenumbered_distances) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_distances);
cugraph::test::sort_by_key<vertex_t, weight_t>(
handle, *d_renumber_map_labels, d_distances);
rmm::device_uvector<vertex_t> d_unrenumbered_predecessors(size_t{0}, handle.get_stream());
std::tie(std::ignore, d_unrenumbered_predecessors) =
cugraph::test::sort_by_key(handle, *d_renumber_map_labels, d_predecessors);
cugraph::test::sort_by_key<vertex_t, vertex_t>(
handle, *d_renumber_map_labels, d_predecessors);

h_cugraph_distances = cugraph::test::to_host(handle, d_unrenumbered_distances);
h_cugraph_predecessors = cugraph::test::to_host(handle, d_unrenumbered_predecessors);
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/utilities/conversion_utilities_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ mg_vertex_property_values_to_sg_vertex_property_values(
static_cast<vertex_t>((*sg_renumber_map).size()));
}

std::tie(sg_vertices, sg_values) = cugraph::test::sort_by_key(handle, sg_vertices, sg_values);
std::tie(sg_vertices, sg_values) = cugraph::test::sort_by_key<vertex_t, value_t>(
handle, std::move(sg_vertices), std::move(sg_values));

if (mg_vertices) {
return std::make_tuple(std::move(sg_vertices), std::move(sg_values));
Expand Down
Loading

0 comments on commit f8a5321

Please sign in to comment.