Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make torch xla available on GPU #29334

Merged
merged 7 commits into from
Mar 11, 2024
Merged

Make torch xla available on GPU #29334

merged 7 commits into from
Mar 11, 2024

Conversation

yitongh
Copy link
Contributor

@yitongh yitongh commented Feb 28, 2024

What does this PR do?

Make torch xla available on GPU. Currently, torch xla can be used in a GPU environment, but there are some conflicts between XLA and native PyTorch CUDA when using an environment with torch xla installed. This PR introduces the environment variable USE_TORCH_XLA to address this issue. When USE_TORCH_XLA is set to false, native PyTorch CUDA can be used seamlessly, even if torch xla is installed.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr and @pacman100

@yitongh
Copy link
Contributor Author

yitongh commented Feb 28, 2024

The main changes:

  1. Change is_torch_tpu_available to is_torch_xla_available
  2. Change require_torch_tpu to require_torch_xla
  3. Add USE_TORCH_XLA to enable or disable torch_xla
  4. Fix amp check
  5. Move grad_norm.item() into _maybe_log_save_evaluate to prevent a performance degradation in XLA
  6. Copy the xla_fsdp_config to avoid modifying the original config

@yitongh
Copy link
Contributor Author

yitongh commented Feb 28, 2024

This PR is related to huggingface/accelerate#2176 and huggingface/accelerate#2467.

@will-cromar Could you please check if this PR has any impact on the TPU environment? Thanks.

@yitongh
Copy link
Contributor Author

yitongh commented Feb 28, 2024

The ci only failed in tests/test_modeling_utils.py::ModelUtilsTest::test_use_safetensors. I run this test in master and it also hangs. From pystack, it looks like this issue related to ssl read. stack log: https://gist.github.com/yitongh/34dc9c9f3de79d208533964bd63bb6f5

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yitongh
Copy link
Contributor Author

yitongh commented Mar 5, 2024

@muellerzr, would you be available to take a look at this PR when you have a moment? Alternatively, if you're not, perhaps you could suggest someone else who might be suited to review it? Many thanks.

@ArthurZucker ArthurZucker requested a review from muellerzr March 7, 2024 11:34
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is fine by me, we have very similar logic in Accelerate if I'm not mistaken. Thanks!

@muellerzr
Copy link
Contributor

Let's make sure we can fix those failing tests though, can you try rebasing from main?

@yitongh
Copy link
Contributor Author

yitongh commented Mar 8, 2024

@muellerzr I have rebased from main. I rerun the failing in my machine, both main and this pr passed test_run_ner_no_trainer and test_run_squad_no_trainer, but failed at test_run_glue_no_trainer. It looks like not related to this pr.

pytest -s -v examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_ner_no_trainer examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_squad_no_trainer examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_glue_no_trainer
======================================================================================================== test session starts =========================================================================================================
platform linux -- Python 3.10.12, pytest-8.0.0, pluggy-1.4.0 -- /usr/bin/python3.10
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase(PosixPath('/root/hyt/github/transformers/.hypothesis/examples'))
rootdir: /root/hyt/github/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.17, xdist-3.5.0, subtests-0.12.1, anyio-4.3.0, timeout-2.3.1
collected 3 items

examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_ner_no_trainer PASSED
examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_squad_no_trainer PASSED
examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_glue_no_trainer FAILED

============================================================================================================== FAILURES ==============================================================================================================
__________________________________________________________________________________________ ExamplesTestsNoTrainer.test_run_glue_no_trainer ___________________________________________________________________________________________

self = <test_accelerate_examples.ExamplesTestsNoTrainer testMethod=test_run_glue_no_trainer>

    @mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
    def test_run_glue_no_trainer(self):
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            {self.examples_dir}/pytorch/text-classification/run_glue_no_trainer.py
            --model_name_or_path distilbert/distilbert-base-uncased
            --output_dir {tmp_dir}
            --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
            --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --learning_rate=1e-4
            --seed=42
            --num_warmup_steps=2
            --checkpointing_steps epoch
            --with_tracking
        """.split()

        run_command(self._launch_args + testargs)
        result = get_results(tmp_dir)
>       self.assertGreaterEqual(result["eval_accuracy"], 0.75)
E       AssertionError: 0.6666666666666666 not greater than or equal to 0.75

examples/pytorch/test_accelerate_examples.py:98: AssertionError
========================================================================================================== warnings summary ==========================================================================================================
../../../../usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1394
  /usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:1394: PytestConfigWarning: Unknown config option: doctest_glob

    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================================================================== short test summary info =======================================================================================================
FAILED examples/pytorch/test_accelerate_examples.py::ExamplesTestsNoTrainer::test_run_glue_no_trainer - AssertionError: 0.6666666666666666 not greater than or equal to 0.75
========================================================================================= 1 failed, 2 passed, 1 warning in 143.07s (0:02:23) =========================================================================================

@muellerzr
Copy link
Contributor

They look to be timeout issues, I'm rerunning the tests now. However if they still fail they were not failing before this 😅

@muellerzr
Copy link
Contributor

Tests pass on our CI so looks to be fine

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @amyeroberts for final review

@muellerzr muellerzr requested a review from amyeroberts March 8, 2024 16:00
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Mostly just small comments about the deprecation handling.

Main concern is that previously check_device was True by default. Therefore, replacing is_torch_tpu_available() with is_torch_xla_available() isn't an equivalent call.

@@ -497,13 +513,33 @@ def is_torch_tpu_available(check_device=True):
except RuntimeError:
return False
return True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The final return should remain

Suggested change
return True
return True
return False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -188,7 +188,7 @@
is_torch_sdpa_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_xla_available,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to leave this as importable whilst it's still going through the deprecation cycle

Suggested change
is_torch_xla_available,
is_torch_tpu_available,
is_torch_xla_available,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -135,7 +135,7 @@ def _get_lr_scheduler(self, num_training_steps):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
elif is_torch_xla_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't equivalent, previously, we were checking for a device, but by default that isn't happening anymore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to align with PR of the accelerate library. If users do not wish to use torch_xla in an environment where torch_xla is installed, they can configure it using USE_TORCH_XLA, which is also the purpose of this PR.

@@ -2404,7 +2404,7 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm
logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change doesn't seem to have anything to do with the goal of this pr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification is because tensor evaluation (grad_norm.item()) will cause XLA to execute the entire computation graph prematurely, resulting in decreased performance. The grad_norm.item() operation should be performed after the XLA mark_step.

@@ -2016,7 +2016,7 @@ def _inner_training_loop(
if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None
grad_norm = _grad_norm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@@ -1090,8 +1090,8 @@
"is_torch_available",
"is_torch_neuroncore_available",
"is_torch_npu_available",
"is_torch_tpu_available",
"is_torchvision_available",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep whilst it's still being deprecated

Suggested change
"is_torchvision_available",
"is_torch_tpu_available",
"is_torchvision_available",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -5894,7 +5894,7 @@
is_torch_available,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tpu_available,
is_torch_xla_available,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_torch_xla_available,
is_torch_tpu_available,
is_torch_xla_available,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -484,6 +494,12 @@ def is_g2p_en_available():
@lru_cache()
def is_torch_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
warnings.warn(
"`is_torch_tpu_available` is deprecated and will be removed in 4.39.0. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be the next release - so would need to be removed now! As it's a public object, it should go through at least two cycles

Suggested change
"`is_torch_tpu_available` is deprecated and will be removed in 4.39.0. "
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me - thanks for iterating and explaining the design choices!

@amyeroberts amyeroberts merged commit 873d9bb into huggingface:main Mar 11, 2024
20 checks passed
itazap pushed a commit that referenced this pull request May 14, 2024
* add USE_TORCH_XLA env

* rename torch_tpu to torch_xla

* better is_torch_xla_available; fix some fsdp and performance issues

* fix format

* fix bug when pjrt_device is cpu

* fix bug

* fix the deprecation handling

---------

Co-authored-by: anw90 <[email protected]>
Co-authored-by: wangang.wa <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants