From 4167ca6527c14132ffe332593df2b79b34c95037 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Wed, 2 Oct 2024 11:41:09 -0700 Subject: [PATCH 1/6] improve tests, fix de-offset isssue, finish mg tests --- python/cugraph-dgl/cugraph_dgl/graph.py | 63 ++++++- .../cugraph_dgl/tests/test_graph.py | 99 +++++++++- .../cugraph_dgl/tests/test_graph_mg.py | 170 ++++++++++++++++++ python/cugraph-dgl/cugraph_dgl/view.py | 43 +++++ 4 files changed, 363 insertions(+), 12 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/graph.py b/python/cugraph-dgl/cugraph_dgl/graph.py index 138e645..8978fb8 100644 --- a/python/cugraph-dgl/cugraph_dgl/graph.py +++ b/python/cugraph-dgl/cugraph_dgl/graph.py @@ -106,6 +106,7 @@ def __init__( self.__graph = None self.__vertex_offsets = None + self.__edge_lookup_table = None self.__handle = None self.__is_multi_gpu = is_multi_gpu @@ -127,6 +128,11 @@ def __init__( def is_multi_gpu(self): return self.__is_multi_gpu + def _clear_graph(self): + self.__graph = None + self.__edge_lookup_table = None + self.__vertex_offsets = None + def to_canonical_etype( self, etype: Union[str, Tuple[str, str, str]] ) -> Tuple[str, str, str]: @@ -146,6 +152,20 @@ def to_canonical_etype( raise ValueError("Unknown relation type " + etype) + def _to_numeric_etype(self, etype: Union[str, Tuple[str, str, str]]) -> int: + if etype is None: + if len(self.canonical_etypes) > 1: + raise ValueError("Edge type is required for heterogeneous graphs.") + return 0 + + etype = self.to_canonical_etype(etype) + return { + k: i + for i, k in enumerate( + sorted(self.__edge_indices.keys(leaves_only=True, include_nested=True)) + ) + }[etype] + def add_nodes( self, global_num_nodes: int, @@ -217,8 +237,7 @@ def add_nodes( _cast_to_torch_tensor(feature_tensor), **self.__wg_kwargs ) - self.__graph = None - self.__vertex_offsets = None + self._clear_graph() def __check_node_ids(self, ntype: str, ids: TensorType): """ @@ -309,8 +328,7 @@ def add_edges( self.__num_edges_dict[dgl_can_edge_type] = int(num_edges) - self.__graph = None - self.__vertex_offsets = None + self._clear_graph() def num_nodes(self, ntype: Optional[str] = None) -> int: """ @@ -537,7 +555,7 @@ def _graph( self.__graph["direction"] != direction or self.__graph["prob_attr"] != prob_attr ): - self.__graph = None + self._clear_graph() if self.__graph is None: src_col, dst_col = ("src", "dst") if direction == "out" else ("dst", "src") @@ -620,9 +638,6 @@ def _get_n_emb( ) try: - print( - u, - ) return self.__ndata_storage[ntype, emb_name].fetch( _cast_to_torch_tensor(u), "cuda" ) @@ -895,6 +910,38 @@ def all_edges( else: raise ValueError(f"Invalid form {form}") + @property + def _edge_lookup_table(self): + if self.__edge_lookup_table is None: + self.__edge_lookup_table = pylibcugraph.EdgeIdLookupTable( + self._resource_handle, + self._graph("out") if self.__graph is None else self.__graph["graph"], + ) + + return self.__edge_lookup_table + + def find_edges( + self, eid: "torch.Tensor", etype: Union[str, Tuple[str, str, str]] = None + ) -> Tuple["torch.Tensor", "torch.Tensor"]: + """ + Looks up and returns the appropriate src/dst pairs given a sequence of edge + ids and an edge type. + """ + + # Have to properly de-offset the vertices based on edge type + etype = self.to_canonical_etype(etype) + num_edge_type = self._to_numeric_etype(etype) + out = self._edge_lookup_table.find(cupy.asarray(eid), num_edge_type) + + src_name = "sources" if self.__graph["direction"] == "out" else "destinations" + dst_name = "destinations" if self.__graph["direction"] == "out" else "sources" + offsets = self._vertex_offsets + + return ( + torch.as_tensor(out[src_name], device="cuda") - offsets[etype[0]], + torch.as_tensor(out[dst_name], device="cuda") - offsets[etype[2]], + ) + @property def ndata(self) -> HeteroNodeDataView: """ diff --git a/python/cugraph-dgl/cugraph_dgl/tests/test_graph.py b/python/cugraph-dgl/cugraph_dgl/tests/test_graph.py index a60db97..51ffeb1 100644 --- a/python/cugraph-dgl/cugraph_dgl/tests/test_graph.py +++ b/python/cugraph-dgl/cugraph_dgl/tests/test_graph.py @@ -53,10 +53,11 @@ def test_graph_make_homogeneous_graph(direction): graph.nodes() == torch.arange(num_nodes, dtype=torch.int64, device="cuda") ).all() - assert graph.nodes[None]["x"] is not None - assert (graph.nodes[None]["x"] == torch.as_tensor(node_x, device="cuda")).all() + emb = graph.nodes[None]["x"] + assert emb is not None + assert (emb() == torch.as_tensor(node_x, device="cuda")).all() assert ( - graph.nodes[None]["num"] + graph.nodes[None]["num"]() == torch.arange(num_nodes, dtype=torch.int64, device="cuda") ).all() @@ -64,7 +65,7 @@ def test_graph_make_homogeneous_graph(direction): graph.edges("eid", device="cuda") == torch.arange(len(df), dtype=torch.int64, device="cuda") ).all() - assert (graph.edges[None]["weight"] == torch.as_tensor(wgt, device="cuda")).all() + assert (graph.edges[None]["weight"]() == torch.as_tensor(wgt, device="cuda")).all() plc_expected_graph = pylibcugraph.SGGraph( pylibcugraph.ResourceHandle(), @@ -215,3 +216,93 @@ def test_graph_make_heterogeneous_graph(direction): assert ( dsts[eid] == int(sampling_output[dst_col][i]) - expected_offsets[etype][1] ) + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") +@pytest.mark.parametrize("direction", ["out", "in"]) +def test_graph_find(direction): + df = karate.get_edgelist() + df.src = df.src.astype("int64") + df.dst = df.dst.astype("int64") + + graph = cugraph_dgl.Graph() + total_num_nodes = max(df.src.max(), df.dst.max()) + 1 + + num_nodes_group_1 = total_num_nodes // 2 + num_nodes_group_2 = total_num_nodes - num_nodes_group_1 + + node_x_1 = np.random.random((num_nodes_group_1,)) + node_x_2 = np.random.random((num_nodes_group_2,)) + + graph.add_nodes(num_nodes_group_1, {"x": node_x_1}, "type1") + graph.add_nodes(num_nodes_group_2, {"x": node_x_2}, "type2") + + edges_11 = df[(df.src < num_nodes_group_1) & (df.dst < num_nodes_group_1)] + edges_12 = df[(df.src < num_nodes_group_1) & (df.dst >= num_nodes_group_1)] + edges_21 = df[(df.src >= num_nodes_group_1) & (df.dst < num_nodes_group_1)] + edges_22 = df[(df.src >= num_nodes_group_1) & (df.dst >= num_nodes_group_1)] + + edges_12.dst -= num_nodes_group_1 + edges_21.src -= num_nodes_group_1 + edges_22.dst -= num_nodes_group_1 + edges_22.src -= num_nodes_group_1 + + graph.add_edges(edges_11.src, edges_11.dst, etype=("type1", "e1", "type1")) + graph.add_edges(edges_12.src, edges_12.dst, etype=("type1", "e2", "type2")) + graph.add_edges(edges_21.src, edges_21.dst, etype=("type2", "e3", "type1")) + graph.add_edges(edges_22.src, edges_22.dst, etype=("type2", "e4", "type2")) + + # force direction generation to make sure in case is tested + graph._graph(direction) + + assert not graph.is_homogeneous + assert not graph.is_multi_gpu + + srcs, dsts = graph.find_edges( + torch.as_tensor([0, len(edges_11) - 1, 999], dtype=torch.int64), + ("type1", "e1", "type1"), + ) + assert ( + srcs[[0, 1]] == torch.tensor([1, 6], device="cuda", dtype=torch.int64) + ).all() + assert ( + dsts[[0, 1]] == torch.tensor([0, 16], device="cuda", dtype=torch.int64) + ).all() + assert srcs[2] < 0 and dsts[2] < 0 + + srcs, dsts = graph.find_edges( + torch.as_tensor([0, len(edges_12) - 1, 999], dtype=torch.int64), + ("type1", "e2", "type2"), + ) + assert ( + srcs[[0, 1]] == torch.tensor([0, 15], device="cuda", dtype=torch.int64) + ).all() + assert ( + dsts[[0, 1]] == torch.tensor([0, 16], device="cuda", dtype=torch.int64) + ).all() + assert srcs[2] < 0 and dsts[2] < 0 + + srcs, dsts = graph.find_edges( + torch.as_tensor([0, len(edges_21) - 1, 999], dtype=torch.int64), + ("type2", "e3", "type1"), + ) + assert ( + srcs[[0, 1]] == torch.tensor([0, 16], device="cuda", dtype=torch.int64) + ).all() + assert ( + dsts[[0, 1]] == torch.tensor([0, 15], device="cuda", dtype=torch.int64) + ).all() + assert srcs[2] < 0 and dsts[2] < 0 + + srcs, dsts = graph.find_edges( + torch.as_tensor([0, len(edges_22) - 1, 999], dtype=torch.int64), + ("type2", "e4", "type2"), + ) + assert ( + srcs[[0, 1]] == torch.tensor([15, 15], device="cuda", dtype=torch.int64) + ).all() + assert ( + dsts[[0, 1]] == torch.tensor([1, 16], device="cuda", dtype=torch.int64) + ).all() + assert srcs[2] < 0 and dsts[2] < 0 diff --git a/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py b/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py index eedda66..489ad90 100644 --- a/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py +++ b/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py @@ -308,3 +308,173 @@ def test_graph_make_heterogeneous_graph_mg(direction): ), nprocs=world_size, ) + + +def run_test_graph_find_simple_mg(rank, world_size, uid, direction): + init_pytorch_worker(rank, world_size, uid) + df = karate.get_edgelist() + + total_num_nodes = max(df.src.max(), df.dst.max()) + 1 + + num_nodes_group_1 = total_num_nodes // 2 + num_nodes_group_2 = total_num_nodes - num_nodes_group_1 + + node_x_1 = np.array_split(np.random.random((num_nodes_group_1,)), world_size)[rank] + node_x_2 = np.array_split(np.random.random((num_nodes_group_2,)), world_size)[rank] + + graph = cugraph_dgl.Graph(is_multi_gpu=True) + graph.add_nodes(num_nodes_group_1, {"x": node_x_1}, "type1") + graph.add_nodes(num_nodes_group_2, {"x": node_x_2}, "type2") + + edges_11 = df[(df.src < num_nodes_group_1) & (df.dst < num_nodes_group_1)] + edges_12 = df[(df.src < num_nodes_group_1) & (df.dst >= num_nodes_group_1)] + edges_21 = df[(df.src >= num_nodes_group_1) & (df.dst < num_nodes_group_1)] + edges_22 = df[(df.src >= num_nodes_group_1) & (df.dst >= num_nodes_group_1)] + + edges_12.dst -= num_nodes_group_1 + edges_21.src -= num_nodes_group_1 + edges_22.dst -= num_nodes_group_1 + edges_22.src -= num_nodes_group_1 + + edges_11_local = edges_11.iloc[ + np.array_split(np.arange(len(edges_11)), world_size)[rank] + ] + edges_12_local = edges_12.iloc[ + np.array_split(np.arange(len(edges_12)), world_size)[rank] + ] + edges_21_local = edges_21.iloc[ + np.array_split(np.arange(len(edges_21)), world_size)[rank] + ] + edges_22_local = edges_22.iloc[ + np.array_split(np.arange(len(edges_22)), world_size)[rank] + ] + + graph.add_edges( + edges_11_local.src, edges_11_local.dst, etype=("type1", "e1", "type1") + ) + graph.add_edges( + edges_12_local.src, edges_12_local.dst, etype=("type1", "e2", "type2") + ) + graph.add_edges( + edges_21_local.src, edges_21_local.dst, etype=("type2", "e3", "type1") + ) + graph.add_edges( + edges_22_local.src, edges_22_local.dst, etype=("type2", "e4", "type2") + ) + + # force direction generation to make sure in case is tested + graph._graph(direction) + + assert not graph.is_homogeneous + assert graph.is_multi_gpu + + if len(edges_11) > 0: + srcs, dsts = graph.find_edges( + torch.as_tensor([0, len(edges_11) - 1, 999], dtype=torch.int64), + ("type1", "e1", "type1"), + ) + assert ( + srcs[[0, 1]] + == torch.tensor( + [edges_11.src.iloc[0], edges_11.src.iloc[-1]], + device="cuda", + dtype=torch.int64, + ) + ).all() + assert ( + dsts[[0, 1]] + == torch.tensor( + [edges_11.dst.iloc[0], edges_11.dst.iloc[-1]], + device="cuda", + dtype=torch.int64, + ) + ).all() + assert srcs[2] < 0 and dsts[2] < 0 + if len(edges_12) > 0: + srcs, dsts = graph.find_edges( + torch.as_tensor([0, len(edges_12) - 1, 999], dtype=torch.int64), + ("type1", "e2", "type2"), + ) + assert ( + srcs[[0, 1]] + == torch.tensor( + [edges_12.src.iloc[0], edges_12.src.iloc[-1]], + device="cuda", + dtype=torch.int64, + ) + ).all() + assert ( + dsts[[0, 1]] + == torch.tensor( + [edges_12.dst.iloc[0], edges_12.dst.iloc[-1]], + device="cuda", + dtype=torch.int64, + ) + ).all() + assert srcs[2] < 0 and dsts[2] < 0 + if len(edges_21) > 0: + srcs, dsts = graph.find_edges( + torch.as_tensor([0, len(edges_21) - 1, 999], dtype=torch.int64), + ("type2", "e3", "type1"), + ) + assert ( + srcs[[0, 1]] + == torch.tensor( + [edges_21.src.iloc[0], edges_21.src.iloc[-1]], + device="cuda", + dtype=torch.int64, + ) + ).all() + assert ( + dsts[[0, 1]] + == torch.tensor( + [edges_21.dst.iloc[0], edges_21.dst.iloc[-1]], + device="cuda", + dtype=torch.int64, + ) + ).all() + assert srcs[2] < 0 and dsts[2] < 0 + if len(edges_22) > 0: + srcs, dsts = graph.find_edges( + torch.as_tensor([0, len(edges_22) - 1, 999], dtype=torch.int64), + ("type2", "e4", "type2"), + ) + assert ( + srcs[[0, 1]] + == torch.tensor( + [edges_22.src.iloc[0], edges_22.src.iloc[-1]], + device="cuda", + dtype=torch.int64, + ) + ).all() + assert ( + dsts[[0, 1]] + == torch.tensor( + [edges_22.dst.iloc[0], edges_22.dst.iloc[-1]], + device="cuda", + dtype=torch.int64, + ) + ).all() + assert srcs[2] < 0 and dsts[2] < 0 + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") +@pytest.mark.parametrize("direction", ["out", "in"]) +def test_graph_find_mg(direction): + df = karate.get_edgelist() + df.src = df.src.astype("int64") + df.dst = df.dst.astype("int64") + + uid = cugraph_comms_create_unique_id() + world_size = torch.cuda.device_count() + + torch.multiprocessing.spawn( + run_test_graph_find_simple_mg, + args=( + world_size, + uid, + direction, + ), + nprocs=world_size, + ) diff --git a/python/cugraph-dgl/cugraph_dgl/view.py b/python/cugraph-dgl/cugraph_dgl/view.py index dbc53e7..7c4d95f 100644 --- a/python/cugraph-dgl/cugraph_dgl/view.py +++ b/python/cugraph-dgl/cugraph_dgl/view.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections import defaultdict from collections.abc import MutableMapping @@ -20,11 +21,53 @@ import cugraph_dgl from cugraph_dgl.typing import TensorType +from cugraph_dgl.utils.cugraph_conversion_utils import _cast_to_torch_tensor torch = import_optional("torch") dgl = import_optional("dgl") +class EmbeddingView: + def __init__(self, storage: "dgl.storages.base.FeatureStorage", ld: int): + self.__ld = ld + self.__storage = storage + + def __getitem__(self, u: TensorType) -> "torch.Tensor": + u = _cast_to_torch_tensor(u) + try: + return self.__storage.fetch( + u, + "cuda", + ) + except RuntimeError as ex: + warnings.warn( + "Got error accessing data, trying again with index on device: " + + str(ex) + ) + return self.__storage.fetch( + u.cuda(), + "cuda", + ) + + def __call__(self): + warnings.warn( + "Getting an entire embedding tensor is not recommended " + " as it wastes memory. Consider indexing to get only the " + "required elements of the embedding tensor." + ) + return self[torch.arange(self.__ld, dtype=torch.int64)] + + @property + def shape(self) -> "torch.Size": + try: + f = self.__storage.fetch(torch.tensor([0]), "cpu") + except RuntimeError: + f = self.__storage.fetch(torch.tensor([0], device="cuda"), "cuda") + sz = [s for s in f.shape] + sz[0] = self.__ld + return torch.Size(tuple(sz)) + + class HeteroEdgeDataView(MutableMapping): """ Duck-typed version of DGL's HeteroEdgeDataView. From f645023bc02f8137facf9838aca8703de43f4883 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Fri, 18 Oct 2024 09:15:58 -0700 Subject: [PATCH 2/6] update for changes in other PR --- python/cugraph-dgl/cugraph_dgl/graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/cugraph-dgl/cugraph_dgl/graph.py b/python/cugraph-dgl/cugraph_dgl/graph.py index 8978fb8..913c1e7 100644 --- a/python/cugraph-dgl/cugraph_dgl/graph.py +++ b/python/cugraph-dgl/cugraph_dgl/graph.py @@ -931,7 +931,9 @@ def find_edges( # Have to properly de-offset the vertices based on edge type etype = self.to_canonical_etype(etype) num_edge_type = self._to_numeric_etype(etype) - out = self._edge_lookup_table.find(cupy.asarray(eid), num_edge_type) + out = self._edge_lookup_table.lookup_vertex_ids( + cupy.asarray(eid), num_edge_type + ) src_name = "sources" if self.__graph["direction"] == "out" else "destinations" dst_name = "destinations" if self.__graph["direction"] == "out" else "sources" From 550e8c377eeb372188606e3126d1b12e9056d8f6 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Fri, 18 Oct 2024 10:28:13 -0700 Subject: [PATCH 3/6] implement global_uniform_negative_sampling, fix bugs --- python/cugraph-dgl/cugraph_dgl/graph.py | 108 ++++++++++++++++++ .../cugraph_pyg/sampler/sampler.py | 5 + .../cugraph_pyg/sampler/sampler_utils.py | 2 +- 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/python/cugraph-dgl/cugraph_dgl/graph.py b/python/cugraph-dgl/cugraph_dgl/graph.py index 913c1e7..2add7d5 100644 --- a/python/cugraph-dgl/cugraph_dgl/graph.py +++ b/python/cugraph-dgl/cugraph_dgl/graph.py @@ -944,6 +944,114 @@ def find_edges( torch.as_tensor(out[dst_name], device="cuda") - offsets[etype[2]], ) + def global_uniform_negative_sampling( + self, + num_samples: int, + exclude_self_loops: bool = True, + replace: bool = False, + etype: Optional[Union[str, Tuple[str, str, str]]] = None, + redundancy: Optional[float] = None, + ): + """ + Performs negative sampling, which constructs a set of source and destination + pairs that do not exist in this graph. + + Parameters + ---------- + num_samples: int + Target number of negative edges to generate. May generate less depending + on whether the existing set of edges allows for it. + exclude_self_loops: bool + Whether to drop edges where the source and destination is the same. + Defaults to True. + replace: bool + Whether to sample with replacement. Sampling with replacement is not + supported by the cuGraph-DGL generator. Defaults to False. + etype: str or tuple[str, str, str] (Optional) + The edge type to generate negative edges for. Optional if there is + only one edge type in the graph. + redundancy: float (Optional) + Not supported by the cuGraph-DGL generator. + """ + + if redundancy: + warnings.warn("The 'redudancy' parameter is ignored by cuGraph-DGL.") + if replace: + raise NotImplementedError( + "Negative sampling with replacement is not supported by cuGraph-DGL." + ) + + if len(self.ntypes) == 1: + vertices = torch.arange(self.num_nodes()) + else: + can_edge_type = self.to_canonical_etype(etype) + # Limit sampled vertices to those of the given edge type. + vertices = torch.concat( + [ + torch.arange( + self._vertex_offsets[can_edge_type[0]], + self._vertex_offsets[can_edge_type[0]] + + self.num_nodes(can_edge_type[0]), + dtype=torch.int64, + device="cuda", + ), + torch.arange( + self._vertex_offsets[can_edge_type[2]], + self._vertex_offsets[can_edge_type[2]] + + self.num_nodes(can_edge_type[2]), + dtype=torch.int64, + device="cuda", + ), + ] + ) + + if self.is_multi_gpu: + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + num_samples_global = torch.tensor([num_samples], device="cuda") + torch.distributed.all_reduce( + num_samples_global, op=torch.distributed.ReduceOp.SUM + ) + num_samples_global = int(num_samples_global) + + vertices = torch.tensor_split(vertices, world_size)[rank] + else: + num_samples_global = num_samples + + graph = ( + self.__graph + if self.__graph["direction"] == "out" + else self._graph("out", self.__graph["prob_attr"]) + ) + bias = cupy.ones(len(vertices), dtype="float32") + + result_dict = pylibcugraph.negative_sampling( + self._resource_handle, + graph, + num_samples_global, + vertices=cupy.asarray(vertices), + src_bias=bias, + dst_bias=bias, + remove_duplicates=True, + remove_false_negatives=True, + exact_number_of_samples=True, + do_expensive_check=False, + ) + + # TODO remove this workaround once the C API is updated to take a local number + # of negatives (rapidsai/cugraph#4672) + src_neg = torch.as_tensor(result_dict["sources"], device="cuda")[:num_samples] + dst_neg = torch.as_tensor(result_dict["destinations"], device="cuda")[ + :num_samples + ] + + if exclude_self_loops: + f = src_neg != dst_neg + return src_neg[f], dst_neg[f] + else: + return src_neg, dst_neg + @property def ndata(self) -> HeteroNodeDataView: """ diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py index bc3d4fd..b4a2dd0 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -484,6 +484,11 @@ def sample_from_edges( ]: src = index.row dst = index.col + + if index.input_type is not None: + src += self.__graph_store._vertex_offsets[index.input_type[0]] + dst += self.__graph_store._vertex_offsets[index.input_type[1]] + input_id = index.input_id neg_batch_size = 0 if neg_sampling: diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py index b3d56ef..ee523a8 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py @@ -462,7 +462,7 @@ def neg_sample( if graph_store.is_multi_gpu: num_neg_global = torch.tensor([num_neg], device="cuda") torch.distributed.all_reduce(num_neg_global, op=torch.distributed.ReduceOp.SUM) - num_neg = int(num_neg_global) + num_neg_global = int(num_neg_global) else: num_neg_global = num_neg From b4cd8deffb3f4125923c8db59902dd7117ea71d5 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Fri, 18 Oct 2024 13:13:52 -0700 Subject: [PATCH 4/6] fixes, add/improve tests --- python/cugraph-dgl/cugraph_dgl/graph.py | 91 ++++++++---- .../cugraph-dgl/cugraph_dgl/tests/conftest.py | 68 +++++++++ .../cugraph_dgl/tests/test_graph.py | 139 +++++++++--------- 3 files changed, 198 insertions(+), 100 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/graph.py b/python/cugraph-dgl/cugraph_dgl/graph.py index 2add7d5..c065f14 100644 --- a/python/cugraph-dgl/cugraph_dgl/graph.py +++ b/python/cugraph-dgl/cugraph_dgl/graph.py @@ -983,27 +983,57 @@ def global_uniform_negative_sampling( if len(self.ntypes) == 1: vertices = torch.arange(self.num_nodes()) + src_vertex_offset = 0 + dst_vertex_offset = 0 + src_bias = cupy.ones(len(vertices), dtype="float32") + dst_bias = src_bias else: can_edge_type = self.to_canonical_etype(etype) + src_vertex_offset = self._vertex_offsets[can_edge_type[0]] + dst_vertex_offset = self._vertex_offsets[can_edge_type[2]] + # Limit sampled vertices to those of the given edge type. - vertices = torch.concat( - [ - torch.arange( - self._vertex_offsets[can_edge_type[0]], - self._vertex_offsets[can_edge_type[0]] - + self.num_nodes(can_edge_type[0]), - dtype=torch.int64, - device="cuda", - ), - torch.arange( - self._vertex_offsets[can_edge_type[2]], - self._vertex_offsets[can_edge_type[2]] - + self.num_nodes(can_edge_type[2]), - dtype=torch.int64, - device="cuda", - ), - ] - ) + if can_edge_type[0] == can_edge_type[2]: + vertices = torch.arange( + src_vertex_offset, + src_vertex_offset + self.num_nodes(can_edge_type[0]), + dtype=torch.int64, + device="cuda", + ) + src_bias = cupy.ones(self.num_nodes(can_edge_type[0]), dtype="float32") + dst_bias = src_bias + + else: + vertices = torch.concat( + [ + torch.arange( + src_vertex_offset, + src_vertex_offset + self.num_nodes(can_edge_type[0]), + dtype=torch.int64, + device="cuda", + ), + torch.arange( + dst_vertex_offset, + dst_vertex_offset + self.num_nodes(can_edge_type[2]), + dtype=torch.int64, + device="cuda", + ), + ] + ) + + src_bias = cupy.concatenate( + [ + cupy.ones(self.num_nodes(can_edge_type[0]), dtype="float32"), + cupy.zeros(self.num_nodes(can_edge_type[2]), dtype="float32"), + ] + ) + + dst_bias = cupy.concatenate( + [ + cupy.zeros(self.num_nodes(can_edge_type[0]), dtype="float32"), + cupy.ones(self.num_nodes(can_edge_type[2]), dtype="float32"), + ] + ) if self.is_multi_gpu: rank = torch.distributed.get_rank() @@ -1020,19 +1050,20 @@ def global_uniform_negative_sampling( num_samples_global = num_samples graph = ( - self.__graph - if self.__graph["direction"] == "out" - else self._graph("out", self.__graph["prob_attr"]) + self.__graph["graph"] + if self.__graph is not None and self.__graph["direction"] == "out" + else self._graph( + "out", None if self.__graph is None else self.__graph["prob_attr"] + ) ) - bias = cupy.ones(len(vertices), dtype="float32") result_dict = pylibcugraph.negative_sampling( self._resource_handle, graph, num_samples_global, vertices=cupy.asarray(vertices), - src_bias=bias, - dst_bias=bias, + src_bias=src_bias, + dst_bias=dst_bias, remove_duplicates=True, remove_false_negatives=True, exact_number_of_samples=True, @@ -1041,10 +1072,14 @@ def global_uniform_negative_sampling( # TODO remove this workaround once the C API is updated to take a local number # of negatives (rapidsai/cugraph#4672) - src_neg = torch.as_tensor(result_dict["sources"], device="cuda")[:num_samples] - dst_neg = torch.as_tensor(result_dict["destinations"], device="cuda")[ - :num_samples - ] + src_neg = ( + torch.as_tensor(result_dict["sources"], device="cuda")[:num_samples] + - src_vertex_offset + ) + dst_neg = ( + torch.as_tensor(result_dict["destinations"], device="cuda")[:num_samples] + - dst_vertex_offset + ) if exclude_self_loops: f = src_neg != dst_neg diff --git a/python/cugraph-dgl/cugraph_dgl/tests/conftest.py b/python/cugraph-dgl/cugraph_dgl/tests/conftest.py index ee1183f..0f9f890 100644 --- a/python/cugraph-dgl/cugraph_dgl/tests/conftest.py +++ b/python/cugraph-dgl/cugraph_dgl/tests/conftest.py @@ -15,12 +15,17 @@ import dgl import torch +import numpy as np + +import cugraph_dgl from cugraph.testing.mg_utils import ( start_dask_client, stop_dask_client, ) +from cugraph.datasets import karate + @pytest.fixture(scope="module") def dask_client(): @@ -66,3 +71,66 @@ def dgl_graph_1(): src = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9]) dst = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0]) return dgl.graph((src, dst)) + + +def create_karate_bipartite(multi_gpu: bool = False): + df = karate.get_edgelist() + df.src = df.src.astype("int64") + df.dst = df.dst.astype("int64") + + graph = cugraph_dgl.Graph(is_multi_gpu=multi_gpu) + total_num_nodes = max(df.src.max(), df.dst.max()) + 1 + + num_nodes_group_1 = total_num_nodes // 2 + num_nodes_group_2 = total_num_nodes - num_nodes_group_1 + + node_x_1 = np.random.random((num_nodes_group_1,)) + node_x_2 = np.random.random((num_nodes_group_2,)) + + graph.add_nodes(num_nodes_group_1, {"x": node_x_1}, "type1") + graph.add_nodes(num_nodes_group_2, {"x": node_x_2}, "type2") + + edges = {} + edges["type1", "e1", "type1"] = df[ + (df.src < num_nodes_group_1) & (df.dst < num_nodes_group_1) + ] + edges["type1", "e2", "type2"] = df[ + (df.src < num_nodes_group_1) & (df.dst >= num_nodes_group_1) + ] + edges["type2", "e3", "type1"] = df[ + (df.src >= num_nodes_group_1) & (df.dst < num_nodes_group_1) + ] + edges["type2", "e4", "type2"] = df[ + (df.src >= num_nodes_group_1) & (df.dst >= num_nodes_group_1) + ] + + edges["type1", "e2", "type2"].dst -= num_nodes_group_1 + edges["type2", "e3", "type1"].src -= num_nodes_group_1 + edges["type2", "e4", "type2"].dst -= num_nodes_group_1 + edges["type2", "e4", "type2"].src -= num_nodes_group_1 + + if multi_gpu: + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + edges_local = { + etype: edf.iloc[np.array_split(np.arange(edf), world_size)[rank]] + for etype, edf in edges + } + else: + edges_local = edges + + for etype, edf in edges_local.items(): + graph.add_edges(edf.src, edf.dst, etype=etype) + + return graph, edges, (num_nodes_group_1, num_nodes_group_2) + + +@pytest.fixture +def karate_bipartite(): + return create_karate_bipartite(False) + + +@pytest.fixture +def karate_bipartite_mg(): + return create_karate_bipartite(True) diff --git a/python/cugraph-dgl/cugraph_dgl/tests/test_graph.py b/python/cugraph-dgl/cugraph_dgl/tests/test_graph.py index 51ffeb1..bacde26 100644 --- a/python/cugraph-dgl/cugraph_dgl/tests/test_graph.py +++ b/python/cugraph-dgl/cugraph_dgl/tests/test_graph.py @@ -99,44 +99,18 @@ def test_graph_make_homogeneous_graph(direction): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") @pytest.mark.parametrize("direction", ["out", "in"]) -def test_graph_make_heterogeneous_graph(direction): - df = karate.get_edgelist() - df.src = df.src.astype("int64") - df.dst = df.dst.astype("int64") - - graph = cugraph_dgl.Graph() - total_num_nodes = max(df.src.max(), df.dst.max()) + 1 - - num_nodes_group_1 = total_num_nodes // 2 - num_nodes_group_2 = total_num_nodes - num_nodes_group_1 - - node_x_1 = np.random.random((num_nodes_group_1,)) - node_x_2 = np.random.random((num_nodes_group_2,)) - - graph.add_nodes(num_nodes_group_1, {"x": node_x_1}, "type1") - graph.add_nodes(num_nodes_group_2, {"x": node_x_2}, "type2") - - edges_11 = df[(df.src < num_nodes_group_1) & (df.dst < num_nodes_group_1)] - edges_12 = df[(df.src < num_nodes_group_1) & (df.dst >= num_nodes_group_1)] - edges_21 = df[(df.src >= num_nodes_group_1) & (df.dst < num_nodes_group_1)] - edges_22 = df[(df.src >= num_nodes_group_1) & (df.dst >= num_nodes_group_1)] - - edges_12.dst -= num_nodes_group_1 - edges_21.src -= num_nodes_group_1 - edges_22.dst -= num_nodes_group_1 - edges_22.src -= num_nodes_group_1 - - graph.add_edges(edges_11.src, edges_11.dst, etype=("type1", "e1", "type1")) - graph.add_edges(edges_12.src, edges_12.dst, etype=("type1", "e2", "type2")) - graph.add_edges(edges_21.src, edges_21.dst, etype=("type2", "e3", "type1")) - graph.add_edges(edges_22.src, edges_22.dst, etype=("type2", "e4", "type2")) +def test_graph_make_heterogeneous_graph(direction, karate_bipartite): + graph, edges, (num_nodes_group_1, num_nodes_group_2) = karate_bipartite assert not graph.is_homogeneous assert not graph.is_multi_gpu # Verify graph.nodes() assert ( - graph.nodes() == torch.arange(total_num_nodes, dtype=torch.int64, device="cuda") + graph.nodes() + == torch.arange( + num_nodes_group_1 + num_nodes_group_2, dtype=torch.int64, device="cuda" + ) ).all() assert ( graph.nodes("type1") @@ -150,19 +124,27 @@ def test_graph_make_heterogeneous_graph(direction): # Verify graph.edges() assert ( graph.edges("eid", etype=("type1", "e1", "type1")) - == torch.arange(len(edges_11), dtype=torch.int64, device="cuda") + == torch.arange( + len(edges["type1", "e1", "type1"]), dtype=torch.int64, device="cuda" + ) ).all() assert ( graph.edges("eid", etype=("type1", "e2", "type2")) - == torch.arange(len(edges_12), dtype=torch.int64, device="cuda") + == torch.arange( + len(edges["type1", "e2", "type2"]), dtype=torch.int64, device="cuda" + ) ).all() assert ( graph.edges("eid", etype=("type2", "e3", "type1")) - == torch.arange(len(edges_21), dtype=torch.int64, device="cuda") + == torch.arange( + len(edges["type2", "e3", "type1"]), dtype=torch.int64, device="cuda" + ) ).all() assert ( graph.edges("eid", etype=("type2", "e4", "type2")) - == torch.arange(len(edges_22), dtype=torch.int64, device="cuda") + == torch.arange( + len(edges["type2", "e4", "type2"]), dtype=torch.int64, device="cuda" + ) ).all() # Use sampling call to check graph creation @@ -172,7 +154,7 @@ def test_graph_make_heterogeneous_graph(direction): sampling_output = pylibcugraph.uniform_neighbor_sample( pylibcugraph.ResourceHandle(), plc_graph, - start_list=cupy.arange(total_num_nodes, dtype="int64"), + start_list=cupy.arange(num_nodes_group_1 + num_nodes_group_2, dtype="int64"), h_fan_out=np.array([1, 1], dtype="int32"), with_replacement=False, do_expensive_check=True, @@ -221,37 +203,8 @@ def test_graph_make_heterogeneous_graph(direction): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") @pytest.mark.parametrize("direction", ["out", "in"]) -def test_graph_find(direction): - df = karate.get_edgelist() - df.src = df.src.astype("int64") - df.dst = df.dst.astype("int64") - - graph = cugraph_dgl.Graph() - total_num_nodes = max(df.src.max(), df.dst.max()) + 1 - - num_nodes_group_1 = total_num_nodes // 2 - num_nodes_group_2 = total_num_nodes - num_nodes_group_1 - - node_x_1 = np.random.random((num_nodes_group_1,)) - node_x_2 = np.random.random((num_nodes_group_2,)) - - graph.add_nodes(num_nodes_group_1, {"x": node_x_1}, "type1") - graph.add_nodes(num_nodes_group_2, {"x": node_x_2}, "type2") - - edges_11 = df[(df.src < num_nodes_group_1) & (df.dst < num_nodes_group_1)] - edges_12 = df[(df.src < num_nodes_group_1) & (df.dst >= num_nodes_group_1)] - edges_21 = df[(df.src >= num_nodes_group_1) & (df.dst < num_nodes_group_1)] - edges_22 = df[(df.src >= num_nodes_group_1) & (df.dst >= num_nodes_group_1)] - - edges_12.dst -= num_nodes_group_1 - edges_21.src -= num_nodes_group_1 - edges_22.dst -= num_nodes_group_1 - edges_22.src -= num_nodes_group_1 - - graph.add_edges(edges_11.src, edges_11.dst, etype=("type1", "e1", "type1")) - graph.add_edges(edges_12.src, edges_12.dst, etype=("type1", "e2", "type2")) - graph.add_edges(edges_21.src, edges_21.dst, etype=("type2", "e3", "type1")) - graph.add_edges(edges_22.src, edges_22.dst, etype=("type2", "e4", "type2")) +def test_graph_find(direction, karate_bipartite): + graph, edges, _ = karate_bipartite # force direction generation to make sure in case is tested graph._graph(direction) @@ -260,7 +213,9 @@ def test_graph_find(direction): assert not graph.is_multi_gpu srcs, dsts = graph.find_edges( - torch.as_tensor([0, len(edges_11) - 1, 999], dtype=torch.int64), + torch.as_tensor( + [0, len(edges["type1", "e1", "type1"]) - 1, 999], dtype=torch.int64 + ), ("type1", "e1", "type1"), ) assert ( @@ -272,7 +227,9 @@ def test_graph_find(direction): assert srcs[2] < 0 and dsts[2] < 0 srcs, dsts = graph.find_edges( - torch.as_tensor([0, len(edges_12) - 1, 999], dtype=torch.int64), + torch.as_tensor( + [0, len(edges["type1", "e2", "type2"]) - 1, 999], dtype=torch.int64 + ), ("type1", "e2", "type2"), ) assert ( @@ -284,7 +241,9 @@ def test_graph_find(direction): assert srcs[2] < 0 and dsts[2] < 0 srcs, dsts = graph.find_edges( - torch.as_tensor([0, len(edges_21) - 1, 999], dtype=torch.int64), + torch.as_tensor( + [0, len(edges["type2", "e3", "type1"]) - 1, 999], dtype=torch.int64 + ), ("type2", "e3", "type1"), ) assert ( @@ -296,7 +255,9 @@ def test_graph_find(direction): assert srcs[2] < 0 and dsts[2] < 0 srcs, dsts = graph.find_edges( - torch.as_tensor([0, len(edges_22) - 1, 999], dtype=torch.int64), + torch.as_tensor( + [0, len(edges["type2", "e4", "type2"]) - 1, 999], dtype=torch.int64 + ), ("type2", "e4", "type2"), ) assert ( @@ -306,3 +267,37 @@ def test_graph_find(direction): dsts[[0, 1]] == torch.tensor([1, 16], device="cuda", dtype=torch.int64) ).all() assert srcs[2] < 0 and dsts[2] < 0 + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") +@pytest.mark.parametrize("exclude_self_loops", [True, False]) +@pytest.mark.parametrize("num_samples", [2, 11]) +def test_graph_uniform_negative_sample( + karate_bipartite, exclude_self_loops, num_samples +): + graph, edges, _ = karate_bipartite + + for etype in graph.canonical_etypes: + src_neg, dst_neg = graph.global_uniform_negative_sampling( + num_samples, + exclude_self_loops=exclude_self_loops, + etype=etype, + ) + + assert len(src_neg) == len(dst_neg) + assert len(src_neg) <= num_samples + + assert (src_neg >= 0).all() + assert (dst_neg >= 0).all() + + assert (src_neg < graph.num_nodes(etype[0])).all() + assert (dst_neg < graph.num_nodes(etype[2])).all() + + if exclude_self_loops: + assert (src_neg == dst_neg).sum() == 0 + + for i in range(len(src_neg)): + s = int(src_neg[i]) + d = int(dst_neg[i]) + assert ((edges[etype].src == s) & (edges[etype].dst == d)).sum() == 0 From 90e0956a98332377815b861368b3b406a5a11158 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Fri, 18 Oct 2024 14:07:53 -0700 Subject: [PATCH 5/6] add mg negative sampling tests --- python/cugraph-dgl/cugraph_dgl/graph.py | 7 + .../cugraph-dgl/cugraph_dgl/tests/conftest.py | 16 +- .../cugraph_dgl/tests/test_graph_mg.py | 184 +++++++++++------- 3 files changed, 132 insertions(+), 75 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/graph.py b/python/cugraph-dgl/cugraph_dgl/graph.py index c065f14..92555a5 100644 --- a/python/cugraph-dgl/cugraph_dgl/graph.py +++ b/python/cugraph-dgl/cugraph_dgl/graph.py @@ -1046,6 +1046,13 @@ def global_uniform_negative_sampling( num_samples_global = int(num_samples_global) vertices = torch.tensor_split(vertices, world_size)[rank] + + src_bias = cupy.array_split(src_bias, world_size)[rank] + dst_bias = ( + src_bias + if can_edge_type[0] == can_edge_type[2] + else cupy.array_split(dst_bias, world_size)[rank] + ) else: num_samples_global = num_samples diff --git a/python/cugraph-dgl/cugraph_dgl/tests/conftest.py b/python/cugraph-dgl/cugraph_dgl/tests/conftest.py index 0f9f890..204539b 100644 --- a/python/cugraph-dgl/cugraph_dgl/tests/conftest.py +++ b/python/cugraph-dgl/cugraph_dgl/tests/conftest.py @@ -87,6 +87,13 @@ def create_karate_bipartite(multi_gpu: bool = False): node_x_1 = np.random.random((num_nodes_group_1,)) node_x_2 = np.random.random((num_nodes_group_2,)) + if multi_gpu: + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + node_x_1 = np.array_split(node_x_1, world_size)[rank] + node_x_2 = np.array_split(node_x_2, world_size)[rank] + graph.add_nodes(num_nodes_group_1, {"x": node_x_1}, "type1") graph.add_nodes(num_nodes_group_2, {"x": node_x_2}, "type2") @@ -114,8 +121,8 @@ def create_karate_bipartite(multi_gpu: bool = False): world_size = torch.distributed.get_world_size() edges_local = { - etype: edf.iloc[np.array_split(np.arange(edf), world_size)[rank]] - for etype, edf in edges + etype: edf.iloc[np.array_split(np.arange(len(edf)), world_size)[rank]] + for etype, edf in edges.items() } else: edges_local = edges @@ -129,8 +136,3 @@ def create_karate_bipartite(multi_gpu: bool = False): @pytest.fixture def karate_bipartite(): return create_karate_bipartite(False) - - -@pytest.fixture -def karate_bipartite_mg(): - return create_karate_bipartite(True) diff --git a/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py b/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py index 489ad90..2b30f36 100644 --- a/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py +++ b/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py @@ -31,6 +31,7 @@ ) from .utils import init_pytorch_worker +from .conftest import create_karate_bipartite pylibwholegraph = import_optional("pylibwholegraph") torch = import_optional("torch") @@ -122,8 +123,10 @@ def run_test_graph_make_homogeneous_graph_mg(rank, uid, world_size, direction): assert (d_out_actual == d_out_exp).all() cugraph_comms_shutdown() + torch.distributed.destroy_process_group() +@pytest.mark.skip(reason="bleh") @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif( isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" @@ -287,8 +290,10 @@ def run_test_graph_make_heterogeneous_graph_mg(rank, uid, world_size, direction) assert len(f) > 0 # may be multiple, some could be on other GPU cugraph_comms_shutdown() + torch.distributed.destroy_process_group() +@pytest.mark.skip(reason="bleh") @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif( isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" @@ -312,55 +317,8 @@ def test_graph_make_heterogeneous_graph_mg(direction): def run_test_graph_find_simple_mg(rank, world_size, uid, direction): init_pytorch_worker(rank, world_size, uid) - df = karate.get_edgelist() - - total_num_nodes = max(df.src.max(), df.dst.max()) + 1 - - num_nodes_group_1 = total_num_nodes // 2 - num_nodes_group_2 = total_num_nodes - num_nodes_group_1 - - node_x_1 = np.array_split(np.random.random((num_nodes_group_1,)), world_size)[rank] - node_x_2 = np.array_split(np.random.random((num_nodes_group_2,)), world_size)[rank] - - graph = cugraph_dgl.Graph(is_multi_gpu=True) - graph.add_nodes(num_nodes_group_1, {"x": node_x_1}, "type1") - graph.add_nodes(num_nodes_group_2, {"x": node_x_2}, "type2") - - edges_11 = df[(df.src < num_nodes_group_1) & (df.dst < num_nodes_group_1)] - edges_12 = df[(df.src < num_nodes_group_1) & (df.dst >= num_nodes_group_1)] - edges_21 = df[(df.src >= num_nodes_group_1) & (df.dst < num_nodes_group_1)] - edges_22 = df[(df.src >= num_nodes_group_1) & (df.dst >= num_nodes_group_1)] - edges_12.dst -= num_nodes_group_1 - edges_21.src -= num_nodes_group_1 - edges_22.dst -= num_nodes_group_1 - edges_22.src -= num_nodes_group_1 - - edges_11_local = edges_11.iloc[ - np.array_split(np.arange(len(edges_11)), world_size)[rank] - ] - edges_12_local = edges_12.iloc[ - np.array_split(np.arange(len(edges_12)), world_size)[rank] - ] - edges_21_local = edges_21.iloc[ - np.array_split(np.arange(len(edges_21)), world_size)[rank] - ] - edges_22_local = edges_22.iloc[ - np.array_split(np.arange(len(edges_22)), world_size)[rank] - ] - - graph.add_edges( - edges_11_local.src, edges_11_local.dst, etype=("type1", "e1", "type1") - ) - graph.add_edges( - edges_12_local.src, edges_12_local.dst, etype=("type1", "e2", "type2") - ) - graph.add_edges( - edges_21_local.src, edges_21_local.dst, etype=("type2", "e3", "type1") - ) - graph.add_edges( - edges_22_local.src, edges_22_local.dst, etype=("type2", "e4", "type2") - ) + graph, edges, _ = create_karate_bipartite(multi_gpu=True) # force direction generation to make sure in case is tested graph._graph(direction) @@ -368,15 +326,20 @@ def run_test_graph_find_simple_mg(rank, world_size, uid, direction): assert not graph.is_homogeneous assert graph.is_multi_gpu - if len(edges_11) > 0: + if len(edges[("type1", "e1", "type1")]) > 0: srcs, dsts = graph.find_edges( - torch.as_tensor([0, len(edges_11) - 1, 999], dtype=torch.int64), + torch.as_tensor( + [0, len(edges[("type1", "e1", "type1")]) - 1, 999], dtype=torch.int64 + ), ("type1", "e1", "type1"), ) assert ( srcs[[0, 1]] == torch.tensor( - [edges_11.src.iloc[0], edges_11.src.iloc[-1]], + [ + edges[("type1", "e1", "type1")].src.iloc[0], + edges[("type1", "e1", "type1")].src.iloc[-1], + ], device="cuda", dtype=torch.int64, ) @@ -384,21 +347,29 @@ def run_test_graph_find_simple_mg(rank, world_size, uid, direction): assert ( dsts[[0, 1]] == torch.tensor( - [edges_11.dst.iloc[0], edges_11.dst.iloc[-1]], + [ + edges[("type1", "e1", "type1")].dst.iloc[0], + edges[("type1", "e1", "type1")].dst.iloc[-1], + ], device="cuda", dtype=torch.int64, ) ).all() assert srcs[2] < 0 and dsts[2] < 0 - if len(edges_12) > 0: + if len(edges[("type1", "e2", "type2")]) > 0: srcs, dsts = graph.find_edges( - torch.as_tensor([0, len(edges_12) - 1, 999], dtype=torch.int64), + torch.as_tensor( + [0, len(edges[("type1", "e2", "type2")]) - 1, 999], dtype=torch.int64 + ), ("type1", "e2", "type2"), ) assert ( srcs[[0, 1]] == torch.tensor( - [edges_12.src.iloc[0], edges_12.src.iloc[-1]], + [ + edges[("type1", "e2", "type2")].src.iloc[0], + edges[("type1", "e2", "type2")].src.iloc[-1], + ], device="cuda", dtype=torch.int64, ) @@ -406,21 +377,29 @@ def run_test_graph_find_simple_mg(rank, world_size, uid, direction): assert ( dsts[[0, 1]] == torch.tensor( - [edges_12.dst.iloc[0], edges_12.dst.iloc[-1]], + [ + edges[("type1", "e2", "type2")].dst.iloc[0], + edges[("type1", "e2", "type2")].dst.iloc[-1], + ], device="cuda", dtype=torch.int64, ) ).all() assert srcs[2] < 0 and dsts[2] < 0 - if len(edges_21) > 0: + if len(edges[("type2", "e3", "type1")]) > 0: srcs, dsts = graph.find_edges( - torch.as_tensor([0, len(edges_21) - 1, 999], dtype=torch.int64), + torch.as_tensor( + [0, len(edges[("type2", "e3", "type1")]) - 1, 999], dtype=torch.int64 + ), ("type2", "e3", "type1"), ) assert ( srcs[[0, 1]] == torch.tensor( - [edges_21.src.iloc[0], edges_21.src.iloc[-1]], + [ + edges[("type2", "e3", "type1")].src.iloc[0], + edges[("type2", "e3", "type1")].src.iloc[-1], + ], device="cuda", dtype=torch.int64, ) @@ -428,21 +407,29 @@ def run_test_graph_find_simple_mg(rank, world_size, uid, direction): assert ( dsts[[0, 1]] == torch.tensor( - [edges_21.dst.iloc[0], edges_21.dst.iloc[-1]], + [ + edges[("type2", "e3", "type1")].dst.iloc[0], + edges[("type2", "e3", "type1")].dst.iloc[-1], + ], device="cuda", dtype=torch.int64, ) ).all() assert srcs[2] < 0 and dsts[2] < 0 - if len(edges_22) > 0: + if len(edges[("type2", "e4", "type2")]) > 0: srcs, dsts = graph.find_edges( - torch.as_tensor([0, len(edges_22) - 1, 999], dtype=torch.int64), + torch.as_tensor( + [0, len(edges[("type2", "e4", "type2")]) - 1, 999], dtype=torch.int64 + ), ("type2", "e4", "type2"), ) assert ( srcs[[0, 1]] == torch.tensor( - [edges_22.src.iloc[0], edges_22.src.iloc[-1]], + [ + edges[("type2", "e4", "type2")].src.iloc[0], + edges[("type2", "e4", "type2")].src.iloc[-1], + ], device="cuda", dtype=torch.int64, ) @@ -450,22 +437,25 @@ def run_test_graph_find_simple_mg(rank, world_size, uid, direction): assert ( dsts[[0, 1]] == torch.tensor( - [edges_22.dst.iloc[0], edges_22.dst.iloc[-1]], + [ + edges[("type2", "e4", "type2")].dst.iloc[0], + edges[("type2", "e4", "type2")].dst.iloc[-1], + ], device="cuda", dtype=torch.int64, ) ).all() assert srcs[2] < 0 and dsts[2] < 0 + cugraph_comms_shutdown() + torch.distributed.destroy_process_group() + +@pytest.mark.skip(reason="bleh") @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") @pytest.mark.parametrize("direction", ["out", "in"]) def test_graph_find_mg(direction): - df = karate.get_edgelist() - df.src = df.src.astype("int64") - df.dst = df.dst.astype("int64") - uid = cugraph_comms_create_unique_id() world_size = torch.cuda.device_count() @@ -478,3 +468,61 @@ def test_graph_find_mg(direction): ), nprocs=world_size, ) + + +def run_test_uniform_negative_sample_mg( + rank, world_size, uid, exclude_self_loops, num_samples_per_worker +): + init_pytorch_worker(rank, world_size, uid) + + graph, edges, _ = create_karate_bipartite(multi_gpu=True) + + assert not graph.is_homogeneous + assert graph.is_multi_gpu + + for etype in graph.canonical_etypes: + src_neg, dst_neg = graph.global_uniform_negative_sampling( + num_samples_per_worker, + exclude_self_loops=exclude_self_loops, + etype=etype, + ) + + assert len(src_neg) == len(dst_neg) + assert len(src_neg) <= num_samples_per_worker + + assert (src_neg >= 0).all() + assert (dst_neg >= 0).all() + + assert (src_neg < graph.num_nodes(etype[0])).all() + assert (dst_neg < graph.num_nodes(etype[2])).all() + + if exclude_self_loops: + assert (src_neg == dst_neg).sum() == 0 + + for i in range(len(src_neg)): + s = int(src_neg[i]) + d = int(dst_neg[i]) + assert ((edges[etype].src == s) & (edges[etype].dst == d)).sum() == 0 + + cugraph_comms_shutdown() + torch.distributed.destroy_process_group() + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") +@pytest.mark.parametrize("exclude_self_loops", [True, False]) +@pytest.mark.parametrize("num_samples_per_worker", [1, 2, 5]) +def test_graph_uniform_negative_sample_mg(exclude_self_loops, num_samples_per_worker): + uid = cugraph_comms_create_unique_id() + world_size = torch.cuda.device_count() + + torch.multiprocessing.spawn( + run_test_uniform_negative_sample_mg, + args=( + world_size, + uid, + exclude_self_loops, + num_samples_per_worker, + ), + nprocs=world_size, + ) From 70a84a1795fca4871a64f88f2f51412f556c64e1 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Fri, 18 Oct 2024 14:08:34 -0700 Subject: [PATCH 6/6] remove skips --- python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py b/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py index 2b30f36..0a8757b 100644 --- a/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py +++ b/python/cugraph-dgl/cugraph_dgl/tests/test_graph_mg.py @@ -126,7 +126,6 @@ def run_test_graph_make_homogeneous_graph_mg(rank, uid, world_size, direction): torch.distributed.destroy_process_group() -@pytest.mark.skip(reason="bleh") @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif( isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" @@ -293,7 +292,6 @@ def run_test_graph_make_heterogeneous_graph_mg(rank, uid, world_size, direction) torch.distributed.destroy_process_group() -@pytest.mark.skip(reason="bleh") @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif( isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available" @@ -451,7 +449,6 @@ def run_test_graph_find_simple_mg(rank, world_size, uid, direction): torch.distributed.destroy_process_group() -@pytest.mark.skip(reason="bleh") @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available") @pytest.mark.parametrize("direction", ["out", "in"])