Skip to content

Commit

Permalink
simplify test, add comment, change function name
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Oct 18, 2024
1 parent 9c4c15b commit 340a488
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 35 deletions.
2 changes: 1 addition & 1 deletion python/pylibcugraph/pylibcugraph/edge_id_lookup_table.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ cdef class EdgeIdLookupTable:
if self.lookup_container_c_ptr is not NULL:
cugraph_lookup_container_free(self.lookup_container_c_ptr)

def find(
def lookup_vertex_ids(
self,
edge_ids,
int edge_type
Expand Down
48 changes: 14 additions & 34 deletions python/pylibcugraph/pylibcugraph/tests/test_lookup_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# limitations under the License.

import cupy
import numpy as np

from pylibcugraph import (
SGGraph,
Expand All @@ -34,43 +33,24 @@


def test_lookup_table():

# Vertex id array
vtcs = cupy.arange(6, dtype="int64")
vtps = np.array([0, 0, 1, 1, 2, 2])

e_lookup = {
(0, 0): [0, 0],
(0, 1): [1, 0],
(0, 2): [2, 0],
(1, 0): [3, 0],
(1, 1): [4, 0],
(1, 2): [5, 0],
(2, 0): [6, 0],
(2, 1): [7, 0],
(2, 2): [8, 0],
}

srcs = np.array([0, 1, 5, 4, 3, 2, 2, 0, 5, 4, 4, 5])
dsts = np.array([1, 5, 0, 3, 2, 1, 3, 3, 2, 3, 1, 4])
wgts = cupy.ones((len(srcs),), dtype="float32")

eids = []
etps = []
for i in range(len(srcs)):
key = (int(vtps[srcs[i]]), int(vtps[dsts[i]]))
etps.append(e_lookup[key][0])
eids.append(e_lookup[key][1])
# Edge ids are unique per edge type and start from 0
# Each edge type has the same src/dst vertex type here,
# just as it would in a GNN application.
srcs = cupy.array([0, 1, 5, 4, 3, 2, 2, 0, 5, 4, 4, 5])
dsts = cupy.array([1, 5, 0, 3, 2, 1, 3, 3, 2, 3, 1, 4])
etps = cupy.array([0, 2, 6, 7, 4, 3, 4, 1, 7, 7, 6, 8], dtype="int32")
eids = cupy.array([0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 1, 0])

e_lookup[key][1] += 1

eids = cupy.array(eids)
etps = cupy.array(etps, dtype="int32")
wgts = cupy.ones((len(srcs),), dtype="float32")

graph = SGGraph(
resource_handle=ResourceHandle(),
graph_properties=GraphProperties(is_symmetric=False, is_multigraph=True),
src_or_offset_array=cupy.array(srcs),
dst_or_index_array=cupy.array(dsts),
src_or_offset_array=srcs,
dst_or_index_array=dsts,
vertices_array=vtcs,
weight_array=wgts,
edge_id_array=eids,
Expand All @@ -84,15 +64,15 @@ def test_lookup_table():

assert table is not None

found_edges = table.find(cupy.array([0, 1, 2, 3, 4]), 7)
found_edges = table.lookup_vertex_ids(cupy.array([0, 1, 2, 3, 4]), 7)
assert (found_edges["sources"] == cupy.array([4, 5, 4, -1, -1])).all()
assert (found_edges["destinations"] == cupy.array([3, 2, 3, -1, -1])).all()

found_edges = table.find(cupy.array([0]), 5)
found_edges = table.lookup_vertex_ids(cupy.array([0]), 5)
assert (found_edges["sources"] == cupy.array([-1])).all()
assert (found_edges["destinations"] == cupy.array([-1])).all()

found_edges = table.find(cupy.array([3, 1, 0, 5]), 6)
found_edges = table.lookup_vertex_ids(cupy.array([3, 1, 0, 5]), 6)
assert (found_edges["sources"] == cupy.array([-1, 4, 5, -1])).all()
assert (found_edges["destinations"] == cupy.array([-1, 1, 0, -1])).all()

Expand Down

0 comments on commit 340a488

Please sign in to comment.