From 1c742c8b649aba45abbbf46fbea45ed8185c4d29 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Wed, 15 May 2024 19:56:48 -0400 Subject: [PATCH] fix num_sampled_nodes/edges usage in tests --- .../cugraph_pyg/tests/test_cugraph_sampler.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py index e703d477b70..ed011a658a9 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_sampler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -80,10 +80,10 @@ def test_neighbor_sample(basic_graph_1): # check the hop dictionaries assert len(out.num_sampled_nodes) == 1 - assert out.num_sampled_nodes["vt1"].tolist() == [4, 1] + assert out.num_sampled_nodes["vt1"] == [4, 1] assert len(out.num_sampled_edges) == 1 - assert out.num_sampled_edges[("vt1", "pig", "vt1")].tolist() == [6] + assert out.num_sampled_edges[("vt1", "pig", "vt1")] == [6] @pytest.mark.cugraph_ops @@ -136,15 +136,15 @@ def test_neighbor_sample_multi_vertex(multi_edge_multi_vertex_graph_1): # check the hop dictionaries assert len(out.num_sampled_nodes) == 2 - assert out.num_sampled_nodes["black"].tolist() == [2, 0] - assert out.num_sampled_nodes["brown"].tolist() == [3, 0] + assert out.num_sampled_nodes["black"] == [2, 0] + assert out.num_sampled_nodes["brown"] == [3, 0] assert len(out.num_sampled_edges) == 5 - assert out.num_sampled_edges[("brown", "horse", "brown")].tolist() == [2] - assert out.num_sampled_edges[("brown", "tortoise", "black")].tolist() == [3] - assert out.num_sampled_edges[("brown", "mongoose", "black")].tolist() == [2] - assert out.num_sampled_edges[("black", "cow", "brown")].tolist() == [2] - assert out.num_sampled_edges[("black", "snake", "black")].tolist() == [1] + assert out.num_sampled_edges[("brown", "horse", "brown")] == [2] + assert out.num_sampled_edges[("brown", "tortoise", "black")] == [3] + assert out.num_sampled_edges[("brown", "mongoose", "black")] == [2] + assert out.num_sampled_edges[("black", "cow", "brown")] == [2] + assert out.num_sampled_edges[("black", "snake", "black")] == [1] @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @@ -183,14 +183,14 @@ def test_neighbor_sample_mock_sampling_results(abc_graph): assert out.col[("B", "ba", "A")].tolist() == [1, 1] assert len(out.num_sampled_nodes) == 3 - assert out.num_sampled_nodes["A"].tolist() == [2, 0, 0, 0, 0] - assert out.num_sampled_nodes["B"].tolist() == [0, 2, 0, 0, 0] - assert out.num_sampled_nodes["C"].tolist() == [0, 0, 2, 0, 1] + assert out.num_sampled_nodes["A"] == [2, 0, 0, 0, 0] + assert out.num_sampled_nodes["B"] == [0, 2, 0, 0, 0] + assert out.num_sampled_nodes["C"] == [0, 0, 2, 0, 1] assert len(out.num_sampled_edges) == 3 - assert out.num_sampled_edges[("A", "ab", "B")].tolist() == [3, 0, 1, 0] - assert out.num_sampled_edges[("B", "ba", "A")].tolist() == [0, 1, 0, 1] - assert out.num_sampled_edges[("B", "bc", "C")].tolist() == [0, 2, 0, 2] + assert out.num_sampled_edges[("A", "ab", "B")] == [3, 0, 1, 0] + assert out.num_sampled_edges[("B", "ba", "A")] == [0, 1, 0, 1] + assert out.num_sampled_edges[("B", "bc", "C")] == [0, 2, 0, 2] @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")