From c09db10bb1dfe01b681e7968ca3e2b791805142e Mon Sep 17 00:00:00 2001 From: Alex Barghi <105237337+alexbarghi-nv@users.noreply.github.com> Date: Thu, 11 Jan 2024 19:21:55 -0500 Subject: [PATCH 1/4] Sampling Performance Testing (#3584) Adds performance benchmarking scripts for testing MNMG cuGraph GNN workflows. This branch is the head branch for the cuGraph benchmarking effort. All work supporting the benchmarks should be merged into this branch. It will be merged into branch-24.02 once all features are ready. Includes patches to cuGraph-PyG required for the latest DLFW container. To-Do: - [x] Refactor for branch-24.02 - [x] Add WholeGraph training portion Deferred to future PR (see https://github.com/alexbarghi-nv/cugraph/pull/6) - [x] Add WholeGraph generators Included in above - [x] Support DGL Deferred to future PR - [x] Use appropriate docker containers Deferred, waiting on DLFW release Closes #3839 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Seunghwa Kang (https://github.com/seunghwak) Approvers: - Vibhu Jawa (https://github.com/VibhuJawa) - Rick Ratzel (https://github.com/rlratzel) - Chuck Hastings (https://github.com/ChuckHastings) URL: https://github.com/rapidsai/cugraph/pull/3584 --- .../standalone/bulk_sampling/.gitignore | 1 + .../standalone/bulk_sampling/README.md | 70 ++- .../bulk_sampling/bench_cugraph_training.py | 251 ++++++++++ .../standalone/bulk_sampling/bulk_sampling.sh | 50 -- .../bulk_sampling/cugraph_bulk_sampling.py | 224 +++++---- .../bulk_sampling/datasets/__init__.py | 15 + .../bulk_sampling/datasets/dataset.py | 55 +++ .../bulk_sampling/datasets/ogbn_papers100M.py | 345 ++++++++++++++ .../bulk_sampling/models/__init__.py | 12 + .../bulk_sampling/models/pyg/__init__.py | 15 + .../models/pyg/models_cugraph_pyg.py | 78 ++++ .../bulk_sampling/models/pyg/models_pyg.py | 58 +++ .../standalone/bulk_sampling/run_sampling.sh | 111 +++++ .../standalone/bulk_sampling/run_train_job.sh | 84 ++++ .../bulk_sampling/trainers/__init__.py | 15 + .../bulk_sampling/trainers/pyg/__init__.py | 15 + .../trainers/pyg/trainers_cugraph_pyg.py | 184 ++++++++ .../trainers/pyg/trainers_pyg.py | 430 ++++++++++++++++++ .../bulk_sampling/trainers/trainer.py | 54 +++ cpp/src/community/flatten_dendrogram.hpp | 2 +- mg_utils/wait_for_workers.py | 124 +++++ .../cugraph_pyg/loader/cugraph_node_loader.py | 29 +- .../cugraph_pyg/sampler/cugraph_sampler.py | 3 +- .../cugraph_pyg/tests/test_cugraph_store.py | 14 +- .../cugraph/cugraph/experimental/__init__.py | 2 +- 25 files changed, 2053 insertions(+), 188 deletions(-) create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/.gitignore create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py delete mode 100755 benchmarks/cugraph/standalone/bulk_sampling/bulk_sampling.sh create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/datasets/__init__.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/datasets/dataset.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/models/__init__.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/models/pyg/__init__.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_pyg.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh create mode 100755 benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/trainers/__init__.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/__init__.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py create mode 100644 benchmarks/cugraph/standalone/bulk_sampling/trainers/trainer.py create mode 100644 mg_utils/wait_for_workers.py diff --git a/benchmarks/cugraph/standalone/bulk_sampling/.gitignore b/benchmarks/cugraph/standalone/bulk_sampling/.gitignore new file mode 100644 index 00000000000..19cbd00ebe0 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/.gitignore @@ -0,0 +1 @@ +mg_utils/ diff --git a/benchmarks/cugraph/standalone/bulk_sampling/README.md b/benchmarks/cugraph/standalone/bulk_sampling/README.md index f48eea5c556..bb01133c52f 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/README.md +++ b/benchmarks/cugraph/standalone/bulk_sampling/README.md @@ -1,11 +1,13 @@ -# cuGraph Bulk Sampling +# cuGraph Sampling Benchmarks -## Overview +## cuGraph Bulk Sampling + +### Overview The `cugraph_bulk_sampling.py` script runs the bulk sampler for a variety of datasets, including both generated (rmat) datasets and disk (ogbn_papers100M, etc.) datasets. It can also load replicas of these datasets to create a larger benchmark (i.e. ogbn_papers100M x2). -## Arguments +### Arguments The script takes a variety of arguments to control sampling behavior. Required: --output_root @@ -51,14 +53,8 @@ Optional: Seed for random number generation. Defaults to '62' - --persist - Whether to aggressively use persist() in dask to make the ETL steps (NOT PART OF SAMPLING) faster. - Will probably make this script finish sooner at the expense of memory usage, but won't affect - sampling time. - Changing this is not recommended unless you know what you are doing. - Defaults to False. -## Input Format +### Input Format The script expects its input data in the following format: ``` @@ -103,7 +99,7 @@ the parquet files. It must have the following format: } ``` -## Output Meta +### Output Meta The script, in addition to the samples, will also output a file named `output_meta.json`. This file contains various statistics about the sampling run, including the runtime, as well as information about the dataset and system that the samples were produced from. @@ -111,6 +107,56 @@ as well as information about the dataset and system that the samples were produc This metadata file can be used to gather the results from the sampling and training stages together. -## Other Notes +### Other Notes For rmat datasets, you will need to generate your own bogus features in the training stage. Since that is trivial, that is not done in this sampling script. + +## cuGraph MNMG Training + +### Overview +The script `run_train_job.sh` runs with the `sbatch` command to launch a series of slurm jobs. +First, for a given number of epochs, the script will produce samples for a given graph. +Then, the training process starts where samples are loaded and training iterations are +processed. + +### Important Notes +Downloading the dataset files before running the slurm jobs is highly recommended. Even though +the script will attempt to download the files if they are not available, this can often +lead to a timeout which will crash the scripts. This applies regardless of whether you are training +with native PyG or cuGraph-PyG. You can download data as follows: + +``` +from ogb.nodeproppred import NodePropPredDataset +dataset = NodePropPredDataset('ogbn-papers100M', root='/home/username/datasets') +``` + +For datasets other than ogbn-papers100M, you follow the same process but only change the dataset name. +The dataset will be correctly preprocessed when you run training. In case you have a slow system, you +can also run preprocessing by running the training script on a single worker, which will avoid a timeout +which crashes the script. + +The multi-GPU utilities are in `mg_utils` in the top level of the cuGraph repository. You should either +copy them to this directory or symlink to them before running the scripts. + +### Arguments +You will need to modify the bash scripts to run appopriately for your environment and +desired training workflow. The standard sbatch arguments are at the top of the script, such as +job name, queue, etc. These will need to be modified for your SLURM cluster. + +Next are arguments for the container image (required), +and directories where the data and outputs are stored. The directories default to subdirectories +of the current working directory. But if there is a high-throughput storage system available, +using that storage for the samples and datasets is highly recommended. + +Next are standard GNN training arguments such as `FANOUT`, `BATCH_SIZE`, etc. You can also set +the number of training epochs here. These are followed by the `REPLICATION_FACTOR` argument, which +can be used to create replications of the dataset for scale testing purposes. + +The final two arguments are `FRAMEWORK` which can be either "cuGraphPyG" or "PyG", and `GPUS_PER_NODE` +which must be set to the correct value, even if this is provided by a SLURM argument. If `GPUS_PER_NODE` +is not set to the correct number of GPUs, the script will hang indefinitely until it times out. Mismatched +GPUs per node is currently unsupported by this script but should be possible in practice. + +### Output +The results of training will be outputted to the logs directory with an `output.txt` file for each worker. +These will be overwritten upon each run. Accuracy is only reported on rank 0. \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py new file mode 100644 index 00000000000..c9e347b261d --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py @@ -0,0 +1,251 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" +os.environ["CUDF_SPILL"] = "1" +os.environ["LIBCUDF_CUFILE_POLICY"] = "KVIKIO" +os.environ["KVIKIO_NTHREADS"] = "8" + +import argparse +import json +import warnings + +import torch +import numpy as np +import pandas + +import torch.distributed as dist + +from datasets import OGBNPapers100MDataset + +from cugraph.testing.mg_utils import enable_spilling + + +def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool = False) -> None: + import cupy + import rmm + from pynvml.smi import nvidia_smi + + smi = nvidia_smi.getInstance() + pool_size = 16e9 # FIXME calculate this + + rmm.reinitialize( + devices=[rank], + pool_allocator=True, + initial_pool_size=pool_size, + ) + + if use_rmm_torch_allocator: + warnings.warn( + "Using the rmm pytorch allocator is currently unsupported." + " The default allocator will be used instead." + ) + # FIXME somehow get the pytorch allocator to work + # from rmm.allocators.torch import rmm_torch_allocator + # torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + cupy.cuda.Device(rank).use() + torch.cuda.set_device(rank) + + # Pytorch training worker initialization + torch.distributed.init_process_group(backend="nccl") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--gpus_per_node", + type=int, + default=8, + help="# GPUs per node", + required=False, + ) + + parser.add_argument( + "--num_epochs", + type=int, + default=1, + help="Number of training epochs", + required=False, + ) + + parser.add_argument( + "--batch_size", + type=int, + default=512, + help="Batch size", + required=False, + ) + + parser.add_argument( + "--fanout", + type=str, + default="10_10_10", + help="Fanout", + required=False, + ) + + parser.add_argument( + "--sample_dir", + type=str, + help="Directory with stored bulk samples (required for cuGraph run)", + required=False, + ) + + parser.add_argument( + "--output_file", + type=str, + help="File to store results", + required=True, + ) + + parser.add_argument( + "--framework", + type=str, + help="The framework to test (PyG, cuGraphPyG)", + required=True, + ) + + parser.add_argument( + "--model", + type=str, + default="GraphSAGE", + help="The model to use (currently only GraphSAGE supported)", + required=False, + ) + + parser.add_argument( + "--replication_factor", + type=int, + default=1, + help="The replication factor for the dataset", + required=False, + ) + + parser.add_argument( + "--dataset_dir", + type=str, + help="The directory where datasets are stored", + required=True, + ) + + parser.add_argument( + "--train_split", + type=float, + help="The percentage of the labeled data to use for training. The remainder is used for testing/validation.", + default=0.8, + required=False, + ) + + parser.add_argument( + "--val_split", + type=float, + help="The percentage of the testing/validation data to allocate for validation.", + default=0.5, + required=False, + ) + + return parser.parse_args() + + +def main(args): + import logging + + logging.basicConfig( + level=logging.INFO, + ) + logger = logging.getLogger("bench_cugraph_training") + logger.setLevel(logging.INFO) + + local_rank = int(os.environ["LOCAL_RANK"]) + global_rank = int(os.environ["RANK"]) + + init_pytorch_worker( + local_rank, use_rmm_torch_allocator=(args.framework == "cuGraph") + ) + enable_spilling() + print(f"worker initialized") + dist.barrier() + + world_size = int(os.environ["SLURM_JOB_NUM_NODES"]) * args.gpus_per_node + + dataset = OGBNPapers100MDataset( + replication_factor=args.replication_factor, + dataset_dir=args.dataset_dir, + train_split=args.train_split, + val_split=args.val_split, + load_edge_index=(args.framework == "PyG"), + ) + + if global_rank == 0: + dataset.download() + dist.barrier() + + fanout = [int(f) for f in args.fanout.split("_")] + + if args.framework == "PyG": + from trainers.pyg import PyGNativeTrainer + + trainer = PyGNativeTrainer( + model=args.model, + dataset=dataset, + device=local_rank, + rank=global_rank, + world_size=world_size, + num_epochs=args.num_epochs, + shuffle=True, + replace=False, + num_neighbors=fanout, + batch_size=args.batch_size, + ) + elif args.framework == "cuGraphPyG": + sample_dir = os.path.join( + args.sample_dir, + f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}", + ) + from trainers.pyg import PyGCuGraphTrainer + + trainer = PyGCuGraphTrainer( + model=args.model, + dataset=dataset, + sample_dir=sample_dir, + device=local_rank, + rank=global_rank, + world_size=world_size, + num_epochs=args.num_epochs, + shuffle=True, + replace=False, + num_neighbors=fanout, + batch_size=args.batch_size, + ) + else: + raise ValueError("unsupported framework") + + logger.info(f"Trainer ready on rank {global_rank}") + stats = trainer.train() + logger.info(stats) + + with open(f"{args.output_file}[{global_rank}]", "w") as f: + json.dump(stats, f) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/benchmarks/cugraph/standalone/bulk_sampling/bulk_sampling.sh b/benchmarks/cugraph/standalone/bulk_sampling/bulk_sampling.sh deleted file mode 100755 index e62cb3cda29..00000000000 --- a/benchmarks/cugraph/standalone/bulk_sampling/bulk_sampling.sh +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -export RAPIDS_NO_INITIALIZE="1" -export CUDF_SPILL="1" -export LIBCUDF_CUFILE_POLICY=OFF - - -dataset_name=$1 -dataset_root=$2 -output_root=$3 -batch_sizes=$4 -fanouts=$5 -reverse_edges=$6 - -rm -rf $output_root -mkdir -p $output_root - -# Change to 2 in Selene -gpu_per_replica=4 -#--add_edge_ids \ - -# Expand to 1, 4, 8 in Selene -for i in 1,2,3,4: -do - for replication in 2; - do - dataset_name_with_replication="${dataset_name}[${replication}]" - dask_worker_devices=$(seq -s, 0 $((gpu_per_replica*replication-1))) - echo "Sampling dataset = $dataset_name_with_replication on devices = $dask_worker_devices" - python3 cugraph_bulk_sampling.py --datasets $dataset_name_with_replication \ - --dataset_root $dataset_root \ - --batch_sizes $batch_sizes \ - --output_root $output_root \ - --dask_worker_devices $dask_worker_devices \ - --fanouts $fanouts \ - --batch_sizes $batch_sizes \ - --reverse_edges - done -done \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py index 9de6c3a2b01..e3a5bba3162 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -97,19 +97,15 @@ def symmetrize_ddf(dask_dataframe): return new_ddf -def renumber_ddf(dask_df, persist=False): +def renumber_ddf(dask_df): vertices = ( dask_cudf.concat([dask_df["src"], dask_df["dst"]]) .unique() .reset_index(drop=True) ) - if persist: - vertices = vertices.persist() vertices.name = "v" vertices = vertices.reset_index().set_index("v").rename(columns={"index": "m"}) - if persist: - vertices = vertices.persist() src = dask_df.merge(vertices, left_on="src", right_on="v", how="left").m.rename( "src" @@ -170,7 +166,7 @@ def _replicate_df( if replication_factor > 1: for r in range(1, replication_factor): - df_replicated = original_df + df_replicated = original_df.copy() for col, offset in col_item_counts.items(): df_replicated[col] += offset * r @@ -189,46 +185,75 @@ def sample_graph( seeds_per_call=400000, batches_per_partition=100, fanout=[5, 5, 5], + num_epochs=1, + train_perc=0.8, + val_perc=0.5, sampling_kwargs={}, ): cupy.random.seed(seed) - - sampler = BulkSampler( - batch_size=batch_size, - output_path=output_path, - graph=G, - fanout_vals=fanout, - with_replacement=False, - random_state=seed, - seeds_per_call=seeds_per_call, - batches_per_partition=batches_per_partition, - log_level=logging.INFO, - **sampling_kwargs, + train_df, test_df = label_df.random_split( + [train_perc, 1 - train_perc], random_state=seed, shuffle=True + ) + val_df, test_df = label_df.random_split( + [val_perc, 1 - val_perc], random_state=seed, shuffle=True ) - n_workers = len(default_client().scheduler_info()["workers"]) + total_time = 0.0 + for epoch in range(num_epochs): + steps = [("train", train_df), ("test", test_df)] + if epoch == num_epochs - 1: + steps.append(("val", val_df)) - meta = cudf.DataFrame( - {"node": cudf.Series(dtype="int64"), "batch": cudf.Series(dtype="int32")} - ) + for step, batch_df in steps: + batch_df = batch_df.sample(frac=1.0, random_state=seed) - batch_df = label_df.map_partitions( - _make_batch_ids, batch_size, n_workers, meta=meta - ) - # batch_df = batch_df.sort_values(by='node') + if step == "val": + output_sample_path = os.path.join(output_path, "val", "samples") + else: + output_sample_path = os.path.join( + output_path, f"epoch={epoch}", f"{step}", "samples" + ) + os.makedirs(output_sample_path) + + sampler = BulkSampler( + batch_size=batch_size, + output_path=output_sample_path, + graph=G, + fanout_vals=fanout, + with_replacement=False, + random_state=seed, + seeds_per_call=seeds_per_call, + batches_per_partition=batches_per_partition, + log_level=logging.INFO, + **sampling_kwargs, + ) - # should always persist the batch dataframe or performance may be suboptimal - batch_df = batch_df.persist() + n_workers = len(default_client().scheduler_info()["workers"]) - del label_df - print("created batches") + meta = cudf.DataFrame( + { + "node": cudf.Series(dtype="int64"), + "batch": cudf.Series(dtype="int32"), + } + ) + + batch_df = batch_df.map_partitions( + _make_batch_ids, batch_size, n_workers, meta=meta + ) + + # should always persist the batch dataframe or performance may be suboptimal + batch_df = batch_df.persist() + + print("created batches") - start_time = perf_counter() - sampler.add_batches(batch_df, start_col_name="node", batch_col_name="batch") - sampler.flush() - end_time = perf_counter() - print("flushed all batches") - return end_time - start_time + start_time = perf_counter() + sampler.add_batches(batch_df, start_col_name="node", batch_col_name="batch") + sampler.flush() + end_time = perf_counter() + print("flushed all batches") + total_time += end_time - start_time + + return total_time def assign_offsets_pyg(node_counts: Dict[str, int], replication_factor: int = 1): @@ -253,7 +278,6 @@ def generate_rmat_dataset( labeled_percentage=0.01, num_labels=256, reverse_edges=False, - persist=False, add_edge_types=False, ): """ @@ -282,12 +306,8 @@ def generate_rmat_dataset( dask_edgelist_df = dask_edgelist_df.reset_index(drop=True) dask_edgelist_df = renumber_ddf(dask_edgelist_df).persist() - if persist: - dask_edgelist_df = dask_edgelist_df.persist() dask_edgelist_df = symmetrize_ddf(dask_edgelist_df).persist() - if persist: - dask_edgelist_df = dask_edgelist_df.persist() if add_edge_types: dask_edgelist_df["etp"] = cupy.int32( @@ -329,7 +349,6 @@ def load_disk_dataset( dataset_dir=".", reverse_edges=True, replication_factor=1, - persist=False, add_edge_types=False, ): from pathlib import Path @@ -363,8 +382,6 @@ def load_disk_dataset( ] edge_index_dict[can_edge_type] = edge_index_dict[can_edge_type] - if persist: - edge_index_dict = edge_index_dict.persist() if replication_factor > 1: edge_index_dict[can_edge_type] = edge_index_dict[ @@ -384,11 +401,6 @@ def load_disk_dataset( ), ) - if persist: - edge_index_dict[can_edge_type] = edge_index_dict[ - can_edge_type - ].persist() - gc.collect() if reverse_edges: @@ -396,9 +408,6 @@ def load_disk_dataset( columns={"src": "dst", "dst": "src"} ) - if persist: - edge_index_dict[can_edge_type] = edge_index_dict[can_edge_type].persist() - # Assign numeric edge type ids based on lexicographic order edge_offsets = {} edge_count = 0 @@ -410,9 +419,6 @@ def load_disk_dataset( all_edges_df = dask_cudf.concat(list(edge_index_dict.values())) - if persist: - all_edges_df = all_edges_df.persist() - del edge_index_dict gc.collect() @@ -440,15 +446,9 @@ def load_disk_dataset( meta=cudf.DataFrame({"node": cudf.Series(dtype="int64")}), ) - if persist: - node_labels[node_type] = node_labels[node_type].persist() - gc.collect() - node_labels_df = dask_cudf.concat(list(node_labels.values())) - - if persist: - node_labels_df = node_labels_df.persist() + node_labels_df = dask_cudf.concat(list(node_labels.values())).reset_index(drop=True) del node_labels gc.collect() @@ -475,8 +475,8 @@ def benchmark_cugraph_bulk_sampling( replication_factor=1, num_labels=256, labeled_percentage=0.001, - persist=False, add_edge_types=False, + num_epochs=1, ): """ Entry point for the benchmark. @@ -506,14 +506,17 @@ def benchmark_cugraph_bulk_sampling( labeled_percentage: float The percentage of the data that is labeled (only for rmat datasets) Defaults to 0.001 to match papers100M - persist: bool - Whether to aggressively persist data in dask in attempt to speed up ETL. - Defaults to False. add_edge_types: bool Whether to add edge types to the edgelist. Defaults to False. + sampling_target_framework: str + The framework to sample for. + num_epochs: int + The number of epochs to sample for. """ - print(dataset) + + logger = logging.getLogger("__main__") + logger.info(str(dataset)) if dataset[0:4] == "rmat": ( dask_edgelist_df, @@ -527,7 +530,6 @@ def benchmark_cugraph_bulk_sampling( seed=seed, labeled_percentage=labeled_percentage, num_labels=num_labels, - persist=persist, add_edge_types=add_edge_types, ) @@ -543,28 +545,25 @@ def benchmark_cugraph_bulk_sampling( dataset_dir=dataset_dir, reverse_edges=reverse_edges, replication_factor=replication_factor, - persist=persist, add_edge_types=add_edge_types, ) num_input_edges = len(dask_edgelist_df) - print(f"Number of input edges = {num_input_edges:,}") + logger.info(f"Number of input edges = {num_input_edges:,}") G = construct_graph(dask_edgelist_df) del dask_edgelist_df - print("constructed graph") + logger.info("constructed graph") input_memory = G.edgelist.edgelist_df.memory_usage().sum().compute() - print(f"input memory: {input_memory}") + logger.info(f"input memory: {input_memory}") output_subdir = os.path.join( - output_path, f"{dataset}[{replication_factor}]_b{batch_size}_f{fanout}" + output_path, + f"{dataset}[{replication_factor}]_b{batch_size}_f{fanout}", ) os.makedirs(output_subdir) - output_sample_path = os.path.join(output_subdir, "samples") - os.makedirs(output_sample_path) - if sampling_target_framework == "cugraph_dgl_csr": sampling_kwargs = { "deduplicate_sources": True, @@ -587,11 +586,12 @@ def benchmark_cugraph_bulk_sampling( "include_hop_column": True, } - batches_per_partition = 400_000 // batch_size + batches_per_partition = 600_000 // batch_size execution_time, allocation_counts = sample_graph( G=G, label_df=dask_label_df, - output_path=output_sample_path, + output_path=output_subdir, + num_epochs=num_epochs, seed=seed, batch_size=batch_size, seeds_per_call=seeds_per_call, @@ -620,8 +620,8 @@ def benchmark_cugraph_bulk_sampling( with open(os.path.join(output_subdir, "output_meta.json"), "w") as f: json.dump(output_meta, f, indent="\t") - print("allocation counts b:") - print(allocation_counts.values()) + logger.info("allocation counts b:") + logger.info(allocation_counts.values()) ( input_to_peak_ratio, @@ -631,8 +631,8 @@ def benchmark_cugraph_bulk_sampling( ) = get_memory_statistics( allocation_counts=allocation_counts, input_memory=input_memory ) - print(f"Number of edges in final graph = {G.number_of_edges():,}") - print("-" * 80) + logger.info(f"Number of edges in final graph = {G.number_of_edges():,}") + logger.info("-" * 80) return ( num_input_edges, input_to_peak_ratio, @@ -693,12 +693,20 @@ def get_args(): required=True, ) + parser.add_argument( + "--num_epochs", + type=int, + help="Number of epochs to run for", + required=False, + default=1, + ) + parser.add_argument( "--fanouts", type=str, - help="Comma separated list of fanouts (i.e. 10_25,5_5_5)", + help='Comma separated list of fanouts (i.e. "10_25,5_5_5")', required=False, - default="10_25", + default="10_10_10", ) parser.add_argument( @@ -743,28 +751,14 @@ def get_args(): "--random_seed", type=int, help="Random seed", required=False, default=62 ) - parser.add_argument( - "--persist", - action="store_true", - help="Will add additional persist() calls to speed up ETL. Does not affect sampling runtime.", - required=False, - default=False, - ) - - parser.add_argument( - "--add_edge_types", - action="store_true", - help="Adds edge types to the edgelist. Required for PyG if not providing edge ids.", - required=False, - default=False, - ) - return parser.parse_args() # call __main__ function if __name__ == "__main__": logging.basicConfig() + logger = logging.getLogger("__main__") + logger.setLevel(logging.INFO) args = get_args() if args.sampling_target_framework not in ["cugraph_dgl_csr", None]: @@ -781,29 +775,28 @@ def get_args(): seeds_per_call_opts = [int(s) for s in args.seeds_per_call_opts.split(",")] dask_worker_devices = [int(d) for d in args.dask_worker_devices.split(",")] - client, cluster = start_dask_client( - dask_worker_devices=dask_worker_devices, - jit_unspill=False, - rmm_pool_size=28e9, - rmm_async=True, - ) + logger.info("starting dask client") + client, cluster = start_dask_client() enable_spilling() stats_ls = [] client.run(enable_spilling) + logger.info("dask client started") for dataset in datasets: - if re.match(r"([A-z]|[0-9])+\[[0-9]+\]", dataset): - replication_factor = int(dataset[-2]) - dataset = dataset[:-3] + m = re.match(r"(\w+)\[([0-9]+)\]", dataset) + if m: + replication_factor = int(m.groups()[1]) + dataset = m.groups()[0] else: replication_factor = 1 for fanout in fanouts: for batch_size in batch_sizes: for seeds_per_call in seeds_per_call_opts: - print(f"dataset: {dataset}") - print(f"batch size: {batch_size}") - print(f"fanout: {fanout}") - print(f"seeds_per_call: {seeds_per_call}") + logger.info(f"dataset: {dataset}") + logger.info(f"batch size: {batch_size}") + logger.info(f"fanout: {fanout}") + logger.info(f"seeds_per_call: {seeds_per_call}") + logger.info(f"num epochs: {args.num_epochs}") try: stats_d = {} @@ -816,6 +809,7 @@ def get_args(): ) = benchmark_cugraph_bulk_sampling( dataset=dataset, output_path=args.output_root, + num_epochs=args.num_epochs, seed=args.random_seed, batch_size=batch_size, seeds_per_call=seeds_per_call, @@ -824,8 +818,6 @@ def get_args(): dataset_dir=args.dataset_root, reverse_edges=args.reverse_edges, replication_factor=replication_factor, - persist=args.persist, - add_edge_types=args.add_edge_types, ) stats_d["dataset"] = dataset stats_d["num_input_edges"] = num_input_edges diff --git a/benchmarks/cugraph/standalone/bulk_sampling/datasets/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/datasets/__init__.py new file mode 100644 index 00000000000..0f4b516cd80 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/datasets/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dataset import Dataset +from .ogbn_papers100M import OGBNPapers100MDataset diff --git a/benchmarks/cugraph/standalone/bulk_sampling/datasets/dataset.py b/benchmarks/cugraph/standalone/bulk_sampling/datasets/dataset.py new file mode 100644 index 00000000000..f914f69fa4e --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/datasets/dataset.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from typing import Dict, Tuple + + +class Dataset: + @property + def edge_index_dict(self) -> Dict[Tuple[str, str, str], Dict[str, torch.Tensor]]: + raise NotImplementedError() + + @property + def x_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def y_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def train_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def test_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def val_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def num_input_features(self) -> int: + raise NotImplementedError() + + @property + def num_labels(self) -> int: + raise NotImplementedError() + + def num_nodes(self, node_type: str) -> int: + raise NotImplementedError() + + def num_edges(self, edge_type: Tuple[str, str, str]) -> int: + raise NotImplementedError() diff --git a/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py new file mode 100644 index 00000000000..a50e40f6d55 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py @@ -0,0 +1,345 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dataset import Dataset +from typing import Dict, Tuple, Union + +import pandas +import torch +import numpy as np + +from sklearn.model_selection import train_test_split + +import gc +import os +import json + + +class OGBNPapers100MDataset(Dataset): + def __init__( + self, + *, + replication_factor=1, + dataset_dir=".", + train_split=0.8, + val_split=0.5, + load_edge_index=True, + ): + self.__replication_factor = replication_factor + self.__disk_x = None + self.__y = None + self.__edge_index = None + self.__dataset_dir = dataset_dir + self.__train_split = train_split + self.__val_split = val_split + self.__load_edge_index = load_edge_index + + def download(self): + import logging + + logger = logging.getLogger("OGBNPapers100MDataset") + logger.info("Processing dataset...") + + dataset_path = os.path.join(self.__dataset_dir, "ogbn_papers100M") + + meta_json_path = os.path.join(dataset_path, "meta.json") + if not os.path.exists(meta_json_path): + j = { + "num_nodes": {"paper": 111059956}, + "num_edges": {"paper__cites__paper": 1615685872}, + } + with open(meta_json_path, "w") as file: + json.dump(j, file) + + dataset = None + if not os.path.exists(dataset_path): + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + features_path = os.path.join(dataset_path, "npy", "paper") + os.makedirs(features_path, exist_ok=True) + + logger.info("Processing node features...") + if self.__replication_factor == 1: + replication_path = os.path.join(features_path, "node_feat.npy") + else: + replication_path = os.path.join( + features_path, f"node_feat_{self.__replication_factor}x.npy" + ) + if not os.path.exists(replication_path): + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + node_feat = dataset[0][0]["node_feat"] + if self.__replication_factor != 1: + node_feat_replicated = np.concat( + [node_feat] * self.__replication_factor + ) + node_feat = node_feat_replicated + np.save(replication_path, node_feat) + + logger.info("Processing edge index...") + edge_index_parquet_path = os.path.join( + dataset_path, "parquet", "paper__cites__paper" + ) + os.makedirs(edge_index_parquet_path, exist_ok=True) + + edge_index_parquet_file_path = os.path.join( + edge_index_parquet_path, "edge_index.parquet" + ) + if not os.path.exists(edge_index_parquet_file_path): + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + edge_index = dataset[0][0]["edge_index"] + eidf = pandas.DataFrame({"src": edge_index[0], "dst": edge_index[1]}) + eidf.to_parquet(edge_index_parquet_file_path) + + edge_index_npy_path = os.path.join(dataset_path, "npy", "paper__cites__paper") + os.makedirs(edge_index_npy_path, exist_ok=True) + + edge_index_npy_file_path = os.path.join(edge_index_npy_path, "edge_index.npy") + if not os.path.exists(edge_index_npy_file_path): + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + edge_index = dataset[0][0]["edge_index"] + np.save(edge_index_npy_file_path, edge_index) + + logger.info("Processing labels...") + node_label_path = os.path.join(dataset_path, "parquet", "paper") + os.makedirs(node_label_path, exist_ok=True) + + node_label_file_path = os.path.join(node_label_path, "node_label.parquet") + if not os.path.exists(node_label_file_path): + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + ldf = pandas.Series(dataset[0][1].T[0]) + ldf = ( + ldf[ldf >= 0] + .reset_index() + .rename(columns={"index": "node", 0: "label"}) + ) + ldf.to_parquet(node_label_file_path) + + @property + def edge_index_dict( + self, + ) -> Dict[Tuple[str, str, str], Union[Dict[str, torch.Tensor], int]]: + import logging + + logger = logging.getLogger("OGBNPapers100MDataset") + + if self.__edge_index is None: + if self.__load_edge_index: + npy_path = os.path.join( + self.__dataset_dir, + "ogbn_papers100M", + "npy", + "paper__cites__paper", + "edge_index.npy", + ) + + logger.info(f"loading edge index from {npy_path}") + ei = np.load(npy_path, mmap_mode="r") + ei = torch.as_tensor(ei) + ei = { + "src": ei[1], + "dst": ei[0], + } + + logger.info("sorting edge index...") + ei["dst"], ix = torch.sort(ei["dst"]) + ei["src"] = ei["src"][ix] + del ix + gc.collect() + + logger.info("processing replications...") + orig_num_nodes = self.num_nodes("paper") // self.__replication_factor + if self.__replication_factor > 1: + orig_src = ei["src"].clone().detach() + orig_dst = ei["dst"].clone().detach() + for r in range(1, self.__replication_factor): + ei["src"] = torch.concat( + [ + ei["src"], + orig_src + int(r * orig_num_nodes), + ] + ) + + ei["dst"] = torch.concat( + [ + ei["dst"], + orig_dst + int(r * orig_num_nodes), + ] + ) + + del orig_src + del orig_dst + + ei["src"] = ei["src"].contiguous() + ei["dst"] = ei["dst"].contiguous() + gc.collect() + + logger.info(f"# edges: {len(ei['src'])}") + self.__edge_index = {("paper", "cites", "paper"): ei} + else: + self.__edge_index = { + ("paper", "cites", "paper"): self.num_edges( + ("paper", "cites", "paper") + ) + } + + return self.__edge_index + + @property + def x_dict(self) -> Dict[str, torch.Tensor]: + node_type_path = os.path.join( + self.__dataset_dir, "ogbn_papers100M", "npy", "paper" + ) + + if self.__disk_x is None: + if self.__replication_factor == 1: + full_path = os.path.join(node_type_path, "node_feat.npy") + else: + full_path = os.path.join( + node_type_path, f"node_feat_{self.__replication_factor}x.npy" + ) + + self.__disk_x = {"paper": np.load(full_path, mmap_mode="r")} + + return self.__disk_x + + @property + def y_dict(self) -> Dict[str, torch.Tensor]: + if self.__y is None: + self.__get_labels() + + return self.__y + + @property + def train_dict(self) -> Dict[str, torch.Tensor]: + if self.__train is None: + self.__get_labels() + return self.__train + + @property + def test_dict(self) -> Dict[str, torch.Tensor]: + if self.__test is None: + self.__get_labels() + return self.__test + + @property + def val_dict(self) -> Dict[str, torch.Tensor]: + if self.__val is None: + self.__get_labels() + return self.__val + + @property + def num_input_features(self) -> int: + return int(self.x_dict["paper"].shape[1]) + + @property + def num_labels(self) -> int: + return int(self.y_dict["paper"].max()) + 1 + + def num_nodes(self, node_type: str) -> int: + if node_type != "paper": + raise ValueError(f"Invalid node type {node_type}") + + return 111_059_956 * self.__replication_factor + + def num_edges(self, edge_type: Tuple[str, str, str]) -> int: + if edge_type != ("paper", "cites", "paper"): + raise ValueError(f"Invalid edge type {edge_type}") + + return 1_615_685_872 * self.__replication_factor + + def __get_labels(self): + label_path = os.path.join( + self.__dataset_dir, + "ogbn_papers100M", + "parquet", + "paper", + "node_label.parquet", + ) + + node_label = pandas.read_parquet(label_path) + + if self.__replication_factor > 1: + orig_num_nodes = self.num_nodes("paper") // self.__replication_factor + dfr = pandas.DataFrame( + { + "node": pandas.concat( + [ + node_label.node + (r * orig_num_nodes) + for r in range(1, self.__replication_factor) + ] + ), + "label": pandas.concat( + [node_label.label for r in range(1, self.__replication_factor)] + ), + } + ) + node_label = pandas.concat([node_label, dfr]).reset_index(drop=True) + + num_nodes = self.num_nodes("paper") + node_label_tensor = torch.full( + (num_nodes,), -1, dtype=torch.float32, device="cpu" + ) + node_label_tensor[ + torch.as_tensor(node_label.node.values, device="cpu") + ] = torch.as_tensor(node_label.label.values, device="cpu") + + self.__y = {"paper": node_label_tensor.contiguous()} + + train_ix, test_val_ix = train_test_split( + torch.as_tensor(node_label.node.values), + train_size=self.__train_split, + random_state=num_nodes, + ) + test_ix, val_ix = train_test_split( + test_val_ix, test_size=self.__val_split, random_state=num_nodes + ) + + train_tensor = torch.full((num_nodes,), 0, dtype=torch.bool, device="cpu") + train_tensor[train_ix] = 1 + self.__train = {"paper": train_tensor} + + test_tensor = torch.full((num_nodes,), 0, dtype=torch.bool, device="cpu") + test_tensor[test_ix] = 1 + self.__test = {"paper": test_tensor} + + val_tensor = torch.full((num_nodes,), 0, dtype=torch.bool, device="cpu") + val_tensor[val_ix] = 1 + self.__val = {"paper": val_tensor} diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/models/__init__.py new file mode 100644 index 00000000000..c2002fd3fb9 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/__init__.py new file mode 100644 index 00000000000..337cb0fa243 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .models_cugraph_pyg import CuGraphSAGE +from .models_pyg import GraphSAGE diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py new file mode 100644 index 00000000000..1de791bf588 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from cugraph_pyg.nn.conv import SAGEConv as CuGraphSAGEConv + +try: + from torch_geometric.utils.trim_to_layer import TrimToLayer +except ModuleNotFoundError: + from torch_geometric.utils._trim_to_layer import TrimToLayer + +import torch.nn as nn +import torch.nn.functional as F + + +def extend_tensor(t: torch.Tensor, l: int): + return torch.concat([t, torch.zeros(l - len(t), dtype=t.dtype, device=t.device)]) + + +class CuGraphSAGE(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers): + super().__init__() + + self.convs = torch.nn.ModuleList() + self.convs.append(CuGraphSAGEConv(in_channels, hidden_channels, aggr="mean")) + for _ in range(num_layers - 2): + conv = CuGraphSAGEConv(hidden_channels, hidden_channels, aggr="mean") + self.convs.append(conv) + + self.convs.append(CuGraphSAGEConv(hidden_channels, out_channels, aggr="mean")) + + self._trim = TrimToLayer() + + def forward(self, x, edge, num_sampled_nodes, num_sampled_edges): + if isinstance(edge, torch.Tensor): + edge = list( + CuGraphSAGEConv.to_csc( + edge.cuda(), (x.shape[0], num_sampled_nodes.sum()) + ) + ) + else: + edge = edge.csr() + edge = [edge[1], edge[0], x.shape[0]] + + x = x.cuda().to(torch.float32) + + for i, conv in enumerate(self.convs): + if i > 0: + new_num_edges = edge[1][-2] + edge[0] = edge[0].narrow( + dim=0, + start=0, + length=new_num_edges, + ) + edge[1] = edge[1].narrow( + dim=0, start=0, length=edge[1].size(0) - num_sampled_nodes[-i - 1] + ) + edge[2] = x.shape[0] + + x = conv(x, edge) + + x = F.relu(x) + x = F.dropout(x, p=0.5) + + x = x.narrow(dim=0, start=0, length=num_sampled_nodes[0]) + + return x diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_pyg.py new file mode 100644 index 00000000000..37f98d5362d --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_pyg.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from torch_geometric.nn import SAGEConv + +try: + from torch_geometric.utils.trim_to_layer import TrimToLayer +except ModuleNotFoundError: + from torch_geometric.utils._trim_to_layer import TrimToLayer + +import torch.nn as nn +import torch.nn.functional as F + + +class GraphSAGE(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers): + super().__init__() + + self.convs = torch.nn.ModuleList() + self.convs.append(SAGEConv(in_channels, hidden_channels, aggr="mean")) + for _ in range(num_layers - 2): + conv = SAGEConv(hidden_channels, hidden_channels, aggr="mean") + self.convs.append(conv) + + self.convs.append(SAGEConv(hidden_channels, out_channels, aggr="mean")) + + self._trim = TrimToLayer() + + def forward(self, x, edge, num_sampled_nodes, num_sampled_edges): + edge = edge.cuda() + x = x.cuda().to(torch.float32) + + for i, conv in enumerate(self.convs): + x, edge, _ = self._trim( + i, num_sampled_nodes, num_sampled_edges, x, edge, None + ) + + s = x.shape[0] + x = conv(x, edge, size=(s, s)) + x = F.relu(x) + x = F.dropout(x, p=0.5) + + x = x.narrow(dim=0, start=0, length=x.shape[0] - num_sampled_nodes[1]) + + # assert x.shape[0] == num_sampled_nodes[0] + return x diff --git a/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh b/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh new file mode 100644 index 00000000000..41792c0b63a --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh @@ -0,0 +1,111 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +conda init +source ~/.bashrc +conda activate rapids + +BATCH_SIZE=$1 +FANOUT=$2 +REPLICATION_FACTOR=$3 +SCRIPTS_DIR=$4 +NUM_EPOCHS=$5 + +SAMPLES_DIR=/samples +DATASET_DIR=/datasets +LOGS_DIR=/logs + +MG_UTILS_DIR=${SCRIPTS_DIR}/mg_utils +SCHEDULER_FILE=${MG_UTILS_DIR}/dask_scheduler.json + +export WORKER_RMM_POOL_SIZE=28G +export UCX_MAX_RNDV_RAILS=1 +export RAPIDS_NO_INITIALIZE=1 +export CUDF_SPILL=1 +export LIBCUDF_CUFILE_POLICY="OFF" +export GPUS_PER_NODE=8 + +export SCHEDULER_FILE=$SCHEDULER_FILE +export LOGS_DIR=$LOGS_DIR + +function handleTimeout { + seconds=$1 + eval "timeout --signal=2 --kill-after=60 $*" + LAST_EXITCODE=$? + if (( $LAST_EXITCODE == 124 )); then + logger "ERROR: command timed out after ${seconds} seconds" + elif (( $LAST_EXITCODE == 137 )); then + logger "ERROR: command timed out after ${seconds} seconds, and had to be killed with signal 9" + fi + ERRORCODE=$((ERRORCODE | ${LAST_EXITCODE})) +} + +DASK_STARTUP_ERRORCODE=0 +if [[ $SLURM_NODEID == 0 ]]; then + ${MG_UTILS_DIR}/run-dask-process.sh scheduler workers & +else + ${MG_UTILS_DIR}/run-dask-process.sh workers & +fi + +echo "properly waiting for workers to connect" +NUM_GPUS=$(python -c "import os; print(int(os.environ['SLURM_JOB_NUM_NODES'])*int(os.environ['GPUS_PER_NODE']))") +handleTimeout 120 python ${MG_UTILS_DIR}/wait_for_workers.py \ + --num-expected-workers ${NUM_GPUS} \ + --scheduler-file-path ${SCHEDULER_FILE} + + +DASK_STARTUP_ERRORCODE=$LAST_EXITCODE + +echo $SLURM_NODEID +if [[ $SLURM_NODEID == 0 ]]; then + echo "Launching Python Script" + python ${SCRIPTS_DIR}/cugraph_bulk_sampling.py \ + --output_root ${SAMPLES_DIR} \ + --dataset_root ${DATASET_DIR} \ + --datasets "ogbn_papers100M["$REPLICATION_FACTOR"]" \ + --fanouts $FANOUT \ + --batch_sizes $BATCH_SIZE \ + --seeds_per_call_opts "524288" \ + --num_epochs $NUM_EPOCHS \ + --random_seed 42 + + echo "DONE" > ${SAMPLES_DIR}/status.txt +fi + +while [ ! -f "${SAMPLES_DIR}"/status.txt ] +do + sleep 1 +done + +sleep 3 + +# At this stage there should be no running processes except /usr/lpp/mmfs/bin/mmsysmon.py +dask_processes=$(pgrep -la dask) +python_processes=$(pgrep -la python) +echo "$dask_processes" +echo "$python_processes" + +if [[ ${#python_processes[@]} -gt 1 || $dask_processes ]]; then + logger "The client was not shutdown properly, killing dask/python processes for Node $SLURM_NODEID" + # This can be caused by a job timeout + pkill python + pkill dask + pgrep -la python + pgrep -la dask +fi +sleep 2 + +if [[ $SLURM_NODEID == 0 ]]; then + rm ${SAMPLES_DIR}/status.txt +fi \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh b/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh new file mode 100755 index 00000000000..977745a9593 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#SBATCH -A datascience_rapids_cugraphgnn +#SBATCH -p luna +#SBATCH -J datascience_rapids_cugraphgnn-papers:bulkSamplingPyG +#SBATCH -N 1 +#SBATCH -t 00:25:00 + +CONTAINER_IMAGE=${CONTAINER_IMAGE:="please_specify_container"} +SCRIPTS_DIR=$(pwd) +LOGS_DIR=${LOGS_DIR:=$(pwd)"/logs"} +SAMPLES_DIR=${SAMPLES_DIR:=$(pwd)/samples} +DATASETS_DIR=${DATASETS_DIR:=$(pwd)/datasets} + +mkdir -p $LOGS_DIR +mkdir -p $SAMPLES_DIR +mkdir -p $DATASETS_DIR + +BATCH_SIZE=512 +FANOUT="10_10_10" +NUM_EPOCHS=1 +REPLICATION_FACTOR=1 + +# options: PyG or cuGraphPyG +FRAMEWORK="cuGraphPyG" +GPUS_PER_NODE=8 + +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +echo Node IP: $head_node_ip + +nnodes=$SLURM_JOB_NUM_NODES +echo Num Nodes: $nnodes + +gpus_per_node=$GPUS_PER_NODE +echo Num GPUs Per Node: $gpus_per_node + +set -e + +# First run without cuGraph to get data + +if [[ "$FRAMEWORK" == "cuGraphPyG" ]]; then + # Generate samples + srun \ + --container-image $CONTAINER_IMAGE \ + --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ + bash /scripts/run_sampling.sh $BATCH_SIZE $FANOUT $REPLICATION_FACTOR "/scripts" $NUM_EPOCHS +fi + +# Train +srun \ + --container-image $CONTAINER_IMAGE \ + --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ + torchrun \ + --nnodes $nnodes \ + --nproc-per-node $gpus_per_node \ + --rdzv-id $RANDOM \ + --rdzv-backend c10d \ + --rdzv-endpoint $head_node_ip:29500 \ + /scripts/bench_cugraph_training.py \ + --output_file "/logs/output.txt" \ + --framework $FRAMEWORK \ + --dataset_dir "/datasets" \ + --sample_dir "/samples" \ + --batch_size $BATCH_SIZE \ + --fanout $FANOUT \ + --replication_factor $REPLICATION_FACTOR \ + --num_epochs $NUM_EPOCHS + diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/__init__.py new file mode 100644 index 00000000000..5f8f4c2b868 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .trainer import Trainer +from .trainer import extend_tensor diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/__init__.py new file mode 100644 index 00000000000..def6110b8e5 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .trainers_cugraph_pyg import PyGCuGraphTrainer +from .trainers_pyg import PyGNativeTrainer diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py new file mode 100644 index 00000000000..71151e9ba59 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .trainers_pyg import PyGTrainer +from models.pyg import CuGraphSAGE + +import torch +import numpy as np + +from torch.nn.parallel import DistributedDataParallel as ddp + +from cugraph.gnn import FeatureStore +from cugraph_pyg.data import CuGraphStore +from cugraph_pyg.loader import BulkSampleLoader + +import os + + +class PyGCuGraphTrainer(PyGTrainer): + def __init__( + self, + dataset, + model="GraphSAGE", + device=0, + rank=0, + world_size=1, + num_epochs=1, + sample_dir=".", + **kwargs, + ): + self.__data = None + self.__device = device + self.__rank = rank + self.__world_size = world_size + self.__num_epochs = num_epochs + self.__dataset = dataset + self.__sample_dir = sample_dir + self.__loader_kwargs = kwargs + self.__model = self.get_model(model) + self.__optimizer = None + + @property + def rank(self): + return self.__rank + + @property + def model(self): + return self.__model + + @property + def dataset(self): + return self.__dataset + + @property + def optimizer(self): + if self.__optimizer is None: + self.__optimizer = torch.optim.Adam( + self.model.parameters(), lr=0.01, weight_decay=0.0005 + ) + return self.__optimizer + + @property + def num_epochs(self) -> int: + return self.__num_epochs + + def get_loader(self, epoch: int = 0, stage="train") -> int: + import logging + + logger = logging.getLogger("PyGCuGraphTrainer") + + logger.info(f"getting loader for epoch {epoch}, {stage} stage") + + # TODO support online sampling + if stage == "val": + path = os.path.join(self.__sample_dir, "val", "samples") + else: + path = os.path.join(self.__sample_dir, f"epoch={epoch}", stage, "samples") + + loader = BulkSampleLoader( + self.data, + self.data, + None, # FIXME get input nodes properly + directory=path, + input_files=self.get_input_files(path, epoch=epoch, stage=stage), + **self.__loader_kwargs, + ) + + logger.info(f"got loader successfully on rank {self.rank}") + return loader + + @property + def data(self): + import logging + + logger = logging.getLogger("PyGCuGraphTrainer") + logger.info("getting data") + + if self.__data is None: + # FIXME wholegraph + fs = FeatureStore(backend="torch") + num_nodes_dict = {} + + for node_type, x in self.__dataset.x_dict.items(): + logger.debug(f"getting x for {node_type}") + fs.add_data(x, node_type, "x") + num_nodes_dict[node_type] = self.__dataset.num_nodes(node_type) + + for node_type, y in self.__dataset.y_dict.items(): + logger.debug(f"getting y for {node_type}") + fs.add_data(y, node_type, "y") + + for node_type, train in self.__dataset.train_dict.items(): + logger.debug(f"getting train for {node_type}") + fs.add_data(train, node_type, "train") + + for node_type, test in self.__dataset.test_dict.items(): + logger.debug(f"getting test for {node_type}") + fs.add_data(test, node_type, "test") + + for node_type, val in self.__dataset.val_dict.items(): + logger.debug(f"getting val for {node_type}") + fs.add_data(val, node_type, "val") + + # TODO support online sampling if the edge index is provided + num_edges_dict = self.__dataset.edge_index_dict + if not isinstance(list(num_edges_dict.values())[0], int): + num_edges_dict = {k: len(v) for k, v in num_edges_dict} + + self.__data = CuGraphStore( + fs, + num_edges_dict, + num_nodes_dict, + ) + + logger.info(f"got data successfully on rank {self.rank}") + + return self.__data + + def get_model(self, name="GraphSAGE"): + if name != "GraphSAGE": + raise ValueError("only GraphSAGE is currently supported") + + num_input_features = self.__dataset.num_input_features + num_output_features = self.__dataset.num_labels + num_layers = len(self.__loader_kwargs["num_neighbors"]) + + with torch.cuda.device(self.__device): + model = ( + CuGraphSAGE( + in_channels=num_input_features, + hidden_channels=64, + out_channels=num_output_features, + num_layers=num_layers, + ) + .to(torch.float32) + .to(self.__device) + ) + + model = ddp(model, device_ids=[self.__device]) + print("done creating model") + + return model + + def get_input_files(self, path, epoch=0, stage="train"): + file_list = np.array(os.listdir(path)) + file_list.sort() + + if stage == "train": + splits = np.array_split(file_list, self.__world_size) + np.random.seed(epoch) + np.random.shuffle(splits) + return splits[self.rank] + else: + return file_list diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py new file mode 100644 index 00000000000..bddd6ae2644 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py @@ -0,0 +1,430 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from trainers import Trainer +from trainers import extend_tensor +from datasets import OGBNPapers100MDataset +from models.pyg import GraphSAGE + +import torch +import numpy as np + +import torch.distributed as td +from torch.nn.parallel import DistributedDataParallel as ddp +import torch.nn.functional as F + +from torch_geometric.utils.sparse import index2ptr +from torch_geometric.data import HeteroData +from torch_geometric.loader import NeighborLoader + +import gc +import os +import time + + +def pyg_num_workers(world_size): + num_workers = None + if hasattr(os, "sched_getaffinity"): + try: + num_workers = len(os.sched_getaffinity(0)) / (2 * world_size) + except Exception: + pass + if num_workers is None: + num_workers = os.cpu_count() / (2 * world_size) + return int(num_workers) + + +class PyGTrainer(Trainer): + def train(self): + import logging + + logger = logging.getLogger("PyGTrainer") + logger.info("Entered train loop") + + total_loss = 0.0 + num_batches = 0 + + time_forward = 0.0 + time_backward = 0.0 + time_loader = 0.0 + time_feature_transfer = 0.0 + start_time = time.perf_counter() + end_time_backward = start_time + + for epoch in range(self.num_epochs): + with td.algorithms.join.Join( + [self.model], divide_by_initial_world_size=False + ): + self.model.train() + for iter_i, data in enumerate( + self.get_loader(epoch=epoch, stage="train") + ): + loader_time_iter = time.perf_counter() - end_time_backward + time_loader += loader_time_iter + + time_feature_transfer_start = time.perf_counter() + + num_sampled_nodes = sum( + [ + torch.as_tensor(n) + for n in data.num_sampled_nodes_dict.values() + ] + ) + num_sampled_edges = sum( + [ + torch.as_tensor(e) + for e in data.num_sampled_edges_dict.values() + ] + ) + + # FIXME find a way to get around this and not have to call extend_tensor + num_layers = len(self.model.module.convs) + num_sampled_nodes = extend_tensor(num_sampled_nodes, num_layers + 1) + num_sampled_edges = extend_tensor(num_sampled_edges, num_layers) + + data = data.to_homogeneous().cuda() + time_feature_transfer_end = time.perf_counter() + time_feature_transfer += ( + time_feature_transfer_end - time_feature_transfer_start + ) + + num_batches += 1 + if iter_i % 20 == 1: + time_forward_iter = time_forward / num_batches + time_backward_iter = time_backward / num_batches + + total_time_iter = ( + time.perf_counter() - start_time + ) / num_batches + logger.info(f"epoch {epoch}, iteration {iter_i}") + logger.info(f"num sampled nodes: {num_sampled_nodes}") + logger.info(f"num sampled edges: {num_sampled_edges}") + logger.info(f"time forward: {time_forward_iter}") + logger.info(f"time backward: {time_backward_iter}") + logger.info(f"loader time: {loader_time_iter}") + logger.info( + f"feature transfer time: {time_feature_transfer / num_batches}" + ) + logger.info(f"total time: {total_time_iter}") + + y_true = data.y + x = data.x.to(torch.float32) + + start_time_forward = time.perf_counter() + edge_index = data.edge_index if "edge_index" in data else data.adj_t + + self.optimizer.zero_grad() + y_pred = self.model( + x, + edge_index, + num_sampled_nodes, + num_sampled_edges, + ) + + end_time_forward = time.perf_counter() + time_forward += end_time_forward - start_time_forward + + if y_pred.shape[0] > len(y_true): + raise ValueError( + f"illegal shape: {y_pred.shape}; {y_true.shape}" + ) + + y_true = y_true[: y_pred.shape[0]] + + y_true = F.one_hot( + y_true.to(torch.int64), num_classes=self.dataset.num_labels + ).to(torch.float32) + + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true shape was {y_true.shape} " + f"but y_pred shape was {y_pred.shape} " + f"in iteration {iter_i} " + f"on rank {y_pred.device.index}" + ) + + start_time_backward = time.perf_counter() + loss = F.cross_entropy(y_pred, y_true) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + total_loss += loss.item() + end_time_backward = time.perf_counter() + time_backward += end_time_backward - start_time_backward + + end_time = time.perf_counter() + + # test + from torchmetrics import Accuracy + + acc = Accuracy( + task="multiclass", num_classes=self.dataset.num_labels + ).cuda() + + with td.algorithms.join.Join( + [self.model], divide_by_initial_world_size=False + ): + self.model.eval() + if self.rank == 0: + acc_sum = 0.0 + with torch.no_grad(): + for i, batch in enumerate( + self.get_loader(epoch=epoch, stage="test") + ): + num_sampled_nodes = sum( + [ + torch.as_tensor(n) + for n in batch.num_sampled_nodes_dict.values() + ] + ) + num_sampled_edges = sum( + [ + torch.as_tensor(e) + for e in batch.num_sampled_edges_dict.values() + ] + ) + batch_size = num_sampled_nodes[0] + + batch = batch.to_homogeneous().cuda() + + batch.y = batch.y.to(torch.long) + out = self.model.module( + batch.x, + batch.edge_index, + num_sampled_nodes, + num_sampled_edges, + ) + acc_sum += acc( + out[:batch_size].softmax(dim=-1), batch.y[:batch_size] + ) + print( + f"Accuracy: {acc_sum/(i) * 100.0:.4f}%", + ) + + td.barrier() + + with td.algorithms.join.Join([self.model], divide_by_initial_world_size=False): + self.model.eval() + if self.rank == 0: + acc_sum = 0.0 + with torch.no_grad(): + for i, batch in enumerate( + self.get_loader(epoch=epoch, stage="val") + ): + num_sampled_nodes = sum( + [ + torch.as_tensor(n) + for n in batch.num_sampled_nodes_dict.values() + ] + ) + num_sampled_edges = sum( + [ + torch.as_tensor(e) + for e in batch.num_sampled_edges_dict.values() + ] + ) + batch_size = num_sampled_nodes[0] + + batch = batch.to_homogeneous().cuda() + + batch.y = batch.y.to(torch.long) + out = self.model.module( + batch.x, + batch.edge_index, + num_sampled_nodes, + num_sampled_edges, + ) + acc_sum += acc( + out[:batch_size].softmax(dim=-1), batch.y[:batch_size] + ) + print( + f"Validation Accuracy: {acc_sum/(i) * 100.0:.4f}%", + ) + + stats = { + "Accuracy": float(acc_sum / (i) * 100.0) if self.rank == 0 else 0.0, + "# Batches": num_batches, + "Loader Time": time_loader, + "Feature Transfer Time": time_feature_transfer, + "Forward Time": time_forward, + "Backward Time": time_backward, + } + return stats + + +class PyGNativeTrainer(PyGTrainer): + def __init__( + self, + dataset, + model="GraphSAGE", + device=0, + rank=0, + world_size=1, + num_epochs=1, + **kwargs, + ): + self.__dataset = dataset + self.__device = device + self.__data = None + self.__rank = rank + self.__num_epochs = num_epochs + self.__world_size = world_size + self.__loader_kwargs = kwargs + self.__model = self.get_model(model) + self.__optimizer = None + + @property + def rank(self): + return self.__rank + + @property + def model(self): + return self.__model + + @property + def dataset(self): + return self.__dataset + + @property + def data(self): + import logging + + logger = logging.getLogger("PyGNativeTrainer") + logger.info("getting data") + + if self.__data is None: + self.__data = HeteroData() + + for node_type, x in self.__dataset.x_dict.items(): + logger.debug(f"getting x for {node_type}") + self.__data[node_type].x = x + self.__data[node_type]["num_nodes"] = self.__dataset.num_nodes( + node_type + ) + + for node_type, y in self.__dataset.y_dict.items(): + logger.debug(f"getting y for {node_type}") + self.__data[node_type]["y"] = y + + for node_type, train in self.__dataset.train_dict.items(): + logger.debug(f"getting train for {node_type}") + self.__data[node_type]["train"] = train + + for node_type, test in self.__dataset.test_dict.items(): + logger.debug(f"getting test for {node_type}") + self.__data[node_type]["test"] = test + + for node_type, val in self.__dataset.val_dict.items(): + logger.debug(f"getting val for {node_type}") + self.__data[node_type]["val"] = val + + for can_edge_type, ei in self.__dataset.edge_index_dict.items(): + logger.info("converting to csc...") + ei["dst"] = index2ptr( + ei["dst"], self.__dataset.num_nodes(can_edge_type[2]) + ) + + logger.info("updating data structure...") + self.__data.put_edge_index( + layout="csc", + edge_index=list(ei.values()), + edge_type=can_edge_type, + size=( + self.__dataset.num_nodes(can_edge_type[0]), + self.__dataset.num_nodes(can_edge_type[2]), + ), + is_sorted=True, + ) + gc.collect() + + return self.__data + + @property + def optimizer(self): + if self.__optimizer is None: + self.__optimizer = torch.optim.Adam( + self.model.parameters(), lr=0.01, weight_decay=0.0005 + ) + return self.__optimizer + + @property + def num_epochs(self) -> int: + return self.__num_epochs + + def get_loader(self, epoch: int = 0, stage="train"): + import logging + + logger = logging.getLogger("PyGNativeTrainer") + logger.info(f"Getting loader for epoch {epoch}") + + if stage == "train": + mask_dict = self.__dataset.train_dict + elif stage == "test": + mask_dict = self.__dataset.test_dict + elif stage == "val": + mask_dict = self.__dataset.val_dict + else: + raise ValueError(f"Invalid stage {stage}") + + input_nodes_dict = { + node_type: np.array_split(np.arange(len(mask))[mask], self.__world_size)[ + self.__rank + ] + for node_type, mask in mask_dict.items() + } + + input_nodes = list(input_nodes_dict.items()) + if len(input_nodes) > 1: + raise ValueError("Multiple input node types currently unsupported") + else: + input_nodes = tuple(input_nodes[0]) + + # get loader + loader = NeighborLoader( + self.data, + input_nodes=input_nodes, + is_sorted=True, + disjoint=False, + num_workers=pyg_num_workers(self.__world_size), # FIXME change this + persistent_workers=True, + **self.__loader_kwargs, # batch size, num neighbors, replace, shuffle, etc. + ) + + logger.info("done creating loader") + return loader + + def get_model(self, name="GraphSAGE"): + if name != "GraphSAGE": + raise ValueError("only GraphSAGE is currently supported") + + num_input_features = self.__dataset.num_input_features + num_output_features = self.__dataset.num_labels + num_layers = len(self.__loader_kwargs["num_neighbors"]) + + with torch.cuda.device(self.__device): + model = ( + GraphSAGE( + in_channels=num_input_features, + hidden_channels=64, + out_channels=num_output_features, + num_layers=num_layers, + ) + .to(torch.float32) + .to(self.__device) + ) + model = ddp(model, device_ids=[self.__device]) + print("done creating model") + + return model diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/trainer.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/trainer.py new file mode 100644 index 00000000000..321edbea96e --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/trainer.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from typing import Union, List + + +def extend_tensor(t: Union[List[int], torch.Tensor], l: int): + t = torch.as_tensor(t) + + return torch.concat([t, torch.zeros(l - len(t), dtype=t.dtype, device=t.device)]) + + +class Trainer: + @property + def rank(self): + raise NotImplementedError() + + @property + def model(self): + raise NotImplementedError() + + @property + def dataset(self): + raise NotImplementedError() + + @property + def data(self): + raise NotImplementedError() + + @property + def optimizer(self): + raise NotImplementedError() + + @property + def num_epochs(self) -> int: + raise NotImplementedError() + + def get_loader(self, epoch: int = 0, stage="train"): + raise NotImplementedError() + + def train(self): + raise NotImplementedError() diff --git a/cpp/src/community/flatten_dendrogram.hpp b/cpp/src/community/flatten_dendrogram.hpp index c0186983904..a4299f17d52 100644 --- a/cpp/src/community/flatten_dendrogram.hpp +++ b/cpp/src/community/flatten_dendrogram.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mg_utils/wait_for_workers.py b/mg_utils/wait_for_workers.py new file mode 100644 index 00000000000..fa75c90d4ad --- /dev/null +++ b/mg_utils/wait_for_workers.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import yaml + +from dask.distributed import Client + + +def initialize_dask_cuda(communication_type): + communication_type = communication_type.lower() + if "ucx" in communication_type: + os.environ["UCX_MAX_RNDV_RAILS"] = "1" + + if communication_type == "ucx-ib": + os.environ["UCX_MEMTYPE_REG_WHOLE_ALLOC_TYPES"]="cuda" + os.environ["DASK_RMM__POOL_SIZE"]="0.5GB" + os.environ["DASK_DISTRIBUTED__COMM__UCX__CREATE_CUDA_CONTEXT"]="True" + + +def wait_for_workers( + num_expected_workers, scheduler_file_path, communication_type, timeout_after=0 +): + """ + Waits until num_expected_workers workers are available based on + the workers managed by scheduler_file_path, then returns 0. If + timeout_after is specified, will return 1 if num_expected_workers + workers are not available before the timeout. + """ + # FIXME: use scheduler file path from global environment if none + # supplied in configuration yaml + + print("wait_for_workers.py - initializing client...", end="") + sys.stdout.flush() + initialize_dask_cuda(communication_type) + print("done.") + sys.stdout.flush() + + ready = False + start_time = time.time() + while not ready: + if timeout_after and ((time.time() - start_time) >= timeout_after): + print( + f"wait_for_workers.py timed out after {timeout_after} seconds before finding {num_expected_workers} workers." + ) + sys.stdout.flush() + break + with Client(scheduler_file=scheduler_file_path) as client: + num_workers = len(client.scheduler_info()["workers"]) + if num_workers < num_expected_workers: + print( + f"wait_for_workers.py expected {num_expected_workers} but got {num_workers}, waiting..." + ) + sys.stdout.flush() + time.sleep(5) + else: + print(f"wait_for_workers.py got {num_workers} workers, done.") + sys.stdout.flush() + ready = True + + if ready is False: + return 1 + return 0 + + +if __name__ == "__main__": + import argparse + + ap = argparse.ArgumentParser() + ap.add_argument( + "--num-expected-workers", + type=int, + required=False, + help="Number of workers to wait for. If not specified, " + "uses the NUM_WORKERS env var if set, otherwise defaults " + "to 16.", + ) + ap.add_argument( + "--scheduler-file-path", + type=str, + required=True, + help="Path to shared scheduler file to read.", + ) + ap.add_argument( + "--communication-type", + type=str, + default="tcp", + required=False, + help="Initiliaze dask_cuda based on the cluster communication type." + "Supported values are tcp(default), ucx, ucxib, ucx-ib.", + ) + ap.add_argument( + "--timeout-after", + type=int, + default=0, + required=False, + help="Number of seconds to wait for workers. " + "Default is 0 which means wait forever.", + ) + args = ap.parse_args() + + if args.num_expected_workers is None: + args.num_expected_workers = os.environ.get("NUM_WORKERS", 16) + + exitcode = wait_for_workers( + num_expected_workers=args.num_expected_workers, + scheduler_file_path=args.scheduler_file_path, + communication_type=args.communication_type, + timeout_after=args.timeout_after, + ) + + sys.exit(exitcode) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index 8a1db4edf29..bcfaf579820 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -151,9 +151,25 @@ def __init__( self.__input_files = iter(input_files) return - input_type, input_nodes = torch_geometric.loader.utils.get_input_nodes( - (feature_store, graph_store), input_nodes + # To accommodate DLFW/PyG 2.5 + get_input_nodes = torch_geometric.loader.utils.get_input_nodes + get_input_nodes_kwargs = {} + if "input_id" in get_input_nodes.__annotations__: + get_input_nodes_kwargs["input_id"] = None + input_node_info = get_input_nodes( + (feature_store, graph_store), input_nodes, **get_input_nodes_kwargs ) + + # PyG 2.4 + if len(input_node_info) == 2: + input_type, input_nodes = input_node_info + # PyG 2.5 + elif len(input_node_info) == 3: + input_type, input_nodes, input_id = input_node_info + # Invalid + else: + raise ValueError("Invalid output from get_input_nodes") + if input_type is not None: input_nodes = graph_store._get_sample_from_vertex_groups( {input_type: input_nodes} @@ -439,7 +455,12 @@ def __next__(self): start_time_feature = perf_counter() # Create a PyG HeteroData object, loading the required features if self.__coo: - out = torch_geometric.loader.utils.filter_custom_store( + pyg_filter_fn = ( + torch_geometric.loader.utils.filter_custom_hetero_store + if hasattr(torch_geometric.loader.utils, "filter_custom_hetero_store") + else torch_geometric.loader.utils.filter_custom_store + ) + out = pyg_filter_fn( self.__feature_store, self.__graph_store, sampler_output.node, diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py index 300ca9beb5a..65cb63d25e0 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -216,7 +216,6 @@ def _sampler_output_from_sampling_results_homogeneous_csr( if renumber_map is None: raise ValueError("Renumbered input is expected for homogeneous graphs") - node_type = graph_store.node_types[0] edge_type = graph_store.edge_types[0] diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py index b39ebad8254..c99fd447aa0 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -365,10 +365,20 @@ def test_get_input_nodes(karate_gnn): F, G, N = karate_gnn cugraph_store = CuGraphStore(F, G, N) - node_type, input_nodes = torch_geometric.loader.utils.get_input_nodes( + input_node_info = torch_geometric.loader.utils.get_input_nodes( (cugraph_store, cugraph_store), "type0" ) + # PyG 2.4 + if len(input_node_info) == 2: + node_type, input_nodes = input_node_info + # PyG 2.5 + elif len(input_node_info) == 3: + node_type, input_nodes, input_id = input_node_info + # Invalid + else: + raise ValueError("Invalid output from get_input_nodes") + assert node_type == "type0" assert input_nodes.tolist() == torch.arange(17, dtype=torch.int32).tolist() diff --git a/python/cugraph/cugraph/experimental/__init__.py b/python/cugraph/cugraph/experimental/__init__.py index d809e28c92e..7e8fd666972 100644 --- a/python/cugraph/cugraph/experimental/__init__.py +++ b/python/cugraph/cugraph/experimental/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at From 24d02a5b1c112a55eb40119a86ae58edd47c3172 Mon Sep 17 00:00:00 2001 From: Chuck Hastings <45364586+ChuckHastings@users.noreply.github.com> Date: Fri, 12 Jan 2024 12:06:46 -0500 Subject: [PATCH 2/4] Fix OOB error, BFS C API should validate that the source vertex is a valid vertex (#4077) * Added error check to be sure that the source vertex is a valid vertex in the graph. * Updated `nx_cugraph.Graph` class to create PLC graphs using `vertices_array` in order to include isolated vertices. This is now needed since the error check added in this PR prevents NetworkX tests from passing if isolated vertices are treated as invalid, so this change prevents that. * This resolves the problem that required the test workarounds done [here](https://github.com/rapidsai/cugraph/pull/4029#discussion_r1433332010) in [4029](https://github.com/rapidsai/cugraph/pull/4029), so those workarounds have been removed in this PR. Closes #4067 Authors: - Chuck Hastings (https://github.com/ChuckHastings) - Rick Ratzel (https://github.com/rlratzel) Approvers: - Seunghwa Kang (https://github.com/seunghwak) - Ray Douglass (https://github.com/raydouglass) - Erik Welch (https://github.com/eriknw) URL: https://github.com/rapidsai/cugraph/pull/4077 --- ci/test_python.sh | 5 -- .../cugraph/detail/utility_wrappers.hpp | 17 +++++- cpp/src/c_api/abstract_functor.hpp | 12 +++-- cpp/src/c_api/bfs.cpp | 17 +++++- cpp/src/detail/utility_wrappers.cu | 19 ++++++- python/nx-cugraph/nx_cugraph/classes/graph.py | 17 +++++- python/nx-cugraph/nx_cugraph/interface.py | 13 +---- python/pylibcugraph/pylibcugraph/graphs.pyx | 53 ++++++++++--------- 8 files changed, 105 insertions(+), 48 deletions(-) diff --git a/ci/test_python.sh b/ci/test_python.sh index d8288758f3c..7eb5a08edc8 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -111,11 +111,6 @@ popd rapids-logger "pytest networkx using nx-cugraph backend" pushd python/nx-cugraph ./run_nx_tests.sh -# Individually run tests that are skipped above b/c they may run out of memory -PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestDAG and test_antichains" -PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestMultiDiGraph_DAGLCA and test_all_pairs_lca_pairs_without_lca" -PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestDAGLCA and test_all_pairs_lca_pairs_without_lca" -PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestEfficiency and test_using_ego_graph" # run_nx_tests.sh outputs coverage data, so check that total coverage is >0.0% # in case nx-cugraph failed to load but fallback mode allowed the run to pass. _coverage=$(coverage report|grep "^TOTAL") diff --git a/cpp/include/cugraph/detail/utility_wrappers.hpp b/cpp/include/cugraph/detail/utility_wrappers.hpp index faa0fbb841b..61ac1bd2804 100644 --- a/cpp/include/cugraph/detail/utility_wrappers.hpp +++ b/cpp/include/cugraph/detail/utility_wrappers.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -174,5 +174,20 @@ bool is_equal(raft::handle_t const& handle, raft::device_span span1, raft::device_span span2); +/** + * @brief Count the number of times a value appears in a span + * + * @tparam data_t type of data in span + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param span The span of data to compare + * @param value The value to count + * @return The count of how many instances of that value occur + */ +template +size_t count_values(raft::handle_t const& handle, + raft::device_span span, + data_t value); + } // namespace detail } // namespace cugraph diff --git a/cpp/src/c_api/abstract_functor.hpp b/cpp/src/c_api/abstract_functor.hpp index 7bff5b37380..72b433aa9af 100644 --- a/cpp/src/c_api/abstract_functor.hpp +++ b/cpp/src/c_api/abstract_functor.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,8 +32,14 @@ struct abstract_functor { void unsupported() { - error_code_ = CUGRAPH_UNSUPPORTED_TYPE_COMBINATION; - error_->error_message_ = "Type Dispatcher executing unsupported combination of types"; + mark_error(CUGRAPH_UNSUPPORTED_TYPE_COMBINATION, + "Type Dispatcher executing unsupported combination of types"); + } + + void mark_error(cugraph_error_code_t error_code, std::string const& error_message) + { + error_code_ = error_code; + error_->error_message_ = error_message; } }; diff --git a/cpp/src/c_api/bfs.cpp b/cpp/src/c_api/bfs.cpp index ae7667375d2..32841b2dd3c 100644 --- a/cpp/src/c_api/bfs.cpp +++ b/cpp/src/c_api/bfs.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -113,6 +113,21 @@ struct bfs_functor : public abstract_functor { graph_view.local_vertex_partition_range_last(), do_expensive_check_); + size_t invalid_count = cugraph::detail::count_values( + handle_, + raft::device_span{sources.data(), sources.size()}, + cugraph::invalid_vertex_id::value); + + if constexpr (multi_gpu) { + invalid_count = cugraph::host_scalar_allreduce( + handle_.get_comms(), invalid_count, raft::comms::op_t::SUM, handle_.get_stream()); + } + + if (invalid_count != 0) { + mark_error(CUGRAPH_INVALID_INPUT, "Found invalid vertex in the input sources"); + return; + } + cugraph::bfs( handle_, graph_view, diff --git a/cpp/src/detail/utility_wrappers.cu b/cpp/src/detail/utility_wrappers.cu index 2d5bf6215b1..9100ecbd5e1 100644 --- a/cpp/src/detail/utility_wrappers.cu +++ b/cpp/src/detail/utility_wrappers.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,11 +15,13 @@ */ #include #include +#include #include #include +#include #include #include #include @@ -227,5 +229,20 @@ template bool is_equal(raft::handle_t const& handle, raft::device_span span1, raft::device_span span2); +template +size_t count_values(raft::handle_t const& handle, + raft::device_span span, + data_t value) +{ + return thrust::count(handle.get_thrust_policy(), span.begin(), span.end(), value); +} + +template size_t count_values(raft::handle_t const& handle, + raft::device_span span, + int32_t value); +template size_t count_values(raft::handle_t const& handle, + raft::device_span span, + int64_t value); + } // namespace detail } // namespace cugraph diff --git a/python/nx-cugraph/nx_cugraph/classes/graph.py b/python/nx-cugraph/nx_cugraph/classes/graph.py index cdd3f744f24..cb6b4e7ae42 100644 --- a/python/nx-cugraph/nx_cugraph/classes/graph.py +++ b/python/nx-cugraph/nx_cugraph/classes/graph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -65,6 +65,7 @@ class Graph: key_to_id: dict[NodeKey, IndexValue] | None _id_to_key: list[NodeKey] | None _N: int + _node_ids: cp.ndarray[IndexValue] | None # holds plc.SGGraph.vertices_array data # Used by graph._get_plc_graph _plc_type_map: ClassVar[dict[np.dtype, np.dtype]] = { @@ -116,6 +117,7 @@ def from_coo( new_graph.key_to_id = None if key_to_id is None else dict(key_to_id) new_graph._id_to_key = None if id_to_key is None else list(id_to_key) new_graph._N = op.index(N) # Ensure N is integral + new_graph._node_ids = None new_graph.graph = new_graph.graph_attr_dict_factory() new_graph.graph.update(attr) size = new_graph.src_indices.size @@ -157,6 +159,16 @@ def from_coo( f"(got {new_graph.dst_indices.dtype.name})." ) new_graph.dst_indices = dst_indices + + # If the graph contains isolates, plc.SGGraph() must be passed a value + # for vertices_array that contains every vertex ID, since the + # src/dst_indices arrays will not contain IDs for isolates. Create this + # only if needed. Like src/dst_indices, the _node_ids array must be + # maintained for the lifetime of the plc.SGGraph + isolates = nxcg.algorithms.isolate._isolates(new_graph) + if len(isolates) > 0: + new_graph._node_ids = cp.arange(new_graph._N, dtype=index_dtype) + return new_graph @classmethod @@ -405,6 +417,7 @@ def clear(self) -> None: self.src_indices = cp.empty(0, self.src_indices.dtype) self.dst_indices = cp.empty(0, self.dst_indices.dtype) self._N = 0 + self._node_ids = None self.key_to_id = None self._id_to_key = None @@ -637,6 +650,7 @@ def _get_plc_graph( dst_indices = self.dst_indices if switch_indices: src_indices, dst_indices = dst_indices, src_indices + return plc.SGGraph( resource_handle=plc.ResourceHandle(), graph_properties=plc.GraphProperties( @@ -649,6 +663,7 @@ def _get_plc_graph( store_transposed=store_transposed, renumber=False, do_expensive_check=False, + vertices_array=self._node_ids, ) def _sort_edge_indices(self, primary="src"): diff --git a/python/nx-cugraph/nx_cugraph/interface.py b/python/nx-cugraph/nx_cugraph/interface.py index 3f6449f571a..34eb5969869 100644 --- a/python/nx-cugraph/nx_cugraph/interface.py +++ b/python/nx-cugraph/nx_cugraph/interface.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -242,20 +242,9 @@ def key(testpath): ) too_slow = "Too slow to run" - maybe_oom = "out of memory in CI" skip = { key("test_tree_isomorphism.py:test_positive"): too_slow, key("test_tree_isomorphism.py:test_negative"): too_slow, - key("test_efficiency.py:TestEfficiency.test_using_ego_graph"): maybe_oom, - key("test_dag.py:TestDAG.test_antichains"): maybe_oom, - key( - "test_lowest_common_ancestors.py:" - "TestDAGLCA.test_all_pairs_lca_pairs_without_lca" - ): maybe_oom, - key( - "test_lowest_common_ancestors.py:" - "TestMultiDiGraph_DAGLCA.test_all_pairs_lca_pairs_without_lca" - ): maybe_oom, # These repeatedly call `bfs_layers`, which converts the graph every call key( "test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph2_different_labels" diff --git a/python/pylibcugraph/pylibcugraph/graphs.pyx b/python/pylibcugraph/pylibcugraph/graphs.pyx index b3065fa0684..76ad7690840 100644 --- a/python/pylibcugraph/pylibcugraph/graphs.pyx +++ b/python/pylibcugraph/pylibcugraph/graphs.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -64,7 +64,7 @@ cdef class SGGraph(_GPUGraph): Object defining intended properties for the graph. src_or_offset_array : device array type - Device array containing either the vertex identifiers of the source of + Device array containing either the vertex identifiers of the source of each directed edge if represented in COO format or the offset if CSR format. In the case of a COO, the order of the array corresponds to the ordering of the dst_or_index_array, where the ith item in @@ -77,9 +77,14 @@ cdef class SGGraph(_GPUGraph): CSR format. In the case of a COO, The order of the array corresponds to the ordering of the src_offset_array, where the ith item in src_offset_array and the ith item in dst_index_array define the ith edge of the graph. - + vertices_array : device array type - Device array containing the isolated vertices of the graph. + Device array containing all vertices of the graph. This array is + optional, but must be used if the graph contains isolated vertices + which cannot be represented in the src_or_offset_array and + dst_index_array arrays. If specified, this array must contain every + vertex identifier, including vertex identifiers that are already + included in the src_or_offset_array and dst_index_array arrays. weight_array : device array type Device array containing the weight values of each directed edge. The @@ -99,25 +104,25 @@ cdef class SGGraph(_GPUGraph): do_expensive_check : bool If True, performs more extensive tests on the inputs to ensure validitity, at the expense of increased run time. - + edge_id_array : device array type Device array containing the edge ids of each directed edge. Must match the ordering of the src/dst arrays. Optional (may be null). If provided, edge_type_array must also be provided. - + edge_type_array : device array type Device array containing the edge types of each directed edge. Must match the ordering of the src/dst/edge_id arrays. Optional (may be null). If provided, edge_id_array must be provided. - + input_array_format: str, optional (default='COO') Input representation used to construct a graph COO: arrays represent src_array and dst_array CSR: arrays represent offset_array and index_array - + drop_self_loops : bool, optional (default='False') If true, drop any self loops that exist in the provided edge list. - + drop_multi_edges: bool, optional (default='False') If true, drop any multi edges that exist in the provided edge list @@ -178,7 +183,7 @@ cdef class SGGraph(_GPUGraph): cdef cugraph_type_erased_device_array_view_t* srcs_or_offsets_view_ptr = \ create_cugraph_type_erased_device_array_view_from_py_obj( src_or_offset_array - ) + ) cdef cugraph_type_erased_device_array_view_t* dsts_or_indices_view_ptr = \ create_cugraph_type_erased_device_array_view_from_py_obj( dst_or_index_array @@ -192,7 +197,7 @@ cdef class SGGraph(_GPUGraph): ) self.edge_id_view_ptr = create_cugraph_type_erased_device_array_view_from_py_obj( edge_id_array - ) + ) cdef cugraph_type_erased_device_array_view_t* edge_type_view_ptr = \ create_cugraph_type_erased_device_array_view_from_py_obj( edge_type_array @@ -237,7 +242,7 @@ cdef class SGGraph(_GPUGraph): assert_success(error_code, error_ptr, "cugraph_sg_graph_create_from_csr()") - + else: raise ValueError("invalid 'input_array_format'. Only " "'COO' and 'CSR' format are supported." @@ -282,7 +287,7 @@ cdef class MGGraph(_GPUGraph): each directed edge. The order of the array corresponds to the ordering of the src_array, where the ith item in src_array and the ith item in dst_array define the ith edge of the graph. - + vertices_array : device array type Device array containing the isolated vertices of the graph. @@ -295,12 +300,12 @@ cdef class MGGraph(_GPUGraph): store_transposed : bool Set to True if the graph should be transposed. This is required for some algorithms, such as pagerank. - + num_arrays : size_t Number of arrays. - + If provided, all list of device arrays should be of the same size. - + do_expensive_check : bool If True, performs more extensive tests on the inputs to ensure validitity, at the expense of increased run time. @@ -309,15 +314,15 @@ cdef class MGGraph(_GPUGraph): Device array containing the edge ids of each directed edge. Must match the ordering of the src/dst arrays. Optional (may be null). If provided, edge_type_array must also be provided. - + edge_type_array : device array type Device array containing the edge types of each directed edge. Must match the ordering of the src/dst/edge_id arrays. Optional (may be null). If provided, edge_id_array must be provided. - + drop_self_loops : bool, optional (default='False') If true, drop any self loops that exist in the provided edge list. - + drop_multi_edges: bool, optional (default='False') If true, drop any multi edges that exist in the provided edge list """ @@ -357,12 +362,12 @@ cdef class MGGraph(_GPUGraph): dst_array = [dst_array] if not any(dst_array): dst_array = dst_array * num_arrays - + if not isinstance(weight_array, list): weight_array = [weight_array] if not any(weight_array): weight_array = weight_array * num_arrays - + if not isinstance(edge_id_array, list): edge_id_array = [edge_id_array] if not any(edge_id_array): @@ -372,7 +377,7 @@ cdef class MGGraph(_GPUGraph): edge_type_array = [edge_type_array] if not any(edge_type_array): edge_type_array = edge_type_array * num_arrays - + if not isinstance(vertices_array, list): vertices_array = [vertices_array] if not any(vertices_array): @@ -394,7 +399,7 @@ cdef class MGGraph(_GPUGraph): if edge_id_array is not None and len(edge_id_array[i]) != len(src_array[i]): raise ValueError('Edge id array must be same length as edgelist') - + assert_CAI_type(edge_type_array[i], "edge_type_array", True) if edge_type_array[i] is not None and len(edge_type_array[i]) != len(src_array[i]): raise ValueError('Edge type array must be same length as edgelist') @@ -421,7 +426,7 @@ cdef class MGGraph(_GPUGraph): malloc( num_arrays * sizeof(cugraph_type_erased_device_array_view_t*)) vertices_view_ptr_ptr[i] = \ - create_cugraph_type_erased_device_array_view_from_py_obj(vertices_array[i]) + create_cugraph_type_erased_device_array_view_from_py_obj(vertices_array[i]) if weight_array[i] is not None: if i == 0: From 5d4ba388ef30bbc01b7fc1a1b61aaaa91fafb918 Mon Sep 17 00:00:00 2001 From: Naim <110031745+naimnv@users.noreply.github.com> Date: Fri, 12 Jan 2024 19:34:58 +0100 Subject: [PATCH 3/4] MNMG ECG (#4030) ECG based on Louvain. Authors: - Naim (https://github.com/naimnv) Approvers: - Seunghwa Kang (https://github.com/seunghwak) - Chuck Hastings (https://github.com/ChuckHastings) URL: https://github.com/rapidsai/cugraph/pull/4030 --- cpp/CMakeLists.txt | 5 +- cpp/include/cugraph/algorithms.hpp | 126 +++++++--- .../cugraph/detail/collect_comm_wrapper.hpp | 3 +- .../cugraph/detail/shuffle_wrappers.hpp | 25 +- cpp/src/c_api/louvain.cpp | 27 +- cpp/src/community/ecg_impl.cuh | 176 +++++++++++++ cpp/src/community/ecg_mg.cu | 92 +++++++ cpp/src/community/ecg_sg.cu | 92 +++++++ cpp/src/community/louvain_impl.cuh | 44 +++- cpp/src/community/louvain_mg.cu | 14 +- cpp/src/community/louvain_sg.cu | 14 +- cpp/src/detail/permute_range.cu | 187 ++++++++++++++ cpp/tests/CMakeLists.txt | 6 +- cpp/tests/community/louvain_test.cpp | 50 ++-- cpp/tests/community/mg_ecg_test.cpp | 233 ++++++++++++++++++ cpp/tests/community/mg_louvain_test.cpp | 19 +- 16 files changed, 1025 insertions(+), 88 deletions(-) create mode 100644 cpp/src/community/ecg_impl.cuh create mode 100644 cpp/src/community/ecg_mg.cu create mode 100644 cpp/src/community/ecg_sg.cu create mode 100644 cpp/src/detail/permute_range.cu create mode 100644 cpp/tests/community/mg_ecg_test.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d8f359b5bcb..ecc2ebf06d3 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -187,6 +187,7 @@ endif() set(CUGRAPH_SOURCES src/detail/shuffle_vertices.cu + src/detail/permute_range.cu src/detail/shuffle_vertex_pairs.cu src/detail/collect_local_vertex_values.cu src/detail/groupby_and_count.cu @@ -218,6 +219,8 @@ set(CUGRAPH_SOURCES src/community/louvain_mg.cu src/community/leiden_sg.cu src/community/leiden_mg.cu + src/community/ecg_sg.cu + src/community/ecg_mg.cu src/community/legacy/louvain.cu src/community/legacy/ktruss.cu src/community/legacy/ecg.cu diff --git a/cpp/include/cugraph/algorithms.hpp b/cpp/include/cugraph/algorithms.hpp index 8501eedce5c..bb721468106 100644 --- a/cpp/include/cugraph/algorithms.hpp +++ b/cpp/include/cugraph/algorithms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -541,30 +541,37 @@ weight_t hungarian(raft::handle_t const& handle, * community hierarchies in large networks, J Stat Mech P10008 (2008), * http://arxiv.org/abs/0803.0476 * - * @throws cugraph::logic_error when an error occurs. - * - * @tparam graph_view_t Type of graph + * @throws cugraph::logic_error when an error occurs. * - * @param[in] handle Library handle (RAFT). If a communicator is set in the handle, - * @param[in] graph input graph object - * @param[out] clustering Pointer to device array where the clustering should be stored - * @param[in] max_level (optional) maximum number of levels to run (default 100) - * @param[in] threshold (optional) threshold for convergence at each level (default - * 1e-7) - * @param[in] resolution (optional) The value of the resolution parameter to use. - * Called gamma in the modularity formula, this changes the size - * of the communities. Higher resolutions lead to more smaller - * communities, lower resolutions lead to fewer larger - * communities. (default 1) + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * @tparam edge_t Type of edge identifiers. Needs to be an integral type. + * @tparam weight_t Type of edge weights. Needs to be a floating point type. + * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) * - * @return a pair containing: - * 1) number of levels of the returned clustering - * 2) modularity of the returned clustering + * @param[in] handle Library handle (RAFT). If a communicator is set in the handle, + * @param[in] rng_state The RngState instance holding pseudo-random number generator state. + * @param[in] graph_view Input graph view object. + * @param[in] edge_weight_view Optional view object holding edge weights for @p graph_view. + * If @pedge_weight_view.has_value() == false, edge weights + * are assumed to be 1.0. + @param[out] clustering Pointer to device array where the clustering should be stored + * @param[in] max_level (optional) maximum number of levels to run (default 100) + * @param[in] threshold (optional) threshold for convergence at each level (default 1e-7) + * @param[in] resolution (optional) The value of the resolution parameter to use. + * Called gamma in the modularity formula, this changes the size + * of the communities. Higher resolutions lead to more smaller + * communities, lower resolutions lead to fewer larger + * communities. (default 1) + * + * @return a pair containing: + * 1) number of levels of the returned clustering + * 2) modularity of the returned clustering * */ template std::pair louvain( raft::handle_t const& handle, + std::optional> rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, vertex_t* clustering, @@ -593,25 +600,33 @@ std::pair louvain( * * @throws cugraph::logic_error when an error occurs. * - * @tparam graph_view_t Type of graph - * - * @param[in] handle Library handle (RAFT) - * @param[in] graph_view Input graph view object - * @param[in] max_level (optional) maximum number of levels to run (default 100) - * @param[in] resolution (optional) The value of the resolution parameter to use. - * Called gamma in the modularity formula, this changes the size - * of the communities. Higher resolutions lead to more smaller - * communities, lower resolutions lead to fewer larger - * communities. (default 1) + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * @tparam edge_t Type of edge identifiers. Needs to be an integral type. + * @tparam weight_t Type of edge weights. Needs to be a floating point type. + * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) * - * @return a pair containing: - * 1) unique pointer to dendrogram - * 2) modularity of the returned clustering + * @param[in] handle Library handle (RAFT). If a communicator is set in the handle, + * @param[in] rng_state The RngState instance holding pseudo-random number generator state. + * @param[in] graph_view Input graph view object. + * @param[in] edge_weight_view Optional view object holding edge weights for @p graph_view. + * If @pedge_weight_view.has_value() == false, edge weights + * are assumed to be 1.0. + * @param[in] max_level (optional) maximum number of levels to run (default 100) + * @param[in] threshold (optional) threshold for convergence at each level (default 1e-7) + * @param[in] resolution (optional) The value of the resolution parameter to use. + * Called gamma in the modularity formula, this changes the size + * of the communities. Higher resolutions lead to more smaller + * communities, lower resolutions lead to fewer larger + * communities. (default 1) + * @return a pair containing: + * 1) unique pointer to dendrogram + * 2) modularity of the returned clustering * */ template std::pair>, weight_t> louvain( raft::handle_t const& handle, + std::optional> rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level = 100, @@ -779,6 +794,55 @@ void ecg(raft::handle_t const& handle, vertex_t ensemble_size, vertex_t* clustering); +/** + * @brief Computes the ecg clustering of the given graph. + * + * ECG runs truncated Louvain on an ensemble of permutations of the input graph, + * then uses the ensemble partitions to determine weights for the input graph. + * The final result is found by running full Louvain on the input graph using + * the determined weights. See https://arxiv.org/abs/1809.05578 for further + * information. + * + * @throws cugraph::logic_error when an error occurs. + * + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * @tparam edge_t Type of edge identifiers. Needs to be an integral type. + * @tparam weight_t Type of edge weights. Needs to be a floating point type. + * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) + * + * @param[in] handle Library handle (RAFT). If a communicator is set in the handle, + * @param[in] rng_state The RngState instance holding pseudo-random number generator state. + * @param[in] graph_view Input graph view object + * @param[in] edge_weight_view View object holding edge weights for @p graph_view. + * @param[in] min_weight Minimum edge weight to use in the final call of the clustering + * algorithm if an edge does not appear in any of the ensemble runs. + * @param[in] ensemble_size The ensemble size parameter + * @param[in] max_level (optional) maximum number of levels to run (default 100) + * @param[in] threshold (optional) threshold for convergence at each level (default 1e-7) + * @param[in] resolution (optional) The value of the resolution parameter to use. + * Called gamma in the modularity formula, this changes the size + * of the communities. Higher resolutions lead to more smaller + * communities, lower resolutions lead to fewer larger + * communities. (default 1) + * + * @return a tuple containing: + * 1) Device vector containing clustering result + * 2) number of levels of the returned clustering + * 3) modularity of the returned clustering + * + */ +template +std::tuple, size_t, weight_t> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + weight_t min_weight, + size_t ensemble_size, + size_t max_level = 100, + weight_t threshold = weight_t{1e-7}, + weight_t resolution = weight_t{1}); + /** * @brief Generate edges in a minimum spanning forest of an undirected weighted graph. * diff --git a/cpp/include/cugraph/detail/collect_comm_wrapper.hpp b/cpp/include/cugraph/detail/collect_comm_wrapper.hpp index b791c593f41..4a2f5d7c44e 100644 --- a/cpp/include/cugraph/detail/collect_comm_wrapper.hpp +++ b/cpp/include/cugraph/detail/collect_comm_wrapper.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include diff --git a/cpp/include/cugraph/detail/shuffle_wrappers.hpp b/cpp/include/cugraph/detail/shuffle_wrappers.hpp index 55ea6a0e355..c77ecb7aa01 100644 --- a/cpp/include/cugraph/detail/shuffle_wrappers.hpp +++ b/cpp/include/cugraph/detail/shuffle_wrappers.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include @@ -138,6 +139,28 @@ shuffle_ext_vertex_value_pairs_to_local_gpu_by_vertex_partitioning( rmm::device_uvector&& vertices, rmm::device_uvector&& values); +/** + * @brief Permute a range. + * + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * + * @param[in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, + * and handles to various CUDA libraries) to run graph algorithms. + * @param[in] rng_state The RngState instance holding pseudo-random number generator state. + * @param[in] local_range_size Size of local range assigned to this process. + * @param[in] local_start Start of local range assigned to this process. + * + * @return permuted range. + */ + +template +rmm::device_uvector permute_range(raft::handle_t const& handle, + raft::random::RngState& rng_state, + vertex_t local_start, + vertex_t local_range_size, + bool multi_gpu = false, + bool do_expensive_check = false); + /** * @brief Shuffle internal (i.e. renumbered) vertices to their local GPUs based on vertex * partitioning. diff --git a/cpp/src/c_api/louvain.cpp b/cpp/src/c_api/louvain.cpp index 0e48b29388a..a131ee6a3ad 100644 --- a/cpp/src/c_api/louvain.cpp +++ b/cpp/src/c_api/louvain.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -95,18 +95,19 @@ struct louvain_functor : public cugraph::c_api::abstract_functor { // could add support in Louvain for std::nullopt as the edge weights behaving // as desired and only instantiating a real edge_property_view_t for the // coarsened graphs. - auto [level, modularity] = - cugraph::louvain(handle_, - graph_view, - (edge_weights != nullptr) - ? std::make_optional(edge_weights->view()) - : std::make_optional(cugraph::c_api::create_constant_edge_property( - handle_, graph_view, weight_t{1}) - .view()), - clusters.data(), - max_level_, - static_cast(threshold_), - static_cast(resolution_)); + auto [level, modularity] = cugraph::louvain( + handle_, + std::optional>{std::nullopt}, + graph_view, + (edge_weights != nullptr) + ? std::make_optional(edge_weights->view()) + : std::make_optional( + cugraph::c_api::create_constant_edge_property(handle_, graph_view, weight_t{1}) + .view()), + clusters.data(), + max_level_, + static_cast(threshold_), + static_cast(resolution_)); rmm::device_uvector vertices(graph_view.local_vertex_partition_range_size(), handle_.get_stream()); diff --git a/cpp/src/community/ecg_impl.cuh b/cpp/src/community/ecg_impl.cuh new file mode 100644 index 00000000000..f885952dfe6 --- /dev/null +++ b/cpp/src/community/ecg_impl.cuh @@ -0,0 +1,176 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cugraph { + +namespace detail { + +template +std::tuple, size_t, weight_t> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + weight_t min_weight, + size_t ensemble_size, + size_t max_level, + weight_t threshold, + weight_t resolution) +{ + using graph_view_t = cugraph::graph_view_t; + + CUGRAPH_EXPECTS(min_weight >= weight_t{0.0}, + "Invalid input arguments: min_weight must be positive"); + CUGRAPH_EXPECTS(ensemble_size >= 1, + "Invalid input arguments: ensemble_size must be a non-zero integer"); + CUGRAPH_EXPECTS( + threshold > 0.0 && threshold <= 1.0, + "Invalid input arguments: threshold must be a positive number in range (0.0, 1.0]"); + CUGRAPH_EXPECTS( + resolution > 0.0 && resolution <= 1.0, + "Invalid input arguments: resolution must be a positive number in range (0.0, 1.0]"); + + edge_src_property_t src_cluster_assignments(handle, graph_view); + edge_dst_property_t dst_cluster_assignments(handle, graph_view); + edge_property_t modified_edge_weights(handle, graph_view); + + cugraph::fill_edge_property(handle, graph_view, weight_t{0}, modified_edge_weights); + + weight_t modularity = -1.0; + rmm::device_uvector cluster_assignments(graph_view.local_vertex_partition_range_size(), + handle.get_stream()); + + for (size_t i = 0; i < ensemble_size; i++) { + std::tie(std::ignore, modularity) = cugraph::louvain( + handle, + std::make_optional(std::reference_wrapper(rng_state)), + graph_view, + edge_weight_view, + cluster_assignments.data(), + size_t{1}, + threshold, + resolution); + + cugraph::update_edge_src_property( + handle, graph_view, cluster_assignments.begin(), src_cluster_assignments); + cugraph::update_edge_dst_property( + handle, graph_view, cluster_assignments.begin(), dst_cluster_assignments); + + cugraph::transform_e( + handle, + graph_view, + src_cluster_assignments.view(), + dst_cluster_assignments.view(), + modified_edge_weights.view(), + [] __device__(auto, auto, auto src_property, auto dst_property, auto edge_property) { + return edge_property + (src_property == dst_property); + }, + modified_edge_weights.mutable_view()); + } + + cugraph::transform_e( + handle, + graph_view, + edge_src_dummy_property_t{}.view(), + edge_dst_dummy_property_t{}.view(), + view_concat(*edge_weight_view, modified_edge_weights.view()), + [min_weight, ensemble_size = static_cast(ensemble_size)] __device__( + auto, auto, thrust::nullopt_t, thrust::nullopt_t, auto edge_properties) { + auto e_weight = thrust::get<0>(edge_properties); + auto e_frequency = thrust::get<1>(edge_properties); + return min_weight + (e_weight - min_weight) * e_frequency / ensemble_size; + }, + modified_edge_weights.mutable_view()); + + std::tie(max_level, modularity) = + cugraph::louvain(handle, + std::make_optional(std::reference_wrapper(rng_state)), + graph_view, + std::make_optional(modified_edge_weights.view()), + cluster_assignments.data(), + max_level, + threshold, + resolution); + + // Compute final modularity using original edge weights + weight_t total_edge_weight = + cugraph::compute_total_edge_weight(handle, graph_view, *edge_weight_view); + + if constexpr (multi_gpu) { + cugraph::update_edge_src_property( + handle, graph_view, cluster_assignments.begin(), src_cluster_assignments); + cugraph::update_edge_dst_property( + handle, graph_view, cluster_assignments.begin(), dst_cluster_assignments); + } + + auto [cluster_keys, cluster_weights] = cugraph::detail::compute_cluster_keys_and_values( + handle, graph_view, edge_weight_view, cluster_assignments, src_cluster_assignments); + + modularity = detail::compute_modularity(handle, + graph_view, + edge_weight_view, + src_cluster_assignments, + dst_cluster_assignments, + cluster_assignments, + cluster_weights, + total_edge_weight, + resolution); + + return std::make_tuple(std::move(cluster_assignments), max_level, modularity); +} + +} // namespace detail + +template +std::tuple, size_t, weight_t> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + weight_t min_weight, + size_t ensemble_size, + size_t max_level, + weight_t threshold, + weight_t resolution) +{ + return detail::ecg(handle, + rng_state, + graph_view, + edge_weight_view, + min_weight, + ensemble_size, + max_level, + threshold, + resolution); +} + +} // namespace cugraph diff --git a/cpp/src/community/ecg_mg.cu b/cpp/src/community/ecg_mg.cu new file mode 100644 index 00000000000..9c910c70739 --- /dev/null +++ b/cpp/src/community/ecg_mg.cu @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace cugraph { +template std::tuple, size_t, float> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + float min_weight, + size_t ensemble_size, + size_t max_level, + float threshold, + float resolution); + +template std::tuple, size_t, float> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + float min_weight, + size_t ensemble_size, + size_t max_level, + float threshold, + float resolution); + +template std::tuple, size_t, float> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + float min_weight, + size_t ensemble_size, + size_t max_level, + float threshold, + float resolution); + +template std::tuple, size_t, double> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + double min_weight, + size_t ensemble_size, + size_t max_level, + double threshold, + double resolution); + +template std::tuple, size_t, double> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + double min_weight, + size_t ensemble_size, + size_t max_level, + double threshold, + double resolution); + +template std::tuple, size_t, double> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + double min_weight, + size_t ensemble_size, + size_t max_level, + double threshold, + double resolution); + +} // namespace cugraph diff --git a/cpp/src/community/ecg_sg.cu b/cpp/src/community/ecg_sg.cu new file mode 100644 index 00000000000..530fb035ed5 --- /dev/null +++ b/cpp/src/community/ecg_sg.cu @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace cugraph { +template std::tuple, size_t, float> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + float min_weight, + size_t ensemble_size, + size_t max_level, + float threshold, + float resolution); + +template std::tuple, size_t, float> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + float min_weight, + size_t ensemble_size, + size_t max_level, + float threshold, + float resolution); + +template std::tuple, size_t, float> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + float min_weight, + size_t ensemble_size, + size_t max_level, + float threshold, + float resolution); + +template std::tuple, size_t, double> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + double min_weight, + size_t ensemble_size, + size_t max_level, + double threshold, + double resolution); + +template std::tuple, size_t, double> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + double min_weight, + size_t ensemble_size, + size_t max_level, + double threshold, + double resolution); + +template std::tuple, size_t, double> ecg( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + + double min_weight, + size_t ensemble_size, + size_t max_level, + double threshold, + double resolution); + +} // namespace cugraph diff --git a/cpp/src/community/louvain_impl.cuh b/cpp/src/community/louvain_impl.cuh index 7777921a091..4919dda5a75 100644 --- a/cpp/src/community/louvain_impl.cuh +++ b/cpp/src/community/louvain_impl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,15 +18,18 @@ // #define TIMING +// FIXME: Only outstanding items preventing this becoming a .hpp file +#include + #include #include -#include -// FIXME: Only outstanding items preventing this becoming a .hpp file +#include #include #include #include #include +#include #include namespace cugraph { @@ -44,6 +47,7 @@ void check_clustering(graph_view_t const& gr template std::pair>, weight_t> louvain( raft::handle_t const& handle, + std::optional> rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, @@ -82,11 +86,25 @@ std::pair>, weight_t> louvain( current_graph_view.local_vertex_partition_range_size(), handle.get_stream()); - detail::sequence_fill(handle.get_stream(), - dendrogram->current_level_begin(), - dendrogram->current_level_size(), - current_graph_view.local_vertex_partition_range_first()); - + if (rng_state) { + auto random_cluster_assignments = cugraph::detail::permute_range( + handle, + *rng_state, + current_graph_view.local_vertex_partition_range_first(), + current_graph_view.local_vertex_partition_range_size(), + multi_gpu); + + raft::copy(dendrogram->current_level_begin(), + random_cluster_assignments.begin(), + random_cluster_assignments.size(), + handle.get_stream()); + + } else { + detail::sequence_fill(handle.get_stream(), + dendrogram->current_level_begin(), + dendrogram->current_level_size(), + current_graph_view.local_vertex_partition_range_first()); + } // // Compute the vertex and cluster weights, these are different for each // graph in the hierarchical decomposition @@ -289,6 +307,7 @@ void flatten_dendrogram(raft::handle_t const& handle, template std::pair>, weight_t> louvain( raft::handle_t const& handle, + std::optional> rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, size_t max_level, @@ -298,7 +317,9 @@ std::pair>, weight_t> louvain( CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); CUGRAPH_EXPECTS(edge_weight_view.has_value(), "Graph must be weighted"); - return detail::louvain(handle, graph_view, edge_weight_view, max_level, threshold, resolution); + + return detail::louvain( + handle, rng_state, graph_view, edge_weight_view, max_level, threshold, resolution); } template @@ -315,6 +336,7 @@ void flatten_dendrogram(raft::handle_t const& handle, template std::pair louvain( raft::handle_t const& handle, + std::optional> rng_state, graph_view_t const& graph_view, std::optional> edge_weight_view, vertex_t* clustering, @@ -330,8 +352,8 @@ std::pair louvain( std::unique_ptr> dendrogram; weight_t modularity; - std::tie(dendrogram, modularity) = - detail::louvain(handle, graph_view, edge_weight_view, max_level, threshold, resolution); + std::tie(dendrogram, modularity) = detail::louvain( + handle, rng_state, graph_view, edge_weight_view, max_level, threshold, resolution); detail::flatten_dendrogram(handle, graph_view, *dendrogram, clustering); diff --git a/cpp/src/community/louvain_mg.cu b/cpp/src/community/louvain_mg.cu index 0be32ed049f..51fb5e3d93d 100644 --- a/cpp/src/community/louvain_mg.cu +++ b/cpp/src/community/louvain_mg.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ namespace cugraph { template std::pair>, float> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -29,6 +30,7 @@ template std::pair>, float> louvain( float); template std::pair>, float> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -36,6 +38,7 @@ template std::pair>, float> louvain( float); template std::pair>, float> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -43,6 +46,7 @@ template std::pair>, float> louvain( float); template std::pair>, double> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -50,6 +54,7 @@ template std::pair>, double> louvain( double); template std::pair>, double> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -57,6 +62,7 @@ template std::pair>, double> louvain( double); template std::pair>, double> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -65,6 +71,7 @@ template std::pair>, double> louvain( template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int32_t*, @@ -73,6 +80,7 @@ template std::pair louvain( float); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int32_t*, @@ -81,6 +89,7 @@ template std::pair louvain( double); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int32_t*, @@ -89,6 +98,7 @@ template std::pair louvain( float); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int32_t*, @@ -97,6 +107,7 @@ template std::pair louvain( double); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int64_t*, @@ -105,6 +116,7 @@ template std::pair louvain( float); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int64_t*, diff --git a/cpp/src/community/louvain_sg.cu b/cpp/src/community/louvain_sg.cu index 3fc0ffab928..557c219d424 100644 --- a/cpp/src/community/louvain_sg.cu +++ b/cpp/src/community/louvain_sg.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ namespace cugraph { template std::pair>, float> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -29,6 +30,7 @@ template std::pair>, float> louvain( float); template std::pair>, float> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -36,6 +38,7 @@ template std::pair>, float> louvain( float); template std::pair>, float> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -43,6 +46,7 @@ template std::pair>, float> louvain( float); template std::pair>, double> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -50,6 +54,7 @@ template std::pair>, double> louvain( double); template std::pair>, double> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -57,6 +62,7 @@ template std::pair>, double> louvain( double); template std::pair>, double> louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, size_t, @@ -65,6 +71,7 @@ template std::pair>, double> louvain( template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int32_t*, @@ -73,6 +80,7 @@ template std::pair louvain( float); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int32_t*, @@ -81,6 +89,7 @@ template std::pair louvain( double); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int32_t*, @@ -89,6 +98,7 @@ template std::pair louvain( float); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int32_t*, @@ -97,6 +107,7 @@ template std::pair louvain( double); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int64_t*, @@ -105,6 +116,7 @@ template std::pair louvain( float); template std::pair louvain( raft::handle_t const&, + std::optional>, graph_view_t const&, std::optional>, int64_t*, diff --git a/cpp/src/detail/permute_range.cu b/cpp/src/detail/permute_range.cu new file mode 100644 index 00000000000..cc77f022616 --- /dev/null +++ b/cpp/src/detail/permute_range.cu @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace cugraph { + +namespace detail { + +template +rmm::device_uvector permute_range(raft::handle_t const& handle, + raft::random::RngState& rng_state, + vertex_t local_range_start, + vertex_t local_range_size, + bool multi_gpu, + bool do_expensive_check) +{ + if (do_expensive_check && multi_gpu) { + auto& comm = handle.get_comms(); + auto const comm_size = comm.get_size(); + auto const comm_rank = comm.get_rank(); + + auto global_start = + cugraph::host_scalar_bcast(handle.get_comms(), local_range_start, 0, handle.get_stream()); + auto sub_range_sizes = + cugraph::host_scalar_allgather(handle.get_comms(), local_range_size, handle.get_stream()); + std::exclusive_scan( + sub_range_sizes.begin(), sub_range_sizes.end(), sub_range_sizes.begin(), global_start); + CUGRAPH_EXPECTS( + sub_range_sizes[comm_rank] == local_range_start, + "Invalid input arguments: a rage must have contiguous and non-overlapping values"); + } + rmm::device_uvector permuted_integers(local_range_size, handle.get_stream()); + + // generate as many integers as #local_range_size on each GPU + detail::sequence_fill( + handle.get_stream(), permuted_integers.begin(), permuted_integers.size(), local_range_start); + + if (multi_gpu) { + // randomly distribute integers to all GPUs + auto& comm = handle.get_comms(); + auto const comm_size = comm.get_size(); + auto const comm_rank = comm.get_rank(); + + std::vector tx_value_counts(comm_size, 0); + + { + rmm::device_uvector d_target_ranks(permuted_integers.size(), handle.get_stream()); + + cugraph::detail::uniform_random_fill(handle.get_stream(), + d_target_ranks.data(), + d_target_ranks.size(), + vertex_t{0}, + vertex_t{comm_size}, + rng_state); + + thrust::sort_by_key(handle.get_thrust_policy(), + d_target_ranks.begin(), + d_target_ranks.end(), + permuted_integers.begin()); + + rmm::device_uvector d_reduced_ranks(comm_size, handle.get_stream()); + rmm::device_uvector d_reduced_counts(comm_size, handle.get_stream()); + + auto output_end = thrust::reduce_by_key(handle.get_thrust_policy(), + d_target_ranks.begin(), + d_target_ranks.end(), + thrust::make_constant_iterator(1), + d_reduced_ranks.begin(), + d_reduced_counts.begin(), + thrust::equal_to()); + + auto nr_output_pairs = + static_cast(thrust::distance(d_reduced_ranks.begin(), output_end.first)); + + std::vector h_reduced_ranks(comm_size); + std::vector h_reduced_counts(comm_size); + + raft::update_host( + h_reduced_ranks.data(), d_reduced_ranks.data(), nr_output_pairs, handle.get_stream()); + + raft::update_host( + h_reduced_counts.data(), d_reduced_counts.data(), nr_output_pairs, handle.get_stream()); + + for (int i = 0; i < static_cast(nr_output_pairs); i++) { + tx_value_counts[h_reduced_ranks[i]] = static_cast(h_reduced_counts[i]); + } + } + + std::tie(permuted_integers, std::ignore) = cugraph::shuffle_values( + handle.get_comms(), permuted_integers.begin(), tx_value_counts, handle.get_stream()); + } + + // permute locally + rmm::device_uvector fractional_random_numbers(permuted_integers.size(), + handle.get_stream()); + + cugraph::detail::uniform_random_fill(handle.get_stream(), + fractional_random_numbers.data(), + fractional_random_numbers.size(), + float{0.0}, + float{1.0}, + rng_state); + thrust::sort_by_key(handle.get_thrust_policy(), + fractional_random_numbers.begin(), + fractional_random_numbers.end(), + permuted_integers.begin()); + + if (multi_gpu) { + // take care of deficits and extras numbers + auto& comm = handle.get_comms(); + auto const comm_rank = comm.get_rank(); + + size_t nr_extras{0}; + size_t nr_deficits{0}; + if (permuted_integers.size() > static_cast(local_range_size)) { + nr_extras = permuted_integers.size() - static_cast(local_range_size); + } else { + nr_deficits = static_cast(local_range_size) - permuted_integers.size(); + } + + auto extra_cluster_ids = cugraph::detail::device_allgatherv( + handle, + comm, + raft::device_span(permuted_integers.data() + local_range_size, + nr_extras > 0 ? nr_extras : 0)); + + permuted_integers.resize(local_range_size, handle.get_stream()); + auto deficits = + cugraph::host_scalar_allgather(handle.get_comms(), nr_deficits, handle.get_stream()); + + std::exclusive_scan(deficits.begin(), deficits.end(), deficits.begin(), vertex_t{0}); + + raft::copy(permuted_integers.data() + local_range_size - nr_deficits, + extra_cluster_ids.begin() + deficits[comm_rank], + nr_deficits, + handle.get_stream()); + } + + assert(permuted_integers.size() == local_range_size); + return permuted_integers; +} + +template rmm::device_uvector permute_range(raft::handle_t const& handle, + raft::random::RngState& rng_state, + int32_t local_range_start, + int32_t local_range_size, + bool multi_gpu, + bool do_expensive_check); + +template rmm::device_uvector permute_range(raft::handle_t const& handle, + raft::random::RngState& rng_state, + int64_t local_range_start, + int64_t local_range_size, + bool multi_gpu, + bool do_expensive_check); + +} // namespace detail +} // namespace cugraph diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index e9c6dc446af..d9d2f677abc 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -1,5 +1,5 @@ #============================================================================= -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -521,6 +521,10 @@ if(BUILD_CUGRAPH_MG_TESTS) # - MG LEIDEN tests -------------------------------------------------------------------------- ConfigureTestMG(MG_LEIDEN_TEST community/mg_leiden_test.cpp) + ############################################################################################### + # - MG ECG tests -------------------------------------------------------------------------- + ConfigureTestMG(MG_ECG_TEST community/mg_ecg_test.cpp) + ############################################################################################### # - MG MIS tests ------------------------------------------------------------------------------ ConfigureTestMG(MG_MIS_TEST community/mg_mis_test.cu) diff --git a/cpp/tests/community/louvain_test.cpp b/cpp/tests/community/louvain_test.cpp index 8de9cdaf4a8..a39793994d1 100644 --- a/cpp/tests/community/louvain_test.cpp +++ b/cpp/tests/community/louvain_test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. * * NVIDIA CORPORATION and its licensors retain all intellectual property * and proprietary rights in and to this software, related documentation @@ -174,27 +174,39 @@ class Tests_Louvain weight_t modularity; if (resolution) { - std::tie(level, modularity) = - cugraph::louvain(handle, - graph_view, - edge_weight_view, - clustering_v.data(), - max_level ? *max_level : size_t{100}, - threshold ? static_cast(*threshold) : weight_t{1e-7}, - static_cast(*resolution)); + std::tie(level, modularity) = cugraph::louvain( + handle, + std::optional>{std::nullopt}, + graph_view, + edge_weight_view, + clustering_v.data(), + max_level ? *max_level : size_t{100}, + threshold ? static_cast(*threshold) : weight_t{1e-7}, + static_cast(*resolution)); } else if (threshold) { - std::tie(level, modularity) = cugraph::louvain(handle, - graph_view, - edge_weight_view, - clustering_v.data(), - max_level ? *max_level : size_t{100}, - static_cast(*threshold)); + std::tie(level, modularity) = cugraph::louvain( + handle, + std::optional>{std::nullopt}, + graph_view, + edge_weight_view, + clustering_v.data(), + max_level ? *max_level : size_t{100}, + static_cast(*threshold)); } else if (max_level) { - std::tie(level, modularity) = - cugraph::louvain(handle, graph_view, edge_weight_view, clustering_v.data(), *max_level); + std::tie(level, modularity) = cugraph::louvain( + handle, + std::optional>{std::nullopt}, + graph_view, + edge_weight_view, + clustering_v.data(), + *max_level); } else { - std::tie(level, modularity) = - cugraph::louvain(handle, graph_view, edge_weight_view, clustering_v.data()); + std::tie(level, modularity) = cugraph::louvain( + handle, + std::optional>{std::nullopt}, + graph_view, + edge_weight_view, + clustering_v.data()); } RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement diff --git a/cpp/tests/community/mg_ecg_test.cpp b/cpp/tests/community/mg_ecg_test.cpp new file mode 100644 index 00000000000..81cee1370f0 --- /dev/null +++ b/cpp/tests/community/mg_ecg_test.cpp @@ -0,0 +1,233 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Test param object. This defines the input and expected output for a test, and +// will be instantiated as the parameter to the tests defined below using +// INSTANTIATE_TEST_SUITE_P() +// +struct Ecg_Usecase { + double min_weight_{0.1}; + size_t ensemble_size_{10}; + size_t max_level_{100}; + double threshold_{1e-7}; + double resolution_{1.0}; + bool check_correctness_{true}; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Parameterized test fixture, to be used with TEST_P(). This defines common +// setup and teardown steps as well as common utilities used by each E2E MG +// test. In this case, each test is identical except for the inputs and +// expected outputs, so the entire test is defined in the run_test() method. +// +template +class Tests_MGEcg : public ::testing::TestWithParam> { + public: + static void SetUpTestCase() { handle_ = cugraph::test::initialize_mg_handle(); } + + static void TearDownTestCase() { handle_.reset(); } + + // Run once for each test instance + virtual void SetUp() {} + virtual void TearDown() {} + + template + void run_current_test(std::tuple const& param) + { + auto [ecg_usecase, input_usecase] = param; + + HighResTimer hr_timer{}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.start("MG Construct graph"); + } + + auto [mg_graph, mg_edge_weights, d_renumber_map_labels] = + cugraph::test::construct_graph( + *handle_, input_usecase, true, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto mg_graph_view = mg_graph.view(); + auto mg_edge_weight_view = + mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.start("MG ECG"); + } + + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + raft::random::RngState rng_state(seed); + + cugraph::ecg(*handle_, + rng_state, + mg_graph_view, + mg_edge_weight_view, + ecg_usecase.min_weight_, + ecg_usecase.ensemble_size_, + ecg_usecase.max_level_, + ecg_usecase.threshold_, + ecg_usecase.resolution_); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + // Louvain and detail::permute_range are both tested, here we only make + // sure that SG and MG ECG calls work expected. + + cugraph::graph_t sg_graph(*handle_); + std::optional< + cugraph::edge_property_t, weight_t>> + sg_edge_weights{std::nullopt}; + std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + false); // crate a SG graph with MG graph vertex IDs + + auto const comm_rank = handle_->get_comms().get_rank(); + if (comm_rank == 0) { + auto sg_graph_view = sg_graph.view(); + auto sg_edge_weight_view = + sg_edge_weights ? std::make_optional((*sg_edge_weights).view()) : std::nullopt; + + cugraph::ecg(*handle_, + rng_state, + sg_graph_view, + sg_edge_weight_view, + ecg_usecase.min_weight_, + ecg_usecase.ensemble_size_, + ecg_usecase.max_level_, + ecg_usecase.threshold_, + ecg_usecase.resolution_); + } + } + + private: + static std::unique_ptr handle_; +}; + +template +std::unique_ptr Tests_MGEcg::handle_ = nullptr; + +using Tests_MGEcg_File = Tests_MGEcg; +using Tests_MGEcg_Rmat = Tests_MGEcg; + +TEST_P(Tests_MGEcg_File, CheckInt32Int32Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGEcg_File, CheckInt64Int64Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGEcg_Rmat, CheckInt32Int32Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGEcg_Rmat, CheckInt32Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +TEST_P(Tests_MGEcg_Rmat, CheckInt64Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + file_tests, + Tests_MGEcg_File, + ::testing::Combine( + // enable correctness checks for small graphs + ::testing::Values(Ecg_Usecase{0.1, 10, 100, 1e-7, 1.0, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_tests, + Tests_MGEcg_Rmat, + ::testing::Combine( + ::testing::Values(Ecg_Usecase{0.1, 10, 100, 1e-7, 1.0, true}), + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, true, false)))); + +INSTANTIATE_TEST_SUITE_P( + file_benchmark_test, /* note that the test filename can be overridden in benchmarking (with + --gtest_filter to select only the file_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one File_Usecase that differ only in filename + (to avoid running same benchmarks more than once) */ + Tests_MGEcg_File, + ::testing::Combine( + // disable correctness checks for large graphs + ::testing::Values(Ecg_Usecase{0.1, 10, 100, 1e-7, 1.0, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_MGEcg_Rmat, + ::testing::Combine( + // disable correctness checks for large graphs + ::testing::Values(Ecg_Usecase{0.1, 10, 100, 1e-7, 1.0, true}), + ::testing::Values(cugraph::test::Rmat_Usecase(12, 32, 0.57, 0.19, 0.19, 0, true, false)))); + +CUGRAPH_MG_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/community/mg_louvain_test.cpp b/cpp/tests/community/mg_louvain_test.cpp index 41339e32d77..011426606fd 100644 --- a/cpp/tests/community/mg_louvain_test.cpp +++ b/cpp/tests/community/mg_louvain_test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -126,13 +126,15 @@ class Tests_MGLouvain rmm::device_uvector d_sg_cluster_v(sg_graph_view.number_of_vertices(), handle_->get_stream()); - std::tie(std::ignore, sg_modularity) = cugraph::louvain(handle, - sg_graph_view, - sg_edge_weight_view, - d_sg_cluster_v.data(), - size_t{1}, - threshold, - resolution); + std::tie(std::ignore, sg_modularity) = cugraph::louvain( + handle, + std::optional>{std::nullopt}, + sg_graph_view, + sg_edge_weight_view, + d_sg_cluster_v.data(), + size_t{1}, + threshold, + resolution); EXPECT_TRUE(cugraph::test::check_invertible( handle, @@ -191,6 +193,7 @@ class Tests_MGLouvain auto [dendrogram, mg_modularity] = cugraph::louvain( *handle_, + std::optional>{std::nullopt}, mg_graph_view, mg_edge_weight_view, louvain_usecase.max_level_, From aa66a3253466f6a96c876b63eb79941a9825d3ee Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Fri, 12 Jan 2024 16:54:33 -0500 Subject: [PATCH 4/4] Remove usages of rapids-env-update (#4090) Reference: https://github.com/rapidsai/ops/issues/2766 Replace rapids-env-update with rapids-configure-conda-channels, rapids-configure-sccache, and rapids-date-string. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) URL: https://github.com/rapidsai/cugraph/pull/4090 --- ci/build_cpp.sh | 8 ++++++-- ci/build_python.sh | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ci/build_cpp.sh b/ci/build_cpp.sh index d0d13f99448..132231e4a64 100755 --- a/ci/build_cpp.sh +++ b/ci/build_cpp.sh @@ -1,9 +1,13 @@ #!/bin/bash -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. set -euo pipefail -source rapids-env-update +rapids-configure-conda-channels + +source rapids-configure-sccache + +source rapids-date-string export CMAKE_GENERATOR=Ninja diff --git a/ci/build_python.sh b/ci/build_python.sh index 31982648884..a99e5ce63e8 100755 --- a/ci/build_python.sh +++ b/ci/build_python.sh @@ -1,9 +1,13 @@ #!/bin/bash -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. set -euo pipefail -source rapids-env-update +rapids-configure-conda-channels + +source rapids-configure-sccache + +source rapids-date-string export CMAKE_GENERATOR=Ninja