From 1433d17d3d5eda22dd79dc06b7dac9467d2b74aa Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Thu, 5 Sep 2024 14:41:46 -0700 Subject: [PATCH] tests --- .../cugraph_pyg/loader/__init__.py | 3 + .../cugraph_pyg/loader/link_loader.py | 13 +- .../loader/link_neighbor_loader.py | 243 ++++++++++++++++++ .../cugraph_pyg/loader/node_loader.py | 2 +- .../cugraph_pyg/sampler/sampler.py | 1 + .../tests/loader/test_neighbor_loader.py | 40 +++ .../tests/loader/test_neighbor_loader_mg.py | 74 +++++- .../cugraph/gnn/data_loading/dist_sampler.py | 6 +- 8 files changed, 373 insertions(+), 9 deletions(-) create mode 100644 python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py diff --git a/python/cugraph-pyg/cugraph_pyg/loader/__init__.py b/python/cugraph-pyg/cugraph_pyg/loader/__init__.py index cad66aaa183..c804b3d1f97 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/__init__.py @@ -16,6 +16,9 @@ from cugraph_pyg.loader.node_loader import NodeLoader from cugraph_pyg.loader.neighbor_loader import NeighborLoader +from cugraph_pyg.loader.link_loader import LinkLoader +from cugraph_pyg.loader.link_neighbor_loader import LinkNeighborLoader + from cugraph_pyg.loader.dask_node_loader import DaskNeighborLoader from cugraph_pyg.loader.dask_node_loader import BulkSampleLoader diff --git a/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py index 21e92a817df..4a2832a4bfc 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py @@ -117,12 +117,12 @@ def __init__( edge_label_index, ) = torch_geometric.loader.utils.get_edge_label_index( data, - edge_label_index, + (None, edge_label_index), ) self.__input_data = torch_geometric.sampler.EdgeSamplerInput( input_id=torch.arange( - edge_label_index.shape[-1], dtype=torch.int64, device="cuda" + edge_label_index[0].numel(), dtype=torch.int64, device="cuda" ) if input_id is None else input_id, @@ -136,6 +136,7 @@ def __init__( self.__data = data self.__link_sampler = link_sampler + self.__neg_sampling = neg_sampling self.__batch_size = batch_size self.__shuffle = shuffle @@ -151,7 +152,7 @@ def __iter__(self): d = perm.numel() % self.__batch_size perm = perm[:-d] - input_data = torch_geometric.loader.node_loader.EdgeSamplerInput( + input_data = torch_geometric.sampler.EdgeSamplerInput( input_id=self.__input_data.input_id[perm], row=self.__input_data.row[perm], col=self.__input_data.col[perm], @@ -165,5 +166,9 @@ def __iter__(self): ) return cugraph_pyg.sampler.SampleIterator( - self.__data, self.__link_sampler.sample_from_edges(input_data) + self.__data, + self.__link_sampler.sample_from_edges( + input_data, + neg_sampling=self.__neg_sampling, + ), ) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py new file mode 100644 index 00000000000..080565368c4 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py @@ -0,0 +1,243 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from typing import Union, Tuple, Optional, Callable, List, Dict + +import cugraph_pyg +from cugraph_pyg.loader import LinkLoader +from cugraph_pyg.sampler import BaseSampler + +from cugraph.gnn import NeighborSampler, DistSampleWriter +from cugraph.utilities.utils import import_optional + +torch_geometric = import_optional("torch_geometric") + + +class LinkNeighborLoader(LinkLoader): + """ + Duck-typed version of torch_geometric.loader.LinkNeighborLoader + + Link loader that implements the neighbor sampling + algorithm used in GraphSAGE. + """ + + def __init__( + self, + data: Union[ + "torch_geometric.data.Data", + "torch_geometric.data.HeteroData", + Tuple[ + "torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore" + ], + ], + num_neighbors: Union[ + List[int], Dict["torch_geometric.typing.EdgeType", List[int]] + ], + edge_label_index: "torch_geometric.typing.InputEdges" = None, + edge_label: "torch_geometric.typing.OptTensor" = None, + edge_label_time: "torch_geometric.typing.OptTensor" = None, + replace: bool = False, + subgraph_type: Union[ + "torch_geometric.typing.SubgraphType", str + ] = "directional", + disjoint: bool = False, + temporal_strategy: str = "uniform", + neg_sampling: Optional["torch_geometric.sampler.NegativeSampling"] = None, + neg_sampling_ratio: Optional[Union[int, float]] = None, + time_attr: Optional[str] = None, + weight_attr: Optional[str] = None, + transform: Optional[Callable] = None, + transform_sampler_output: Optional[Callable] = None, + is_sorted: bool = False, + filter_per_worker: Optional[bool] = None, + neighbor_sampler: Optional["torch_geometric.sampler.NeighborSampler"] = None, + directed: bool = True, # Deprecated. + batch_size: int = 16, # Refers to number of edges per batch. + directory: Optional[str] = None, + batches_per_partition=256, + format: str = "parquet", + compression: Optional[str] = None, + local_seeds_per_call: Optional[int] = None, + **kwargs, + ): + """ + data: Data, HeteroData, or Tuple[FeatureStore, GraphStore] + See torch_geometric.loader.LinkNeighborLoader. + num_neighbors: List[int] or Dict[EdgeType, List[int]] + Fanout values. + See torch_geometric.loader.LinkNeighborLoader. + edge_label_index: InputEdges + Input edges for sampling. + See torch_geometric.loader.LinkNeighborLoader. + edge_label: OptTensor + Labels for input edges. + See torch_geometric.loader.LinkNeighborLoader. + edge_label_time: OptTensor + Time attribute for input edges. + See torch_geometric.loader.LinkNeighborLoader. + replace: bool (optional, default=False) + Whether to sample with replacement. + See torch_geometric.loader.LinkNeighborLoader. + subgraph_type: Union[SubgraphType, str] (optional, default='directional') + The type of subgraph to return. + Currently only 'directional' is supported. + See torch_geometric.loader.LinkNeighborLoader. + disjoint: bool (optional, default=False) + Whether to perform disjoint sampling. + Currently unsupported. + See torch_geometric.loader.LinkNeighborLoader. + temporal_strategy: str (optional, default='uniform') + Currently only 'uniform' is suppported. + See torch_geometric.loader.LinkNeighborLoader. + time_attr: str (optional, default=None) + Used for temporal sampling. + See torch_geometric.loader.LinkNeighborLoader. + weight_attr: str (optional, default=None) + Used for biased sampling. + See torch_geometric.loader.LinkNeighborLoader. + transform: Callable (optional, default=None) + See torch_geometric.loader.LinkNeighborLoader. + transform_sampler_output: Callable (optional, default=None) + See torch_geometric.loader.LinkNeighborLoader. + is_sorted: bool (optional, default=False) + Ignored by cuGraph. + See torch_geometric.loader.LinkNeighborLoader. + filter_per_worker: bool (optional, default=False) + Currently ignored by cuGraph, but this may + change once in-memory sampling is implemented. + See torch_geometric.loader.LinkNeighborLoader. + neighbor_sampler: torch_geometric.sampler.NeighborSampler + (optional, default=None) + Not supported by cuGraph. + See torch_geometric.loader.LinkNeighborLoader. + directed: bool (optional, default=True) + Deprecated. + See torch_geometric.loader.LinkNeighborLoader. + batch_size: int (optional, default=16) + The number of input nodes per output minibatch. + See torch.utils.dataloader. + directory: str (optional, default=None) + The directory where samples will be temporarily stored, + if spilling samples to disk. If None, this loader + will perform buffered in-memory sampling. + If writing to disk, setting this argument + to a tempfile.TemporaryDirectory with a context + manager is a good option but depending on the filesystem, + you may want to choose an alternative location with fast I/O + intead. + See cugraph.gnn.DistSampleWriter. + batches_per_partition: int (optional, default=256) + The number of batches per partition if writing samples to + disk. Manually tuning this parameter is not recommended + but reducing it may help conserve GPU memory. + See cugraph.gnn.DistSampleWriter. + format: str (optional, default='parquet') + If writing samples to disk, they will be written in this + file format. + See cugraph.gnn.DistSampleWriter. + compression: str (optional, default=None) + The compression type to use if writing samples to disk. + If not provided, it is automatically chosen. + local_seeds_per_call: int (optional, default=None) + The number of seeds to process within a single sampling call. + Manually tuning this parameter is not recommended but reducing + it may conserve GPU memory. The total number of seeds processed + per sampling call is equal to the sum of this parameter across + all workers. If not provided, it will be automatically + calculated. + See cugraph.gnn.DistSampler. + **kwargs + Other keyword arguments passed to the superclass. + """ + + subgraph_type = torch_geometric.sampler.base.SubgraphType(subgraph_type) + + if not directed: + subgraph_type = torch_geometric.sampler.base.SubgraphType.induced + warnings.warn( + "The 'directed' argument is deprecated. " + "Use subgraph_type='induced' instead." + ) + if subgraph_type != torch_geometric.sampler.base.SubgraphType.directional: + raise ValueError("Only directional subgraphs are currently supported") + if disjoint: + raise ValueError("Disjoint sampling is currently unsupported") + if temporal_strategy != "uniform": + warnings.warn("Only the uniform temporal strategy is currently supported") + if neighbor_sampler is not None: + raise ValueError("Passing a neighbor sampler is currently unsupported") + if time_attr is not None: + raise ValueError("Temporal sampling is currently unsupported") + if is_sorted: + warnings.warn("The 'is_sorted' argument is ignored by cuGraph.") + if not isinstance(data, (list, tuple)) or not isinstance( + data[1], cugraph_pyg.data.GraphStore + ): + # Will eventually automatically convert these objects to cuGraph objects. + raise NotImplementedError("Currently can't accept non-cugraph graphs") + + if compression is None: + compression = "CSR" + elif compression not in ["CSR", "COO"]: + raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')") + + writer = ( + None + if directory is None + else DistSampleWriter( + directory=directory, + batches_per_partition=batches_per_partition, + format=format, + ) + ) + + feature_store, graph_store = data + + if weight_attr is not None: + graph_store._set_weight_attr((feature_store, weight_attr)) + + sampler = BaseSampler( + NeighborSampler( + graph_store._graph, + writer, + retain_original_seeds=True, + fanout=num_neighbors, + prior_sources_behavior="exclude", + deduplicate_sources=True, + compression=compression, + compress_per_hop=False, + with_replacement=replace, + local_seeds_per_call=local_seeds_per_call, + biased=(weight_attr is not None), + ), + (feature_store, graph_store), + batch_size=batch_size, + ) + # TODO add heterogeneous support and pass graph_store._vertex_offsets + + super().__init__( + (feature_store, graph_store), + sampler, + edge_label_index=edge_label_index, + edge_label=edge_label, + edge_label_time=edge_label_time, + neg_sampling=neg_sampling, + neg_sampling_ratio=neg_sampling_ratio, + transform=transform, + transform_sampler_output=transform_sampler_output, + filter_per_worker=filter_per_worker, + batch_size=batch_size, + **kwargs, + ) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py index fe7d2eaeeef..4b236f75885 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py @@ -137,7 +137,7 @@ def __iter__(self): d = perm.numel() % self.__batch_size perm = perm[:-d] - input_data = torch_geometric.loader.node_loader.NodeSamplerInput( + input_data = torch_geometric.sampler.NodeSamplerInput( input_id=self.__input_data.input_id[perm], node=self.__input_data.node[perm], time=None diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py index e9714bd0316..58077d9a77a 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -468,6 +468,7 @@ def sample_from_edges( reader = self.__sampler.sample_from_edges( torch.stack([index.row, index.col]), # reverse of usual convention input_id=index.input_id, + batch_size=self.__batch_size, **kwargs, ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py index c4ad941de7a..ec7ca9a2318 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py @@ -16,6 +16,7 @@ from cugraph.datasets import karate from cugraph.utilities.utils import import_optional, MissingModule +import cugraph_pyg from cugraph_pyg.data import TensorDictFeatureStore, GraphStore from cugraph_pyg.loader import NeighborLoader @@ -86,3 +87,42 @@ def test_neighbor_loader_biased(): assert out.edge_index.shape[1] == 2 assert (out.edge_index.cpu() == torch.tensor([[3, 4], [1, 2]])).all() + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.sg +@pytest.mark.parametrize("num_nodes", [10, 25]) +@pytest.mark.parametrize("num_edges", [64, 128]) +@pytest.mark.parametrize("batch_size", [2, 4]) +@pytest.mark.parametrize("select_edges", [16, 32]) +@pytest.mark.parametrize("depth", [1, 3]) +@pytest.mark.parametrize("num_neighbors", [1, 4]) +def test_link_neighbor_loader_basic( + num_nodes, num_edges, batch_size, select_edges, num_neighbors, depth +): + graph_store = GraphStore() + feature_store = TensorDictFeatureStore() + + eix = torch.randperm(num_edges)[:select_edges] + graph_store[("n", "e", "n"), "coo"] = torch.stack( + [ + torch.randint(0, num_nodes, (num_edges,)), + torch.randint(0, num_nodes, (num_edges,)), + ] + ) + + elx = graph_store[("n", "e", "n"), "coo"][:, eix] + loader = cugraph_pyg.loader.LinkNeighborLoader( + (feature_store, graph_store), + num_neighbors=[num_neighbors] * depth, + edge_label_index=elx, + batch_size=batch_size, + shuffle=False, + ) + + elx = torch.tensor_split(elx, eix.numel() // batch_size, dim=1) + for i, batch in enumerate(loader): + assert ( + batch.input_id.cpu() == torch.arange(i * batch_size, (i + 1) * batch_size) + ).all() + assert (elx[i] == batch.n_id[batch.edge_label_index.cpu()]).all() diff --git a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py index b8089bb901d..e995f9378cd 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py @@ -19,7 +19,7 @@ from cugraph.utilities.utils import import_optional, MissingModule from cugraph_pyg.data import TensorDictFeatureStore, GraphStore -from cugraph_pyg.loader import NeighborLoader +from cugraph_pyg.loader import NeighborLoader, LinkNeighborLoader from cugraph.gnn import ( cugraph_comms_init, @@ -179,3 +179,75 @@ def test_neighbor_loader_biased_mg(): ), nprocs=world_size, ) + + +def run_test_link_neighbor_loader_basic_mg( + rank, + uid, + world_size, + num_nodes: int, + num_edges: int, + select_edges: int, + batch_size: int, + num_neighbors: int, + depth: int, +): + init_pytorch_worker(rank, world_size, uid) + + graph_store = GraphStore(is_multi_gpu=True) + feature_store = TensorDictFeatureStore() + + eix = torch.randperm(num_edges)[:select_edges] + graph_store[("n", "e", "n"), "coo"] = torch.stack( + [ + torch.randint(0, num_nodes, (num_edges,)), + torch.randint(0, num_nodes, (num_edges,)), + ] + ) + + elx = graph_store[("n", "e", "n"), "coo"][:, eix] + loader = LinkNeighborLoader( + (feature_store, graph_store), + num_neighbors=[num_neighbors] * depth, + edge_label_index=elx, + batch_size=batch_size, + shuffle=False, + ) + + elx = torch.tensor_split(elx, eix.numel() // batch_size, dim=1) + for i, batch in enumerate(loader): + assert ( + batch.input_id.cpu() == torch.arange(i * batch_size, (i + 1) * batch_size) + ).all() + assert (elx[i] == batch.n_id[batch.edge_label_index.cpu()]).all() + + cugraph_comms_shutdown() + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.mg +@pytest.mark.parametrize("select_edges", [64, 128]) +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +@pytest.mark.parametrize("depth", [1, 3]) +def test_link_neighbor_loader_basic_mg(select_edges, batch_size, depth): + num_nodes = 25 + num_edges = 128 + num_neighbors = 2 + + uid = cugraph_comms_create_unique_id() + world_size = torch.cuda.device_count() + + torch.multiprocessing.spawn( + run_test_link_neighbor_loader_basic_mg, + args=( + uid, + world_size, + num_nodes, + num_edges, + select_edges, + batch_size, + num_neighbors, + depth, + ), + nprocs=world_size, + ) diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py index af2b1fc181b..e121591d318 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -446,8 +446,8 @@ def __sample_from_edges_func( ).cumsum(-1) current_seeds, leftover_seeds = ( - current_seeds[:, num_whole_batches], - current_seeds[:, num_whole_batches:], + current_seeds[:, : (batch_size * num_whole_batches)], + current_seeds[:, (batch_size * num_whole_batches) :], ) # For input edges, we need to translate this into unique vertices @@ -499,7 +499,7 @@ def __sample_from_edges_func( device="cuda", dtype=torch.int32, ) - for i, a in enumerate(u) + for i, (a, _) in enumerate(u) ] ) del u