From dc9224336af6cf6313e60d197dab47323e8e509d Mon Sep 17 00:00:00 2001 From: Yeounoh Chung Date: Mon, 18 Sep 2023 12:02:15 -0700 Subject: [PATCH] Rename _cuda_bfloat16 to _xla_bfloat16 since it is set when xla backend is used for bfloat16. --- torch_xla/amp/autocast_mode.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/amp/autocast_mode.py b/torch_xla/amp/autocast_mode.py index 08327effb6ef..27e878567d65 100644 --- a/torch_xla/amp/autocast_mode.py +++ b/torch_xla/amp/autocast_mode.py @@ -27,7 +27,7 @@ def __init__(self, self._xla_device = xm.xla_device_hw(device) if self._xla_device == 'GPU': backend = 'cuda' - self._cuda_bfloat16 = False + self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype. if dtype is None: dtype = torch.float16 elif dtype == torch.bfloat16 and not torch.cuda.is_available(): @@ -35,7 +35,7 @@ def __init__(self, # XLA:GPU with bfloat16 should run on `xla` backend # unless torch.autocast is compiled with cuda. backend = 'xla' - self._cuda_bfloat16 = True + self._xla_bfloat16 = True else: # This has been the default behavior for unsupported bfloat16 dtype dtype = torch.float16 @@ -74,7 +74,7 @@ def __enter__(self): self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined] self.prev_dtype = torch.get_autocast_xla_dtype( ) # type: ignore[attr-defined] - if self._cuda_bfloat16: + if self._xla_bfloat16: torch.set_autocast_enabled(self._enabled) torch.set_autocast_gpu_dtype(self._dtype) else: @@ -85,7 +85,7 @@ def __enter__(self): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] if self._xla_device == 'GPU': - if self._cuda_bfloat16: + if self._xla_bfloat16: torch.set_autocast_enabled(self.prev) torch.set_autocast_gpu_dtype(self.prev_dtype) else: