Skip to content

Commit

Permalink
[BUG] Fix Failing WholeGraph Tests (#4560)
Browse files Browse the repository at this point in the history
This PR properly uses PyTorch DDP to initialize a process group and test the WholeGraph feature store.  Previously it was relying on an API in WholeGraph that no longer appears to work.

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4560
  • Loading branch information
alexbarghi-nv authored Jul 31, 2024
1 parent aa0347c commit 5458e76
Showing 1 changed file with 22 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import pytest
import numpy as np
import os

from cugraph.gnn import FeatureStore

Expand All @@ -21,18 +22,23 @@
pylibwholegraph = import_optional("pylibwholegraph")
wmb = import_optional("pylibwholegraph.binding.wholememory_binding")
torch = import_optional("torch")
wgth = import_optional("pylibwholegraph.torch")


def runtest(world_rank: int, world_size: int):
from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm
def runtest(rank: int, world_size: int):
torch.cuda.set_device(rank)

wm_comm, _ = init_torch_env_and_create_wm_comm(
world_rank,
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)

pylibwholegraph.torch.initialize.init(
rank,
world_size,
world_rank,
rank,
world_size,
)
wm_comm = wm_comm.wmb_comm
wm_comm = wgth.get_global_communicator()

generator = np.random.default_rng(62)
arr = (
Expand All @@ -52,36 +58,32 @@ def runtest(world_rank: int, world_size: int):
expected = arr[indices_to_fetch]
np.testing.assert_array_equal(output_fs.cpu().numpy(), expected)

wmb.finalize()
pylibwholegraph.torch.initialize.finalize()


@pytest.mark.sg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
@pytest.mark.skip(reason="broken")
def test_feature_storage_wholegraph_backend():
from pylibwholegraph.utils.multiprocess import multiprocess_run
world_size = torch.cuda.device_count()
print("gpu count:", world_size)
assert world_size > 0

gpu_count = wmb.fork_get_gpu_count()
print("gpu count:", gpu_count)
assert gpu_count > 0
print("ignoring gpu count and running on 1 GPU only")

multiprocess_run(1, runtest)
torch.multiprocessing.spawn(runtest, args=(1,), nprocs=1)


@pytest.mark.mg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
@pytest.mark.skip(reason="broken")
def test_feature_storage_wholegraph_backend_mg():
from pylibwholegraph.utils.multiprocess import multiprocess_run

gpu_count = wmb.fork_get_gpu_count()
print("gpu count:", gpu_count)
assert gpu_count > 0
world_size = torch.cuda.device_count()
print("gpu count:", world_size)
assert world_size > 0

multiprocess_run(gpu_count, runtest)
torch.multiprocessing.spawn(runtest, args=(world_size,), nprocs=world_size)

0 comments on commit 5458e76

Please sign in to comment.