Skip to content

Commit

Permalink
update according to the comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anw90 committed Nov 29, 2023
1 parent 4348dcb commit 9c12351
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/source/concept_guides/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Why is this important? Under the hood this will set **5** different seed setting
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if is_troch_xla_available():
if is_torch_xla_available():
xm.set_rng_state(seed)
```

Expand Down
8 changes: 5 additions & 3 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,16 @@ def __init__(
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
):
self.native_amp = True
if self.device.type not in ("xpu", "cuda", "mps", "npu", "xla") or is_torch_xla_available(tuple(["TPU"])):
if self.device.type not in ("xpu", "cuda", "mps", "npu", "xla") or is_torch_xla_available(
check_is_tpu=True
):
raise ValueError(err.format(mode="fp16", requirement="a GPU"))
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
if self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

self.scaler = ShardedGradScaler(**kwargs)
elif is_torch_xla_available(tuple(["GPU"])):
elif is_torch_xla_available(check_is_gpu=True):
self.scaler = xamp.GradScaler(**kwargs)
elif is_npu_available():
self.scaler = torch.npu.amp.GradScaler(**kwargs)
Expand All @@ -447,7 +449,7 @@ def __init__(
self.native_amp = True
else:
self.native_amp = is_bf16_available(True)
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available(tuple(["GPU"])):
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available(check_is_gpu=True):
raise ValueError(err.format(mode="bf16", requirement="PyTorch >= 1.10 and a supported device."))

# Start of internal step tracking
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def _validate_launch_command(args):
if (
args.mixed_precision == "bf16"
and not native_amp
and not (args.tpu and is_torch_xla_available(tuple(["TPU"])))
and not (args.tpu and is_torch_xla_available(check_is_tpu=True))
):
raise ValueError(err.format(mode="bf16", requirement="PyTorch >= 1.10 and a supported device."))

Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(self, cpu: bool = False, **kwargs):
xm.set_replication(self.device, [self.device])
self.num_processes = xm.xrt_world_size()
self.process_index = xm.get_ordinal()
if is_torch_xla_available(tuple("TPU")):
if is_torch_xla_available(check_is_tpu=True):
self.local_process_index = xm.get_local_ordinal()
else:
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
Expand Down Expand Up @@ -760,7 +760,7 @@ def __init__(
)
# deepspeed handles mixed_precision using deepspeed_config
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
if self.distributed_type == DistributedType.TPU and is_torch_xla_available(tuple(["TPU"])):
if self.distributed_type == DistributedType.TPU and is_torch_xla_available(check_is_tpu=True):
if mixed_precision == "bf16":
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
os.environ["XLA_USE_BF16"] = str(0)
Expand Down
65 changes: 48 additions & 17 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@
from .versions import compare_versions, is_torch_version


ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", default=True)

try:
if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
_torch_xla_available = False
if USE_TORCH_XLA:
try:
import torch_xla.core.xla_model as xm # noqa: F401

_torch_xla_available = True
else:
_torch_xla_available = False
except ImportError:
_torch_xla_available = False
except ImportError:
pass

# Keep it for is_tpu_available. It will be removed along with is_tpu_available.
_tpu_available = _torch_xla_available

# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()
Expand Down Expand Up @@ -97,19 +97,50 @@ def is_cuda_available():


@lru_cache
def is_torch_xla_available(hardware_types=("TPU", "GPU")):
def is_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
# Due to bugs on the amp series GPUs, we disable torch-xla on them
warnings.warn(
"The `is_tpu_available` is deprecated and will be removed in v0.27.0. "
"Please use the `is_torch_xla_available` instead.",
FutureWarning,
)
if is_cuda_available():
return False
if check_device:
if _tpu_available:
try:
# Will raise a RuntimeError if no XLA configuration is found
_ = xm.xla_device()
return True
except RuntimeError:
return False
return _tpu_available


@lru_cache
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
"""
Check if `torch_xla` is available and real hardware in `hardware_types`. To train a native pytorch job in an
environment with torch xla installed, set the USE_TORCH_XLA to false.
Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
the USE_TORCH_XLA to false.
"""
if USE_TORCH_XLA not in ENV_VARS_TRUE_VALUES:
assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."

if not USE_TORCH_XLA:
return False

if _torch_xla_available:
try:
xla_device = xm.xla_device()
return xm.xla_device_hw(xla_device) in hardware_types

return False
hardware_type = xm.xla_device_hw(xla_device)
return any(
[
check_is_tpu and hardware_type == "TPU",
check_is_gpu and hardware_type == "GPU",
not (check_is_tpu or check_is_gpu),
]
)
except RuntimeError:
return False


def is_deepspeed_available():
Expand All @@ -118,7 +149,7 @@ def is_deepspeed_available():

def is_bf16_available(ignore_tpu=False):
"Checks if bf16 is supported, optionally ignoring the TPU"
if is_torch_xla_available(tuple(["TPU"])):
if is_torch_xla_available(check_is_tpu=True):
return not ignore_tpu
if torch.cuda.is_available():
return torch.cuda.is_bf16_supported()
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def prepare_tpu(
"""
Prepares and returns an environment with the correct TPU environment variables.
"""
if args.mixed_precision == "bf16" and is_torch_xla_available(tuple(["TPU"])):
if args.mixed_precision == "bf16" and is_torch_xla_available(check_is_tpu=True):
if args.downcast_bf16:
current_env["XLA_DOWNCAST_BF16"] = "1"
else:
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,7 +1455,7 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
if native_amp:
device_type = (
"cuda"
if (state.distributed_type == DistributedType.TPU and is_torch_xla_available(tuple(["GPU"])))
if (state.distributed_type == DistributedType.TPU and is_torch_xla_available(check_is_gpu=True))
else state.device.type
)
if state.mixed_precision == "fp16":
Expand Down

0 comments on commit 9c12351

Please sign in to comment.