Skip to content

Commit

Permalink
wrap up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed May 14, 2024
1 parent 887a7fe commit 4a33bde
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 7 deletions.
18 changes: 12 additions & 6 deletions python/cugraph-pyg/cugraph_pyg/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def _graph(self) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]:
self.__graph = pylibcugraph.MGGraph(
self._resource_handle,
graph_properties,
[cupy.asarray(edgelist_dict["src"])],
[cupy.asarray(edgelist_dict["dst"])],
[cupy.asarray(edgelist_dict["src"]).astype("int64")],
[cupy.asarray(edgelist_dict["dst"]).astype("int64")],
vertices_array=[vertices_array],
edge_id_array=[cupy.asarray(edgelist_dict["eid"])],
edge_type_array=[cupy.asarray(edgelist_dict["etp"])],
Expand All @@ -157,8 +157,8 @@ def _graph(self) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]:
self.__graph = pylibcugraph.SGGraph(
self._resource_handle,
graph_properties,
cupy.asarray(edgelist_dict["src"]),
cupy.asarray(edgelist_dict["dst"]),
cupy.asarray(edgelist_dict["src"]).astype("int64"),
cupy.asarray(edgelist_dict["dst"]).astype("int64"),
vertices_array=cupy.arange(
sum(self._num_vertices().values()), dtype="int64"
),
Expand All @@ -184,14 +184,20 @@ def _num_vertices(self) -> Dict[str, int]:
)
else:
if edge_attr.edge_type[0] not in num_vertices:
num_vertices[edge_attr.edge_type[0]] = (
num_vertices[edge_attr.edge_type[0]] = int(
self.__edge_indices[edge_attr.edge_type][0].max() + 1
)
if edge_attr.edge_type[2] not in num_vertices:
num_vertices[edge_attr.edge_type[1]] = (
num_vertices[edge_attr.edge_type[1]] = int(
self.__edge_indices[edge_attr.edge_type][1].max() + 1
)

if self.is_multi_gpu:
vtypes = num_vertices.keys()
for vtype in vtypes:
sz = torch.tensor(num_vertices[vtype], device="cuda")
torch.distributed.all_reduce(sz, op=torch.distributed.ReduceOp.MAX)
num_vertices[vtype] = int(sz)
return num_vertices

@property
Expand Down
45 changes: 45 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/data/test_graph_store_mg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from cugraph.datasets import karate
from cugraph.utilities.utils import import_optional, MissingModule

from cugraph_pyg.data import GraphStore

torch = import_optional("torch")


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.mg
def test_graph_store_basic_api_mg():
df = karate.get_edgelist()
src = torch.as_tensor(df["src"], device="cuda")
dst = torch.as_tensor(df["dst"], device="cuda")

ei = torch.stack([dst, src])

graph_store = GraphStore(is_multi_gpu=True)
graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo")

rei = graph_store.get_edge_index(("person", "knows", "person"), "coo")

assert (ei == rei).all()

edge_attrs = graph_store.get_all_edge_attrs()
assert len(edge_attrs) == 1

graph_store.remove_edge_index(("person", "knows", "person"), "coo")
edge_attrs = graph_store.get_all_edge_attrs()
assert len(edge_attrs) == 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from cugraph.datasets import karate
from cugraph.utilities.utils import import_optional, MissingModule

from cugraph_pyg.data import TensorDictFeatureStore, GraphStore
from cugraph_pyg.loader import NeighborLoader

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.sg
def test_neighbor_loader():
"""
Basic e2e test that covers loading and sampling.
"""

df = karate.get_edgelist()
src = torch.as_tensor(df["src"], device="cuda")
dst = torch.as_tensor(df["dst"], device="cuda")

ei = torch.stack([dst, src])

graph_store = GraphStore()
graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo")

feature_store = TensorDictFeatureStore()
feature_store["person", "feat"] = torch.randint(128, (34, 16))

loader = NeighborLoader(
(feature_store, graph_store),
[5, 5],
input_nodes=torch.arange(34),
directory=".",
)

for batch in loader:
assert isinstance(batch, torch_geometric.data.Data)
assert (feature_store["person", "feat"][batch.n_id] == batch.feat).all()
111 changes: 111 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import os

from cugraph.datasets import karate
from cugraph.utilities.utils import import_optional, MissingModule

from cugraph_pyg.data import TensorDictFeatureStore, GraphStore
from cugraph_pyg.loader import NeighborLoader

from cugraph.gnn import (
cugraph_comms_init,
cugraph_comms_shutdown,
cugraph_comms_create_unique_id,
)

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")


def init_pytorch_worker(rank, world_size, cugraph_id):
import rmm

rmm.reinitialize(
devices=rank,
)

import cupy

cupy.cuda.Device(rank).use()
from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)

from cugraph.testing.mg_utils import enable_spilling

enable_spilling()

torch.cuda.set_device(rank)

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)

cugraph_comms_init(rank=rank, world_size=world_size, uid=cugraph_id, device=rank)


def run_test_neighbor_loader_mg(rank, uid, world_size, specify_size):
"""
Basic e2e test that covers loading and sampling.
"""
init_pytorch_worker(rank, world_size, uid)

df = karate.get_edgelist()
src = torch.as_tensor(df["src"], device="cuda")
dst = torch.as_tensor(df["dst"], device="cuda")

ei = torch.stack([dst, src])
ei = torch.tensor_split(ei.clone(), world_size, axis=1)[rank]

sz = (34, 34) if specify_size else None
graph_store = GraphStore(is_multi_gpu=True)
graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo", False, sz)

feature_store = TensorDictFeatureStore()
feature_store["person", "feat"] = torch.randint(128, (34, 16))

ix_train = torch.tensor_split(torch.arange(34), world_size, axis=0)[rank]

loader = NeighborLoader(
(feature_store, graph_store),
[5, 5],
input_nodes=ix_train,
)

for batch in loader:
assert isinstance(batch, torch_geometric.data.Data)
assert (feature_store["person", "feat"][batch.n_id] == batch.feat).all()

cugraph_comms_shutdown()


@pytest.mark.parametrize("specify_size", [True, False])
@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.mg
def test_neighbor_loader_mg(specify_size):
uid = cugraph_comms_create_unique_id()
world_size = torch.cuda.device_count()

torch.multiprocessing.spawn(
run_test_neighbor_loader_mg,
args=(
uid,
world_size,
specify_size,
),
nprocs=world_size,
)
4 changes: 3 additions & 1 deletion python/cugraph/cugraph/gnn/data_loading/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def __init__(
ex = re.compile(r"batch\=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet")
filematch = [ex.match(f) for f in files]
filematch = [f for f in filematch if f]
filematch = [f for f in filematch if int(f[1]) == rank]

if rank is not None:
filematch = [f for f in filematch if int(f[1]) == rank]

batch_count = sum([int(f[4]) - int(f[2]) + 1 for f in filematch])
filematch = sorted(filematch, key=lambda f: int(f[2]), reverse=True)
Expand Down

0 comments on commit 4a33bde

Please sign in to comment.