diff --git a/cpp/src/sampling/detail/shuffle_and_organize_output_mg_v32_e64.cu b/cpp/src/sampling/detail/shuffle_and_organize_output_mg_v32_e64.cu new file mode 100644 index 0000000000..ef760844b5 --- /dev/null +++ b/cpp/src/sampling/detail/shuffle_and_organize_output_mg_v32_e64.cu @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sampling/detail/shuffle_and_organize_output_impl.cuh" + +namespace cugraph { +namespace detail { + +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +shuffle_and_organize_output( + raft::handle_t const& handle, + rmm::device_uvector&& majors, + rmm::device_uvector&& minors, + std::optional>&& weights, + std::optional>&& edge_ids, + std::optional>&& edge_types, + std::optional>&& hops, + std::optional>&& labels, + std::optional> label_to_output_comm_rank); + +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +shuffle_and_organize_output( + raft::handle_t const& handle, + rmm::device_uvector&& majors, + rmm::device_uvector&& minors, + std::optional>&& weights, + std::optional>&& edge_ids, + std::optional>&& edge_types, + std::optional>&& hops, + std::optional>&& labels, + std::optional> label_to_output_comm_rank); + +} // namespace detail +} // namespace cugraph diff --git a/cpp/tests/sampling/heterogeneous_biased_neighbor_sampling.cpp b/cpp/tests/sampling/heterogeneous_biased_neighbor_sampling.cpp index c8bfb65459..3c3739ed66 100644 --- a/cpp/tests/sampling/heterogeneous_biased_neighbor_sampling.cpp +++ b/cpp/tests/sampling/heterogeneous_biased_neighbor_sampling.cpp @@ -75,8 +75,6 @@ class Tests_Heterogeneous_Biased_Neighbor_Sampling auto edge_weight_view = edge_weights ? std::make_optional((*edge_weights).view()) : std::nullopt; - std::optional> edge_mask{std::nullopt}; - constexpr float select_probability{0.05}; // FIXME: Update the tests to initialize RngState and use it instead @@ -251,12 +249,6 @@ TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int32Float) override_File_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int64Float) -{ - run_current_test( - override_File_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_File, CheckInt64Int64Float) { run_current_test( @@ -269,12 +261,6 @@ TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float) override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float) -{ - run_current_test( - override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_Heterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float) { run_current_test( diff --git a/cpp/tests/sampling/heterogeneous_uniform_neighbor_sampling.cpp b/cpp/tests/sampling/heterogeneous_uniform_neighbor_sampling.cpp index 15be79a623..83f99e68e9 100644 --- a/cpp/tests/sampling/heterogeneous_uniform_neighbor_sampling.cpp +++ b/cpp/tests/sampling/heterogeneous_uniform_neighbor_sampling.cpp @@ -116,11 +116,11 @@ class Tests_Heterogeneous_Uniform_Neighbor_Sampling // Generate the edge types - std::optional> edge_types{ + std::optional> edge_types{ std::nullopt}; if (heterogeneous_uniform_neighbor_sampling_usecase.num_edge_types > 1) { - edge_types = cugraph::test::generate::edge_property( + edge_types = cugraph::test::generate::edge_property( handle, graph_view, heterogeneous_uniform_neighbor_sampling_usecase.num_edge_types); } @@ -249,12 +249,6 @@ TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int32Float) override_File_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int64Float) -{ - run_current_test( - override_File_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_File, CheckInt64Int64Float) { run_current_test( @@ -267,12 +261,6 @@ TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Float) override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float) -{ - run_current_test( - override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_Heterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float) { run_current_test( diff --git a/cpp/tests/sampling/homogeneous_biased_neighbor_sampling.cpp b/cpp/tests/sampling/homogeneous_biased_neighbor_sampling.cpp index d9f3e90f12..d01206c5ce 100644 --- a/cpp/tests/sampling/homogeneous_biased_neighbor_sampling.cpp +++ b/cpp/tests/sampling/homogeneous_biased_neighbor_sampling.cpp @@ -233,12 +233,6 @@ TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int32Float) override_File_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_File, CheckInt32Int64Float) -{ - run_current_test( - override_File_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_File, CheckInt64Int64Float) { run_current_test( @@ -251,12 +245,6 @@ TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float) override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float) -{ - run_current_test( - override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_Homogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float) { run_current_test( diff --git a/cpp/tests/sampling/homogeneous_uniform_neighbor_sampling.cpp b/cpp/tests/sampling/homogeneous_uniform_neighbor_sampling.cpp index 8c50ab5691..508914e249 100644 --- a/cpp/tests/sampling/homogeneous_uniform_neighbor_sampling.cpp +++ b/cpp/tests/sampling/homogeneous_uniform_neighbor_sampling.cpp @@ -80,6 +80,11 @@ class Tests_Homogeneous_Uniform_Neighbor_Sampling graph_view.attach_edge_mask((*edge_mask).view()); } + // FIXME: Read a tuple of two edge mask and mask out if edge mask is set in either 1 (OR) and create + // a new one. + // No graph view can have two mask and perform OR in itself, and need to OR the mask + // manually by itself. + constexpr float select_probability{0.05}; // FIXME: Update the tests to initialize RngState and use it instead @@ -231,12 +236,6 @@ TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int32Float) override_File_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_File, CheckInt32Int64Float) -{ - run_current_test( - override_File_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_File, CheckInt64Int64Float) { run_current_test( @@ -249,12 +248,6 @@ TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Float) override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float) -{ - run_current_test( - override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_Homogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float) { run_current_test( diff --git a/cpp/tests/sampling/mg_heterogeneous_biased_neighbor_sampling.cpp b/cpp/tests/sampling/mg_heterogeneous_biased_neighbor_sampling.cpp index c2e3435506..688d848011 100644 --- a/cpp/tests/sampling/mg_heterogeneous_biased_neighbor_sampling.cpp +++ b/cpp/tests/sampling/mg_heterogeneous_biased_neighbor_sampling.cpp @@ -338,12 +338,6 @@ TEST_P(Tests_MGHeterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_MGHeterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float) -{ - run_current_test( - override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_MGHeterogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float) { run_current_test( diff --git a/cpp/tests/sampling/mg_heterogeneous_uniform_neighbor_sampling.cpp b/cpp/tests/sampling/mg_heterogeneous_uniform_neighbor_sampling.cpp index 6282f57f1b..18e75d1b1b 100644 --- a/cpp/tests/sampling/mg_heterogeneous_uniform_neighbor_sampling.cpp +++ b/cpp/tests/sampling/mg_heterogeneous_uniform_neighbor_sampling.cpp @@ -339,12 +339,6 @@ TEST_P(Tests_MGHeterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Floa override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_MGHeterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float) -{ - run_current_test( - override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_MGHeterogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float) { run_current_test( diff --git a/cpp/tests/sampling/mg_homogeneous_biased_neighbor_sampling.cpp b/cpp/tests/sampling/mg_homogeneous_biased_neighbor_sampling.cpp index 4ad5f0a205..ec13493806 100644 --- a/cpp/tests/sampling/mg_homogeneous_biased_neighbor_sampling.cpp +++ b/cpp/tests/sampling/mg_homogeneous_biased_neighbor_sampling.cpp @@ -323,12 +323,6 @@ TEST_P(Tests_MGHomogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int32Float) override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_MGHomogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt32Int64Float) -{ - run_current_test( - override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_MGHomogeneous_Biased_Neighbor_Sampling_Rmat, CheckInt64Int64Float) { run_current_test( diff --git a/cpp/tests/sampling/mg_homogeneous_uniform_neighbor_sampling.cpp b/cpp/tests/sampling/mg_homogeneous_uniform_neighbor_sampling.cpp index 68ca7d4dd8..e33000044a 100644 --- a/cpp/tests/sampling/mg_homogeneous_uniform_neighbor_sampling.cpp +++ b/cpp/tests/sampling/mg_homogeneous_uniform_neighbor_sampling.cpp @@ -323,12 +323,6 @@ TEST_P(Tests_MGHomogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int32Float) override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); } -TEST_P(Tests_MGHomogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt32Int64Float) -{ - run_current_test( - override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); -} - TEST_P(Tests_MGHomogeneous_Uniform_Neighbor_Sampling_Rmat, CheckInt64Int64Float) { run_current_test(