Skip to content

Commit

Permalink
Make torch xla available on GPU (#29334)
Browse files Browse the repository at this point in the history
* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------

Co-authored-by: anw90 <[email protected]>
Co-authored-by: wangang.wa <[email protected]>
  • Loading branch information
3 people authored Mar 11, 2024
1 parent 9a3f4d4 commit 873d9bb
Show file tree
Hide file tree
Showing 25 changed files with 120 additions and 77 deletions.
2 changes: 1 addition & 1 deletion docs/source/de/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ Dekorateure werden verwendet, um die Anforderungen von Tests in Bezug auf CPU/GP
- `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich
- `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs
- `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs
- `require_torch_tpu` - wie `require_torch` plus erfordert mindestens 1 TPU
- `require_torch_xla` - wie `require_torch` plus erfordert mindestens 1 TPU

Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise:
- `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs
- `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs
- `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs
- `require_torch_tpu` - as `require_torch` plus requires at least 1 TPU
- `require_torch_xla` - as `require_torch` plus requires at least 1 TPU

Let's depict the GPU requirements in the following table:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/ja/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。
- `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。
- `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。
- `require_torch_tpu` - `require_torch` に加えて、少なくとも1つのTPUが必要です。
- `require_torch_xla` - `require_torch` に加えて、少なくとも1つのTPUが必要です。

以下の表にGPUの要件を示します:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/ko/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다.
- `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다.
- `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다.
- `require_torch_tpu` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.
- `require_torch_xla` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.

GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ:

Expand Down
4 changes: 2 additions & 2 deletions examples/legacy/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_tpu_available
from transformers.utils import is_torch_xla_available


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -135,7 +135,7 @@ def _get_lr_scheduler(self, num_training_steps):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
elif is_torch_xla_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
Trainer,
TrainingArguments,
default_data_collator,
is_torch_tpu_available,
is_torch_xla_available,
set_seed,
)
from transformers.testing_utils import CaptureLogger
Expand Down Expand Up @@ -602,9 +602,9 @@ def compute_metrics(eval_preds):
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
if training_args.do_eval and not is_torch_xla_available()
else None,
)

Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
HfArgumentParser,
Trainer,
TrainingArguments,
is_torch_tpu_available,
is_torch_xla_available,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
Expand Down Expand Up @@ -620,9 +620,9 @@ def compute_metrics(eval_preds):
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
if training_args.do_eval and not is_torch_xla_available()
else None,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/old_test_xla_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from time import time
from unittest.mock import patch

from transformers.testing_utils import TestCasePlus, require_torch_tpu
from transformers.testing_utils import TestCasePlus, require_torch_xla


logging.basicConfig(level=logging.DEBUG)
Expand All @@ -44,7 +44,7 @@ def get_results(output_dir):
logger.addHandler(stream_handler)


@require_torch_tpu
@require_torch_xla
class TorchXLAExamplesTests(TestCasePlus):
def test_run_glue(self):
import xla_spawn
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/question-answering/trainer_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import math
import time

from transformers import Trainer, is_torch_tpu_available
from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics


if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/question-answering/trainer_seq2seq_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

from torch.utils.data import Dataset

from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers import Seq2SeqTrainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics


if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import torch
from torch.utils.data import DataLoader

from transformers import Trainer, is_torch_tpu_available
from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput


logger = logging.getLogger(__name__)

if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,7 @@
"is_torch_npu_available",
"is_torch_tpu_available",
"is_torchvision_available",
"is_torch_xla_available",
"is_torch_xpu_available",
"is_vision_available",
"logging",
Expand Down Expand Up @@ -5897,6 +5898,7 @@
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchvision_available,
is_vision_available,
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/benchmark/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..utils import (
cached_property,
is_torch_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
logging,
requires_backends,
Expand All @@ -31,7 +31,7 @@
if is_torch_available():
import torch

if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm


Expand Down Expand Up @@ -88,7 +88,7 @@ def _setup_devices(self) -> Tuple["torch.device", int]:
if not self.cuda:
device = torch.device("cpu")
n_gpu = 0
elif is_torch_tpu_available():
elif is_torch_xla_available():
device = xm.xla_device()
n_gpu = 0
elif is_torch_xpu_available():
Expand All @@ -101,7 +101,7 @@ def _setup_devices(self) -> Tuple["torch.device", int]:

@property
def is_tpu(self):
return is_torch_tpu_available() and self.tpu
return is_torch_xla_available() and self.tpu

@property
def device_idx(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
is_torch_fx_proxy,
is_torch_mps_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torchaudio_available,
is_training_run_on_sagemaker,
is_vision_available,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402


# Integration functions:
Expand Down Expand Up @@ -752,7 +752,7 @@ def setup(self, args, state, model, **kwargs):

# keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer")

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/integrations/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from torch.utils.data import DataLoader

from ..utils import is_torch_tpu_available
from ..utils import is_torch_xla_available


def tpu_spmd_dataloader(dataloader: DataLoader):
if is_torch_tpu_available():
if is_torch_xla_available():
import torch_xla.distributed.parallel_loader as pl

assert isinstance(
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
is_remote_url,
is_safetensors_available,
is_torch_sdpa_available,
is_torch_tpu_available,
is_torch_xla_available,
logging,
replace_return_docstrings,
strtobool,
Expand Down Expand Up @@ -246,10 +246,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
# NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
# NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
if t.dtype == torch.float:
return torch.bfloat16
if t.dtype == torch.double:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from safetensors.torch import storage_ptr, storage_size
from torch import nn

from .utils import is_torch_tpu_available, logging
from .utils import is_torch_xla_available, logging


ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
Expand Down Expand Up @@ -282,7 +282,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
if tensor.device.type == "xla" and is_torch_tpu_available():
if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdynamo_available,
Expand Down Expand Up @@ -733,11 +733,11 @@ def require_torch_up_to_2_accelerators(test_case):
(test_case)


def require_torch_tpu(test_case):
def require_torch_xla(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
Decorator marking a test that requires TorchXLA (in PyTorch).
"""
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)


def require_torch_neuroncore(test_case):
Expand Down
Loading

0 comments on commit 873d9bb

Please sign in to comment.