From b18e33b77c69dbe242b242c81309ccad7df36b7d Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Thu, 8 Aug 2024 11:44:31 -0700 Subject: [PATCH 1/3] global shuffle --- .../cugraph_pyg/loader/loader_utils.py | 70 ++++++++++++ .../cugraph_pyg/loader/node_loader.py | 53 ++++++++- .../tests/loader/test_loader_utils_mg.py | 107 ++++++++++++++++++ 3 files changed, 226 insertions(+), 4 deletions(-) create mode 100644 python/cugraph-pyg/cugraph_pyg/loader/loader_utils.py create mode 100644 python/cugraph-pyg/cugraph_pyg/tests/loader/test_loader_utils_mg.py diff --git a/python/cugraph-pyg/cugraph_pyg/loader/loader_utils.py b/python/cugraph-pyg/cugraph_pyg/loader/loader_utils.py new file mode 100644 index 00000000000..d014cd2862f --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/loader/loader_utils.py @@ -0,0 +1,70 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cugraph.utilities.utils import import_optional +from typing import List + +torch = import_optional("torch") + + +def scatter( + t: "torch.Tensor", scatter_perm: List["torch.Tensor"], rank: int, world_size: int +): + """ + t: torch.Tensor + The local tensor being scattered. + scatter_perm: List[torch.Tensor] + The indices to send to each rank. + rank: int + The global rank of this worker. + world_size: int + The total number of workers. + """ + + scatter_len = torch.tensor( + [s.numel() for s in scatter_perm], device="cuda", dtype=torch.int64 + ) + + scatter_len_all = [ + torch.empty((world_size,), device="cuda", dtype=torch.int64) + for _ in range(world_size) + ] + torch.distributed.all_gather(scatter_len_all, scatter_len) + + t = t.cuda() + local_tensors = [ + torch.empty((scatter_len_all[r][rank],), device="cuda", dtype=torch.int64) + for r in range(world_size) + ] + + qx = [] + for r in range(world_size): + send_rank = (rank + r) % world_size + send_op = torch.distributed.P2POp( + torch.distributed.isend, + t[scatter_perm[send_rank]], + send_rank, + ) + + recv_rank = (rank - r) % world_size + recv_op = torch.distributed.P2POp( + torch.distributed.irecv, + local_tensors[recv_rank], + recv_rank, + ) + qx += torch.distributed.batch_isend_irecv([send_op, recv_op]) + + for x in qx: + x.wait() + + return torch.concat(local_tensors) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py index 49923783d6b..5536c7654de 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py @@ -17,6 +17,7 @@ from typing import Union, Tuple, Callable, Optional from cugraph.utilities.utils import import_optional +from .loader_utils import scatter torch_geometric = import_optional("torch_geometric") torch = import_optional("torch") @@ -50,6 +51,7 @@ def __init__( batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, + global_shuffle: bool = True, **kwargs, ): """ @@ -74,7 +76,17 @@ def __init__( always return a Data or HeteroData object. input_id: OptTensor See torch_geometric.loader.NodeLoader. - + batch_size: int + The size of each batch. + shuffle: bool + Whether to shuffle data into random batches. + drop_last: bool + Whether to drop remaining inputs that can't form a full + batch. + global_shuffle: bool + (cuGraph-PyG only) Whether or not to shuffle globally. + It might make sense to turn this off if comms are slow, + but there may be a penalty to accuracy. """ if not isinstance(data, (list, tuple)) or not isinstance( data[1], cugraph_pyg.data.GraphStore @@ -125,7 +137,39 @@ def __init__( self.__shuffle = shuffle self.__drop_last = drop_last - def __iter__(self): + def __get_input(self): + _, graph_store = self.__data + if graph_store.is_multi_gpu and self.__shuffle and self.__global_shuffle: + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + scatter_perm = torch.tensor_split( + torch.randperm( + self.__input_data.node.numel(), device="cpu", dtype=torch.int64 + ), + world_size, + ) + + new_node = scatter(self.__input_data.node, scatter_perm, rank, world_size) + local_perm = torch.randperm(new_node.numel()) + if self.__drop_last: + d = local_perm.numel() % self.__batch_size + local_perm = local_perm[:-d] + + return torch_geometric.loader.node_loader.NodeSamplerInput( + input_id=None + if self.__input_data.input_id is None + else scatter( + self.__input_data.input_id, scatter_perm, rank, world_size + )[local_perm], + time=None + if self.__input_data.time is None + else scatter(self.__input_data.time, scatter_perm, rank, world_size)[ + local_perm + ], + node=new_node[local_perm], + input_type=self.__input_data.input_type, + ) + if self.__shuffle: perm = torch.randperm(self.__input_data.node.numel()) else: @@ -135,7 +179,7 @@ def __iter__(self): d = perm.numel() % self.__batch_size perm = perm[:-d] - input_data = torch_geometric.loader.node_loader.NodeSamplerInput( + return torch_geometric.loader.node_loader.NodeSamplerInput( input_id=None if self.__input_data.input_id is None else self.__input_data.input_id[perm], @@ -146,6 +190,7 @@ def __iter__(self): input_type=self.__input_data.input_type, ) + def __iter__(self): return cugraph_pyg.sampler.SampleIterator( - self.__data, self.__node_sampler.sample_from_nodes(input_data) + self.__data, self.__node_sampler.sample_from_nodes(self.__get_input()) ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_loader_utils_mg.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_loader_utils_mg.py new file mode 100644 index 00000000000..523c5a989a9 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_loader_utils_mg.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from cugraph.utilities.utils import import_optional, MissingModule + + +from cugraph.gnn import ( + cugraph_comms_init, + cugraph_comms_shutdown, + cugraph_comms_create_unique_id, +) + +from cugraph_pyg.loader.loader_utils import scatter + +torch = import_optional("torch") +torch_geometric = import_optional("torch_geometric") + + +def init_pytorch_worker(rank, world_size, cugraph_id): + import rmm + + rmm.reinitialize( + devices=rank, + ) + + import cupy + + cupy.cuda.Device(rank).use() + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + from cugraph.testing.mg_utils import enable_spilling + + enable_spilling() + + torch.cuda.set_device(rank) + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + + cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id, device=rank) + + +def run_test_loader_utils_scatter(rank, world_size, uid): + init_pytorch_worker(rank, world_size, uid) + + num_values_rank = (1 + rank) * 9 + local_values = torch.arange(0, num_values_rank) + 9 * ( + rank + ((rank * (rank - 1)) // 2) + ) + + scatter_perm = torch.tensor_split(torch.arange(local_values.numel()), world_size) + + new_values = scatter(local_values, scatter_perm, rank, world_size) + print( + rank, + local_values, + new_values, + flush=True, + ) + + offset = 0 + for send_rank in range(world_size): + num_values_send_rank = (1 + send_rank) * 9 + + expected_values = torch.tensor_split( + torch.arange(0, num_values_send_rank) + + 9 * (send_rank + ((send_rank * (send_rank - 1)) // 2)), + world_size, + )[rank] + + ix_sent = torch.arange(expected_values.numel()) + values_rec = new_values[ix_sent + offset].cpu() + offset += values_rec.numel() + + assert (values_rec == expected_values).all() + + cugraph_comms_shutdown() + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg +def test_loader_utils_scatter(): + uid = cugraph_comms_create_unique_id() + world_size = torch.cuda.device_count() + + torch.multiprocessing.spawn( + run_test_loader_utils_scatter, + args=(world_size, uid), + nprocs=world_size, + ) From a4ac6875970bb6f4240540efb045436e3acd0cc8 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Thu, 8 Aug 2024 11:53:00 -0700 Subject: [PATCH 2/3] properly set option --- python/cugraph-pyg/cugraph_pyg/loader/node_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py index 5536c7654de..6b82c68ff44 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py @@ -136,6 +136,7 @@ def __init__( self.__batch_size = batch_size self.__shuffle = shuffle self.__drop_last = drop_last + self.__global_shuffle = global_shuffle def __get_input(self): _, graph_store = self.__data From 63065f54dc3723133f4f8ca734a246246caf1800 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Thu, 8 Aug 2024 14:07:09 -0700 Subject: [PATCH 3/3] setting global_shuffle --- python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py | 3 +++ python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py index 7002d7ebded..66831e7c042 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py @@ -204,6 +204,7 @@ def run_train( directory=train_path, shuffle=True, drop_last=True, + global_shuffle=True, **kwargs, ) @@ -217,6 +218,7 @@ def run_train( shuffle=True, drop_last=True, local_seeds_per_call=80000, + global_shuffle=False, **kwargs, ) @@ -229,6 +231,7 @@ def run_train( directory=valid_path, shuffle=True, drop_last=True, + global_shuffle=False, **kwargs, ) diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py index b1bb0240e71..95a88c20e0d 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py @@ -127,6 +127,7 @@ def run_train( directory=train_path, shuffle=True, drop_last=True, + global_shuffle=True, **kwargs, ) @@ -140,6 +141,7 @@ def run_train( shuffle=True, drop_last=True, local_seeds_per_call=80000, + global_shuffle=False, **kwargs, ) @@ -152,6 +154,7 @@ def run_train( directory=valid_path, shuffle=True, drop_last=True, + global_shuffle=False, **kwargs, )