Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make torch xla available on GPU #2176

Merged
merged 18 commits into from
Feb 14, 2024
2 changes: 1 addition & 1 deletion docs/source/concept_guides/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Why is this important? Under the hood this will set **5** different seed setting
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if is_tpu_available():
if is_torch_xla_available():
xm.set_rng_state(seed)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/package_reference/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ These functionalities check the state of the current working environment includi

[[autodoc]] utils.is_torch_version

[[autodoc]] utils.is_tpu_available
[[autodoc]] utils.is_torch_xla_available

[[autodoc]] utils.is_xpu_available

Expand Down
2 changes: 1 addition & 1 deletion docs/source/quicktour.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ To introduce special behavior in your script for TPUs you can check the `distrib
```python docstyle-ignore
from accelerate import DistributedType

if accelerator.distributed_type == DistributedType.TPU:
if accelerator.distributed_type == DistributedType.XLA:
# do something of static shape
else:
# go crazy and be dynamic
Expand Down
4 changes: 2 additions & 2 deletions examples/by_feature/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down Expand Up @@ -153,7 +153,7 @@ def training_function(config, args):

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

Expand Down
4 changes: 2 additions & 2 deletions examples/by_feature/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down Expand Up @@ -156,7 +156,7 @@ def training_function(config, args):

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

Expand Down
2 changes: 1 addition & 1 deletion examples/by_feature/deepspeed_with_config_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def group_texts(examples):
optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate)

# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
if accelerator.distributed_type == DistributedType.XLA:
model.tie_weights()

# Scheduler and math around the number of training steps.
Expand Down
4 changes: 2 additions & 2 deletions examples/by_feature/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down Expand Up @@ -150,7 +150,7 @@ def training_function(config, args):

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

Expand Down
4 changes: 2 additions & 2 deletions examples/by_feature/fsdp_with_peak_mem_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,13 @@ def tokenize_function(examples):

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down
4 changes: 2 additions & 2 deletions examples/by_feature/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down Expand Up @@ -125,7 +125,7 @@ def training_function(config, args):
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps
)
if accelerator.distributed_type == DistributedType.TPU and gradient_accumulation_steps > 1:
if accelerator.distributed_type == DistributedType.XLA and gradient_accumulation_steps > 1:
raise NotImplementedError(
"Gradient accumulation on TPUs is currently not supported. Pass `gradient_accumulation_steps=1`"
)
Expand Down
2 changes: 1 addition & 1 deletion examples/by_feature/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down
2 changes: 1 addition & 1 deletion examples/by_feature/megatron_lm_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def group_texts(examples):
)

# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
if accelerator.distributed_type == DistributedType.XLA:
model.tie_weights()

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
Expand Down
2 changes: 1 addition & 1 deletion examples/by_feature/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down
4 changes: 2 additions & 2 deletions examples/by_feature/multi_process_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down Expand Up @@ -138,7 +138,7 @@ def training_function(config, args):

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

Expand Down
4 changes: 2 additions & 2 deletions examples/by_feature/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def tokenize_function(examples):

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down Expand Up @@ -148,7 +148,7 @@ def training_function(config, args):

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

Expand Down
4 changes: 2 additions & 2 deletions examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ def tokenize_function(examples):

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down
6 changes: 3 additions & 3 deletions examples/nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def tokenize_function(examples):
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
# For Torchxla, it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
Expand Down Expand Up @@ -123,7 +123,7 @@ def training_function(config, args):

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

Expand Down
44 changes: 29 additions & 15 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
is_msamp_available,
is_npu_available,
is_torch_version,
is_tpu_available,
is_torch_xla_available,
is_xpu_available,
load_fsdp_model,
load_fsdp_optimizer,
Expand Down Expand Up @@ -133,7 +133,8 @@
from torch.distributed.algorithms.join import Join


if is_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.amp as xamp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

Expand Down Expand Up @@ -397,7 +398,7 @@ def __init__(
if (
(mixed_precision != "bf16")
and getattr(self.state, "downcast_bfloat", False)
and (self.state.distributedType != DistributedType.TPU)
and (self.state.distributedType != DistributedType.XLA)
):
raise ValueError("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU")

Expand All @@ -414,7 +415,7 @@ def __init__(
self.gradient_state = GradientState(
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
if self.state.distributed_type == DistributedType.TPU:
if self.state.distributed_type == DistributedType.XLA:
if self.gradient_state.num_steps != 1:
raise ValueError(
"Gradient accumulation is not supported on TPU. Please set `gradient_accumulation_steps` to 1 and don't pass in a `GradientAccumulationPlugin` object."
Expand All @@ -436,13 +437,17 @@ def __init__(
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
):
self.native_amp = True
if self.device.type not in ("xpu", "cuda", "mps", "npu"):
if self.device.type not in ("xpu", "cuda", "mps", "npu", "xla") or is_torch_xla_available(
check_is_tpu=True
):
raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
if self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

self.scaler = ShardedGradScaler(**kwargs)
elif is_torch_xla_available(check_is_gpu=True):
self.scaler = xamp.GradScaler(**kwargs)
elif is_npu_available():
self.scaler = torch.npu.amp.GradScaler(**kwargs)
else:
Expand All @@ -456,7 +461,7 @@ def __init__(
self.native_amp = True
else:
self.native_amp = is_bf16_available(True)
if mixed_precision == "bf16" and not self.native_amp and not is_tpu_available():
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available(check_is_gpu=True):
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")

# Start of internal step tracking
Expand Down Expand Up @@ -1193,7 +1198,7 @@ def prepare(self, *args, device_placement=None):
# On TPUs, putting the model on the XLA device will create new parameters, so the corresponding optimizer will
# have parameters disconnected from the model (so no training :-( ).
# If the model and optimizer have parameters on different devices we raise an error.
if self.distributed_type == DistributedType.TPU:
if self.distributed_type == DistributedType.XLA:
model_device, optimizer_device = self._get_devices()
if model_device is not None and optimizer_device is not None and model_device != optimizer_device:
raise ValueError(
Expand All @@ -1205,7 +1210,7 @@ def prepare(self, *args, device_placement=None):
)

# If we're dealing with device placement, this deals with that by...
tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.TPU
tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA
if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
# 1. grabbing old model parameters
old_named_params = self._get_named_parameters(*args)
Expand Down Expand Up @@ -1406,7 +1411,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
elif self.distributed_type == DistributedType.TPU and self.state.fork_launched:
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
Expand Down Expand Up @@ -1843,7 +1848,7 @@ def prepare_data_loader(
self._dataloaders.append(data_loader)
return data_loader
if device_placement is None:
device_placement = self.device_placement if self.distributed_type != DistributedType.TPU else False
device_placement = self.device_placement if self.distributed_type != DistributedType.XLA else False
prepared_data_loader = prepare_data_loader(
data_loader,
self.device,
Expand Down Expand Up @@ -2056,10 +2061,6 @@ def unscale_gradients(self, optimizer=None):
for opt in optimizer:
while isinstance(opt, AcceleratedOptimizer):
opt = opt.optimizer
# Reduce gradients first for XLA
if self.distributed_type == DistributedType.TPU:
gradients = xm._fetch_gradients(opt)
self.reduce(gradients, scale=1.0 / self.num_processes)
self.scaler.unscale_(opt)

def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
Expand Down Expand Up @@ -2097,6 +2098,19 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
# `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
# We cannot return the gradient norm because DeepSpeed does it.
return None
elif self.distributed_type == DistributedType.XLA:
# Reduce gradients first for XLA
for acc_opt in self._optimizers:
if not acc_opt.gradient_state.is_xla_gradients_synced:
opt = acc_opt
while isinstance(opt, AcceleratedOptimizer):
opt = opt.optimizer
gradients = xm._fetch_gradients(opt)
# Use xm.all_reduce to perform an in-place all-reduce. Recusrsive all-reduce each tensor
# one by one in self.reduce is non-inplace.
xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes)
# Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
acc_opt.gradient_state.is_xla_gradients_synced = True
self.unscale_gradients()
return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)

Expand Down Expand Up @@ -2714,7 +2728,7 @@ def _inner(folder):
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving current state to {output_dir}")

if self.distributed_type == DistributedType.TPU:
if self.distributed_type == DistributedType.XLA:
# Finish running the previous step before checkpointing
xm.mark_step()

Expand Down
Loading
Loading