Skip to content

Commit

Permalink
add test using karate dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Sep 28, 2023
1 parent 564ddb4 commit 9e73617
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions python/cugraph-dgl/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
create_homogeneous_sampled_graphs_from_dataframe,
_get_source_destination_range,
_create_homogeneous_cugraph_dgl_nn_sparse_graph,
create_homogeneous_sampled_graphs_from_dataframe_csc,
)
from cugraph.utilities.utils import import_optional

Expand Down Expand Up @@ -50,6 +51,23 @@ def get_dummy_sampled_df():
return df


def get_dummy_sampled_df_csc():
df_dict = dict(
minors=np.array(
[1, 1, 2, 1, 0, 3, 1, 3, 2, 3, 2, 4, 0, 1, 1, 0, 3, 2], dtype=np.int32
),
major_offsets=np.arange(19, dtype=np.int64),
map=np.array(
[26, 29, 33, 22, 23, 32, 18, 29, 33, 33, 8, 30, 32], dtype=np.int32
),
renumber_map_offsets=np.array([0, 4, 9, 13], dtype=np.int64),
label_hop_offsets=np.array([0, 1, 3, 6, 7, 9, 13, 14, 16, 18], dtype=np.int64),
)

# convert values to Series so that NaNs are padded automatically
return cudf.DataFrame({k: cudf.Series(v) for k, v in df_dict.items()})


def test_get_renumber_map():

sampled_df = get_dummy_sampled_df()
Expand Down Expand Up @@ -176,3 +194,13 @@ def test__create_homogeneous_cugraph_dgl_nn_sparse_graph():
assert sparse_graph.num_src_nodes() == 2
assert sparse_graph.num_dst_nodes() == seednodes_range + 1
assert isinstance(sparse_graph, cugraph_dgl.nn.SparseGraph)


def test_create_homogeneous_sampled_graphs_from_dataframe_csc():
df = get_dummy_sampled_df_csc()
batches = create_homogeneous_sampled_graphs_from_dataframe_csc(df)

assert len(batches) == 3
assert torch.equal(batches[0][0], torch.IntTensor([26, 29, 33, 22]).cuda())
assert torch.equal(batches[1][0], torch.IntTensor([23, 32, 18, 29, 33]).cuda())
assert torch.equal(batches[2][0], torch.IntTensor([33, 8, 30, 32]).cuda())

0 comments on commit 9e73617

Please sign in to comment.