From 996c7a46526c24c2f7d046cb587f72c5f5da4ff1 Mon Sep 17 00:00:00 2001 From: Ralph Liu <137829296+nv-rliu@users.noreply.github.com> Date: Fri, 30 Aug 2024 20:45:08 -0400 Subject: [PATCH 1/2] Add `nx-cugraph` Benchmarking Scripts (#4616) Closes https://github.com/rapidsai/graph_dl/issues/596 This PR adds scripts written by @rlratzel to generate benchmarking numbers for `nx-cugraph` Authors: - Ralph Liu (https://github.com/nv-rliu) - Rick Ratzel (https://github.com/rlratzel) Approvers: - Don Acosta (https://github.com/acostadon) - Rick Ratzel (https://github.com/rlratzel) URL: https://github.com/rapidsai/cugraph/pull/4616 --- benchmarks/nx-cugraph/pytest-based/README.md | 45 +++ .../nx-cugraph/pytest-based/bench_algos.py | 31 +- .../create_results_summary_page.py | 291 ++++++++++++++++++ .../pytest-based/get_graph_bench_dataset.py | 35 +++ .../pytest-based/run-main-benchmarks.sh | 58 ++++ 5 files changed, 458 insertions(+), 2 deletions(-) create mode 100644 benchmarks/nx-cugraph/pytest-based/README.md create mode 100644 benchmarks/nx-cugraph/pytest-based/create_results_summary_page.py create mode 100644 benchmarks/nx-cugraph/pytest-based/get_graph_bench_dataset.py create mode 100755 benchmarks/nx-cugraph/pytest-based/run-main-benchmarks.sh diff --git a/benchmarks/nx-cugraph/pytest-based/README.md b/benchmarks/nx-cugraph/pytest-based/README.md new file mode 100644 index 00000000000..4ea0f127a51 --- /dev/null +++ b/benchmarks/nx-cugraph/pytest-based/README.md @@ -0,0 +1,45 @@ +## `nx-cugraph` Benchmarks + +### Overview + +This directory contains a set of scripts designed to benchmark NetworkX with the `nx-cugraph` backend and deliver a report that summarizes the speed-up and runtime deltas over default NetworkX. + +Our current benchmarks provide the following datasets: + +| Dataset | Nodes | Edges | Directed | +| -------- | ------- | ------- | ------- | +| netscience | 1,461 | 5,484 | Yes | +| email-Eu-core | 1,005 | 25,571 | Yes | +| cit-Patents | 3,774,768 | 16,518,948 | Yes | +| hollywood | 1,139,905 | 57,515,616 | No | +| soc-LiveJournal1 | 4,847,571 | 68,993,773 | Yes | + + + +### Scripts + +#### 1. `run-main-benchmarks.sh` +This script allows users to run selected algorithms across multiple datasets and backends. All results are stored inside a sub-directory (`logs/`) and output files are named based on the combination of parameters for that benchmark. + +NOTE: If running with all algorithms, datasets, and backends, this script may take a few hours to finish running. + +**Usage:** + ```bash + bash run-main-benchmarks.sh # edit this script directly + ``` + +#### 2. `get_graph_bench_dataset.py` +This script downloads the specified dataset using `cugraph.datasets`. + +**Usage:** + ```bash + python get_graph_bench_dataset.py [dataset] + ``` + +#### 3. `create_results_summary_page.py` +This script is designed to be run after `run-gap-benchmarks.sh` in order to generate an HTML page displaying a results table comparing default NetworkX to nx-cugraph. The script also provides information about the current system. + +**Usage:** + ```bash + python create_results_summary_page.py > report.html + ``` diff --git a/benchmarks/nx-cugraph/pytest-based/bench_algos.py b/benchmarks/nx-cugraph/pytest-based/bench_algos.py index d40b5130827..f88d93c3f17 100644 --- a/benchmarks/nx-cugraph/pytest-based/bench_algos.py +++ b/benchmarks/nx-cugraph/pytest-based/bench_algos.py @@ -271,9 +271,8 @@ def bench_from_networkx(benchmark, graph_obj): # normalized_param_values = [True, False] -# k_param_values = [10, 100] normalized_param_values = [True] -k_param_values = [10] +k_param_values = [10, 100, 1000] @pytest.mark.parametrize( @@ -282,6 +281,10 @@ def bench_from_networkx(benchmark, graph_obj): @pytest.mark.parametrize("k", k_param_values, ids=lambda k: f"{k=}") def bench_betweenness_centrality(benchmark, graph_obj, backend_wrapper, normalized, k): G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper) + + if k > G.number_of_nodes(): + pytest.skip(reason=f"{k=} > {G.number_of_nodes()=}") + result = benchmark.pedantic( target=backend_wrapper(nx.betweenness_centrality), args=(G,), @@ -305,6 +308,10 @@ def bench_edge_betweenness_centrality( benchmark, graph_obj, backend_wrapper, normalized, k ): G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper) + + if k > G.number_of_nodes(): + pytest.skip(reason=f"{k=} > {G.number_of_nodes()=}") + result = benchmark.pedantic( target=backend_wrapper(nx.edge_betweenness_centrality), args=(G,), @@ -473,6 +480,26 @@ def bench_pagerank_personalized(benchmark, graph_obj, backend_wrapper): assert type(result) is dict +def bench_shortest_path(benchmark, graph_obj, backend_wrapper): + """ + This passes in the source node with the highest degree, but no target. + """ + G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper) + node = get_highest_degree_node(graph_obj) + + result = benchmark.pedantic( + target=backend_wrapper(nx.shortest_path), + args=(G,), + kwargs=dict( + source=node, + ), + rounds=rounds, + iterations=iterations, + warmup_rounds=warmup_rounds, + ) + assert type(result) is dict + + def bench_single_source_shortest_path_length(benchmark, graph_obj, backend_wrapper): G = get_graph_obj_for_benchmark(graph_obj, backend_wrapper) node = get_highest_degree_node(graph_obj) diff --git a/benchmarks/nx-cugraph/pytest-based/create_results_summary_page.py b/benchmarks/nx-cugraph/pytest-based/create_results_summary_page.py new file mode 100644 index 00000000000..f1cc4b06ccc --- /dev/null +++ b/benchmarks/nx-cugraph/pytest-based/create_results_summary_page.py @@ -0,0 +1,291 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re +import pathlib +import json +import platform +import psutil +import socket +import subprocess + + +def get_formatted_time_value(time): + res = "" + if time < 1: + if time < 0.001: + units = "us" + time *= 1e6 + else: + units = "ms" + time *= 1e3 + else: + units = "s" + return f"{time:.3f}{units}" + + +def get_all_benchmark_info(): + benchmarks = {} + # Populate benchmarks dir from .json files + for json_file in logs_dir.glob("*.json"): + try: + data = json.loads(open(json_file).read()) + except json.decoder.JSONDecodeError: + continue + + for benchmark_run in data["benchmarks"]: + # example name: "bench_triangles[ds=netscience-backend=cugraph-preconverted]" + name = benchmark_run["name"] + + algo_name = name.split("[")[0] + if algo_name.startswith("bench_"): + algo_name = algo_name[6:] + # special case for betweenness_centrality + match = k_patt.match(name) + if match is not None: + algo_name += f", k={match.group(1)}" + + match = dataset_patt.match(name) + if match is None: + raise RuntimeError( + f"benchmark name {name} in file {json_file} has an unexpected format" + ) + dataset = match.group(1) + if dataset.endswith("-backend"): + dataset = dataset[:-8] + + match = backend_patt.match(name) + if match is None: + raise RuntimeError( + f"benchmark name {name} in file {json_file} has an unexpected format" + ) + backend = match.group(1) + if backend == "None": + backend = "networkx" + + runtime = benchmark_run["stats"]["mean"] + benchmarks.setdefault(algo_name, {}).setdefault(backend, {})[ + dataset + ] = runtime + return benchmarks + + +def compute_perf_vals(cugraph_runtime, networkx_runtime): + speedup_string = f"{networkx_runtime / cugraph_runtime:.3f}X" + delta = networkx_runtime - cugraph_runtime + if abs(delta) < 1: + if abs(delta) < 0.001: + units = "us" + delta *= 1e6 + else: + units = "ms" + delta *= 1e3 + else: + units = "s" + delta_string = f"{delta:.3f}{units}" + + return (speedup_string, delta_string) + + +def get_mem_info(): + return round(psutil.virtual_memory().total / (1024**3), 2) + + +def get_cuda_version(): + output = subprocess.check_output("nvidia-smi", shell=True).decode() + try: + return next( + line.split("CUDA Version: ")[1].split()[0] + for line in output.splitlines() + if "CUDA Version" in line + ) + except subprocess.CalledProcessError: + return "Failed to get CUDA version." + + +def get_first_gpu_info(): + try: + gpu_info = ( + subprocess.check_output( + "nvidia-smi --query-gpu=name,memory.total,memory.free,memory.used --format=csv,noheader", + shell=True, + ) + .decode() + .strip() + ) + if gpu_info: + gpus = gpu_info.split("\n") + num_gpus = len(gpus) + first_gpu = gpus[0] # Get the information for the first GPU + gpu_name, mem_total, _, _ = first_gpu.split(",") + return f"{num_gpus} x {gpu_name.strip()} ({round(int(mem_total.strip().split()[0]) / (1024), 2)} GB)" + else: + print("No GPU found or unable to query GPU details.") + except subprocess.CalledProcessError: + print("Failed to execute nvidia-smi. No GPU information available.") + + +def get_system_info(): + print('
') + print(f"

Hostname: {socket.gethostname()}

") + print( + f'

Operating System: {platform.system()} {platform.release()}

' + ) + print(f'

Kernel Version : {platform.version()}

') + with open("/proc/cpuinfo") as f: + print( + f'

CPU: {next(line.strip().split(": ")[1] for line in f if "model name" in line)} ({psutil.cpu_count(logical=False)} cores)

' + ) + print(f'

Memory: {get_mem_info()} GB

') + print(f"

GPU: {get_first_gpu_info()}

") + print(f"

CUDA Version: {get_cuda_version()}

") + + +if __name__ == "__main__": + logs_dir = pathlib.Path("logs") + + dataset_patt = re.compile(".*ds=([\w-]+).*") + backend_patt = re.compile(".*backend=(\w+).*") + k_patt = re.compile(".*k=(10*).*") + + # Organize all benchmark runs by the following hierarchy: algo -> backend -> dataset + benchmarks = get_all_benchmark_info() + + # dump HTML table + ordered_datasets = [ + "netscience", + "email_Eu_core", + "cit-patents", + "hollywood", + "soc-livejournal1", + ] + # dataset, # Node, # Edge, Directed info + dataset_meta = { + "netscience": ["1,461", "5,484", "Yes"], + "email_Eu_core": ["1,005", "25,571", "Yes"], + "cit-patents": ["3,774,768", "16,518,948", "Yes"], + "hollywood": ["1,139,905", "57,515,616", "No"], + "soc-livejournal1": ["4,847,571", "68,993,773", "Yes"], + } + + print( + """ + + + + + + + + """ + ) + for ds in ordered_datasets: + print( + f" " + ) + print( + """ + + + """ + ) + for algo_name in sorted(benchmarks): + algo_runs = benchmarks[algo_name] + print(" ") + print(f" ") + # Proceed only if any results are present for both cugraph and NX + if "cugraph" in algo_runs and "networkx" in algo_runs: + cugraph_algo_runs = algo_runs["cugraph"] + networkx_algo_runs = algo_runs["networkx"] + datasets_in_both = set(cugraph_algo_runs).intersection(networkx_algo_runs) + + # populate the table with speedup results for each dataset in the order + # specified in ordered_datasets. If results for a run using a dataset + # are not present for both cugraph and NX, output an empty cell. + for dataset in ordered_datasets: + if dataset in datasets_in_both: + cugraph_runtime = cugraph_algo_runs[dataset] + networkx_runtime = networkx_algo_runs[dataset] + (speedup, runtime_delta) = compute_perf_vals( + cugraph_runtime=cugraph_runtime, + networkx_runtime=networkx_runtime, + ) + nx_formatted = get_formatted_time_value(networkx_runtime) + cg_formatted = get_formatted_time_value(cugraph_runtime) + print( + f" " + ) + else: + print(f" ") + + # If a comparison between cugraph and NX cannot be made, output empty cells + # for each dataset + else: + for _ in range(len(ordered_datasets)): + print(" ") + print(" ") + print( + """ + \n
Dataset
Nodes
Edges
Directed
{ds}
{dataset_meta[ds][0]}
{dataset_meta[ds][1]}
{dataset_meta[ds][2]}
{algo_name}{nx_formatted} / {cg_formatted}
{speedup}
{runtime_delta}
+ \n
\n""") diff --git a/benchmarks/nx-cugraph/pytest-based/get_graph_bench_dataset.py b/benchmarks/nx-cugraph/pytest-based/get_graph_bench_dataset.py new file mode 100644 index 00000000000..5a0a15da8ee --- /dev/null +++ b/benchmarks/nx-cugraph/pytest-based/get_graph_bench_dataset.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Checks if a particular dataset has been downloaded inside the datasets dir +(RAPIDS_DATAEST_ROOT_DIR). If not, the file will be downloaded using the +datasets API. + +Positional Arguments: + 1) dataset name (e.g. 'email_Eu_core', 'cit-patents') + available datasets can be found here: `python/cugraph/cugraph/datasets/__init__.py` +""" + +import sys + +import cugraph.datasets as cgds + + +if __name__ == "__main__": + # download and store dataset (csv) by using the Datasets API + dataset = sys.argv[1].replace("-", "_") + dataset_obj = getattr(cgds, dataset) + + if not dataset_obj.get_path().exists(): + dataset_obj.get_edgelist(download=True) diff --git a/benchmarks/nx-cugraph/pytest-based/run-main-benchmarks.sh b/benchmarks/nx-cugraph/pytest-based/run-main-benchmarks.sh new file mode 100755 index 00000000000..1a81fe4b80a --- /dev/null +++ b/benchmarks/nx-cugraph/pytest-based/run-main-benchmarks.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# location to store datasets used for benchmarking +export RAPIDS_DATASET_ROOT_DIR=/datasets/cugraph +mkdir -p logs + +# list of algos, datasets, and back-ends to use in combinations +algos=" + pagerank + betweenness_centrality + louvain + shortest_path + weakly_connected_components + triangles + bfs_predecessors +" +datasets=" + netscience + email_Eu_core + cit_patents + hollywood + soc-livejournal +" +# None backend is default networkx +# cugraph-preconvert backend is nx-cugraph +backends=" + None + cugraph-preconverted +" + +for algo in $algos; do + for dataset in $datasets; do + python get_graph_bench_dataset.py $dataset + for backend in $backends; do + name="${backend}__${algo}__${dataset}" + echo "Running: $backend, $dataset, bench_$algo" + # command to preproduce test + # echo "RUNNING: \"pytest -sv -k \"$backend and $dataset and bench_$algo and not 1000\" --benchmark-json=\"logs/${name}.json\" bench_algos.py" + pytest -sv \ + -k "$backend and $dataset and bench_$algo and not 1000" \ + --benchmark-json="logs/${name}.json" \ + bench_algos.py 2>&1 | tee "logs/${name}.out" + done + done +done From 338e5e0944eeb7533288a274624720b730a21a81 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang <45857425+seunghwak@users.noreply.github.com> Date: Tue, 3 Sep 2024 09:07:14 -0700 Subject: [PATCH 2/2] Heterogeneous renumbering implementation (#4602) This PR implements heterogeneous renumbering for GNN. In addition, * Update the existing (homogeneous) sampling post processing function test file extension from .cu to .cpp. * Remove the unused `renumber_sampled_edgelist` function (breaking because this function is removed) * Add a `stride_fill` utility function (thrust wrapper) * Add test utility functions to generate edge types & IDs. Closes #4412 Authors: - Seunghwa Kang (https://github.com/seunghwak) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Alex Barghi (https://github.com/alexbarghi-nv) URL: https://github.com/rapidsai/cugraph/pull/4602 --- cpp/CMakeLists.txt | 2 - .../cugraph/detail/utility_wrappers.hpp | 22 + cpp/include/cugraph/graph_functions.hpp | 57 - cpp/include/cugraph/sampling_functions.hpp | 84 +- cpp/src/detail/utility_wrappers_32.cu | 17 + cpp/src/detail/utility_wrappers_64.cu | 12 + cpp/src/detail/utility_wrappers_impl.cuh | 16 + .../renumber_sampled_edgelist_impl.cuh | 719 ------- .../renumber_sampled_edgelist_sg_v32_e32.cu | 37 - .../renumber_sampled_edgelist_sg_v64_e64.cu | 37 - .../sampling_post_processing_impl.cuh | 1638 ++++++++++++++-- .../sampling_post_processing_sg_v32_e32.cu | 56 + .../sampling_post_processing_sg_v32_e64.cu | 56 + .../sampling_post_processing_sg_v64_e64.cu | 56 + cpp/tests/CMakeLists.txt | 11 +- .../sampling_post_processing_validate.cu | 1738 +++++++++++++++++ .../sampling_post_processing_validate.hpp | 101 + ...ing_heterogeneous_post_processing_test.cpp | 828 ++++++++ ...t.cu => sampling_post_processing_test.cpp} | 1151 ++++------- .../property_generator_utilities.hpp | 23 + .../property_generator_utilities_impl.cuh | 98 + cpp/tests/utilities/thrust_wrapper.cu | 69 + cpp/tests/utilities/thrust_wrapper.hpp | 14 + 23 files changed, 4971 insertions(+), 1871 deletions(-) delete mode 100644 cpp/src/sampling/renumber_sampled_edgelist_impl.cuh delete mode 100644 cpp/src/sampling/renumber_sampled_edgelist_sg_v32_e32.cu delete mode 100644 cpp/src/sampling/renumber_sampled_edgelist_sg_v64_e64.cu create mode 100644 cpp/tests/sampling/detail/sampling_post_processing_validate.cu create mode 100644 cpp/tests/sampling/detail/sampling_post_processing_validate.hpp create mode 100644 cpp/tests/sampling/sampling_heterogeneous_post_processing_test.cpp rename cpp/tests/sampling/{sampling_post_processing_test.cu => sampling_post_processing_test.cpp} (52%) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 26b710247f6..b8eaba9d575 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -338,8 +338,6 @@ set(CUGRAPH_SOURCES src/sampling/negative_sampling_mg_v32_e64.cu src/sampling/negative_sampling_mg_v32_e32.cu src/sampling/negative_sampling_mg_v64_e64.cu - src/sampling/renumber_sampled_edgelist_sg_v64_e64.cu - src/sampling/renumber_sampled_edgelist_sg_v32_e32.cu src/sampling/sampling_post_processing_sg_v64_e64.cu src/sampling/sampling_post_processing_sg_v32_e32.cu src/sampling/sampling_post_processing_sg_v32_e64.cu diff --git a/cpp/include/cugraph/detail/utility_wrappers.hpp b/cpp/include/cugraph/detail/utility_wrappers.hpp index 61ac1bd2804..3d99b85556b 100644 --- a/cpp/include/cugraph/detail/utility_wrappers.hpp +++ b/cpp/include/cugraph/detail/utility_wrappers.hpp @@ -87,6 +87,28 @@ void sequence_fill(rmm::cuda_stream_view const& stream_view, size_t size, value_t start_value); +/** + * @brief Fill a buffer with a sequence of values with the input stride + * + * Fills the buffer with the sequence with the input stride: + * {start_value, start_value+stride, start_value+stride*2, ..., start_value+stride*(size-1)} + * + * @tparam value_t type of the value to operate on + * + * @param[in] stream_view stream view + * @param[out] d_value device array to fill + * @param[in] size number of elements in array + * @param[in] start_value starting value for sequence + * @param[in] stride input stride + * + */ +template +void stride_fill(rmm::cuda_stream_view const& stream_view, + value_t* d_value, + size_t size, + value_t start_value, + value_t stride); + /** * @brief Compute the maximum vertex id of an edge list * diff --git a/cpp/include/cugraph/graph_functions.hpp b/cpp/include/cugraph/graph_functions.hpp index 7f6543ccab8..866ab16ee97 100644 --- a/cpp/include/cugraph/graph_functions.hpp +++ b/cpp/include/cugraph/graph_functions.hpp @@ -988,63 +988,6 @@ rmm::device_uvector select_random_vertices( bool sort_vertices, bool do_expensive_check = false); -/** - * @brief renumber sampling output - * - * @deprecated This API will be deprecated and will be replaced by the - * renumber_and_compress_sampled_edgelist and renumber_and_sort_sampled_edgelist functions in - * sampling_functions.hpp. - * - * This function renumbers sampling function (e.g. uniform_neighbor_sample) outputs satisfying the - * following requirements. - * - * 1. If @p edgelist_hops is valid, we can consider (vertex ID, flag=src, hop) triplets for each - * vertex ID in @p edgelist_srcs and (vertex ID, flag=dst, hop) triplets for each vertex ID in @p - * edgelist_dsts. From these triplets, we can find the minimum (hop, flag) pairs for every unique - * vertex ID (hop is the primary key and flag is the secondary key, flag=src is considered smaller - * than flag=dst if hop numbers are same). Vertex IDs with smaller (hop, flag) pairs precede vertex - * IDs with larger (hop, flag) pairs in renumbering. Ordering can be arbitrary among the vertices - * with the same (hop, flag) pairs. - * 2. If @p edgelist_hops is invalid, unique vertex IDs in @p edgelist_srcs precede vertex IDs that - * appear only in @p edgelist_dsts. - * 3. If label_offsets.has_value() is ture, edge lists for different labels will be renumbered - * separately. - * - * This function is single-GPU only (we are not aware of any practical multi-GPU use cases). - * - * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. - * @tparam label_t Type of labels. Needs to be an integral type. - * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and - * handles to various CUDA libraries) to run graph algorithms. - * @param edgelist_srcs A vector storing original edgelist source vertices. - * @param edgelist_dsts A vector storing original edgelist destination vertices (size = @p - * edgelist_srcs.size()). - * @param edgelist_hops An optional pointer to the array storing hops for each edge list (source, - * destination) pairs (size = @p edgelist_srcs.size() if valid). - * @param label_offsets An optional tuple of unique labels and the input edge list (@p - * edgelist_srcs, @p edgelist_hops, and @p edgelist_dsts) offsets for the labels (siez = # unique - * labels + 1). - * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). - * @return Tuple of vectors storing renumbered edge sources (size = @p edgelist_srcs.size()) , - * renumbered edge destinations (size = @p edgelist_dsts.size()), renumber_map to query original - * verties (size = # unique vertices or aggregate # unique vertices for every label), and - * renumber_map offsets (size = std::get<0>(*label_offsets).size() + 1, valid only if @p - * label_offsets.has_value() is true). - */ -template -std::tuple, - rmm::device_uvector, - rmm::device_uvector, - std::optional>> -renumber_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional> edgelist_hops, - std::optional, raft::device_span>> - label_offsets, - bool do_expensive_check = false); - /** * @brief Remove self loops from an edge list * diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index 4e5596d06e0..783cd3a7e2b 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -476,12 +476,12 @@ renumber_and_sort_sampled_edgelist( * 1. If @p edgelist_hops is valid, we can consider (vertex ID, hop, flag=major) triplets for each * vertex ID in edge majors (@p edgelist_srcs if @p src_is_major is true, @p edgelist_dsts if false) * and (vertex ID, hop, flag=minor) triplets for each vertex ID in edge minors. From these triplets, - * we can find the minimum (hop, flag) pairs for every unique vertex ID (hop is the primary key and + * we can find the minimum (hop, flag) pair for every unique vertex ID (hop is the primary key and * flag is the secondary key, flag=major is considered smaller than flag=minor if hop numbers are * same). Vertex IDs with smaller (hop, flag) pairs precede vertex IDs with larger (hop, flag) pairs * in renumbering (if their vertex types are same, vertices with different types are renumbered * separately). Ordering can be arbitrary among the vertices with the same (vertex type, hop, flag) - * triplets. If @p seed_vertices.has-value() is true, we assume (hop=0, flag=major) for every vertex + * triplets. If @p seed_vertices.has_value() is true, we assume (hop=0, flag=major) for every vertex * in @p *seed_vertices in renumbering (this is relevant when there are seed vertices with no * neighbors). * 2. If @p edgelist_hops is invalid, unique vertex IDs in edge majors precede vertex IDs that @@ -495,11 +495,15 @@ renumber_and_sort_sampled_edgelist( * Edge IDs are renumbered fulfilling the following requirements (This is relevant only when @p * edgelist_edge_ids.has_value() is true). * - * 1. If @p edgelist_edge_types.has_value() is true, unique (edge type, edge ID) pairs are - * renumbered to consecutive integers starting from 0 for each edge type. If @p - * edgelist_edge_types.has_value() is true, unique edge IDs are renumbered to consecutive inetgers - * starting from 0. - * 2. If edgelist_label_offsets.has_value() is true, edge lists for different labels will be + * 1. If @p edgelist_hops is valid, we can consider (edge ID, hop) pairs. From these pairs, we can + * find the minimum hop value for every unique edge ID. Edge IDs with smaller hop values precede + * edge IDs with larger hop values in renumbering (if their edge types are same, edges with + * different edge types are renumbered separately). Ordering can be arbitrary among the edge IDs + * with the same (edge type, hop) pairs. + * 2. If @p edgelist_edge_hops.has_value() is false, unique edge IDs (for each edge type is @p + * edgelist_edge_types.has_value() is true) are mapped to consecutive integers starting from 0. The + * ordering can be arbitrary. + * 3. If edgelist_label_offsets.has_value() is true, edge lists for different labels will be * renumbered separately. * * The renumbered edges are sorted based on the following rules. @@ -510,6 +514,11 @@ renumber_and_sort_sampled_edgelist( * true. * 2. Edges in each label are sorted independently if @p edgelist_label_offsets.has_value() is true. * + * This function assumes that there is a single edge source vertex type and a single edge + * destination vertex type for each edge. If @p edgelist_edge_types.has_value() is false (i.e. there + * is only one edge type), there should be only one edge source vertex type and only one edge + * destination vertex type; the source & destination vertex types may or may not coincide. + * * This function is single-GPU only (we are not aware of any practical multi-GPU use cases). * * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. @@ -530,19 +539,16 @@ renumber_and_sort_sampled_edgelist( * edgelist_srcs.size() if valid). * @param edgelist_hops An optional vector storing edge list hop numbers (size = @p * edgelist_srcs.size() if valid). @p edgelist_hops should be valid if @p num_hops >= 2. - * @param edgelist_label_offsets An optional pointer to the array storing label offsets to the input - * edges (size = @p num_labels + 1). @p edgelist_label_offsets should be valid if @p num_labels - * >= 2. * @param seed_vertices An optional pointer to the array storing seed vertices in hop 0. * @param seed_vertex_label_offsets An optional pointer to the array storing label offsets to the * seed vertices (size = @p num_labels + 1). @p seed_vertex_label_offsets should be valid if @p * num_labels >= 2 and @p seed_vertices is valid and invalid otherwise. - * ext_vertices A pointer to the array storing external vertex IDs for the local internal vertices. - * The local internal vertex range can be obatined bgy invoking a graph_view_t object's - * local_vertex_partition_range() function. ext_vertex_type offsets A pointer to the array storing - * vertex type offsets for the entire external vertex ID range (array size = @p num_vertex_types + - * 1). For example, if the array stores [0, 100, 200], external vertex IDs [0, 100) has vertex type - * 0 and external vertex IDs [100, 200) has vertex type 1. + * @param edgelist_label_offsets An optional pointer to the array storing label offsets to the input + * edges (size = @p num_labels + 1). @p edgelist_label_offsets should be valid if @p num_labels + * >= 2. + * @param vertex_type offsets A pointer to the array storing vertex type offsets for the entire + * vertex ID range (array size = @p num_vertex_types + 1). For example, if the array stores [0, 100, + * 200], vertex IDs [0, 100) has vertex type 0 and vertex IDs [100, 200) has vertex type 1. * @param num_labels Number of labels. Labels are considered if @p num_labels >=2 and ignored if @p * num_labels = 1. * @param num_hops Number of hops. Hop numbers are considered if @p num_hops >=2 and ignored if @p @@ -552,31 +558,36 @@ renumber_and_sort_sampled_edgelist( * @param src_is_major A flag to determine whether to use the source or destination as the * major key in renumbering and sorting. * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). - * @return Tuple of vectors storing edge sources, edge destinations, optional edge weights (valid - * only if @p edgelist_weights.has_value() is true), optional edge IDs (valid only if @p - * edgelist_edge_ids.has_value() is true), optional edge types (valid only if @p - * edgelist_edge_types.has_value() is true), optional (label, hop) offset values to the renumbered - * and sorted edges (size = @p num_labels * @p num_hops + 1, valid only when @p - * edgelist_hops.has_value() or @p edgelist_label_offsetes.has_value() is true), renumber_map to - * query original vertices (size = # unique or aggregate # unique vertices for each label), and - * label offsets to the renumber map (size = @p num_labels + 1, valid only if @p - * edgelist_label_offsets.has_value() is true). + * @return Tuple of vectors storing renumbered edge sources, renumbered edge destinations, optional + * edge weights (valid only if @p edgelist_weights.has_value() is true), optional renumbered edge + * IDs (valid only if @p edgelist_edge_ids.has_value() is true), optional (label, edge type, hop) + * offset values to the renumbered and sorted edges (size = @p num_labels * @p num_edge_types * @p + * num_hops + 1, valid only when @p edgelist_edge_types.has_value(), @p edgelist_hops.has_value(), + * or @p edgelist_label_offsetes.has_value() is true), renumber_map to query original vertices (size + * = # unique or aggregate # unique vertices for each label), (label, vertex type) offsets to the + * vertex renumber map (size = @p num_labels * @p num_vertex_types + 1), optional renumber_map to + * query original edge IDs (size = # unique (edge_type, edge ID) pairs, valid only if @p + * edgelist_edge_ids.has_value() is true), and optional (label, edge type) offsets to the edge ID + * renumber map (size = @p num_labels + @p num_edge_types + 1, valid only if @p + * edgelist_edge_ids.has_value() is true). We do not explicitly return edge source & destination + * vertex types as we assume that source & destination vertex type are implicilty determined for a + * given edge type. */ template std::tuple< - rmm::device_uvector, // srcs - rmm::device_uvector, // dsts - std::optional>, // weights - std::optional>, // edge IDs - std::optional>, // edge types - std::optional>, // (label, edge type, hop) offsets to the edges - rmm::device_uvector, // vertex renumber map - std::optional>, // (label, type) offsets to the vertex renumber map + rmm::device_uvector, // srcs + rmm::device_uvector, // dsts + std::optional>, // weights + std::optional>, // edge IDs + std::optional>, // (label, edge type, hop) offsets to the edges + rmm::device_uvector, // vertex renumber map + rmm::device_uvector, // (label, vertex type) offsets to the vertex renumber map std::optional>, // edge ID renumber map - std::optional>> // (label, type) offsets to the edge ID renumber map + std::optional< + rmm::device_uvector>> // (label, edge type) offsets to the edge ID renumber map heterogeneous_renumber_and_sort_sampled_edgelist( raft::handle_t const& handle, rmm::device_uvector&& edgelist_srcs, @@ -585,11 +596,10 @@ heterogeneous_renumber_and_sort_sampled_edgelist( std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, std::optional>&& edgelist_hops, - std::optional> edgelist_label_offsets, std::optional> seed_vertices, std::optional> seed_vertex_label_offsets, - raft::device_span ext_vertices, - raft::device_span ext_vertex_type_offsets, + std::optional> edgelist_label_offsets, + raft::device_span vertex_type_offsets, size_t num_labels, size_t num_hops, size_t num_vertex_types, diff --git a/cpp/src/detail/utility_wrappers_32.cu b/cpp/src/detail/utility_wrappers_32.cu index 72dee4a19a5..de407f12493 100644 --- a/cpp/src/detail/utility_wrappers_32.cu +++ b/cpp/src/detail/utility_wrappers_32.cu @@ -68,6 +68,23 @@ template void sequence_fill(rmm::cuda_stream_view const& stream_view, size_t size, int32_t start_value); +template void sequence_fill(rmm::cuda_stream_view const& stream_view, + uint32_t* d_value, + size_t size, + uint32_t start_value); + +template void stride_fill(rmm::cuda_stream_view const& stream_view, + int32_t* d_value, + size_t size, + int32_t start_value, + int32_t stride); + +template void stride_fill(rmm::cuda_stream_view const& stream_view, + uint32_t* d_value, + size_t size, + uint32_t start_value, + uint32_t stride); + template int32_t compute_maximum_vertex_id(rmm::cuda_stream_view const& stream_view, int32_t const* d_edgelist_srcs, int32_t const* d_edgelist_dsts, diff --git a/cpp/src/detail/utility_wrappers_64.cu b/cpp/src/detail/utility_wrappers_64.cu index e7254d97c4d..2c136d5902b 100644 --- a/cpp/src/detail/utility_wrappers_64.cu +++ b/cpp/src/detail/utility_wrappers_64.cu @@ -71,6 +71,18 @@ template void sequence_fill(rmm::cuda_stream_view const& stream_view, size_t size, uint64_t start_value); +template void stride_fill(rmm::cuda_stream_view const& stream_view, + int64_t* d_value, + size_t size, + int64_t start_value, + int64_t stride); + +template void stride_fill(rmm::cuda_stream_view const& stream_view, + uint64_t* d_value, + size_t size, + uint64_t start_value, + uint64_t stride); + template int64_t compute_maximum_vertex_id(rmm::cuda_stream_view const& stream_view, int64_t const* d_edgelist_srcs, int64_t const* d_edgelist_dsts, diff --git a/cpp/src/detail/utility_wrappers_impl.cuh b/cpp/src/detail/utility_wrappers_impl.cuh index ce8549db9f8..074d7044261 100644 --- a/cpp/src/detail/utility_wrappers_impl.cuh +++ b/cpp/src/detail/utility_wrappers_impl.cuh @@ -72,6 +72,22 @@ void sequence_fill(rmm::cuda_stream_view const& stream_view, thrust::sequence(rmm::exec_policy(stream_view), d_value, d_value + size, start_value); } +template +void stride_fill(rmm::cuda_stream_view const& stream_view, + value_t* d_value, + size_t size, + value_t start_value, + value_t stride) +{ + thrust::transform(rmm::exec_policy(stream_view), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(size), + d_value, + cuda::proclaim_return_type([start_value, stride] __device__(size_t i) { + return static_cast(start_value + stride * i); + })); +} + template vertex_t compute_maximum_vertex_id(rmm::cuda_stream_view const& stream_view, vertex_t const* d_edgelist_srcs, diff --git a/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh b/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh deleted file mode 100644 index f5bc3ef6d2e..00000000000 --- a/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh +++ /dev/null @@ -1,719 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "prims/kv_store.cuh" - -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -// FIXME: deprecated, to be deleted -namespace cugraph { - -namespace { - -// output sorted by (primary key:label_index, secondary key:vertex) -template -std::tuple> /* label indices */, - rmm::device_uvector /* vertices */, - std::optional> /* minimum hops for the vertices */, - std::optional> /* label offsets for the output */> -compute_min_hop_for_unique_label_vertex_pairs( - raft::handle_t const& handle, - raft::device_span vertices, - std::optional> hops, - std::optional> label_indices, - std::optional> label_offsets) -{ - auto approx_edges_to_sort_per_iteration = - static_cast(handle.get_device_properties().multiProcessorCount) * - (1 << 20) /* tuning parameter */; // for segmented sort - - if (label_indices) { - auto num_labels = (*label_offsets).size() - 1; - - rmm::device_uvector tmp_label_indices((*label_indices).size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - (*label_indices).begin(), - (*label_indices).end(), - tmp_label_indices.begin()); - - rmm::device_uvector tmp_vertices(0, handle.get_stream()); - std::optional> tmp_hops{std::nullopt}; - - if (hops) { - tmp_vertices.resize(vertices.size(), handle.get_stream()); - thrust::copy( - handle.get_thrust_policy(), vertices.begin(), vertices.end(), tmp_vertices.begin()); - tmp_hops = rmm::device_uvector((*hops).size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), (*hops).begin(), (*hops).end(), (*tmp_hops).begin()); - - auto triplet_first = thrust::make_zip_iterator( - tmp_label_indices.begin(), tmp_vertices.begin(), (*tmp_hops).begin()); - thrust::sort( - handle.get_thrust_policy(), triplet_first, triplet_first + tmp_label_indices.size()); - auto key_first = thrust::make_zip_iterator(tmp_label_indices.begin(), tmp_vertices.begin()); - auto num_uniques = static_cast( - thrust::distance(key_first, - thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), - key_first, - key_first + tmp_label_indices.size(), - (*tmp_hops).begin())))); - tmp_label_indices.resize(num_uniques, handle.get_stream()); - tmp_vertices.resize(num_uniques, handle.get_stream()); - (*tmp_hops).resize(num_uniques, handle.get_stream()); - tmp_label_indices.shrink_to_fit(handle.get_stream()); - tmp_vertices.shrink_to_fit(handle.get_stream()); - (*tmp_hops).shrink_to_fit(handle.get_stream()); - } else { - rmm::device_uvector segment_sorted_vertices(vertices.size(), handle.get_stream()); - - rmm::device_uvector d_tmp_storage(0, handle.get_stream()); - - auto [h_label_offsets, h_edge_offsets] = detail::compute_offset_aligned_element_chunks( - handle, *label_offsets, vertices.size(), approx_edges_to_sort_per_iteration); - auto num_chunks = h_label_offsets.size() - 1; - - for (size_t i = 0; i < num_chunks; ++i) { - size_t tmp_storage_bytes{0}; - - auto offset_first = - thrust::make_transform_iterator((*label_offsets).data() + h_label_offsets[i], - detail::shift_left_t{h_edge_offsets[i]}); - cub::DeviceSegmentedSort::SortKeys(static_cast(nullptr), - tmp_storage_bytes, - vertices.begin() + h_edge_offsets[i], - segment_sorted_vertices.begin() + h_edge_offsets[i], - h_edge_offsets[i + 1] - h_edge_offsets[i], - h_label_offsets[i + 1] - h_label_offsets[i], - offset_first, - offset_first + 1, - handle.get_stream()); - - if (tmp_storage_bytes > d_tmp_storage.size()) { - d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); - } - - cub::DeviceSegmentedSort::SortKeys(d_tmp_storage.data(), - tmp_storage_bytes, - vertices.begin() + h_edge_offsets[i], - segment_sorted_vertices.begin() + h_edge_offsets[i], - h_edge_offsets[i + 1] - h_edge_offsets[i], - h_label_offsets[i + 1] - h_label_offsets[i], - offset_first, - offset_first + 1, - handle.get_stream()); - } - d_tmp_storage.resize(0, handle.get_stream()); - d_tmp_storage.shrink_to_fit(handle.get_stream()); - - auto pair_first = - thrust::make_zip_iterator(tmp_label_indices.begin(), segment_sorted_vertices.begin()); - auto num_uniques = static_cast(thrust::distance( - pair_first, - thrust::unique( - handle.get_thrust_policy(), pair_first, pair_first + tmp_label_indices.size()))); - tmp_label_indices.resize(num_uniques, handle.get_stream()); - segment_sorted_vertices.resize(num_uniques, handle.get_stream()); - tmp_label_indices.shrink_to_fit(handle.get_stream()); - segment_sorted_vertices.shrink_to_fit(handle.get_stream()); - - tmp_vertices = std::move(segment_sorted_vertices); - } - - rmm::device_uvector tmp_label_offsets(num_labels + 1, handle.get_stream()); - tmp_label_offsets.set_element_to_zero_async(0, handle.get_stream()); - thrust::upper_bound(handle.get_thrust_policy(), - tmp_label_indices.begin(), - tmp_label_indices.end(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(num_labels), - tmp_label_offsets.begin() + 1); - - return std::make_tuple(std::move(tmp_label_indices), - std::move(tmp_vertices), - std::move(tmp_hops), - std::move(tmp_label_offsets)); - } else { - rmm::device_uvector tmp_vertices(vertices.size(), handle.get_stream()); - thrust::copy( - handle.get_thrust_policy(), vertices.begin(), vertices.end(), tmp_vertices.begin()); - - if (hops) { - rmm::device_uvector tmp_hops((*hops).size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), (*hops).begin(), (*hops).end(), tmp_hops.begin()); - - auto pair_first = thrust::make_zip_iterator( - tmp_vertices.begin(), tmp_hops.begin()); // vertex is a primary key, hop is a secondary key - thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + tmp_vertices.size()); - tmp_vertices.resize( - thrust::distance(tmp_vertices.begin(), - thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), - tmp_vertices.begin(), - tmp_vertices.end(), - tmp_hops.begin()))), - handle.get_stream()); - tmp_hops.resize(tmp_vertices.size(), handle.get_stream()); - - return std::make_tuple( - std::nullopt, std::move(tmp_vertices), std::move(tmp_hops), std::nullopt); - } else { - thrust::sort(handle.get_thrust_policy(), tmp_vertices.begin(), tmp_vertices.end()); - tmp_vertices.resize( - thrust::distance( - tmp_vertices.begin(), - thrust::unique(handle.get_thrust_policy(), tmp_vertices.begin(), tmp_vertices.end())), - handle.get_stream()); - tmp_vertices.shrink_to_fit(handle.get_stream()); - - return std::make_tuple(std::nullopt, std::move(tmp_vertices), std::nullopt, std::nullopt); - } - } -} - -template -std::tuple, std::optional>> -compute_renumber_map(raft::handle_t const& handle, - raft::device_span edgelist_srcs, - raft::device_span edgelist_dsts, - std::optional> edgelist_hops, - std::optional> label_offsets) -{ - auto approx_edges_to_sort_per_iteration = - static_cast(handle.get_device_properties().multiProcessorCount) * - (1 << 20) /* tuning parameter */; // for segmented sort - - std::optional> edgelist_label_indices{std::nullopt}; - if (label_offsets) { - edgelist_label_indices = - detail::expand_sparse_offsets(*label_offsets, label_index_t{0}, handle.get_stream()); - } - - auto [unique_label_src_pair_label_indices, - unique_label_src_pair_vertices, - unique_label_src_pair_hops, - unique_label_src_pair_label_offsets] = - compute_min_hop_for_unique_label_vertex_pairs( - handle, - edgelist_srcs, - edgelist_hops, - edgelist_label_indices ? std::make_optional>( - (*edgelist_label_indices).data(), (*edgelist_label_indices).size()) - : std::nullopt, - label_offsets); - - auto [unique_label_dst_pair_label_indices, - unique_label_dst_pair_vertices, - unique_label_dst_pair_hops, - unique_label_dst_pair_label_offsets] = - compute_min_hop_for_unique_label_vertex_pairs( - handle, - edgelist_dsts, - edgelist_hops, - edgelist_label_indices ? std::make_optional>( - (*edgelist_label_indices).data(), (*edgelist_label_indices).size()) - : std::nullopt, - label_offsets); - - edgelist_label_indices = std::nullopt; - - if (label_offsets) { - auto num_labels = (*label_offsets).size() - 1; - - rmm::device_uvector renumber_map(0, handle.get_stream()); - rmm::device_uvector renumber_map_label_indices(0, handle.get_stream()); - - renumber_map.reserve( - (*unique_label_src_pair_label_indices).size() + (*unique_label_dst_pair_label_indices).size(), - handle.get_stream()); - renumber_map_label_indices.reserve(renumber_map.capacity(), handle.get_stream()); - - auto num_chunks = (edgelist_srcs.size() + (approx_edges_to_sort_per_iteration - 1)) / - approx_edges_to_sort_per_iteration; - auto chunk_size = (num_chunks > 0) ? ((num_labels + (num_chunks - 1)) / num_chunks) : 0; - - size_t copy_offset{0}; - for (size_t i = 0; i < num_chunks; ++i) { - auto src_start_offset = - (*unique_label_src_pair_label_offsets).element(chunk_size * i, handle.get_stream()); - auto src_end_offset = - (*unique_label_src_pair_label_offsets) - .element(std::min(chunk_size * (i + 1), num_labels), handle.get_stream()); - auto dst_start_offset = - (*unique_label_dst_pair_label_offsets).element(chunk_size * i, handle.get_stream()); - auto dst_end_offset = - (*unique_label_dst_pair_label_offsets) - .element(std::min(chunk_size * (i + 1), num_labels), handle.get_stream()); - - rmm::device_uvector merged_label_indices( - (src_end_offset - src_start_offset) + (dst_end_offset - dst_start_offset), - handle.get_stream()); - rmm::device_uvector merged_vertices(merged_label_indices.size(), - handle.get_stream()); - rmm::device_uvector merged_flags(merged_label_indices.size(), handle.get_stream()); - - if (edgelist_hops) { - rmm::device_uvector merged_hops(merged_label_indices.size(), handle.get_stream()); - auto src_quad_first = - thrust::make_zip_iterator((*unique_label_src_pair_label_indices).begin(), - unique_label_src_pair_vertices.begin(), - (*unique_label_src_pair_hops).begin(), - thrust::make_constant_iterator(int8_t{0})); - auto dst_quad_first = - thrust::make_zip_iterator((*unique_label_dst_pair_label_indices).begin(), - unique_label_dst_pair_vertices.begin(), - (*unique_label_dst_pair_hops).begin(), - thrust::make_constant_iterator(int8_t{1})); - thrust::merge(handle.get_thrust_policy(), - src_quad_first + src_start_offset, - src_quad_first + src_end_offset, - dst_quad_first + dst_start_offset, - dst_quad_first + dst_end_offset, - thrust::make_zip_iterator(merged_label_indices.begin(), - merged_vertices.begin(), - merged_hops.begin(), - merged_flags.begin())); - - auto unique_key_first = - thrust::make_zip_iterator(merged_label_indices.begin(), merged_vertices.begin()); - merged_label_indices.resize( - thrust::distance( - unique_key_first, - thrust::get<0>(thrust::unique_by_key( - handle.get_thrust_policy(), - unique_key_first, - unique_key_first + merged_label_indices.size(), - thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin())))), - handle.get_stream()); - merged_vertices.resize(merged_label_indices.size(), handle.get_stream()); - merged_hops.resize(merged_label_indices.size(), handle.get_stream()); - merged_flags.resize(merged_label_indices.size(), handle.get_stream()); - auto sort_key_first = thrust::make_zip_iterator( - merged_label_indices.begin(), merged_hops.begin(), merged_flags.begin()); - thrust::sort_by_key(handle.get_thrust_policy(), - sort_key_first, - sort_key_first + merged_label_indices.size(), - merged_vertices.begin()); - } else { - auto src_triplet_first = - thrust::make_zip_iterator((*unique_label_src_pair_label_indices).begin(), - unique_label_src_pair_vertices.begin(), - thrust::make_constant_iterator(int8_t{0})); - auto dst_triplet_first = - thrust::make_zip_iterator((*unique_label_dst_pair_label_indices).begin(), - unique_label_dst_pair_vertices.begin(), - thrust::make_constant_iterator(int8_t{1})); - thrust::merge( - handle.get_thrust_policy(), - src_triplet_first + src_start_offset, - src_triplet_first + src_end_offset, - dst_triplet_first + dst_start_offset, - dst_triplet_first + dst_end_offset, - thrust::make_zip_iterator( - merged_label_indices.begin(), merged_vertices.begin(), merged_flags.begin())); - - auto unique_key_first = - thrust::make_zip_iterator(merged_label_indices.begin(), merged_vertices.begin()); - merged_label_indices.resize( - thrust::distance( - unique_key_first, - thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), - unique_key_first, - unique_key_first + merged_label_indices.size(), - merged_flags.begin()))), - handle.get_stream()); - merged_vertices.resize(merged_label_indices.size(), handle.get_stream()); - merged_flags.resize(merged_label_indices.size(), handle.get_stream()); - auto sort_key_first = - thrust::make_zip_iterator(merged_label_indices.begin(), merged_flags.begin()); - thrust::sort_by_key(handle.get_thrust_policy(), - sort_key_first, - sort_key_first + merged_label_indices.size(), - merged_vertices.begin()); - } - - renumber_map.resize(copy_offset + merged_vertices.size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - merged_vertices.begin(), - merged_vertices.end(), - renumber_map.begin() + copy_offset); - renumber_map_label_indices.resize(copy_offset + merged_label_indices.size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - merged_label_indices.begin(), - merged_label_indices.end(), - renumber_map_label_indices.begin() + copy_offset); - - copy_offset += merged_vertices.size(); - } - - renumber_map.shrink_to_fit(handle.get_stream()); - renumber_map_label_indices.shrink_to_fit(handle.get_stream()); - - return std::make_tuple(std::move(renumber_map), std::move(renumber_map_label_indices)); - } else { - if (edgelist_hops) { - rmm::device_uvector merged_vertices( - unique_label_src_pair_vertices.size() + unique_label_dst_pair_vertices.size(), - handle.get_stream()); - rmm::device_uvector merged_hops(merged_vertices.size(), handle.get_stream()); - rmm::device_uvector merged_flags(merged_vertices.size(), handle.get_stream()); - auto src_triplet_first = thrust::make_zip_iterator(unique_label_src_pair_vertices.begin(), - (*unique_label_src_pair_hops).begin(), - thrust::make_constant_iterator(int8_t{0})); - auto dst_triplet_first = thrust::make_zip_iterator(unique_label_dst_pair_vertices.begin(), - (*unique_label_dst_pair_hops).begin(), - thrust::make_constant_iterator(int8_t{1})); - thrust::merge(handle.get_thrust_policy(), - src_triplet_first, - src_triplet_first + unique_label_src_pair_vertices.size(), - dst_triplet_first, - dst_triplet_first + unique_label_dst_pair_vertices.size(), - thrust::make_zip_iterator( - merged_vertices.begin(), merged_hops.begin(), merged_flags.begin())); - - unique_label_src_pair_vertices.resize(0, handle.get_stream()); - unique_label_src_pair_vertices.shrink_to_fit(handle.get_stream()); - unique_label_src_pair_hops = std::nullopt; - unique_label_dst_pair_vertices.resize(0, handle.get_stream()); - unique_label_dst_pair_vertices.shrink_to_fit(handle.get_stream()); - unique_label_dst_pair_hops = std::nullopt; - - merged_vertices.resize( - thrust::distance(merged_vertices.begin(), - thrust::get<0>(thrust::unique_by_key( - handle.get_thrust_policy(), - merged_vertices.begin(), - merged_vertices.end(), - thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin())))), - handle.get_stream()); - merged_hops.resize(merged_vertices.size(), handle.get_stream()); - merged_flags.resize(merged_vertices.size(), handle.get_stream()); - - auto sort_key_first = thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin()); - thrust::sort_by_key(handle.get_thrust_policy(), - sort_key_first, - sort_key_first + merged_hops.size(), - merged_vertices.begin()); - - return std::make_tuple(std::move(merged_vertices), std::nullopt); - } else { - rmm::device_uvector output_vertices(unique_label_dst_pair_vertices.size(), - handle.get_stream()); - auto output_last = thrust::set_difference(handle.get_thrust_policy(), - unique_label_dst_pair_vertices.begin(), - unique_label_dst_pair_vertices.end(), - unique_label_src_pair_vertices.begin(), - unique_label_src_pair_vertices.end(), - output_vertices.begin()); - - auto num_unique_srcs = unique_label_src_pair_vertices.size(); - auto renumber_map = std::move(unique_label_src_pair_vertices); - renumber_map.resize( - renumber_map.size() + thrust::distance(output_vertices.begin(), output_last), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - output_vertices.begin(), - output_last, - renumber_map.begin() + num_unique_srcs); - - return std::make_tuple(std::move(renumber_map), std::nullopt); - } - } -} - -} // namespace - -template -std::tuple, - rmm::device_uvector, - rmm::device_uvector, - std::optional>> -renumber_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional> edgelist_hops, - std::optional, raft::device_span>> - label_offsets, - bool do_expensive_check) -{ - using label_index_t = uint32_t; - - // 1. check input arguments - - CUGRAPH_EXPECTS(!label_offsets || (std::get<0>(*label_offsets).size() <= - std::numeric_limits::max()), - "Invalid input arguments: current implementation assumes that the number of " - "unique labels is no larger than std::numeric_limits::max()."); - - CUGRAPH_EXPECTS( - edgelist_srcs.size() == edgelist_dsts.size(), - "Invalid input arguments: edgelist_srcs.size() and edgelist_dsts.size() should coincide."); - CUGRAPH_EXPECTS(!edgelist_hops.has_value() || (edgelist_srcs.size() == (*edgelist_hops).size()), - "Invalid input arguments: if edgelist_hops is valid, (*edgelist_hops).size() and " - "edgelist_srcs.size() should coincide."); - CUGRAPH_EXPECTS(!label_offsets.has_value() || - (std::get<1>(*label_offsets).size() == std::get<0>(*label_offsets).size() + 1), - "Invalid input arguments: if label_offsets is valid, " - "std::get<1>(label_offsets).size() (size of the offset array) should be " - "std::get<0>(label_offsets).size() (number of unique labels) + 1."); - - if (do_expensive_check) { - if (label_offsets) { - CUGRAPH_EXPECTS(thrust::is_sorted(handle.get_thrust_policy(), - std::get<1>(*label_offsets).begin(), - std::get<1>(*label_offsets).end()), - "Invalid input arguments: if label_offsets is valid, " - "std::get<1>(*label_offsets) should be sorted."); - size_t back_element{}; - raft::update_host( - &back_element, - std::get<1>(*label_offsets).data() + (std::get<1>(*label_offsets).size() - 1), - size_t{1}, - handle.get_stream()); - handle.get_stream(); - CUGRAPH_EXPECTS(back_element == edgelist_srcs.size(), - "Invalid input arguments: if label_offsets is valid, the last element of " - "std::get<1>(*label_offsets) and edgelist_srcs.size() should coincide."); - } - } - - // 2. compute renumber_map - - auto [renumber_map, renumber_map_label_indices] = compute_renumber_map( - handle, - raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), - raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), - edgelist_hops, - label_offsets ? std::make_optional>(std::get<1>(*label_offsets)) - : std::nullopt); - - // 3. compute renumber map offsets for each label - - std::optional> renumber_map_label_offsets{}; - if (label_offsets) { - auto num_unique_labels = thrust::count_if( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator((*renumber_map_label_indices).size()), - detail::is_first_in_run_t{(*renumber_map_label_indices).data()}); - rmm::device_uvector unique_label_indices(num_unique_labels, handle.get_stream()); - rmm::device_uvector vertex_counts(num_unique_labels, handle.get_stream()); - thrust::reduce_by_key(handle.get_thrust_policy(), - (*renumber_map_label_indices).begin(), - (*renumber_map_label_indices).end(), - thrust::make_constant_iterator(size_t{1}), - unique_label_indices.begin(), - vertex_counts.begin()); - - renumber_map_label_offsets = - rmm::device_uvector(std::get<0>(*label_offsets).size() + 1, handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), - (*renumber_map_label_offsets).begin(), - (*renumber_map_label_offsets).end(), - size_t{0}); - thrust::scatter(handle.get_thrust_policy(), - vertex_counts.begin(), - vertex_counts.end(), - unique_label_indices.begin(), - (*renumber_map_label_offsets).begin() + 1); - - thrust::inclusive_scan(handle.get_thrust_policy(), - (*renumber_map_label_offsets).begin(), - (*renumber_map_label_offsets).end(), - (*renumber_map_label_offsets).begin()); - } - - // 4. renumber input edges - - if (label_offsets) { - rmm::device_uvector new_vertices(renumber_map.size(), handle.get_stream()); - thrust::tabulate(handle.get_thrust_policy(), - new_vertices.begin(), - new_vertices.end(), - [label_indices = raft::device_span( - (*renumber_map_label_indices).data(), (*renumber_map_label_indices).size()), - renumber_map_label_offsets = raft::device_span( - (*renumber_map_label_offsets).data(), - (*renumber_map_label_offsets).size())] __device__(size_t i) { - auto label_index = label_indices[i]; - auto label_start_offset = renumber_map_label_offsets[label_index]; - return static_cast(i - label_start_offset); - }); - - (*renumber_map_label_indices).resize(0, handle.get_stream()); - (*renumber_map_label_indices).shrink_to_fit(handle.get_stream()); - - auto num_labels = std::get<0>(*label_offsets).size(); - - rmm::device_uvector segment_sorted_renumber_map(renumber_map.size(), - handle.get_stream()); - rmm::device_uvector segment_sorted_new_vertices(new_vertices.size(), - handle.get_stream()); - - rmm::device_uvector d_tmp_storage(0, handle.get_stream()); - - auto approx_edges_to_sort_per_iteration = - static_cast(handle.get_device_properties().multiProcessorCount) * - (1 << 20) /* tuning parameter */; // for segmented sort - - auto [h_label_offsets, h_edge_offsets] = detail::compute_offset_aligned_element_chunks( - handle, - raft::device_span{(*renumber_map_label_offsets).data(), - (*renumber_map_label_offsets).size()}, - renumber_map.size(), - approx_edges_to_sort_per_iteration); - auto num_chunks = h_label_offsets.size() - 1; - - for (size_t i = 0; i < num_chunks; ++i) { - size_t tmp_storage_bytes{0}; - - auto offset_first = - thrust::make_transform_iterator((*renumber_map_label_offsets).data() + h_label_offsets[i], - detail::shift_left_t{h_edge_offsets[i]}); - cub::DeviceSegmentedSort::SortPairs(static_cast(nullptr), - tmp_storage_bytes, - renumber_map.begin() + h_edge_offsets[i], - segment_sorted_renumber_map.begin() + h_edge_offsets[i], - new_vertices.begin() + h_edge_offsets[i], - segment_sorted_new_vertices.begin() + h_edge_offsets[i], - h_edge_offsets[i + 1] - h_edge_offsets[i], - h_label_offsets[i + 1] - h_label_offsets[i], - offset_first, - offset_first + 1, - handle.get_stream()); - - if (tmp_storage_bytes > d_tmp_storage.size()) { - d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); - } - - cub::DeviceSegmentedSort::SortPairs(d_tmp_storage.data(), - tmp_storage_bytes, - renumber_map.begin() + h_edge_offsets[i], - segment_sorted_renumber_map.begin() + h_edge_offsets[i], - new_vertices.begin() + h_edge_offsets[i], - segment_sorted_new_vertices.begin() + h_edge_offsets[i], - h_edge_offsets[i + 1] - h_edge_offsets[i], - h_label_offsets[i + 1] - h_label_offsets[i], - offset_first, - offset_first + 1, - handle.get_stream()); - } - new_vertices.resize(0, handle.get_stream()); - d_tmp_storage.resize(0, handle.get_stream()); - new_vertices.shrink_to_fit(handle.get_stream()); - d_tmp_storage.shrink_to_fit(handle.get_stream()); - - auto edgelist_label_indices = detail::expand_sparse_offsets( - std::get<1>(*label_offsets), label_index_t{0}, handle.get_stream()); - - auto pair_first = - thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_label_indices.begin()); - thrust::transform( - handle.get_thrust_policy(), - pair_first, - pair_first + edgelist_srcs.size(), - edgelist_srcs.begin(), - [renumber_map_label_offsets = raft::device_span( - (*renumber_map_label_offsets).data(), (*renumber_map_label_offsets).size()), - old_vertices = raft::device_span(segment_sorted_renumber_map.data(), - segment_sorted_renumber_map.size()), - new_vertices = raft::device_span( - segment_sorted_new_vertices.data(), - segment_sorted_new_vertices.size())] __device__(auto pair) { - auto old_vertex = thrust::get<0>(pair); - auto label_index = thrust::get<1>(pair); - auto label_start_offset = renumber_map_label_offsets[label_index]; - auto label_end_offset = renumber_map_label_offsets[label_index + 1]; - auto it = thrust::lower_bound(thrust::seq, - old_vertices.begin() + label_start_offset, - old_vertices.begin() + label_end_offset, - old_vertex); - assert(*it == old_vertex); - return *(new_vertices.begin() + thrust::distance(old_vertices.begin(), it)); - }); - - pair_first = thrust::make_zip_iterator(edgelist_dsts.begin(), edgelist_label_indices.begin()); - thrust::transform( - handle.get_thrust_policy(), - pair_first, - pair_first + edgelist_dsts.size(), - edgelist_dsts.begin(), - [renumber_map_label_offsets = raft::device_span( - (*renumber_map_label_offsets).data(), (*renumber_map_label_offsets).size()), - old_vertices = raft::device_span(segment_sorted_renumber_map.data(), - segment_sorted_renumber_map.size()), - new_vertices = raft::device_span( - segment_sorted_new_vertices.data(), - segment_sorted_new_vertices.size())] __device__(auto pair) { - auto old_vertex = thrust::get<0>(pair); - auto label_index = thrust::get<1>(pair); - auto label_start_offset = renumber_map_label_offsets[label_index]; - auto label_end_offset = renumber_map_label_offsets[label_index + 1]; - auto it = thrust::lower_bound(thrust::seq, - old_vertices.begin() + label_start_offset, - old_vertices.begin() + label_end_offset, - old_vertex); - assert(*it == old_vertex); - return new_vertices[thrust::distance(old_vertices.begin(), it)]; - }); - } else { - kv_store_t kv_store(renumber_map.begin(), - renumber_map.end(), - thrust::make_counting_iterator(vertex_t{0}), - std::numeric_limits::max(), - std::numeric_limits::max(), - handle.get_stream()); - auto kv_store_view = kv_store.view(); - - kv_store_view.find( - edgelist_srcs.begin(), edgelist_srcs.end(), edgelist_srcs.begin(), handle.get_stream()); - kv_store_view.find( - edgelist_dsts.begin(), edgelist_dsts.end(), edgelist_dsts.begin(), handle.get_stream()); - } - - return std::make_tuple(std::move(edgelist_srcs), - std::move(edgelist_dsts), - std::move(renumber_map), - std::move(renumber_map_label_offsets)); -} - -} // namespace cugraph diff --git a/cpp/src/sampling/renumber_sampled_edgelist_sg_v32_e32.cu b/cpp/src/sampling/renumber_sampled_edgelist_sg_v32_e32.cu deleted file mode 100644 index dee28c593ad..00000000000 --- a/cpp/src/sampling/renumber_sampled_edgelist_sg_v32_e32.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "renumber_sampled_edgelist_impl.cuh" - -#include - -// FIXME: deprecated, to be deleted -namespace cugraph { - -template std::tuple, - rmm::device_uvector, - rmm::device_uvector, - std::optional>> -renumber_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional> edgelist_hops, - std::optional, raft::device_span>> - label_offsets, - bool do_expensive_check); - -} // namespace cugraph diff --git a/cpp/src/sampling/renumber_sampled_edgelist_sg_v64_e64.cu b/cpp/src/sampling/renumber_sampled_edgelist_sg_v64_e64.cu deleted file mode 100644 index 99293c68f0c..00000000000 --- a/cpp/src/sampling/renumber_sampled_edgelist_sg_v64_e64.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "renumber_sampled_edgelist_impl.cuh" - -#include - -// FIXME: deprecated, to be deleted -namespace cugraph { - -template std::tuple, - rmm::device_uvector, - rmm::device_uvector, - std::optional>> -renumber_sampled_edgelist( - raft::handle_t const& handle, - rmm::device_uvector&& edgelist_srcs, - rmm::device_uvector&& edgelist_dsts, - std::optional> edgelist_hops, - std::optional, raft::device_span>> - label_offsets, - bool do_expensive_check); - -} // namespace cugraph diff --git a/cpp/src/sampling/sampling_post_processing_impl.cuh b/cpp/src/sampling/sampling_post_processing_impl.cuh index b0b3bb5f4f2..4624e6d4a5e 100644 --- a/cpp/src/sampling/sampling_post_processing_impl.cuh +++ b/cpp/src/sampling/sampling_post_processing_impl.cuh @@ -49,9 +49,10 @@ namespace cugraph { namespace { -template +template struct edge_order_t { thrust::optional> edgelist_label_offsets{thrust::nullopt}; + thrust::optional> edgelist_edge_types{thrust::nullopt}; thrust::optional> edgelist_hops{thrust::nullopt}; raft::device_span edgelist_majors{}; raft::device_span edgelist_minors{}; @@ -72,6 +73,12 @@ struct edge_order_t { if (l_label != r_label) { return l_label < r_label; } } + if (edgelist_edge_types) { + auto l_type = (*edgelist_edge_types)[l_idx]; + auto r_type = (*edgelist_edge_types)[r_idx]; + if (l_type != r_type) { return l_type < r_type; } + } + if (edgelist_hops) { auto l_hop = (*edgelist_hops)[l_idx]; auto r_hop = (*edgelist_hops)[r_idx]; @@ -151,6 +158,7 @@ struct optionally_compute_label_index_t { template @@ -164,8 +172,11 @@ void check_input_edges(raft::handle_t const& handle, std::optional> seed_vertices, std::optional> seed_vertex_label_offsets, std::optional> edgelist_label_offsets, + std::optional> vertex_type_offsets, size_t num_labels, size_t num_hops, + size_t num_vertex_types, + std::optional num_edge_types, bool do_expensive_check) { CUGRAPH_EXPECTS( @@ -193,6 +204,7 @@ void check_input_edges(raft::handle_t const& handle, "(size of the offset array) should be num_labels + 1."); if (edgelist_majors.size() > 0) { + static_assert(std::is_same_v); CUGRAPH_EXPECTS((num_labels >= 1) && (num_labels <= std::numeric_limits::max()), "Invalid input arguments: num_labels should be a positive integer and the " "current implementation assumes that the number of unique labels is no larger " @@ -209,13 +221,16 @@ void check_input_edges(raft::handle_t const& handle, CUGRAPH_EXPECTS( (num_hops == 1) || edgelist_hops.has_value(), "Invalid input arguments: edgelist_hops.has_value() should be true if num_hops >= 2."); - } else { - CUGRAPH_EXPECTS( - "num_labels == 0", - "Invalid input arguments: num_labels should be 0 if the input edge list is empty."); + + static_assert(std::is_same_v); CUGRAPH_EXPECTS( - "num_hops == 0", - "Invalid input arguments: num_hops should be 0 if the input edge list is empty."); + (num_vertex_types >= 1) && (num_vertex_types <= std::numeric_limits::max()), + "Invalid input arguments: num_vertex_types should be a positive integer and the " + "current implementation assumes that the number of vertex types is no larger " + "than std::numeric_limits::max()."); + CUGRAPH_EXPECTS((num_vertex_types == 1) || vertex_type_offsets.has_value(), + "Invalid input arguments: vertex_type_offsets.has_value() should be true if " + "num_vertex_types >= 2."); } CUGRAPH_EXPECTS((!seed_vertices.has_value() && !seed_vertex_label_offsets.has_value()) || @@ -257,6 +272,174 @@ void check_input_edges(raft::handle_t const& handle, "*edgelist_label_offsets and edgelist_(srcs|dsts).size() should coincide."); } + if (edgelist_edge_types && num_edge_types) { + CUGRAPH_EXPECTS( + thrust::count_if(handle.get_thrust_policy(), + (*edgelist_edge_types).begin(), + (*edgelist_edge_types).end(), + [num_edge_types = static_cast(*num_edge_types)] __device__( + edge_type_t edge_type) { return edge_type >= num_edge_types; }) == 0, + "Invalid input arguments: edgelist_edge_type is valid but contains out-of-range edge type " + "values."); + if constexpr (std::is_signed_v) { + CUGRAPH_EXPECTS(thrust::count_if(handle.get_thrust_policy(), + (*edgelist_edge_types).begin(), + (*edgelist_edge_types).end(), + [] __device__(edge_type_t edge_type) { + return edge_type < edge_type_t{0}; + }) == 0, + "Invalid input arguments: edgelist_edge_type is valid but contains " + "negative edge type values."); + } + } + + if (vertex_type_offsets) { + CUGRAPH_EXPECTS( + thrust::is_sorted( + handle.get_thrust_policy(), (*vertex_type_offsets).begin(), (*vertex_type_offsets).end()), + "Invalid input arguments: if vertex_type_offsets is valid, " + "*vertex_type_offsets should be sorted."); + vertex_t front_element{}; + raft::update_host( + &front_element, (*vertex_type_offsets).data(), size_t{1}, handle.get_stream()); + vertex_t back_element{}; + raft::update_host(&back_element, + (*vertex_type_offsets).data() + num_vertex_types, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + CUGRAPH_EXPECTS( + front_element == vertex_t{0}, + "Invalid input arguments: if vertex_type_offsets is valid, the first element of " + "*vertex_type_offsets should be 0."); + vertex_t max_v = std::max(thrust::reduce(handle.get_thrust_policy(), + edgelist_majors.begin(), + edgelist_majors.end(), + vertex_t{0}, + thrust::maximum{}), + thrust::reduce(handle.get_thrust_policy(), + edgelist_minors.begin(), + edgelist_minors.end(), + vertex_t{0}, + thrust::maximum{})); + CUGRAPH_EXPECTS( + back_element > max_v, + "Invalid input arguments: if vertex_type_offsets is valid, the last element of " + "*vertex_type_offsets should be larger than the maximum vertex ID in edgelist_majors & " + "edgelist_minors."); + + rmm::device_uvector tmp_majors(edgelist_majors.size(), handle.get_stream()); + rmm::device_uvector tmp_minors(edgelist_minors.size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + edgelist_majors.begin(), + edgelist_majors.end(), + tmp_majors.begin()); + thrust::copy(handle.get_thrust_policy(), + edgelist_minors.begin(), + edgelist_minors.end(), + tmp_minors.begin()); + if (edgelist_edge_types) { + rmm::device_uvector tmp_edge_types((*edgelist_edge_types).size(), + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*edgelist_edge_types).begin(), + (*edgelist_edge_types).end(), + tmp_edge_types.begin()); + auto triplet_first = + thrust::make_zip_iterator(tmp_edge_types.begin(), tmp_majors.begin(), tmp_minors.begin()); + thrust::sort(handle.get_thrust_policy(), triplet_first, triplet_first + tmp_majors.size()); + CUGRAPH_EXPECTS( + thrust::count_if( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(tmp_majors.size()), + [vertex_type_offsets = *vertex_type_offsets, triplet_first] __device__(size_t i) { + if (i > 0) { + auto prev = *(triplet_first + i - 1); + auto cur = *(triplet_first + i); + if (thrust::get<0>(prev) == thrust::get<0>(cur)) { // same edge type + auto prev_major_v_type = + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<1>(prev))); + auto cur_major_v_type = + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<1>(cur))); + if (prev_major_v_type != cur_major_v_type) { return true; } + auto prev_minor_v_type = + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<2>(prev))); + auto cur_minor_v_type = + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<2>(cur))); + if (prev_minor_v_type != cur_minor_v_type) { return true; } + } + } + return false; + }) == 0, + "Invalid input arguments: if vertex_type_offsets and edgelist_edge_types are valid, the " + "entire set of input edge source vertices for each edge type should have an identical " + "vertex type, and the entire set of input edge destination vertices for each type should " + "have an identical vertex type."); + } else { + auto pair_first = thrust::make_zip_iterator(tmp_majors.begin(), tmp_minors.begin()); + thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + tmp_majors.size()); + CUGRAPH_EXPECTS( + thrust::count_if( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(tmp_majors.size()), + [vertex_type_offsets = *vertex_type_offsets, pair_first] __device__(size_t i) { + if (i > 0) { + auto prev = *(pair_first + i - 1); + auto cur = *(pair_first + i); + auto prev_src_v_type = + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<0>(prev))); + auto cur_src_v_type = + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<0>(cur))); + if (prev_src_v_type != cur_src_v_type) { return true; } + auto prev_dst_v_type = + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<1>(prev))); + auto cur_dst_v_type = + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<1>(cur))); + if (prev_dst_v_type != cur_dst_v_type) { return true; } + } + return false; + }) == 0, + "Invalid input arguments: if vertex_type_offsets is valid (but " + "edgelist_edge_types is invalid), the entire set of input edge source " + "vertices should have an identical vertex type, and the entire set of " + "input edge destination vertices should have an identical vertex type."); + } + } + if (seed_vertices) { for (size_t i = 0; i < num_labels; ++i) { rmm::device_uvector this_label_seed_vertices(0, handle.get_stream()); @@ -356,7 +539,7 @@ compute_min_hop_for_unique_label_vertex_pairs( std::optional> seed_vertex_label_offsets, std::optional> edgelist_label_offsets) { - auto approx_edges_to_sort_per_iteration = + auto approx_items_to_sort_per_iteration = static_cast(handle.get_device_properties().multiProcessorCount) * (1 << 18) /* tuning parameter */; // for segmented sort @@ -369,7 +552,7 @@ compute_min_hop_for_unique_label_vertex_pairs( detail::compute_offset_aligned_element_chunks(handle, *edgelist_label_offsets, edgelist_vertices.size(), - approx_edges_to_sort_per_iteration); + approx_items_to_sort_per_iteration); auto num_chunks = h_label_offsets.size() - 1; if (edgelist_hops) { @@ -406,28 +589,28 @@ compute_min_hop_for_unique_label_vertex_pairs( } tmp_indices.resize( - thrust::distance( - tmp_indices.begin(), - thrust::unique(handle.get_thrust_policy(), - tmp_indices.begin(), - tmp_indices.end(), - [edgelist_label_offsets = *edgelist_label_offsets, - edgelist_vertices, - edgelist_hops = *edgelist_hops] __device__(size_t l_idx, size_t r_idx) { - auto l_it = thrust::upper_bound(thrust::seq, - edgelist_label_offsets.begin() + 1, - edgelist_label_offsets.end(), - l_idx); - auto r_it = thrust::upper_bound(thrust::seq, - edgelist_label_offsets.begin() + 1, - edgelist_label_offsets.end(), - r_idx); - if (l_it != r_it) { return false; } - - auto l_vertex = edgelist_vertices[l_idx]; - auto r_vertex = edgelist_vertices[r_idx]; - return l_vertex == r_vertex; - })), + thrust::distance(tmp_indices.begin(), + thrust::unique(handle.get_thrust_policy(), + tmp_indices.begin(), + tmp_indices.end(), + [edgelist_label_offsets = *edgelist_label_offsets, + edgelist_vertices] __device__(size_t l_idx, size_t r_idx) { + auto l_it = + thrust::upper_bound(thrust::seq, + edgelist_label_offsets.begin() + 1, + edgelist_label_offsets.end(), + l_idx); + auto r_it = + thrust::upper_bound(thrust::seq, + edgelist_label_offsets.begin() + 1, + edgelist_label_offsets.end(), + r_idx); + if (l_it != r_it) { return false; } + + auto l_vertex = edgelist_vertices[l_idx]; + auto r_vertex = edgelist_vertices[r_idx]; + return l_vertex == r_vertex; + })), handle.get_stream()); tmp_label_indices.resize(tmp_indices.size(), handle.get_stream()); @@ -859,17 +1042,23 @@ compute_min_hop_for_unique_label_vertex_pairs( } } -template -std::tuple, std::optional>> -compute_renumber_map(raft::handle_t const& handle, - raft::device_span edgelist_majors, - raft::device_span edgelist_minors, - std::optional> edgelist_hops, - std::optional> seed_vertices, - std::optional> seed_vertex_label_offsets, - std::optional> edgelist_label_offsets) +// returns renumber map & optional (label, type) offsets +// indices are non-descedning) +template +std::tuple, std::optional>> +compute_vertex_renumber_map( + raft::handle_t const& handle, + raft::device_span edgelist_majors, + raft::device_span edgelist_minors, + std::optional> edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + std::optional> vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types) { - auto approx_edges_to_sort_per_iteration = + auto approx_items_to_sort_per_iteration = static_cast(handle.get_device_properties().multiProcessorCount) * (1 << 20) /* tuning parameter */; // for segmented sort @@ -892,10 +1081,9 @@ compute_renumber_map(raft::handle_t const& handle, compute_min_hop_for_unique_label_vertex_pairs( handle, edgelist_minors, edgelist_hops, std::nullopt, std::nullopt, edgelist_label_offsets); + rmm::device_uvector renumber_map(0, handle.get_stream()); + std::optional> renumber_map_label_type_offsets{std::nullopt}; if (edgelist_label_offsets) { - auto num_labels = (*edgelist_label_offsets).size() - 1; - - rmm::device_uvector renumber_map(0, handle.get_stream()); rmm::device_uvector renumber_map_label_indices(0, handle.get_stream()); renumber_map.reserve((*unique_label_major_pair_label_indices).size() + @@ -903,8 +1091,8 @@ compute_renumber_map(raft::handle_t const& handle, handle.get_stream()); renumber_map_label_indices.reserve(renumber_map.capacity(), handle.get_stream()); - auto num_chunks = (edgelist_majors.size() + (approx_edges_to_sort_per_iteration - 1)) / - approx_edges_to_sort_per_iteration; + auto num_chunks = (edgelist_majors.size() + (approx_items_to_sort_per_iteration - 1)) / + approx_items_to_sort_per_iteration; auto chunk_size = (num_chunks > 0) ? ((num_labels + (num_chunks - 1)) / num_chunks) : 0; size_t copy_offset{0}; @@ -963,12 +1151,37 @@ compute_renumber_map(raft::handle_t const& handle, merged_vertices.resize(merged_label_indices.size(), handle.get_stream()); merged_hops.resize(merged_label_indices.size(), handle.get_stream()); merged_flags.resize(merged_label_indices.size(), handle.get_stream()); - auto sort_key_first = thrust::make_zip_iterator( - merged_label_indices.begin(), merged_hops.begin(), merged_flags.begin()); - thrust::sort_by_key(handle.get_thrust_policy(), - sort_key_first, - sort_key_first + merged_label_indices.size(), - merged_vertices.begin()); + if (vertex_type_offsets) { + auto quadraplet_first = thrust::make_zip_iterator(merged_label_indices.begin(), + merged_vertices.begin(), + merged_hops.begin(), + merged_flags.begin()); + thrust::sort( + handle.get_thrust_policy(), + quadraplet_first, + quadraplet_first + merged_vertices.size(), + [offsets = *vertex_type_offsets] __device__(auto lhs, auto rhs) { + auto lhs_v_type = thrust::distance( + offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, offsets.begin() + 1, offsets.end(), thrust::get<1>(lhs))); + auto rhs_v_type = thrust::distance( + offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, offsets.begin() + 1, offsets.end(), thrust::get<1>(rhs))); + return thrust::make_tuple( + thrust::get<0>(lhs), lhs_v_type, thrust::get<2>(lhs), thrust::get<3>(lhs)) < + thrust::make_tuple( + thrust::get<0>(rhs), rhs_v_type, thrust::get<2>(rhs), thrust::get<3>(rhs)); + }); + } else { + auto sort_key_first = thrust::make_zip_iterator( + merged_label_indices.begin(), merged_hops.begin(), merged_flags.begin()); + thrust::sort_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_label_indices.size(), + merged_vertices.begin()); + } } else { auto major_triplet_first = thrust::make_zip_iterator((*unique_label_major_pair_label_indices).begin(), @@ -999,12 +1212,33 @@ compute_renumber_map(raft::handle_t const& handle, handle.get_stream()); merged_vertices.resize(merged_label_indices.size(), handle.get_stream()); merged_flags.resize(merged_label_indices.size(), handle.get_stream()); - auto sort_key_first = - thrust::make_zip_iterator(merged_label_indices.begin(), merged_flags.begin()); - thrust::sort_by_key(handle.get_thrust_policy(), - sort_key_first, - sort_key_first + merged_label_indices.size(), - merged_vertices.begin()); + if (vertex_type_offsets) { + auto triplet_first = thrust::make_zip_iterator( + merged_label_indices.begin(), merged_vertices.begin(), merged_flags.begin()); + thrust::sort( + handle.get_thrust_policy(), + triplet_first, + triplet_first + merged_vertices.size(), + [offsets = *vertex_type_offsets] __device__(auto lhs, auto rhs) { + auto lhs_v_type = thrust::distance( + offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, offsets.begin() + 1, offsets.end(), thrust::get<1>(lhs))); + auto rhs_v_type = thrust::distance( + offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, offsets.begin() + 1, offsets.end(), thrust::get<1>(rhs))); + return thrust::make_tuple(thrust::get<0>(lhs), lhs_v_type, thrust::get<2>(lhs)) < + thrust::make_tuple(thrust::get<0>(rhs), rhs_v_type, thrust::get<2>(rhs)); + }); + } else { + auto sort_key_first = + thrust::make_zip_iterator(merged_label_indices.begin(), merged_flags.begin()); + thrust::sort_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_label_indices.size(), + merged_vertices.begin()); + } } renumber_map.resize(copy_offset + merged_vertices.size(), handle.get_stream()); @@ -1025,7 +1259,41 @@ compute_renumber_map(raft::handle_t const& handle, renumber_map.shrink_to_fit(handle.get_stream()); renumber_map_label_indices.shrink_to_fit(handle.get_stream()); - return std::make_tuple(std::move(renumber_map), std::move(renumber_map_label_indices)); + renumber_map_label_type_offsets = + rmm::device_uvector(num_labels * num_vertex_types + 1, handle.get_stream()); + (*renumber_map_label_type_offsets).set_element_to_zero_async(0, handle.get_stream()); + if (vertex_type_offsets) { + auto label_type_pair_first = thrust::make_zip_iterator( + renumber_map_label_indices.begin(), + thrust::make_transform_iterator( + renumber_map.begin(), + cuda::proclaim_return_type( + [offsets = *vertex_type_offsets] __device__(auto v) { + return static_cast(thrust::distance( + offsets.begin() + 1, + thrust::upper_bound(thrust::seq, offsets.begin() + 1, offsets.end(), v))); + }))); + auto value_first = thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type>( + [num_vertex_types] __device__(size_t i) { + return thrust::make_tuple(static_cast(i / num_vertex_types), + static_cast(i % num_vertex_types)); + })); + thrust::upper_bound(handle.get_thrust_policy(), + label_type_pair_first, + label_type_pair_first + renumber_map.size(), + value_first, + value_first + (num_labels * num_vertex_types), + (*renumber_map_label_type_offsets).begin() + 1); + } else { + thrust::upper_bound(handle.get_thrust_policy(), + renumber_map_label_indices.begin(), + renumber_map_label_indices.end(), + thrust::make_counting_iterator(label_index_t{0}), + thrust::make_counting_iterator(static_cast(num_labels)), + (*renumber_map_label_type_offsets).begin() + 1); + } } else { if (edgelist_hops) { rmm::device_uvector merged_vertices( @@ -1067,13 +1335,34 @@ compute_renumber_map(raft::handle_t const& handle, merged_hops.resize(merged_vertices.size(), handle.get_stream()); merged_flags.resize(merged_vertices.size(), handle.get_stream()); - auto sort_key_first = thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin()); - thrust::sort_by_key(handle.get_thrust_policy(), - sort_key_first, - sort_key_first + merged_hops.size(), - merged_vertices.begin()); + if (vertex_type_offsets) { + auto triplet_first = thrust::make_zip_iterator( + merged_vertices.begin(), merged_hops.begin(), merged_flags.begin()); + thrust::sort( + handle.get_thrust_policy(), + triplet_first, + triplet_first + merged_vertices.size(), + [offsets = *vertex_type_offsets] __device__(auto lhs, auto rhs) { + auto lhs_v_type = thrust::distance( + offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, offsets.begin() + 1, offsets.end(), thrust::get<0>(lhs))); + auto rhs_v_type = thrust::distance( + offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, offsets.begin() + 1, offsets.end(), thrust::get<0>(rhs))); + return thrust::make_tuple(lhs_v_type, thrust::get<1>(lhs), thrust::get<2>(lhs)) < + thrust::make_tuple(rhs_v_type, thrust::get<1>(rhs), thrust::get<2>(rhs)); + }); + } else { + auto sort_key_first = thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin()); + thrust::sort_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_hops.size(), + merged_vertices.begin()); + } - return std::make_tuple(std::move(merged_vertices), std::nullopt); + renumber_map = std::move(merged_vertices); } else { rmm::device_uvector output_vertices(unique_label_minor_pair_vertices.size(), handle.get_stream()); @@ -1085,7 +1374,7 @@ compute_renumber_map(raft::handle_t const& handle, output_vertices.begin()); auto num_unique_majors = unique_label_major_pair_vertices.size(); - auto renumber_map = std::move(unique_label_major_pair_vertices); + renumber_map = std::move(unique_label_major_pair_vertices); renumber_map.resize( renumber_map.size() + thrust::distance(output_vertices.begin(), output_last), handle.get_stream()); @@ -1094,9 +1383,370 @@ compute_renumber_map(raft::handle_t const& handle, output_last, renumber_map.begin() + num_unique_majors); - return std::make_tuple(std::move(renumber_map), std::nullopt); + if (vertex_type_offsets) { + thrust::stable_sort( + handle.get_thrust_policy(), + renumber_map.begin(), + renumber_map.end(), + [offsets = *vertex_type_offsets] __device__(auto lhs, auto rhs) { + auto lhs_v_type = thrust::distance( + offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, offsets.begin() + 1, offsets.end(), thrust::get<0>(lhs))); + auto rhs_v_type = thrust::distance( + offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, offsets.begin() + 1, offsets.end(), thrust::get<0>(rhs))); + return lhs_v_type < rhs_v_type; + }); + } + } + + if (vertex_type_offsets) { + renumber_map_label_type_offsets = + rmm::device_uvector(num_vertex_types + 1, handle.get_stream()); + (*renumber_map_label_type_offsets).set_element_to_zero_async(0, handle.get_stream()); + auto type_first = thrust::make_transform_iterator( + renumber_map.begin(), + cuda::proclaim_return_type( + [offsets = *vertex_type_offsets] __device__(auto v) { + return static_cast(thrust::distance( + offsets.begin() + 1, + thrust::upper_bound(thrust::seq, offsets.begin() + 1, offsets.end(), v))); + })); + thrust::upper_bound( + handle.get_thrust_policy(), + type_first, + type_first + renumber_map.size(), + thrust::make_counting_iterator(vertex_type_t{0}), + thrust::make_counting_iterator(static_cast(num_vertex_types)), + (*renumber_map_label_type_offsets).begin() + 1); + } + } + + return std::make_tuple(std::move(renumber_map), std::move(renumber_map_label_type_offsets)); +} + +// returns renumber map & optional (label, type) offsets +template +std::tuple, std::optional>> +compute_edge_id_renumber_map( + raft::handle_t const& handle, + raft::device_span edgelist_edge_ids, + std::optional> edgelist_edge_types, + std::optional> edgelist_hops, + std::optional> edgelist_label_offsets, + size_t num_labels, + size_t num_edge_types) +{ + rmm::device_uvector renumber_map(0, handle.get_stream()); + std::optional> renumber_map_label_type_offsets{std::nullopt}; + if (edgelist_label_offsets) { + auto approx_items_to_sort_per_iteration = + static_cast(handle.get_device_properties().multiProcessorCount) * + (1 << 20) /* tuning parameter */; // for segmented sort + + auto [h_label_offsets, h_edge_offsets] = + detail::compute_offset_aligned_element_chunks(handle, + *edgelist_label_offsets, + edgelist_edge_ids.size(), + approx_items_to_sort_per_iteration); + auto num_chunks = h_label_offsets.size() - 1; + + rmm::device_uvector tmp_indices(edgelist_edge_ids.size(), handle.get_stream()); + thrust::sequence(handle.get_thrust_policy(), tmp_indices.begin(), tmp_indices.end(), size_t{0}); + + // cub::DeviceSegmentedSort currently does not suuport thrust::tuple type keys, sorting in + // chunks still helps in limiting the binary search range and improving memory locality + for (size_t i = 0; i < num_chunks; ++i) { + // sort by (label, (type), id, (hop)) + + thrust::sort( + handle.get_thrust_policy(), + tmp_indices.begin() + h_edge_offsets[i], + tmp_indices.begin() + h_edge_offsets[i + 1], + [edgelist_label_offsets = + raft::device_span((*edgelist_label_offsets).data() + h_label_offsets[i], + (h_label_offsets[i + 1] - h_label_offsets[i]) + 1), + edgelist_edge_types = detail::to_thrust_optional(edgelist_edge_types), + edgelist_edge_ids, + edgelist_hops = detail::to_thrust_optional(edgelist_hops)] __device__(size_t l_idx, + size_t r_idx) { + auto l_it = thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), l_idx); + auto r_it = thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), r_idx); + if (l_it != r_it) { return l_it < r_it; } + + if (edgelist_edge_types) { + auto l_type = (*edgelist_edge_types)[l_idx]; + auto r_type = (*edgelist_edge_types)[r_idx]; + if (l_type != r_type) { return l_type < r_type; } + } + + auto l_id = edgelist_edge_ids[l_idx]; + auto r_id = edgelist_edge_ids[r_idx]; + if (l_id != r_id) { return l_id < r_id; } + + if (edgelist_hops) { + auto l_hop = (*edgelist_hops)[l_idx]; + auto r_hop = (*edgelist_hops)[r_idx]; + return l_hop < r_hop; + } + + return false; + }); + + // find unique (label, (type), id, (min_hop)) tuples + + auto last = thrust::unique( + handle.get_thrust_policy(), + tmp_indices.begin() + h_edge_offsets[i], + tmp_indices.begin() + h_edge_offsets[i + 1], + [edgelist_label_offsets = *edgelist_label_offsets, + edgelist_edge_types = detail::to_thrust_optional(edgelist_edge_types), + edgelist_edge_ids] __device__(size_t l_idx, size_t r_idx) { + auto l_it = thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), l_idx); + auto r_it = thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), r_idx); + if (l_it != r_it) { return false; } + + if (edgelist_edge_types) { + auto l_type = (*edgelist_edge_types)[l_idx]; + auto r_type = (*edgelist_edge_types)[r_idx]; + if (l_type != r_type) { return false; } + } + + auto l_id = edgelist_edge_ids[l_idx]; + auto r_id = edgelist_edge_ids[r_idx]; + return l_id == r_id; + }); + + // sort by (label, (type), (min_hop), id) + + if (edgelist_hops) { + thrust::sort( + handle.get_thrust_policy(), + tmp_indices.begin() + h_edge_offsets[i], + last, + [edgelist_label_offsets = + raft::device_span((*edgelist_label_offsets).data() + h_label_offsets[i], + (h_label_offsets[i + 1] - h_label_offsets[i]) + 1), + edgelist_edge_types = detail::to_thrust_optional(edgelist_edge_types), + edgelist_edge_ids, + edgelist_hops = detail::to_thrust_optional(edgelist_hops)] __device__(size_t l_idx, + size_t r_idx) { + auto l_it = thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), l_idx); + auto r_it = thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), r_idx); + if (l_it != r_it) { return l_it < r_it; } + + if (edgelist_edge_types) { + auto l_type = (*edgelist_edge_types)[l_idx]; + auto r_type = (*edgelist_edge_types)[r_idx]; + if (l_type != r_type) { return l_type < r_type; } + } + + if (edgelist_hops) { + auto l_hop = (*edgelist_hops)[l_idx]; + auto r_hop = (*edgelist_hops)[r_idx]; + return l_hop < r_hop; + } + + auto l_id = edgelist_edge_ids[l_idx]; + auto r_id = edgelist_edge_ids[r_idx]; + if (l_id != r_id) { return l_id < r_id; } + + return false; + }); + } + + // mark invalid indices + + thrust::fill(handle.get_thrust_policy(), + last, + tmp_indices.begin() + h_edge_offsets[i + 1], + std::numeric_limits::max()); + } + + tmp_indices.resize(thrust::distance(tmp_indices.begin(), + thrust::remove(handle.get_thrust_policy(), + tmp_indices.begin(), + tmp_indices.end(), + std::numeric_limits::max())), + handle.get_stream()); + + renumber_map = rmm::device_uvector(tmp_indices.size(), handle.get_stream()); + thrust::gather(handle.get_thrust_policy(), + tmp_indices.begin(), + tmp_indices.end(), + edgelist_edge_ids.begin(), + renumber_map.begin()); + + renumber_map_label_type_offsets = + rmm::device_uvector(num_labels * num_edge_types + 1, handle.get_stream()); + (*renumber_map_label_type_offsets).set_element_to_zero_async(0, handle.get_stream()); + if (edgelist_edge_types) { + auto label_type_pair_first = thrust::make_transform_iterator( + tmp_indices.begin(), + cuda::proclaim_return_type>( + [edgelist_label_offsets = *edgelist_label_offsets, + edgelist_edge_types = *edgelist_edge_types] __device__(size_t i) { + auto label_idx = thrust::distance( + edgelist_label_offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), i)); + return thrust::make_tuple(static_cast(label_idx), + edgelist_edge_types[i]); + })); + auto value_first = thrust::make_transform_iterator( + thrust::make_counting_iterator(size_t{0}), + cuda::proclaim_return_type>( + [num_edge_types] __device__(size_t i) { + return thrust::make_tuple(static_cast(i / num_edge_types), + static_cast(i % num_edge_types)); + })); + thrust::upper_bound(handle.get_thrust_policy(), + label_type_pair_first, + label_type_pair_first + renumber_map.size(), + value_first, + value_first + (num_labels * num_edge_types), + (*renumber_map_label_type_offsets).begin() + 1); + } else { + auto label_first = thrust::make_transform_iterator( + tmp_indices.begin(), + cuda::proclaim_return_type( + [edgelist_label_offsets = *edgelist_label_offsets] __device__(size_t i) { + auto label_idx = thrust::distance( + edgelist_label_offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, edgelist_label_offsets.begin() + 1, edgelist_label_offsets.end(), i)); + return static_cast(label_idx); + })); + auto value_first = thrust::make_counting_iterator(label_index_t{0}); + thrust::upper_bound(handle.get_thrust_policy(), + label_first, + label_first + renumber_map.size(), + value_first, + value_first + num_labels, + (*renumber_map_label_type_offsets).begin() + 1); + } + } else { + // copy + + std::optional> tmp_types{std::nullopt}; + if (edgelist_edge_types) { + tmp_types = + rmm::device_uvector((*edgelist_edge_types).size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*edgelist_edge_types).begin(), + (*edgelist_edge_types).end(), + (*tmp_types).begin()); + } + rmm::device_uvector tmp_ids(edgelist_edge_ids.size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + edgelist_edge_ids.begin(), + edgelist_edge_ids.end(), + tmp_ids.begin()); + std::optional> tmp_hops{std::nullopt}; + if (edgelist_hops) { + tmp_hops = rmm::device_uvector((*edgelist_hops).size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*edgelist_hops).begin(), + (*edgelist_hops).end(), + (*tmp_hops).begin()); + } + + // sort by ((type), id, (hop)) + + if (tmp_types) { + if (tmp_hops) { + auto triplet_first = + thrust::make_zip_iterator((*tmp_types).begin(), tmp_ids.begin(), (*tmp_hops).begin()); + thrust::sort(handle.get_thrust_policy(), triplet_first, triplet_first + tmp_ids.size()); + } else { + auto pair_first = thrust::make_zip_iterator((*tmp_types).begin(), tmp_ids.begin()); + thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + tmp_ids.size()); + } + } else { + if (tmp_hops) { + auto pair_first = thrust::make_zip_iterator(tmp_ids.begin(), (*tmp_hops).begin()); + thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + tmp_ids.size()); + } else { + thrust::sort(handle.get_thrust_policy(), tmp_ids.begin(), tmp_ids.end()); + } + } + + // find unique ((type), id, (min_hop)) tuples + + if (tmp_types) { + auto pair_first = thrust::make_zip_iterator((*tmp_types).begin(), tmp_ids.begin()); + if (tmp_hops) { + tmp_ids.resize( + thrust::distance(pair_first, + thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), + pair_first, + pair_first + tmp_ids.size(), + (*tmp_hops).begin()))), + handle.get_stream()); + (*tmp_hops).resize(tmp_ids.size(), handle.get_stream()); + } else { + tmp_ids.resize( + thrust::distance( + pair_first, + thrust::unique(handle.get_thrust_policy(), pair_first, pair_first + tmp_ids.size())), + handle.get_stream()); + } + (*tmp_types).resize(tmp_ids.size(), handle.get_stream()); + } else { + if (tmp_hops) { + tmp_ids.resize( + thrust::distance( + tmp_ids.begin(), + thrust::get<0>(thrust::unique_by_key( + handle.get_thrust_policy(), tmp_ids.begin(), tmp_ids.end(), (*tmp_hops).begin()))), + handle.get_stream()); + (*tmp_hops).resize(tmp_ids.size(), handle.get_stream()); + } else { + tmp_ids.resize( + thrust::distance( + tmp_ids.begin(), + thrust::unique(handle.get_thrust_policy(), tmp_ids.begin(), tmp_ids.end())), + handle.get_stream()); + } + } + + // sort by ((type), (min_hop), id) + + if (tmp_hops) { + if (tmp_types) { + auto triplet_first = + thrust::make_zip_iterator((*tmp_types).begin(), (*tmp_hops).begin(), tmp_ids.begin()); + thrust::sort(handle.get_thrust_policy(), triplet_first, triplet_first + tmp_ids.size()); + } else { + auto pair_first = thrust::make_zip_iterator((*tmp_hops).begin(), tmp_ids.begin()); + thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + tmp_ids.size()); + } + } + + renumber_map = std::move(tmp_ids); + + if (tmp_types) { + renumber_map_label_type_offsets = + rmm::device_uvector(num_edge_types + 1, handle.get_stream()); + (*renumber_map_label_type_offsets).set_element_to_zero_async(0, handle.get_stream()); + thrust::upper_bound(handle.get_thrust_policy(), + (*tmp_types).begin(), + (*tmp_types).end(), + thrust::make_counting_iterator(edge_type_t{0}), + thrust::make_counting_iterator(static_cast(num_edge_types)), + (*renumber_map_label_type_offsets).begin() + 1); } } + + return std::make_tuple(std::move(renumber_map), std::move(renumber_map_label_type_offsets)); } // this function does not reorder edges (the i'th returned edge is the renumbered output of the @@ -1117,74 +1767,45 @@ renumber_sampled_edgelist(raft::handle_t const& handle, size_t num_labels, bool do_expensive_check) { - // 1. compute renumber_map + using vertex_type_t = uint32_t; // dummy - auto [renumber_map, renumber_map_label_indices] = compute_renumber_map( - handle, - raft::device_span(edgelist_majors.data(), edgelist_majors.size()), - raft::device_span(edgelist_minors.data(), edgelist_minors.size()), - edgelist_hops, - seed_vertices ? std::make_optional>((*seed_vertices).data(), - (*seed_vertices).size()) - : std::nullopt, - seed_vertex_label_offsets, - edgelist_label_offsets); - - // 2. compute renumber map offsets for each label + // 1. compute renumber_map - std::optional> renumber_map_label_offsets{}; - if (edgelist_label_offsets) { - auto num_unique_labels = thrust::count_if( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator((*renumber_map_label_indices).size()), - detail::is_first_in_run_t{(*renumber_map_label_indices).data()}); - rmm::device_uvector unique_label_indices(num_unique_labels, handle.get_stream()); - rmm::device_uvector vertex_counts(num_unique_labels, handle.get_stream()); - thrust::reduce_by_key(handle.get_thrust_policy(), - (*renumber_map_label_indices).begin(), - (*renumber_map_label_indices).end(), - thrust::make_constant_iterator(size_t{1}), - unique_label_indices.begin(), - vertex_counts.begin()); - - renumber_map_label_offsets = rmm::device_uvector(num_labels + 1, handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), - (*renumber_map_label_offsets).begin(), - (*renumber_map_label_offsets).end(), - size_t{0}); - thrust::scatter(handle.get_thrust_policy(), - vertex_counts.begin(), - vertex_counts.end(), - unique_label_indices.begin(), - (*renumber_map_label_offsets).begin() + 1); - - thrust::inclusive_scan(handle.get_thrust_policy(), - (*renumber_map_label_offsets).begin(), - (*renumber_map_label_offsets).end(), - (*renumber_map_label_offsets).begin()); - } + auto [renumber_map, renumber_map_label_offsets] = + compute_vertex_renumber_map( + handle, + raft::device_span(edgelist_majors.data(), edgelist_majors.size()), + raft::device_span(edgelist_minors.data(), edgelist_minors.size()), + edgelist_hops, + seed_vertices ? std::make_optional>((*seed_vertices).data(), + (*seed_vertices).size()) + : std::nullopt, + seed_vertex_label_offsets, + edgelist_label_offsets, + std::nullopt, + num_labels, + size_t{1}); - // 3. renumber input edges + // 2. renumber input edges if (edgelist_label_offsets) { rmm::device_uvector new_vertices(renumber_map.size(), handle.get_stream()); thrust::tabulate(handle.get_thrust_policy(), new_vertices.begin(), new_vertices.end(), - [label_indices = raft::device_span( - (*renumber_map_label_indices).data(), (*renumber_map_label_indices).size()), - renumber_map_label_offsets = raft::device_span( + [renumber_map_label_offsets = raft::device_span( (*renumber_map_label_offsets).data(), (*renumber_map_label_offsets).size())] __device__(size_t i) { - auto label_index = label_indices[i]; + auto label_index = static_cast(thrust::distance( + renumber_map_label_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + renumber_map_label_offsets.begin() + 1, + renumber_map_label_offsets.end(), + i))); auto label_start_offset = renumber_map_label_offsets[label_index]; return static_cast(i - label_start_offset); }); - (*renumber_map_label_indices).resize(0, handle.get_stream()); - (*renumber_map_label_indices).shrink_to_fit(handle.get_stream()); - rmm::device_uvector segment_sorted_renumber_map(renumber_map.size(), handle.get_stream()); rmm::device_uvector segment_sorted_new_vertices(new_vertices.size(), @@ -1192,7 +1813,7 @@ renumber_sampled_edgelist(raft::handle_t const& handle, rmm::device_uvector d_tmp_storage(0, handle.get_stream()); - auto approx_edges_to_sort_per_iteration = + auto approx_items_to_sort_per_iteration = static_cast(handle.get_device_properties().multiProcessorCount) * (1 << 20) /* tuning parameter */; // for segmented sort @@ -1201,7 +1822,7 @@ renumber_sampled_edgelist(raft::handle_t const& handle, raft::device_span{(*renumber_map_label_offsets).data(), (*renumber_map_label_offsets).size()}, renumber_map.size(), - approx_edges_to_sort_per_iteration); + approx_items_to_sort_per_iteration); auto num_chunks = h_label_offsets.size() - 1; for (size_t i = 0; i < num_chunks; ++i) { @@ -1369,6 +1990,455 @@ renumber_sampled_edgelist(raft::handle_t const& handle, std::move(renumber_map_label_offsets)); } +// this function does not reorder edges (the i'th returned edge is the renumbered output of the +// i'th input edge) +template +std::tuple< + rmm::device_uvector, // edgelist_majors + rmm::device_uvector, // edgelist minors + std::optional>, // edgelist edge IDs + std::optional>, // seed_vertices, + rmm::device_uvector, // vertex renumber_map + rmm::device_uvector, // vertex renumber_map (label, vertex type) offsets + std::optional>, // edge ID renumber map + std::optional>> // edge ID renumber map (label, edge type) offsets +heterogeneous_renumber_sampled_edgelist( + raft::handle_t const& handle, + rmm::device_uvector&& edgelist_majors, + rmm::device_uvector&& edgelist_minors, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional> edgelist_hops, + std::optional>&& seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + size_t num_edge_types, + bool do_expensive_check) +{ + // 1. compute vertex renumber map + + auto [vertex_renumber_map, vertex_renumber_map_label_type_offsets] = + compute_vertex_renumber_map( + handle, + raft::device_span(edgelist_majors.data(), edgelist_majors.size()), + raft::device_span(edgelist_minors.data(), edgelist_minors.size()), + edgelist_hops, + seed_vertices ? std::make_optional>((*seed_vertices).data(), + (*seed_vertices).size()) + : std::nullopt, + seed_vertex_label_offsets, + edgelist_label_offsets, + std::make_optional(vertex_type_offsets), + num_labels, + num_vertex_types); + assert(vertex_renumber_map_label_type_offsets.has_value()); + + // 2. compute edge renumber map + + std::optional> edge_id_renumber_map{std::nullopt}; + std::optional> edge_id_renumber_map_label_type_offsets{std::nullopt}; + if (edgelist_edge_ids) { + std::tie(edge_id_renumber_map, edge_id_renumber_map_label_type_offsets) = + compute_edge_id_renumber_map( + handle, + raft::device_span((*edgelist_edge_ids).data(), + (*edgelist_edge_ids).size()), + edgelist_edge_types, + edgelist_hops, + edgelist_label_offsets, + num_labels, + num_edge_types); + } + + auto approx_items_to_sort_per_iteration = + static_cast(handle.get_device_properties().multiProcessorCount) * + (1 << 20) /* tuning parameter */; // for segmented sort + + // 3. renumber input edge source/destination vertices + + { + rmm::device_uvector new_vertices(vertex_renumber_map.size(), handle.get_stream()); + thrust::tabulate(handle.get_thrust_policy(), + new_vertices.begin(), + new_vertices.end(), + [renumber_map_label_type_offsets = raft::device_span( + (*vertex_renumber_map_label_type_offsets).data(), + (*vertex_renumber_map_label_type_offsets).size())] __device__(size_t i) { + auto idx = static_cast(thrust::distance( + renumber_map_label_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + renumber_map_label_type_offsets.begin() + 1, + renumber_map_label_type_offsets.end(), + i))); + auto start_offset = renumber_map_label_type_offsets[idx]; + return static_cast(i - start_offset); + }); + + rmm::device_uvector segment_sorted_vertex_renumber_map(vertex_renumber_map.size(), + handle.get_stream()); + rmm::device_uvector segment_sorted_new_vertices(new_vertices.size(), + handle.get_stream()); + + rmm::device_uvector d_tmp_storage(0, handle.get_stream()); + + auto [h_label_offsets, h_edge_offsets] = detail::compute_offset_aligned_element_chunks( + handle, + raft::device_span{(*vertex_renumber_map_label_type_offsets).data(), + (*vertex_renumber_map_label_type_offsets).size()}, + vertex_renumber_map.size(), + approx_items_to_sort_per_iteration); + auto num_chunks = h_label_offsets.size() - 1; + + for (size_t i = 0; i < num_chunks; ++i) { + size_t tmp_storage_bytes{0}; + + auto offset_first = thrust::make_transform_iterator( + (*vertex_renumber_map_label_type_offsets).data() + h_label_offsets[i], + detail::shift_left_t{h_edge_offsets[i]}); + cub::DeviceSegmentedSort::SortPairs( + static_cast(nullptr), + tmp_storage_bytes, + vertex_renumber_map.begin() + h_edge_offsets[i], + segment_sorted_vertex_renumber_map.begin() + h_edge_offsets[i], + new_vertices.begin() + h_edge_offsets[i], + segment_sorted_new_vertices.begin() + h_edge_offsets[i], + h_edge_offsets[i + 1] - h_edge_offsets[i], + h_label_offsets[i + 1] - h_label_offsets[i], + offset_first, + offset_first + 1, + handle.get_stream()); + + if (tmp_storage_bytes > d_tmp_storage.size()) { + d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); + } + + cub::DeviceSegmentedSort::SortPairs( + d_tmp_storage.data(), + tmp_storage_bytes, + vertex_renumber_map.begin() + h_edge_offsets[i], + segment_sorted_vertex_renumber_map.begin() + h_edge_offsets[i], + new_vertices.begin() + h_edge_offsets[i], + segment_sorted_new_vertices.begin() + h_edge_offsets[i], + h_edge_offsets[i + 1] - h_edge_offsets[i], + h_label_offsets[i + 1] - h_label_offsets[i], + offset_first, + offset_first + 1, + handle.get_stream()); + } + + new_vertices.resize(0, handle.get_stream()); + new_vertices.shrink_to_fit(handle.get_stream()); + + auto pair_first = + thrust::make_zip_iterator(edgelist_majors.begin(), thrust::make_counting_iterator(size_t{0})); + thrust::transform( + handle.get_thrust_policy(), + pair_first, + pair_first + edgelist_majors.size(), + edgelist_majors.begin(), + [edgelist_label_offsets = detail::to_thrust_optional(edgelist_label_offsets), + vertex_type_offsets, + renumber_map_label_type_offsets = + raft::device_span((*vertex_renumber_map_label_type_offsets).data(), + (*vertex_renumber_map_label_type_offsets).size()), + old_vertices = raft::device_span(segment_sorted_vertex_renumber_map.data(), + segment_sorted_vertex_renumber_map.size()), + new_vertices = raft::device_span(segment_sorted_new_vertices.data(), + segment_sorted_new_vertices.size()), + num_vertex_types] __device__(auto pair) { + auto old_vertex = thrust::get<0>(pair); + label_index_t label_idx{0}; + if (edgelist_label_offsets) { + label_idx = static_cast( + thrust::distance((*edgelist_label_offsets).begin() + 1, + thrust::upper_bound(thrust::seq, + (*edgelist_label_offsets).begin() + 1, + (*edgelist_label_offsets).end(), + thrust::get<1>(pair)))); + } + auto v_type = static_cast(thrust::distance( + vertex_type_offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, vertex_type_offsets.begin() + 1, vertex_type_offsets.end(), old_vertex))); + auto start_offset = renumber_map_label_type_offsets[label_idx * num_vertex_types + v_type]; + auto end_offset = + renumber_map_label_type_offsets[label_idx * num_vertex_types + v_type + 1]; + auto it = thrust::lower_bound(thrust::seq, + old_vertices.begin() + start_offset, + old_vertices.begin() + end_offset, + old_vertex); + assert(*it == old_vertex); + return *(new_vertices.begin() + thrust::distance(old_vertices.begin(), it)); + }); + + pair_first = + thrust::make_zip_iterator(edgelist_minors.begin(), thrust::make_counting_iterator(size_t{0})); + thrust::transform( + handle.get_thrust_policy(), + pair_first, + pair_first + edgelist_minors.size(), + edgelist_minors.begin(), + [edgelist_label_offsets = detail::to_thrust_optional(edgelist_label_offsets), + vertex_type_offsets, + renumber_map_label_type_offsets = + raft::device_span((*vertex_renumber_map_label_type_offsets).data(), + (*vertex_renumber_map_label_type_offsets).size()), + old_vertices = raft::device_span(segment_sorted_vertex_renumber_map.data(), + segment_sorted_vertex_renumber_map.size()), + new_vertices = raft::device_span(segment_sorted_new_vertices.data(), + segment_sorted_new_vertices.size()), + num_vertex_types] __device__(auto pair) { + auto old_vertex = thrust::get<0>(pair); + label_index_t label_idx{0}; + if (edgelist_label_offsets) { + label_idx = static_cast( + thrust::distance((*edgelist_label_offsets).begin() + 1, + thrust::upper_bound(thrust::seq, + (*edgelist_label_offsets).begin() + 1, + (*edgelist_label_offsets).end(), + thrust::get<1>(pair)))); + } + auto v_type = static_cast(thrust::distance( + vertex_type_offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, vertex_type_offsets.begin() + 1, vertex_type_offsets.end(), old_vertex))); + auto start_offset = renumber_map_label_type_offsets[label_idx * num_vertex_types + v_type]; + auto end_offset = + renumber_map_label_type_offsets[label_idx * num_vertex_types + v_type + 1]; + auto it = thrust::lower_bound(thrust::seq, + old_vertices.begin() + start_offset, + old_vertices.begin() + end_offset, + old_vertex); + assert(*it == old_vertex); + return *(new_vertices.begin() + thrust::distance(old_vertices.begin(), it)); + }); + + if (seed_vertices) { + pair_first = thrust::make_zip_iterator((*seed_vertices).begin(), + thrust::make_counting_iterator(size_t{0})); + thrust::transform( + handle.get_thrust_policy(), + pair_first, + pair_first + (*seed_vertices).size(), + (*seed_vertices).begin(), + [seed_vertex_label_offsets = detail::to_thrust_optional(seed_vertex_label_offsets), + vertex_type_offsets, + renumber_map_label_type_offsets = + raft::device_span((*vertex_renumber_map_label_type_offsets).data(), + (*vertex_renumber_map_label_type_offsets).size()), + old_vertices = raft::device_span( + segment_sorted_vertex_renumber_map.data(), segment_sorted_vertex_renumber_map.size()), + new_vertices = raft::device_span(segment_sorted_new_vertices.data(), + segment_sorted_new_vertices.size()), + num_vertex_types] __device__(auto pair) { + auto old_vertex = thrust::get<0>(pair); + label_index_t label_idx{0}; + if (seed_vertex_label_offsets) { + label_idx = static_cast( + thrust::distance((*seed_vertex_label_offsets).begin() + 1, + thrust::upper_bound(thrust::seq, + (*seed_vertex_label_offsets).begin() + 1, + (*seed_vertex_label_offsets).end(), + thrust::get<1>(pair)))); + } + auto v_type = static_cast( + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + old_vertex))); + auto start_offset = + renumber_map_label_type_offsets[label_idx * num_vertex_types + v_type]; + auto end_offset = + renumber_map_label_type_offsets[label_idx * num_vertex_types + v_type + 1]; + auto it = thrust::lower_bound(thrust::seq, + old_vertices.begin() + start_offset, + old_vertices.begin() + end_offset, + old_vertex); + assert(*it == old_vertex); + return new_vertices[thrust::distance(old_vertices.begin(), it)]; + }); + } + } + + // 4. renumber input edge IDs + + if (edgelist_edge_ids) { + rmm::device_uvector new_edge_ids((*edge_id_renumber_map).size(), + handle.get_stream()); + if (edge_id_renumber_map_label_type_offsets) { + thrust::tabulate(handle.get_thrust_policy(), + new_edge_ids.begin(), + new_edge_ids.end(), + [renumber_map_label_type_offsets = raft::device_span( + (*edge_id_renumber_map_label_type_offsets).data(), + (*edge_id_renumber_map_label_type_offsets).size())] __device__(size_t i) { + auto idx = static_cast(thrust::distance( + renumber_map_label_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + renumber_map_label_type_offsets.begin() + 1, + renumber_map_label_type_offsets.end(), + i))); + auto start_offset = renumber_map_label_type_offsets[idx]; + return static_cast(i - start_offset); + }); + } else { + thrust::sequence( + handle.get_thrust_policy(), new_edge_ids.begin(), new_edge_ids.end(), edge_id_t{0}); + } + + rmm::device_uvector segment_sorted_edge_id_renumber_map( + (*edge_id_renumber_map).size(), handle.get_stream()); + rmm::device_uvector segment_sorted_new_edge_ids(new_edge_ids.size(), + handle.get_stream()); + + if (edge_id_renumber_map_label_type_offsets) { + rmm::device_uvector d_tmp_storage(0, handle.get_stream()); + + auto [h_label_offsets, h_edge_offsets] = detail::compute_offset_aligned_element_chunks( + handle, + raft::device_span{(*edge_id_renumber_map_label_type_offsets).data(), + (*edge_id_renumber_map_label_type_offsets).size()}, + (*edge_id_renumber_map).size(), + approx_items_to_sort_per_iteration); + auto num_chunks = h_label_offsets.size() - 1; + + for (size_t i = 0; i < num_chunks; ++i) { + size_t tmp_storage_bytes{0}; + + auto offset_first = thrust::make_transform_iterator( + (*edge_id_renumber_map_label_type_offsets).data() + h_label_offsets[i], + detail::shift_left_t{h_edge_offsets[i]}); + cub::DeviceSegmentedSort::SortPairs( + static_cast(nullptr), + tmp_storage_bytes, + (*edge_id_renumber_map).begin() + h_edge_offsets[i], + segment_sorted_edge_id_renumber_map.begin() + h_edge_offsets[i], + new_edge_ids.begin() + h_edge_offsets[i], + segment_sorted_new_edge_ids.begin() + h_edge_offsets[i], + h_edge_offsets[i + 1] - h_edge_offsets[i], + h_label_offsets[i + 1] - h_label_offsets[i], + offset_first, + offset_first + 1, + handle.get_stream()); + + if (tmp_storage_bytes > d_tmp_storage.size()) { + d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); + } + + cub::DeviceSegmentedSort::SortPairs( + d_tmp_storage.data(), + tmp_storage_bytes, + (*edge_id_renumber_map).begin() + h_edge_offsets[i], + segment_sorted_edge_id_renumber_map.begin() + h_edge_offsets[i], + new_edge_ids.begin() + h_edge_offsets[i], + segment_sorted_new_edge_ids.begin() + h_edge_offsets[i], + h_edge_offsets[i + 1] - h_edge_offsets[i], + h_label_offsets[i + 1] - h_label_offsets[i], + offset_first, + offset_first + 1, + handle.get_stream()); + } + + new_edge_ids.resize(0, handle.get_stream()); + new_edge_ids.shrink_to_fit(handle.get_stream()); + } else { + thrust::copy(handle.get_thrust_policy(), + (*edge_id_renumber_map).begin(), + (*edge_id_renumber_map).end(), + segment_sorted_edge_id_renumber_map.begin()); + segment_sorted_new_edge_ids = std::move(new_edge_ids); + thrust::sort_by_key(handle.get_thrust_policy(), + segment_sorted_edge_id_renumber_map.begin(), + segment_sorted_edge_id_renumber_map.end(), + segment_sorted_new_edge_ids.begin()); + } + + if (edge_id_renumber_map_label_type_offsets) { + auto pair_first = thrust::make_zip_iterator((*edgelist_edge_ids).begin(), + thrust::make_counting_iterator(size_t{0})); + thrust::transform( + handle.get_thrust_policy(), + pair_first, + pair_first + (*edgelist_edge_ids).size(), + (*edgelist_edge_ids).begin(), + cuda::proclaim_return_type( + [edgelist_label_offsets = detail::to_thrust_optional(edgelist_label_offsets), + edge_types = edgelist_edge_types + ? thrust::make_optional>( + (*edgelist_edge_types).data(), (*edgelist_edge_types).size()) + : thrust::nullopt, + renumber_map = + raft::device_span(segment_sorted_edge_id_renumber_map.data(), + segment_sorted_edge_id_renumber_map.size()), + new_edge_ids = raft::device_span(segment_sorted_new_edge_ids.data(), + segment_sorted_new_edge_ids.size()), + renumber_map_label_type_offsets = + raft::device_span((*edge_id_renumber_map_label_type_offsets).data(), + (*edge_id_renumber_map_label_type_offsets).size()), + num_edge_types] __device__(auto pair) { + auto old_edge_id = thrust::get<0>(pair); + auto edge_idx = thrust::get<1>(pair); + size_t label_idx{0}; + if (edgelist_label_offsets) { + label_idx = static_cast( + thrust::distance((*edgelist_label_offsets).begin() + 1, + thrust::upper_bound(thrust::seq, + (*edgelist_label_offsets).begin() + 1, + (*edgelist_label_offsets).end(), + edge_idx))); + } + edge_type_t edge_type{0}; + if (edge_types) { edge_type = (*edge_types)[edge_idx]; } + auto renumber_map_start_offset = + renumber_map_label_type_offsets[label_idx * num_edge_types + edge_type]; + auto renumber_map_end_offset = + renumber_map_label_type_offsets[label_idx * num_edge_types + edge_type + 1]; + auto it = thrust::lower_bound(thrust::seq, + renumber_map.begin() + renumber_map_start_offset, + renumber_map.begin() + renumber_map_end_offset, + old_edge_id); + assert(*it == old_edge_id); + return *(new_edge_ids.begin() + thrust::distance(renumber_map.begin(), it)); + })); + } else { + thrust::transform( + handle.get_thrust_policy(), + (*edgelist_edge_ids).begin(), + (*edgelist_edge_ids).end(), + (*edgelist_edge_ids).begin(), + cuda::proclaim_return_type( + [renumber_map = + raft::device_span(segment_sorted_edge_id_renumber_map.data(), + segment_sorted_edge_id_renumber_map.size()), + new_edge_ids = raft::device_span( + segment_sorted_new_edge_ids.data(), + segment_sorted_new_edge_ids.size())] __device__(edge_id_t old_edge_id) { + auto it = thrust::lower_bound( + thrust::seq, renumber_map.begin(), renumber_map.end(), old_edge_id); + assert(*it == old_edge_id); + return *(new_edge_ids.begin() + thrust::distance(renumber_map.begin(), it)); + })); + } + } + + return std::make_tuple(std::move(edgelist_majors), + std::move(edgelist_minors), + std::move(edgelist_edge_ids), + std::move(seed_vertices), + std::move(vertex_renumber_map), + std::move(*vertex_renumber_map_label_type_offsets), + std::move(edge_id_renumber_map), + std::move(edge_id_renumber_map_label_type_offsets)); +} + template void permute_array(raft::handle_t const& handle, IndexIterator index_first, @@ -1390,7 +2460,9 @@ void permute_array(raft::handle_t const& handle, value_first); } -// key: ((label), (hop), major, minor) +// key: +// ((label), (edge type), (hop), major, minor) if use_edge_type_as_sort_key is true +// ((label), (hop), major, minor) if use_edge_type_as_sort_key is false template std::tuple, rmm::device_uvector, @@ -1405,7 +2477,8 @@ sort_sampled_edge_tuples(raft::handle_t const& handle, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types, std::optional>&& edgelist_hops, - std::optional> edgelist_label_offsets) + std::optional> edgelist_label_offsets, + bool use_edge_type_as_sort_key) { std::vector h_label_offsets{}; std::vector h_edge_offsets{}; @@ -1427,11 +2500,15 @@ sort_sampled_edge_tuples(raft::handle_t const& handle, rmm::device_uvector indices(h_edge_offsets[i + 1] - h_edge_offsets[i], handle.get_stream()); thrust::sequence(handle.get_thrust_policy(), indices.begin(), indices.end(), size_t{0}); - edge_order_t edge_order_comp{ + edge_order_t edge_order_comp{ edgelist_label_offsets ? thrust::make_optional>( (*edgelist_label_offsets).data() + h_label_offsets[i], (h_label_offsets[i + 1] - h_label_offsets[i]) + 1) : thrust::nullopt, + edgelist_edge_types && use_edge_type_as_sort_key + ? thrust::make_optional>( + (*edgelist_edge_types).data() + h_edge_offsets[i], indices.size()) + : thrust::nullopt, edgelist_hops ? thrust::make_optional>( (*edgelist_hops).data() + h_edge_offsets[i], indices.size()) : thrust::nullopt, @@ -1510,25 +2587,29 @@ renumber_and_compress_sampled_edgelist( bool do_expensive_check) { using label_index_t = uint32_t; + using vertex_type_t = uint32_t; // dummy auto edgelist_majors = src_is_major ? std::move(edgelist_srcs) : std::move(edgelist_dsts); auto edgelist_minors = src_is_major ? std::move(edgelist_dsts) : std::move(edgelist_srcs); // 1. check input arguments - check_input_edges(handle, - edgelist_majors, - edgelist_minors, - edgelist_weights, - edgelist_edge_ids, - edgelist_edge_types, - edgelist_hops, - seed_vertices, - seed_vertex_label_offsets, - edgelist_label_offsets, - num_labels, - num_hops, - do_expensive_check); + check_input_edges(handle, + edgelist_majors, + edgelist_minors, + edgelist_weights, + edgelist_edge_ids, + edgelist_edge_types, + edgelist_hops, + seed_vertices, + seed_vertex_label_offsets, + edgelist_label_offsets, + std::nullopt, + num_labels, + num_hops, + size_t{1}, + std::optional{std::nullopt}, + do_expensive_check); CUGRAPH_EXPECTS( !doubly_compress || !compress_per_hop, @@ -1582,7 +2663,8 @@ renumber_and_compress_sampled_edgelist( std::move(edgelist_edge_ids), std::move(edgelist_edge_types), std::move(edgelist_hops), - edgelist_label_offsets); + edgelist_label_offsets, + false); if (renumbered_seed_vertices) { if (seed_vertex_label_offsets) { @@ -2144,25 +3226,29 @@ renumber_and_sort_sampled_edgelist( bool do_expensive_check) { using label_index_t = uint32_t; + using vertex_type_t = uint32_t; // dummy auto edgelist_majors = src_is_major ? std::move(edgelist_srcs) : std::move(edgelist_dsts); auto edgelist_minors = src_is_major ? std::move(edgelist_dsts) : std::move(edgelist_srcs); // 1. check input arguments - check_input_edges(handle, - edgelist_majors, - edgelist_minors, - edgelist_weights, - edgelist_edge_ids, - edgelist_edge_types, - edgelist_hops, - seed_vertices, - seed_vertex_label_offsets, - edgelist_label_offsets, - num_labels, - num_hops, - do_expensive_check); + check_input_edges(handle, + edgelist_majors, + edgelist_minors, + edgelist_weights, + edgelist_edge_ids, + edgelist_edge_types, + edgelist_hops, + seed_vertices, + seed_vertex_label_offsets, + edgelist_label_offsets, + std::nullopt, + num_labels, + num_hops, + size_t{1}, + std::optional{std::nullopt}, + do_expensive_check); // 2. renumber @@ -2206,7 +3292,8 @@ renumber_and_sort_sampled_edgelist( std::move(edgelist_edge_ids), std::move(edgelist_edge_types), std::move(edgelist_hops), - edgelist_label_offsets); + edgelist_label_offsets, + false); // 4. compute edgelist_label_hop_offsets @@ -2274,6 +3361,218 @@ renumber_and_sort_sampled_edgelist( std::move(renumber_map_label_offsets)); } +template +std::tuple, // srcs + rmm::device_uvector, // dsts + std::optional>, // weights + std::optional>, // edge IDs + std::optional>, // (label, edge type, hop) offsets to the + // edges + rmm::device_uvector, // vertex renumber map + rmm::device_uvector, // (label, vertex type) offsets to the vertex renumber map + std::optional>, // edge ID renumber map + std::optional< + rmm::device_uvector>> // (label, edge type) offsets to the vertex renumber map +heterogeneous_renumber_and_sort_sampled_edgelist( + raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_hops, + size_t num_vertex_types, + size_t num_edge_types, + bool src_is_major, + bool do_expensive_check) +{ + using label_index_t = uint32_t; + using vertex_type_t = uint32_t; + + auto edgelist_majors = src_is_major ? std::move(edgelist_srcs) : std::move(edgelist_dsts); + auto edgelist_minors = src_is_major ? std::move(edgelist_dsts) : std::move(edgelist_srcs); + + // 1. check input arguments + + check_input_edges(handle, + edgelist_majors, + edgelist_minors, + edgelist_weights, + edgelist_edge_ids, + edgelist_edge_types, + edgelist_hops, + seed_vertices, + seed_vertex_label_offsets, + edgelist_label_offsets, + vertex_type_offsets, + num_labels, + num_hops, + num_vertex_types, + std::optional{num_edge_types}, + do_expensive_check); + + // 2. renumber + + std::optional> renumbered_seed_vertices{std::nullopt}; + if (seed_vertices) { + renumbered_seed_vertices = + rmm::device_uvector((*seed_vertices).size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*seed_vertices).begin(), + (*seed_vertices).end(), + (*renumbered_seed_vertices).begin()); + } + rmm::device_uvector vertex_renumber_map(0, handle.get_stream()); + rmm::device_uvector vertex_renumber_map_label_type_offsets(0, handle.get_stream()); + std::optional> edge_id_renumber_map{std::nullopt}; + std::optional> edge_id_renumber_map_label_type_offsets{std::nullopt}; + std::tie(edgelist_majors, + edgelist_minors, + edgelist_edge_ids, + std::ignore, + vertex_renumber_map, + vertex_renumber_map_label_type_offsets, + edge_id_renumber_map, + edge_id_renumber_map_label_type_offsets) = + heterogeneous_renumber_sampled_edgelist( + handle, + std::move(edgelist_majors), + std::move(edgelist_minors), + std::move(edgelist_edge_ids), + edgelist_edge_types ? std::make_optional(raft::device_span( + (*edgelist_edge_types).data(), (*edgelist_edge_types).size())) + : std::nullopt, + edgelist_hops ? std::make_optional(raft::device_span((*edgelist_hops).data(), + (*edgelist_hops).size())) + : std::nullopt, + std::move(renumbered_seed_vertices), + seed_vertex_label_offsets, + edgelist_label_offsets, + vertex_type_offsets, + num_labels, + num_vertex_types, + num_edge_types, + do_expensive_check); + + // 3. sort by ((label), (edge type), (hop), major, minor) + + std::tie(edgelist_majors, + edgelist_minors, + edgelist_weights, + edgelist_edge_ids, + edgelist_edge_types, + edgelist_hops) = sort_sampled_edge_tuples(handle, + std::move(edgelist_majors), + std::move(edgelist_minors), + std::move(edgelist_weights), + std::move(edgelist_edge_ids), + std::move(edgelist_edge_types), + std::move(edgelist_hops), + edgelist_label_offsets, + true); + + // 4. compute edgelist (label, edge type, hop) offsets + + std::optional> edgelist_label_type_hop_offsets{std::nullopt}; + if (edgelist_label_offsets || edgelist_edge_types || edgelist_hops) { + edgelist_label_type_hop_offsets = + rmm::device_uvector(num_labels * num_edge_types * num_hops + 1, handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), + (*edgelist_label_type_hop_offsets).begin(), + (*edgelist_label_type_hop_offsets).end(), + size_t{0}); + thrust::transform( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(num_labels * num_edge_types * num_hops), + (*edgelist_label_type_hop_offsets).begin(), + cuda::proclaim_return_type( + [edgelist_label_offsets = detail::to_thrust_optional(edgelist_label_offsets), + edgelist_edge_types = edgelist_edge_types + ? thrust::make_optional>( + (*edgelist_edge_types).data(), (*edgelist_edge_types).size()) + : thrust::nullopt, + edgelist_hops = edgelist_hops ? thrust::make_optional>( + (*edgelist_hops).data(), (*edgelist_hops).size()) + : thrust::nullopt, + num_edge_types, + num_hops, + num_edges = edgelist_majors.size()] __device__(size_t i) { + size_t start_offset{0}; + auto end_offset = num_edges; + + if (edgelist_label_offsets) { + auto l_idx = static_cast(i / (num_edge_types * num_hops)); + start_offset = (*edgelist_label_offsets)[l_idx]; + end_offset = (*edgelist_label_offsets)[l_idx + 1]; + } + + if (edgelist_edge_types) { + auto t = static_cast((i % (num_edge_types * num_hops)) / num_hops); + auto lower_it = thrust::lower_bound(thrust::seq, + (*edgelist_edge_types).begin() + start_offset, + (*edgelist_edge_types).begin() + end_offset, + t); + auto upper_it = thrust::upper_bound(thrust::seq, + (*edgelist_edge_types).begin() + start_offset, + (*edgelist_edge_types).begin() + end_offset, + t); + start_offset = + static_cast(thrust::distance((*edgelist_edge_types).begin(), lower_it)); + end_offset = + static_cast(thrust::distance((*edgelist_edge_types).begin(), upper_it)); + } + + if (edgelist_hops) { + auto h = static_cast(i % num_hops); + auto lower_it = thrust::lower_bound(thrust::seq, + (*edgelist_hops).begin() + start_offset, + (*edgelist_hops).begin() + end_offset, + h); + auto upper_it = thrust::upper_bound(thrust::seq, + (*edgelist_hops).begin() + start_offset, + (*edgelist_hops).begin() + end_offset, + h); + start_offset = + static_cast(thrust::distance((*edgelist_hops).begin(), lower_it)); + end_offset = static_cast(thrust::distance((*edgelist_hops).begin(), upper_it)); + } + + return end_offset - start_offset; + })); + thrust::exclusive_scan(handle.get_thrust_policy(), + (*edgelist_label_type_hop_offsets).begin(), + (*edgelist_label_type_hop_offsets).end(), + (*edgelist_label_type_hop_offsets).begin()); + } + + edgelist_edge_types = std::nullopt; + edgelist_hops = std::nullopt; + + return std::make_tuple(std::move(src_is_major ? edgelist_majors : edgelist_minors), + std::move(src_is_major ? edgelist_minors : edgelist_majors), + std::move(edgelist_weights), + std::move(edgelist_edge_ids), + std::move(edgelist_label_type_hop_offsets), + std::move(vertex_renumber_map), + std::move(vertex_renumber_map_label_type_offsets), + std::move(edge_id_renumber_map), + std::move(edge_id_renumber_map_label_type_offsets)); +} + template (handle, - edgelist_majors, - edgelist_minors, - edgelist_weights, - edgelist_edge_ids, - edgelist_edge_types, - edgelist_hops, - std::nullopt, - std::nullopt, - edgelist_label_offsets, - num_labels, - num_hops, - do_expensive_check); + check_input_edges(handle, + edgelist_majors, + edgelist_minors, + edgelist_weights, + edgelist_edge_ids, + edgelist_edge_types, + edgelist_hops, + std::nullopt, + std::nullopt, + edgelist_label_offsets, + std::nullopt, + num_labels, + num_hops, + size_t{1}, + std::optional{std::nullopt}, + do_expensive_check); // 2. sort by ((l), (h), major, minor) @@ -2332,7 +3635,8 @@ sort_sampled_edgelist(raft::handle_t const& handle, std::move(edgelist_edge_ids), std::move(edgelist_edge_types), std::move(edgelist_hops), - edgelist_label_offsets); + edgelist_label_offsets, + false); // 3. compute edgelist_label_hop_offsets diff --git a/cpp/src/sampling/sampling_post_processing_sg_v32_e32.cu b/cpp/src/sampling/sampling_post_processing_sg_v32_e32.cu index 6b8d8a07d92..ff1add6a02a 100644 --- a/cpp/src/sampling/sampling_post_processing_sg_v32_e32.cu +++ b/cpp/src/sampling/sampling_post_processing_sg_v32_e32.cu @@ -122,6 +122,62 @@ renumber_and_sort_sampled_edgelist( bool src_is_major, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + rmm::device_uvector, + rmm::device_uvector, + std::optional>, + std::optional>> +heterogeneous_renumber_and_sort_sampled_edgelist( + raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_hops, + size_t num_vertex_types, + size_t num_edge_types, + bool src_is_major, + bool do_expensive_check); + +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + rmm::device_uvector, + rmm::device_uvector, + std::optional>, + std::optional>> +heterogeneous_renumber_and_sort_sampled_edgelist( + raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_hops, + size_t num_vertex_types, + size_t num_edge_types, + bool src_is_major, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, diff --git a/cpp/src/sampling/sampling_post_processing_sg_v32_e64.cu b/cpp/src/sampling/sampling_post_processing_sg_v32_e64.cu index a4b083efd7c..7001dcfdaf3 100644 --- a/cpp/src/sampling/sampling_post_processing_sg_v32_e64.cu +++ b/cpp/src/sampling/sampling_post_processing_sg_v32_e64.cu @@ -122,6 +122,62 @@ renumber_and_sort_sampled_edgelist( bool src_is_major, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + rmm::device_uvector, + rmm::device_uvector, + std::optional>, + std::optional>> +heterogeneous_renumber_and_sort_sampled_edgelist( + raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_hops, + size_t num_vertex_types, + size_t num_edge_types, + bool src_is_major, + bool do_expensive_check); + +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + rmm::device_uvector, + rmm::device_uvector, + std::optional>, + std::optional>> +heterogeneous_renumber_and_sort_sampled_edgelist( + raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_hops, + size_t num_vertex_types, + size_t num_edge_types, + bool src_is_major, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, diff --git a/cpp/src/sampling/sampling_post_processing_sg_v64_e64.cu b/cpp/src/sampling/sampling_post_processing_sg_v64_e64.cu index a62ca2a0777..3b2b8144420 100644 --- a/cpp/src/sampling/sampling_post_processing_sg_v64_e64.cu +++ b/cpp/src/sampling/sampling_post_processing_sg_v64_e64.cu @@ -122,6 +122,62 @@ renumber_and_sort_sampled_edgelist( bool src_is_major, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + rmm::device_uvector, + rmm::device_uvector, + std::optional>, + std::optional>> +heterogeneous_renumber_and_sort_sampled_edgelist( + raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_hops, + size_t num_vertex_types, + size_t num_edge_types, + bool src_is_major, + bool do_expensive_check); + +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + rmm::device_uvector, + rmm::device_uvector, + std::optional>, + std::optional>> +heterogeneous_renumber_and_sort_sampled_edgelist( + raft::handle_t const& handle, + rmm::device_uvector&& edgelist_srcs, + rmm::device_uvector&& edgelist_dsts, + std::optional>&& edgelist_weights, + std::optional>&& edgelist_edge_ids, + std::optional>&& edgelist_edge_types, + std::optional>&& edgelist_hops, + std::optional> seed_vertices, + std::optional> seed_vertex_label_offsets, + std::optional> edgelist_label_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_hops, + size_t num_vertex_types, + size_t num_edge_types, + bool src_is_major, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 3c3a3650491..09b1431e33b 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -47,6 +47,7 @@ add_library(cugraphtestutil STATIC structure/induced_subgraph_validate.cu sampling/random_walks_check_sg.cu sampling/detail/nbr_sampling_validate.cu + sampling/detail/sampling_post_processing_validate.cu ../../thirdparty/mmio/mmio.c) target_compile_options(cugraphtestutil @@ -486,7 +487,12 @@ ConfigureTest(BIASED_NEIGHBOR_SAMPLING_TEST sampling/biased_neighbor_sampling.cp ################################################################################################### # - SAMPLING_POST_PROCESSING tests ---------------------------------------------------------------- -ConfigureTest(SAMPLING_POST_PROCESSING_TEST sampling/sampling_post_processing_test.cu) +ConfigureTest(SAMPLING_POST_PROCESSING_TEST sampling/sampling_post_processing_test.cpp) + +################################################################################################### +# - SAMPLING_HETEROGENEOUS_POST_PROCESSING tests -------------------------------------------------- +ConfigureTest(SAMPLING_HETEROGENEOUS_POST_PROCESSING_TEST + sampling/sampling_heterogeneous_post_processing_test.cpp) ################################################################################################### # - NEGATIVE SAMPLING tests -------------------------------------------------------------------- @@ -581,7 +587,8 @@ if(BUILD_CUGRAPH_MG_TESTS) ############################################################################################### # - MG BETWEENNESS CENTRALITY tests ----------------------------------------------------------- ConfigureTestMG(MG_BETWEENNESS_CENTRALITY_TEST centrality/mg_betweenness_centrality_test.cpp) - ConfigureTestMG(MG_EDGE_BETWEENNESS_CENTRALITY_TEST centrality/mg_edge_betweenness_centrality_test.cpp) + ConfigureTestMG(MG_EDGE_BETWEENNESS_CENTRALITY_TEST + centrality/mg_edge_betweenness_centrality_test.cpp) ############################################################################################### # - MG BFS tests ------------------------------------------------------------------------------ diff --git a/cpp/tests/sampling/detail/sampling_post_processing_validate.cu b/cpp/tests/sampling/detail/sampling_post_processing_validate.cu new file mode 100644 index 00000000000..a0babc3b921 --- /dev/null +++ b/cpp/tests/sampling/detail/sampling_post_processing_validate.cu @@ -0,0 +1,1738 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +bool check_offsets(raft::handle_t const& handle, + raft::device_span offsets, + index_t num_segments, + index_t num_elements) +{ + if (offsets.size() != num_segments + 1) { return false; } + + if (!thrust::is_sorted(handle.get_thrust_policy(), offsets.begin(), offsets.end())) { + return false; + } + + index_t front_element{}; + index_t back_element{}; + raft::update_host(&front_element, offsets.data(), index_t{1}, handle.get_stream()); + raft::update_host( + &back_element, offsets.data() + offsets.size() - 1, index_t{1}, handle.get_stream()); + handle.sync_stream(); + + if (front_element != index_t{0}) { return false; } + + if (back_element != num_elements) { return false; } + + return true; +} + +template bool check_offsets(raft::handle_t const& handle, + raft::device_span offsets, + size_t num_segments, + size_t num_elements); + +template +bool check_edgelist_is_sorted(raft::handle_t const& handle, + raft::device_span edgelist_majors, + raft::device_span edgelist_minors) +{ + auto edge_first = thrust::make_zip_iterator(edgelist_majors.begin(), edgelist_minors.begin()); + return thrust::is_sorted( + handle.get_thrust_policy(), edge_first, edge_first + edgelist_majors.size()); +} + +template bool check_edgelist_is_sorted(raft::handle_t const& handle, + raft::device_span edgelist_majors, + raft::device_span edgelist_minors); + +template bool check_edgelist_is_sorted(raft::handle_t const& handle, + raft::device_span edgelist_majors, + raft::device_span edgelist_minors); + +// unrenumber the renumbered edge list and check whether the original & unrenumbered edge lists are +// identical +template +bool compare_edgelist(raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumber_map, + std::optional> renumber_map_label_offsets, + size_t num_labels) +{ + if (org_edgelist_srcs.size() != renumbered_edgelist_srcs.size()) { return false; } + + for (size_t i = 0; i < num_labels; ++i) { + size_t label_start_offset{0}; + size_t label_end_offset = org_edgelist_srcs.size(); + if (org_edgelist_label_offsets) { + raft::update_host(&label_start_offset, + (*org_edgelist_label_offsets).data() + i, + size_t{1}, + handle.get_stream()); + raft::update_host(&label_end_offset, + (*org_edgelist_label_offsets).data() + i + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + if (label_start_offset == label_end_offset) { continue; } + + rmm::device_uvector this_label_sorted_org_edgelist_srcs( + label_end_offset - label_start_offset, handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + org_edgelist_srcs.begin() + label_start_offset, + org_edgelist_srcs.begin() + label_end_offset, + this_label_sorted_org_edgelist_srcs.begin()); + rmm::device_uvector this_label_sorted_org_edgelist_dsts(org_edgelist_dsts.size(), + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + org_edgelist_dsts.begin() + label_start_offset, + org_edgelist_dsts.begin() + label_end_offset, + this_label_sorted_org_edgelist_dsts.begin()); + auto this_label_sorted_org_edgelist_weights = + org_edgelist_weights ? std::make_optional>( + label_end_offset - label_start_offset, handle.get_stream()) + : std::nullopt; + if (this_label_sorted_org_edgelist_weights) { + thrust::copy(handle.get_thrust_policy(), + (*org_edgelist_weights).begin() + label_start_offset, + (*org_edgelist_weights).begin() + label_end_offset, + (*this_label_sorted_org_edgelist_weights).begin()); + } + + if (this_label_sorted_org_edgelist_weights) { + auto sorted_org_edge_first = + thrust::make_zip_iterator(this_label_sorted_org_edgelist_srcs.begin(), + this_label_sorted_org_edgelist_dsts.begin(), + (*this_label_sorted_org_edgelist_weights).begin()); + thrust::sort(handle.get_thrust_policy(), + sorted_org_edge_first, + sorted_org_edge_first + this_label_sorted_org_edgelist_srcs.size()); + } else { + auto sorted_org_edge_first = thrust::make_zip_iterator( + this_label_sorted_org_edgelist_srcs.begin(), this_label_sorted_org_edgelist_dsts.begin()); + thrust::sort(handle.get_thrust_policy(), + sorted_org_edge_first, + sorted_org_edge_first + this_label_sorted_org_edgelist_srcs.size()); + } + + rmm::device_uvector this_label_sorted_unrenumbered_edgelist_srcs( + label_end_offset - label_start_offset, handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + renumbered_edgelist_srcs.begin() + label_start_offset, + renumbered_edgelist_srcs.begin() + label_end_offset, + this_label_sorted_unrenumbered_edgelist_srcs.begin()); + rmm::device_uvector this_label_sorted_unrenumbered_edgelist_dsts( + label_end_offset - label_start_offset, handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + renumbered_edgelist_dsts.begin() + label_start_offset, + renumbered_edgelist_dsts.begin() + label_end_offset, + this_label_sorted_unrenumbered_edgelist_dsts.begin()); + auto this_label_sorted_unrenumbered_edgelist_weights = + renumbered_edgelist_weights ? std::make_optional>( + label_end_offset - label_start_offset, handle.get_stream()) + : std::nullopt; + if (this_label_sorted_unrenumbered_edgelist_weights) { + thrust::copy(handle.get_thrust_policy(), + (*renumbered_edgelist_weights).begin() + label_start_offset, + (*renumbered_edgelist_weights).begin() + label_end_offset, + (*this_label_sorted_unrenumbered_edgelist_weights).begin()); + } + + if (renumber_map) { + size_t renumber_map_label_start_offset{0}; + size_t renumber_map_label_end_offset = (*renumber_map).size(); + if (renumber_map_label_offsets) { + raft::update_host(&renumber_map_label_start_offset, + (*renumber_map_label_offsets).data() + i, + size_t{1}, + handle.get_stream()); + raft::update_host(&renumber_map_label_end_offset, + (*renumber_map_label_offsets).data() + i + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + cugraph::unrenumber_int_vertices( + handle, + this_label_sorted_unrenumbered_edgelist_srcs.data(), + this_label_sorted_unrenumbered_edgelist_srcs.size(), + (*renumber_map).data() + renumber_map_label_start_offset, + std::vector{ + static_cast(renumber_map_label_end_offset - renumber_map_label_start_offset)}); + cugraph::unrenumber_int_vertices( + handle, + this_label_sorted_unrenumbered_edgelist_dsts.data(), + this_label_sorted_unrenumbered_edgelist_dsts.size(), + (*renumber_map).data() + renumber_map_label_start_offset, + std::vector{ + static_cast(renumber_map_label_end_offset - renumber_map_label_start_offset)}); + } + + if (this_label_sorted_unrenumbered_edgelist_weights) { + auto sorted_unrenumbered_edge_first = + thrust::make_zip_iterator(this_label_sorted_unrenumbered_edgelist_srcs.begin(), + this_label_sorted_unrenumbered_edgelist_dsts.begin(), + (*this_label_sorted_unrenumbered_edgelist_weights).begin()); + thrust::sort( + handle.get_thrust_policy(), + sorted_unrenumbered_edge_first, + sorted_unrenumbered_edge_first + this_label_sorted_unrenumbered_edgelist_srcs.size()); + + auto sorted_org_edge_first = + thrust::make_zip_iterator(this_label_sorted_org_edgelist_srcs.begin(), + this_label_sorted_org_edgelist_dsts.begin(), + (*this_label_sorted_org_edgelist_weights).begin()); + if (!thrust::equal(handle.get_thrust_policy(), + sorted_org_edge_first, + sorted_org_edge_first + this_label_sorted_org_edgelist_srcs.size(), + sorted_unrenumbered_edge_first)) { + return false; + } + } else { + auto sorted_unrenumbered_edge_first = + thrust::make_zip_iterator(this_label_sorted_unrenumbered_edgelist_srcs.begin(), + this_label_sorted_unrenumbered_edgelist_dsts.begin()); + thrust::sort( + handle.get_thrust_policy(), + sorted_unrenumbered_edge_first, + sorted_unrenumbered_edge_first + this_label_sorted_unrenumbered_edgelist_srcs.size()); + + auto sorted_org_edge_first = thrust::make_zip_iterator( + this_label_sorted_org_edgelist_srcs.begin(), this_label_sorted_org_edgelist_dsts.begin()); + if (!thrust::equal(handle.get_thrust_policy(), + sorted_org_edge_first, + sorted_org_edge_first + this_label_sorted_org_edgelist_srcs.size(), + sorted_unrenumbered_edge_first)) { + return false; + } + } + } + + return true; +} + +template bool compare_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumber_map, + std::optional> renumber_map_label_offsets, + size_t num_labels); + +template bool compare_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumber_map, + std::optional> renumber_map_label_offsets, + size_t num_labels); + +template bool compare_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumber_map, + std::optional> renumber_map_label_offsets, + size_t num_labels); + +template bool compare_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumber_map, + std::optional> renumber_map_label_offsets, + size_t num_labels); + +// unrenumber the renumbered edge list and check whether the original & unrenumbered edge lists +// are identical +template +bool compare_heterogeneous_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumbered_edgelist_edge_ids, + std::optional> renumbered_edgelist_label_edge_type_hop_offsets, + raft::device_span vertex_renumber_map, + raft::device_span vertex_renumber_map_label_type_offsets, + std::optional> edge_id_renumber_map, + std::optional> edge_id_renumber_map_label_type_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + size_t num_edge_types, + size_t num_hops) +{ + if (org_edgelist_srcs.size() != renumbered_edgelist_srcs.size()) { return false; } + + for (size_t i = 0; i < num_labels; ++i) { + size_t label_start_offset{0}; + size_t label_end_offset = org_edgelist_srcs.size(); + if (org_edgelist_label_offsets) { + raft::update_host(&label_start_offset, + (*org_edgelist_label_offsets).data() + i, + size_t{1}, + handle.get_stream()); + raft::update_host(&label_end_offset, + (*org_edgelist_label_offsets).data() + i + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + if (label_start_offset == label_end_offset) { continue; } + + if (renumbered_edgelist_label_edge_type_hop_offsets) { + size_t renumbered_label_start_offset{0}; + size_t renumbered_label_end_offset{0}; + raft::update_host( + &renumbered_label_start_offset, + (*renumbered_edgelist_label_edge_type_hop_offsets).data() + i * num_edge_types * num_hops, + size_t{1}, + handle.get_stream()); + raft::update_host(&renumbered_label_end_offset, + (*renumbered_edgelist_label_edge_type_hop_offsets).data() + + (i + 1) * num_edge_types * num_hops, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + if (renumbered_label_start_offset != label_start_offset) { return false; } + if (renumbered_label_end_offset != label_end_offset) { return false; } + } + + // sort org edgelist by ((edge_type), (hop), src, dst, (weight), (edge ID)) + + rmm::device_uvector this_label_org_sorted_indices(label_end_offset - label_start_offset, + handle.get_stream()); + thrust::sequence(handle.get_thrust_policy(), + this_label_org_sorted_indices.begin(), + this_label_org_sorted_indices.end(), + size_t{0}); + + thrust::sort( + handle.get_thrust_policy(), + this_label_org_sorted_indices.begin(), + this_label_org_sorted_indices.end(), + [edge_types = org_edgelist_edge_types + ? thrust::make_optional>( + (*org_edgelist_edge_types).data() + label_start_offset, + label_end_offset - label_start_offset) + : thrust::nullopt, + hops = org_edgelist_hops ? thrust::make_optional>( + (*org_edgelist_hops).data() + label_start_offset, + label_end_offset - label_start_offset) + : thrust::nullopt, + srcs = raft::device_span(org_edgelist_srcs.data() + label_start_offset, + label_end_offset - label_start_offset), + dsts = raft::device_span(org_edgelist_dsts.data() + label_start_offset, + label_end_offset - label_start_offset), + weights = org_edgelist_weights ? thrust::make_optional>( + (*org_edgelist_weights).data() + label_start_offset, + label_end_offset - label_start_offset) + : thrust::nullopt, + edge_ids = org_edgelist_edge_ids ? thrust::make_optional>( + (*org_edgelist_edge_ids).data() + label_start_offset, + label_end_offset - label_start_offset) + : thrust::nullopt] __device__(size_t l_idx, size_t r_idx) { + edge_type_t l_edge_type{0}; + edge_type_t r_edge_type{0}; + if (edge_types) { + l_edge_type = (*edge_types)[l_idx]; + r_edge_type = (*edge_types)[r_idx]; + } + + int32_t l_hop{0}; + int32_t r_hop{0}; + if (hops) { + l_hop = (*hops)[l_idx]; + r_hop = (*hops)[r_idx]; + } + + vertex_t l_src = srcs[l_idx]; + vertex_t r_src = srcs[r_idx]; + + vertex_t l_dst = dsts[l_idx]; + vertex_t r_dst = dsts[r_idx]; + + weight_t l_weight{0.0}; + weight_t r_weight{0.0}; + if (weights) { + l_weight = (*weights)[l_idx]; + r_weight = (*weights)[r_idx]; + } + + edge_id_t l_edge_id{0}; + edge_id_t r_edge_id{0}; + if (edge_ids) { + l_edge_id = (*edge_ids)[l_idx]; + r_edge_id = (*edge_ids)[r_idx]; + } + + return thrust::make_tuple(l_edge_type, l_hop, l_src, l_dst, l_weight, l_edge_id) < + thrust::make_tuple(r_edge_type, r_hop, r_src, r_dst, r_weight, r_edge_id); + }); + + for (size_t j = 0; j < num_edge_types; ++j) { + auto edge_type_start_offset = label_start_offset; + auto edge_type_end_offset = label_end_offset; + if (renumbered_edgelist_label_edge_type_hop_offsets) { + raft::update_host(&edge_type_start_offset, + (*renumbered_edgelist_label_edge_type_hop_offsets).data() + + i * num_edge_types * num_hops + j * num_hops, + size_t{1}, + handle.get_stream()); + raft::update_host(&edge_type_end_offset, + (*renumbered_edgelist_label_edge_type_hop_offsets).data() + + i * num_edge_types * num_hops + (j + 1) * num_hops, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + if (edge_type_start_offset == edge_type_end_offset) { continue; } + + if (org_edgelist_edge_types) { + if (static_cast(thrust::count_if( + handle.get_thrust_policy(), + this_label_org_sorted_indices.begin() + (edge_type_start_offset - label_start_offset), + this_label_org_sorted_indices.begin() + (edge_type_end_offset - label_start_offset), + [edge_types = raft::device_span( + (*org_edgelist_edge_types).data() + label_start_offset, + label_end_offset - label_start_offset), + edge_type = static_cast(j)] __device__(auto i) { + return edge_types[i] == edge_type; + })) != edge_type_end_offset - edge_type_start_offset) { + return false; + } + } + + if (org_edgelist_hops) { + for (size_t k = 0; k < num_hops; ++k) { + auto hop_start_offset = edge_type_start_offset; + auto hop_end_offset = edge_type_end_offset; + if (renumbered_edgelist_label_edge_type_hop_offsets) { + raft::update_host(&hop_start_offset, + (*renumbered_edgelist_label_edge_type_hop_offsets).data() + + i * num_edge_types * num_hops + j * num_hops + k, + size_t{1}, + handle.get_stream()); + raft::update_host(&hop_end_offset, + (*renumbered_edgelist_label_edge_type_hop_offsets).data() + + i * num_edge_types * num_hops + j * num_hops + k + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + if (hop_start_offset == hop_end_offset) { continue; } + + if (static_cast(thrust::count_if( + handle.get_thrust_policy(), + this_label_org_sorted_indices.begin() + (hop_start_offset - label_start_offset), + this_label_org_sorted_indices.begin() + (hop_end_offset - label_start_offset), + [hops = raft::device_span( + (*org_edgelist_hops).data() + label_start_offset, + label_end_offset - label_start_offset), + hop = static_cast(k)] __device__(auto i) { return hops[i] == hop; })) != + hop_end_offset - hop_start_offset) { + return false; + } + } + } + + // unrenumber source vertex IDs + + rmm::device_uvector this_edge_type_unrenumbered_edgelist_srcs( + edge_type_end_offset - edge_type_start_offset, handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + renumbered_edgelist_srcs.begin() + edge_type_start_offset, + renumbered_edgelist_srcs.begin() + edge_type_end_offset, + this_edge_type_unrenumbered_edgelist_srcs.begin()); + { + vertex_t org_src{}; + raft::update_host(&org_src, + org_edgelist_srcs.data() + label_start_offset + + this_label_org_sorted_indices.element( + edge_type_start_offset - label_start_offset, handle.get_stream()), + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + auto vertex_type = thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(handle.get_thrust_policy(), + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + org_src)); + size_t renumber_map_label_start_offset{}; + size_t renumber_map_label_end_offset{}; + raft::update_host( + &renumber_map_label_start_offset, + vertex_renumber_map_label_type_offsets.data() + i * num_vertex_types + vertex_type, + size_t{1}, + handle.get_stream()); + raft::update_host( + &renumber_map_label_end_offset, + vertex_renumber_map_label_type_offsets.data() + i * num_vertex_types + vertex_type + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + auto renumber_map = raft::device_span( + vertex_renumber_map.data() + renumber_map_label_start_offset, + renumber_map_label_end_offset - renumber_map_label_start_offset); + cugraph::unrenumber_int_vertices( + handle, + this_edge_type_unrenumbered_edgelist_srcs.data(), + edge_type_end_offset - edge_type_start_offset, + renumber_map.data(), + std::vector{static_cast(renumber_map.size())}); + } + + // unrenumber destination vertex IDs + + rmm::device_uvector this_edge_type_unrenumbered_edgelist_dsts( + edge_type_end_offset - edge_type_start_offset, handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + renumbered_edgelist_dsts.begin() + edge_type_start_offset, + renumbered_edgelist_dsts.begin() + edge_type_end_offset, + this_edge_type_unrenumbered_edgelist_dsts.begin()); + { + vertex_t org_dst{}; + raft::update_host(&org_dst, + org_edgelist_dsts.data() + label_start_offset + + this_label_org_sorted_indices.element( + edge_type_start_offset - label_start_offset, handle.get_stream()), + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + auto vertex_type = thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(handle.get_thrust_policy(), + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + org_dst)); + size_t renumber_map_label_start_offset{0}; + size_t renumber_map_label_end_offset{}; + raft::update_host( + &renumber_map_label_start_offset, + vertex_renumber_map_label_type_offsets.data() + i * num_vertex_types + vertex_type, + size_t{1}, + handle.get_stream()); + raft::update_host( + &renumber_map_label_end_offset, + vertex_renumber_map_label_type_offsets.data() + i * num_vertex_types + vertex_type + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + auto renumber_map = raft::device_span( + vertex_renumber_map.data() + renumber_map_label_start_offset, + renumber_map_label_end_offset - renumber_map_label_start_offset); + cugraph::unrenumber_int_vertices( + handle, + this_edge_type_unrenumbered_edgelist_dsts.data(), + edge_type_end_offset - edge_type_start_offset, + renumber_map.data(), + std::vector{static_cast(renumber_map.size())}); + } + + // unrenumber edge IDs + + std::optional> unrenumbered_edgelist_edge_ids{std::nullopt}; + if (renumbered_edgelist_edge_ids) { + unrenumbered_edgelist_edge_ids = rmm::device_uvector( + edge_type_end_offset - edge_type_start_offset, handle.get_stream()); + size_t renumber_map_type_start_offset{0}; + size_t renumber_map_type_end_offset = (*edge_id_renumber_map).size(); + if (edge_id_renumber_map_label_type_offsets) { + raft::update_host(&renumber_map_type_start_offset, + (*edge_id_renumber_map_label_type_offsets).data() + i * num_edge_types + + static_cast(j), + size_t{1}, + handle.get_stream()); + raft::update_host(&renumber_map_type_end_offset, + (*edge_id_renumber_map_label_type_offsets).data() + i * num_edge_types + + static_cast(j) + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + auto renumber_map = raft::device_span( + (*edge_id_renumber_map).data() + renumber_map_type_start_offset, + renumber_map_type_end_offset - renumber_map_type_start_offset); + thrust::gather(handle.get_thrust_policy(), + (*renumbered_edgelist_edge_ids).begin() + edge_type_start_offset, + (*renumbered_edgelist_edge_ids).begin() + edge_type_end_offset, + renumber_map.begin(), + (*unrenumbered_edgelist_edge_ids).begin()); + } + + // sort sorted & renumbered edgelist by (src, dst, (weight), (edge ID)) + + rmm::device_uvector this_edge_type_unrenumbered_sorted_indices( + edge_type_end_offset - edge_type_start_offset, handle.get_stream()); + thrust::sequence(handle.get_thrust_policy(), + this_edge_type_unrenumbered_sorted_indices.begin(), + this_edge_type_unrenumbered_sorted_indices.end(), + size_t{0}); + + for (size_t k = 0; k < num_hops; ++k) { + auto hop_start_offset = edge_type_start_offset; + auto hop_end_offset = edge_type_end_offset; + if (renumbered_edgelist_label_edge_type_hop_offsets) { + raft::update_host(&hop_start_offset, + (*renumbered_edgelist_label_edge_type_hop_offsets).data() + + i * num_edge_types * num_hops + j * num_hops + k, + size_t{1}, + handle.get_stream()); + raft::update_host(&hop_end_offset, + (*renumbered_edgelist_label_edge_type_hop_offsets).data() + + i * num_edge_types * num_hops + j * num_hops + k + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + if (hop_start_offset == hop_end_offset) { continue; } + + thrust::sort( + handle.get_thrust_policy(), + this_edge_type_unrenumbered_sorted_indices.begin() + + (hop_start_offset - edge_type_start_offset), + this_edge_type_unrenumbered_sorted_indices.begin() + + (hop_end_offset - edge_type_start_offset), + [srcs = + raft::device_span(this_edge_type_unrenumbered_edgelist_srcs.data(), + this_edge_type_unrenumbered_edgelist_srcs.size()), + dsts = + raft::device_span(this_edge_type_unrenumbered_edgelist_dsts.data(), + this_edge_type_unrenumbered_edgelist_dsts.size()), + weights = renumbered_edgelist_weights + ? thrust::make_optional>( + (*renumbered_edgelist_weights).data() + edge_type_start_offset, + edge_type_end_offset - edge_type_start_offset) + : thrust::nullopt, + edge_ids = renumbered_edgelist_edge_ids + ? thrust::make_optional>( + (*renumbered_edgelist_edge_ids).data() + edge_type_start_offset, + edge_type_end_offset - edge_type_start_offset) + : thrust::nullopt] __device__(size_t l_idx, size_t r_idx) { + vertex_t l_src = srcs[l_idx]; + vertex_t r_src = srcs[r_idx]; + + vertex_t l_dst = dsts[l_idx]; + vertex_t r_dst = dsts[r_idx]; + + weight_t l_weight{0.0}; + weight_t r_weight{0.0}; + if (weights) { + l_weight = (*weights)[l_idx]; + r_weight = (*weights)[r_idx]; + } + + edge_id_t l_edge_id{0}; + edge_id_t r_edge_id{0}; + if (edge_ids) { + l_edge_id = (*edge_ids)[l_idx]; + r_edge_id = (*edge_ids)[r_idx]; + } + + return thrust::make_tuple(l_src, l_dst, l_weight, l_edge_id) < + thrust::make_tuple(r_src, r_dst, r_weight, r_edge_id); + }); + } + + // compare + + if (!thrust::equal( + handle.get_thrust_policy(), + this_label_org_sorted_indices.begin() + (edge_type_start_offset - label_start_offset), + this_label_org_sorted_indices.begin() + (edge_type_end_offset - label_start_offset), + this_edge_type_unrenumbered_sorted_indices.begin(), + [org_srcs = + raft::device_span(org_edgelist_srcs.data() + label_start_offset, + label_end_offset - label_start_offset), + org_dsts = + raft::device_span(org_edgelist_dsts.data() + label_start_offset, + label_end_offset - label_start_offset), + org_weights = org_edgelist_weights + ? thrust::make_optional>( + (*org_edgelist_weights).data() + label_start_offset, + label_end_offset - label_start_offset) + : thrust::nullopt, + org_edge_ids = org_edgelist_edge_ids + ? thrust::make_optional>( + (*org_edgelist_edge_ids).data() + label_start_offset, + label_end_offset - label_start_offset) + : thrust::nullopt, + unrenumbered_srcs = + raft::device_span(this_edge_type_unrenumbered_edgelist_srcs.data(), + this_edge_type_unrenumbered_edgelist_srcs.size()), + unrenumbered_dsts = + raft::device_span(this_edge_type_unrenumbered_edgelist_dsts.data(), + this_edge_type_unrenumbered_edgelist_dsts.size()), + unrenumbered_weights = + renumbered_edgelist_weights + ? thrust::make_optional>( + (*renumbered_edgelist_weights).data() + edge_type_start_offset, + edge_type_end_offset - edge_type_start_offset) + : thrust::nullopt, + unrenumbered_edge_ids = + unrenumbered_edgelist_edge_ids + ? thrust::make_optional>( + (*unrenumbered_edgelist_edge_ids).data(), + (*unrenumbered_edgelist_edge_ids).size()) + : thrust:: + nullopt] __device__(size_t org_idx /* from label_start_offset */, + size_t + unrenumbered_idx /* from edge_type_start_offset */) { + auto org_src = org_srcs[org_idx]; + auto unrenumbered_src = unrenumbered_srcs[unrenumbered_idx]; + if (org_src != unrenumbered_src) { return false; } + + auto org_dst = org_dsts[org_idx]; + auto unrenumbered_dst = unrenumbered_dsts[unrenumbered_idx]; + if (org_dst != unrenumbered_dst) { return false; } + + weight_t org_weight{0.0}; + if (org_weights) { org_weight = (*org_weights)[org_idx]; } + weight_t unrenumbered_weight{0.0}; + if (unrenumbered_weights) { + unrenumbered_weight = (*unrenumbered_weights)[unrenumbered_idx]; + } + if (org_weight != unrenumbered_weight) { return false; } + + edge_id_t org_edge_id{0}; + if (org_edge_ids) { org_edge_id = (*org_edge_ids)[org_idx]; } + edge_id_t unrenumbered_edge_id{0}; + if (unrenumbered_edge_ids) { + unrenumbered_edge_id = (*unrenumbered_edge_ids)[unrenumbered_idx]; + } + + return org_edge_id == unrenumbered_edge_id; + })) { + return false; + } + } + } + + return true; +} + +template bool compare_heterogeneous_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumbered_edgelist_edge_ids, + std::optional> renumbered_edgelist_label_edge_type_hop_offsets, + raft::device_span vertex_renumber_map, + raft::device_span vertex_renumber_map_label_type_offsets, + std::optional> edge_id_renumber_map, + std::optional> edge_id_renumber_map_label_type_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + size_t num_edge_types, + size_t num_hops); + +template bool compare_heterogeneous_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumbered_edgelist_edge_ids, + std::optional> renumbered_edgelist_label_edge_type_hop_offsets, + raft::device_span vertex_renumber_map, + raft::device_span vertex_renumber_map_label_type_offsets, + std::optional> edge_id_renumber_map, + std::optional> edge_id_renumber_map_label_type_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + size_t num_edge_types, + size_t num_hops); + +template bool compare_heterogeneous_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumbered_edgelist_edge_ids, + std::optional> renumbered_edgelist_label_edge_type_hop_offsets, + raft::device_span vertex_renumber_map, + raft::device_span vertex_renumber_map_label_type_offsets, + std::optional> edge_id_renumber_map, + std::optional> edge_id_renumber_map_label_type_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + size_t num_edge_types, + size_t num_hops); + +template bool compare_heterogeneous_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumbered_edgelist_edge_ids, + std::optional> renumbered_edgelist_label_edge_type_hop_offsets, + raft::device_span vertex_renumber_map, + raft::device_span vertex_renumber_map_label_type_offsets, + std::optional> edge_id_renumber_map, + std::optional> edge_id_renumber_map_label_type_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + size_t num_edge_types, + size_t num_hops); + +template bool compare_heterogeneous_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumbered_edgelist_edge_ids, + std::optional> renumbered_edgelist_label_edge_type_hop_offsets, + raft::device_span vertex_renumber_map, + raft::device_span vertex_renumber_map_label_type_offsets, + std::optional> edge_id_renumber_map, + std::optional> edge_id_renumber_map_label_type_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + size_t num_edge_types, + size_t num_hops); + +template bool compare_heterogeneous_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumbered_edgelist_edge_ids, + std::optional> renumbered_edgelist_label_edge_type_hop_offsets, + raft::device_span vertex_renumber_map, + raft::device_span vertex_renumber_map_label_type_offsets, + std::optional> edge_id_renumber_map, + std::optional> edge_id_renumber_map_label_type_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + size_t num_edge_types, + size_t num_hops); + +template +bool check_vertex_renumber_map_invariants( + raft::handle_t const& handle, + std::optional> starting_vertices, + std::optional> starting_vertex_label_offsets, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumber_map, + std::optional> renumber_map_label_type_offsets, + std::optional> vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + bool src_is_major) +{ + // Check the invariants in renumber_map + // Say we found the minimum (primary key:hop, secondary key:flag) pairs for every unique vertices, + // where flag is 0 for majors and 1 for minors. Then, vertices with smaller (hop, flag) + // pairs should be renumbered to smaller numbers than vertices with larger (hop, flag) pairs. + auto org_edgelist_majors = src_is_major ? org_edgelist_srcs : org_edgelist_dsts; + auto org_edgelist_minors = src_is_major ? org_edgelist_dsts : org_edgelist_srcs; + + for (size_t i = 0; i < num_labels; ++i) { + size_t label_start_offset{0}; + auto label_end_offset = org_edgelist_majors.size(); + if (org_edgelist_label_offsets) { + raft::update_host(&label_start_offset, + (*org_edgelist_label_offsets).data() + i, + size_t{1}, + handle.get_stream()); + raft::update_host(&label_end_offset, + (*org_edgelist_label_offsets).data() + i + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + if (label_start_offset == label_end_offset) { continue; } + + // compute (unique major, min_hop) pairs + + rmm::device_uvector this_label_unique_majors(label_end_offset - label_start_offset, + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + org_edgelist_majors.begin() + label_start_offset, + org_edgelist_majors.begin() + label_end_offset, + this_label_unique_majors.begin()); + if (starting_vertices) { + size_t starting_vertex_label_start_offset{0}; + auto starting_vertex_label_end_offset = (*starting_vertices).size(); + if (starting_vertex_label_offsets) { + raft::update_host(&starting_vertex_label_start_offset, + (*starting_vertex_label_offsets).data() + i, + size_t{1}, + handle.get_stream()); + raft::update_host(&starting_vertex_label_end_offset, + (*starting_vertex_label_offsets).data() + i + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + auto old_size = this_label_unique_majors.size(); + this_label_unique_majors.resize( + old_size + starting_vertex_label_end_offset - starting_vertex_label_start_offset, + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*starting_vertices).begin() + starting_vertex_label_start_offset, + (*starting_vertices).begin() + starting_vertex_label_end_offset, + this_label_unique_majors.begin() + old_size); + } + + std::optional> this_label_unique_major_hops = + org_edgelist_hops ? std::make_optional>( + label_end_offset - label_start_offset, handle.get_stream()) + : std::nullopt; + if (org_edgelist_hops) { + thrust::copy(handle.get_thrust_policy(), + (*org_edgelist_hops).begin() + label_start_offset, + (*org_edgelist_hops).begin() + label_end_offset, + (*this_label_unique_major_hops).begin()); + if (starting_vertices) { + auto old_size = (*this_label_unique_major_hops).size(); + (*this_label_unique_major_hops) + .resize(this_label_unique_majors.size(), handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), + (*this_label_unique_major_hops).begin() + old_size, + (*this_label_unique_major_hops).end(), + int32_t{0}); + } + + auto pair_first = thrust::make_zip_iterator(this_label_unique_majors.begin(), + (*this_label_unique_major_hops).begin()); + thrust::sort( + handle.get_thrust_policy(), pair_first, pair_first + this_label_unique_majors.size()); + this_label_unique_majors.resize(thrust::distance(this_label_unique_majors.begin(), + thrust::get<0>(thrust::unique_by_key( + handle.get_thrust_policy(), + this_label_unique_majors.begin(), + this_label_unique_majors.end(), + (*this_label_unique_major_hops).begin()))), + handle.get_stream()); + (*this_label_unique_major_hops).resize(this_label_unique_majors.size(), handle.get_stream()); + } else { + thrust::sort(handle.get_thrust_policy(), + this_label_unique_majors.begin(), + this_label_unique_majors.end()); + this_label_unique_majors.resize( + thrust::distance(this_label_unique_majors.begin(), + thrust::unique(handle.get_thrust_policy(), + this_label_unique_majors.begin(), + this_label_unique_majors.end())), + handle.get_stream()); + } + + // compute (unique minor, min_hop) pairs + + rmm::device_uvector this_label_unique_minors(label_end_offset - label_start_offset, + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + org_edgelist_minors.begin() + label_start_offset, + org_edgelist_minors.begin() + label_end_offset, + this_label_unique_minors.begin()); + std::optional> this_label_unique_minor_hops = + org_edgelist_hops ? std::make_optional>( + label_end_offset - label_start_offset, handle.get_stream()) + : std::nullopt; + if (org_edgelist_hops) { + thrust::copy(handle.get_thrust_policy(), + (*org_edgelist_hops).begin() + label_start_offset, + (*org_edgelist_hops).begin() + label_end_offset, + (*this_label_unique_minor_hops).begin()); + + auto pair_first = thrust::make_zip_iterator(this_label_unique_minors.begin(), + (*this_label_unique_minor_hops).begin()); + thrust::sort( + handle.get_thrust_policy(), pair_first, pair_first + this_label_unique_minors.size()); + this_label_unique_minors.resize(thrust::distance(this_label_unique_minors.begin(), + thrust::get<0>(thrust::unique_by_key( + handle.get_thrust_policy(), + this_label_unique_minors.begin(), + this_label_unique_minors.end(), + (*this_label_unique_minor_hops).begin()))), + handle.get_stream()); + (*this_label_unique_minor_hops).resize(this_label_unique_minors.size(), handle.get_stream()); + } else { + thrust::sort(handle.get_thrust_policy(), + this_label_unique_minors.begin(), + this_label_unique_minors.end()); + this_label_unique_minors.resize( + thrust::distance(this_label_unique_minors.begin(), + thrust::unique(handle.get_thrust_policy(), + this_label_unique_minors.begin(), + this_label_unique_minors.end())), + handle.get_stream()); + } + + for (size_t j = 0; j < num_vertex_types; ++j) { + size_t renumber_map_type_start_offset{0}; + auto renumber_map_type_end_offset = renumber_map.size(); + if (renumber_map_label_type_offsets) { + raft::update_host(&renumber_map_type_start_offset, + (*renumber_map_label_type_offsets).data() + i * num_vertex_types + j, + size_t{1}, + handle.get_stream()); + raft::update_host(&renumber_map_type_end_offset, + (*renumber_map_label_type_offsets).data() + i * num_vertex_types + j + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + rmm::device_uvector this_type_sorted_org_vertices( + renumber_map_type_end_offset - renumber_map_type_start_offset, handle.get_stream()); + rmm::device_uvector this_type_matching_renumbered_vertices( + this_type_sorted_org_vertices.size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + renumber_map.begin() + renumber_map_type_start_offset, + renumber_map.begin() + renumber_map_type_end_offset, + this_type_sorted_org_vertices.begin()); + thrust::sequence(handle.get_thrust_policy(), + this_type_matching_renumbered_vertices.begin(), + this_type_matching_renumbered_vertices.end(), + vertex_t{0}); + thrust::sort_by_key(handle.get_thrust_policy(), + this_type_sorted_org_vertices.begin(), + this_type_sorted_org_vertices.end(), + this_type_matching_renumbered_vertices.begin()); + + rmm::device_uvector this_type_unique_majors(this_label_unique_majors.size(), + handle.get_stream()); + auto this_type_unique_major_hops = + this_label_unique_major_hops + ? std::make_optional>((*this_label_unique_major_hops).size(), + handle.get_stream()) + : std::nullopt; + rmm::device_uvector this_type_unique_minors(this_label_unique_minors.size(), + handle.get_stream()); + auto this_type_unique_minor_hops = + this_label_unique_minor_hops + ? std::make_optional>((*this_label_unique_minor_hops).size(), + handle.get_stream()) + : std::nullopt; + + if (org_edgelist_hops) { + if (vertex_type_offsets) { + auto input_pair_first = thrust::make_zip_iterator( + this_label_unique_majors.begin(), (*this_label_unique_major_hops).begin()); + auto output_pair_first = thrust::make_zip_iterator( + this_type_unique_majors.begin(), (*this_type_unique_major_hops).begin()); + this_type_unique_majors.resize( + thrust::distance( + output_pair_first, + thrust::copy_if(handle.get_thrust_policy(), + input_pair_first, + input_pair_first + this_label_unique_majors.size(), + output_pair_first, + [vertex_type_offsets = *vertex_type_offsets, + vertex_type = j] __device__(auto pair) { + auto type_idx = thrust::distance( + vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<0>(pair))); + return static_cast(thrust::distance( + vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<0>(pair)))) == vertex_type; + })), + handle.get_stream()); + (*this_type_unique_major_hops) + .resize(this_type_unique_majors.size(), handle.get_stream()); + + input_pair_first = thrust::make_zip_iterator(this_label_unique_minors.begin(), + (*this_label_unique_minor_hops).begin()); + output_pair_first = thrust::make_zip_iterator(this_type_unique_minors.begin(), + (*this_type_unique_minor_hops).begin()); + this_type_unique_minors.resize( + thrust::distance( + output_pair_first, + thrust::copy_if(handle.get_thrust_policy(), + input_pair_first, + input_pair_first + this_label_unique_minors.size(), + output_pair_first, + [vertex_type_offsets = *vertex_type_offsets, + vertex_type = j] __device__(auto pair) { + return static_cast(thrust::distance( + vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + thrust::get<0>(pair)))) == vertex_type; + })), + handle.get_stream()); + (*this_type_unique_minor_hops) + .resize(this_type_unique_minors.size(), handle.get_stream()); + } else { + auto input_pair_first = thrust::make_zip_iterator( + this_label_unique_majors.begin(), (*this_label_unique_major_hops).begin()); + thrust::copy(handle.get_thrust_policy(), + input_pair_first, + input_pair_first + this_label_unique_majors.size(), + thrust::make_zip_iterator(this_type_unique_majors.begin(), + (*this_type_unique_major_hops).begin())); + input_pair_first = thrust::make_zip_iterator(this_label_unique_minors.begin(), + (*this_label_unique_minor_hops).begin()); + thrust::copy(handle.get_thrust_policy(), + input_pair_first, + input_pair_first + this_label_unique_minors.size(), + thrust::make_zip_iterator(this_type_unique_minors.begin(), + (*this_type_unique_minor_hops).begin())); + } + + if (this_type_unique_majors.size() + this_type_unique_minors.size() == 0) { continue; } + + rmm::device_uvector merged_vertices( + this_type_unique_majors.size() + this_type_unique_minors.size(), handle.get_stream()); + rmm::device_uvector merged_hops(merged_vertices.size(), handle.get_stream()); + rmm::device_uvector merged_flags(merged_vertices.size(), handle.get_stream()); + + auto major_triplet_first = + thrust::make_zip_iterator(this_type_unique_majors.begin(), + (*this_type_unique_major_hops).begin(), + thrust::make_constant_iterator(int8_t{0})); + auto minor_triplet_first = + thrust::make_zip_iterator(this_type_unique_minors.begin(), + (*this_type_unique_minor_hops).begin(), + thrust::make_constant_iterator(int8_t{1})); + thrust::merge(handle.get_thrust_policy(), + major_triplet_first, + major_triplet_first + this_type_unique_majors.size(), + minor_triplet_first, + minor_triplet_first + this_type_unique_minors.size(), + thrust::make_zip_iterator( + merged_vertices.begin(), merged_hops.begin(), merged_flags.begin())); + merged_vertices.resize( + thrust::distance( + merged_vertices.begin(), + thrust::get<0>(thrust::unique_by_key( + handle.get_thrust_policy(), + merged_vertices.begin(), + merged_vertices.end(), + thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin())))), + handle.get_stream()); + merged_hops.resize(merged_vertices.size(), handle.get_stream()); + merged_flags.resize(merged_vertices.size(), handle.get_stream()); + + if ((renumber_map_type_end_offset - renumber_map_type_start_offset) != + merged_vertices.size()) { // renumber map size == # unique vertices + return false; + } + + auto sort_key_first = thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin()); + thrust::sort_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_hops.size(), + merged_vertices.begin()); + + auto num_unique_keys = thrust::count_if( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(merged_hops.size()), + cugraph::detail::is_first_in_run_t{sort_key_first}); + rmm::device_uvector min_vertices(num_unique_keys, handle.get_stream()); + rmm::device_uvector max_vertices(num_unique_keys, handle.get_stream()); + + auto renumbered_merged_vertex_first = thrust::make_transform_iterator( + merged_vertices.begin(), + cuda::proclaim_return_type( + [this_type_sorted_org_vertices = raft::device_span( + this_type_sorted_org_vertices.data(), this_type_sorted_org_vertices.size()), + this_type_matching_renumbered_vertices = raft::device_span( + this_type_matching_renumbered_vertices.data(), + this_type_matching_renumbered_vertices.size())] __device__(vertex_t major) { + auto it = thrust::lower_bound(thrust::seq, + this_type_sorted_org_vertices.begin(), + this_type_sorted_org_vertices.end(), + major); + return this_type_matching_renumbered_vertices[thrust::distance( + this_type_sorted_org_vertices.begin(), it)]; + })); + + thrust::reduce_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_hops.size(), + renumbered_merged_vertex_first, + thrust::make_discard_iterator(), + min_vertices.begin(), + thrust::equal_to>{}, + thrust::minimum{}); + thrust::reduce_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_hops.size(), + renumbered_merged_vertex_first, + thrust::make_discard_iterator(), + max_vertices.begin(), + thrust::equal_to>{}, + thrust::maximum{}); + + auto num_violations = + thrust::count_if(handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{1}), + thrust::make_counting_iterator(min_vertices.size()), + [min_vertices = raft::device_span(min_vertices.data(), + min_vertices.size()), + max_vertices = raft::device_span( + max_vertices.data(), max_vertices.size())] __device__(size_t i) { + return min_vertices[i] <= max_vertices[i - 1]; + }); + + if (num_violations != 0) { return false; } + } else { + if (vertex_type_offsets) { + this_type_unique_majors.resize( + thrust::distance( + this_type_unique_majors.begin(), + thrust::copy_if( + handle.get_thrust_policy(), + this_label_unique_majors.begin(), + this_label_unique_majors.end(), + this_type_unique_majors.begin(), + [vertex_type_offsets = *vertex_type_offsets, vertex_type = j] __device__(auto v) { + auto type_idx = thrust::distance( + vertex_type_offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, vertex_type_offsets.begin() + 1, vertex_type_offsets.end(), v)); + return static_cast( + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + v))) == vertex_type; + })), + handle.get_stream()); + + this_type_unique_minors.resize( + thrust::distance( + this_type_unique_minors.begin(), + thrust::copy_if( + handle.get_thrust_policy(), + this_label_unique_minors.begin(), + this_label_unique_minors.end(), + this_type_unique_minors.begin(), + [vertex_type_offsets = *vertex_type_offsets, vertex_type = j] __device__(auto v) { + return static_cast( + thrust::distance(vertex_type_offsets.begin() + 1, + thrust::upper_bound(thrust::seq, + vertex_type_offsets.begin() + 1, + vertex_type_offsets.end(), + v))) == vertex_type; + })), + handle.get_stream()); + (*this_type_unique_minor_hops) + .resize(this_type_unique_minors.size(), handle.get_stream()); + } else { + thrust::copy(handle.get_thrust_policy(), + this_label_unique_majors.begin(), + this_label_unique_majors.end(), + this_type_unique_majors.begin()); + thrust::copy(handle.get_thrust_policy(), + this_label_unique_minors.begin(), + this_label_unique_minors.end(), + this_type_unique_minors.begin()); + } + + this_type_unique_minors.resize( + thrust::distance( + this_type_unique_minors.begin(), + thrust::remove_if(handle.get_thrust_policy(), + this_type_unique_minors.begin(), + this_type_unique_minors.end(), + [sorted_unique_majors = raft::device_span( + this_type_unique_majors.data(), + this_type_unique_majors.size())] __device__(auto minor) { + return thrust::binary_search(thrust::seq, + sorted_unique_majors.begin(), + sorted_unique_majors.end(), + minor); + })), + handle.get_stream()); + + if ((renumber_map_type_end_offset - renumber_map_type_start_offset) != + (this_type_unique_majors.size() + + this_type_unique_minors.size())) { // renumber map size == # unique vertices + return false; + } + + auto max_major_renumbered_vertex = thrust::transform_reduce( + handle.get_thrust_policy(), + this_type_unique_majors.begin(), + this_type_unique_majors.end(), + cuda::proclaim_return_type( + [this_type_sorted_org_vertices = raft::device_span( + this_type_sorted_org_vertices.data(), this_type_sorted_org_vertices.size()), + this_type_matching_renumbered_vertices = raft::device_span( + this_type_matching_renumbered_vertices.data(), + this_type_matching_renumbered_vertices.size())] __device__(vertex_t major) + -> vertex_t { + auto it = thrust::lower_bound(thrust::seq, + this_type_sorted_org_vertices.begin(), + this_type_sorted_org_vertices.end(), + major); + return this_type_matching_renumbered_vertices[thrust::distance( + this_type_sorted_org_vertices.begin(), it)]; + }), + std::numeric_limits::lowest(), + thrust::maximum{}); + + auto min_minor_renumbered_vertex = thrust::transform_reduce( + handle.get_thrust_policy(), + this_type_unique_minors.begin(), + this_type_unique_minors.end(), + cuda::proclaim_return_type( + [this_type_sorted_org_vertices = raft::device_span( + this_type_sorted_org_vertices.data(), this_type_sorted_org_vertices.size()), + this_type_matching_renumbered_vertices = raft::device_span( + this_type_matching_renumbered_vertices.data(), + this_type_matching_renumbered_vertices.size())] __device__(vertex_t minor) + -> vertex_t { + auto it = thrust::lower_bound(thrust::seq, + this_type_sorted_org_vertices.begin(), + this_type_sorted_org_vertices.end(), + minor); + return this_type_matching_renumbered_vertices[thrust::distance( + this_type_sorted_org_vertices.begin(), it)]; + }), + std::numeric_limits::max(), + thrust::minimum{}); + + if (max_major_renumbered_vertex >= min_minor_renumbered_vertex) { return false; } + } + } + } + + return true; +} + +template bool check_vertex_renumber_map_invariants( + raft::handle_t const& handle, + std::optional> starting_vertices, + std::optional> starting_vertex_label_offsets, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumber_map, + std::optional> renumber_map_label_type_offsets, + std::optional> vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + bool src_is_major); + +template bool check_vertex_renumber_map_invariants( + raft::handle_t const& handle, + std::optional> starting_vertices, + std::optional> starting_vertex_label_offsets, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumber_map, + std::optional> renumber_map_label_type_offsets, + std::optional> vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + bool src_is_major); + +template +bool check_edge_id_renumber_map_invariants( + raft::handle_t const& handle, + raft::device_span org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumber_map, + std::optional> renumber_map_label_type_offsets, + size_t num_labels, + size_t num_edge_types) +{ + // Check the invariants in renumber_map + // Say we found the minimum (primary key:hop, secondary key:flag) pairs for every unique vertices, + // where flag is 0 for majors and 1 for minors. Then, vertices with smaller (hop, flag) + // pairs should be renumbered to smaller numbers than vertices with larger (hop, flag) pairs. + + for (size_t i = 0; i < num_labels; ++i) { + size_t label_start_offset{0}; + auto label_end_offset = org_edgelist_edge_ids.size(); + if (org_edgelist_label_offsets) { + raft::update_host(&label_start_offset, + (*org_edgelist_label_offsets).data() + i, + size_t{1}, + handle.get_stream()); + raft::update_host(&label_end_offset, + (*org_edgelist_label_offsets).data() + i + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + if (label_start_offset == label_end_offset) { continue; } + + // compute unique key (edge type, edge ID), value (min. hop) pairs + + std::optional> this_label_unique_key_edge_types = + org_edgelist_edge_types ? std::make_optional>( + label_end_offset - label_start_offset, handle.get_stream()) + : std::nullopt; + if (org_edgelist_edge_types) { + thrust::copy(handle.get_thrust_policy(), + (*org_edgelist_edge_types).begin() + label_start_offset, + (*org_edgelist_edge_types).begin() + label_end_offset, + (*this_label_unique_key_edge_types).begin()); + } + + rmm::device_uvector this_label_unique_key_edge_ids( + label_end_offset - label_start_offset, handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + org_edgelist_edge_ids.begin() + label_start_offset, + org_edgelist_edge_ids.begin() + label_end_offset, + this_label_unique_key_edge_ids.begin()); + + std::optional> this_label_unique_key_hops = + org_edgelist_hops ? std::make_optional>( + label_end_offset - label_start_offset, handle.get_stream()) + : std::nullopt; + if (org_edgelist_hops) { + thrust::copy(handle.get_thrust_policy(), + (*org_edgelist_hops).begin() + label_start_offset, + (*org_edgelist_hops).begin() + label_end_offset, + (*this_label_unique_key_hops).begin()); + } + + if (org_edgelist_edge_types) { + if (org_edgelist_hops) { + auto triplet_first = thrust::make_zip_iterator((*this_label_unique_key_edge_types).begin(), + this_label_unique_key_edge_ids.begin(), + (*this_label_unique_key_hops).begin()); + thrust::sort(handle.get_thrust_policy(), + triplet_first, + triplet_first + this_label_unique_key_edge_ids.size()); + auto key_first = thrust::make_zip_iterator((*this_label_unique_key_edge_types).begin(), + this_label_unique_key_edge_ids.begin()); + this_label_unique_key_edge_ids.resize( + thrust::distance( + key_first, + thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), + key_first, + key_first + this_label_unique_key_edge_ids.size(), + (*this_label_unique_key_hops).begin()))), + handle.get_stream()); + (*this_label_unique_key_edge_types) + .resize(this_label_unique_key_edge_ids.size(), handle.get_stream()); + (*this_label_unique_key_hops) + .resize(this_label_unique_key_edge_ids.size(), handle.get_stream()); + } else { + auto pair_first = thrust::make_zip_iterator((*this_label_unique_key_edge_types).begin(), + this_label_unique_key_edge_ids.begin()); + thrust::sort(handle.get_thrust_policy(), + pair_first, + pair_first + this_label_unique_key_edge_ids.size()); + this_label_unique_key_edge_ids.resize( + thrust::distance(pair_first, + thrust::unique(handle.get_thrust_policy(), + pair_first, + pair_first + this_label_unique_key_edge_ids.size())), + handle.get_stream()); + (*this_label_unique_key_edge_types) + .resize(this_label_unique_key_edge_ids.size(), handle.get_stream()); + } + } else { + if (org_edgelist_hops) { + auto pair_first = thrust::make_zip_iterator(this_label_unique_key_edge_ids.begin(), + (*this_label_unique_key_hops).begin()); + thrust::sort(handle.get_thrust_policy(), + pair_first, + pair_first + this_label_unique_key_edge_ids.size()); + this_label_unique_key_edge_ids.resize( + thrust::distance( + this_label_unique_key_edge_ids.begin(), + thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), + this_label_unique_key_edge_ids.begin(), + this_label_unique_key_edge_ids.end(), + (*this_label_unique_key_hops).begin()))), + handle.get_stream()); + (*this_label_unique_key_hops) + .resize(this_label_unique_key_edge_ids.size(), handle.get_stream()); + } else { + thrust::sort(handle.get_thrust_policy(), + this_label_unique_key_edge_ids.begin(), + this_label_unique_key_edge_ids.end()); + this_label_unique_key_edge_ids.resize( + thrust::distance(this_label_unique_key_edge_ids.begin(), + thrust::unique(handle.get_thrust_policy(), + this_label_unique_key_edge_ids.begin(), + this_label_unique_key_edge_ids.end())), + handle.get_stream()); + } + } + + for (size_t j = 0; j < num_edge_types; ++j) { + size_t renumber_map_type_start_offset{0}; + auto renumber_map_type_end_offset = renumber_map.size(); + if (renumber_map_label_type_offsets) { + raft::update_host(&renumber_map_type_start_offset, + (*renumber_map_label_type_offsets).data() + i * num_edge_types + j, + size_t{1}, + handle.get_stream()); + raft::update_host(&renumber_map_type_end_offset, + (*renumber_map_label_type_offsets).data() + i * num_edge_types + j + 1, + size_t{1}, + handle.get_stream()); + handle.sync_stream(); + } + + rmm::device_uvector this_type_sorted_org_edge_ids( + renumber_map_type_end_offset - renumber_map_type_start_offset, handle.get_stream()); + rmm::device_uvector this_type_matching_renumbered_edge_ids( + this_type_sorted_org_edge_ids.size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + renumber_map.begin() + renumber_map_type_start_offset, + renumber_map.begin() + renumber_map_type_end_offset, + this_type_sorted_org_edge_ids.begin()); + thrust::sequence(handle.get_thrust_policy(), + this_type_matching_renumbered_edge_ids.begin(), + this_type_matching_renumbered_edge_ids.end(), + edge_id_t{0}); + thrust::sort_by_key(handle.get_thrust_policy(), + this_type_sorted_org_edge_ids.begin(), + this_type_sorted_org_edge_ids.end(), + this_type_matching_renumbered_edge_ids.begin()); + + size_t type_start_offset{0}; + auto type_end_offset = this_label_unique_key_edge_ids.size(); + if (this_label_unique_key_edge_types) { + type_start_offset = static_cast( + thrust::distance((*this_label_unique_key_edge_types).begin(), + thrust::lower_bound(handle.get_thrust_policy(), + (*this_label_unique_key_edge_types).begin(), + (*this_label_unique_key_edge_types).end(), + static_cast(j)))); + type_end_offset = static_cast( + thrust::distance((*this_label_unique_key_edge_types).begin(), + thrust::upper_bound(handle.get_thrust_policy(), + (*this_label_unique_key_edge_types).begin(), + (*this_label_unique_key_edge_types).end(), + static_cast(j)))); + } + + if ((renumber_map_type_end_offset - renumber_map_type_start_offset) != + (type_end_offset - type_start_offset)) { // renumber map size == # unique edge IDs + return false; + } + + if (org_edgelist_hops) { + if (type_start_offset == type_end_offset) { continue; } + + auto sort_key_first = (*this_label_unique_key_hops).begin(); + thrust::sort_by_key(handle.get_thrust_policy(), + sort_key_first + type_start_offset, + sort_key_first + type_end_offset, + this_label_unique_key_edge_ids.begin() + type_start_offset); + + auto num_unique_keys = + thrust::count_if(handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(type_end_offset - type_start_offset), + cugraph::detail::is_first_in_run_t{ + sort_key_first + type_start_offset}); + rmm::device_uvector min_edge_ids(num_unique_keys, handle.get_stream()); + rmm::device_uvector max_edge_ids(num_unique_keys, handle.get_stream()); + + auto renumbered_edge_id_first = thrust::make_transform_iterator( + this_label_unique_key_edge_ids.begin(), + cuda::proclaim_return_type( + [this_type_sorted_org_edge_ids = raft::device_span( + this_type_sorted_org_edge_ids.data(), this_type_sorted_org_edge_ids.size()), + this_type_matching_renumbered_edge_ids = raft::device_span( + this_type_matching_renumbered_edge_ids.data(), + this_type_matching_renumbered_edge_ids.size())] __device__(edge_id_t id) { + auto it = thrust::lower_bound(thrust::seq, + this_type_sorted_org_edge_ids.begin(), + this_type_sorted_org_edge_ids.end(), + id); + return this_type_matching_renumbered_edge_ids[thrust::distance( + this_type_sorted_org_edge_ids.begin(), it)]; + })); + + thrust::reduce_by_key(handle.get_thrust_policy(), + sort_key_first + type_start_offset, + sort_key_first + type_end_offset, + renumbered_edge_id_first + type_start_offset, + thrust::make_discard_iterator(), + min_edge_ids.begin(), + thrust::equal_to{}, + thrust::minimum{}); + thrust::reduce_by_key(handle.get_thrust_policy(), + sort_key_first + type_start_offset, + sort_key_first + type_end_offset, + renumbered_edge_id_first + type_start_offset, + thrust::make_discard_iterator(), + max_edge_ids.begin(), + thrust::equal_to{}, + thrust::maximum{}); + + auto num_violations = + thrust::count_if(handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{1}), + thrust::make_counting_iterator(min_edge_ids.size()), + [min_edge_ids = raft::device_span(min_edge_ids.data(), + min_edge_ids.size()), + max_edge_ids = raft::device_span( + max_edge_ids.data(), max_edge_ids.size())] __device__(size_t i) { + return min_edge_ids[i] <= max_edge_ids[i - 1]; + }); + + if (num_violations != 0) { return false; } + } + } + } + + return true; +} + +template bool check_edge_id_renumber_map_invariants( + raft::handle_t const& handle, + raft::device_span org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumber_map, + std::optional> renumber_map_label_type_offsets, + size_t num_labels, + size_t num_edge_types); + +template bool check_edge_id_renumber_map_invariants( + raft::handle_t const& handle, + raft::device_span org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumber_map, + std::optional> renumber_map_label_type_offsets, + size_t num_labels, + size_t num_edge_types); diff --git a/cpp/tests/sampling/detail/sampling_post_processing_validate.hpp b/cpp/tests/sampling/detail/sampling_post_processing_validate.hpp new file mode 100644 index 00000000000..986265b368f --- /dev/null +++ b/cpp/tests/sampling/detail/sampling_post_processing_validate.hpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +template +bool check_offsets(raft::handle_t const& handle, + raft::device_span offsets, + index_t num_segments, + index_t num_elements); + +template +bool check_edgelist_is_sorted(raft::handle_t const& handle, + raft::device_span edgelist_majors, + raft::device_span edgelist_minors); + +// unrenumber the renumbered edge list and check whether the original & unrenumbered edge lists are +// identical +template +bool compare_edgelist(raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumber_map, + std::optional> renumber_map_label_offsets, + size_t num_labels); + +// unrenumber the renumbered edge list and check whether the original & unrenumbered edge lists +// are identical +template +bool compare_heterogeneous_edgelist( + raft::handle_t const& handle, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_weights, + std::optional> org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumbered_edgelist_srcs, + raft::device_span renumbered_edgelist_dsts, + std::optional> renumbered_edgelist_weights, + std::optional> renumbered_edgelist_edge_ids, + std::optional> renumbered_edgelist_label_edge_type_hop_offsets, + raft::device_span vertex_renumber_map, + raft::device_span vertex_renumber_map_label_type_offsets, + std::optional> edge_id_renumber_map, + std::optional> edge_id_renumber_map_label_type_offsets, + raft::device_span vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + size_t num_edge_types, + size_t num_hops); + +template +bool check_vertex_renumber_map_invariants( + raft::handle_t const& handle, + std::optional> starting_vertices, + std::optional> starting_vertex_label_offsets, + raft::device_span org_edgelist_srcs, + raft::device_span org_edgelist_dsts, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumber_map, + std::optional> renumber_map_label_type_offsets, + std::optional> vertex_type_offsets, + size_t num_labels, + size_t num_vertex_types, + bool src_is_major); + +template +bool check_edge_id_renumber_map_invariants( + raft::handle_t const& handle, + raft::device_span org_edgelist_edge_ids, + std::optional> org_edgelist_edge_types, + std::optional> org_edgelist_hops, + std::optional> org_edgelist_label_offsets, + raft::device_span renumber_map, + std::optional> renumber_map_label_type_offsets, + size_t num_labels, + size_t num_edge_types); diff --git a/cpp/tests/sampling/sampling_heterogeneous_post_processing_test.cpp b/cpp/tests/sampling/sampling_heterogeneous_post_processing_test.cpp new file mode 100644 index 00000000000..2b2049dc8db --- /dev/null +++ b/cpp/tests/sampling/sampling_heterogeneous_post_processing_test.cpp @@ -0,0 +1,828 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail/sampling_post_processing_validate.hpp" +#include "utilities/base_fixture.hpp" +#include "utilities/conversion_utilities.hpp" +#include "utilities/property_generator_utilities.hpp" + +#include +#include +#include + +#include + +#include + +#include + +struct SamplingHeterogeneousPostProcessing_Usecase { + size_t num_labels{}; + size_t num_seeds_per_label{}; + size_t num_vertex_types{}; + std::vector fanouts{{-1}}; + bool sample_with_replacement{false}; + + bool src_is_major{true}; + bool renumber_with_seeds{false}; + bool check_correctness{true}; +}; + +template +class Tests_SamplingHeterogeneousPostProcessing + : public ::testing::TestWithParam< + std::tuple> { + public: + Tests_SamplingHeterogeneousPostProcessing() {} + + static void SetUpTestCase() {} + static void TearDownTestCase() {} + + virtual void SetUp() {} + virtual void TearDown() {} + + template + void run_current_test(std::tuple const& param) + { + using label_t = int32_t; + using weight_t = float; + using edge_id_t = edge_t; + using edge_type_t = int32_t; + + bool constexpr store_transposed = false; + bool constexpr renumber = true; + bool constexpr test_weighted = true; + + auto [sampling_heterogeneous_post_processing_usecase, input_usecase] = param; + + raft::handle_t handle{}; + HighResTimer hr_timer{}; + + // 1. create a graph + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Construct graph"); + } + + auto [graph, edge_weights, d_renumber_map_labels] = + cugraph::test::construct_graph( + handle, input_usecase, test_weighted, renumber); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto graph_view = graph.view(); + auto edge_weight_view = + edge_weights ? std::make_optional((*edge_weights).view()) : std::nullopt; + + // 2. vertex type offsets + + raft::random::RngState rng_state(0); + + rmm::device_uvector vertex_type_offsets( + sampling_heterogeneous_post_processing_usecase.num_vertex_types + 1, handle.get_stream()); + { + auto num_vertices = graph_view.number_of_vertices(); + vertex_type_offsets.set_element_to_zero_async(0, handle.get_stream()); + vertex_type_offsets.set_element_async( + vertex_type_offsets.size() - 1, num_vertices, handle.get_stream()); + auto tmp = cugraph::select_random_vertices( + handle, + graph_view, + std::nullopt, + rng_state, + sampling_heterogeneous_post_processing_usecase.num_vertex_types - 1, + false /* with_replacement */, + true /* sort_vertices */); + raft::copy(vertex_type_offsets.data() + 1, tmp.data(), tmp.size(), handle.get_stream()); + } + + // 3. seed vertices (& labels) + + rmm::device_uvector starting_vertices( + sampling_heterogeneous_post_processing_usecase.num_labels * + sampling_heterogeneous_post_processing_usecase.num_seeds_per_label, + handle.get_stream()); + cugraph::detail::uniform_random_fill(handle.get_stream(), + starting_vertices.data(), + starting_vertices.size(), + vertex_t{0}, + graph_view.number_of_vertices(), + rng_state); + auto starting_vertex_labels = (sampling_heterogeneous_post_processing_usecase.num_labels > 1) + ? std::make_optional>( + starting_vertices.size(), handle.get_stream()) + : std::nullopt; + auto starting_vertex_label_offsets = + (sampling_heterogeneous_post_processing_usecase.num_labels > 1) + ? std::make_optional>( + sampling_heterogeneous_post_processing_usecase.num_labels + 1, handle.get_stream()) + : std::nullopt; + if (starting_vertex_labels) { + auto num_seeds_per_label = sampling_heterogeneous_post_processing_usecase.num_seeds_per_label; + for (size_t i = 0; i < sampling_heterogeneous_post_processing_usecase.num_labels; ++i) { + cugraph::detail::scalar_fill(handle.get_stream(), + (*starting_vertex_labels).data() + i * num_seeds_per_label, + num_seeds_per_label, + static_cast(i)); + } + cugraph::detail::stride_fill(handle.get_stream(), + (*starting_vertex_label_offsets).data(), + (*starting_vertex_label_offsets).size(), + size_t{0}, + num_seeds_per_label); + } + + // 4. generate edge IDs and types + + auto num_edge_types = + sampling_heterogeneous_post_processing_usecase.num_vertex_types * + sampling_heterogeneous_post_processing_usecase + .num_vertex_types; // necessary to enforce that edge type dictates edge source vertex type + // and edge destination vertex type. + + std::optional> edge_types{ + std::nullopt}; + if (num_edge_types > 1) { + edge_types = + cugraph::test::generate::edge_property_by_src_dst_types( + handle, + graph_view, + raft::device_span(vertex_type_offsets.data(), vertex_type_offsets.size()), + num_edge_types); + } + + cugraph::edge_property_t edge_ids(handle); + if (edge_types) { + static_assert(std::is_same_v); + edge_ids = + cugraph::test::generate::unique_edge_property_per_type( + handle, graph_view, (*edge_types).view(), static_cast(num_edge_types)); + } else { + edge_ids = cugraph::test::generate::unique_edge_property( + handle, graph_view); + } + + // 5. sampling + + rmm::device_uvector org_edgelist_srcs(0, handle.get_stream()); + rmm::device_uvector org_edgelist_dsts(0, handle.get_stream()); + std::optional> org_edgelist_weights{std::nullopt}; + std::optional> org_edgelist_edge_ids{std::nullopt}; + std::optional> org_edgelist_edge_types{std::nullopt}; + std::optional> org_edgelist_hops{std::nullopt}; + std::optional> org_labels{std::nullopt}; + std::optional> org_edgelist_label_offsets{std::nullopt}; + std::tie(org_edgelist_srcs, + org_edgelist_dsts, + org_edgelist_weights, + org_edgelist_edge_ids, + org_edgelist_edge_types, + org_edgelist_hops, + org_labels, + org_edgelist_label_offsets) = cugraph::uniform_neighbor_sample( + handle, + graph_view, + edge_weight_view, + std::optional>{edge_ids.view()}, + edge_types + ? std::optional>{(*edge_types) + .view()} + : std::nullopt, + raft::device_span(starting_vertices.data(), starting_vertices.size()), + starting_vertex_labels ? std::make_optional>( + (*starting_vertex_labels).data(), (*starting_vertex_labels).size()) + : std::nullopt, + std::nullopt, + raft::host_span(sampling_heterogeneous_post_processing_usecase.fanouts.data(), + sampling_heterogeneous_post_processing_usecase.fanouts.size()), + rng_state, + sampling_heterogeneous_post_processing_usecase.fanouts.size() > 1, + sampling_heterogeneous_post_processing_usecase.sample_with_replacement, + cugraph::prior_sources_behavior_t::EXCLUDE, + false); + + if (!sampling_heterogeneous_post_processing_usecase.src_is_major) { + std::swap(org_edgelist_srcs, org_edgelist_dsts); + } + + // 6. post processing: renumber & sort + + { + rmm::device_uvector renumbered_and_sorted_edgelist_srcs(org_edgelist_srcs.size(), + handle.get_stream()); + rmm::device_uvector renumbered_and_sorted_edgelist_dsts(org_edgelist_dsts.size(), + handle.get_stream()); + auto renumbered_and_sorted_edgelist_weights = + org_edgelist_weights ? std::make_optional>( + (*org_edgelist_weights).size(), handle.get_stream()) + : std::nullopt; + auto renumbered_and_sorted_edgelist_edge_ids = + org_edgelist_edge_ids ? std::make_optional>( + (*org_edgelist_edge_ids).size(), handle.get_stream()) + : std::nullopt; + auto renumbered_and_sorted_edgelist_edge_types = + org_edgelist_edge_types ? std::make_optional>( + (*org_edgelist_edge_types).size(), handle.get_stream()) + : std::nullopt; + auto renumbered_and_sorted_edgelist_hops = + org_edgelist_hops ? std::make_optional(rmm::device_uvector( + (*org_edgelist_hops).size(), handle.get_stream())) + : std::nullopt; + + raft::copy(renumbered_and_sorted_edgelist_srcs.data(), + org_edgelist_srcs.data(), + org_edgelist_srcs.size(), + handle.get_stream()); + raft::copy(renumbered_and_sorted_edgelist_dsts.data(), + org_edgelist_dsts.data(), + org_edgelist_dsts.size(), + handle.get_stream()); + if (renumbered_and_sorted_edgelist_weights) { + raft::copy((*renumbered_and_sorted_edgelist_weights).data(), + (*org_edgelist_weights).data(), + (*org_edgelist_weights).size(), + handle.get_stream()); + } + if (renumbered_and_sorted_edgelist_edge_ids) { + raft::copy((*renumbered_and_sorted_edgelist_edge_ids).data(), + (*org_edgelist_edge_ids).data(), + (*org_edgelist_edge_ids).size(), + handle.get_stream()); + } + if (renumbered_and_sorted_edgelist_edge_types) { + raft::copy((*renumbered_and_sorted_edgelist_edge_types).data(), + (*org_edgelist_edge_types).data(), + (*org_edgelist_edge_types).size(), + handle.get_stream()); + } + if (renumbered_and_sorted_edgelist_hops) { + raft::copy((*renumbered_and_sorted_edgelist_hops).data(), + (*org_edgelist_hops).data(), + (*org_edgelist_hops).size(), + handle.get_stream()); + } + + std::optional> + renumbered_and_sorted_edgelist_label_type_hop_offsets{std::nullopt}; + rmm::device_uvector renumbered_and_sorted_vertex_renumber_map(0, + handle.get_stream()); + rmm::device_uvector renumbered_and_sorted_vertex_renumber_map_label_type_offsets( + 0, handle.get_stream()); + std::optional> renumbered_and_sorted_edge_id_renumber_map{ + std::nullopt}; + std::optional> + renumbered_and_sorted_edge_id_renumber_map_label_type_offsets{std::nullopt}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Renumber and sort sampled edgelist"); + } + + std::tie(renumbered_and_sorted_edgelist_srcs, + renumbered_and_sorted_edgelist_dsts, + renumbered_and_sorted_edgelist_weights, + renumbered_and_sorted_edgelist_edge_ids, + renumbered_and_sorted_edgelist_label_type_hop_offsets, + renumbered_and_sorted_vertex_renumber_map, + renumbered_and_sorted_vertex_renumber_map_label_type_offsets, + renumbered_and_sorted_edge_id_renumber_map, + renumbered_and_sorted_edge_id_renumber_map_label_type_offsets) = + cugraph::heterogeneous_renumber_and_sort_sampled_edgelist( + handle, + std::move(renumbered_and_sorted_edgelist_srcs), + std::move(renumbered_and_sorted_edgelist_dsts), + std::move(renumbered_and_sorted_edgelist_weights), + std::move(renumbered_and_sorted_edgelist_edge_ids), + std::move(renumbered_and_sorted_edgelist_edge_types), + std::move(renumbered_and_sorted_edgelist_hops), + sampling_heterogeneous_post_processing_usecase.renumber_with_seeds + ? std::make_optional>(starting_vertices.data(), + starting_vertices.size()) + : std::nullopt, + (sampling_heterogeneous_post_processing_usecase.renumber_with_seeds && + starting_vertex_label_offsets) + ? std::make_optional>( + (*starting_vertex_label_offsets).data(), (*starting_vertex_label_offsets).size()) + : std::nullopt, + org_edgelist_label_offsets + ? std::make_optional(raft::device_span( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size())) + : std::nullopt, + raft::device_span(vertex_type_offsets.data(), vertex_type_offsets.size()), + sampling_heterogeneous_post_processing_usecase.num_labels, + sampling_heterogeneous_post_processing_usecase.fanouts.size(), + sampling_heterogeneous_post_processing_usecase.num_vertex_types, + num_edge_types, + sampling_heterogeneous_post_processing_usecase.src_is_major); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + if (sampling_heterogeneous_post_processing_usecase.check_correctness) { + if (renumbered_and_sorted_edgelist_label_type_hop_offsets) { + ASSERT_TRUE(check_offsets( + handle, + raft::device_span( + (*renumbered_and_sorted_edgelist_label_type_hop_offsets).data(), + (*renumbered_and_sorted_edgelist_label_type_hop_offsets).size()), + sampling_heterogeneous_post_processing_usecase.num_labels * num_edge_types * + sampling_heterogeneous_post_processing_usecase.fanouts.size(), + renumbered_and_sorted_edgelist_srcs.size())) + << "Renumbered and sorted edge (label, edge type, hop) offset array is invalid."; + } + + ASSERT_TRUE( + check_offsets(handle, + raft::device_span( + renumbered_and_sorted_vertex_renumber_map_label_type_offsets.data(), + renumbered_and_sorted_vertex_renumber_map_label_type_offsets.size()), + sampling_heterogeneous_post_processing_usecase.num_labels * + sampling_heterogeneous_post_processing_usecase.num_vertex_types, + renumbered_and_sorted_vertex_renumber_map.size())) + << "Renumbered and sorted vertex renumber map (label, vertex type) offset array is " + "invalid."; + + if (renumbered_and_sorted_edge_id_renumber_map_label_type_offsets) { + ASSERT_TRUE(check_offsets( + handle, + raft::device_span( + (*renumbered_and_sorted_edge_id_renumber_map_label_type_offsets).data(), + (*renumbered_and_sorted_edge_id_renumber_map_label_type_offsets).size()), + sampling_heterogeneous_post_processing_usecase.num_labels * num_edge_types, + (*renumbered_and_sorted_edge_id_renumber_map).size())) + << "Renumbered and sorted edge renumber map (label, edge type) offset array is " + "invalid."; + } + + // check whether the edges are properly sorted + + auto renumbered_and_sorted_edgelist_majors = + sampling_heterogeneous_post_processing_usecase.src_is_major + ? raft::device_span(renumbered_and_sorted_edgelist_srcs.data(), + renumbered_and_sorted_edgelist_srcs.size()) + : raft::device_span(renumbered_and_sorted_edgelist_dsts.data(), + renumbered_and_sorted_edgelist_dsts.size()); + auto renumbered_and_sorted_edgelist_minors = + sampling_heterogeneous_post_processing_usecase.src_is_major + ? raft::device_span(renumbered_and_sorted_edgelist_dsts.data(), + renumbered_and_sorted_edgelist_dsts.size()) + : raft::device_span(renumbered_and_sorted_edgelist_srcs.data(), + renumbered_and_sorted_edgelist_srcs.size()); + + if (renumbered_and_sorted_edgelist_label_type_hop_offsets) { + for (size_t i = 0; + i < sampling_heterogeneous_post_processing_usecase.num_labels * num_edge_types * + sampling_heterogeneous_post_processing_usecase.fanouts.size(); + ++i) { + auto hop_start_offset = (*renumbered_and_sorted_edgelist_label_type_hop_offsets) + .element(i, handle.get_stream()); + auto hop_end_offset = (*renumbered_and_sorted_edgelist_label_type_hop_offsets) + .element(i + 1, handle.get_stream()); + ASSERT_TRUE(check_edgelist_is_sorted( + handle, + raft::device_span( + renumbered_and_sorted_edgelist_majors.data() + hop_start_offset, + hop_end_offset - hop_start_offset), + raft::device_span( + renumbered_and_sorted_edgelist_minors.data() + hop_start_offset, + hop_end_offset - hop_start_offset))) + << "Renumbered and sorted edge list is not properly sorted."; + } + } else { + ASSERT_TRUE(check_edgelist_is_sorted( + handle, + raft::device_span(renumbered_and_sorted_edgelist_majors.data(), + renumbered_and_sorted_edgelist_majors.size()), + raft::device_span(renumbered_and_sorted_edgelist_minors.data(), + renumbered_and_sorted_edgelist_minors.size()))) + << "Renumbered and sorted edge list is not properly sorted."; + } + + // check whether renumbering recovers the original edge list + + ASSERT_TRUE(compare_heterogeneous_edgelist( + handle, + raft::device_span(org_edgelist_srcs.data(), org_edgelist_srcs.size()), + raft::device_span(org_edgelist_dsts.data(), org_edgelist_dsts.size()), + org_edgelist_weights ? std::make_optional>( + (*org_edgelist_weights).data(), (*org_edgelist_weights).size()) + : std::nullopt, + org_edgelist_edge_ids + ? std::make_optional>( + (*org_edgelist_edge_ids).data(), (*org_edgelist_edge_ids).size()) + : std::nullopt, + org_edgelist_edge_types + ? std::make_optional>( + (*org_edgelist_edge_types).data(), (*org_edgelist_edge_types).size()) + : std::nullopt, + org_edgelist_hops ? std::make_optional>( + (*org_edgelist_hops).data(), (*org_edgelist_hops).size()) + : std::nullopt, + org_edgelist_label_offsets + ? std::make_optional>( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size()) + : std::nullopt, + raft::device_span(renumbered_and_sorted_edgelist_srcs.data(), + renumbered_and_sorted_edgelist_srcs.size()), + raft::device_span(renumbered_and_sorted_edgelist_dsts.data(), + renumbered_and_sorted_edgelist_dsts.size()), + renumbered_and_sorted_edgelist_weights + ? std::make_optional>( + (*renumbered_and_sorted_edgelist_weights).data(), + (*renumbered_and_sorted_edgelist_weights).size()) + : std::nullopt, + renumbered_and_sorted_edgelist_edge_ids + ? std::make_optional>( + (*renumbered_and_sorted_edgelist_edge_ids).data(), + (*renumbered_and_sorted_edgelist_edge_ids).size()) + : std::nullopt, + renumbered_and_sorted_edgelist_label_type_hop_offsets + ? std::make_optional>( + (*renumbered_and_sorted_edgelist_label_type_hop_offsets).data(), + (*renumbered_and_sorted_edgelist_label_type_hop_offsets).size()) + : std::nullopt, + raft::device_span(renumbered_and_sorted_vertex_renumber_map.data(), + renumbered_and_sorted_vertex_renumber_map.size()), + raft::device_span( + renumbered_and_sorted_vertex_renumber_map_label_type_offsets.data(), + renumbered_and_sorted_vertex_renumber_map_label_type_offsets.size()), + renumbered_and_sorted_edge_id_renumber_map + ? std::make_optional>( + (*renumbered_and_sorted_edge_id_renumber_map).data(), + (*renumbered_and_sorted_edge_id_renumber_map).size()) + : std::nullopt, + renumbered_and_sorted_edge_id_renumber_map_label_type_offsets + ? std::make_optional>( + (*renumbered_and_sorted_edge_id_renumber_map_label_type_offsets).data(), + (*renumbered_and_sorted_edge_id_renumber_map_label_type_offsets).size()) + : std::nullopt, + raft::device_span(vertex_type_offsets.data(), vertex_type_offsets.size()), + sampling_heterogeneous_post_processing_usecase.num_labels, + sampling_heterogeneous_post_processing_usecase.num_vertex_types, + num_edge_types, + sampling_heterogeneous_post_processing_usecase.fanouts.size())) + << "Unrenumbering the renumbered and sorted edge list does not recover the original " + "edgelist."; + + // Check the invariants in vertex renumber_map + + ASSERT_TRUE(check_vertex_renumber_map_invariants( + handle, + sampling_heterogeneous_post_processing_usecase.renumber_with_seeds + ? std::make_optional>(starting_vertices.data(), + starting_vertices.size()) + : std::nullopt, + (sampling_heterogeneous_post_processing_usecase.renumber_with_seeds && + starting_vertex_label_offsets) + ? std::make_optional>( + (*starting_vertex_label_offsets).data(), (*starting_vertex_label_offsets).size()) + : std::nullopt, + raft::device_span(org_edgelist_srcs.data(), org_edgelist_srcs.size()), + raft::device_span(org_edgelist_dsts.data(), org_edgelist_dsts.size()), + org_edgelist_hops ? std::make_optional>( + (*org_edgelist_hops).data(), (*org_edgelist_hops).size()) + : std::nullopt, + org_edgelist_label_offsets + ? std::make_optional>( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size()) + : std::nullopt, + raft::device_span(renumbered_and_sorted_vertex_renumber_map.data(), + renumbered_and_sorted_vertex_renumber_map.size()), + std::make_optional>( + renumbered_and_sorted_vertex_renumber_map_label_type_offsets.data(), + renumbered_and_sorted_vertex_renumber_map_label_type_offsets.size()), + raft::device_span(vertex_type_offsets.data(), vertex_type_offsets.size()), + sampling_heterogeneous_post_processing_usecase.num_labels, + sampling_heterogeneous_post_processing_usecase.num_vertex_types, + sampling_heterogeneous_post_processing_usecase.src_is_major)) + << "Renumbered and sorted output vertex renumber map violates invariants."; + + // Check the invariants in edge renumber_map + + if (org_edgelist_edge_ids) { + ASSERT_TRUE(check_edge_id_renumber_map_invariants( + handle, + raft::device_span((*org_edgelist_edge_ids).data(), + (*org_edgelist_edge_ids).size()), + org_edgelist_edge_types + ? std::make_optional>( + (*org_edgelist_edge_types).data(), (*org_edgelist_edge_types).size()) + : std::nullopt, + org_edgelist_hops ? std::make_optional>( + (*org_edgelist_hops).data(), (*org_edgelist_hops).size()) + : std::nullopt, + org_edgelist_label_offsets + ? std::make_optional>( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size()) + : std::nullopt, + raft::device_span( + (*renumbered_and_sorted_edge_id_renumber_map).data(), + (*renumbered_and_sorted_edge_id_renumber_map).size()), + renumbered_and_sorted_edge_id_renumber_map_label_type_offsets + ? std::make_optional>( + (*renumbered_and_sorted_edge_id_renumber_map_label_type_offsets).data(), + (*renumbered_and_sorted_edge_id_renumber_map_label_type_offsets).size()) + : std::nullopt, + sampling_heterogeneous_post_processing_usecase.num_labels, + num_edge_types)) + << "Renumbered and sorted output edge ID renumber map violates invariants."; + } + } + } + } +}; + +using Tests_SamplingHeterogeneousPostProcessing_File = + Tests_SamplingHeterogeneousPostProcessing; +using Tests_SamplingHeterogeneousPostProcessing_Rmat = + Tests_SamplingHeterogeneousPostProcessing; + +TEST_P(Tests_SamplingHeterogeneousPostProcessing_File, CheckInt32Int32) +{ + run_current_test(override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_SamplingHeterogeneousPostProcessing_Rmat, CheckInt32Int32) +{ + run_current_test(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_SamplingHeterogeneousPostProcessing_Rmat, CheckInt32Int64) +{ + run_current_test(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_SamplingHeterogeneousPostProcessing_Rmat, CheckInt64Int64) +{ + run_current_test(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_SamplingHeterogeneousPostProcessing_File, + ::testing::Combine( + // enable correctness checks + ::testing::Values( + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 15}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, true, true, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), + cugraph::test::File_Usecase("test/datasets/dolphins.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_SamplingHeterogeneousPostProcessing_Rmat, + ::testing::Combine( + // enable correctness checks + ::testing::Values( + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {10}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {10}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 1, {5, 10, 25}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{1, 16, 4, {5, 10, 25}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {10}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {10}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, false, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, false, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, true, false, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 1, {5, 10, 25}, true, true, true}, + SamplingHeterogeneousPostProcessing_Usecase{32, 16, 4, {5, 10, 25}, true, true, true}), + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false)))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, + Tests_SamplingHeterogeneousPostProcessing_Rmat, + ::testing::Combine( + // enable correctness checks + ::testing::Values( + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {10}, false, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {10}, false, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {10}, false, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {10}, false, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {10}, false, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {10}, false, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {10}, false, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {10}, false, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {10}, true, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {10}, true, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {10}, true, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {10}, true, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {10}, true, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {10}, true, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {10}, true, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {10}, true, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 1, 64, 1, {5, 10, 15}, false, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 1, 64, 16, {5, 10, 15}, false, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {5, 10, 15}, false, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 1, 64, 16, {5, 10, 15}, false, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {5, 10, 15}, false, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 1, 64, 16, {5, 10, 15}, false, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {5, 10, 15}, false, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {5, 10, 15}, false, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {5, 10, 15}, true, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 1, 64, 16, {5, 10, 15}, true, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {5, 10, 15}, true, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {5, 10, 15}, true, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {5, 10, 15}, true, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {5, 10, 15}, true, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 1, {5, 10, 15}, true, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{1, 64, 16, {5, 10, 15}, true, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 1, {10}, false, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 16, {10}, false, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 1, {10}, false, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 16, {10}, false, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 1, {10}, false, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 16, {10}, false, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 1, {10}, false, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 16, {10}, false, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 1, {10}, true, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 16, {10}, true, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 1, {10}, true, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 16, {10}, true, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 1, {10}, true, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 16, {10}, true, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 1, {10}, true, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 16, {10}, true, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 1, {5, 10, 15}, false, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 16, {5, 10, 15}, false, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 1, {5, 10, 15}, false, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 16, {5, 10, 15}, false, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 1, {5, 10, 15}, false, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 16, {5, 10, 15}, false, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 1, {5, 10, 15}, false, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 16, {5, 10, 15}, false, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 1, {5, 10, 15}, true, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 16, {5, 10, 15}, true, false, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 1, {5, 10, 15}, true, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 16, {5, 10, 15}, true, false, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 1, {5, 10, 15}, true, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 16, {5, 10, 15}, true, true, false, false}, + SamplingHeterogeneousPostProcessing_Usecase{128, 64, 1, {5, 10, 15}, true, true, true, false}, + SamplingHeterogeneousPostProcessing_Usecase{ + 128, 64, 16, {5, 10, 15}, true, true, true, false}), + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false)))); + +CUGRAPH_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/sampling/sampling_post_processing_test.cu b/cpp/tests/sampling/sampling_post_processing_test.cpp similarity index 52% rename from cpp/tests/sampling/sampling_post_processing_test.cu rename to cpp/tests/sampling/sampling_post_processing_test.cpp index ecec1d0ed89..b262794d26d 100644 --- a/cpp/tests/sampling/sampling_post_processing_test.cu +++ b/cpp/tests/sampling/sampling_post_processing_test.cpp @@ -14,30 +14,17 @@ * limitations under the License. */ +#include "detail/sampling_post_processing_validate.hpp" #include "utilities/base_fixture.hpp" -#include -#include #include #include -#include #include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - #include struct SamplingPostProcessing_Usecase { @@ -53,385 +40,6 @@ struct SamplingPostProcessing_Usecase { bool check_correctness{true}; }; -template -bool compare_edgelist(raft::handle_t const& handle, - raft::device_span org_edgelist_srcs, - raft::device_span org_edgelist_dsts, - std::optional> org_edgelist_weights, - raft::device_span renumbered_edgelist_srcs, - raft::device_span renumbered_edgelist_dsts, - std::optional> renumbered_edgelist_weights, - std::optional> renumber_map) -{ - if (org_edgelist_srcs.size() != renumbered_edgelist_srcs.size()) { return false; } - - rmm::device_uvector sorted_org_edgelist_srcs(org_edgelist_srcs.size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - org_edgelist_srcs.begin(), - org_edgelist_srcs.end(), - sorted_org_edgelist_srcs.begin()); - rmm::device_uvector sorted_org_edgelist_dsts(org_edgelist_dsts.size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - org_edgelist_dsts.begin(), - org_edgelist_dsts.end(), - sorted_org_edgelist_dsts.begin()); - auto sorted_org_edgelist_weights = org_edgelist_weights - ? std::make_optional>( - (*org_edgelist_weights).size(), handle.get_stream()) - : std::nullopt; - if (sorted_org_edgelist_weights) { - thrust::copy(handle.get_thrust_policy(), - (*org_edgelist_weights).begin(), - (*org_edgelist_weights).end(), - (*sorted_org_edgelist_weights).begin()); - } - - if (sorted_org_edgelist_weights) { - auto sorted_org_edge_first = thrust::make_zip_iterator(sorted_org_edgelist_srcs.begin(), - sorted_org_edgelist_dsts.begin(), - (*sorted_org_edgelist_weights).begin()); - thrust::sort(handle.get_thrust_policy(), - sorted_org_edge_first, - sorted_org_edge_first + sorted_org_edgelist_srcs.size()); - } else { - auto sorted_org_edge_first = - thrust::make_zip_iterator(sorted_org_edgelist_srcs.begin(), sorted_org_edgelist_dsts.begin()); - thrust::sort(handle.get_thrust_policy(), - sorted_org_edge_first, - sorted_org_edge_first + sorted_org_edgelist_srcs.size()); - } - - rmm::device_uvector sorted_unrenumbered_edgelist_srcs(renumbered_edgelist_srcs.size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - renumbered_edgelist_srcs.begin(), - renumbered_edgelist_srcs.end(), - sorted_unrenumbered_edgelist_srcs.begin()); - rmm::device_uvector sorted_unrenumbered_edgelist_dsts(renumbered_edgelist_dsts.size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - renumbered_edgelist_dsts.begin(), - renumbered_edgelist_dsts.end(), - sorted_unrenumbered_edgelist_dsts.begin()); - auto sorted_unrenumbered_edgelist_weights = - renumbered_edgelist_weights ? std::make_optional>( - (*renumbered_edgelist_weights).size(), handle.get_stream()) - : std::nullopt; - if (sorted_unrenumbered_edgelist_weights) { - thrust::copy(handle.get_thrust_policy(), - (*renumbered_edgelist_weights).begin(), - (*renumbered_edgelist_weights).end(), - (*sorted_unrenumbered_edgelist_weights).begin()); - } - - if (renumber_map) { - cugraph::unrenumber_int_vertices( - handle, - sorted_unrenumbered_edgelist_srcs.data(), - sorted_unrenumbered_edgelist_srcs.size(), - (*renumber_map).data(), - std::vector{static_cast((*renumber_map).size())}); - cugraph::unrenumber_int_vertices( - handle, - sorted_unrenumbered_edgelist_dsts.data(), - sorted_unrenumbered_edgelist_dsts.size(), - (*renumber_map).data(), - std::vector{static_cast((*renumber_map).size())}); - } - - if (sorted_unrenumbered_edgelist_weights) { - auto sorted_unrenumbered_edge_first = - thrust::make_zip_iterator(sorted_unrenumbered_edgelist_srcs.begin(), - sorted_unrenumbered_edgelist_dsts.begin(), - (*sorted_unrenumbered_edgelist_weights).begin()); - thrust::sort(handle.get_thrust_policy(), - sorted_unrenumbered_edge_first, - sorted_unrenumbered_edge_first + sorted_unrenumbered_edgelist_srcs.size()); - - auto sorted_org_edge_first = thrust::make_zip_iterator(sorted_org_edgelist_srcs.begin(), - sorted_org_edgelist_dsts.begin(), - (*sorted_org_edgelist_weights).begin()); - return thrust::equal(handle.get_thrust_policy(), - sorted_org_edge_first, - sorted_org_edge_first + sorted_org_edgelist_srcs.size(), - sorted_unrenumbered_edge_first); - } else { - auto sorted_unrenumbered_edge_first = thrust::make_zip_iterator( - sorted_unrenumbered_edgelist_srcs.begin(), sorted_unrenumbered_edgelist_dsts.begin()); - thrust::sort(handle.get_thrust_policy(), - sorted_unrenumbered_edge_first, - sorted_unrenumbered_edge_first + sorted_unrenumbered_edgelist_srcs.size()); - - auto sorted_org_edge_first = - thrust::make_zip_iterator(sorted_org_edgelist_srcs.begin(), sorted_org_edgelist_dsts.begin()); - return thrust::equal(handle.get_thrust_policy(), - sorted_org_edge_first, - sorted_org_edge_first + sorted_org_edgelist_srcs.size(), - sorted_unrenumbered_edge_first); - } -} - -template -bool check_renumber_map_invariants( - raft::handle_t const& handle, - std::optional> starting_vertices, - raft::device_span org_edgelist_srcs, - raft::device_span org_edgelist_dsts, - std::optional> org_edgelist_hops, - raft::device_span renumber_map, - bool src_is_major) -{ - // Check the invariants in renumber_map - // Say we found the minimum (primary key:hop, secondary key:flag) pairs for every unique vertices, - // where flag is 0 for sources and 1 for destinations. Then, vertices with smaller (hop, flag) - // pairs should be renumbered to smaller numbers than vertices with larger (hop, flag) pairs. - auto org_edgelist_majors = src_is_major ? org_edgelist_srcs : org_edgelist_dsts; - auto org_edgelist_minors = src_is_major ? org_edgelist_dsts : org_edgelist_srcs; - - rmm::device_uvector unique_majors(org_edgelist_majors.size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - org_edgelist_majors.begin(), - org_edgelist_majors.end(), - unique_majors.begin()); - if (starting_vertices) { - auto old_size = unique_majors.size(); - unique_majors.resize(old_size + (*starting_vertices).size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - (*starting_vertices).begin(), - (*starting_vertices).end(), - unique_majors.begin() + old_size); - } - - std::optional> unique_major_hops = - org_edgelist_hops ? std::make_optional>( - (*org_edgelist_hops).size(), handle.get_stream()) - : std::nullopt; - if (org_edgelist_hops) { - thrust::copy(handle.get_thrust_policy(), - (*org_edgelist_hops).begin(), - (*org_edgelist_hops).end(), - (*unique_major_hops).begin()); - if (starting_vertices) { - auto old_size = (*unique_major_hops).size(); - (*unique_major_hops).resize(old_size + (*starting_vertices).size(), handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), - (*unique_major_hops).begin() + old_size, - (*unique_major_hops).end(), - int32_t{0}); - } - - auto pair_first = - thrust::make_zip_iterator(unique_majors.begin(), (*unique_major_hops).begin()); - thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + unique_majors.size()); - unique_majors.resize( - thrust::distance(unique_majors.begin(), - thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), - unique_majors.begin(), - unique_majors.end(), - (*unique_major_hops).begin()))), - handle.get_stream()); - (*unique_major_hops).resize(unique_majors.size(), handle.get_stream()); - } else { - thrust::sort(handle.get_thrust_policy(), unique_majors.begin(), unique_majors.end()); - unique_majors.resize( - thrust::distance( - unique_majors.begin(), - thrust::unique(handle.get_thrust_policy(), unique_majors.begin(), unique_majors.end())), - handle.get_stream()); - } - - rmm::device_uvector unique_minors(org_edgelist_minors.size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - org_edgelist_minors.begin(), - org_edgelist_minors.end(), - unique_minors.begin()); - std::optional> unique_minor_hops = - org_edgelist_hops ? std::make_optional>( - (*org_edgelist_hops).size(), handle.get_stream()) - : std::nullopt; - if (org_edgelist_hops) { - thrust::copy(handle.get_thrust_policy(), - (*org_edgelist_hops).begin(), - (*org_edgelist_hops).end(), - (*unique_minor_hops).begin()); - - auto pair_first = - thrust::make_zip_iterator(unique_minors.begin(), (*unique_minor_hops).begin()); - thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + unique_minors.size()); - unique_minors.resize( - thrust::distance(unique_minors.begin(), - thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), - unique_minors.begin(), - unique_minors.end(), - (*unique_minor_hops).begin()))), - handle.get_stream()); - (*unique_minor_hops).resize(unique_minors.size(), handle.get_stream()); - } else { - thrust::sort(handle.get_thrust_policy(), unique_minors.begin(), unique_minors.end()); - unique_minors.resize( - thrust::distance( - unique_minors.begin(), - thrust::unique(handle.get_thrust_policy(), unique_minors.begin(), unique_minors.end())), - handle.get_stream()); - } - - rmm::device_uvector sorted_org_vertices(renumber_map.size(), handle.get_stream()); - rmm::device_uvector matching_renumbered_vertices(sorted_org_vertices.size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - renumber_map.begin(), - renumber_map.end(), - sorted_org_vertices.begin()); - thrust::sequence(handle.get_thrust_policy(), - matching_renumbered_vertices.begin(), - matching_renumbered_vertices.end(), - vertex_t{0}); - thrust::sort_by_key(handle.get_thrust_policy(), - sorted_org_vertices.begin(), - sorted_org_vertices.end(), - matching_renumbered_vertices.begin()); - - if (org_edgelist_hops) { - rmm::device_uvector merged_vertices(unique_majors.size() + unique_minors.size(), - handle.get_stream()); - rmm::device_uvector merged_hops(merged_vertices.size(), handle.get_stream()); - rmm::device_uvector merged_flags(merged_vertices.size(), handle.get_stream()); - - auto major_triplet_first = thrust::make_zip_iterator(unique_majors.begin(), - (*unique_major_hops).begin(), - thrust::make_constant_iterator(int8_t{0})); - auto minor_triplet_first = thrust::make_zip_iterator(unique_minors.begin(), - (*unique_minor_hops).begin(), - thrust::make_constant_iterator(int8_t{1})); - thrust::merge(handle.get_thrust_policy(), - major_triplet_first, - major_triplet_first + unique_majors.size(), - minor_triplet_first, - minor_triplet_first + unique_minors.size(), - thrust::make_zip_iterator( - merged_vertices.begin(), merged_hops.begin(), merged_flags.begin())); - merged_vertices.resize( - thrust::distance(merged_vertices.begin(), - thrust::get<0>(thrust::unique_by_key( - handle.get_thrust_policy(), - merged_vertices.begin(), - merged_vertices.end(), - thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin())))), - handle.get_stream()); - merged_hops.resize(merged_vertices.size(), handle.get_stream()); - merged_flags.resize(merged_vertices.size(), handle.get_stream()); - - auto sort_key_first = thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin()); - thrust::sort_by_key(handle.get_thrust_policy(), - sort_key_first, - sort_key_first + merged_hops.size(), - merged_vertices.begin()); - - auto num_unique_keys = thrust::count_if( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(merged_hops.size()), - cugraph::detail::is_first_in_run_t{sort_key_first}); - rmm::device_uvector min_vertices(num_unique_keys, handle.get_stream()); - rmm::device_uvector max_vertices(num_unique_keys, handle.get_stream()); - - auto renumbered_merged_vertex_first = thrust::make_transform_iterator( - merged_vertices.begin(), - cuda::proclaim_return_type( - [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), - sorted_org_vertices.size()), - matching_renumbered_vertices = raft::device_span( - matching_renumbered_vertices.data(), - matching_renumbered_vertices.size())] __device__(vertex_t major) { - auto it = thrust::lower_bound( - thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), major); - return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)]; - })); - - thrust::reduce_by_key(handle.get_thrust_policy(), - sort_key_first, - sort_key_first + merged_hops.size(), - renumbered_merged_vertex_first, - thrust::make_discard_iterator(), - min_vertices.begin(), - thrust::equal_to>{}, - thrust::minimum{}); - thrust::reduce_by_key(handle.get_thrust_policy(), - sort_key_first, - sort_key_first + merged_hops.size(), - renumbered_merged_vertex_first, - thrust::make_discard_iterator(), - max_vertices.begin(), - thrust::equal_to>{}, - thrust::maximum{}); - - auto num_violations = thrust::count_if( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{1}), - thrust::make_counting_iterator(min_vertices.size()), - [min_vertices = raft::device_span(min_vertices.data(), min_vertices.size()), - max_vertices = raft::device_span(max_vertices.data(), - max_vertices.size())] __device__(size_t i) { - return min_vertices[i] <= max_vertices[i - 1]; - }); - - return (num_violations == 0); - } else { - unique_minors.resize( - thrust::distance( - unique_minors.begin(), - thrust::remove_if(handle.get_thrust_policy(), - unique_minors.begin(), - unique_minors.end(), - [sorted_unique_majors = raft::device_span( - unique_majors.data(), unique_majors.size())] __device__(auto minor) { - return thrust::binary_search(thrust::seq, - sorted_unique_majors.begin(), - sorted_unique_majors.end(), - minor); - })), - handle.get_stream()); - - auto max_major_renumbered_vertex = thrust::transform_reduce( - handle.get_thrust_policy(), - unique_majors.begin(), - unique_majors.end(), - cuda::proclaim_return_type( - [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), - sorted_org_vertices.size()), - matching_renumbered_vertices = raft::device_span( - matching_renumbered_vertices.data(), - matching_renumbered_vertices.size())] __device__(vertex_t major) -> vertex_t { - auto it = thrust::lower_bound( - thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), major); - return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)]; - }), - std::numeric_limits::lowest(), - thrust::maximum{}); - - auto min_minor_renumbered_vertex = thrust::transform_reduce( - handle.get_thrust_policy(), - unique_minors.begin(), - unique_minors.end(), - cuda::proclaim_return_type( - [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), - sorted_org_vertices.size()), - matching_renumbered_vertices = raft::device_span( - matching_renumbered_vertices.data(), - matching_renumbered_vertices.size())] __device__(vertex_t minor) -> vertex_t { - auto it = thrust::lower_bound( - thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), minor); - return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)]; - }), - std::numeric_limits::max(), - thrust::minimum{}); - - return (max_major_renumbered_vertex < min_minor_renumbered_vertex); - } -} - template class Tests_SamplingPostProcessing : public ::testing::TestWithParam> { @@ -450,7 +58,7 @@ class Tests_SamplingPostProcessing { using label_t = int32_t; using weight_t = float; - using edge_id_t = vertex_t; + using edge_id_t = edge_t; using edge_type_t = int32_t; bool constexpr store_transposed = false; @@ -462,6 +70,8 @@ class Tests_SamplingPostProcessing raft::handle_t handle{}; HighResTimer hr_timer{}; + // 1. create a graph + if (cugraph::test::g_perf) { RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement hr_timer.start("Construct graph"); @@ -481,6 +91,8 @@ class Tests_SamplingPostProcessing auto edge_weight_view = edge_weights ? std::make_optional((*edge_weights).view()) : std::nullopt; + // 2. seed vertices (& labels) + raft::random::RngState rng_state(0); rmm::device_uvector starting_vertices( @@ -503,20 +115,22 @@ class Tests_SamplingPostProcessing sampling_post_processing_usecase.num_labels + 1, handle.get_stream()) : std::nullopt; if (starting_vertex_labels) { - thrust::tabulate( - handle.get_thrust_policy(), - (*starting_vertex_labels).begin(), - (*starting_vertex_labels).end(), - [num_seeds_per_label = sampling_post_processing_usecase.num_seeds_per_label] __device__( - size_t i) { return static_cast(i / num_seeds_per_label); }); - thrust::tabulate( - handle.get_thrust_policy(), - (*starting_vertex_label_offsets).begin(), - (*starting_vertex_label_offsets).end(), - [num_seeds_per_label = sampling_post_processing_usecase.num_seeds_per_label] __device__( - size_t i) { return num_seeds_per_label * i; }); + auto num_seeds_per_label = sampling_post_processing_usecase.num_seeds_per_label; + for (size_t i = 0; i < sampling_post_processing_usecase.num_labels; ++i) { + cugraph::detail::scalar_fill(handle.get_stream(), + (*starting_vertex_labels).data() + i * num_seeds_per_label, + num_seeds_per_label, + static_cast(i)); + } + cugraph::detail::stride_fill(handle.get_stream(), + (*starting_vertex_label_offsets).data(), + (*starting_vertex_label_offsets).size(), + size_t{0}, + num_seeds_per_label); } + // 3. sampling + rmm::device_uvector org_edgelist_srcs(0, handle.get_stream()); rmm::device_uvector org_edgelist_dsts(0, handle.get_stream()); std::optional> org_edgelist_weights{std::nullopt}; @@ -562,6 +176,8 @@ class Tests_SamplingPostProcessing std::swap(org_edgelist_srcs, org_edgelist_dsts); } + // 4. post processing: renumber & sort + { rmm::device_uvector renumbered_and_sorted_edgelist_srcs(org_edgelist_srcs.size(), handle.get_stream()); @@ -652,178 +268,138 @@ class Tests_SamplingPostProcessing if (sampling_post_processing_usecase.check_correctness) { if (renumbered_and_sorted_edgelist_label_hop_offsets) { - ASSERT_TRUE((*renumbered_and_sorted_edgelist_label_hop_offsets).size() == - sampling_post_processing_usecase.num_labels * - sampling_post_processing_usecase.fanouts.size() + - 1) - << "Renumbered and sorted edge list (label,hop) offset array size should coincide with " - "the number of labels * the number of hops + 1."; - - ASSERT_TRUE(thrust::is_sorted(handle.get_thrust_policy(), - (*renumbered_and_sorted_edgelist_label_hop_offsets).begin(), - (*renumbered_and_sorted_edgelist_label_hop_offsets).end())) - << "Renumbered and sorted edge list (label,hop) offset array values should be " - "non-decreasing."; - - ASSERT_TRUE( - (*renumbered_and_sorted_edgelist_label_hop_offsets).back_element(handle.get_stream()) == - renumbered_and_sorted_edgelist_srcs.size()) - << "Renumbered and sorted edge list (label,hop) offset array's last element should " - "coincide with the number of edges."; + ASSERT_TRUE(check_offsets(handle, + raft::device_span( + (*renumbered_and_sorted_edgelist_label_hop_offsets).data(), + (*renumbered_and_sorted_edgelist_label_hop_offsets).size()), + sampling_post_processing_usecase.num_labels * + sampling_post_processing_usecase.fanouts.size(), + renumbered_and_sorted_edgelist_srcs.size())) + << "Renumbered and sorted edge (label, hop) offset array is invalid."; } if (renumbered_and_sorted_renumber_map_label_offsets) { - ASSERT_TRUE((*renumbered_and_sorted_renumber_map_label_offsets).size() == - sampling_post_processing_usecase.num_labels + 1) - << "Renumbered and sorted offset (label, hop) offset array size should coincide with " - "the number of labels + 1."; - - ASSERT_TRUE(thrust::is_sorted(handle.get_thrust_policy(), - (*renumbered_and_sorted_renumber_map_label_offsets).begin(), - (*renumbered_and_sorted_renumber_map_label_offsets).end())) - << "Renumbered and sorted renumber map label offset array values should be " - "non-decreasing."; - - ASSERT_TRUE( - (*renumbered_and_sorted_renumber_map_label_offsets).back_element(handle.get_stream()) == - renumbered_and_sorted_renumber_map.size()) - << "Renumbered and sorted renumber map label offset array's last value should coincide " - "with the renumber map size."; + ASSERT_TRUE(check_offsets(handle, + raft::device_span( + (*renumbered_and_sorted_renumber_map_label_offsets).data(), + (*renumbered_and_sorted_renumber_map_label_offsets).size()), + sampling_post_processing_usecase.num_labels, + renumbered_and_sorted_renumber_map.size())) + << "Renumbered and sorted renumber map label offset array is invalid."; } - for (size_t i = 0; i < sampling_post_processing_usecase.num_labels; ++i) { - size_t starting_vertex_start_offset = - starting_vertex_label_offsets - ? (*starting_vertex_label_offsets).element(i, handle.get_stream()) - : size_t{0}; - size_t starting_vertex_end_offset = - starting_vertex_label_offsets - ? (*starting_vertex_label_offsets).element(i + 1, handle.get_stream()) - : starting_vertices.size(); + // check whether the edges are properly sorted + + auto renumbered_and_sorted_edgelist_majors = + sampling_post_processing_usecase.src_is_major + ? raft::device_span(renumbered_and_sorted_edgelist_srcs.data(), + renumbered_and_sorted_edgelist_srcs.size()) + : raft::device_span(renumbered_and_sorted_edgelist_dsts.data(), + renumbered_and_sorted_edgelist_dsts.size()); + auto renumbered_and_sorted_edgelist_minors = + sampling_post_processing_usecase.src_is_major + ? raft::device_span(renumbered_and_sorted_edgelist_dsts.data(), + renumbered_and_sorted_edgelist_dsts.size()) + : raft::device_span(renumbered_and_sorted_edgelist_srcs.data(), + renumbered_and_sorted_edgelist_srcs.size()); - size_t edgelist_start_offset = - org_edgelist_label_offsets - ? (*org_edgelist_label_offsets).element(i, handle.get_stream()) - : size_t{0}; - size_t edgelist_end_offset = - org_edgelist_label_offsets - ? (*org_edgelist_label_offsets).element(i + 1, handle.get_stream()) - : org_edgelist_srcs.size(); - if (edgelist_start_offset == edgelist_end_offset) continue; - - auto this_label_starting_vertices = raft::device_span( - starting_vertices.data() + starting_vertex_start_offset, - starting_vertex_end_offset - starting_vertex_start_offset); - - auto this_label_org_edgelist_srcs = - raft::device_span(org_edgelist_srcs.data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset); - auto this_label_org_edgelist_dsts = - raft::device_span(org_edgelist_dsts.data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset); - auto this_label_org_edgelist_hops = - org_edgelist_hops ? std::make_optional>( - (*org_edgelist_hops).data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset) - : std::nullopt; - auto this_label_org_edgelist_weights = - org_edgelist_weights ? std::make_optional>( - (*org_edgelist_weights).data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset) - : std::nullopt; - - auto this_label_output_edgelist_srcs = raft::device_span( - renumbered_and_sorted_edgelist_srcs.data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset); - auto this_label_output_edgelist_dsts = raft::device_span( - renumbered_and_sorted_edgelist_dsts.data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset); - auto this_label_output_edgelist_weights = - renumbered_and_sorted_edgelist_weights - ? std::make_optional>( - (*renumbered_and_sorted_edgelist_weights).data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset) - : std::nullopt; - - size_t renumber_map_start_offset = - renumbered_and_sorted_renumber_map_label_offsets - ? (*renumbered_and_sorted_renumber_map_label_offsets).element(i, handle.get_stream()) - : size_t{0}; - size_t renumber_map_end_offset = renumbered_and_sorted_renumber_map_label_offsets - ? (*renumbered_and_sorted_renumber_map_label_offsets) - .element(i + 1, handle.get_stream()) - : renumbered_and_sorted_renumber_map.size(); - auto this_label_output_renumber_map = raft::device_span( - renumbered_and_sorted_renumber_map.data() + renumber_map_start_offset, - renumber_map_end_offset - renumber_map_start_offset); - - // check whether the edges are properly sorted - - auto this_label_output_edgelist_majors = sampling_post_processing_usecase.src_is_major - ? this_label_output_edgelist_srcs - : this_label_output_edgelist_dsts; - auto this_label_output_edgelist_minors = sampling_post_processing_usecase.src_is_major - ? this_label_output_edgelist_dsts - : this_label_output_edgelist_srcs; - - if (this_label_org_edgelist_hops) { - auto num_hops = sampling_post_processing_usecase.fanouts.size(); - auto edge_first = thrust::make_zip_iterator(this_label_output_edgelist_majors.begin(), - this_label_output_edgelist_minors.begin()); - for (size_t j = 0; j < num_hops; ++j) { - auto hop_start_offset = (*renumbered_and_sorted_edgelist_label_hop_offsets) - .element(i * num_hops + j, handle.get_stream()) - - (*renumbered_and_sorted_edgelist_label_hop_offsets) - .element(i * num_hops, handle.get_stream()); - auto hop_end_offset = (*renumbered_and_sorted_edgelist_label_hop_offsets) - .element(i * num_hops + j + 1, handle.get_stream()) - - (*renumbered_and_sorted_edgelist_label_hop_offsets) - .element(i * num_hops, handle.get_stream()); - ASSERT_TRUE(thrust::is_sorted(handle.get_thrust_policy(), - edge_first + hop_start_offset, - edge_first + hop_end_offset)) - << "Renumbered and sorted output edges are not properly sorted."; - } - } else { - auto edge_first = thrust::make_zip_iterator(this_label_output_edgelist_majors.begin(), - this_label_output_edgelist_minors.begin()); - ASSERT_TRUE(thrust::is_sorted(handle.get_thrust_policy(), - edge_first, - edge_first + this_label_output_edgelist_majors.size())) - << "Renumbered and sorted output edges are not properly sorted."; + if (renumbered_and_sorted_edgelist_label_hop_offsets) { + for (size_t i = 0; i < sampling_post_processing_usecase.num_labels * + sampling_post_processing_usecase.fanouts.size(); + ++i) { + auto hop_start_offset = + (*renumbered_and_sorted_edgelist_label_hop_offsets).element(i, handle.get_stream()); + auto hop_end_offset = (*renumbered_and_sorted_edgelist_label_hop_offsets) + .element(i + 1, handle.get_stream()); + ASSERT_TRUE(check_edgelist_is_sorted( + handle, + raft::device_span( + renumbered_and_sorted_edgelist_majors.data() + hop_start_offset, + hop_end_offset - hop_start_offset), + raft::device_span( + renumbered_and_sorted_edgelist_minors.data() + hop_start_offset, + hop_end_offset - hop_start_offset))) + << "Renumbered and sorted edge list is not properly sorted."; } + } else { + ASSERT_TRUE(check_edgelist_is_sorted( + handle, + raft::device_span(renumbered_and_sorted_edgelist_majors.data(), + renumbered_and_sorted_edgelist_majors.size()), + raft::device_span(renumbered_and_sorted_edgelist_minors.data(), + renumbered_and_sorted_edgelist_minors.size()))) + << "Renumbered and sorted edge list is not properly sorted."; + } - // check whether renumbering recovers the original edge list - - ASSERT_TRUE(compare_edgelist(handle, - this_label_org_edgelist_srcs, - this_label_org_edgelist_dsts, - this_label_org_edgelist_weights, - this_label_output_edgelist_srcs, - this_label_output_edgelist_dsts, - this_label_output_edgelist_weights, - std::make_optional(this_label_output_renumber_map))) - << "Unrenumbering the renumbered and sorted edge list does not recover the original " - "edgelist."; + ASSERT_TRUE(compare_edgelist( + handle, + raft::device_span(org_edgelist_srcs.data(), org_edgelist_srcs.size()), + raft::device_span(org_edgelist_dsts.data(), org_edgelist_dsts.size()), + org_edgelist_weights ? std::make_optional>( + (*org_edgelist_weights).data(), (*org_edgelist_weights).size()) + : std::nullopt, + org_edgelist_label_offsets + ? std::make_optional>( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size()) + : std::nullopt, + raft::device_span(renumbered_and_sorted_edgelist_srcs.data(), + renumbered_and_sorted_edgelist_srcs.size()), + raft::device_span(renumbered_and_sorted_edgelist_dsts.data(), + renumbered_and_sorted_edgelist_dsts.size()), + renumbered_and_sorted_edgelist_weights + ? std::make_optional>( + (*renumbered_and_sorted_edgelist_weights).data(), + (*renumbered_and_sorted_edgelist_weights).size()) + : std::nullopt, + std::make_optional>( + renumbered_and_sorted_renumber_map.data(), renumbered_and_sorted_renumber_map.size()), + renumbered_and_sorted_renumber_map_label_offsets + ? std::make_optional>( + (*renumbered_and_sorted_renumber_map_label_offsets).data(), + (*renumbered_and_sorted_renumber_map_label_offsets).size()) + : std::nullopt, + sampling_post_processing_usecase.num_labels)) + << "Unrenumbering the renumbered and sorted edge list does not recover the original " + "edgelist."; - // Check the invariants in renumber_map + // Check the invariants in renumber_map - ASSERT_TRUE(check_renumber_map_invariants( - handle, - sampling_post_processing_usecase.renumber_with_seeds - ? std::make_optional>( - this_label_starting_vertices.data(), this_label_starting_vertices.size()) - : std::nullopt, - this_label_org_edgelist_srcs, - this_label_org_edgelist_dsts, - this_label_org_edgelist_hops, - this_label_output_renumber_map, - sampling_post_processing_usecase.src_is_major)) - << "Renumbered and sorted output renumber map violates invariants."; - } + ASSERT_TRUE(check_vertex_renumber_map_invariants( + handle, + sampling_post_processing_usecase.renumber_with_seeds + ? std::make_optional>(starting_vertices.data(), + starting_vertices.size()) + : std::nullopt, + (sampling_post_processing_usecase.renumber_with_seeds && starting_vertex_label_offsets) + ? std::make_optional>( + (*starting_vertex_label_offsets).data(), (*starting_vertex_label_offsets).size()) + : std::nullopt, + raft::device_span(org_edgelist_srcs.data(), org_edgelist_srcs.size()), + raft::device_span(org_edgelist_dsts.data(), org_edgelist_dsts.size()), + org_edgelist_hops ? std::make_optional>( + (*org_edgelist_hops).data(), (*org_edgelist_hops).size()) + : std::nullopt, + org_edgelist_label_offsets + ? std::make_optional>( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size()) + : std::nullopt, + raft::device_span(renumbered_and_sorted_renumber_map.data(), + renumbered_and_sorted_renumber_map.size()), + renumbered_and_sorted_renumber_map_label_offsets + ? std::make_optional>( + (*renumbered_and_sorted_renumber_map_label_offsets).data(), + (*renumbered_and_sorted_renumber_map_label_offsets).size()) + : std::nullopt, + std::nullopt, + sampling_post_processing_usecase.num_labels, + 1, + sampling_post_processing_usecase.src_is_major)) + << "Renumbered and sorted output renumber map violates invariants."; } } + // 5. post processing: renumber & compress + { rmm::device_uvector renumbered_and_compressed_edgelist_srcs( org_edgelist_srcs.size(), handle.get_stream()); @@ -921,126 +497,52 @@ class Tests_SamplingPostProcessing } if (sampling_post_processing_usecase.check_correctness) { - if (renumbered_and_compressed_nzd_vertices) { - ASSERT_TRUE(renumbered_and_compressed_offsets.size() == - (*renumbered_and_compressed_nzd_vertices).size() + 1) - << "Renumbered and compressed offset array size should coincide with the number of " - "non-zero-degree vertices + 1."; - } - - ASSERT_TRUE(thrust::is_sorted(handle.get_thrust_policy(), - renumbered_and_compressed_offsets.begin(), - renumbered_and_compressed_offsets.end())) - << "Renumbered and compressed offset array values should be non-decreasing."; - - ASSERT_TRUE(renumbered_and_compressed_offsets.back_element(handle.get_stream()) == - renumbered_and_compressed_edgelist_minors.size()) - << "Renumbered and compressed offset array's last value should coincide with the number " - "of " - "edges."; + ASSERT_TRUE(check_offsets( + handle, + raft::device_span(renumbered_and_compressed_offsets.data(), + renumbered_and_compressed_offsets.size()), + renumbered_and_compressed_nzd_vertices ? (*renumbered_and_compressed_nzd_vertices).size() + : renumbered_and_compressed_offsets.size() - 1, + renumbered_and_compressed_edgelist_minors.size())) + << "Renumbered and compressed offset array is invalid"; if (renumbered_and_compressed_offset_label_hop_offsets) { - ASSERT_TRUE((*renumbered_and_compressed_offset_label_hop_offsets).size() == - sampling_post_processing_usecase.num_labels * - sampling_post_processing_usecase.fanouts.size() + - 1) - << "Renumbered and compressed offset (label,hop) offset array size should coincide " - "with " - "the number of labels * the number of hops + 1."; - - ASSERT_TRUE( - thrust::is_sorted(handle.get_thrust_policy(), - (*renumbered_and_compressed_offset_label_hop_offsets).begin(), - (*renumbered_and_compressed_offset_label_hop_offsets).end())) - << "Renumbered and compressed offset (label,hop) offset array values should be " - "non-decreasing."; - - ASSERT_TRUE((*renumbered_and_compressed_offset_label_hop_offsets) - .back_element(handle.get_stream()) == - renumbered_and_compressed_offsets.size() - 1) - << "Renumbered and compressed offset (label,hop) offset array's last value should " - "coincide with the offset array size - 1."; + ASSERT_TRUE(check_offsets(handle, + raft::device_span( + (*renumbered_and_compressed_offset_label_hop_offsets).data(), + (*renumbered_and_compressed_offset_label_hop_offsets).size()), + sampling_post_processing_usecase.num_labels * + sampling_post_processing_usecase.fanouts.size(), + renumbered_and_compressed_offsets.size() - 1)) + << "Renumbered and compressed offset (label, hop) offset array is invalid"; } if (renumbered_and_compressed_renumber_map_label_offsets) { - ASSERT_TRUE((*renumbered_and_compressed_renumber_map_label_offsets).size() == - sampling_post_processing_usecase.num_labels + 1) - << "Renumbered and compressed offset (label, hop) offset array size should coincide " - "with " - "the number of labels + 1."; - ASSERT_TRUE( - thrust::is_sorted(handle.get_thrust_policy(), - (*renumbered_and_compressed_renumber_map_label_offsets).begin(), - (*renumbered_and_compressed_renumber_map_label_offsets).end())) - << "Renumbered and compressed renumber map label offset array values should be " - "non-decreasing."; - - ASSERT_TRUE((*renumbered_and_compressed_renumber_map_label_offsets) - .back_element(handle.get_stream()) == - renumbered_and_compressed_renumber_map.size()) - << "Renumbered and compressed renumber map label offset array's last value should " - "coincide with the renumber map size."; + check_offsets(handle, + raft::device_span( + (*renumbered_and_compressed_renumber_map_label_offsets).data(), + (*renumbered_and_compressed_renumber_map_label_offsets).size()), + sampling_post_processing_usecase.num_labels, + renumbered_and_compressed_renumber_map.size())) + << "Renumbered and compressed renumber map label offset array is invalid"; } - for (size_t i = 0; i < sampling_post_processing_usecase.num_labels; ++i) { - size_t starting_vertex_start_offset = - starting_vertex_label_offsets - ? (*starting_vertex_label_offsets).element(i, handle.get_stream()) - : size_t{0}; - size_t starting_vertex_end_offset = - starting_vertex_label_offsets - ? (*starting_vertex_label_offsets).element(i + 1, handle.get_stream()) - : starting_vertices.size(); - - size_t edgelist_start_offset = - org_edgelist_label_offsets - ? (*org_edgelist_label_offsets).element(i, handle.get_stream()) - : size_t{0}; - size_t edgelist_end_offset = - org_edgelist_label_offsets - ? (*org_edgelist_label_offsets).element(i + 1, handle.get_stream()) - : org_edgelist_srcs.size(); - if (edgelist_start_offset == edgelist_end_offset) continue; - - auto this_label_starting_vertices = raft::device_span( - starting_vertices.data() + starting_vertex_start_offset, - starting_vertex_end_offset - starting_vertex_start_offset); - - auto this_label_org_edgelist_srcs = - raft::device_span(org_edgelist_srcs.data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset); - auto this_label_org_edgelist_dsts = - raft::device_span(org_edgelist_dsts.data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset); - auto this_label_org_edgelist_hops = - org_edgelist_hops ? std::make_optional>( - (*org_edgelist_hops).data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset) - : std::nullopt; - auto this_label_org_edgelist_weights = - org_edgelist_weights ? std::make_optional>( - (*org_edgelist_weights).data() + edgelist_start_offset, - edgelist_end_offset - edgelist_start_offset) - : std::nullopt; - - rmm::device_uvector this_label_output_edgelist_srcs(0, handle.get_stream()); - rmm::device_uvector this_label_output_edgelist_dsts(0, handle.get_stream()); - auto this_label_output_edgelist_weights = - renumbered_and_compressed_edgelist_weights - ? std::make_optional>(0, handle.get_stream()) - : std::nullopt; - this_label_output_edgelist_srcs.reserve(edgelist_end_offset - edgelist_start_offset, - handle.get_stream()); - this_label_output_edgelist_dsts.reserve(edgelist_end_offset - edgelist_start_offset, - handle.get_stream()); - if (this_label_output_edgelist_weights) { - (*this_label_output_edgelist_weights) - .reserve(edgelist_end_offset - edgelist_start_offset, handle.get_stream()); - } - - // decompress + // check whether renumbering recovers the original edge list + + rmm::device_uvector output_edgelist_srcs(0, handle.get_stream()); + rmm::device_uvector output_edgelist_dsts(0, handle.get_stream()); + auto output_edgelist_weights = + renumbered_and_compressed_edgelist_weights + ? std::make_optional>(0, handle.get_stream()) + : std::nullopt; + output_edgelist_srcs.reserve(org_edgelist_srcs.size(), handle.get_stream()); + output_edgelist_dsts.reserve(org_edgelist_srcs.capacity(), handle.get_stream()); + if (output_edgelist_weights) { + (*output_edgelist_weights).reserve(org_edgelist_srcs.capacity(), handle.get_stream()); + } + for (size_t i = 0; i < sampling_post_processing_usecase.num_labels; ++i) { auto num_hops = sampling_post_processing_usecase.fanouts.size(); for (size_t j = 0; j < num_hops; ++j) { auto offset_start_offset = renumbered_and_compressed_offset_label_hop_offsets @@ -1069,108 +571,123 @@ class Tests_SamplingPostProcessing h_offsets.data(), d_offsets.data(), h_offsets.size(), handle.get_stream()); handle.sync_stream(); - auto old_size = this_label_output_edgelist_srcs.size(); - this_label_output_edgelist_srcs.resize(old_size + (h_offsets.back() - h_offsets[0]), - handle.get_stream()); - this_label_output_edgelist_dsts.resize(this_label_output_edgelist_srcs.size(), - handle.get_stream()); - if (this_label_output_edgelist_weights) { - (*this_label_output_edgelist_weights) - .resize(this_label_output_edgelist_srcs.size(), handle.get_stream()); + auto old_size = output_edgelist_srcs.size(); + output_edgelist_srcs.resize(old_size + (h_offsets.back() - h_offsets[0]), + handle.get_stream()); + output_edgelist_dsts.resize(output_edgelist_srcs.size(), handle.get_stream()); + if (output_edgelist_weights) { + (*output_edgelist_weights).resize(output_edgelist_srcs.size(), handle.get_stream()); + } + if (renumbered_and_compressed_nzd_vertices) { + cugraph::test::expand_hypersparse_offsets( + handle, + raft::device_span(d_offsets.data(), d_offsets.size()), + raft::device_span( + (*renumbered_and_compressed_nzd_vertices).data() + offset_start_offset, + (offset_end_offset - offset_start_offset) - 1), + raft::device_span( + (sampling_post_processing_usecase.src_is_major ? output_edgelist_srcs.data() + : output_edgelist_dsts.data()) + + old_size, + h_offsets.back() - h_offsets[0]), + h_offsets[0]); + } else { + cugraph::test::expand_sparse_offsets( + handle, + raft::device_span(d_offsets.data(), d_offsets.size()), + raft::device_span( + (sampling_post_processing_usecase.src_is_major ? output_edgelist_srcs.data() + : output_edgelist_dsts.data()) + + old_size, + h_offsets.back() - h_offsets[0]), + h_offsets[0], + base_v); } - thrust::transform( - handle.get_thrust_policy(), - thrust::make_counting_iterator(h_offsets[0]), - thrust::make_counting_iterator(h_offsets.back()), - (sampling_post_processing_usecase.src_is_major - ? this_label_output_edgelist_srcs.begin() - : this_label_output_edgelist_dsts.begin()) + + raft::copy( + (sampling_post_processing_usecase.src_is_major ? output_edgelist_dsts.begin() + : output_edgelist_srcs.begin()) + old_size, - cuda::proclaim_return_type( - [offsets = raft::device_span(d_offsets.data(), d_offsets.size()), - nzd_vertices = - renumbered_and_compressed_nzd_vertices - ? thrust::make_optional>( - (*renumbered_and_compressed_nzd_vertices).data() + offset_start_offset, - (offset_end_offset - offset_start_offset) - 1) - : thrust::nullopt, - base_v] __device__(size_t i) { - auto idx = static_cast(thrust::distance( - offsets.begin() + 1, - thrust::upper_bound(thrust::seq, offsets.begin() + 1, offsets.end(), i))); - if (nzd_vertices) { - return (*nzd_vertices)[idx]; - } else { - return base_v + static_cast(idx); - } - })); - thrust::copy(handle.get_thrust_policy(), - renumbered_and_compressed_edgelist_minors.begin() + h_offsets[0], - renumbered_and_compressed_edgelist_minors.begin() + h_offsets.back(), - (sampling_post_processing_usecase.src_is_major - ? this_label_output_edgelist_dsts.begin() - : this_label_output_edgelist_srcs.begin()) + - old_size); - if (this_label_output_edgelist_weights) { - thrust::copy(handle.get_thrust_policy(), - (*renumbered_and_compressed_edgelist_weights).begin() + h_offsets[0], - (*renumbered_and_compressed_edgelist_weights).begin() + h_offsets.back(), - (*this_label_output_edgelist_weights).begin() + old_size); + renumbered_and_compressed_edgelist_minors.begin() + h_offsets[0], + h_offsets.back() - h_offsets[0], + handle.get_stream()); + if (output_edgelist_weights) { + raft::copy((*output_edgelist_weights).begin() + old_size, + (*renumbered_and_compressed_edgelist_weights).begin() + h_offsets[0], + h_offsets.back() - h_offsets[0], + handle.get_stream()); } } - - size_t renumber_map_start_offset = - renumbered_and_compressed_renumber_map_label_offsets - ? (*renumbered_and_compressed_renumber_map_label_offsets) - .element(i, handle.get_stream()) - : size_t{0}; - size_t renumber_map_end_offset = - renumbered_and_compressed_renumber_map_label_offsets - ? (*renumbered_and_compressed_renumber_map_label_offsets) - .element(i + 1, handle.get_stream()) - : renumbered_and_compressed_renumber_map.size(); - auto this_label_output_renumber_map = raft::device_span( - renumbered_and_compressed_renumber_map.data() + renumber_map_start_offset, - renumber_map_end_offset - renumber_map_start_offset); - - // check whether renumbering recovers the original edge list - - ASSERT_TRUE(compare_edgelist( - handle, - this_label_org_edgelist_srcs, - this_label_org_edgelist_dsts, - this_label_org_edgelist_weights, - raft::device_span(this_label_output_edgelist_srcs.data(), - this_label_output_edgelist_srcs.size()), - raft::device_span(this_label_output_edgelist_dsts.data(), - this_label_output_edgelist_dsts.size()), - this_label_output_edgelist_weights - ? std::make_optional>( - (*this_label_output_edgelist_weights).data(), - (*this_label_output_edgelist_weights).size()) - : std::nullopt, - std::make_optional(this_label_output_renumber_map))) - << "Unrenumbering the renumbered and sorted edge list does not recover the original " - "edgelist."; - - // Check the invariants in renumber_map - - ASSERT_TRUE(check_renumber_map_invariants( - handle, - sampling_post_processing_usecase.renumber_with_seeds - ? std::make_optional>( - this_label_starting_vertices.data(), this_label_starting_vertices.size()) - : std::nullopt, - this_label_org_edgelist_srcs, - this_label_org_edgelist_dsts, - this_label_org_edgelist_hops, - this_label_output_renumber_map, - sampling_post_processing_usecase.src_is_major)) - << "Renumbered and sorted output renumber map violates invariants."; } + + ASSERT_TRUE(compare_edgelist( + handle, + raft::device_span(org_edgelist_srcs.data(), org_edgelist_srcs.size()), + raft::device_span(org_edgelist_dsts.data(), org_edgelist_dsts.size()), + org_edgelist_weights ? std::make_optional>( + (*org_edgelist_weights).data(), (*org_edgelist_weights).size()) + : std::nullopt, + org_edgelist_label_offsets + ? std::make_optional(raft::device_span( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size())) + : std::nullopt, + raft::device_span(output_edgelist_srcs.data(), + output_edgelist_srcs.size()), + raft::device_span(output_edgelist_dsts.data(), + output_edgelist_dsts.size()), + output_edgelist_weights + ? std::make_optional>( + (*output_edgelist_weights).data(), (*output_edgelist_weights).size()) + : std::nullopt, + std::make_optional>( + renumbered_and_compressed_renumber_map.data(), + renumbered_and_compressed_renumber_map.size()), + renumbered_and_compressed_renumber_map_label_offsets + ? std::make_optional>( + (*renumbered_and_compressed_renumber_map_label_offsets).data(), + (*renumbered_and_compressed_renumber_map_label_offsets).size()) + : std::nullopt, + sampling_post_processing_usecase.num_labels)) + << "Unrenumbering the renumbered and sorted edge list does not recover the original " + "edgelist."; + + // Check the invariants in renumber_map + + ASSERT_TRUE(check_vertex_renumber_map_invariants( + handle, + sampling_post_processing_usecase.renumber_with_seeds + ? std::make_optional>(starting_vertices.data(), + starting_vertices.size()) + : std::nullopt, + (sampling_post_processing_usecase.renumber_with_seeds && starting_vertex_label_offsets) + ? std::make_optional>( + (*starting_vertex_label_offsets).data(), (*starting_vertex_label_offsets).size()) + : std::nullopt, + raft::device_span(org_edgelist_srcs.data(), org_edgelist_srcs.size()), + raft::device_span(org_edgelist_dsts.data(), org_edgelist_dsts.size()), + org_edgelist_hops ? std::make_optional>( + (*org_edgelist_hops).data(), (*org_edgelist_hops).size()) + : std::nullopt, + org_edgelist_label_offsets + ? std::make_optional(raft::device_span( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size())) + : std::nullopt, + raft::device_span(renumbered_and_compressed_renumber_map.data(), + renumbered_and_compressed_renumber_map.size()), + renumbered_and_compressed_renumber_map_label_offsets + ? std::make_optional>( + (*renumbered_and_compressed_renumber_map_label_offsets).data(), + (*renumbered_and_compressed_renumber_map_label_offsets).size()) + : std::nullopt, + std::nullopt, + sampling_post_processing_usecase.num_labels, + 1, + sampling_post_processing_usecase.src_is_major)) + << "Renumbered and sorted output renumber map violates invariants."; } } + // 6. post processing: sort only + { rmm::device_uvector sorted_edgelist_srcs(org_edgelist_srcs.size(), handle.get_stream()); @@ -1245,25 +762,42 @@ class Tests_SamplingPostProcessing if (sampling_post_processing_usecase.check_correctness) { if (sorted_edgelist_label_hop_offsets) { - ASSERT_TRUE((*sorted_edgelist_label_hop_offsets).size() == - sampling_post_processing_usecase.num_labels * - sampling_post_processing_usecase.fanouts.size() + - 1) - << "Sorted edge list (label,hop) offset array size should coincide with " - "the number of labels * the number of hops + 1."; - - ASSERT_TRUE(thrust::is_sorted(handle.get_thrust_policy(), - (*sorted_edgelist_label_hop_offsets).begin(), - (*sorted_edgelist_label_hop_offsets).end())) - << "Sorted edge list (label,hop) offset array values should be " - "non-decreasing."; - - ASSERT_TRUE((*sorted_edgelist_label_hop_offsets).back_element(handle.get_stream()) == - sorted_edgelist_srcs.size()) - << "Sorted edge list (label,hop) offset array's last element should coincide with the " - "number of edges."; + ASSERT_TRUE(check_offsets( + handle, + raft::device_span((*sorted_edgelist_label_hop_offsets).data(), + (*sorted_edgelist_label_hop_offsets).size()), + sampling_post_processing_usecase.num_labels * + sampling_post_processing_usecase.fanouts.size(), + sorted_edgelist_srcs.size())) + << "Sorted edge list (label, hop) offset array is invalid."; } + // check whether renumbering recovers the original edge list + + ASSERT_TRUE(compare_edgelist( + handle, + raft::device_span(org_edgelist_srcs.data(), org_edgelist_srcs.size()), + raft::device_span(org_edgelist_dsts.data(), org_edgelist_dsts.size()), + org_edgelist_weights ? std::make_optional>( + (*org_edgelist_weights).data(), (*org_edgelist_weights).size()) + : std::nullopt, + org_edgelist_label_offsets + ? std::make_optional(raft::device_span( + (*org_edgelist_label_offsets).data(), (*org_edgelist_label_offsets).size())) + : std::nullopt, + raft::device_span(sorted_edgelist_srcs.data(), + sorted_edgelist_srcs.size()), + raft::device_span(sorted_edgelist_dsts.data(), + sorted_edgelist_dsts.size()), + sorted_edgelist_weights + ? std::make_optional>( + (*sorted_edgelist_weights).data(), (*sorted_edgelist_weights).size()) + : std::nullopt, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + sampling_post_processing_usecase.num_labels)) + << "Sorted edge list does not coincide with the original edgelist."; + for (size_t i = 0; i < sampling_post_processing_usecase.num_labels; ++i) { size_t edgelist_start_offset = org_edgelist_label_offsets @@ -1314,9 +848,7 @@ class Tests_SamplingPostProcessing : this_label_output_edgelist_srcs; if (this_label_org_edgelist_hops) { - auto num_hops = sampling_post_processing_usecase.fanouts.size(); - auto edge_first = thrust::make_zip_iterator(this_label_output_edgelist_majors.begin(), - this_label_output_edgelist_minors.begin()); + auto num_hops = sampling_post_processing_usecase.fanouts.size(); for (size_t j = 0; j < num_hops; ++j) { auto hop_start_offset = (*sorted_edgelist_label_hop_offsets) @@ -1326,32 +858,25 @@ class Tests_SamplingPostProcessing (*sorted_edgelist_label_hop_offsets) .element(i * num_hops + j + 1, handle.get_stream()) - (*sorted_edgelist_label_hop_offsets).element(i * num_hops, handle.get_stream()); - ASSERT_TRUE(thrust::is_sorted(handle.get_thrust_policy(), - edge_first + hop_start_offset, - edge_first + hop_end_offset)) - << "Renumbered and sorted output edges are not properly sorted."; + ASSERT_TRUE(check_edgelist_is_sorted( + handle, + raft::device_span( + this_label_output_edgelist_majors.data() + hop_start_offset, + hop_end_offset - hop_start_offset), + raft::device_span( + this_label_output_edgelist_minors.data() + hop_start_offset, + hop_end_offset - hop_start_offset))) + << "Sorted edge list is not properly sorted."; } } else { - auto edge_first = thrust::make_zip_iterator(this_label_output_edgelist_majors.begin(), - this_label_output_edgelist_minors.begin()); - ASSERT_TRUE(thrust::is_sorted(handle.get_thrust_policy(), - edge_first, - edge_first + this_label_output_edgelist_majors.size())) - << "Renumbered and sorted output edges are not properly sorted."; + ASSERT_TRUE(check_edgelist_is_sorted( + handle, + raft::device_span(this_label_output_edgelist_majors.data(), + this_label_output_edgelist_majors.size()), + raft::device_span(this_label_output_edgelist_minors.data(), + this_label_output_edgelist_minors.size()))) + << "Sorted edge list is not properly sorted."; } - - // check whether renumbering recovers the original edge list - - ASSERT_TRUE( - compare_edgelist(handle, - this_label_org_edgelist_srcs, - this_label_org_edgelist_dsts, - this_label_org_edgelist_weights, - this_label_output_edgelist_srcs, - this_label_output_edgelist_dsts, - this_label_output_edgelist_weights, - std::optional>{std::nullopt})) - << "Sorted edge list does not coincide with the original edgelist."; } } } diff --git a/cpp/tests/utilities/property_generator_utilities.hpp b/cpp/tests/utilities/property_generator_utilities.hpp index 6bd22da1f75..f907501cc7c 100644 --- a/cpp/tests/utilities/property_generator_utilities.hpp +++ b/cpp/tests/utilities/property_generator_utilities.hpp @@ -34,6 +34,7 @@ template struct generate { private: using vertex_type = typename GraphViewType::vertex_type; + using edge_type_t = int32_t; using property_buffer_type = std::decay_t( size_t{0}, rmm::cuda_stream_view{}))>; @@ -62,6 +63,28 @@ struct generate { static cugraph::edge_property_t edge_property( raft::handle_t const& handle, GraphViewType const& graph_view, int32_t hash_bin_count); + + static cugraph::edge_property_t edge_property_by_src_dst_types( + raft::handle_t const& handle, + GraphViewType const& graph_view, + raft::device_span vertex_type_offsets, + int32_t hash_bin_count); + + // generate unqiue edge property values (in [0, # edges in the graph) if property_t is an integer + // type, this function requires std::numeric_limits::max() to be no smaller than the + // number of edges in the input graph). + static cugraph::edge_property_t unique_edge_property( + raft::handle_t const& handle, GraphViewType const& graph_view); + + // generate unique (edge property value, edge type) pairs (if property_t is an integral type, edge + // property values for each type are consecutive integers starting from 0, this function requires + // std::numeric_limits::max() to be no smaller than the number of edges in the input + // graph). + static cugraph::edge_property_t unique_edge_property_per_type( + raft::handle_t const& handle, + GraphViewType const& graph_view, + cugraph::edge_property_view_t edge_type_view, + int32_t num_edge_types); }; } // namespace test diff --git a/cpp/tests/utilities/property_generator_utilities_impl.cuh b/cpp/tests/utilities/property_generator_utilities_impl.cuh index a46009f95e3..61a861b6670 100644 --- a/cpp/tests/utilities/property_generator_utilities_impl.cuh +++ b/cpp/tests/utilities/property_generator_utilities_impl.cuh @@ -26,6 +26,7 @@ #include +#include #include #include @@ -127,5 +128,102 @@ generate::edge_property(raft::handle_t const& handle, return output_property; } +template +cugraph::edge_property_t +generate::edge_property_by_src_dst_types( + raft::handle_t const& handle, + GraphViewType const& graph_view, + raft::device_span vertex_type_offsets, + int32_t hash_bin_count) +{ + auto output_property = cugraph::edge_property_t(handle, graph_view); + + cugraph::transform_e( + handle, + graph_view, + cugraph::edge_src_dummy_property_t{}.view(), + cugraph::edge_dst_dummy_property_t{}.view(), + cugraph::edge_dummy_property_t{}.view(), + [vertex_type_offsets, hash_bin_count] __device__(auto src, auto dst, auto, auto, auto) { + auto src_v_type = thrust::distance( + vertex_type_offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, vertex_type_offsets.begin() + 1, vertex_type_offsets.end(), src)); + auto dst_v_type = thrust::distance( + vertex_type_offsets.begin() + 1, + thrust::upper_bound( + thrust::seq, vertex_type_offsets.begin() + 1, vertex_type_offsets.end(), dst)); + auto num_v_types = vertex_type_offsets.size() - 1; + return detail::make_property_value((src_v_type * num_v_types + dst_v_type) % + hash_bin_count); + }, + output_property.mutable_view()); + + return output_property; +} + +template +cugraph::edge_property_t +generate::unique_edge_property(raft::handle_t const& handle, + GraphViewType const& graph_view) +{ + auto output_property = cugraph::edge_property_t(handle, graph_view); + if constexpr (std::is_integral_v && !std::is_same_v) { + CUGRAPH_EXPECTS( + graph_view.compute_number_of_edges(handle) <= std::numeric_limits::max(), + "std::numeric_limits::max() is smaller than the number of edges."); + rmm::device_scalar counter(property_t{0}, handle.get_stream()); + cugraph::transform_e( + handle, + graph_view, + cugraph::edge_src_dummy_property_t{}.view(), + cugraph::edge_dst_dummy_property_t{}.view(), + cugraph::edge_dummy_property_t{}.view(), + [counter = counter.data()] __device__(auto, auto, auto, auto, auto) { + cuda::atomic_ref atomic_counter(*counter); + return atomic_counter.fetch_add(property_t{1}, cuda::std::memory_order_relaxed); + }, + output_property.mutable_view()); + if constexpr (GraphViewType::is_multi_gpu) { CUGRAPH_FAIL("unimplemented."); } + } else { + CUGRAPH_FAIL("unimplemented."); + } + return output_property; +} + +template +cugraph::edge_property_t +generate::unique_edge_property_per_type( + raft::handle_t const& handle, + GraphViewType const& graph_view, + cugraph::edge_property_view_t edge_type_view, + int32_t num_edge_types) +{ + auto output_property = cugraph::edge_property_t(handle, graph_view); + if constexpr (std::is_integral_v && !std::is_same_v) { + CUGRAPH_EXPECTS( + graph_view.compute_number_of_edges(handle) <= std::numeric_limits::max(), + "std::numeric_limits::max() is smaller than the number of edges."); + rmm::device_uvector counters(num_edge_types, handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), counters.begin(), counters.end(), property_t{0}); + cugraph::transform_e( + handle, + graph_view, + cugraph::edge_src_dummy_property_t{}.view(), + cugraph::edge_dst_dummy_property_t{}.view(), + edge_type_view, + [counters = raft::device_span(counters.data(), counters.size())] __device__( + auto, auto, auto, auto, int32_t edge_type) { + cuda::atomic_ref atomic_counter(counters[edge_type]); + return atomic_counter.fetch_add(property_t{1}, cuda::std::memory_order_relaxed); + }, + output_property.mutable_view()); + if constexpr (GraphViewType::is_multi_gpu) { CUGRAPH_FAIL("unimplemented."); } + } else { + CUGRAPH_FAIL("unimplemented."); + } + return output_property; +} + } // namespace test } // namespace cugraph diff --git a/cpp/tests/utilities/thrust_wrapper.cu b/cpp/tests/utilities/thrust_wrapper.cu index 8d26ac1f2fe..ef1c4f831eb 100644 --- a/cpp/tests/utilities/thrust_wrapper.cu +++ b/cpp/tests/utilities/thrust_wrapper.cu @@ -16,11 +16,15 @@ #include "utilities/thrust_wrapper.hpp" +#include +#include + #include #include #include #include +#include #include #include #include @@ -477,5 +481,70 @@ template void populate_vertex_ids(raft::handle_t const& handle, rmm::device_uvector& d_vertices_v, int64_t vertex_id_offset); +template +void expand_sparse_offsets(raft::handle_t const& handle, + raft::device_span offsets, + raft::device_span indices, + offset_t base_offset, + idx_t base_idx) +{ + rmm::device_uvector tmp_offsets(offsets.size(), handle.get_stream()); + thrust::transform(handle.get_thrust_policy(), + offsets.begin(), + offsets.end(), + tmp_offsets.begin(), + cugraph::detail::shift_left_t{base_offset}); + auto tmp = cugraph::detail::expand_sparse_offsets( + raft::device_span(tmp_offsets.data(), tmp_offsets.size()), + base_idx, + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), tmp.begin(), tmp.end(), indices.begin()); +} + +template void expand_sparse_offsets(raft::handle_t const& handle, + raft::device_span offsets, + raft::device_span indices, + size_t base_offset, + int32_t base_idx); + +template void expand_sparse_offsets(raft::handle_t const& handle, + raft::device_span offsets, + raft::device_span indices, + size_t base_offset, + int64_t base_idx); + +template +void expand_hypersparse_offsets(raft::handle_t const& handle, + raft::device_span offsets, + raft::device_span nzd_indices, + raft::device_span indices, + offset_t base_offset) +{ + rmm::device_uvector tmp_offsets(offsets.size(), handle.get_stream()); + thrust::transform(handle.get_thrust_policy(), + offsets.begin(), + offsets.end(), + tmp_offsets.begin(), + cugraph::detail::shift_left_t{base_offset}); + auto tmp = cugraph::detail::expand_sparse_offsets( + raft::device_span(tmp_offsets.data(), tmp_offsets.size()), + idx_t{0}, + handle.get_stream()); + thrust::gather( + handle.get_thrust_policy(), tmp.begin(), tmp.end(), nzd_indices.begin(), indices.begin()); +} + +template void expand_hypersparse_offsets(raft::handle_t const& handle, + raft::device_span offsets, + raft::device_span nzd_indices, + raft::device_span indices, + size_t base_offset); + +template void expand_hypersparse_offsets(raft::handle_t const& handle, + raft::device_span offsets, + raft::device_span nzd_indices, + raft::device_span indices, + size_t base_offset); + } // namespace test } // namespace cugraph diff --git a/cpp/tests/utilities/thrust_wrapper.hpp b/cpp/tests/utilities/thrust_wrapper.hpp index cd8bc33308f..afdff33d80a 100644 --- a/cpp/tests/utilities/thrust_wrapper.hpp +++ b/cpp/tests/utilities/thrust_wrapper.hpp @@ -93,5 +93,19 @@ void populate_vertex_ids(raft::handle_t const& handle, rmm::device_uvector& d_vertices_v /* [INOUT] */, vertex_t vertex_id_offset); +template +void expand_sparse_offsets(raft::handle_t const& handle, + raft::device_span offsets, + raft::device_span indices, + offset_t base_offset, + idx_t base_idx); + +template +void expand_hypersparse_offsets(raft::handle_t const& handle, + raft::device_span offsets, + raft::device_span nzd_indices, + raft::device_span indices, + offset_t base_offset); + } // namespace test } // namespace cugraph