Skip to content

Commit

Permalink
Enable xla:gpu autocast for bfloat16 if not restricted (#5570)
Browse files Browse the repository at this point in the history
* Enable autocast for XLA:GPU

* linter fix

* XLA autocast test for GPU and TPU

* linter fix

* Ensure that xla autocast is properly enabled for GPU and does not crash when torch cuda is not available.

* linter fix

* Add tests

* Support bf16

* linter fix

* exclude unsupported test cases

* increase GPU test timeout to 300
  • Loading branch information
yeounoh authored Sep 15, 2023
1 parent 6dd13a1 commit 64bbf15
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
with:
docker-image: ${{ needs.build.outputs.docker-image }}
runner: linux.8xlarge.nvidia.gpu
timeout-minutes: 240
timeout-minutes: 300
disable-xrt: 1
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
Expand Down
20 changes: 20 additions & 0 deletions test/spmd/test_xla_spmd_python_api_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
from torch_xla.amp import autocast

import test_xla_sharding_base

Expand Down Expand Up @@ -112,6 +113,25 @@ def test_runtime_spmd_api(self):
os.environ["XLA_USE_SPMD"] = "1"


class BasicAutocastAPITest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(xr.device_type() not in ['GPU', 'TPU'],
f"TPU/GPU autocast test.")
def test_xla_autocast_api(self):
device = xm.xla_device()
t1 = torch.ones([2, 3], device=device, dtype=torch.float32)
t2 = torch.ones([3, 2], device=device, dtype=torch.float32)
with autocast(device, dtype=torch.bfloat16):
t3 = torch.matmul(t1, t2)
expected_dtype = torch.bfloat16 if xr.is_bf16_supported() else torch.float16
self.assertTrue(t3.dtype == expected_dtype)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
58 changes: 56 additions & 2 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
sys.argv = [sys.argv[0]] + leftovers

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import collections
import unittest
Expand Down Expand Up @@ -152,6 +151,48 @@ def __init__(self, dev):
self.methods_bf16 = [("__matmul__", mat0_bf16 + mat1_fp32)]


class AutocastCudaTestExtraLists(object):

def __init__(self, dev):
super().__init__()
n = 8
dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
torch.randn(dimset, dtype=torch.float32, device=dev))
for dimset in dimsets]

mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)

pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)

element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)

# This is currently not part of AutocastTestLists and excludes `relu`, `addbmm`
self.torch_bf16 = [
("conv1d", conv_args_fp32[0]),
("conv2d", conv_args_fp32[1]),
("conv3d", conv_args_fp32[2]),
("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("mm", mat0_fp32 + mat1_fp32),
("matmul", mat0_fp32 + mat1_fp32),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32),
torch.randn((5, 3, 5), device=dev, dtype=torch.float32),
torch.randn(5, device=dev, dtype=torch.float32), 0)),
("conv_transpose1d", conv_args_fp32[0]),
("conv_transpose2d", conv_args_fp32[1]),
("conv_transpose3d", conv_args_fp32[2]),
("prelu", pointwise0_fp32 + element0_fp32),
]


class AutocastCudaTestUnsupportedLists(object):

def __init__(self):
Expand Down Expand Up @@ -301,8 +342,10 @@ class TestAutocastCuda(TestAutocastBase):

def setUp(self):
super(TestAutocastCuda, self).setUp()
self.is_autocast_enabled = torch.is_autocast_enabled
self.is_autocast_enabled = torch.is_autocast_xla_enabled
self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device()))
self.autocast_lists_extra = AutocastCudaTestExtraLists(
torch.device(xm.xla_device()))
self.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists()

def test_autocast_nn_fp16(self):
Expand Down Expand Up @@ -334,6 +377,17 @@ def test_autocast_torch_fp32(self):
self._run_autocast_outofplace(
op, args, torch.float32, add_kwargs=maybe_kwargs)

def test_autocast_torch_bf16(self):
bf16_test_list = [
tp for tp in getattr(self.autocast_lists_extra, 'torch_bf16')
]
for op_with_args in bf16_test_list:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
# Expects float16, following the torch GPU autocast policy:
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/autocast_mode.cpp
self._run_autocast_outofplace(
op, args, torch.float16, add_kwargs=maybe_kwargs)

def test_autocast_torch_need_autocast_promote(self):
for op, args in self.get_autocast_list('torch_need_autocast_promote'):
self._run_autocast_outofplace(op, args, torch.float32)
Expand Down
57 changes: 47 additions & 10 deletions torch_xla/amp/autocast_mode.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,85 @@
import torch
import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
from typing import Any
import warnings


class autocast(torch.amp.autocast_mode.autocast):
r"""
See :class:`torch.autocast`.
``torch_xla.amp.autocast(device, **kwargs)`` is equivalent to
``torch.autocast("xla", **kwargs)`` for TPUs
``torch.autocast("cuda", **kwargs)`` for GPUs.
"""
`torch.autocast` for XLA backend devices. See :class:`torch.autocast`.
``torch_xla.amp.autocast(device, **kwargs)`` is equivalent to
``torch.autocast("xla", **kwargs)`` for XLA:GPU and XLA:TPU for dtype torch.bfloat16,
``torch.autocast("cuda", **kwargs)`` for XLA:GPU and other dtypes.
"""

def __init__(self,
device,
enabled: bool = True,
dtype: torch.dtype = None,
cache_enabled: bool = True):
if xm.xla_device_hw(device) == 'GPU':
# `torch_xla.amp.autocast` is intended for XLA backend, with AutocastXLA dispatch key.
assert 'xla' in device.__str__(
), "torch_xla.autocast is available for XLA:TPU, XLA:GPU"

self._enabled = enabled
self._xla_device = xm.xla_device_hw(device)
if self._xla_device == 'GPU':
backend = 'cuda'
if dtype is None:
dtype = torch.float16
elif dtype == torch.bfloat16:
if xr.is_bf16_supported() and not torch.cuda.is_available():
# XLA:GPU with bfloat16 should run on `xla` backend
# unless torch.autocast is compiled with cuda.
backend = 'xla'
else:
# This has been the default behavior for unsupported bfloat16 dtype
dtype = torch.float16
error_message = "In XLA:GPU autocast, but bfloat16 is not supported on this HW.\n"
error_message += ("Using the default cuda autocast dtype float16.")
self._dtype = dtype
super().__init__(
"cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
elif xm.xla_device_hw(device) == 'TPU':
backend,
enabled=enabled,
dtype=self._dtype,
cache_enabled=cache_enabled)
elif self._xla_device == 'TPU':
if dtype is None:
dtype = torch.bfloat16
if dtype != torch.bfloat16:
error_message = "In TPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message = "In XLA:TPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"TPU Autocast only supports dtype of torch.bfloat16 currently.")
warnings.warn(error_message)
enabled = False
self._dtype = dtype
super().__init__(
"xla", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
"xla",
enabled=enabled,
dtype=self._dtype,
cache_enabled=cache_enabled)
else:
print(
'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.'
)

def __enter__(self):
# This ensures that xla autocast is enabled even for XLA:GPU, which calls
# `torch.amp.autocast_mode.autocast` with `cuda` backend.
if self._xla_device == 'GPU':
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_dtype = torch.get_autocast_xla_dtype(
) # type: ignore[attr-defined]
torch.set_autocast_xla_enabled(self._enabled)
torch.set_autocast_xla_dtype(self._dtype)
return super().__enter__()

def __exit__(self, exc_type: Any, exc_val: Any,
exc_tb: Any): # type: ignore[override]
if self._xla_device == 'GPU':
torch.set_autocast_xla_enabled(self.prev)
torch.set_autocast_xla_dtype(self.prev_dtype)
return super().__exit__(exc_type, exc_val, exc_tb)

def __call__(self, func):
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def wrapper(*args, **kwargs):
return wrapper


def is_bf16_supported():
"""Returns whether torch.bfloat16 is supported on this environment.
"""
try:
torch.tensor([1.], dtype=torch.bfloat16, device=xm.xla_device())
return True
except Exception as e:
return False


@requires_pjrt
def xla_device(n: Optional[int] = None,
devkind: Optional[str] = None) -> torch.device:
Expand Down

0 comments on commit 64bbf15

Please sign in to comment.