Skip to content

Commit

Permalink
Sampling Performance Testing (#3584)
Browse files Browse the repository at this point in the history
Adds performance benchmarking scripts for testing MNMG cuGraph GNN workflows.
This branch is the head branch for the cuGraph benchmarking effort.  All work supporting the benchmarks should be merged into this branch.  It will be merged into branch-24.02 once all features are ready.

Includes patches to cuGraph-PyG required for the latest DLFW container.

To-Do:

- [x] Refactor for branch-24.02
- [x] <s>Add WholeGraph training portion</s> Deferred to future PR (see alexbarghi-nv#6)
- [x] <s>Add WholeGraph generators</s> Included in above
- [x] <s>Support DGL</s> Deferred to future PR
- [x] <s>Use appropriate docker containers</s> Deferred, waiting on DLFW release

Closes #3839

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Seunghwa Kang (https://github.com/seunghwak)

Approvers:
  - Vibhu Jawa (https://github.com/VibhuJawa)
  - Rick Ratzel (https://github.com/rlratzel)
  - Chuck Hastings (https://github.com/ChuckHastings)

URL: #3584
  • Loading branch information
alexbarghi-nv authored Jan 12, 2024
1 parent 88c3884 commit c09db10
Show file tree
Hide file tree
Showing 25 changed files with 2,053 additions and 188 deletions.
1 change: 1 addition & 0 deletions benchmarks/cugraph/standalone/bulk_sampling/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mg_utils/
70 changes: 58 additions & 12 deletions benchmarks/cugraph/standalone/bulk_sampling/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# cuGraph Bulk Sampling
# cuGraph Sampling Benchmarks

## Overview
## cuGraph Bulk Sampling

### Overview
The `cugraph_bulk_sampling.py` script runs the bulk sampler for a variety of datasets, including
both generated (rmat) datasets and disk (ogbn_papers100M, etc.) datasets. It can also load
replicas of these datasets to create a larger benchmark (i.e. ogbn_papers100M x2).

## Arguments
### Arguments
The script takes a variety of arguments to control sampling behavior.
Required:
--output_root
Expand Down Expand Up @@ -51,14 +53,8 @@ Optional:
Seed for random number generation.
Defaults to '62'

--persist
Whether to aggressively use persist() in dask to make the ETL steps (NOT PART OF SAMPLING) faster.
Will probably make this script finish sooner at the expense of memory usage, but won't affect
sampling time.
Changing this is not recommended unless you know what you are doing.
Defaults to False.

## Input Format
### Input Format
The script expects its input data in the following format:
```
<top level directory>
Expand Down Expand Up @@ -103,14 +99,64 @@ the parquet files. It must have the following format:
}
```

## Output Meta
### Output Meta
The script, in addition to the samples, will also output a file named `output_meta.json`.
This file contains various statistics about the sampling run, including the runtime,
as well as information about the dataset and system that the samples were produced from.

This metadata file can be used to gather the results from the sampling and training stages
together.

## Other Notes
### Other Notes
For rmat datasets, you will need to generate your own bogus features in the training stage.
Since that is trivial, that is not done in this sampling script.

## cuGraph MNMG Training

### Overview
The script `run_train_job.sh` runs with the `sbatch` command to launch a series of slurm jobs.
First, for a given number of epochs, the script will produce samples for a given graph.
Then, the training process starts where samples are loaded and training iterations are
processed.

### Important Notes
Downloading the dataset files before running the slurm jobs is highly recommended. Even though
the script will attempt to download the files if they are not available, this can often
lead to a timeout which will crash the scripts. This applies regardless of whether you are training
with native PyG or cuGraph-PyG. You can download data as follows:

```
from ogb.nodeproppred import NodePropPredDataset
dataset = NodePropPredDataset('ogbn-papers100M', root='/home/username/datasets')
```

For datasets other than ogbn-papers100M, you follow the same process but only change the dataset name.
The dataset will be correctly preprocessed when you run training. In case you have a slow system, you
can also run preprocessing by running the training script on a single worker, which will avoid a timeout
which crashes the script.

The multi-GPU utilities are in `mg_utils` in the top level of the cuGraph repository. You should either
copy them to this directory or symlink to them before running the scripts.

### Arguments
You will need to modify the bash scripts to run appopriately for your environment and
desired training workflow. The standard sbatch arguments are at the top of the script, such as
job name, queue, etc. These will need to be modified for your SLURM cluster.

Next are arguments for the container image (required),
and directories where the data and outputs are stored. The directories default to subdirectories
of the current working directory. But if there is a high-throughput storage system available,
using that storage for the samples and datasets is highly recommended.

Next are standard GNN training arguments such as `FANOUT`, `BATCH_SIZE`, etc. You can also set
the number of training epochs here. These are followed by the `REPLICATION_FACTOR` argument, which
can be used to create replications of the dataset for scale testing purposes.

The final two arguments are `FRAMEWORK` which can be either "cuGraphPyG" or "PyG", and `GPUS_PER_NODE`
which must be set to the correct value, even if this is provided by a SLURM argument. If `GPUS_PER_NODE`
is not set to the correct number of GPUs, the script will hang indefinitely until it times out. Mismatched
GPUs per node is currently unsupported by this script but should be possible in practice.

### Output
The results of training will be outputted to the logs directory with an `output.txt` file for each worker.
These will be overwritten upon each run. Accuracy is only reported on rank 0.
251 changes: 251 additions & 0 deletions benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

os.environ["RAPIDS_NO_INITIALIZE"] = "1"
os.environ["CUDF_SPILL"] = "1"
os.environ["LIBCUDF_CUFILE_POLICY"] = "KVIKIO"
os.environ["KVIKIO_NTHREADS"] = "8"

import argparse
import json
import warnings

import torch
import numpy as np
import pandas

import torch.distributed as dist

from datasets import OGBNPapers100MDataset

from cugraph.testing.mg_utils import enable_spilling


def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool = False) -> None:
import cupy
import rmm
from pynvml.smi import nvidia_smi

smi = nvidia_smi.getInstance()
pool_size = 16e9 # FIXME calculate this

rmm.reinitialize(
devices=[rank],
pool_allocator=True,
initial_pool_size=pool_size,
)

if use_rmm_torch_allocator:
warnings.warn(
"Using the rmm pytorch allocator is currently unsupported."
" The default allocator will be used instead."
)
# FIXME somehow get the pytorch allocator to work
# from rmm.allocators.torch import rmm_torch_allocator
# torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)

cupy.cuda.Device(rank).use()
torch.cuda.set_device(rank)

# Pytorch training worker initialization
torch.distributed.init_process_group(backend="nccl")


def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--gpus_per_node",
type=int,
default=8,
help="# GPUs per node",
required=False,
)

parser.add_argument(
"--num_epochs",
type=int,
default=1,
help="Number of training epochs",
required=False,
)

parser.add_argument(
"--batch_size",
type=int,
default=512,
help="Batch size",
required=False,
)

parser.add_argument(
"--fanout",
type=str,
default="10_10_10",
help="Fanout",
required=False,
)

parser.add_argument(
"--sample_dir",
type=str,
help="Directory with stored bulk samples (required for cuGraph run)",
required=False,
)

parser.add_argument(
"--output_file",
type=str,
help="File to store results",
required=True,
)

parser.add_argument(
"--framework",
type=str,
help="The framework to test (PyG, cuGraphPyG)",
required=True,
)

parser.add_argument(
"--model",
type=str,
default="GraphSAGE",
help="The model to use (currently only GraphSAGE supported)",
required=False,
)

parser.add_argument(
"--replication_factor",
type=int,
default=1,
help="The replication factor for the dataset",
required=False,
)

parser.add_argument(
"--dataset_dir",
type=str,
help="The directory where datasets are stored",
required=True,
)

parser.add_argument(
"--train_split",
type=float,
help="The percentage of the labeled data to use for training. The remainder is used for testing/validation.",
default=0.8,
required=False,
)

parser.add_argument(
"--val_split",
type=float,
help="The percentage of the testing/validation data to allocate for validation.",
default=0.5,
required=False,
)

return parser.parse_args()


def main(args):
import logging

logging.basicConfig(
level=logging.INFO,
)
logger = logging.getLogger("bench_cugraph_training")
logger.setLevel(logging.INFO)

local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])

init_pytorch_worker(
local_rank, use_rmm_torch_allocator=(args.framework == "cuGraph")
)
enable_spilling()
print(f"worker initialized")
dist.barrier()

world_size = int(os.environ["SLURM_JOB_NUM_NODES"]) * args.gpus_per_node

dataset = OGBNPapers100MDataset(
replication_factor=args.replication_factor,
dataset_dir=args.dataset_dir,
train_split=args.train_split,
val_split=args.val_split,
load_edge_index=(args.framework == "PyG"),
)

if global_rank == 0:
dataset.download()
dist.barrier()

fanout = [int(f) for f in args.fanout.split("_")]

if args.framework == "PyG":
from trainers.pyg import PyGNativeTrainer

trainer = PyGNativeTrainer(
model=args.model,
dataset=dataset,
device=local_rank,
rank=global_rank,
world_size=world_size,
num_epochs=args.num_epochs,
shuffle=True,
replace=False,
num_neighbors=fanout,
batch_size=args.batch_size,
)
elif args.framework == "cuGraphPyG":
sample_dir = os.path.join(
args.sample_dir,
f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}",
)
from trainers.pyg import PyGCuGraphTrainer

trainer = PyGCuGraphTrainer(
model=args.model,
dataset=dataset,
sample_dir=sample_dir,
device=local_rank,
rank=global_rank,
world_size=world_size,
num_epochs=args.num_epochs,
shuffle=True,
replace=False,
num_neighbors=fanout,
batch_size=args.batch_size,
)
else:
raise ValueError("unsupported framework")

logger.info(f"Trainer ready on rank {global_rank}")
stats = trainer.train()
logger.info(stats)

with open(f"{args.output_file}[{global_rank}]", "w") as f:
json.dump(stats, f)


if __name__ == "__main__":
args = parse_args()
main(args)
Loading

0 comments on commit c09db10

Please sign in to comment.