Skip to content

Commit

Permalink
Merge branch 'pyg-neg-sampling' of https://github.com/alexbarghi-nv/c…
Browse files Browse the repository at this point in the history
…ugraph into pyg-neg-sampling
  • Loading branch information
alexbarghi-nv committed Sep 20, 2024
2 parents ecf4230 + 8f0264f commit 9b6d759
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,14 @@ def _get_plc_graph(
src_indices = src_indices.astype(index_dtype)
dst_indices = dst_indices.astype(index_dtype)

# This sets drop_multi_edges=True for non-multigraph input, which means
# the data in self.src_indices and self.dst_indices may not be
# identical to that contained in the returned pcl.SGGraph (the returned
# SGGraph may have fewer edges since duplicates are dropped). Ideally
# self.src_indices and self.dst_indices would be updated to have
# duplicate edges removed for non-multigraph instances, but that
# requires additional code which would be redundant and likely not as
# performant as the code in PLC.
return plc.SGGraph(
resource_handle=plc.ResourceHandle(),
graph_properties=plc.GraphProperties(
Expand All @@ -702,6 +710,7 @@ def _get_plc_graph(
renumber=False,
do_expensive_check=False,
vertices_array=self._node_ids,
drop_multi_edges=not self.is_multigraph(),
)

def _sort_edge_indices(self, primary="src"):
Expand Down
36 changes: 36 additions & 0 deletions python/nx-cugraph/nx_cugraph/tests/test_pagerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import networkx as nx
import pandas as pd
from pytest import approx


def test_pagerank_multigraph():
"""
Ensures correct differences between pagerank results for Graphs
vs. MultiGraphs generated using from_pandas_edgelist()
"""
df = pd.DataFrame({"source": [0, 1, 1, 1, 1, 1, 1, 2],
"target": [1, 2, 2, 2, 2, 2, 2, 3]})
expected_pr_for_G = nx.pagerank(nx.from_pandas_edgelist(df))
expected_pr_for_MultiG = nx.pagerank(
nx.from_pandas_edgelist(df, create_using=nx.MultiGraph))

G = nx.from_pandas_edgelist(df, backend="cugraph")
actual_pr_for_G = nx.pagerank(G, backend="cugraph")

MultiG = nx.from_pandas_edgelist(df, create_using=nx.MultiGraph, backend="cugraph")
actual_pr_for_MultiG = nx.pagerank(MultiG, backend="cugraph")

assert actual_pr_for_G == approx(expected_pr_for_G)
assert actual_pr_for_MultiG == approx(expected_pr_for_MultiG)

0 comments on commit 9b6d759

Please sign in to comment.