Skip to content

Commit

Permalink
Rename _cuda_bfloat16 to _xla_bfloat16 since it is set when xla backe…
Browse files Browse the repository at this point in the history
…nd is used for bfloat16.
  • Loading branch information
yeounoh committed Sep 18, 2023
1 parent 50c1d2f commit dc92243
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch_xla/amp/autocast_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def __init__(self,
self._xla_device = xm.xla_device_hw(device)
if self._xla_device == 'GPU':
backend = 'cuda'
self._cuda_bfloat16 = False
self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype.
if dtype is None:
dtype = torch.float16
elif dtype == torch.bfloat16 and not torch.cuda.is_available():
if xr.is_bf16_supported():
# XLA:GPU with bfloat16 should run on `xla` backend
# unless torch.autocast is compiled with cuda.
backend = 'xla'
self._cuda_bfloat16 = True
self._xla_bfloat16 = True
else:
# This has been the default behavior for unsupported bfloat16 dtype
dtype = torch.float16
Expand Down Expand Up @@ -74,7 +74,7 @@ def __enter__(self):
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_dtype = torch.get_autocast_xla_dtype(
) # type: ignore[attr-defined]
if self._cuda_bfloat16:
if self._xla_bfloat16:
torch.set_autocast_enabled(self._enabled)
torch.set_autocast_gpu_dtype(self._dtype)
else:
Expand All @@ -85,7 +85,7 @@ def __enter__(self):
def __exit__(self, exc_type: Any, exc_val: Any,
exc_tb: Any): # type: ignore[override]
if self._xla_device == 'GPU':
if self._cuda_bfloat16:
if self._xla_bfloat16:
torch.set_autocast_enabled(self.prev)
torch.set_autocast_gpu_dtype(self.prev_dtype)
else:
Expand Down

0 comments on commit dc92243

Please sign in to comment.