Skip to content

Commit

Permalink
fix num_sampled_nodes/edges usage in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed May 15, 2024
1 parent 814daaa commit 1c742c8
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -80,10 +80,10 @@ def test_neighbor_sample(basic_graph_1):

# check the hop dictionaries
assert len(out.num_sampled_nodes) == 1
assert out.num_sampled_nodes["vt1"].tolist() == [4, 1]
assert out.num_sampled_nodes["vt1"] == [4, 1]

assert len(out.num_sampled_edges) == 1
assert out.num_sampled_edges[("vt1", "pig", "vt1")].tolist() == [6]
assert out.num_sampled_edges[("vt1", "pig", "vt1")] == [6]


@pytest.mark.cugraph_ops
Expand Down Expand Up @@ -136,15 +136,15 @@ def test_neighbor_sample_multi_vertex(multi_edge_multi_vertex_graph_1):

# check the hop dictionaries
assert len(out.num_sampled_nodes) == 2
assert out.num_sampled_nodes["black"].tolist() == [2, 0]
assert out.num_sampled_nodes["brown"].tolist() == [3, 0]
assert out.num_sampled_nodes["black"] == [2, 0]
assert out.num_sampled_nodes["brown"] == [3, 0]

assert len(out.num_sampled_edges) == 5
assert out.num_sampled_edges[("brown", "horse", "brown")].tolist() == [2]
assert out.num_sampled_edges[("brown", "tortoise", "black")].tolist() == [3]
assert out.num_sampled_edges[("brown", "mongoose", "black")].tolist() == [2]
assert out.num_sampled_edges[("black", "cow", "brown")].tolist() == [2]
assert out.num_sampled_edges[("black", "snake", "black")].tolist() == [1]
assert out.num_sampled_edges[("brown", "horse", "brown")] == [2]
assert out.num_sampled_edges[("brown", "tortoise", "black")] == [3]
assert out.num_sampled_edges[("brown", "mongoose", "black")] == [2]
assert out.num_sampled_edges[("black", "cow", "brown")] == [2]
assert out.num_sampled_edges[("black", "snake", "black")] == [1]


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
Expand Down Expand Up @@ -183,14 +183,14 @@ def test_neighbor_sample_mock_sampling_results(abc_graph):
assert out.col[("B", "ba", "A")].tolist() == [1, 1]

assert len(out.num_sampled_nodes) == 3
assert out.num_sampled_nodes["A"].tolist() == [2, 0, 0, 0, 0]
assert out.num_sampled_nodes["B"].tolist() == [0, 2, 0, 0, 0]
assert out.num_sampled_nodes["C"].tolist() == [0, 0, 2, 0, 1]
assert out.num_sampled_nodes["A"] == [2, 0, 0, 0, 0]
assert out.num_sampled_nodes["B"] == [0, 2, 0, 0, 0]
assert out.num_sampled_nodes["C"] == [0, 0, 2, 0, 1]

assert len(out.num_sampled_edges) == 3
assert out.num_sampled_edges[("A", "ab", "B")].tolist() == [3, 0, 1, 0]
assert out.num_sampled_edges[("B", "ba", "A")].tolist() == [0, 1, 0, 1]
assert out.num_sampled_edges[("B", "bc", "C")].tolist() == [0, 2, 0, 2]
assert out.num_sampled_edges[("A", "ab", "B")] == [3, 0, 1, 0]
assert out.num_sampled_edges[("B", "ba", "A")] == [0, 1, 0, 1]
assert out.num_sampled_edges[("B", "bc", "C")] == [0, 2, 0, 2]


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
Expand Down

0 comments on commit 1c742c8

Please sign in to comment.