From 5458e76c332dd80e025704344307e1b96cc911b6 Mon Sep 17 00:00:00 2001 From: Alex Barghi <105237337+alexbarghi-nv@users.noreply.github.com> Date: Tue, 30 Jul 2024 22:04:45 -0400 Subject: [PATCH] [BUG] Fix Failing WholeGraph Tests (#4560) This PR properly uses PyTorch DDP to initialize a process group and test the WholeGraph feature store. Previously it was relying on an API in WholeGraph that no longer appears to work. Authors: - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: https://github.com/rapidsai/cugraph/pull/4560 --- .../test_gnn_feat_storage_wholegraph.py | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py b/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py index 0a272e445fa..30336490312 100644 --- a/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py +++ b/python/cugraph/cugraph/tests/data_store/test_gnn_feat_storage_wholegraph.py @@ -13,6 +13,7 @@ import pytest import numpy as np +import os from cugraph.gnn import FeatureStore @@ -21,18 +22,23 @@ pylibwholegraph = import_optional("pylibwholegraph") wmb = import_optional("pylibwholegraph.binding.wholememory_binding") torch = import_optional("torch") +wgth = import_optional("pylibwholegraph.torch") -def runtest(world_rank: int, world_size: int): - from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm +def runtest(rank: int, world_size: int): + torch.cuda.set_device(rank) - wm_comm, _ = init_torch_env_and_create_wm_comm( - world_rank, + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) + + pylibwholegraph.torch.initialize.init( + rank, world_size, - world_rank, + rank, world_size, ) - wm_comm = wm_comm.wmb_comm + wm_comm = wgth.get_global_communicator() generator = np.random.default_rng(62) arr = ( @@ -52,7 +58,7 @@ def runtest(world_rank: int, world_size: int): expected = arr[indices_to_fetch] np.testing.assert_array_equal(output_fs.cpu().numpy(), expected) - wmb.finalize() + pylibwholegraph.torch.initialize.finalize() @pytest.mark.sg @@ -60,15 +66,14 @@ def runtest(world_rank: int, world_size: int): @pytest.mark.skipif( isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" ) -@pytest.mark.skip(reason="broken") def test_feature_storage_wholegraph_backend(): - from pylibwholegraph.utils.multiprocess import multiprocess_run + world_size = torch.cuda.device_count() + print("gpu count:", world_size) + assert world_size > 0 - gpu_count = wmb.fork_get_gpu_count() - print("gpu count:", gpu_count) - assert gpu_count > 0 + print("ignoring gpu count and running on 1 GPU only") - multiprocess_run(1, runtest) + torch.multiprocessing.spawn(runtest, args=(1,), nprocs=1) @pytest.mark.mg @@ -76,12 +81,9 @@ def test_feature_storage_wholegraph_backend(): @pytest.mark.skipif( isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" ) -@pytest.mark.skip(reason="broken") def test_feature_storage_wholegraph_backend_mg(): - from pylibwholegraph.utils.multiprocess import multiprocess_run - - gpu_count = wmb.fork_get_gpu_count() - print("gpu count:", gpu_count) - assert gpu_count > 0 + world_size = torch.cuda.device_count() + print("gpu count:", world_size) + assert world_size > 0 - multiprocess_run(gpu_count, runtest) + torch.multiprocessing.spawn(runtest, args=(world_size,), nprocs=world_size)