From fccdd0d3a5d69b9aa49a59729b1d0bfe46b07435 Mon Sep 17 00:00:00 2001 From: anw90 Date: Fri, 10 Nov 2023 16:16:46 +0800 Subject: [PATCH 1/7] add USE_TORCH_XLA env --- src/transformers/utils/import_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index db2278fc5f585c..fd5266f3f4a3b2 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -62,6 +62,9 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +# Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0. +USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper() + FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() # `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it. @@ -484,8 +487,11 @@ def is_g2p_en_available(): @lru_cache() def is_torch_tpu_available(check_device=True): "Checks if `torch_xla` is installed and potentially if a TPU is in the environment" - if not _torch_available: + if not _torch_available or USE_TORCH_XLA not in ENV_VARS_TRUE_VALUES: return False + import torch_xla.core.xla_model as xm + device = xm.xla_device() + xm.set_replication(device, [device]) if importlib.util.find_spec("torch_xla") is not None: if check_device: # We need to check if `xla_device` can be found, will raise a RuntimeError if not From 5504241140f2cd52b59e6e2d169d847d579ee963 Mon Sep 17 00:00:00 2001 From: "wangang.wa" Date: Mon, 8 Jan 2024 21:31:31 +0800 Subject: [PATCH 2/7] rename torch_tpu to torch_xla --- docs/source/de/testing.md | 2 +- docs/source/en/testing.md | 2 +- docs/source/ja/testing.md | 2 +- docs/source/ko/testing.md | 2 +- examples/legacy/seq2seq/seq2seq_trainer.py | 4 +- examples/pytorch/language-modeling/run_clm.py | 6 +-- examples/pytorch/language-modeling/run_mlm.py | 6 +-- examples/pytorch/old_test_xla_examples.py | 4 +- .../pytorch/question-answering/trainer_qa.py | 4 +- .../question-answering/trainer_seq2seq_qa.py | 4 +- .../quantization-qdqbert/trainer_quant_qa.py | 4 +- src/transformers/__init__.py | 4 +- src/transformers/benchmark/benchmark_args.py | 8 ++-- src/transformers/file_utils.py | 2 +- .../integrations/integration_utils.py | 4 +- src/transformers/modeling_utils.py | 8 ++-- src/transformers/pytorch_utils.py | 4 +- src/transformers/testing_utils.py | 8 ++-- src/transformers/trainer.py | 34 ++++++++--------- src/transformers/trainer_pt_utils.py | 6 +-- src/transformers/trainer_utils.py | 6 +-- src/transformers/training_args.py | 16 ++++---- src/transformers/utils/__init__.py | 2 +- src/transformers/utils/import_utils.py | 38 ++++++++++++++++--- 24 files changed, 103 insertions(+), 77 deletions(-) diff --git a/docs/source/de/testing.md b/docs/source/de/testing.md index 25c1143e381de8..1d68c11c3ba07a 100644 --- a/docs/source/de/testing.md +++ b/docs/source/de/testing.md @@ -452,7 +452,7 @@ Dekorateure werden verwendet, um die Anforderungen von Tests in Bezug auf CPU/GP - `require_torch_multi_gpu` - wie `require_torch` und zusätzlich mindestens 2 GPUs erforderlich - `require_torch_non_multi_gpu` - wie `require_torch` plus benötigt 0 oder 1 GPUs - `require_torch_up_to_2_gpus` - wie `require_torch` plus erfordert 0 oder 1 oder 2 GPUs -- `require_torch_tpu` - wie `require_torch` plus erfordert mindestens 1 TPU +- `require_torch_xla` - wie `require_torch` plus erfordert mindestens 1 TPU Lassen Sie uns die GPU-Anforderungen in der folgenden Tabelle darstellen: diff --git a/docs/source/en/testing.md b/docs/source/en/testing.md index fda2fc0cb34352..aadb3667948df7 100644 --- a/docs/source/en/testing.md +++ b/docs/source/en/testing.md @@ -451,7 +451,7 @@ decorators are used to set the requirements of tests CPU/GPU/TPU-wise: - `require_torch_multi_gpu` - as `require_torch` plus requires at least 2 GPUs - `require_torch_non_multi_gpu` - as `require_torch` plus requires 0 or 1 GPUs - `require_torch_up_to_2_gpus` - as `require_torch` plus requires 0 or 1 or 2 GPUs -- `require_torch_tpu` - as `require_torch` plus requires at least 1 TPU +- `require_torch_xla` - as `require_torch` plus requires at least 1 TPU Let's depict the GPU requirements in the following table: diff --git a/docs/source/ja/testing.md b/docs/source/ja/testing.md index a7b357acd66e7e..00a51f13811b2f 100644 --- a/docs/source/ja/testing.md +++ b/docs/source/ja/testing.md @@ -424,7 +424,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py - `require_torch_multi_gpu` - `require_torch` に加えて、少なくとも2つのGPUが必要です。 - `require_torch_non_multi_gpu` - `require_torch` に加えて、0または1つのGPUが必要です。 - `require_torch_up_to_2_gpus` - `require_torch` に加えて、0、1、または2つのGPUが必要です。 -- `require_torch_tpu` - `require_torch` に加えて、少なくとも1つのTPUが必要です。 +- `require_torch_xla` - `require_torch` に加えて、少なくとも1つのTPUが必要です。 以下の表にGPUの要件を示します: diff --git a/docs/source/ko/testing.md b/docs/source/ko/testing.md index aad22c00feea4d..390a1c19baac6f 100644 --- a/docs/source/ko/testing.md +++ b/docs/source/ko/testing.md @@ -452,7 +452,7 @@ CUDA_VISIBLE_DEVICES="1" pytest tests/utils/test_logging.py - `require_torch_multi_gpu` - `require_torch`에 추가로 적어도 2개의 GPU가 필요합니다. - `require_torch_non_multi_gpu` - `require_torch`에 추가로 0개 또는 1개의 GPU가 필요합니다. - `require_torch_up_to_2_gpus` - `require_torch`에 추가로 0개, 1개 또는 2개의 GPU가 필요합니다. -- `require_torch_tpu` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다. +- `require_torch_xla` - `require_torch`에 추가로 적어도 1개의 TPU가 필요합니다. GPU 요구 사항을 표로 정리하면 아래와 같습니디ㅏ: diff --git a/examples/legacy/seq2seq/seq2seq_trainer.py b/examples/legacy/seq2seq/seq2seq_trainer.py index bb219fd2bcb94d..0c981a201dd4b1 100644 --- a/examples/legacy/seq2seq/seq2seq_trainer.py +++ b/examples/legacy/seq2seq/seq2seq_trainer.py @@ -32,7 +32,7 @@ ) from transformers.trainer_pt_utils import get_tpu_sampler from transformers.training_args import ParallelMode -from transformers.utils import is_torch_tpu_available +from transformers.utils import is_torch_xla_available logger = logging.get_logger(__name__) @@ -135,7 +135,7 @@ def _get_lr_scheduler(self, num_training_steps): def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if isinstance(self.train_dataset, torch.utils.data.IterableDataset): return None - elif is_torch_tpu_available(): + elif is_torch_xla_available(): return get_tpu_sampler(self.train_dataset) else: if self.args.sortish_sampler: diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index a7ffb9c1f8d019..bdf83684c3e4e6 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -46,7 +46,7 @@ Trainer, TrainingArguments, default_data_collator, - is_torch_tpu_available, + is_torch_xla_available, set_seed, ) from transformers.testing_utils import CaptureLogger @@ -602,9 +602,9 @@ def compute_metrics(eval_preds): tokenizer=tokenizer, # Data collator will default to DataCollatorWithPadding, so we change it. data_collator=default_data_collator, - compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, + compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None, preprocess_logits_for_metrics=preprocess_logits_for_metrics - if training_args.do_eval and not is_torch_tpu_available() + if training_args.do_eval and not is_torch_xla_available() else None, ) diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index b2b8419ae44dc5..a86f51203b7415 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -45,7 +45,7 @@ HfArgumentParser, Trainer, TrainingArguments, - is_torch_tpu_available, + is_torch_xla_available, set_seed, ) from transformers.trainer_utils import get_last_checkpoint @@ -620,9 +620,9 @@ def compute_metrics(eval_preds): eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, + compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None, preprocess_logits_for_metrics=preprocess_logits_for_metrics - if training_args.do_eval and not is_torch_tpu_available() + if training_args.do_eval and not is_torch_xla_available() else None, ) diff --git a/examples/pytorch/old_test_xla_examples.py b/examples/pytorch/old_test_xla_examples.py index 2f24035d72377b..c13d8b3115130c 100644 --- a/examples/pytorch/old_test_xla_examples.py +++ b/examples/pytorch/old_test_xla_examples.py @@ -21,7 +21,7 @@ from time import time from unittest.mock import patch -from transformers.testing_utils import TestCasePlus, require_torch_tpu +from transformers.testing_utils import TestCasePlus, require_torch_xla logging.basicConfig(level=logging.DEBUG) @@ -44,7 +44,7 @@ def get_results(output_dir): logger.addHandler(stream_handler) -@require_torch_tpu +@require_torch_xla class TorchXLAExamplesTests(TestCasePlus): def test_run_glue(self): import xla_spawn diff --git a/examples/pytorch/question-answering/trainer_qa.py b/examples/pytorch/question-answering/trainer_qa.py index a486405b62877e..0e82e6b8163644 100644 --- a/examples/pytorch/question-answering/trainer_qa.py +++ b/examples/pytorch/question-answering/trainer_qa.py @@ -18,11 +18,11 @@ import math import time -from transformers import Trainer, is_torch_tpu_available +from transformers import Trainer, is_torch_xla_available from transformers.trainer_utils import PredictionOutput, speed_metrics -if is_torch_tpu_available(check_device=False): +if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met diff --git a/examples/pytorch/question-answering/trainer_seq2seq_qa.py b/examples/pytorch/question-answering/trainer_seq2seq_qa.py index bdf82bda9f3678..dea184e9085b70 100644 --- a/examples/pytorch/question-answering/trainer_seq2seq_qa.py +++ b/examples/pytorch/question-answering/trainer_seq2seq_qa.py @@ -21,11 +21,11 @@ from torch.utils.data import Dataset -from transformers import Seq2SeqTrainer, is_torch_tpu_available +from transformers import Seq2SeqTrainer, is_torch_xla_available from transformers.trainer_utils import PredictionOutput, speed_metrics -if is_torch_tpu_available(check_device=False): +if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met diff --git a/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py b/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py index 9b8c53b272b11b..a56d875354ddb0 100644 --- a/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py +++ b/examples/research_projects/quantization-qdqbert/trainer_quant_qa.py @@ -24,13 +24,13 @@ import torch from torch.utils.data import DataLoader -from transformers import Trainer, is_torch_tpu_available +from transformers import Trainer, is_torch_xla_available from transformers.trainer_utils import PredictionOutput logger = logging.getLogger(__name__) -if is_torch_tpu_available(check_device=False): +if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index da650cc58ff99b..7e7d4fd2874094 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1090,8 +1090,8 @@ "is_torch_available", "is_torch_neuroncore_available", "is_torch_npu_available", - "is_torch_tpu_available", "is_torchvision_available", + "is_torch_xla_available", "is_torch_xpu_available", "is_vision_available", "logging", @@ -5894,7 +5894,7 @@ is_torch_available, is_torch_neuroncore_available, is_torch_npu_available, - is_torch_tpu_available, + is_torch_xla_available, is_torch_xpu_available, is_torchvision_available, is_vision_available, diff --git a/src/transformers/benchmark/benchmark_args.py b/src/transformers/benchmark/benchmark_args.py index c20683e416843b..396207300b84f1 100644 --- a/src/transformers/benchmark/benchmark_args.py +++ b/src/transformers/benchmark/benchmark_args.py @@ -20,7 +20,7 @@ from ..utils import ( cached_property, is_torch_available, - is_torch_tpu_available, + is_torch_xla_available, is_torch_xpu_available, logging, requires_backends, @@ -31,7 +31,7 @@ if is_torch_available(): import torch -if is_torch_tpu_available(check_device=False): +if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -88,7 +88,7 @@ def _setup_devices(self) -> Tuple["torch.device", int]: if not self.cuda: device = torch.device("cpu") n_gpu = 0 - elif is_torch_tpu_available(): + elif is_torch_xla_available(): device = xm.xla_device() n_gpu = 0 elif is_torch_xpu_available(): @@ -101,7 +101,7 @@ def _setup_devices(self) -> Tuple["torch.device", int]: @property def is_tpu(self): - return is_torch_tpu_available() and self.tpu + return is_torch_xla_available() and self.tpu @property def device_idx(self) -> int: diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 7596e4cd231f0c..2d9477727ea4e1 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -121,7 +121,7 @@ is_torch_fx_proxy, is_torch_mps_available, is_torch_tf32_available, - is_torch_tpu_available, + is_torch_xla_available, is_torchaudio_available, is_training_run_on_sagemaker, is_vision_available, diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 65642039da7395..9da0607e428d0d 100644 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -72,7 +72,7 @@ from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 from ..training_args import ParallelMode # noqa: E402 -from ..utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402 +from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402 # Integration functions: @@ -752,7 +752,7 @@ def setup(self, args, state, model, **kwargs): # keep track of model topology and gradients, unsupported on TPU _watch_model = os.getenv("WANDB_WATCH", "false") - if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"): + if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"): self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps)) self._wandb.run._label(code="transformers_trainer") diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 505c9cb45950cb..ca5ee26e279800 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -84,7 +84,7 @@ is_remote_url, is_safetensors_available, is_torch_sdpa_available, - is_torch_tpu_available, + is_torch_xla_available, logging, replace_return_docstrings, strtobool, @@ -246,10 +246,10 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil # Adding fix for https://github.com/pytorch/xla/issues/4152 # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1 # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf - # NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo - if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available(): + # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo + if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available(): return torch.bfloat16 - if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available(): + if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available(): if t.dtype == torch.float: return torch.bfloat16 if t.dtype == torch.double: diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 993da84d33f837..cab0b0d4aec72b 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -19,7 +19,7 @@ from safetensors.torch import storage_ptr, storage_size from torch import nn -from .utils import is_torch_tpu_available, logging +from .utils import is_torch_xla_available, logging ALL_LAYERNORM_LAYERS = [nn.LayerNorm] @@ -282,7 +282,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. """ - if tensor.device.type == "xla" and is_torch_tpu_available(): + if tensor.device.type == "xla" and is_torch_xla_available(): # NOTE: xla tensors dont have storage # use some other unique id to distinguish. # this is a XLA tensor, it must be created using torch_xla's diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index adcadfc379251e..0e08f054dcf933 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -114,7 +114,7 @@ is_torch_sdpa_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, - is_torch_tpu_available, + is_torch_xla_available, is_torch_xpu_available, is_torchaudio_available, is_torchdynamo_available, @@ -725,11 +725,11 @@ def require_torch_up_to_2_accelerators(test_case): (test_case) -def require_torch_tpu(test_case): +def require_torch_xla(test_case): """ - Decorator marking a test that requires a TPU (in PyTorch). + Decorator marking a test that requires TorchXLA (in PyTorch). """ - return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case) + return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case) def require_torch_neuroncore(test_case): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 574363421234b3..cd547b44463f20 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -149,7 +149,7 @@ is_torch_compile_available, is_torch_neuroncore_available, is_torch_npu_available, - is_torch_tpu_available, + is_torch_xla_available, logging, strtobool, ) @@ -170,7 +170,7 @@ if is_datasets_available(): import datasets -if is_torch_tpu_available(check_device=False): +if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import torch_xla.distributed.spmd as xs @@ -508,7 +508,7 @@ def __init__( "Passing a `model_init` is incompatible with providing the `optimizers` argument. " "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) - if is_torch_tpu_available() and self.optimizer is not None: + if is_torch_xla_available() and self.optimizer is not None: for param in self.model.parameters(): model_device = param.device break @@ -856,7 +856,7 @@ def get_train_dataloader(self) -> DataLoader: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: # Deprecated code if self.args.use_legacy_prediction_loop: - if is_torch_tpu_available(): + if is_torch_xla_available(): return SequentialDistributedSampler( eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) @@ -1964,7 +1964,7 @@ def _inner_training_loop( if ( args.logging_nan_inf_filter - and not is_torch_tpu_available() + and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): # if loss is nan or inf simply add the average of previous logged losses @@ -2054,7 +2054,7 @@ def _inner_training_loop( self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: - if is_torch_tpu_available(): + if is_torch_xla_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) else: @@ -2072,7 +2072,7 @@ def _inner_training_loop( logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sure the model has been saved by process 0. - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.rendezvous("load_best_model_at_end") elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() @@ -2391,7 +2391,7 @@ def _issue_warnings_after_load(self, load_result): def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): if self.control.should_log and self.state.global_step > self._globalstep_last_logged: - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.mark_step() logs: Dict[str, float] = {} @@ -2467,7 +2467,7 @@ def _load_rng_state(self, checkpoint): f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" "\nThis won't yield the same results as if the training had not been interrupted." ) - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.set_rng_state(checkpoint_rng_state["xla"]) if is_torch_npu_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: @@ -2545,7 +2545,7 @@ def _save_rng_state(self, output_dir): else: rng_states["cuda"] = torch.cuda.random.get_rng_state() - if is_torch_tpu_available(): + if is_torch_xla_available(): rng_states["xla"] = xm.get_rng_state() if is_torch_npu_available(): @@ -2564,7 +2564,7 @@ def _save_rng_state(self, output_dir): torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) def _save_optimizer_and_scheduler(self, output_dir): - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: @@ -2609,7 +2609,7 @@ def _save_optimizer_and_scheduler(self, output_dir): if ( self.args.should_save and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler) - and not is_torch_tpu_available() + and not is_torch_xla_available() ): with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -2646,7 +2646,7 @@ def _load_optimizer_and_scheduler(self, checkpoint): ) if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states - if is_torch_tpu_available(): + if is_torch_xla_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: @@ -2953,7 +2953,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if output_dir is None: output_dir = self.args.output_dir - if is_torch_tpu_available(): + if is_torch_xla_available(): self._save_tpu(output_dir) elif is_sagemaker_mp_enabled(): # Calling the state_dict needs to be done on the wrapped model and on all processes. @@ -3394,7 +3394,7 @@ def evaluation_loop( main_input_name = getattr(self.model, "main_input_name", "input_ids") inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.mark_step() # Update containers on host @@ -3518,7 +3518,7 @@ def _nested_gather(self, tensors, name=None): """ if tensors is None: return - if is_torch_tpu_available(): + if is_torch_xla_available(): if name is None: name = "nested_gather" tensors = nested_xla_mesh_reduce(tensors, name) @@ -4034,7 +4034,7 @@ def _gather_and_numpify(self, tensors, name): """ if tensors is None: return - if is_torch_tpu_available(): + if is_torch_xla_available(): tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index dce0eeaf818604..34d2c8416b5936 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -39,13 +39,13 @@ from .integrations.deepspeed import is_deepspeed_zero3_enabled from .tokenization_utils_base import BatchEncoding -from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging +from .utils import is_sagemaker_mp_enabled, is_torch_xla_available, is_training_run_on_sagemaker, logging if is_training_run_on_sagemaker(): logging.add_handler(StreamHandler(sys.stdout)) -if is_torch_tpu_available(check_device=False): +if is_torch_xla_available(): import torch_xla.core.xla_model as xm # this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0 @@ -179,7 +179,7 @@ def nested_detach(tensors): def nested_xla_mesh_reduce(tensors, name): - if is_torch_tpu_available(): + if is_torch_xla_available(): import torch_xla.core.xla_model as xm if isinstance(tensors, (list, tuple)): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 803f6fe840e7d0..5d528317e54fe0 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -37,7 +37,7 @@ is_torch_cuda_available, is_torch_mps_available, is_torch_npu_available, - is_torch_tpu_available, + is_torch_xla_available, is_torch_xpu_available, requires_backends, ) @@ -340,7 +340,7 @@ def is_main_process(local_rank): Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on `local_rank`. """ - if is_torch_tpu_available(check_device=True): + if is_torch_xla_available(): import torch_xla.core.xla_model as xm return xm.get_ordinal() == 0 @@ -351,7 +351,7 @@ def total_processes_number(local_rank): """ Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs. """ - if is_torch_tpu_available(check_device=True): + if is_torch_xla_available(): import torch_xla.core.xla_model as xm return xm.xrt_world_size() diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5baa3e1b51f366..70ef0cab18754d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -49,7 +49,7 @@ is_torch_neuroncore_available, is_torch_npu_available, is_torch_tf32_available, - is_torch_tpu_available, + is_torch_xla_available, is_torch_xpu_available, logging, requires_backends, @@ -74,7 +74,7 @@ from .trainer_pt_utils import AcceleratorConfig -if is_torch_tpu_available(check_device=False): +if is_torch_xla_available(): import torch_xla.core.xla_model as xm if is_torch_neuroncore_available(check_device=False): @@ -130,7 +130,7 @@ def get_xla_device_type(device: "torch.device") -> Optional[str]: """ Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device. """ - if is_torch_tpu_available(): + if is_torch_xla_available(): return xm.xla_real_devices([device])[0].split(":")[0] return None @@ -1475,7 +1475,7 @@ def __post_init__(self): self.half_precision_backend = self.fp16_backend if self.bf16 or self.bf16_full_eval: - if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_tpu_available(): + if self.use_cpu and not is_torch_bf16_cpu_available() and not is_torch_xla_available(): # cpu raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") elif not self.use_cpu: @@ -1948,7 +1948,7 @@ def _setup_devices(self) -> "torch.device": "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" ) - if is_torch_tpu_available(): + if is_torch_xla_available(): device = self.distributed_state.device self._n_gpu = 0 elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): @@ -2029,7 +2029,7 @@ def parallel_mode(self): - `ParallelMode.TPU`: several TPU cores. """ requires_backends(self, ["torch"]) - if is_torch_tpu_available(): + if is_torch_xla_available(): return ParallelMode.TPU elif is_sagemaker_mp_enabled(): return ParallelMode.SAGEMAKER_MODEL_PARALLEL @@ -2180,7 +2180,7 @@ def main_process_first(self, local=True, desc="work"): # tell all replicas to wait logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.rendezvous(desc) else: dist.barrier() @@ -2189,7 +2189,7 @@ def main_process_first(self, local=True, desc="work"): if is_main_process: # the wait is over logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas") - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.rendezvous(desc) else: dist.barrier() diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2fe931b3f38faf..333c3463aa2e6a 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -188,7 +188,7 @@ is_torch_sdpa_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, - is_torch_tpu_available, + is_torch_xla_available, is_torch_xpu_available, is_torchaudio_available, is_torchdistx_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index fd5266f3f4a3b2..db02ce57f1756b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -487,11 +487,14 @@ def is_g2p_en_available(): @lru_cache() def is_torch_tpu_available(check_device=True): "Checks if `torch_xla` is installed and potentially if a TPU is in the environment" - if not _torch_available or USE_TORCH_XLA not in ENV_VARS_TRUE_VALUES: + warnings.warn( + "`is_torch_tpu_available` is deprecated and will be removed in 4.39.0. " + "Please use the `is_torch_xla_available` instead.", + FutureWarning, + ) + + if not _torch_available: return False - import torch_xla.core.xla_model as xm - device = xm.xla_device() - xm.set_replication(device, [device]) if importlib.util.find_spec("torch_xla") is not None: if check_device: # We need to check if `xla_device` can be found, will raise a RuntimeError if not @@ -503,13 +506,36 @@ def is_torch_tpu_available(check_device=True): except RuntimeError: return False return True - return False + + +@lru_cache +def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): + """ + Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set + the USE_TORCH_XLA to false. + """ + assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." + + try: + import torch_xla.core.xla_model as xm + + xla_device = xm.xla_device() + hardware_type = xm.xla_device_hw(xla_device) + return any( + [ + check_is_tpu and hardware_type == "TPU", + check_is_gpu and hardware_type == "GPU", + not (check_is_tpu or check_is_gpu), + ] + ) + except (ImportError, RuntimeError): + return False @lru_cache() def is_torch_neuroncore_available(check_device=True): if importlib.util.find_spec("torch_neuronx") is not None: - return is_torch_tpu_available(check_device) + return is_torch_xla_available(check_device) return False From 82d0d769c00ed97757c2830abc64a2d7c2d6e587 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Tue, 27 Feb 2024 14:35:21 +0800 Subject: [PATCH 3/7] better is_torch_xla_available; fix some fsdp and performance issues --- src/transformers/integrations/tpu.py | 4 ++-- src/transformers/trainer.py | 6 +++--- src/transformers/training_args.py | 7 ++++--- src/transformers/utils/import_utils.py | 29 ++++++++++++++------------ 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/transformers/integrations/tpu.py b/src/transformers/integrations/tpu.py index f2943dcf12df3e..29262789dc9855 100644 --- a/src/transformers/integrations/tpu.py +++ b/src/transformers/integrations/tpu.py @@ -14,11 +14,11 @@ from torch.utils.data import DataLoader -from ..utils import is_torch_tpu_available +from ..utils import is_torch_xla_available def tpu_spmd_dataloader(dataloader: DataLoader): - if is_torch_tpu_available(): + if is_torch_xla_available(): import torch_xla.distributed.parallel_loader as pl assert isinstance( diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cd547b44463f20..c5a5b9ad1ac43d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2016,7 +2016,7 @@ def _inner_training_loop( if hasattr(grad_norm, "item"): grad_norm = grad_norm.item() else: - grad_norm = _grad_norm.item() if _grad_norm is not None else None + grad_norm = _grad_norm # Optimizer step self.optimizer.step() @@ -2039,7 +2039,7 @@ def _inner_training_loop( # PyTorch/XLA relies on the data loader to insert the mark_step for # each step. Since we are breaking the loop early, we need to manually # insert the mark_step here. - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.mark_step() break if step < 0: @@ -2404,7 +2404,7 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) if grad_norm is not None: - logs["grad_norm"] = grad_norm + logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm logs["learning_rate"] = self._get_learning_rate() self._total_loss_scalar += tr_loss_scalar diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 70ef0cab18754d..8deadaa6b59279 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1530,7 +1530,7 @@ def __post_init__(self): and (self.device.type != "cuda") and (self.device.type != "npu") and (self.device.type != "xpu") - and (get_xla_device_type(self.device) != "GPU") + and (get_xla_device_type(self.device) not in ["GPU", "CUDA"]) and (self.fp16 or self.fp16_full_eval) ): raise ValueError( @@ -1544,7 +1544,7 @@ def __post_init__(self): and (self.device.type != "cuda") and (self.device.type != "npu") and (self.device.type != "xpu") - and (get_xla_device_type(self.device) != "GPU") + and (get_xla_device_type(self.device) not in ["GPU", "CUDA"]) and (get_xla_device_type(self.device) != "TPU") and (self.device.type != "cpu") and (self.bf16 or self.bf16_full_eval) @@ -1694,7 +1694,8 @@ def __post_init__(self): if self.fsdp_config["xla"]: if len(self.fsdp) > 0: # store XLA fsdp configuration parameters into a dictionary - self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}) + # Copy the config to avoid modifying the original config (which may be used for JSON serialization) + self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy() # apply appropriate string to torch.dtype conversions for parameters if "compute_dtype" in self.xla_fsdp_config: self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"]) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index db02ce57f1756b..6436f7d8548c36 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -252,6 +252,13 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ ) +_torch_xla_available = False +if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES: + _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True) + if _torch_xla_available: + logger.info(f"Torch XLA version {_torch_xla_version} available.") + + def is_kenlm_available(): return _kenlm_available @@ -516,21 +523,17 @@ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): """ assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." - try: - import torch_xla.core.xla_model as xm - - xla_device = xm.xla_device() - hardware_type = xm.xla_device_hw(xla_device) - return any( - [ - check_is_tpu and hardware_type == "TPU", - check_is_gpu and hardware_type == "GPU", - not (check_is_tpu or check_is_gpu), - ] - ) - except (ImportError, RuntimeError): + if not _torch_xla_available: return False + import torch_xla + if check_is_gpu: + return torch_xla.runtime.device_type() in ["GPU", "CUDA"] + elif check_is_tpu: + return torch_xla.runtime.device_type() == "TPU" + + return True + @lru_cache() def is_torch_neuroncore_available(check_device=True): From ce73774097f01f6e12b9de7b0d8747f777cfe564 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Tue, 27 Feb 2024 14:44:09 +0800 Subject: [PATCH 4/7] fix format --- src/transformers/utils/import_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 6436f7d8548c36..59ab371f3ee339 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -527,6 +527,7 @@ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): return False import torch_xla + if check_is_gpu: return torch_xla.runtime.device_type() in ["GPU", "CUDA"] elif check_is_tpu: From 0d218e237cd5117350c4f2ac69e0625097027c5b Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Tue, 27 Feb 2024 19:36:08 +0800 Subject: [PATCH 5/7] fix bug when pjrt_device is cpu --- src/transformers/training_args.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 8deadaa6b59279..e73a60dabb470f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -131,6 +131,8 @@ def get_xla_device_type(device: "torch.device") -> Optional[str]: Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device. """ if is_torch_xla_available(): + if device.type == "cpu": + return "CPU" return xm.xla_real_devices([device])[0].split(":")[0] return None From f761c88548af5ebfacd3290e7532ad20b25238a7 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Wed, 28 Feb 2024 10:43:39 +0800 Subject: [PATCH 6/7] fix bug --- src/transformers/utils/import_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 59ab371f3ee339..e9261b01cb3b57 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -539,7 +539,7 @@ def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): @lru_cache() def is_torch_neuroncore_available(check_device=True): if importlib.util.find_spec("torch_neuronx") is not None: - return is_torch_xla_available(check_device) + return is_torch_xla_available() return False From 1d6223e1d29bc13cf6d1e8cb81e17d8a05247203 Mon Sep 17 00:00:00 2001 From: Yitong Huang Date: Mon, 11 Mar 2024 11:02:57 +0800 Subject: [PATCH 7/7] fix the deprecation handling --- src/transformers/__init__.py | 2 ++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 3 ++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7e7d4fd2874094..606d44bbaffafd 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1090,6 +1090,7 @@ "is_torch_available", "is_torch_neuroncore_available", "is_torch_npu_available", + "is_torch_tpu_available", "is_torchvision_available", "is_torch_xla_available", "is_torch_xpu_available", @@ -5894,6 +5895,7 @@ is_torch_available, is_torch_neuroncore_available, is_torch_npu_available, + is_torch_tpu_available, is_torch_xla_available, is_torch_xpu_available, is_torchvision_available, diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 333c3463aa2e6a..cb54c09819f4d7 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -188,6 +188,7 @@ is_torch_sdpa_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, + is_torch_tpu_available, is_torch_xla_available, is_torch_xpu_available, is_torchaudio_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index e9261b01cb3b57..ceed806e465cb7 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -495,7 +495,7 @@ def is_g2p_en_available(): def is_torch_tpu_available(check_device=True): "Checks if `torch_xla` is installed and potentially if a TPU is in the environment" warnings.warn( - "`is_torch_tpu_available` is deprecated and will be removed in 4.39.0. " + "`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. " "Please use the `is_torch_xla_available` instead.", FutureWarning, ) @@ -513,6 +513,7 @@ def is_torch_tpu_available(check_device=True): except RuntimeError: return False return True + return False @lru_cache