From 36c190ac8378527ac1e7445eb3773068f9179dea Mon Sep 17 00:00:00 2001 From: Alex Barghi <105237337+alexbarghi-nv@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:42:32 -0400 Subject: [PATCH] [FEA] Biased Sampling in cuGraph-DGL (#4595) Adds support for biased sampling to cuGraph-DGL. Resolves rapidsai/cugraph-gnn#25 Merge after #4583, #4586, #4607 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Ralph Liu (https://github.com/nv-rliu) - Seunghwa Kang (https://github.com/seunghwak) Approvers: - Ray Douglass (https://github.com/raydouglass) - Tingyu Wang (https://github.com/tingyu66) - Rick Ratzel (https://github.com/rlratzel) URL: https://github.com/rapidsai/cugraph/pull/4595 --- .../dataloading/neighbor_sampler.py | 19 ++- python/cugraph-dgl/cugraph_dgl/graph.py | 148 ++++++++++-------- .../tests/dataloading/test_dataloader.py | 55 ++++++- .../tests/dataloading/test_dataloader_mg.py | 81 +++++++++- .../cugraph/gnn/data_loading/dist_sampler.py | 2 +- .../pylibcugraph/biased_neighbor_sample.pyx | 2 +- .../pylibcugraph/uniform_neighbor_sample.pyx | 2 +- 7 files changed, 224 insertions(+), 85 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py index 87d111adcba..4ec513cbf9b 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py @@ -18,7 +18,7 @@ from typing import Sequence, Optional, Union, List, Tuple, Iterator -from cugraph.gnn import UniformNeighborSampler, DistSampleWriter +from cugraph.gnn import UniformNeighborSampler, BiasedNeighborSampler, DistSampleWriter from cugraph.utilities.utils import import_optional import cugraph_dgl @@ -93,7 +93,6 @@ def __init__( If provided, the probability of each neighbor being sampled is proportional to the edge feature with the given name. Mutually exclusive with mask. - Currently unsupported. mask: str Optional. If proivided, only neighbors where the edge mask @@ -133,10 +132,6 @@ def __init__( raise NotImplementedError( "Edge masking is currently unsupported by cuGraph-DGL" ) - if prob: - raise NotImplementedError( - "Edge masking is currently unsupported by cuGraph-DGL" - ) if prefetch_edge_feats: warnings.warn("'prefetch_edge_feats' is ignored by cuGraph-DGL") if prefetch_node_feats: @@ -146,6 +141,8 @@ def __init__( if fused: warnings.warn("'fused' is ignored by cuGraph-DGL") + self.__prob_attr = prob + self.fanouts = fanouts_per_layer reverse_fanouts = fanouts_per_layer.copy() reverse_fanouts.reverse() @@ -180,8 +177,14 @@ def sample( format=kwargs.pop("format", "parquet"), ) - ds = UniformNeighborSampler( - g._graph(self.edge_dir), + sampling_clx = ( + UniformNeighborSampler + if self.__prob_attr is None + else BiasedNeighborSampler + ) + + ds = sampling_clx( + g._graph(self.edge_dir, prob_attr=self.__prob_attr), writer, compression="CSR", fanout=self._reversed_fanout_vals, diff --git a/python/cugraph-dgl/cugraph_dgl/graph.py b/python/cugraph-dgl/cugraph_dgl/graph.py index 011ab736d00..138e645838a 100644 --- a/python/cugraph-dgl/cugraph_dgl/graph.py +++ b/python/cugraph-dgl/cugraph_dgl/graph.py @@ -312,7 +312,7 @@ def add_edges( self.__graph = None self.__vertex_offsets = None - def num_nodes(self, ntype: str = None) -> int: + def num_nodes(self, ntype: Optional[str] = None) -> int: """ Returns the number of nodes of ntype, or if ntype is not provided, the total number of nodes in the graph. @@ -322,7 +322,7 @@ def num_nodes(self, ntype: str = None) -> int: return self.__num_nodes_dict[ntype] - def number_of_nodes(self, ntype: str = None) -> int: + def number_of_nodes(self, ntype: Optional[str] = None) -> int: """ Alias for num_nodes. """ @@ -381,7 +381,7 @@ def _vertex_offsets(self) -> Dict[str, int]: return dict(self.__vertex_offsets) - def __get_edgelist(self) -> Dict[str, "torch.Tensor"]: + def __get_edgelist(self, prob_attr=None) -> Dict[str, "torch.Tensor"]: """ This function always returns src/dst labels with respect to the out direction. @@ -431,63 +431,71 @@ def __get_edgelist(self) -> Dict[str, "torch.Tensor"]: ) ) + num_edges_t = torch.tensor( + [self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda" + ) + if self.is_multi_gpu: rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() - num_edges_t = torch.tensor( - [self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda" - ) num_edges_all_t = torch.empty( world_size, num_edges_t.numel(), dtype=torch.int64, device="cuda" ) torch.distributed.all_gather_into_tensor(num_edges_all_t, num_edges_t) - if rank > 0: - start_offsets = num_edges_all_t[:rank].T.sum(axis=1) - edge_id_array = torch.concat( + start_offsets = num_edges_all_t[:rank].T.sum(axis=1) + + else: + rank = 0 + start_offsets = torch.zeros( + (len(sorted_keys),), dtype=torch.int64, device="cuda" + ) + num_edges_all_t = num_edges_t.reshape((1, num_edges_t.numel())) + + # Use pinned memory here for fast access to CPU/WG storage + edge_id_array_per_type = [ + torch.arange( + start_offsets[i], + start_offsets[i] + num_edges_all_t[rank][i], + dtype=torch.int64, + device="cpu", + ).pin_memory() + for i in range(len(sorted_keys)) + ] + + # Retrieve the weights from the appropriate feature(s) + # DGL implicitly requires all edge types use the same + # feature name. + if prob_attr is None: + weights = None + else: + if len(sorted_keys) > 1: + weights = torch.concat( [ - torch.arange( - start_offsets[i], - start_offsets[i] + num_edges_all_t[rank][i], - dtype=torch.int64, - device="cuda", - ) - for i in range(len(sorted_keys)) + self.edata[prob_attr][sorted_keys[i]][ix] + for i, ix in enumerate(edge_id_array_per_type) ] ) else: - edge_id_array = torch.concat( - [ - torch.arange( - self.__edge_indices[et].shape[1], - dtype=torch.int64, - device="cuda", - ) - for et in sorted_keys - ] - ) + weights = self.edata[prob_attr][edge_id_array_per_type[0]] - else: - # single GPU - edge_id_array = torch.concat( - [ - torch.arange( - self.__edge_indices[et].shape[1], - dtype=torch.int64, - device="cuda", - ) - for et in sorted_keys - ] - ) + # Safe to move this to cuda because the consumer will always + # move it to cuda if it isn't already there. + edge_id_array = torch.concat(edge_id_array_per_type).cuda() - return { + edgelist_dict = { "src": edge_index[0], "dst": edge_index[1], "etp": edge_type_array, "eid": edge_id_array, } + if weights is not None: + edgelist_dict["wgt"] = weights + + return edgelist_dict + @property def is_homogeneous(self): return len(self.__num_edges_dict) <= 1 and len(self.__num_nodes_dict) <= 1 @@ -508,7 +516,9 @@ def _resource_handle(self): return self.__handle def _graph( - self, direction: str + self, + direction: str, + prob_attr: Optional[str] = None, ) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]: """ Gets the pylibcugraph Graph object with edges pointing in the given direction @@ -522,12 +532,16 @@ def _graph( is_multigraph=True, is_symmetric=False ) - if self.__graph is not None and self.__graph[1] != direction: - self.__graph = None + if self.__graph is not None: + if ( + self.__graph["direction"] != direction + or self.__graph["prob_attr"] != prob_attr + ): + self.__graph = None if self.__graph is None: src_col, dst_col = ("src", "dst") if direction == "out" else ("dst", "src") - edgelist_dict = self.__get_edgelist() + edgelist_dict = self.__get_edgelist(prob_attr=prob_attr) if self.is_multi_gpu: rank = torch.distributed.get_rank() @@ -536,33 +550,35 @@ def _graph( vertices_array = cupy.arange(self.num_nodes(), dtype="int64") vertices_array = cupy.array_split(vertices_array, world_size)[rank] - self.__graph = ( - pylibcugraph.MGGraph( - self._resource_handle, - graph_properties, - [cupy.asarray(edgelist_dict[src_col]).astype("int64")], - [cupy.asarray(edgelist_dict[dst_col]).astype("int64")], - vertices_array=[vertices_array], - edge_id_array=[cupy.asarray(edgelist_dict["eid"])], - edge_type_array=[cupy.asarray(edgelist_dict["etp"])], - ), - direction, + graph = pylibcugraph.MGGraph( + self._resource_handle, + graph_properties, + [cupy.asarray(edgelist_dict[src_col]).astype("int64")], + [cupy.asarray(edgelist_dict[dst_col]).astype("int64")], + vertices_array=[vertices_array], + edge_id_array=[cupy.asarray(edgelist_dict["eid"])], + edge_type_array=[cupy.asarray(edgelist_dict["etp"])], + weight_array=[cupy.asarray(edgelist_dict["wgt"])] + if "wgt" in edgelist_dict + else None, ) else: - self.__graph = ( - pylibcugraph.SGGraph( - self._resource_handle, - graph_properties, - cupy.asarray(edgelist_dict[src_col]).astype("int64"), - cupy.asarray(edgelist_dict[dst_col]).astype("int64"), - vertices_array=cupy.arange(self.num_nodes(), dtype="int64"), - edge_id_array=cupy.asarray(edgelist_dict["eid"]), - edge_type_array=cupy.asarray(edgelist_dict["etp"]), - ), - direction, + graph = pylibcugraph.SGGraph( + self._resource_handle, + graph_properties, + cupy.asarray(edgelist_dict[src_col]).astype("int64"), + cupy.asarray(edgelist_dict[dst_col]).astype("int64"), + vertices_array=cupy.arange(self.num_nodes(), dtype="int64"), + edge_id_array=cupy.asarray(edgelist_dict["eid"]), + edge_type_array=cupy.asarray(edgelist_dict["etp"]), + weight_array=cupy.asarray(edgelist_dict["wgt"]) + if "wgt" in edgelist_dict + else None, ) - return self.__graph[0] + self.__graph = {"graph": graph, "direction": direction, "prob_attr": prob_attr} + + return self.__graph["graph"] def _has_n_emb(self, ntype: str, emb_name: str) -> bool: return (ntype, emb_name) in self.__ndata_storage diff --git a/python/cugraph-dgl/cugraph_dgl/tests/dataloading/test_dataloader.py b/python/cugraph-dgl/cugraph_dgl/tests/dataloading/test_dataloader.py index ef47875463d..419ec7790a9 100644 --- a/python/cugraph-dgl/cugraph_dgl/tests/dataloading/test_dataloader.py +++ b/python/cugraph-dgl/cugraph_dgl/tests/dataloading/test_dataloader.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import cugraph_dgl.dataloading import pytest @@ -48,9 +49,12 @@ def test_dataloader_basic_homogeneous(): assert len(out_t) <= 2 -def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1): +def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1, prob_attr=None): # Single fanout to match cugraph - sampler = dgl.dataloading.NeighborSampler(fanouts) + sampler = dgl.dataloading.NeighborSampler( + fanouts, + prob=prob_attr, + ) dataloader = dgl.dataloading.DataLoader( g, train_nid, @@ -71,8 +75,13 @@ def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1): return dgl_output -def sample_cugraph_dgl_graphs(cugraph_g, train_nid, fanouts, batch_size=1): - sampler = cugraph_dgl.dataloading.NeighborSampler(fanouts) +def sample_cugraph_dgl_graphs( + cugraph_g, train_nid, fanouts, batch_size=1, prob_attr=None +): + sampler = cugraph_dgl.dataloading.NeighborSampler( + fanouts, + prob=prob_attr, + ) dataloader = cugraph_dgl.dataloading.FutureDataLoader( cugraph_g, @@ -126,3 +135,41 @@ def test_same_homogeneousgraph_results(ix, batch_size): dgl_output[0]["blocks"][0].num_edges() == cugraph_output[0]["blocks"][0].num_edges() ) + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") +def test_dataloader_biased_homogeneous(): + src = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + dst = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) + wgt = torch.tensor([1, 1, 2, 0, 0, 0, 2, 1], dtype=torch.float32) + + train_nid = torch.tensor([0, 1]) + # Create a heterograph with 3 node types and 3 edges types. + dgl_g = dgl.graph((src, dst)) + dgl_g.edata["wgt"] = wgt + + cugraph_g = cugraph_dgl.Graph(is_multi_gpu=False) + cugraph_g.add_nodes(9) + cugraph_g.add_edges(u=src, v=dst, data={"wgt": wgt}) + + dgl_output = sample_dgl_graphs(dgl_g, train_nid, [4], batch_size=2, prob_attr="wgt") + cugraph_output = sample_cugraph_dgl_graphs( + cugraph_g, train_nid, [4], batch_size=2, prob_attr="wgt" + ) + + cugraph_output_nodes = cugraph_output[0]["output_nodes"].cpu().numpy() + dgl_output_nodes = dgl_output[0]["output_nodes"].cpu().numpy() + + np.testing.assert_array_equal( + np.sort(cugraph_output_nodes), np.sort(dgl_output_nodes) + ) + assert ( + dgl_output[0]["blocks"][0].num_dst_nodes() + == cugraph_output[0]["blocks"][0].num_dst_nodes() + ) + assert ( + dgl_output[0]["blocks"][0].num_edges() + == cugraph_output[0]["blocks"][0].num_edges() + ) + assert 5 == cugraph_output[0]["blocks"][0].num_edges() diff --git a/python/cugraph-dgl/cugraph_dgl/tests/dataloading/test_dataloader_mg.py b/python/cugraph-dgl/cugraph_dgl/tests/dataloading/test_dataloader_mg.py index b32233f16a6..061f4fa2077 100644 --- a/python/cugraph-dgl/cugraph_dgl/tests/dataloading/test_dataloader_mg.py +++ b/python/cugraph-dgl/cugraph_dgl/tests/dataloading/test_dataloader_mg.py @@ -82,9 +82,18 @@ def test_dataloader_basic_homogeneous(): ) -def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1): +def sample_dgl_graphs( + g, + train_nid, + fanouts, + batch_size=1, + prob_attr=None, +): # Single fanout to match cugraph - sampler = dgl.dataloading.NeighborSampler(fanouts) + sampler = dgl.dataloading.NeighborSampler( + fanouts, + prob=prob_attr, + ) dataloader = dgl.dataloading.DataLoader( g, train_nid, @@ -105,8 +114,17 @@ def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1): return dgl_output -def sample_cugraph_dgl_graphs(cugraph_g, train_nid, fanouts, batch_size=1): - sampler = cugraph_dgl.dataloading.NeighborSampler(fanouts) +def sample_cugraph_dgl_graphs( + cugraph_g, + train_nid, + fanouts, + batch_size=1, + prob_attr=None, +): + sampler = cugraph_dgl.dataloading.NeighborSampler( + fanouts, + prob=prob_attr, + ) dataloader = cugraph_dgl.dataloading.FutureDataLoader( cugraph_g, @@ -179,3 +197,58 @@ def test_same_homogeneousgraph_results_mg(ix, batch_size): args=(world_size, uid, ix, batch_size), nprocs=world_size, ) + + +def run_test_dataloader_biased_homogeneous(rank, world_size, uid): + init_pytorch_worker(rank, world_size, uid, True) + + src = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + (rank * 9) + dst = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) + (rank * 9) + wgt = torch.tensor( + [0.1, 0.1, 0.2, 0, 0, 0, 0.2, 0.1] * world_size, dtype=torch.float32 + ) + + train_nid = torch.tensor([0, 1]) + (rank * 9) + # Create a heterograph with 3 node types and 3 edge types. + dgl_g = dgl.graph((src, dst)) + dgl_g.edata["wgt"] = wgt[:8] + + cugraph_g = cugraph_dgl.Graph(is_multi_gpu=True) + cugraph_g.add_nodes(9 * world_size) + cugraph_g.add_edges(u=src, v=dst, data={"wgt": wgt}) + + dgl_output = sample_dgl_graphs(dgl_g, train_nid, [4], batch_size=2, prob_attr="wgt") + cugraph_output = sample_cugraph_dgl_graphs( + cugraph_g, train_nid, [4], batch_size=2, prob_attr="wgt" + ) + + cugraph_output_nodes = cugraph_output[0]["output_nodes"].cpu().numpy() + dgl_output_nodes = dgl_output[0]["output_nodes"].cpu().numpy() + + np.testing.assert_array_equal( + np.sort(cugraph_output_nodes), np.sort(dgl_output_nodes) + ) + assert ( + dgl_output[0]["blocks"][0].num_dst_nodes() + == cugraph_output[0]["blocks"][0].num_dst_nodes() + ) + assert ( + dgl_output[0]["blocks"][0].num_edges() + == cugraph_output[0]["blocks"][0].num_edges() + ) + + assert 5 == cugraph_output[0]["blocks"][0].num_edges() + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") +def test_dataloader_biased_homogeneous_mg(): + uid = cugraph_comms_create_unique_id() + # Limit the number of GPUs this test is run with + world_size = torch.cuda.device_count() + + torch.multiprocessing.spawn( + run_test_dataloader_biased_homogeneous, + args=(world_size, uid), + nprocs=world_size, + ) diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py index a49139961fd..52ffd8fadfd 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -776,7 +776,7 @@ def sample_batches( label_to_output_comm_rank=cupy.asarray(label_to_output_comm_rank), h_fan_out=np.array(self.__fanout, dtype="int32"), with_replacement=self.__with_replacement, - do_expensive_check=True, + do_expensive_check=False, with_edge_properties=True, random_state=random_state + rank, prior_sources_behavior=self.__prior_sources_behavior, diff --git a/python/pylibcugraph/pylibcugraph/biased_neighbor_sample.pyx b/python/pylibcugraph/pylibcugraph/biased_neighbor_sample.pyx index 77f1f04c394..2dd138d5d06 100644 --- a/python/pylibcugraph/pylibcugraph/biased_neighbor_sample.pyx +++ b/python/pylibcugraph/pylibcugraph/biased_neighbor_sample.pyx @@ -118,7 +118,7 @@ def biased_neighbor_sample(ResourceHandle resource_handle, Device array containing the list of starting vertices for sampling. h_fan_out: numpy array type - Host array containing the brancing out (fan-out) degrees per + Host array containing the branching out (fan-out) degrees per starting vertex for each hop level. with_replacement: bool diff --git a/python/pylibcugraph/pylibcugraph/uniform_neighbor_sample.pyx b/python/pylibcugraph/pylibcugraph/uniform_neighbor_sample.pyx index c25c9119985..f3e2336d8f6 100644 --- a/python/pylibcugraph/pylibcugraph/uniform_neighbor_sample.pyx +++ b/python/pylibcugraph/pylibcugraph/uniform_neighbor_sample.pyx @@ -117,7 +117,7 @@ def uniform_neighbor_sample(ResourceHandle resource_handle, Device array containing the list of starting vertices for sampling. h_fan_out: numpy array type - Host array containing the brancing out (fan-out) degrees per + Host array containing the branching out (fan-out) degrees per starting vertex for each hop level. with_replacement: bool