From 3ee41e7390f8ae2108944f24ac2d737f5ee4d716 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 5 Aug 2024 20:52:03 -0400 Subject: [PATCH] cputest --- tests/metrics/test_nlp_metrics.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 9b198003d3..0f6d989102 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -6,6 +6,7 @@ import pytest import torch +from packaging import version from torch.nn.functional import cross_entropy from composer.metrics.nlp import ( @@ -82,8 +83,11 @@ def test_cross_entropy( tensor_device (str): which device the input tensors to the metric are on """ - if device == 'cpu' and tensor_device == 'gpu': - pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.') + if device == 'cpu': + if tensor_device == 'gpu': + pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.') + if version.parse(torch.__version__) < version.parse('2.3.0'): + pytest.skip('Skipping test that would try to use gloo + nccl backend on torch < 2.3.0.') batch_size = int(batch_size) generated_preds = torch.randn((batch_size, sequence_length, num_classes))