Skip to content

Commit

Permalink
catch all types of error when importing torch_sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Nov 1, 2023
1 parent b1fc14d commit 9823649
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")
torch_sparse = import_optional("torch_sparse")

trim_to_layer = import_optional("torch_geometric.utils.trim_to_layer")

try:
import torch_sparse # noqa: F401

HAS_TORCH_SPARSE = True
except: # noqa: E722
HAS_TORCH_SPARSE = False


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
def test_cugraph_loader_basic(karate_gnn):
Expand Down Expand Up @@ -201,9 +208,7 @@ def test_cugraph_loader_from_disk_subset():


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(torch_sparse, MissingModule), reason="torch-sparse not available"
)
@pytest.mark.skipif(not HAS_TORCH_SPARSE, reason="torch-sparse not available")
def test_cugraph_loader_from_disk_subset_csr():
m = [2, 9, 99, 82, 11, 13]
n = torch.arange(1, 1 + len(m), dtype=torch.int32)
Expand Down Expand Up @@ -336,9 +341,7 @@ def test_cugraph_loader_e2e_coo():


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(
isinstance(torch_sparse, MissingModule), reason="torch-sparse not available"
)
@pytest.mark.skipif(not HAS_TORCH_SPARSE, reason="torch-sparse not available")
@pytest.mark.parametrize("framework", ["pyg", "cugraph-ops"])
def test_cugraph_loader_e2e_csc(framework):
m = [2, 9, 99, 82, 9, 3, 18, 1, 12]
Expand Down

0 comments on commit 9823649

Please sign in to comment.