-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make torch xla available on GPU #2176
Conversation
The main changes:
|
hi anyone can help take a look? thanks. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this work! There's some heavy design choices that add friction to the Accelerate API done here however. I've added recommendations and suggestions on improvement + code readability.
(Sorry for the delayed review, holidays + fully grasping the code here to think about)
@@ -45,7 +45,7 @@ Why is this important? Under the hood this will set **5** different seed setting | |||
torch.manual_seed(seed) | |||
torch.cuda.manual_seed_all(seed) | |||
# ^^ safe to call this function even if cuda is not available | |||
if is_tpu_available(): | |||
if is_troch_xla_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_troch_xla_available(): | |
if is_torch_xla_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/accelerate/utils/imports.py
Outdated
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | ||
# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0. | ||
USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just use an already existing util here please:
from accelerate.utils import parse_flag_from_env
USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", "1")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/accelerate/utils/imports.py
Outdated
try: | ||
import torch_xla.core.xla_model as xm # noqa: F401 | ||
if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES: | ||
import torch_xla.core.xla_model as xm # noqa: F401 | ||
|
||
_tpu_available = True | ||
_torch_xla_available = True | ||
else: | ||
_torch_xla_available = False | ||
except ImportError: | ||
_tpu_available = False | ||
_torch_xla_available = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rewrite this whole block:
_torch_xla_available = False
if USE_TORCH_XLA:
try:
import torch_xla.core.xla_model as xm # noqa: F401
_torch_xla_available = True
except ImportError:
pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
return _tpu_available | ||
|
||
if _torch_xla_available: | ||
xla_device = xm.xla_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we cannot do that the original check device must be there. Please bring back the original logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A try-catch block has been added here. Do we still need to retain the check_device
arguments? xm.xla_device()
cannot be called outside of xm.spawn, as described in this PR, but this issue did not occur on the GPU for the master torch XLA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't believe we do, as we can just pass in check_is_tpu
src/accelerate/utils/imports.py
Outdated
Check if `torch_xla` is available and real hardware in `hardware_types`. To train a native pytorch job in an | ||
environment with torch xla installed, set the USE_TORCH_XLA to false. | ||
""" | ||
if USE_TORCH_XLA not in ENV_VARS_TRUE_VALUES: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if USE_TORCH_XLA not in ENV_VARS_TRUE_VALUES: | |
if not USE_TORCH_XLA: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/accelerate/utils/imports.py
Outdated
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment" | ||
# Due to bugs on the amp series GPUs, we disable torch-xla on them | ||
if is_cuda_available(): | ||
def is_torch_xla_available(hardware_types=("TPU", "GPU")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is suddenly a very difficult API to use. Let's just use two boolean values for tpu
and gpu
please. These can default to False
and then we check using any()
at the end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/accelerate/utils/modeling.py
Outdated
@@ -1453,18 +1453,24 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg | |||
else: | |||
autocast_kwargs = autocast_kwargs.to_kwargs() | |||
if native_amp: | |||
device_type = ( | |||
"cuda" | |||
if (state.distributed_type == DistributedType.TPU and is_torch_xla_available(tuple(["GPU"]))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. TPU == XLA. We should have state
return DistributedType.MULTI_GPU
if we want to support CUDA.
src/accelerate/state.py
Outdated
@@ -152,12 +152,16 @@ def __init__(self, cpu: bool = False, **kwargs): | |||
if self.device is None: | |||
self.device = torch.device("cuda", self.local_process_index) | |||
torch.cuda.set_device(self.device) | |||
elif is_tpu_available() and not cpu: | |||
elif is_torch_xla_available() and not cpu: | |||
self.distributed_type = DistributedType.TPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned earlier in my review, if this is the case when we should check if it's CUDA or not and specify DistributedType.MULTI_GPU
. I'm open to having it be DistributedType.XLA
now, but we should have a deprecation warning if users try to use DistributedType.TPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, DistributedType.XLA has been added to replace DistributedType.TPU.
45f44d6
to
72a4322
Compare
Thank you for your review and suggestions @muellerzr . They were incredibly helpful, and the suggested changes have been committed. Please feel free to let me know if there's anything else we should address. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking quite good!
I left a few comments, and it would be great if you could run accelerate test
/run though the accelerate tests on an XLA env (both TPU and GPU env ideally) for us to verify things work fine? And I'll run on GPUs to make sure nothing has broken there (non-xla compiled ones)
src/accelerate/utils/dataclasses.py
Outdated
|
||
def __get__(self, instance, owner): | ||
warnings.warn( | ||
f"The `{self.field_name}` of `{owner}` is deprecated and will be removed in v0.27.0. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this has been in here for a long, long time, let's deprecate it in v1.0.0 please.
src/accelerate/utils/imports.py
Outdated
@@ -92,6 +99,11 @@ def is_cuda_available(): | |||
@lru_cache | |||
def is_tpu_available(check_device=True): | |||
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment" | |||
warnings.warn( | |||
"The `is_tpu_available` is deprecated and will be removed in v0.27.0. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"The `is_tpu_available` is deprecated and will be removed in v0.27.0. " | |
"`is_tpu_available` is deprecated and will be removed in v0.27.0. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note v0.27.0 is fine here
return _tpu_available | ||
|
||
if _torch_xla_available: | ||
xla_device = xm.xla_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't believe we do, as we can just pass in check_is_tpu
Hi, @muellerzr . Sorry for the late reply. I spent some time to resolving the UT issue to ensure that all UTs passed. I ran the UTs (using
(I don't have a TPU at the moment. I'll run the TPU UTs as soon as I acquire the TPU environment.) All the three UTs above were executed with It appears that this test has also failed in the main branch. Before merging this CI, this test is fine. Here is the failure error:
|
Running that test standalone should not (and does not) fail, only as part of the grouped CI which I'm looking into. Are you initializing CUDA somewhere now suddenly? To verify this, in Jupyter on a GPU-enabled system do the following: import torch
from accelerate import Accelerator, notebook_launcher
torch.cuda.is_initialized() This should be |
Thanks for your reply, muellerzr. Following your suggestion, I tried running your code, and torch.cuda.is_initialized() returned false in Jupyter. And I switched to another worker, but encountered the same issue. Perhaps I need to delve deeper into this problem. |
Mostly likely yes. My best recommendation for doing so is running the contents of that notebook test inside Jupyter to see if that works. I’ll rerun it here locally in a moment to verify but the test should work. What is your output of |
Output of
|
Ah that could be on our end with bitsandbytes. If you uninstall it does the test pass? |
No, even after uninstalling bitsandbytes, the issue persists. |
@anw90 that does indeed mean it's an issue with the xla integration. In the meantime I would recommend modifying the |
Thanks, @muellerzr . I’m currently experimenting with docker run --gpus all --net host --ipc host --shm-size 10G -it --rm --cap-add=SYS_PTRACE pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime bash -c 'apt-get update; \
apt-get install git -y ; \
git clone http://github.com/huggingface/accelerate; \
cd accelerate; \
pip install pytest; \
pip install -e .;\
RUN_SLOW=1 python -u -m pytest -s -v tests/test_cli.py::AccelerateLauncherTester::test_notebook_launcher;' Following the incorrect prompt ( |
We can not do spawn, it must be fork :) but yes I’ll take a look, it passed yesterday on main for me though. |
I ran |
Thanks for your reviewing. All code conflicts have been resolved, and all unit tests have passed except for the 'test_notebook_launcher'. BTW, after following this blog, I managed to get a worker with 8 TPU-v2s and ran the unit tests on it. But the UT fails a lot with the accelerate code cloned from http://github.com/huggingface/accelerate main branch. The commands are as follows:
The failed UTs are as follows:
Perhaps we could address this in a separate PR. |
Do you think we should make similar adjustments to the |
Let's address this in a separate PR please. |
The code conflict has been resolved and all UTs, including the |
@anw90 can you rebase from main one more time (it's the failing test currently) and then I'll run this on our runners to make sure everything looks solid before final reviews from myself and a few others (considering how big this PR is) |
Thanks @muellerzr . All conflicts resolved. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi @anw90, terribly sorry this is taking so long. The other failures are okay, can you run |
Sorry for the late reply due to the holidays. The code is formatted, and all unit tests have passed. Hope it can be merged today. Let me know if there's anything else I can help with. |
What does this PR do?
Make torch xla available on GPU. Users could run both native Torch training jobs and TorchXLA jobs within the same environment without needing to uninstall or reinstall TorchXLA.
Before submitting
Pull Request section?
to it if that's the case. https://discuss.huggingface.co/t/should-we-optimize-the-logic-for-enabling-torchxla-in-a-gpu-environment/60008
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr