diff --git a/docs/source/concept_guides/performance.md b/docs/source/concept_guides/performance.md index 89926ef57c8..8b112005365 100644 --- a/docs/source/concept_guides/performance.md +++ b/docs/source/concept_guides/performance.md @@ -45,7 +45,7 @@ Why is this important? Under the hood this will set **5** different seed setting torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available - if is_troch_xla_available(): + if is_torch_xla_available(): xm.set_rng_state(seed) ``` diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index c5b0cb2c050..5ea9ec2fcfb 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -425,14 +425,16 @@ def __init__( and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM) ): self.native_amp = True - if self.device.type not in ("xpu", "cuda", "mps", "npu", "xla") or is_torch_xla_available(tuple(["TPU"])): + if self.device.type not in ("xpu", "cuda", "mps", "npu", "xla") or is_torch_xla_available( + check_is_tpu=True + ): raise ValueError(err.format(mode="fp16", requirement="a GPU")) kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {} if self.distributed_type == DistributedType.FSDP: from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler self.scaler = ShardedGradScaler(**kwargs) - elif is_torch_xla_available(tuple(["GPU"])): + elif is_torch_xla_available(check_is_gpu=True): self.scaler = xamp.GradScaler(**kwargs) elif is_npu_available(): self.scaler = torch.npu.amp.GradScaler(**kwargs) @@ -447,7 +449,7 @@ def __init__( self.native_amp = True else: self.native_amp = is_bf16_available(True) - if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available(tuple(["GPU"])): + if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available(check_is_gpu=True): raise ValueError(err.format(mode="bf16", requirement="PyTorch >= 1.10 and a supported device.")) # Start of internal step tracking diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 81254b3af55..80088e89dd3 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -904,7 +904,7 @@ def _validate_launch_command(args): if ( args.mixed_precision == "bf16" and not native_amp - and not (args.tpu and is_torch_xla_available(tuple(["TPU"]))) + and not (args.tpu and is_torch_xla_available(check_is_tpu=True)) ): raise ValueError(err.format(mode="bf16", requirement="PyTorch >= 1.10 and a supported device.")) diff --git a/src/accelerate/state.py b/src/accelerate/state.py index ffd8a55ed81..debfb027bdc 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -158,7 +158,7 @@ def __init__(self, cpu: bool = False, **kwargs): xm.set_replication(self.device, [self.device]) self.num_processes = xm.xrt_world_size() self.process_index = xm.get_ordinal() - if is_torch_xla_available(tuple("TPU")): + if is_torch_xla_available(check_is_tpu=True): self.local_process_index = xm.get_local_ordinal() else: self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) @@ -760,7 +760,7 @@ def __init__( ) # deepspeed handles mixed_precision using deepspeed_config self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision - if self.distributed_type == DistributedType.TPU and is_torch_xla_available(tuple(["TPU"])): + if self.distributed_type == DistributedType.TPU and is_torch_xla_available(check_is_tpu=True): if mixed_precision == "bf16": if os.environ.get("ACCELERATE_DOWNCAST_BF16"): os.environ["XLA_USE_BF16"] = str(0) diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 829c7ed03f2..618216bc69c 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -26,20 +26,20 @@ from .versions import compare_versions, is_torch_version -ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} # Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0. -USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper() +USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", default=True) -try: - if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES: +_torch_xla_available = False +if USE_TORCH_XLA: + try: import torch_xla.core.xla_model as xm # noqa: F401 _torch_xla_available = True - else: - _torch_xla_available = False -except ImportError: - _torch_xla_available = False + except ImportError: + pass +# Keep it for is_tpu_available. It will be removed along with is_tpu_available. +_tpu_available = _torch_xla_available # Cache this result has it's a C FFI call which can be pretty time-consuming _torch_distributed_available = torch.distributed.is_available() @@ -97,19 +97,50 @@ def is_cuda_available(): @lru_cache -def is_torch_xla_available(hardware_types=("TPU", "GPU")): +def is_tpu_available(check_device=True): + "Checks if `torch_xla` is installed and potentially if a TPU is in the environment" + # Due to bugs on the amp series GPUs, we disable torch-xla on them + warnings.warn( + "The `is_tpu_available` is deprecated and will be removed in v0.27.0. " + "Please use the `is_torch_xla_available` instead.", + FutureWarning, + ) + if is_cuda_available(): + return False + if check_device: + if _tpu_available: + try: + # Will raise a RuntimeError if no XLA configuration is found + _ = xm.xla_device() + return True + except RuntimeError: + return False + return _tpu_available + + +@lru_cache +def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): """ - Check if `torch_xla` is available and real hardware in `hardware_types`. To train a native pytorch job in an - environment with torch xla installed, set the USE_TORCH_XLA to false. + Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set + the USE_TORCH_XLA to false. """ - if USE_TORCH_XLA not in ENV_VARS_TRUE_VALUES: + assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." + + if not USE_TORCH_XLA: return False - if _torch_xla_available: + try: xla_device = xm.xla_device() - return xm.xla_device_hw(xla_device) in hardware_types - - return False + hardware_type = xm.xla_device_hw(xla_device) + return any( + [ + check_is_tpu and hardware_type == "TPU", + check_is_gpu and hardware_type == "GPU", + not (check_is_tpu or check_is_gpu), + ] + ) + except RuntimeError: + return False def is_deepspeed_available(): @@ -118,7 +149,7 @@ def is_deepspeed_available(): def is_bf16_available(ignore_tpu=False): "Checks if bf16 is supported, optionally ignoring the TPU" - if is_torch_xla_available(tuple(["TPU"])): + if is_torch_xla_available(check_is_tpu=True): return not ignore_tpu if torch.cuda.is_available(): return torch.cuda.is_bf16_supported() diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index e1f515009c9..ed9c249243f 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -334,7 +334,7 @@ def prepare_tpu( """ Prepares and returns an environment with the correct TPU environment variables. """ - if args.mixed_precision == "bf16" and is_torch_xla_available(tuple(["TPU"])): + if args.mixed_precision == "bf16" and is_torch_xla_available(check_is_tpu=True): if args.downcast_bf16: current_env["XLA_DOWNCAST_BF16"] = "1" else: diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index 3279d23ba06..a8419d314ba 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1455,7 +1455,7 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg if native_amp: device_type = ( "cuda" - if (state.distributed_type == DistributedType.TPU and is_torch_xla_available(tuple(["GPU"]))) + if (state.distributed_type == DistributedType.TPU and is_torch_xla_available(check_is_gpu=True)) else state.device.type ) if state.mixed_precision == "fp16":