Skip to content

Commit

Permalink
tests: revert change of torch_require_multi_gpu to be device agnostic
Browse files Browse the repository at this point in the history
The 11c27dd modified `torch_require_multi_gpu()` to be device agnostic
instead of being CUDA specific. This broke some tests which are rightfully
CUDA specific, such as:

* `tests/trainer/test_trainer_distributed.py::TestTrainerDistributed`

In the current Transformers tests architecture `require_torch_multi_accelerator()`
should be used to mark multi-GPU tests agnostic to device.

This change addresses the issue introduced by 11c27dd and reverts
modification of `torch_require_multi_gpu()`.

Fixes: 11c27dd ("Enable BNB multi-backend support (huggingface#31098)")
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Jan 16, 2025
1 parent 99e0ab6 commit bef1b1a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 20 deletions.
19 changes: 4 additions & 15 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,6 @@ def parse_int_from_env(key, default=None):
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)


def get_device_count():
import torch

if is_torch_xpu_available():
num_devices = torch.xpu.device_count()
else:
num_devices = torch.cuda.device_count()

return num_devices


def is_pt_tf_cross_test(test_case):
"""
Decorator marking a test as a test that control interactions between PyTorch and TensorFlow.
Expand Down Expand Up @@ -770,17 +759,17 @@ def require_spacy(test_case):

def require_torch_multi_gpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
multiple GPUs.
Decorator marking a test that requires a multi-GPU CUDA setup (in PyTorch). These tests are skipped on a machine without
multiple CUDA GPUs.
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
"""
if not is_torch_available():
return unittest.skip(reason="test requires PyTorch")(test_case)

device_count = get_device_count()
import torch

return unittest.skipUnless(device_count > 1, "test requires multiple GPUs")(test_case)
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple CUDA GPUs")(test_case)


def require_torch_multi_accelerator(test_case):
Expand Down
4 changes: 2 additions & 2 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
require_bitsandbytes,
require_torch,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_torch_multi_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -514,7 +514,7 @@ def test_pipeline(self):
self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)


@require_torch_multi_gpu
@require_torch_multi_accelerator
@apply_skip_if_not_implemented
class Bnb4bitTestMultiGpu(Base4bitTest):
def setUp(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
require_bitsandbytes,
require_torch,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_torch_multi_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -669,7 +669,7 @@ def test_pipeline(self):
self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)


@require_torch_multi_gpu
@require_torch_multi_accelerator
@apply_skip_if_not_implemented
class MixedInt8TestMultiGpu(BaseMixedInt8Test):
def setUp(self):
Expand Down Expand Up @@ -698,7 +698,7 @@ def test_multi_gpu_loading(self):
self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)


@require_torch_multi_gpu
@require_torch_multi_accelerator
@apply_skip_if_not_implemented
class MixedInt8TestCpuGpu(BaseMixedInt8Test):
def setUp(self):
Expand Down

0 comments on commit bef1b1a

Please sign in to comment.