Skip to content

Commit

Permalink
cputest
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Aug 6, 2024
1 parent 7b2ae26 commit 3ee41e7
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/metrics/test_nlp_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
from packaging import version
from torch.nn.functional import cross_entropy

from composer.metrics.nlp import (
Expand Down Expand Up @@ -82,8 +83,11 @@ def test_cross_entropy(
tensor_device (str): which device the input tensors to the metric are on
"""

if device == 'cpu' and tensor_device == 'gpu':
pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.')
if device == 'cpu':
if tensor_device == 'gpu':
pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.')
if version.parse(torch.__version__) < version.parse('2.3.0'):
pytest.skip('Skipping test that would try to use gloo + nccl backend on torch < 2.3.0.')

batch_size = int(batch_size)
generated_preds = torch.randn((batch_size, sequence_length, num_classes))
Expand Down

0 comments on commit 3ee41e7

Please sign in to comment.