diff --git a/cpp/include/cugraph/edge_partition_device_view.cuh b/cpp/include/cugraph/edge_partition_device_view.cuh index d34d639f4d9..213f9b9497a 100644 --- a/cpp/include/cugraph/edge_partition_device_view.cuh +++ b/cpp/include/cugraph/edge_partition_device_view.cuh @@ -109,6 +109,13 @@ class edge_partition_device_view_base_t { __host__ __device__ edge_t const* offsets() const { return offsets_.data(); } __host__ __device__ vertex_t const* indices() const { return indices_.data(); } + __device__ vertex_t major_idx_from_local_edge_idx_nocheck(edge_t local_edge_idx) const noexcept + { + return static_cast(thrust::distance( + offsets_.begin() + 1, + thrust::upper_bound(thrust::seq, offsets_.begin() + 1, offsets_.end(), local_edge_idx))); + } + // major_idx == major offset if CSR/CSC, major_offset != major_idx if DCSR/DCSC __device__ thrust::tuple local_edges( vertex_t major_idx) const noexcept @@ -291,8 +298,19 @@ class edge_partition_device_view_t= (*major_hypersparse_first_ - major_range_first_) + ? (*dcs_nzd_vertices_)[major_idx - (*major_hypersparse_first_ - major_range_first_)] + : major_from_major_offset_nocheck(major_idx); + } else { // major_idx == major_offset + return major_from_major_offset_nocheck(major_idx); + } + } + // major_hypersparse_idx: index within the hypersparse segment - __host__ __device__ thrust::optional major_hypersparse_idx_from_major_nocheck( + __device__ thrust::optional major_hypersparse_idx_from_major_nocheck( vertex_t major) const noexcept { if (dcs_nzd_vertices_) { @@ -303,7 +321,7 @@ class edge_partition_device_view_t major_from_major_hypersparse_idx_nocheck( + __device__ thrust::optional major_from_major_hypersparse_idx_nocheck( vertex_t major_hypersparse_idx) const noexcept { return dcs_nzd_vertices_ @@ -442,8 +460,13 @@ class edge_partition_device_view_t major_hypersparse_idx_from_major_nocheck( + __device__ thrust::optional major_hypersparse_idx_from_major_nocheck( vertex_t major) const noexcept { assert(false); @@ -451,7 +474,7 @@ class edge_partition_device_view_t major_from_major_hypersparse_idx_nocheck( + __device__ thrust::optional major_from_major_hypersparse_idx_nocheck( vertex_t major_hypersparse_idx) const noexcept { assert(false); diff --git a/cpp/include/cugraph/edge_partition_edge_property_device_view.cuh b/cpp/include/cugraph/edge_partition_edge_property_device_view.cuh index f71fc167d12..18091567e38 100644 --- a/cpp/include/cugraph/edge_partition_edge_property_device_view.cuh +++ b/cpp/include/cugraph/edge_partition_edge_property_device_view.cuh @@ -33,18 +33,21 @@ template ::value_type> class edge_partition_edge_property_device_view_t { public: + using edge_type = edge_t; + using value_type = value_t; + static constexpr bool is_packed_bool = cugraph::is_packed_bool(); + static constexpr bool has_packed_bool_element = + cugraph::has_packed_bool_element(); + static_assert( std::is_same_v::value_type, value_t> || - cugraph::has_packed_bool_element()); + has_packed_bool_element); static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic::value); - using edge_type = edge_t; - using value_type = value_t; - edge_partition_edge_property_device_view_t() = default; edge_partition_edge_property_device_view_t( - edge_property_view_t const& view, size_t partition_idx) + edge_property_view_t const& view, size_t partition_idx) : value_first_(view.value_firsts()[partition_idx]) { value_first_ = view.value_firsts()[partition_idx]; @@ -54,8 +57,8 @@ class edge_partition_edge_property_device_view_t { __device__ value_t get(edge_t offset) const { - if constexpr (cugraph::has_packed_bool_element()) { - static_assert(std::is_arithmetic_v, "unimplemented for thrust::tuple types."); + if constexpr (has_packed_bool_element) { + static_assert(is_packed_bool, "unimplemented for thrust::tuple types."); auto mask = cugraph::packed_bool_mask(offset); return static_cast(*(value_first_ + cugraph::packed_bool_offset(offset)) & mask); } else { @@ -69,8 +72,8 @@ class edge_partition_edge_property_device_view_t { void> set(edge_t offset, value_t val) const { - if constexpr (cugraph::has_packed_bool_element()) { - static_assert(std::is_arithmetic_v, "unimplemented for thrust::tuple types."); + if constexpr (has_packed_bool_element) { + static_assert(is_packed_bool, "unimplemented for thrust::tuple types."); auto mask = cugraph::packed_bool_mask(offset); if (val) { atomicOr(value_first_ + cugraph::packed_bool_offset(offset), mask); @@ -88,8 +91,8 @@ class edge_partition_edge_property_device_view_t { value_t> atomic_and(edge_t offset, value_t val) const { - if constexpr (cugraph::has_packed_bool_element()) { - static_assert(std::is_arithmetic_v, "unimplemented for thrust::tuple types."); + if constexpr (has_packed_bool_element) { + static_assert(is_packed_bool, "unimplemented for thrust::tuple types."); auto mask = cugraph::packed_bool_mask(offset); auto old = atomicAnd(value_first_ + cugraph::packed_bool_offset(offset), val ? uint32_t{0xffffffff} : ~mask); @@ -105,8 +108,8 @@ class edge_partition_edge_property_device_view_t { value_t> atomic_or(edge_t offset, value_t val) const { - if constexpr (cugraph::has_packed_bool_element()) { - static_assert(std::is_arithmetic_v, "unimplemented for thrust::tuple types."); + if constexpr (has_packed_bool_element) { + static_assert(is_packed_bool, "unimplemented for thrust::tuple types."); auto mask = cugraph::packed_bool_mask(offset); auto old = atomicOr(value_first_ + cugraph::packed_bool_offset(offset), val ? mask : uint32_t{0}); @@ -132,8 +135,8 @@ class edge_partition_edge_property_device_view_t { value_t> elementwise_atomic_cas(edge_t offset, value_t compare, value_t val) const { - if constexpr (cugraph::has_packed_bool_element()) { - static_assert(std::is_arithmetic_v, "unimplemented for thrust::tuple types."); + if constexpr (has_packed_bool_element) { + static_assert(is_packed_bool, "unimplemented for thrust::tuple types."); auto mask = cugraph::packed_bool_mask(offset); auto old = val ? atomicOr(value_first_ + cugraph::packed_bool_offset(offset), mask) : atomicAnd(value_first_ + cugraph::packed_bool_offset(offset), ~mask); @@ -170,8 +173,10 @@ class edge_partition_edge_property_device_view_t { template class edge_partition_edge_dummy_property_device_view_t { public: - using edge_type = edge_t; - using value_type = thrust::nullopt_t; + using edge_type = edge_t; + using value_type = thrust::nullopt_t; + static constexpr bool is_packed_bool = false; + static constexpr bool has_packed_bool_element = false; edge_partition_edge_dummy_property_device_view_t() = default; diff --git a/cpp/include/cugraph/edge_partition_endpoint_property_device_view.cuh b/cpp/include/cugraph/edge_partition_endpoint_property_device_view.cuh index 1ff279fbdca..7578c646175 100644 --- a/cpp/include/cugraph/edge_partition_endpoint_property_device_view.cuh +++ b/cpp/include/cugraph/edge_partition_endpoint_property_device_view.cuh @@ -39,12 +39,15 @@ template ::value_type> class edge_partition_endpoint_property_device_view_t { public: + using vertex_type = vertex_t; + using value_type = value_t; + static constexpr bool is_packed_bool = cugraph::is_packed_bool(); + static constexpr bool has_packed_bool_element = + cugraph::has_packed_bool_element(); + static_assert( std::is_same_v::value_type, value_t> || - cugraph::has_packed_bool_element()); - - using vertex_type = vertex_t; - using value_type = value_t; + has_packed_bool_element); edge_partition_endpoint_property_device_view_t() = default; @@ -77,8 +80,8 @@ class edge_partition_endpoint_property_device_view_t { __device__ value_t get(vertex_t offset) const { auto val_offset = value_offset(offset); - if constexpr (cugraph::has_packed_bool_element()) { - static_assert(std::is_arithmetic_v, "unimplemented for thrust::tuple types."); + if constexpr (has_packed_bool_element) { + static_assert(is_packed_bool, "unimplemented for thrust::tuple types."); auto mask = cugraph::packed_bool_mask(val_offset); return static_cast(*(value_first_ + cugraph::packed_bool_offset(val_offset)) & mask); } else { @@ -93,8 +96,8 @@ class edge_partition_endpoint_property_device_view_t { atomic_and(vertex_t offset, value_t val) const { auto val_offset = value_offset(offset); - if constexpr (cugraph::has_packed_bool_element()) { - static_assert(std::is_arithmetic_v, "unimplemented for thrust::tuple types."); + if constexpr (has_packed_bool_element) { + static_assert(is_packed_bool, "unimplemented for thrust::tuple types."); auto mask = cugraph::packed_bool_mask(val_offset); auto old = atomicAnd(value_first_ + cugraph::packed_bool_offset(val_offset), val ? cugraph::packed_bool_full_mask() : ~mask); @@ -111,8 +114,8 @@ class edge_partition_endpoint_property_device_view_t { atomic_or(vertex_t offset, value_t val) const { auto val_offset = value_offset(offset); - if constexpr (cugraph::has_packed_bool_element()) { - static_assert(std::is_arithmetic_v, "unimplemented for thrust::tuple types."); + if constexpr (has_packed_bool_element) { + static_assert(is_packed_bool, "unimplemented for thrust::tuple types."); auto mask = cugraph::packed_bool_mask(val_offset); auto old = atomicOr(value_first_ + cugraph::packed_bool_offset(val_offset), val ? mask : cugraph::packed_bool_empty_mask()); @@ -140,8 +143,8 @@ class edge_partition_endpoint_property_device_view_t { elementwise_atomic_cas(vertex_t offset, value_t compare, value_t val) const { auto val_offset = value_offset(offset); - if constexpr (cugraph::has_packed_bool_element()) { - static_assert(std::is_arithmetic_v, "unimplemented for thrust::tuple types."); + if constexpr (has_packed_bool_element) { + static_assert(is_packed_bool, "unimplemented for thrust::tuple types."); auto mask = cugraph::packed_bool_mask(val_offset); auto old = val ? atomicOr(value_first_ + cugraph::packed_bool_offset(val_offset), mask) : atomicAnd(value_first_ + cugraph::packed_bool_offset(val_offset), ~mask); @@ -203,8 +206,10 @@ class edge_partition_endpoint_property_device_view_t { template class edge_partition_endpoint_dummy_property_device_view_t { public: - using vertex_type = vertex_t; - using value_type = thrust::nullopt_t; + using vertex_type = vertex_t; + using value_type = thrust::nullopt_t; + static constexpr bool is_packed_bool = false; + static constexpr bool has_packed_bool_element = false; edge_partition_endpoint_dummy_property_device_view_t() = default; diff --git a/cpp/include/cugraph/edge_property.hpp b/cpp/include/cugraph/edge_property.hpp index 8904006a2a2..d46d4e52fd4 100644 --- a/cpp/include/cugraph/edge_property.hpp +++ b/cpp/include/cugraph/edge_property.hpp @@ -72,9 +72,11 @@ class edge_property_t { public: static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic::value); - using edge_type = typename GraphViewType::edge_type; - using value_type = T; - using buffer_type = decltype(allocate_dataframe_buffer(size_t{0}, rmm::cuda_stream_view{})); + using edge_type = typename GraphViewType::edge_type; + using value_type = T; + using buffer_type = + decltype(allocate_dataframe_buffer, uint32_t, T>>( + size_t{0}, rmm::cuda_stream_view{})); edge_property_t(raft::handle_t const& handle) {} diff --git a/cpp/include/cugraph/utilities/packed_bool_utils.hpp b/cpp/include/cugraph/utilities/packed_bool_utils.hpp index 9557b11e8e0..0be5711d90c 100644 --- a/cpp/include/cugraph/utilities/packed_bool_utils.hpp +++ b/cpp/include/cugraph/utilities/packed_bool_utils.hpp @@ -47,6 +47,13 @@ has_packed_bool_element(std::index_sequence) } // namespace detail +template +constexpr bool is_packed_bool() +{ + return std::is_same_v::value_type, uint32_t> && + std::is_same_v; +} + // sizeof(uint32_t) * 8 packed Boolean values are stored using one uint32_t template constexpr bool has_packed_bool_element() diff --git a/cpp/src/prims/transform_e.cuh b/cpp/src/prims/transform_e.cuh index 7950df58a3e..edacdc8a970 100644 --- a/cpp/src/prims/transform_e.cuh +++ b/cpp/src/prims/transform_e.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -35,6 +36,70 @@ namespace cugraph { +namespace detail { + +int32_t constexpr transform_e_kernel_block_size = 512; + +template +__global__ void transform_e_packed_bool( + edge_partition_device_view_t edge_partition, + EdgePartitionSrcValueInputWrapper edge_partition_src_value_input, + EdgePartitionDstValueInputWrapper edge_partition_dst_value_input, + EdgePartitionEdgeValueInputWrapper edge_partition_e_value_input, + EdgePartitionEdgeValueOutputWrapper edge_partition_e_value_output, + EdgeOp e_op) +{ + static_assert(EdgePartitionEdgeValueOutputWrapper::is_packed_bool); + static_assert(raft::warp_size() == packed_bools_per_word()); + + using edge_t = typename GraphViewType::edge_type; + + auto const tid = threadIdx.x + blockIdx.x * blockDim.x; + static_assert(transform_e_kernel_block_size % raft::warp_size() == 0); + auto const lane_id = tid % raft::warp_size(); + auto idx = static_cast(packed_bool_offset(tid)); + + auto num_edges = edge_partition.number_of_edges(); + while (idx < static_cast(packed_bool_size(num_edges))) { + auto local_edge_idx = + idx * static_cast(packed_bools_per_word()) + static_cast(lane_id); + uint32_t mask{0}; + int predicate{0}; + if (local_edge_idx < num_edges) { + auto major_idx = edge_partition.major_idx_from_local_edge_idx_nocheck(local_edge_idx); + auto major = edge_partition.major_from_major_idx_nocheck(major_idx); + auto major_offset = edge_partition.major_offset_from_major_nocheck(major); + auto minor = *(edge_partition.indices() + local_edge_idx); + auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); + + auto src = GraphViewType::is_storage_transposed ? minor : major; + auto dst = GraphViewType::is_storage_transposed ? major : minor; + auto src_offset = GraphViewType::is_storage_transposed ? minor_offset : major_offset; + auto dst_offset = GraphViewType::is_storage_transposed ? major_offset : minor_offset; + predicate = e_op(src, + dst, + edge_partition_src_value_input.get(src_offset), + edge_partition_dst_value_input.get(dst_offset), + edge_partition_e_value_input.get(local_edge_idx)) + ? int{1} + : int{0}; + } + mask = __ballot_sync(uint32_t{0xffffffff}, predicate); + if (lane_id == 0) { *(edge_partition_e_value_output.value_first() + idx) = mask; } + + idx += static_cast(gridDim.x * (blockDim.x / raft::warp_size())); + } +} + +} // namespace detail + /** * @brief Iterate over the entire set of edges and update edge property values. * @@ -84,9 +149,103 @@ void transform_e(raft::handle_t const& handle, EdgeValueOutputWrapper edge_value_output, bool do_expensive_check = false) { - // CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); + using vertex_t = typename GraphViewType::vertex_type; + using edge_t = typename GraphViewType::edge_type; - CUGRAPH_FAIL("unimplemented."); + using edge_partition_src_input_device_view_t = std::conditional_t< + std::is_same_v, + detail::edge_partition_endpoint_dummy_property_device_view_t, + detail::edge_partition_endpoint_property_device_view_t< + vertex_t, + typename EdgeSrcValueInputWrapper::value_iterator, + typename EdgeSrcValueInputWrapper::value_type>>; + using edge_partition_dst_input_device_view_t = std::conditional_t< + std::is_same_v, + detail::edge_partition_endpoint_dummy_property_device_view_t, + detail::edge_partition_endpoint_property_device_view_t< + vertex_t, + typename EdgeDstValueInputWrapper::value_iterator, + typename EdgeDstValueInputWrapper::value_type>>; + using edge_partition_e_input_device_view_t = std::conditional_t< + std::is_same_v, + detail::edge_partition_edge_dummy_property_device_view_t, + detail::edge_partition_edge_property_device_view_t< + edge_t, + typename EdgeValueInputWrapper::value_iterator, + typename EdgeValueInputWrapper::value_type>>; + using edge_partition_e_output_device_view_t = detail::edge_partition_edge_property_device_view_t< + edge_t, + typename EdgeValueOutputWrapper::value_iterator, + typename EdgeValueOutputWrapper::value_type>; + + CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); + + for (size_t i = 0; i < graph_view.number_of_local_edge_partitions(); ++i) { + auto edge_partition = + edge_partition_device_view_t( + graph_view.local_edge_partition_view(i)); + + edge_partition_src_input_device_view_t edge_partition_src_value_input{}; + edge_partition_dst_input_device_view_t edge_partition_dst_value_input{}; + if constexpr (GraphViewType::is_storage_transposed) { + edge_partition_src_value_input = edge_partition_src_input_device_view_t(edge_src_value_input); + edge_partition_dst_value_input = + edge_partition_dst_input_device_view_t(edge_dst_value_input, i); + } else { + edge_partition_src_value_input = + edge_partition_src_input_device_view_t(edge_src_value_input, i); + edge_partition_dst_value_input = edge_partition_dst_input_device_view_t(edge_dst_value_input); + } + auto edge_partition_e_value_input = edge_partition_e_input_device_view_t(edge_value_input, i); + auto edge_partition_e_value_output = + edge_partition_e_output_device_view_t(edge_value_output, i); + + auto num_edges = edge_partition.number_of_edges(); + if constexpr (edge_partition_e_output_device_view_t::has_packed_bool_element) { + static_assert(edge_partition_e_output_device_view_t::is_packed_bool, + "unimplemented for thrust::tuple types."); + if (edge_partition.number_of_edges() > edge_t{0}) { + raft::grid_1d_thread_t update_grid(num_edges, + detail::transform_e_kernel_block_size, + handle.get_device_properties().maxGridSize[0]); + detail::transform_e_packed_bool + <<>>( + edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input, + edge_partition_e_value_output, + e_op); + } + } else { + thrust::transform( + handle.get_thrust_policy(), + thrust::make_counting_iterator(edge_t{0}), + thrust::make_counting_iterator(num_edges), + edge_partition_e_value_output.value_first(), + [e_op, + edge_partition, + edge_partition_src_value_input, + edge_partition_dst_value_input, + edge_partition_e_value_input] __device__(edge_t i) { + auto major_idx = edge_partition.major_idx_from_local_edge_idx_nocheck(i); + auto major = edge_partition.major_from_major_idx_nocheck(major_idx); + auto major_offset = edge_partition.major_offset_from_major_nocheck(major); + auto minor = *(edge_partition.indices() + i); + auto minor_offset = edge_partition.minor_offset_from_minor_nocheck(minor); + + auto src = GraphViewType::is_storage_transposed ? minor : major; + auto dst = GraphViewType::is_storage_transposed ? major : minor; + auto src_offset = GraphViewType::is_storage_transposed ? minor_offset : major_offset; + auto dst_offset = GraphViewType::is_storage_transposed ? major_offset : minor_offset; + return e_op(src, + dst, + edge_partition_src_value_input.get(src_offset), + edge_partition_dst_value_input.get(dst_offset), + edge_partition_e_value_input.get(i)); + }); + } + } } /** @@ -177,7 +336,7 @@ void transform_e(raft::handle_t const& handle, typename EdgeValueOutputWrapper::value_iterator, typename EdgeValueOutputWrapper::value_type>; - // CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); + CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); auto major_first = GraphViewType::is_storage_transposed ? edge_list.dst_begin() : edge_list.src_begin(); diff --git a/cpp/tests/prims/mg_transform_e.cu b/cpp/tests/prims/mg_transform_e.cu index 127eddd43c7..24deaad810a 100644 --- a/cpp/tests/prims/mg_transform_e.cu +++ b/cpp/tests/prims/mg_transform_e.cu @@ -51,6 +51,7 @@ #include struct Prims_Usecase { + bool use_edgelist{false}; bool check_correctness{true}; }; @@ -113,8 +114,9 @@ class Tests_MGTransformE auto mg_dst_prop = cugraph::test::generate::dst_property( *handle_, mg_graph_view, mg_vertex_prop); - cugraph::edge_bucket_t edge_list(*handle_); - { + cugraph::edge_bucket_t edge_list( + *handle_); + if (prims_usecase.use_edgelist) { rmm::device_uvector srcs(0, handle_->get_stream()); rmm::device_uvector dsts(0, handle_->get_stream()); std::tie(srcs, dsts, std::ignore, std::ignore) = cugraph::decompress_to_edgelist( @@ -154,24 +156,41 @@ class Tests_MGTransformE if (cugraph::test::g_perf) { RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement handle_->get_comms().barrier(); - hr_timer.start("MG transform_reduce_e"); + hr_timer.start("MG transform_e"); } - cugraph::transform_e( - *handle_, - mg_graph_view, - edge_list, - mg_src_prop.view(), - mg_dst_prop.view(), - cugraph::edge_dummy_property_t{}.view(), - [] __device__(auto src, auto dst, auto src_property, auto dst_property, thrust::nullopt_t) { - if (src_property < dst_property) { - return src_property; - } else { - return dst_property; - } - }, - edge_value_output.mutable_view()); + if (prims_usecase.use_edgelist) { + cugraph::transform_e( + *handle_, + mg_graph_view, + edge_list, + mg_src_prop.view(), + mg_dst_prop.view(), + cugraph::edge_dummy_property_t{}.view(), + [] __device__(auto src, auto dst, auto src_property, auto dst_property, thrust::nullopt_t) { + if (src_property < dst_property) { + return src_property; + } else { + return dst_property; + } + }, + edge_value_output.mutable_view()); + } else { + cugraph::transform_e( + *handle_, + mg_graph_view, + mg_src_prop.view(), + mg_dst_prop.view(), + cugraph::edge_dummy_property_t{}.view(), + [] __device__(auto src, auto dst, auto src_property, auto dst_property, thrust::nullopt_t) { + if (src_property < dst_property) { + return src_property; + } else { + return dst_property; + } + }, + edge_value_output.mutable_view()); + } if (cugraph::test::g_perf) { RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement @@ -183,24 +202,42 @@ class Tests_MGTransformE // 3. validate MG results if (prims_usecase.check_correctness) { - auto num_invalids = cugraph::count_if_e( - *handle_, - mg_graph_view, - mg_src_prop.view(), - mg_dst_prop.view(), - edge_value_output.view(), - [property_initial_value] __device__( - auto src, auto dst, auto src_property, auto dst_property, auto edge_property) { - if (((src + dst) % 2) == 0) { + size_t num_invalids{}; + if (prims_usecase.use_edgelist) { + num_invalids = cugraph::count_if_e( + *handle_, + mg_graph_view, + mg_src_prop.view(), + mg_dst_prop.view(), + edge_value_output.view(), + [property_initial_value] __device__( + auto src, auto dst, auto src_property, auto dst_property, auto edge_property) { + if (((src + dst) % 2) == 0) { + if (src_property < dst_property) { + return edge_property != src_property; + } else { + return edge_property != dst_property; + } + } else { + return edge_property != property_initial_value; + } + }); + } else { + num_invalids = cugraph::count_if_e( + *handle_, + mg_graph_view, + mg_src_prop.view(), + mg_dst_prop.view(), + edge_value_output.view(), + [property_initial_value] __device__( + auto src, auto dst, auto src_property, auto dst_property, auto edge_property) { if (src_property < dst_property) { return edge_property != src_property; } else { return edge_property != dst_property; } - } else { - return edge_property != property_initial_value; - } - }); + }); + } ASSERT_TRUE(num_invalids == 0); } @@ -278,13 +315,13 @@ TEST_P(Tests_MGTransformE_Rmat, CheckInt64Int64FloatTupleIntFloatTransposeTrue) cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } -TEST_P(Tests_MGTransformE_File, CheckInt32Int32FloatTransposeFalse) +TEST_P(Tests_MGTransformE_File, CheckInt32Int32FloatIntTransposeFalse) { auto param = GetParam(); run_current_test(std::get<0>(param), std::get<1>(param)); } -TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int32FloatTransposeFalse) +TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int32FloatIntTransposeFalse) { auto param = GetParam(); run_current_test( @@ -292,7 +329,7 @@ TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int32FloatTransposeFalse) cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } -TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int64FloatTransposeFalse) +TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int64FloatIntTransposeFalse) { auto param = GetParam(); run_current_test( @@ -300,7 +337,7 @@ TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int64FloatTransposeFalse) cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } -TEST_P(Tests_MGTransformE_Rmat, CheckInt64Int64FloatTransposeFalse) +TEST_P(Tests_MGTransformE_Rmat, CheckInt64Int64FloatIntTransposeFalse) { auto param = GetParam(); run_current_test( @@ -308,13 +345,13 @@ TEST_P(Tests_MGTransformE_Rmat, CheckInt64Int64FloatTransposeFalse) cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } -TEST_P(Tests_MGTransformE_File, CheckInt32Int32FloatTransposeTrue) +TEST_P(Tests_MGTransformE_File, CheckInt32Int32FloatIntTransposeTrue) { auto param = GetParam(); run_current_test(std::get<0>(param), std::get<1>(param)); } -TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int32FloatTransposeTrue) +TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int32FloatIntTransposeTrue) { auto param = GetParam(); run_current_test( @@ -322,7 +359,7 @@ TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int32FloatTransposeTrue) cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } -TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int64FloatTransposeTrue) +TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int64FloatIntTransposeTrue) { auto param = GetParam(); run_current_test( @@ -330,7 +367,7 @@ TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int64FloatTransposeTrue) cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } -TEST_P(Tests_MGTransformE_Rmat, CheckInt64Int64FloatTransposeTrue) +TEST_P(Tests_MGTransformE_Rmat, CheckInt64Int64FloatIntTransposeTrue) { auto param = GetParam(); run_current_test( @@ -338,11 +375,71 @@ TEST_P(Tests_MGTransformE_Rmat, CheckInt64Int64FloatTransposeTrue) cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } +TEST_P(Tests_MGTransformE_File, CheckInt32Int32FloatBoolTransposeFalse) +{ + auto param = GetParam(); + run_current_test(std::get<0>(param), std::get<1>(param)); +} + +TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int32FloatBoolTransposeFalse) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int64FloatBoolTransposeFalse) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MGTransformE_Rmat, CheckInt64Int64FloatBoolTransposeFalse) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MGTransformE_File, CheckInt32Int32FloatBoolTransposeTrue) +{ + auto param = GetParam(); + run_current_test(std::get<0>(param), std::get<1>(param)); +} + +TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int32FloatBoolTransposeTrue) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MGTransformE_Rmat, CheckInt32Int64FloatBoolTransposeTrue) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MGTransformE_Rmat, CheckInt64Int64FloatBoolTransposeTrue) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + INSTANTIATE_TEST_SUITE_P( file_test, Tests_MGTransformE_File, ::testing::Combine( - ::testing::Values(Prims_Usecase{true}), + ::testing::Values(Prims_Usecase{false, true}, Prims_Usecase{true, true}), ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), @@ -350,7 +447,8 @@ INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(rmat_small_test, Tests_MGTransformE_Rmat, - ::testing::Combine(::testing::Values(Prims_Usecase{true}), + ::testing::Combine(::testing::Values(Prims_Usecase{false, true}, + Prims_Usecase{true, true}), ::testing::Values(cugraph::test::Rmat_Usecase( 10, 16, 0.57, 0.19, 0.19, 0, false, false)))); @@ -362,7 +460,7 @@ INSTANTIATE_TEST_SUITE_P( factor (to avoid running same benchmarks more than once) */ Tests_MGTransformE_Rmat, ::testing::Combine( - ::testing::Values(Prims_Usecase{false}), + ::testing::Values(Prims_Usecase{false, false}, Prims_Usecase{true, false}), ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); CUGRAPH_MG_TEST_PROGRAM_MAIN()