From 5c7cb2bd789326676f63cf958a3595fb1593d44a Mon Sep 17 00:00:00 2001 From: Alex Barghi <105237337+alexbarghi-nv@users.noreply.github.com> Date: Mon, 15 Apr 2024 18:27:07 -0400 Subject: [PATCH] [FEA] cuGraph GNN NCCL-only Setup and Distributed Sampling (#4278) * Adds the ability to run `pylibcugraph` without UCX/dask within PyTorch DDP. * Adds the new distributed sampler which uses the new nccl+ddp path to perform bulk sampling. Closes #4200 Closes #4201 Closes #4246 Closes #3851 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Seunghwa Kang (https://github.com/seunghwak) - Rick Ratzel (https://github.com/rlratzel) - Chuck Hastings (https://github.com/ChuckHastings) - Jake Awe (https://github.com/AyodeAwe) - Joseph Nke (https://github.com/jnke2016) URL: https://github.com/rapidsai/cugraph/pull/4278 --- ci/run_cugraph_pyg_pytests.sh | 2 +- ci/test_wheel_cugraph-pyg.sh | 2 +- .../examples/cugraph_dist_sampling_mg.py | 112 ++++ .../examples/cugraph_dist_sampling_sg.py | 80 +++ .../cugraph_pyg/examples/pylibcugraph_mg.py | 100 +++ .../cugraph_pyg/examples/pylibcugraph_sg.py | 66 ++ python/cugraph/cugraph/gnn/__init__.py | 13 +- python/cugraph/cugraph/gnn/comms/__init__.py | 19 + .../cugraph/gnn/comms/cugraph_nccl_comms.py | 92 +++ .../cugraph/gnn/data_loading/__init__.py | 7 +- .../gnn/data_loading/bulk_sampler_io.py | 9 +- .../cugraph/gnn/data_loading/dist_sampler.py | 577 ++++++++++++++++++ .../tests/sampling/test_dist_sampler.py | 94 +++ .../tests/sampling/test_dist_sampler_mg.py | 230 +++++++ 14 files changed, 1398 insertions(+), 5 deletions(-) create mode 100644 python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_mg.py create mode 100644 python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_sg.py create mode 100644 python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_mg.py create mode 100644 python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_sg.py create mode 100644 python/cugraph/cugraph/gnn/comms/__init__.py create mode 100644 python/cugraph/cugraph/gnn/comms/cugraph_nccl_comms.py create mode 100644 python/cugraph/cugraph/gnn/data_loading/dist_sampler.py create mode 100644 python/cugraph/cugraph/tests/sampling/test_dist_sampler.py create mode 100644 python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py diff --git a/ci/run_cugraph_pyg_pytests.sh b/ci/run_cugraph_pyg_pytests.sh index 0acc8aa462a..88642e6ceb6 100755 --- a/ci/run_cugraph_pyg_pytests.sh +++ b/ci/run_cugraph_pyg_pytests.sh @@ -11,5 +11,5 @@ pytest --cache-clear --ignore=tests/mg "$@" . # Test examples for e in "$(pwd)"/examples/*.py; do rapids-logger "running example $e" - python $e + (yes || true) | python $e done diff --git a/ci/test_wheel_cugraph-pyg.sh b/ci/test_wheel_cugraph-pyg.sh index ad615f0b3ff..e98bf4ab56b 100755 --- a/ci/test_wheel_cugraph-pyg.sh +++ b/ci/test_wheel_cugraph-pyg.sh @@ -52,6 +52,6 @@ python -m pytest \ # Test examples for e in "$(pwd)"/examples/*.py; do rapids-logger "running example $e" - python $e + (yes || true) | python $e done popd diff --git a/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_mg.py b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_mg.py new file mode 100644 index 00000000000..29a6cc2b464 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_mg.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example shows how to use cuGraph nccl-only comms, pylibcuGraph, +# and PyTorch DDP to run a multi-GPU sampling workflow. Most users of the +# GNN packages will not interact with cuGraph directly. This example +# is intented for users who want to extend cuGraph within a DDP workflow. + +import os +import re +import tempfile + +import numpy as np +import torch +import torch.multiprocessing as tmp +import torch.distributed as dist + +import cudf + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, + cugraph_comms_get_raft_handle, + DistSampleWriter, + UniformNeighborSampler, +) + +from pylibcugraph import MGGraph, ResourceHandle, GraphProperties + +from ogb.nodeproppred import NodePropPredDataset + + +def init_pytorch(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def sample(rank: int, world_size: int, uid, edgelist, directory): + init_pytorch(rank, world_size) + + device = rank + cugraph_comms_init(rank, world_size, uid, device) + + print(f"rank {rank} initialized cugraph") + + src = cudf.Series(np.array_split(edgelist[0], world_size)[rank]) + dst = cudf.Series(np.array_split(edgelist[1], world_size)[rank]) + + seeds_per_rank = 50 + seeds = cudf.Series(np.arange(rank * seeds_per_rank, (rank + 1) * seeds_per_rank)) + handle = ResourceHandle(cugraph_comms_get_raft_handle().getHandle()) + + print("constructing graph") + G = MGGraph( + handle, + GraphProperties(is_multigraph=True, is_symmetric=False), + [src], + [dst], + ) + print("graph constructed") + + sample_writer = DistSampleWriter(directory=directory, batches_per_partition=2) + sampler = UniformNeighborSampler( + G, + sample_writer, + fanout=[5, 5], + ) + + sampler.sample_from_nodes(seeds, batch_size=16, random_state=62) + + dist.barrier() + cugraph_comms_shutdown() + print(f"rank {rank} shut down cugraph") + + +def main(): + world_size = torch.cuda.device_count() + uid = cugraph_comms_create_unique_id() + + dataset = NodePropPredDataset("ogbn-products") + el = dataset[0][0]["edge_index"].astype("int64") + + with tempfile.TemporaryDirectory() as directory: + tmp.spawn( + sample, + args=(world_size, uid, el, "."), + nprocs=world_size, + ) + + print("Printing samples...") + for file in os.listdir(directory): + m = re.match(r"batch=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet", file) + rank, start, _, end = int(m[1]), int(m[2]), int(m[3]), int(m[4]) + print(f"File: {file} (batches {start} to {end} for rank {rank})") + print(cudf.read_parquet(os.path.join(directory, file))) + print("\n") + + +if __name__ == "__main__": + main() diff --git a/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_sg.py new file mode 100644 index 00000000000..8366ff44233 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling_sg.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example shows how to use cuGraph nccl-only comms, pylibcuGraph, +# and PyTorch to run a single-GPU sampling workflow. Most users of the +# GNN packages will not interact with cuGraph directly. This example +# is intented for users who want to extend cuGraph within a PyTorch workflow. + +import os +import re +import tempfile + +import numpy as np + +import cudf + +from cugraph.gnn import ( + DistSampleWriter, + UniformNeighborSampler, +) + +from pylibcugraph import SGGraph, ResourceHandle, GraphProperties + +from ogb.nodeproppred import NodePropPredDataset + + +def sample(edgelist, directory): + src = cudf.Series(edgelist[0]) + dst = cudf.Series(edgelist[1]) + + seeds_per_rank = 50 + seeds = cudf.Series(np.arange(0, seeds_per_rank)) + + print("constructing graph") + G = SGGraph( + ResourceHandle(), + GraphProperties(is_multigraph=True, is_symmetric=False), + src, + dst, + ) + print("graph constructed") + + sample_writer = DistSampleWriter(directory=directory, batches_per_partition=2) + sampler = UniformNeighborSampler( + G, + sample_writer, + fanout=[5, 5], + ) + + sampler.sample_from_nodes(seeds, batch_size=16, random_state=62) + + +def main(): + dataset = NodePropPredDataset("ogbn-products") + el = dataset[0][0]["edge_index"].astype("int64") + + with tempfile.TemporaryDirectory() as directory: + sample(el, directory) + + print("Printing samples...") + for file in os.listdir(directory): + m = re.match(r"batch=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet", file) + rank, start, _, end = int(m[1]), int(m[2]), int(m[3]), int(m[4]) + print(f"File: {file} (batches {start} to {end} for rank {rank})") + print(cudf.read_parquet(os.path.join(directory, file))) + print("\n") + + +if __name__ == "__main__": + main() diff --git a/python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_mg.py b/python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_mg.py new file mode 100644 index 00000000000..832c5ec74f0 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_mg.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example shows how to use cuGraph nccl-only comms, pylibcuGraph, +# and PyTorch DDP to run a multi-GPU workflow. Most users of the +# GNN packages will not interact with cuGraph directly. This example +# is intented for users who want to extend cuGraph within a DDP workflow. + +import os + +import pandas +import numpy as np +import torch +import torch.multiprocessing as tmp +import torch.distributed as dist + +import cudf + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, + cugraph_comms_get_raft_handle, +) + +from pylibcugraph import MGGraph, ResourceHandle, GraphProperties, degrees + +from ogb.nodeproppred import NodePropPredDataset + + +def init_pytorch(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def calc_degree(rank: int, world_size: int, uid, edgelist): + init_pytorch(rank, world_size) + + device = rank + cugraph_comms_init(rank, world_size, uid, device) + + print(f"rank {rank} initialized cugraph") + + src = cudf.Series(np.array_split(edgelist[0], world_size)[rank]) + dst = cudf.Series(np.array_split(edgelist[1], world_size)[rank]) + + seeds = cudf.Series(np.arange(rank * 50, (rank + 1) * 50)) + handle = ResourceHandle(cugraph_comms_get_raft_handle().getHandle()) + + print("constructing graph") + G = MGGraph( + handle, + GraphProperties(is_multigraph=True, is_symmetric=False), + [src], + [dst], + ) + print("graph constructed") + + print("calculating degrees") + vertices, in_deg, out_deg = degrees(handle, G, seeds, do_expensive_check=False) + print("degrees calculated") + + print("constructing dataframe") + df = pandas.DataFrame( + {"v": vertices.get(), "in": in_deg.get(), "out": out_deg.get()} + ) + print(df) + + dist.barrier() + cugraph_comms_shutdown() + print(f"rank {rank} shut down cugraph") + + +def main(): + world_size = torch.cuda.device_count() + uid = cugraph_comms_create_unique_id() + + dataset = NodePropPredDataset("ogbn-products") + el = dataset[0][0]["edge_index"].astype("int64") + + tmp.spawn( + calc_degree, + args=(world_size, uid, el), + nprocs=world_size, + ) + + +if __name__ == "__main__": + main() diff --git a/python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_sg.py new file mode 100644 index 00000000000..2f273ee581e --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/pylibcugraph_sg.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example shows how to use cuGraph and pylibcuGraph to run a +# single-GPU workflow. Most users of the GNN packages will not interact +# with cuGraph directly. This example is intented for users who want +# to extend cuGraph within a PyTorch workflow. + +import pandas +import numpy as np + +import cudf + +from pylibcugraph import SGGraph, ResourceHandle, GraphProperties, degrees + +from ogb.nodeproppred import NodePropPredDataset + + +def calc_degree(edgelist): + src = cudf.Series(edgelist[0]) + dst = cudf.Series(edgelist[1]) + + seeds = cudf.Series(np.arange(256)) + + print("constructing graph") + G = SGGraph( + ResourceHandle(), + GraphProperties(is_multigraph=True, is_symmetric=False), + src, + dst, + ) + print("graph constructed") + + print("calculating degrees") + vertices, in_deg, out_deg = degrees( + ResourceHandle(), G, seeds, do_expensive_check=False + ) + print("degrees calculated") + + print("constructing dataframe") + df = pandas.DataFrame( + {"v": vertices.get(), "in": in_deg.get(), "out": out_deg.get()} + ) + print(df) + + print("done") + + +def main(): + dataset = NodePropPredDataset("ogbn-products") + el = dataset[0][0]["edge_index"].astype("int64") + calc_degree(el) + + +if __name__ == "__main__": + main() diff --git a/python/cugraph/cugraph/gnn/__init__.py b/python/cugraph/cugraph/gnn/__init__.py index f8a3035440b..1f4d98f0230 100644 --- a/python/cugraph/cugraph/gnn/__init__.py +++ b/python/cugraph/cugraph/gnn/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,3 +13,14 @@ from .feature_storage.feat_storage import FeatureStore from .data_loading.bulk_sampler import BulkSampler +from .data_loading.dist_sampler import ( + DistSampler, + DistSampleWriter, + UniformNeighborSampler, +) +from .comms.cugraph_nccl_comms import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, + cugraph_comms_get_raft_handle, +) diff --git a/python/cugraph/cugraph/gnn/comms/__init__.py b/python/cugraph/cugraph/gnn/comms/__init__.py new file mode 100644 index 00000000000..b842dd0927d --- /dev/null +++ b/python/cugraph/cugraph/gnn/comms/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cugraph_nccl_comms import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, + cugraph_comms_get_raft_handle, +) diff --git a/python/cugraph/cugraph/gnn/comms/cugraph_nccl_comms.py b/python/cugraph/cugraph/gnn/comms/cugraph_nccl_comms.py new file mode 100644 index 00000000000..fc27789621f --- /dev/null +++ b/python/cugraph/cugraph/gnn/comms/cugraph_nccl_comms.py @@ -0,0 +1,92 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from raft_dask.common.nccl import nccl +from raft_dask.common.comms_utils import inject_comms_on_handle_coll_only + +from pylibraft.common.handle import Handle +from rmm._cuda.gpu import getDevice, setDevice + +from cugraph.dask.comms.comms_wrapper import init_subcomms + +__nccl_comms = None +__raft_handle = None + + +def nccl_init(rank: int, world_size: int, uid: int): + try: + ni = nccl() + ni.init(world_size, uid, rank) + return ni + except Exception as ex: + raise RuntimeError(f"A nccl error occurred: {ex}") + + +def make_raft_handle( + rank, world_size, nccl_comms, n_streams_per_handle=0, verbose=False +): + handle = Handle(n_streams=n_streams_per_handle) + inject_comms_on_handle_coll_only(handle, nccl_comms, world_size, rank, verbose) + + return handle + + +def __get_2D_div(ngpus): + prows = int(math.sqrt(ngpus)) + while ngpus % prows != 0: + prows = prows - 1 + return prows, int(ngpus / prows) + + +def cugraph_comms_init(rank, world_size, uid, device=0): + global __nccl_comms, __raft_handle + if __nccl_comms is not None or __raft_handle is not None: + raise RuntimeError("cuGraph has already been initialized!") + + # TODO add options for rmm initialization + + global __old_device + __old_device = getDevice() + setDevice(device) + + nccl_comms = nccl_init(rank, world_size, uid) + # FIXME should we use n_streams_per_handle=1 here? + raft_handle = make_raft_handle(rank, world_size, nccl_comms, verbose=True) + + pcols, _ = __get_2D_div(world_size) + init_subcomms(raft_handle, pcols) + + __nccl_comms = nccl_comms + __raft_handle = raft_handle + + +def cugraph_comms_shutdown(): + global __raft_handle, __nccl_comms, __old_device + + __nccl_comms.destroy() + setDevice(__old_device) + + del __raft_handle + del __nccl_comms + del __old_device + + +def cugraph_comms_create_unique_id(): + return nccl.get_unique_id() + + +def cugraph_comms_get_raft_handle(): + global __raft_handle + return __raft_handle diff --git a/python/cugraph/cugraph/gnn/data_loading/__init__.py b/python/cugraph/cugraph/gnn/data_loading/__init__.py index 4b725fba75a..a50f6085e9a 100644 --- a/python/cugraph/cugraph/gnn/data_loading/__init__.py +++ b/python/cugraph/cugraph/gnn/data_loading/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,3 +12,8 @@ # limitations under the License. from cugraph.gnn.data_loading.bulk_sampler import BulkSampler +from cugraph.gnn.data_loading.dist_sampler import ( + DistSampler, + DistSampleWriter, + UniformNeighborSampler, +) diff --git a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py index 194df7d2f75..6abbd82647b 100644 --- a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py +++ b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py @@ -19,7 +19,7 @@ from pandas import isna -from typing import Union, Optional, List +from typing import Union, Optional, List, Dict def create_df_from_disjoint_series(series_list: List[cudf.Series]): @@ -32,6 +32,13 @@ def create_df_from_disjoint_series(series_list: List[cudf.Series]): return df +def create_df_from_disjoint_arrays(array_dict: Dict[str, cupy.array]): + for k in list(array_dict.keys()): + array_dict[k] = cudf.Series(array_dict[k], name=k) + + return create_df_from_disjoint_series(list(array_dict.values())) + + def _write_samples_to_parquet_csr( results: cudf.DataFrame, offsets: cudf.DataFrame, diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py new file mode 100644 index 00000000000..e57e195a4b8 --- /dev/null +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -0,0 +1,577 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +from math import ceil + +import pylibcugraph +import numpy as np +import cupy +import cudf + +from typing import Union, List, Dict, Tuple +from cugraph.utilities import import_optional +from cugraph.gnn.comms import cugraph_comms_get_raft_handle + +from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays + +# PyTorch is NOT optional but this is required for container builds. +torch = import_optional("torch") + +TensorType = Union["torch.Tensor", cupy.ndarray, cudf.Series] + + +class DistSampleWriter: + def __init__( + self, + directory: str, + *, + batches_per_partition: int = 256, + format: str = "parquet", + ): + """ + Parameters + ---------- + directory: str (required) + The directory where samples will be written. This + writer can only write to disk. + batches_per_partition: int (optional, default=256) + The number of batches to write in a single file. + format: str (optional, default='parquet') + The file format of the output files containing the + sampled minibatches. Currently, only parquet format + is supported. + """ + if format != "parquet": + raise ValueError("Invalid format (currently supported: 'parquet')") + + self.__format = format + self.__directory = directory + self.__batches_per_partition = batches_per_partition + + @property + def _format(self): + return self.__format + + @property + def _directory(self): + return self.__directory + + @property + def _batches_per_partition(self): + return self.__batches_per_partition + + def __write_minibatches_coo(self, minibatch_dict): + has_edge_ids = minibatch_dict["edge_id"] is not None + has_edge_types = minibatch_dict["edge_type"] is not None + has_weights = minibatch_dict["weight"] is not None + + if minibatch_dict["renumber_map"] is None: + raise ValueError( + "Distributed sampling without renumbering is not supported" + ) + + # Quit if there are no batches to write. + if len(minibatch_dict["batch_id"]) == 0: + return + + fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( + minibatch_dict["batch_id"] + ) + rank_batch_offset = minibatch_dict["batch_id"][0] + + for p in range( + 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) + ): + partition_start = p * (self.__batches_per_partition) + partition_end = (p + 1) * (self.__batches_per_partition) + + label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ + partition_start * fanout_length : partition_end * fanout_length + 1 + ] + + batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] + start_batch_id = batch_id_array_p[0] - rank_batch_offset + + start_ix, end_ix = label_hop_offsets_array_p[[0, -1]] + majors_array_p = minibatch_dict["majors"][start_ix:end_ix] + minors_array_p = minibatch_dict["minors"][start_ix:end_ix] + edge_id_array_p = ( + minibatch_dict["edge_id"][start_ix:end_ix] + if has_edge_ids + else cupy.array([], dtype="int64") + ) + edge_type_array_p = ( + minibatch_dict["edge_type"][start_ix:end_ix] + if has_edge_types + else cupy.array([], dtype="int32") + ) + weight_array_p = ( + minibatch_dict["weight"][start_ix:end_ix] + if has_weights + else cupy.array([], dtype="float32") + ) + + # create the renumber map offsets + renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ + partition_start : partition_end + 1 + ] + + renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ + [0, -1] + ] + + renumber_map_array_p = minibatch_dict["renumber_map"][ + renumber_map_start_ix:renumber_map_end_ix + ] + + results_dataframe_p = create_df_from_disjoint_arrays( + { + "majors": majors_array_p, + "minors": minors_array_p, + "map": renumber_map_array_p, + "label_hop_offsets": label_hop_offsets_array_p, + "weight": weight_array_p, + "edge_id": edge_id_array_p, + "edge_type": edge_type_array_p, + "renumber_map_offsets": renumber_map_offsets_array_p, + } + ) + + end_batch_id = start_batch_id + len(batch_id_array_p) - 1 + rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 + + full_output_path = os.path.join( + self.__directory, + f"batch={rank:05d}.{start_batch_id:08d}-" + f"{rank:05d}.{end_batch_id:08d}.parquet", + ) + + results_dataframe_p.to_parquet( + full_output_path, + compression=None, + index=False, + force_nullable_schema=True, + ) + + def __write_minibatches_csr(self, minibatch_dict): + raise NotImplementedError( + "CSR format currently not supported for distributed sampling" + ) + + def write_minibatches(self, minibatch_dict): + if (minibatch_dict["majors"] is not None) and ( + minibatch_dict["minors"] is not None + ): + self.__write_minibatches_coo(minibatch_dict) + elif (minibatch_dict["major_offsets"] is not None) and ( + minibatch_dict["minors"] is not None + ): + self.__write_minibatches_csr(minibatch_dict) + else: + raise ValueError("invalid columns") + + +class DistSampler: + def __init__( + self, + graph: Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph], + writer: DistSampleWriter, + local_seeds_per_call: int = 32768, + retain_original_seeds: bool = False, # TODO See #4329, needs C API + ): + """ + Parameters + ---------- + graph: SGGraph or MGGraph (required) + The pylibcugraph graph object that will be sampled. + writer: DistSampleWriter (required) + The writer responsible for writing samples to disk + or, in the future, device or host memory. + local_seeds_per_call: int (optional, default=32768) + The number of seeds on this rank this sampler will + process in a single sampling call. Batches will + get split into multiple sampling calls based on + this parameter. This parameter must + be the same across all ranks. The total number + of seeds processed per sampling call is this + parameter times the world size. + retain_original_seeds: bool (optional, default=False) + Whether to retain the original seeds even if they + do not appear in the output minibatch. This will + affect the output renumber map and CSR/CSC graph + if applicable. + """ + self.__graph = graph + self.__writer = writer + self.__local_seeds_per_call = local_seeds_per_call + self.__handle = None + self.__retain_original_seeds = retain_original_seeds + + def sample_batches( + self, + seeds: TensorType, + batch_ids: TensorType, + random_state: int = 0, + assume_equal_input_size: bool = False, + ) -> Dict[str, TensorType]: + """ + For a single call group of seeds and associated batch ids, performs + sampling. + + Parameters + ---------- + seeds: TensorType + Input seeds for a single call group (node ids). + batch_ids: TensorType + The batch id for each seed. + random_state: int + The random seed to use for sampling. + assume_equal_input_size: bool + If True, will assume all ranks have the same number of inputs, + and will skip the synchronization/gather steps to check for + and handle uneven inputs. + + Returns + ------- + A dictionary containing the sampling outputs (majors, minors, map, etc.) + """ + raise NotImplementedError("Must be implemented by subclass") + + def get_label_list_and_output_rank( + self, local_label_list: TensorType, assume_equal_input_size: bool = False + ): + """ + Computes the label list and output rank mapping for + the list of labels (batch ids). + Subclasses may override this as needed depending on their + memory and compute constraints. + + Parameters + ---------- + local_label_list: TensorType + The list of unique labels on this rank. + assume_equal_input_size: bool + If True, assumes that all ranks have the same number of inputs (batches) + and skips some synchronization/gathering accordingly. + + Returns + ------- + label_list: TensorType + The global label list containing all labels used across ranks. + label_to_output_comm_rank: TensorType + The global mapping of labels to ranks. + """ + world_size = torch.distributed.get_world_size() + + if assume_equal_input_size: + num_batches = len(local_label_list) * world_size + label_list = torch.empty((num_batches,), dtype=torch.int32, device="cuda") + w = torch.distributed.all_gather_into_tensor( + label_list, local_label_list, async_op=True + ) + + label_to_output_comm_rank = torch.concat( + [ + torch.full( + (len(local_label_list),), r, dtype=torch.int32, device="cuda" + ) + for r in range(world_size) + ] + ) + else: + num_batches = torch.tensor( + [len(local_label_list)], device="cuda", dtype=torch.int64 + ) + num_batches_all_ranks = torch.empty( + (world_size,), device="cuda", dtype=torch.int64 + ) + torch.distributed.all_gather_into_tensor(num_batches_all_ranks, num_batches) + + label_list = [ + torch.empty((n,), dtype=torch.int32, device="cuda") + for n in num_batches_all_ranks + ] + w = torch.distributed.all_gather( + label_list, local_label_list, async_op=True + ) + + label_to_output_comm_rank = torch.concat( + [ + torch.full((num_batches_r,), r, device="cuda", dtype=torch.int32) + for r, num_batches_r in enumerate(num_batches_all_ranks) + ] + ) + + w.wait() + if isinstance(label_list, list): + label_list = torch.concat(label_list) + return label_list, label_to_output_comm_rank + + def get_start_batch_offset( + self, local_num_batches: int, assume_equal_input_size: bool = False + ) -> Tuple[int, bool]: + """ + Gets the starting batch offset to ensure each rank's set of batch ids is + disjoint. + + Parameters + ---------- + local_num_batches: int + The number of batches for this rank. + assume_equal_input_size: bool + Whether to assume all ranks have the same number of batches. + + Returns + ------- + Tuple[int, bool] + The starting batch offset (int) + and whether the input sizes on each rank are equal (bool). + + """ + input_size_is_equal = True + if self.is_multi_gpu: + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + if assume_equal_input_size: + t = torch.full( + (world_size,), local_num_batches, dtype=torch.int64, device="cuda" + ) + else: + t = torch.empty((world_size,), dtype=torch.int64, device="cuda") + local_size = torch.tensor( + [local_num_batches], dtype=torch.int64, device="cuda" + ) + + torch.distributed.all_gather_into_tensor(t, local_size) + if (t != local_size).any(): + input_size_is_equal = False + if rank == 0: + warnings.warn( + "Not all ranks received the same number of batches. " + "This might cause your training loop to hang " + "due to uneven inputs." + ) + + return (0 if rank == 0 else t.cumsum(dim=0)[rank - 1], input_size_is_equal) + else: + return 0, input_size_is_equal + + def sample_from_nodes( + self, + nodes: TensorType, + *, + batch_size: int = 16, + random_state: int = 62, + assume_equal_input_size: bool = False, + ): + """ + Performs node-based sampling. Accepts a list of seed nodes, and batch size. + Splits the seed list into batches, then divides the batches into call groups + based on the number of seeds per call this sampler was set to use. + Then calls sample_batches for each call group and writes the result using + the writer associated with this sampler. + + Parameters + ---------- + nodes: TensorType + Input seeds (node ids). + batch_size: int + The size of each batch. + random_state: int + The random seed to use for sampling. + """ + nodes = torch.as_tensor(nodes, device="cuda") + + batches_per_call = self._local_seeds_per_call // batch_size + actual_seeds_per_call = batches_per_call * batch_size + + # Split the input seeds into call groups. Each call group + # corresponds to one sampling call. A call group contains + # many batches. + num_seeds = len(nodes) + nodes_call_groups = torch.split(nodes, actual_seeds_per_call) + + local_num_batches = int(ceil(num_seeds / batch_size)) + batch_id_start, input_size_is_equal = self.get_start_batch_offset( + local_num_batches, assume_equal_input_size=assume_equal_input_size + ) + + # Need to add empties to the list of call groups to handle the case + # where not all nodes have the same number of call groups. This + # prevents a hang since we need all ranks to make the same number + # of calls. + if not input_size_is_equal: + num_call_groups = torch.tensor( + [len(nodes_call_groups)], device="cuda", dtype=torch.int32 + ) + torch.distributed.all_reduce( + num_call_groups, op=torch.distributed.ReduceOp.MAX + ) + nodes_call_groups = list(nodes_call_groups) + ( + [torch.tensor([], dtype=nodes.dtype, device="cuda")] + * (int(num_call_groups) - len(nodes_call_groups)) + ) + + # Make a call to sample_batches for each call group + for i, current_seeds in enumerate(nodes_call_groups): + current_batches = torch.arange( + batch_id_start + i * batches_per_call, + batch_id_start + (i + 1) * batches_per_call, + device="cuda", + dtype=torch.int32, + ) + + current_batches = current_batches.repeat_interleave(batch_size)[ + : len(current_seeds) + ] + + # Handle the case where not all ranks have the same number of call groups, + # in which case there will be some empty groups that get submitted on the + # ranks with fewer call groups. + label_start, label_end = ( + current_batches[[0, -1]] if len(current_batches) > 0 else (0, -1) + ) + + minibatch_dict = self.sample_batches( + seeds=current_seeds, + batch_ids=current_batches, + random_state=random_state, + assume_equal_input_size=input_size_is_equal, + ) + self.__writer.write_minibatches(minibatch_dict) + + @property + def is_multi_gpu(self): + return isinstance(self.__graph, pylibcugraph.MGGraph) + + @property + def _local_seeds_per_call(self): + return self.__local_seeds_per_call + + @property + def _graph(self): + return self.__graph + + @property + def _resource_handle(self): + if self.__handle is None: + if self.is_multi_gpu: + self.__handle = pylibcugraph.ResourceHandle( + cugraph_comms_get_raft_handle().getHandle() + ) + else: + self.__handle = pylibcugraph.ResourceHandle() + return self.__handle + + @property + def _retain_original_seeds(self): + return self.__retain_original_seeds + + +class UniformNeighborSampler(DistSampler): + def __init__( + self, + graph: Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph], + writer: DistSampleWriter, + *, + local_seeds_per_call: int = 32768, + retain_original_seeds: bool = False, + fanout: List[int] = [-1], + prior_sources_behavior: str = "exclude", + deduplicate_sources: bool = True, + compression: str = "COO", + compress_per_hop: bool = False, + with_replacement: bool = False, + ): + super().__init__( + graph, + writer, + local_seeds_per_call=local_seeds_per_call, + retain_original_seeds=retain_original_seeds, + ) + self.__fanout = fanout + self.__prior_sources_behavior = prior_sources_behavior + self.__deduplicate_sources = deduplicate_sources + self.__compress_per_hop = compress_per_hop + self.__compression = compression + self.__with_replacement = with_replacement + + def sample_batches( + self, + seeds: TensorType, + batch_ids: TensorType, + random_state: int = 0, + assume_equal_input_size: bool = False, + ) -> Dict[str, TensorType]: + if self.is_multi_gpu: + rank = torch.distributed.get_rank() + + batch_ids = batch_ids.to(device="cuda", dtype=torch.int32) + local_label_list = torch.unique(batch_ids) + + label_list, label_to_output_comm_rank = self.get_label_list_and_output_rank( + local_label_list, assume_equal_input_size=assume_equal_input_size + ) + + # TODO add calculation of seed vertex label offsets + if self._retain_original_seeds: + warnings.warn( + "The 'retain_original_seeds` parameter is currently ignored " + "since seed retention is not implemented yet." + ) + + sampling_results_dict = pylibcugraph.uniform_neighbor_sample( + self._resource_handle, + self._graph, + start_list=cupy.asarray(seeds), + batch_id_list=cupy.asarray(batch_ids), + label_list=cupy.asarray(label_list), + label_to_output_comm_rank=cupy.asarray(label_to_output_comm_rank), + h_fan_out=np.array(self.__fanout, dtype="int32"), + with_replacement=self.__with_replacement, + do_expensive_check=False, + with_edge_properties=True, + random_state=random_state + rank, + prior_sources_behavior=self.__prior_sources_behavior, + deduplicate_sources=self.__deduplicate_sources, + return_hops=True, + renumber=True, + compression=self.__compression, + compress_per_hop=self.__compress_per_hop, + return_dict=True, + ) + sampling_results_dict["rank"] = rank + else: + sampling_results_dict = pylibcugraph.uniform_neighbor_sample( + self._resource_handle, + self._graph, + start_list=cupy.asarray(seeds), + batch_id_list=cupy.asarray(batch_ids), + h_fan_out=np.array(self.__fanout, dtype="int32"), + with_replacement=self.__with_replacement, + do_expensive_check=False, + with_edge_properties=True, + random_state=random_state, + prior_sources_behavior=self.__prior_sources_behavior, + deduplicate_sources=self.__deduplicate_sources, + return_hops=True, + renumber=True, + compression=self.__compression, + compress_per_hop=self.__compress_per_hop, + return_dict=True, + ) + + return sampling_results_dict diff --git a/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py b/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py new file mode 100644 index 00000000000..02676774a02 --- /dev/null +++ b/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import os +import shutil + +import cupy +import cudf + +from cugraph.datasets import karate +from cugraph.gnn import UniformNeighborSampler, DistSampleWriter + +from pylibcugraph import SGGraph, ResourceHandle, GraphProperties + +from cugraph.utilities.utils import ( + create_directory_with_overwrite, + import_optional, + MissingModule, +) + + +torch = import_optional("torch") + + +@pytest.fixture +def karate_graph(): + el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"}) + G = SGGraph( + ResourceHandle(), + GraphProperties(is_multigraph=True, is_symmetric=False), + el.src.astype("int64"), + el.dst.astype("int64"), + edge_id_array=el.eid, + ) + + return G + + +@pytest.mark.sg +@pytest.mark.parametrize("equal_input_size", [True, False]) +@pytest.mark.parametrize("fanout", [[2, 2], [4, 4], [4, 2, 1]]) +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +def test_dist_sampler_simple( + scratch_dir, karate_graph, batch_size, fanout, equal_input_size +): + G = karate_graph + + samples_path = os.path.join(scratch_dir, "test_bulk_sampler_simple") + create_directory_with_overwrite(samples_path) + + writer = DistSampleWriter(samples_path) + + sampler = UniformNeighborSampler(G, writer, fanout=fanout) + + seeds = cupy.array([0, 5, 10, 15], dtype="int64") + + sampler.sample_from_nodes( + seeds, batch_size=batch_size, assume_equal_input_size=equal_input_size + ) + + recovered_samples = cudf.read_parquet(samples_path) + original_el = karate.get_edgelist() + + for b in range(len(seeds) // batch_size): + el_start = recovered_samples.label_hop_offsets.iloc[b * len(fanout)] + el_end = recovered_samples.label_hop_offsets.iloc[(b + 1) * len(fanout)] + src = recovered_samples.majors.iloc[el_start:el_end] + dst = recovered_samples.minors.iloc[el_start:el_end] + edge_id = recovered_samples.edge_id.iloc[el_start:el_end] + + map_start = recovered_samples.renumber_map_offsets[b] + map_end = recovered_samples.renumber_map_offsets[b + 1] + renumber_map = recovered_samples["map"].iloc[map_start:map_end] + + src = renumber_map.iloc[src.values] + dst = renumber_map.iloc[dst.values] + + for i in range(len(edge_id)): + assert original_el.src.iloc[edge_id.iloc[i]] == src.iloc[i] + assert original_el.dst.iloc[edge_id.iloc[i]] == dst.iloc[i] + + shutil.rmtree(samples_path) diff --git a/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py b/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py new file mode 100644 index 00000000000..bf65e46c516 --- /dev/null +++ b/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py @@ -0,0 +1,230 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import os +import shutil + +import cupy +import cudf + +from cugraph.datasets import karate +from cugraph.gnn import ( + UniformNeighborSampler, + DistSampleWriter, + cugraph_comms_create_unique_id, + cugraph_comms_get_raft_handle, + cugraph_comms_init, + cugraph_comms_shutdown, +) +from pylibcugraph import MGGraph, ResourceHandle, GraphProperties + +from cugraph.utilities.utils import ( + create_directory_with_overwrite, + import_optional, + MissingModule, +) + +torch = import_optional("torch") + + +def karate_mg_graph(rank, world_size): + el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"}) + split = cupy.array_split(cupy.arange(len(el)), world_size)[rank] + el = el.iloc[split] + + G = MGGraph( + ResourceHandle(cugraph_comms_get_raft_handle().getHandle()), + GraphProperties(is_multigraph=True, is_symmetric=False), + [el.src.astype("int64")], + [el.dst.astype("int64")], + edge_id_array=[el.eid], + ) + + return G + + +def init_pytorch(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + + +def run_test_dist_sampler_simple( + rank, + world_size, + uid, + samples_path, + batch_size, + seeds_per_rank, + fanout, + equal_input_size, + seeds_per_call, +): + init_pytorch(rank, world_size) + cugraph_comms_init(rank, world_size, uid, device=rank) + + G = karate_mg_graph(rank, world_size) + + writer = DistSampleWriter(samples_path) + + sampler = UniformNeighborSampler( + G, writer, fanout=fanout, local_seeds_per_call=seeds_per_call + ) + + seeds = cupy.random.randint(0, 34, seeds_per_rank, dtype="int64") + + from time import perf_counter + + start_time = perf_counter() + sampler.sample_from_nodes( + seeds, batch_size=batch_size, assume_equal_input_size=equal_input_size + ) + end_time = perf_counter() + + print("time:", end_time - start_time) + + cugraph_comms_shutdown() + + +@pytest.mark.mg +@pytest.mark.parametrize("equal_input_size", [True, False]) +@pytest.mark.parametrize("fanout", [[4, 4], [4, 2, 1]]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seeds_per_rank", [8, 1]) +@pytest.mark.parametrize("seeds_per_call", [4, 8]) +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not installed") +def test_dist_sampler_simple( + scratch_dir, batch_size, seeds_per_rank, fanout, equal_input_size, seeds_per_call +): + uid = cugraph_comms_create_unique_id() + + samples_path = os.path.join(scratch_dir, "test_bulk_sampler_mg_simple") + create_directory_with_overwrite(samples_path) + + world_size = torch.cuda.device_count() + torch.multiprocessing.spawn( + run_test_dist_sampler_simple, + args=( + world_size, + uid, + samples_path, + batch_size, + seeds_per_rank, + fanout, + equal_input_size, + seeds_per_call, + ), + nprocs=world_size, + ) + + for file in os.listdir(samples_path): + recovered_samples = cudf.read_parquet(os.path.join(samples_path, file)) + original_el = karate.get_edgelist() + + for b in range(len(recovered_samples.renumber_map_offsets.dropna()) - 1): + el_start = int(recovered_samples.label_hop_offsets.iloc[b * len(fanout)]) + el_end = int( + recovered_samples.label_hop_offsets.iloc[(b + 1) * len(fanout)] + ) + src = recovered_samples.majors.iloc[el_start:el_end] + dst = recovered_samples.minors.iloc[el_start:el_end] + edge_id = recovered_samples.edge_id.iloc[el_start:el_end] + + map_start = recovered_samples.renumber_map_offsets[b] + map_end = recovered_samples.renumber_map_offsets[b + 1] + renumber_map = recovered_samples["map"].iloc[map_start:map_end] + + src = renumber_map.iloc[src.values] + dst = renumber_map.iloc[dst.values] + + for i in range(len(edge_id)): + assert original_el.src.iloc[edge_id.iloc[i]] == src.iloc[i] + assert original_el.dst.iloc[edge_id.iloc[i]] == dst.iloc[i] + + shutil.rmtree(samples_path) + + +def run_test_dist_sampler_uneven( + rank, world_size, uid, samples_path, batch_size, fanout, seeds_per_call +): + init_pytorch(rank, world_size) + cugraph_comms_init(rank, world_size, uid, device=rank) + + G = karate_mg_graph(rank, world_size) + + writer = DistSampleWriter(samples_path) + + sampler = UniformNeighborSampler( + G, writer, fanout=fanout, local_seeds_per_call=seeds_per_call + ) + + num_seeds = 8 + rank + seeds = cupy.random.randint(0, 34, num_seeds, dtype="int64") + + from time import perf_counter + + start_time = perf_counter() + sampler.sample_from_nodes( + seeds, batch_size=batch_size, assume_equal_input_size=False + ) + end_time = perf_counter() + + print("time:", end_time - start_time) + + cugraph_comms_shutdown() + + +@pytest.mark.mg +@pytest.mark.parametrize("fanout", [[4, 4], [4, 2, 1]]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seeds_per_call", [4, 8, 16]) +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not installed") +def test_dist_sampler_uneven(scratch_dir, batch_size, fanout, seeds_per_call): + uid = cugraph_comms_create_unique_id() + + samples_path = os.path.join(scratch_dir, "test_bulk_sampler_mg_uneven") + create_directory_with_overwrite(samples_path) + + world_size = torch.cuda.device_count() + torch.multiprocessing.spawn( + run_test_dist_sampler_uneven, + args=(world_size, uid, samples_path, batch_size, fanout, seeds_per_call), + nprocs=world_size, + ) + + for file in os.listdir(samples_path): + recovered_samples = cudf.read_parquet(os.path.join(samples_path, file)) + original_el = karate.get_edgelist() + + for b in range(len(recovered_samples.renumber_map_offsets.dropna()) - 1): + el_start = int(recovered_samples.label_hop_offsets.iloc[b * len(fanout)]) + el_end = int( + recovered_samples.label_hop_offsets.iloc[(b + 1) * len(fanout)] + ) + src = recovered_samples.majors.iloc[el_start:el_end] + dst = recovered_samples.minors.iloc[el_start:el_end] + edge_id = recovered_samples.edge_id.iloc[el_start:el_end] + + map_start = recovered_samples.renumber_map_offsets[b] + map_end = recovered_samples.renumber_map_offsets[b + 1] + renumber_map = recovered_samples["map"].iloc[map_start:map_end] + + src = renumber_map.iloc[src.values] + dst = renumber_map.iloc[dst.values] + + for i in range(len(edge_id)): + assert original_el.src.iloc[edge_id.iloc[i]] == src.iloc[i] + assert original_el.dst.iloc[edge_id.iloc[i]] == dst.iloc[i] + + shutil.rmtree(samples_path)