Skip to content

Commit

Permalink
[Fix] make all tests pass on XPU (#2427)
Browse files Browse the repository at this point in the history
* fix tests

* style
  • Loading branch information
faaany authored Feb 9, 2024
1 parent 9c1d5ba commit f75c624
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
require_multi_gpu,
require_multi_xpu,
require_non_cpu,
require_non_xpu,
require_npu,
require_pippy,
require_single_device,
require_single_gpu,
Expand Down
14 changes: 14 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ def require_xpu(test_case):
return unittest.skipUnless(is_xpu_available(), "test requires a XPU")(test_case)


def require_non_xpu(test_case):
"""
Decorator marking a test that should be skipped for XPU.
"""
return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)


def require_npu(test_case):
"""
Decorator marking a test that requires NPU. These tests are skipped when there are no NPU available.
"""
return unittest.skipUnless(is_npu_available(), "test require a NPU")(test_case)


def require_mps(test_case):
"""
Decorator marking a test that requires MPS backend. These tests are skipped when torch doesn't support `mps`
Expand Down
9 changes: 8 additions & 1 deletion tests/test_kwargs_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@

from accelerate import Accelerator, DistributedDataParallelKwargs, GradScalerKwargs
from accelerate.state import AcceleratorState
from accelerate.test_utils import device_count, execute_subprocess_async, require_multi_device, require_non_cpu
from accelerate.test_utils import (
device_count,
execute_subprocess_async,
require_multi_device,
require_non_cpu,
require_non_xpu,
)
from accelerate.utils import AutocastKwargs, KwargsHandler, TorchDynamoPlugin, clear_environment


Expand All @@ -41,6 +47,7 @@ def test_kwargs_handler(self):
self.assertDictEqual(MockClass(a=2, c=2.25).to_kwargs(), {"a": 2, "c": 2.25})

@require_non_cpu
@require_non_xpu
def test_grad_scaler_kwargs(self):
# If no defaults are changed, `to_kwargs` returns an empty dict.
scaler_handler = GradScalerKwargs(init_scale=1024, growth_factor=2)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.test_utils import require_cpu, require_non_cpu
from accelerate.test_utils import require_cpu, require_non_cpu, require_non_xpu


@require_cpu
Expand All @@ -37,6 +37,7 @@ def test_accelerated_optimizer_pickling(self):


@require_non_cpu
@require_non_xpu
class OptimizerTester(unittest.TestCase):
def test_accelerated_optimizer_step_was_skipped(self):
model = torch.nn.Linear(5, 5)
Expand Down

0 comments on commit f75c624

Please sign in to comment.