diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py index 836b30c9df7..2a79bf203e2 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py @@ -29,9 +29,16 @@ torch = import_optional("torch") torch_geometric = import_optional("torch_geometric") -torch_sparse = import_optional("torch_sparse") + trim_to_layer = import_optional("torch_geometric.utils.trim_to_layer") +try: + import torch_sparse # noqa: F401 + + HAS_TORCH_SPARSE = True +except: # noqa: E722 + HAS_TORCH_SPARSE = False + @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") def test_cugraph_loader_basic(karate_gnn): @@ -201,9 +208,7 @@ def test_cugraph_loader_from_disk_subset(): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -@pytest.mark.skipif( - isinstance(torch_sparse, MissingModule), reason="torch-sparse not available" -) +@pytest.mark.skipif(not HAS_TORCH_SPARSE, reason="torch-sparse not available") def test_cugraph_loader_from_disk_subset_csr(): m = [2, 9, 99, 82, 11, 13] n = torch.arange(1, 1 + len(m), dtype=torch.int32) @@ -336,9 +341,7 @@ def test_cugraph_loader_e2e_coo(): @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") -@pytest.mark.skipif( - isinstance(torch_sparse, MissingModule), reason="torch-sparse not available" -) +@pytest.mark.skipif(not HAS_TORCH_SPARSE, reason="torch-sparse not available") @pytest.mark.parametrize("framework", ["pyg", "cugraph-ops"]) def test_cugraph_loader_e2e_csc(framework): m = [2, 9, 99, 82, 9, 3, 18, 1, 12]