diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index c28e4a1544a..a24bad4aa51 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -45,7 +45,7 @@ jobs: with: docker-image: ${{ needs.build.outputs.docker-image }} runner: linux.8xlarge.nvidia.gpu - timeout-minutes: 240 + timeout-minutes: 300 disable-xrt: 1 secrets: gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }} diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index 1d061f0d400..8ea4db3e051 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -6,6 +6,7 @@ import torch_xla import torch_xla.core.xla_model as xm from torch_xla import runtime as xr +from torch_xla.amp import autocast import test_xla_sharding_base @@ -112,6 +113,25 @@ def test_runtime_spmd_api(self): os.environ["XLA_USE_SPMD"] = "1" +class BasicAutocastAPITest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + xr.use_spmd() + super().setUpClass() + + @unittest.skipIf(xr.device_type() not in ['GPU', 'TPU'], + f"TPU/GPU autocast test.") + def test_xla_autocast_api(self): + device = xm.xla_device() + t1 = torch.ones([2, 3], device=device, dtype=torch.float32) + t2 = torch.ones([3, 2], device=device, dtype=torch.float32) + with autocast(device, dtype=torch.bfloat16): + t3 = torch.matmul(t1, t2) + expected_dtype = torch.bfloat16 if xr.is_bf16_supported() else torch.float16 + self.assertTrue(t3.dtype == expected_dtype) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_autocast.py b/test/test_autocast.py index 7e1a01e346d..ec8246ec22c 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -7,7 +7,6 @@ sys.argv = [sys.argv[0]] + leftovers import torch -import torch_xla import torch_xla.core.xla_model as xm import collections import unittest @@ -152,6 +151,48 @@ def __init__(self, dev): self.methods_bf16 = [("__matmul__", mat0_bf16 + mat1_fp32)] +class AutocastCudaTestExtraLists(object): + + def __init__(self, dev): + super().__init__() + n = 8 + dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) + conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), + torch.randn(dimset, dtype=torch.float32, device=dev)) + for dimset in dimsets] + + mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) + + pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) + + element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) + + # This is currently not part of AutocastTestLists and excludes `relu`, `addbmm` + self.torch_bf16 = [ + ("conv1d", conv_args_fp32[0]), + ("conv2d", conv_args_fp32[1]), + ("conv3d", conv_args_fp32[2]), + ("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("mm", mat0_fp32 + mat1_fp32), + ("matmul", mat0_fp32 + mat1_fp32), + ("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32), + torch.randn((n, n, n), device=dev, dtype=torch.float32))), + ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), + ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), + torch.randn((5, 3, 5), device=dev, dtype=torch.float32), + torch.randn(5, device=dev, dtype=torch.float32), 0)), + ("conv_transpose1d", conv_args_fp32[0]), + ("conv_transpose2d", conv_args_fp32[1]), + ("conv_transpose3d", conv_args_fp32[2]), + ("prelu", pointwise0_fp32 + element0_fp32), + ] + + class AutocastCudaTestUnsupportedLists(object): def __init__(self): @@ -301,8 +342,10 @@ class TestAutocastCuda(TestAutocastBase): def setUp(self): super(TestAutocastCuda, self).setUp() - self.is_autocast_enabled = torch.is_autocast_enabled + self.is_autocast_enabled = torch.is_autocast_xla_enabled self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) + self.autocast_lists_extra = AutocastCudaTestExtraLists( + torch.device(xm.xla_device())) self.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() def test_autocast_nn_fp16(self): @@ -334,6 +377,17 @@ def test_autocast_torch_fp32(self): self._run_autocast_outofplace( op, args, torch.float32, add_kwargs=maybe_kwargs) + def test_autocast_torch_bf16(self): + bf16_test_list = [ + tp for tp in getattr(self.autocast_lists_extra, 'torch_bf16') + ] + for op_with_args in bf16_test_list: + op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) + # Expects float16, following the torch GPU autocast policy: + # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.cpp + self._run_autocast_outofplace( + op, args, torch.float16, add_kwargs=maybe_kwargs) + def test_autocast_torch_need_autocast_promote(self): for op, args in self.get_autocast_list('torch_need_autocast_promote'): self._run_autocast_outofplace(op, args, torch.float32) diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index 4394a537fd0..975376e5057 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -1,48 +1,85 @@ import torch import torch_xla.core.xla_model as xm +from torch_xla import runtime as xr from typing import Any import warnings class autocast(torch.amp.autocast_mode.autocast): r""" - See :class:`torch.autocast`. - ``torch_xla.amp.autocast(device, **kwargs)`` is equivalent to - ``torch.autocast("xla", **kwargs)`` for TPUs - ``torch.autocast("cuda", **kwargs)`` for GPUs. - """ + `torch.autocast` for XLA backend devices. See :class:`torch.autocast`. + ``torch_xla.amp.autocast(device, **kwargs)`` is equivalent to + ``torch.autocast("xla", **kwargs)`` for XLA:GPU and XLA:TPU for dtype torch.bfloat16, + ``torch.autocast("cuda", **kwargs)`` for XLA:GPU and other dtypes. + """ def __init__(self, device, enabled: bool = True, dtype: torch.dtype = None, cache_enabled: bool = True): - if xm.xla_device_hw(device) == 'GPU': + # `torch_xla.amp.autocast` is intended for XLA backend, with AutocastXLA dispatch key. + assert 'xla' in device.__str__( + ), "torch_xla.autocast is available for XLA:TPU, XLA:GPU" + + self._enabled = enabled + self._xla_device = xm.xla_device_hw(device) + if self._xla_device == 'GPU': + backend = 'cuda' if dtype is None: dtype = torch.float16 + elif dtype == torch.bfloat16: + if xr.is_bf16_supported() and not torch.cuda.is_available(): + # XLA:GPU with bfloat16 should run on `xla` backend + # unless torch.autocast is compiled with cuda. + backend = 'xla' + else: + # This has been the default behavior for unsupported bfloat16 dtype + dtype = torch.float16 + error_message = "In XLA:GPU autocast, but bfloat16 is not supported on this HW.\n" + error_message += ("Using the default cuda autocast dtype float16.") + self._dtype = dtype super().__init__( - "cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) - elif xm.xla_device_hw(device) == 'TPU': + backend, + enabled=enabled, + dtype=self._dtype, + cache_enabled=cache_enabled) + elif self._xla_device == 'TPU': if dtype is None: dtype = torch.bfloat16 if dtype != torch.bfloat16: - error_message = "In TPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message = "In XLA:TPU autocast, but the target dtype is not supported. Disabling autocast.\n" error_message += ( "TPU Autocast only supports dtype of torch.bfloat16 currently.") warnings.warn(error_message) enabled = False + self._dtype = dtype super().__init__( - "xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) + "xla", + enabled=enabled, + dtype=self._dtype, + cache_enabled=cache_enabled) else: print( 'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.' ) def __enter__(self): + # This ensures that xla autocast is enabled even for XLA:GPU, which calls + # `torch.amp.autocast_mode.autocast` with `cuda` backend. + if self._xla_device == 'GPU': + self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined] + self.prev_dtype = torch.get_autocast_xla_dtype( + ) # type: ignore[attr-defined] + torch.set_autocast_xla_enabled(self._enabled) + torch.set_autocast_xla_dtype(self._dtype) return super().__enter__() def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] + if self._xla_device == 'GPU': + torch.set_autocast_xla_enabled(self.prev) + torch.set_autocast_xla_dtype(self.prev_dtype) return super().__exit__(exc_type, exc_val, exc_tb) def __call__(self, func): diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index ee05ddc14b5..a50e92b55a5 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -84,6 +84,16 @@ def wrapper(*args, **kwargs): return wrapper +def is_bf16_supported(): + """Returns whether torch.bfloat16 is supported on this environment. + """ + try: + torch.tensor([1.], dtype=torch.bfloat16, device=xm.xla_device()) + return True + except Exception as e: + return False + + @requires_pjrt def xla_device(n: Optional[int] = None, devkind: Optional[str] = None) -> torch.device: