Skip to content

Commit

Permalink
[BUG] Set Max PyTorch Version, Skip 11.4 Tests Using WholeGraph (#4808)
Browse files Browse the repository at this point in the history
Set Max PyTorch Version, Skip 11.4 Tests Using WholeGraph

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4808
  • Loading branch information
alexbarghi-nv authored Dec 6, 2024
1 parent 5956d4d commit 58075dd
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dependencies:
- pytest-cov
- pytest-xdist
- python-louvain
- pytorch>=2.3
- pytorch>=2.3,<2.5a0
- raft-dask==24.12.*,>=0.0.0a0
- rapids-build-backend>=0.3.1,<0.4.0.dev0
- rapids-dask-dependency==24.12.*,>=0.0.0a0
Expand Down
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ dependencies:
- pytest-cov
- pytest-xdist
- python-louvain
- pytorch>=2.3
- pytorch>=2.3,<2.5a0
- raft-dask==24.12.*,>=0.0.0a0
- rapids-build-backend>=0.3.1,<0.4.0.dev0
- rapids-dask-dependency==24.12.*,>=0.0.0a0
Expand Down
2 changes: 1 addition & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ dependencies:
common:
- output_types: [conda]
packages:
- &pytorch_conda pytorch>=2.3
- &pytorch_conda pytorch>=2.3,<2.5a0
- torchdata
- pydantic
- ogb
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import numpy as np
import os

import numba.cuda

from cugraph.gnn import FeatureStore

from cugraph.utilities.utils import import_optional, MissingModule
Expand All @@ -25,6 +27,11 @@
wgth = import_optional("pylibwholegraph.torch")


def get_cudart_version():
major, minor = numba.cuda.runtime.get_version()
return major * 1000 + minor * 10


def runtest(rank: int, world_size: int):
torch.cuda.set_device(rank)

Expand Down Expand Up @@ -66,6 +73,9 @@ def runtest(rank: int, world_size: int):
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
@pytest.mark.skipif(
get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8"
)
def test_feature_storage_wholegraph_backend():
world_size = torch.cuda.device_count()
print("gpu count:", world_size)
Expand All @@ -81,6 +91,9 @@ def test_feature_storage_wholegraph_backend():
@pytest.mark.skipif(
isinstance(pylibwholegraph, MissingModule), reason="wholegraph not available"
)
@pytest.mark.skipif(
get_cudart_version() < 11080, reason="not compatible with CUDA < 11.8"
)
def test_feature_storage_wholegraph_backend_mg():
world_size = torch.cuda.device_count()
print("gpu count:", world_size)
Expand Down

0 comments on commit 58075dd

Please sign in to comment.