Skip to content

Commit

Permalink
mg testing
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Aug 26, 2024
1 parent 55d8ad1 commit 5e1b986
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import cupy
import cudf

from typing import Any

from cugraph.datasets import karate
from cugraph.gnn import (
UniformNeighborSampler,
Expand All @@ -27,6 +29,7 @@
cugraph_comms_init,
cugraph_comms_shutdown,
)
from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays
from pylibcugraph import MGGraph, ResourceHandle, GraphProperties

from cugraph.utilities.utils import (
Expand Down Expand Up @@ -235,3 +238,80 @@ def test_dist_sampler_uneven(scratch_dir, batch_size, fanout, seeds_per_call):
assert original_el.dst.iloc[edge_id.iloc[i]] == dst.iloc[i]

shutil.rmtree(samples_path)


def run_test_dist_sampler_buffered_in_memory(
rank: int,
world_size: int,
uid: Any,
samples_path: str,
seeds_per_call: int,
compression: str,
):
init_pytorch(rank, world_size)
cugraph_comms_init(rank, world_size, uid, device=rank)

G = karate_mg_graph(rank, world_size)

num_seeds = 8
seeds = cupy.random.randint(0, 34, num_seeds, dtype="int64")

unbuffered_sampler = UniformNeighborSampler(
G,
writer=DistSampleWriter(samples_path),
local_seeds_per_call=seeds_per_call,
compression=compression,
)

buffered_sampler = UniformNeighborSampler(
G,
writer=None,
local_seeds_per_call=seeds_per_call,
compression=compression,
)

unbuffered_results = unbuffered_sampler.sample_from_nodes(
seeds,
batch_size=4,
)

unbuffered_results = [
(create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in unbuffered_results
]

buffered_results = buffered_sampler.sample_from_nodes(seeds, batch_size=4)
buffered_results = [
(create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in buffered_results
]

assert len(buffered_results) == len(unbuffered_results)

for k in range(len(buffered_results)):
br, bs, be = buffered_results[k]
ur, us, ue = unbuffered_results[k]

assert bs == us
assert be == ue

for col in ur.columns:
assert (br[col].dropna() == ur[col].dropna()).all()


@pytest.mark.mg
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.parametrize("seeds_per_call", [4, 5, 10])
@pytest.mark.parametrize("compression", ["COO", "CSR"])
def test_dist_sampler_buffered_in_memory(scratch_dir, seeds_per_call, compression):
uid = cugraph_comms_create_unique_id()

samples_path = os.path.join(scratch_dir, "test_bulk_sampler_buffered_in_memory_mg")
create_directory_with_overwrite(samples_path)

world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(
run_test_dist_sampler_buffered_in_memory,
args=(world_size, uid, samples_path, seeds_per_call, compression),
nprocs=world_size,
)

shutil.rmtree(samples_path)

0 comments on commit 5e1b986

Please sign in to comment.