Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make torch xla available on GPU #29334

Merged
merged 7 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/de/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ Dekorateure werden verwendet, um die Anforderungen von Tests in Bezug auf CPU/GP
- `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich
- `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs
- `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs
- `require_torch_tpu` - wie `require_torch` plus erfordert mindestens 1 TPU
- `require_torch_xla` - wie `require_torch` plus erfordert mindestens 1 TPU

Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise:
- `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs
- `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs
- `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs
- `require_torch_tpu` - as `require_torch` plus requires at least 1 TPU
- `require_torch_xla` - as `require_torch` plus requires at least 1 TPU

Let's depict the GPU requirements in the following table:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/ja/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。
- `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。
- `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。
- `require_torch_tpu` - `require_torch` に加えて、少なくとも1つのTPUが必要です。
- `require_torch_xla` - `require_torch` に加えて、少なくとも1つのTPUが必要です。

以下の表にGPUの要件を示します:

Expand Down
2 changes: 1 addition & 1 deletion docs/source/ko/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py
- `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다.
- `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다.
- `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다.
- `require_torch_tpu` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.
- `require_torch_xla` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다.

GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ:

Expand Down
4 changes: 2 additions & 2 deletions examples/legacy/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_tpu_available
from transformers.utils import is_torch_xla_available


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -135,7 +135,7 @@ def _get_lr_scheduler(self, num_training_steps):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
elif is_torch_xla_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't equivalent, previously, we were checking for a device, but by default that isn't happening anymore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to align with PR of the accelerate library. If users do not wish to use torch_xla in an environment where torch_xla is installed, they can configure it using USE_TORCH_XLA, which is also the purpose of this PR.

return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
Trainer,
TrainingArguments,
default_data_collator,
is_torch_tpu_available,
is_torch_xla_available,
set_seed,
)
from transformers.testing_utils import CaptureLogger
Expand Down Expand Up @@ -602,9 +602,9 @@ def compute_metrics(eval_preds):
tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
if training_args.do_eval and not is_torch_xla_available()
else None,
)

Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
HfArgumentParser,
Trainer,
TrainingArguments,
is_torch_tpu_available,
is_torch_xla_available,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
Expand Down Expand Up @@ -620,9 +620,9 @@ def compute_metrics(eval_preds):
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
if training_args.do_eval and not is_torch_xla_available()
else None,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/old_test_xla_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from time import time
from unittest.mock import patch

from transformers.testing_utils import TestCasePlus, require_torch_tpu
from transformers.testing_utils import TestCasePlus, require_torch_xla


logging.basicConfig(level=logging.DEBUG)
Expand All @@ -44,7 +44,7 @@ def get_results(output_dir):
logger.addHandler(stream_handler)


@require_torch_tpu
@require_torch_xla
class TorchXLAExamplesTests(TestCasePlus):
def test_run_glue(self):
import xla_spawn
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/question-answering/trainer_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import math
import time

from transformers import Trainer, is_torch_tpu_available
from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics


if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/question-answering/trainer_seq2seq_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

from torch.utils.data import Dataset

from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers import Seq2SeqTrainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput, speed_metrics


if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import torch
from torch.utils.data import DataLoader

from transformers import Trainer, is_torch_tpu_available
from transformers import Trainer, is_torch_xla_available
from transformers.trainer_utils import PredictionOutput


logger = logging.getLogger(__name__)

if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,7 @@
"is_torch_npu_available",
"is_torch_tpu_available",
"is_torchvision_available",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep whilst it's still being deprecated

Suggested change
"is_torchvision_available",
"is_torch_tpu_available",
"is_torchvision_available",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"is_torch_xla_available",
"is_torch_xpu_available",
"is_vision_available",
"logging",
Expand Down Expand Up @@ -5895,6 +5896,7 @@
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xla_available,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_torch_xla_available,
is_torch_tpu_available,
is_torch_xla_available,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

is_torch_xpu_available,
is_torchvision_available,
is_vision_available,
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/benchmark/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..utils import (
cached_property,
is_torch_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
logging,
requires_backends,
Expand All @@ -31,7 +31,7 @@
if is_torch_available():
import torch

if is_torch_tpu_available(check_device=False):
if is_torch_xla_available():
import torch_xla.core.xla_model as xm


Expand Down Expand Up @@ -88,7 +88,7 @@ def _setup_devices(self) -> Tuple["torch.device", int]:
if not self.cuda:
device = torch.device("cpu")
n_gpu = 0
elif is_torch_tpu_available():
elif is_torch_xla_available():
device = xm.xla_device()
n_gpu = 0
elif is_torch_xpu_available():
Expand All @@ -101,7 +101,7 @@ def _setup_devices(self) -> Tuple["torch.device", int]:

@property
def is_tpu(self):
return is_torch_tpu_available() and self.tpu
return is_torch_xla_available() and self.tpu

@property
def device_idx(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
is_torch_fx_proxy,
is_torch_mps_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torchaudio_available,
is_training_run_on_sagemaker,
is_vision_available,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402


# Integration functions:
Expand Down Expand Up @@ -752,7 +752,7 @@ def setup(self, args, state, model, **kwargs):

# keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer")

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/integrations/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from torch.utils.data import DataLoader

from ..utils import is_torch_tpu_available
from ..utils import is_torch_xla_available


def tpu_spmd_dataloader(dataloader: DataLoader):
if is_torch_tpu_available():
if is_torch_xla_available():
import torch_xla.distributed.parallel_loader as pl

assert isinstance(
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
is_remote_url,
is_safetensors_available,
is_torch_sdpa_available,
is_torch_tpu_available,
is_torch_xla_available,
logging,
replace_return_docstrings,
strtobool,
Expand Down Expand Up @@ -246,10 +246,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
# Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
# NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
# NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available():
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
if t.dtype == torch.float:
return torch.bfloat16
if t.dtype == torch.double:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from safetensors.torch import storage_ptr, storage_size
from torch import nn

from .utils import is_torch_tpu_available, logging
from .utils import is_torch_xla_available, logging


ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
Expand Down Expand Up @@ -282,7 +282,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
if tensor.device.type == "xla" and is_torch_tpu_available():
if tensor.device.type == "xla" and is_torch_xla_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdynamo_available,
Expand Down Expand Up @@ -725,11 +725,11 @@ def require_torch_up_to_2_accelerators(test_case):
(test_case)


def require_torch_tpu(test_case):
def require_torch_xla(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
Decorator marking a test that requires TorchXLA (in PyTorch).
"""
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)


def require_torch_neuroncore(test_case):
Expand Down
Loading