-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make torch xla available on GPU #29334
Conversation
The main changes:
|
This PR is related to huggingface/accelerate#2176 and huggingface/accelerate#2467. @will-cromar Could you please check if this PR has any impact on the TPU environment? Thanks. |
The ci only failed in tests/test_modeling_utils.py::ModelUtilsTest::test_use_safetensors. I run this test in master and it also hangs. From pystack, it looks like this issue related to ssl read. stack log: https://gist.github.com/yitongh/34dc9c9f3de79d208533964bd63bb6f5 |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@muellerzr, would you be available to take a look at this PR when you have a moment? Alternatively, if you're not, perhaps you could suggest someone else who might be suited to review it? Many thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this is fine by me, we have very similar logic in Accelerate if I'm not mistaken. Thanks!
Let's make sure we can fix those failing tests though, can you try rebasing from |
@muellerzr I have rebased from main. I rerun the failing in my machine, both main and this pr passed
|
They look to be timeout issues, I'm rerunning the tests now. However if they still fail they were not failing before this 😅 |
Tests pass on our CI so looks to be fine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @amyeroberts for final review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
Mostly just small comments about the deprecation handling.
Main concern is that previously check_device
was True
by default. Therefore, replacing is_torch_tpu_available()
with is_torch_xla_available()
isn't an equivalent call.
@@ -497,13 +513,33 @@ def is_torch_tpu_available(check_device=True): | |||
except RuntimeError: | |||
return False | |||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The final return should remain
return True | |
return True | |
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -188,7 +188,7 @@ | |||
is_torch_sdpa_available, | |||
is_torch_tensorrt_fx_available, | |||
is_torch_tf32_available, | |||
is_torch_tpu_available, | |||
is_torch_xla_available, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to leave this as importable whilst it's still going through the deprecation cycle
is_torch_xla_available, | |
is_torch_tpu_available, | |
is_torch_xla_available, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -135,7 +135,7 @@ def _get_lr_scheduler(self, num_training_steps): | |||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: | |||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): | |||
return None | |||
elif is_torch_tpu_available(): | |||
elif is_torch_xla_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't equivalent, previously, we were checking for a device, but by default that isn't happening anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to align with PR of the accelerate library. If users do not wish to use torch_xla in an environment where torch_xla is installed, they can configure it using USE_TORCH_XLA, which is also the purpose of this PR.
@@ -2404,7 +2404,7 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno | |||
|
|||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) | |||
if grad_norm is not None: | |||
logs["grad_norm"] = grad_norm | |||
logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change doesn't seem to have anything to do with the goal of this pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This modification is because tensor evaluation (grad_norm.item()
) will cause XLA to execute the entire computation graph prematurely, resulting in decreased performance. The grad_norm.item()
operation should be performed after the XLA mark_step.
@@ -2016,7 +2016,7 @@ def _inner_training_loop( | |||
if hasattr(grad_norm, "item"): | |||
grad_norm = grad_norm.item() | |||
else: | |||
grad_norm = _grad_norm.item() if _grad_norm is not None else None | |||
grad_norm = _grad_norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
@@ -1090,8 +1090,8 @@ | |||
"is_torch_available", | |||
"is_torch_neuroncore_available", | |||
"is_torch_npu_available", | |||
"is_torch_tpu_available", | |||
"is_torchvision_available", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to keep whilst it's still being deprecated
"is_torchvision_available", | |
"is_torch_tpu_available", | |
"is_torchvision_available", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -5894,7 +5894,7 @@ | |||
is_torch_available, | |||
is_torch_neuroncore_available, | |||
is_torch_npu_available, | |||
is_torch_tpu_available, | |||
is_torch_xla_available, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_torch_xla_available, | |
is_torch_tpu_available, | |
is_torch_xla_available, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -484,6 +494,12 @@ def is_g2p_en_available(): | |||
@lru_cache() | |||
def is_torch_tpu_available(check_device=True): | |||
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment" | |||
warnings.warn( | |||
"`is_torch_tpu_available` is deprecated and will be removed in 4.39.0. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be the next release - so would need to be removed now! As it's a public object, it should go through at least two cycles
"`is_torch_tpu_available` is deprecated and will be removed in 4.39.0. " | |
"`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look good to me - thanks for iterating and explaining the design choices!
* add USE_TORCH_XLA env * rename torch_tpu to torch_xla * better is_torch_xla_available; fix some fsdp and performance issues * fix format * fix bug when pjrt_device is cpu * fix bug * fix the deprecation handling --------- Co-authored-by: anw90 <[email protected]> Co-authored-by: wangang.wa <[email protected]>
What does this PR do?
Make torch xla available on GPU. Currently, torch xla can be used in a GPU environment, but there are some conflicts between XLA and native PyTorch CUDA when using an environment with torch xla installed. This PR introduces the environment variable USE_TORCH_XLA to address this issue. When USE_TORCH_XLA is set to false, native PyTorch CUDA can be used seamlessly, even if torch xla is installed.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr and @pacman100