Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make torch xla available on GPU #2176

Merged
merged 18 commits into from
Feb 14, 2024
Merged

Make torch xla available on GPU #2176

merged 18 commits into from
Feb 14, 2024

Conversation

anw90
Copy link
Contributor

@anw90 anw90 commented Nov 21, 2023

What does this PR do?

Make torch xla available on GPU. Users could run both native Torch training jobs and TorchXLA jobs within the same environment without needing to uninstall or reinstall TorchXLA.

Before submitting

Who can review?

@muellerzr

@anw90
Copy link
Contributor Author

anw90 commented Nov 24, 2023

The main changes:

  1. Change is_tpu_available to is_torch_xla_available;
  2. Add USE_TORCH_XLA to enable or disable torch_xla;
  3. Use torch native autocast and torch_xla GradScaler for AMP on GPU;
  4. Use xm.all_reduce instead of accelerator.reduce to perform an in-place all-reduce for Data Parallel;
  5. Avoid all-reduce twice by _set_sync_gradients(True);

@YongCHN
Copy link

YongCHN commented Nov 27, 2023

hi anyone can help take a look? thanks.

@muellerzr muellerzr self-requested a review November 27, 2023 16:47
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this work! There's some heavy design choices that add friction to the Accelerate API done here however. I've added recommendations and suggestions on improvement + code readability.

(Sorry for the delayed review, holidays + fully grasping the code here to think about)

@@ -45,7 +45,7 @@ Why is this important? Under the hood this will set **5** different seed setting
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if is_tpu_available():
if is_troch_xla_available():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_troch_xla_available():
if is_torch_xla_available():

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 29 to 31
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0.
USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use an already existing util here please:

from accelerate.utils import parse_flag_from_env

USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", "1")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 33 to 41
try:
import torch_xla.core.xla_model as xm # noqa: F401
if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
import torch_xla.core.xla_model as xm # noqa: F401

_tpu_available = True
_torch_xla_available = True
else:
_torch_xla_available = False
except ImportError:
_tpu_available = False
_torch_xla_available = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rewrite this whole block:

_torch_xla_available = False
if USE_TORCH_XLA:
    try:
        import torch_xla.core.xla_model as xm # noqa: F401
        _torch_xla_available = True
    except ImportError:
        pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return _tpu_available

if _torch_xla_available:
xla_device = xm.xla_device()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we cannot do that the original check device must be there. Please bring back the original logic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A try-catch block has been added here. Do we still need to retain the check_device arguments? xm.xla_device() cannot be called outside of xm.spawn, as described in this PR, but this issue did not occur on the GPU for the master torch XLA.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe we do, as we can just pass in check_is_tpu

Check if `torch_xla` is available and real hardware in `hardware_types`. To train a native pytorch job in an
environment with torch xla installed, set the USE_TORCH_XLA to false.
"""
if USE_TORCH_XLA not in ENV_VARS_TRUE_VALUES:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if USE_TORCH_XLA not in ENV_VARS_TRUE_VALUES:
if not USE_TORCH_XLA:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
# Due to bugs on the amp series GPUs, we disable torch-xla on them
if is_cuda_available():
def is_torch_xla_available(hardware_types=("TPU", "GPU")):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is suddenly a very difficult API to use. Let's just use two boolean values for tpu and gpu please. These can default to False and then we check using any() at the end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -1453,18 +1453,24 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
else:
autocast_kwargs = autocast_kwargs.to_kwargs()
if native_amp:
device_type = (
"cuda"
if (state.distributed_type == DistributedType.TPU and is_torch_xla_available(tuple(["GPU"])))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. TPU == XLA. We should have state return DistributedType.MULTI_GPU if we want to support CUDA.

@@ -152,12 +152,16 @@ def __init__(self, cpu: bool = False, **kwargs):
if self.device is None:
self.device = torch.device("cuda", self.local_process_index)
torch.cuda.set_device(self.device)
elif is_tpu_available() and not cpu:
elif is_torch_xla_available() and not cpu:
self.distributed_type = DistributedType.TPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned earlier in my review, if this is the case when we should check if it's CUDA or not and specify DistributedType.MULTI_GPU. I'm open to having it be DistributedType.XLA now, but we should have a deprecation warning if users try to use DistributedType.TPU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, DistributedType.XLA has been added to replace DistributedType.TPU.

src/accelerate/utils/imports.py Show resolved Hide resolved
@anw90 anw90 force-pushed the xla_gpu branch 3 times, most recently from 45f44d6 to 72a4322 Compare December 4, 2023 11:41
@anw90
Copy link
Contributor Author

anw90 commented Dec 4, 2023

Thank you for your review and suggestions @muellerzr . They were incredibly helpful, and the suggested changes have been committed. Please feel free to let me know if there's anything else we should address.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking quite good!

I left a few comments, and it would be great if you could run accelerate test/run though the accelerate tests on an XLA env (both TPU and GPU env ideally) for us to verify things work fine? And I'll run on GPUs to make sure nothing has broken there (non-xla compiled ones)

src/accelerate/optimizer.py Outdated Show resolved Hide resolved
src/accelerate/state.py Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Show resolved Hide resolved

def __get__(self, instance, owner):
warnings.warn(
f"The `{self.field_name}` of `{owner}` is deprecated and will be removed in v0.27.0. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this has been in here for a long, long time, let's deprecate it in v1.0.0 please.

@@ -92,6 +99,11 @@ def is_cuda_available():
@lru_cache
def is_tpu_available(check_device=True):
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
warnings.warn(
"The `is_tpu_available` is deprecated and will be removed in v0.27.0. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"The `is_tpu_available` is deprecated and will be removed in v0.27.0. "
"`is_tpu_available` is deprecated and will be removed in v0.27.0. "

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note v0.27.0 is fine here

return _tpu_available

if _torch_xla_available:
xla_device = xm.xla_device()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe we do, as we can just pass in check_is_tpu

@anw90
Copy link
Contributor Author

anw90 commented Dec 13, 2023

Hi, @muellerzr . Sorry for the late reply. I spent some time to resolving the UT issue to ensure that all UTs passed.

I ran the UTs (using make test with 2 GPUs) in three different environments:

  1. RUN_SLOW=1 make test: CUDA UT without torchxla installed.
  2. USE_TORCH_XLA=0 RUN_SLOW=1 make test: CUDA UT with torchxla installed, while setting USE_TORCH_XLA to 0 (to disable torchxla).
  3. RUN_SLOW=1 make test: TorchXLA UT with torchxla installed and USE_TORCH_XLA set to 1 (default value).

(I don't have a TPU at the moment. I'll run the TPU UTs as soon as I acquire the TPU environment.)

All the three UTs above were executed with RUN_SLOW=1, and all passed except for one: tests/test_cli.py::AccelerateLauncherTester::test_notebook_launcher.

It appears that this test has also failed in the main branch. Before merging this CI, this test is fine.
I'm currently investigating the reason for its failure.

Here is the failure error:

RUN_SLOW=1 USE_TORCH_XLA=0 python -u -m pytest -s -v tests/test_cli.py::AccelerateLauncherTester::test_notebook_launcher

E           accelerate.test_utils.testing.SubprocessCallException: Command `python /root/wangang.wa/code/anw_accelerate/accelerate/src/accelerate/test_utils/scripts/test_notebook.py` failed with the following error:
E
E           Test basic notebook can be ran
E           Launching training on 2 GPUs.
E           Traceback (most recent call last):
E             File "/root/wangang.wa/code/anw_accelerate/accelerate/src/accelerate/launchers.py", line 200, in notebook_launcher
E               start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
E             File "/root/wangang.wa/code/pytorch/torch/multiprocessing/spawn.py", line 202, in start_processes
E               while not context.join():
E             File "/root/wangang.wa/code/pytorch/torch/multiprocessing/spawn.py", line 163, in join
E               raise ProcessRaisedException(msg, error_index, failed_process.pid)
E           torch.multiprocessing.spawn.ProcessRaisedException:
E
E           -- Process 1 terminated with the following error:
E           Traceback (most recent call last):
E             File "/root/wangang.wa/code/pytorch/torch/multiprocessing/spawn.py", line 74, in _wrap
E               fn(i, *args)
E             File "/root/wangang.wa/code/anw_accelerate/accelerate/src/accelerate/utils/launch.py", line 563, in __call__
E               self.launcher(*args)
E             File "/root/wangang.wa/code/anw_accelerate/accelerate/src/accelerate/test_utils/scripts/test_notebook.py", line 13, in basic_function
E               print(f"PartialState:\n{PartialState()}")
E             File "/root/wangang.wa/code/anw_accelerate/accelerate/src/accelerate/state.py", line 233, in __init__
E               torch.cuda.set_device(self.device)
E             File "/root/wangang.wa/code/pytorch/torch/cuda/__init__.py", line 404, in set_device
E               torch._C._cuda_setDevice(device)
E             File "/root/wangang.wa/code/pytorch/torch/cuda/__init__.py", line 284, in _lazy_init
E               raise RuntimeError(
E           RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

@muellerzr
Copy link
Collaborator

muellerzr commented Dec 13, 2023

Running that test standalone should not (and does not) fail, only as part of the grouped CI which I'm looking into. Are you initializing CUDA somewhere now suddenly?

To verify this, in Jupyter on a GPU-enabled system do the following:

import torch
from accelerate import Accelerator, notebook_launcher

torch.cuda.is_initialized()

This should be False

@anw90
Copy link
Contributor Author

anw90 commented Dec 13, 2023

Thanks for your reply, muellerzr. Following your suggestion, I tried running your code, and torch.cuda.is_initialized() returned false in Jupyter. And I switched to another worker, but encountered the same issue. Perhaps I need to delve deeper into this problem.

@muellerzr
Copy link
Collaborator

Mostly likely yes. My best recommendation for doing so is running the contents of that notebook test inside Jupyter to see if that works. I’ll rerun it here locally in a moment to verify but the test should work.

What is your output of pip freeze?

@anw90
Copy link
Contributor Author

anw90 commented Dec 13, 2023

Output of pip freeze:

absl-py @ file:///croot/absl-py_1686852429912/work
-e git+https://github.com/huggingface/accelerate@eafcea07f639a5476385854ea9bccdbba467db9d#egg=accelerate
aiohttp @ file:///croot/aiohttp_1670009560265/work
aiosignal @ file:///tmp/build/80754af9/aiosignal_1637843061372/work
anyio==4.1.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work
async-lru==2.0.4
async-timeout @ file:///opt/conda/conda-bld/async-timeout_1664876359750/work
attrs==23.1.0
Babel==2.14.0
backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
beautifulsoup4 @ file:///croot/beautifulsoup4-split_1681493039619/work
bitsandbytes==0.41.3.post2
black @ file:///croot/black_1680737249031/work
bleach==6.1.0
blinker==1.4
boltons @ file:///croot/boltons_1677628692245/work
cachetools @ file:///tmp/build/80754af9/cachetools_1619597386817/work
certifi @ file:///croot/certifi_1690232220950/work/certifi
cffi @ file:///croot/cffi_1670423208954/work
chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work
charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
click @ file:///tmp/build/80754af9/click_1646038465422/work
cloud-tpu-client==0.10
comm==0.2.0
conda @ file:///croot/conda_1692724491024/work
conda-build @ file:///croot/conda-build_1692366767805/work
conda-content-trust @ file:///tmp/build/80754af9/conda-content-trust_1617045594566/work
conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1691418897561/work/src
conda-package-handling @ file:///croot/conda-package-handling_1690999929514/work
conda_index @ file:///croot/conda-index_1672127320521/work
conda_package_streaming @ file:///croot/conda-package-streaming_1690987966409/work
coverage @ file:///croot/coverage_1680092710405/work
cryptography @ file:///croot/cryptography_1689373676338/work
datasets==2.15.0
debugpy==1.8.0
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
defusedxml==0.7.1
dill==0.3.7
einops==0.7.0
exceptiongroup @ file:///croot/exceptiongroup_1668714342571/work
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
fastjsonschema==2.19.0
filelock @ file:///croot/filelock_1672387128942/work
flash-attn @ file:///workspace/flash_attn-2.0.1-cp38-cp38-linux_x86_64.whl#sha256=64c3089ec53964041a46a87be2fa6fe6cb2b4adfd4b3d07349d303b6ca6f0087
fqdn==1.5.1
frozenlist @ file:///croot/frozenlist_1670004507010/work
fsspec==2023.10.0
glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work
google-api-core==1.34.0
google-api-python-client==1.8.0
google-auth @ file:///opt/conda/conda-bld/google-auth_1646735974934/work
google-auth-httplib2==0.1.1
google-auth-oauthlib @ file:///opt/conda/conda-bld/google-auth-oauthlib_1660687784486/work
googleapis-common-protos==1.62.0
grpcio @ file:///croot/grpc-suite_1685746860162/work
httplib2==0.22.0
huggingface-hub==0.19.4
hypothesis @ file:///croot/hypothesis_1690562126398/work
idna @ file:///croot/idna_1666125576474/work
importlib-metadata @ file:///croot/importlib-metadata_1678997070253/work
importlib-resources==6.1.1
iniconfig==2.0.0
ipykernel==6.27.1
ipython @ file:///croot/ipython_1691532092695/work
ipywidgets==8.1.1
isoduration==20.11.0
jedi @ file:///tmp/build/80754af9/jedi_1644315233700/work
Jinja2 @ file:///croot/jinja2_1666908132255/work
json5==0.9.14
jsonpatch @ file:///tmp/build/80754af9/jsonpatch_1615747632069/work
jsonpointer==2.1
jsonschema==4.20.0
jsonschema-specifications==2023.11.2
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.9.0
jupyter-lsp==2.2.1
jupyter_client==8.6.0
jupyter_core==5.5.0
jupyter_server==2.12.1
jupyter_server_terminals==0.5.0
jupyterlab==4.0.9
jupyterlab-widgets==3.0.9
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.2
libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work
libmambapy @ file:///croot/mamba-split_1685993156657/work/libmambapy
Markdown @ file:///croot/markdown_1671541909495/work
markdown-it-py @ file:///croot/markdown-it-py_1684279902645/work
MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work
mdurl @ file:///opt/conda/conda-bld/mdurl_1659716024347/work
mistune==3.0.2
mkl-fft==1.3.6
mkl-random @ file:///work/mkl/mkl_random_1682950433854/work
mkl-service==2.4.0
more-itertools @ file:///tmp/build/80754af9/more-itertools_1637733554872/work
mpmath==1.3.0
multidict @ file:///croot/multidict_1665674239670/work
multiprocess==0.70.15
mypy-extensions==0.4.3
nbclient==0.9.0
nbconvert==7.12.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.1
ninja==1.11.1.1
notebook==7.0.6
notebook_shim==0.2.3
numpy @ file:///work/mkl/numpy_and_numpy_base_1682953417311/work
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauth2client==4.1.3
oauthlib @ file:///croot/oauthlib_1679489621486/work
overrides==7.4.0
packaging @ file:///croot/packaging_1678965309396/work
pandas==2.0.3
pandocfilters==1.5.0
parameterized==0.9.0
parso @ file:///opt/conda/conda-bld/parso_1641458642106/work
pathspec @ file:///croot/pathspec_1674681560568/work
peft==0.7.0
pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
Pillow==10.1.0
pkginfo @ file:///croot/pkginfo_1679431160147/work
pkgutil_resolve_name==1.3.10
platformdirs @ file:///croot/platformdirs_1692205439124/work
pluggy @ file:///tmp/build/80754af9/pluggy_1648042571233/work
prometheus-client==0.19.0
prompt-toolkit @ file:///croot/prompt-toolkit_1672387306916/work
protobuf==3.20.3
psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
pyarrow==14.0.1
pyarrow-hotfix==0.6
pyasn1 @ file:///Users/ktietz/demo/mc3/conda-bld/pyasn1_1629708007385/work
pyasn1-modules==0.2.8
pycosat @ file:///croot/pycosat_1666805502580/work
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
Pygments @ file:///croot/pygments_1684279966437/work
PyJWT @ file:///opt/conda/conda-bld/pyjwt_1657544592787/work
pyOpenSSL @ file:///croot/pyopenssl_1690223430423/work
pyparsing==3.1.1
PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
pytest==7.4.3
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz @ file:///croot/pytz_1671697431263/work
PyYAML @ file:///croot/pyyaml_1670514731622/work
pyzmq==25.1.2
qtconsole==5.5.1
QtPy==2.4.1
referencing==0.32.0
regex==2023.10.3
requests @ file:///croot/requests_1690400202158/work
requests-oauthlib==1.3.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich @ file:///croot/rich_1684282154404/work
rpds-py==0.13.2
rsa @ file:///tmp/build/80754af9/rsa_1614366226499/work
ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work
ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work
safetensors==0.4.1
scipy==1.10.1
Send2Trash==1.8.2
sentencepiece==0.1.99
six @ file:///tmp/build/80754af9/six_1644875935023/work
sniffio==1.3.0
sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1623949099177/work
soupsieve @ file:///croot/soupsieve_1680518478486/work
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
sympy==1.12
tensorboard @ file:///home/builder/mesters/opt/envs/tensorflow/conda-bld/tensorboard_1682445826165/work/tensorboard-2.12.1-py3-none-any.whl
tensorboard-data-server @ file:///croot/tensorboard-data-server_1681498183723/work/tensorboard_data_server-0.7.0-py3-none-manylinux2014_x86_64.whl
tensorboard-plugin-wit @ file:///home/builder/tkoch/workspace/tensorflow/tensorboard-plugin-wit_1658918494740/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
terminado==0.18.0
tiktoken==0.5.2
tinycss2==1.2.1
tokenizers==0.13.3
tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work
toolz @ file:///croot/toolz_1667464077321/work
torch==2.1.1
torchacc @ file:///workspace/torchacc-2.0.0-py3-none-any.whl#sha256=8d35354ddc6b0172ce3d89508a4f4485a36864f95f72096f9596e5e1bb766072
torchdistx @ file:///workspace/torchdistx-0.3.0.dev0%2Bcu118-cp38-cp38-linux_x86_64.whl#sha256=0fd6a0c997b1efaa4a6af5422e55d1f722a0da8ec490e8f4e14fa3969e598052
torchvision @ file:///workspace/torchvision-0.18.0a0%2B7e9e784-cp38-cp38-linux_x86_64.whl#sha256=e40666cfe733b79f6dc56050360d9b6caf41e0c6f5b99797768bb42007c55bf4
tornado==6.4
tqdm @ file:///croot/tqdm_1679561862951/work
traitlets @ file:///croot/traitlets_1671143879854/work
transformers==4.30.2
transformers-stream-generator==0.0.4
triton==2.1.0
types-python-dateutil==2.8.19.14
typing_extensions @ file:///croot/typing_extensions_1690297465030/work
tzdata==2023.3
uri-template==1.3.0
uritemplate==3.0.1
urllib3==1.25.8
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
Werkzeug @ file:///croot/werkzeug_1679489717957/work
widgetsnbextension==4.0.9
xxhash==3.4.1
yacs @ file:///tmp/build/80754af9/yacs_1634047592950/work
yarl @ file:///opt/conda/conda-bld/yarl_1661437085904/work
zipp @ file:///croot/zipp_1672387121353/work
zstandard @ file:///croot/zstandard_1677013143055/work

@muellerzr
Copy link
Collaborator

Ah that could be on our end with bitsandbytes. If you uninstall it does the test pass?

@anw90
Copy link
Contributor Author

anw90 commented Dec 13, 2023

No, even after uninstalling bitsandbytes, the issue persists.

@muellerzr
Copy link
Collaborator

muellerzr commented Dec 13, 2023

@anw90 that does indeed mean it's an issue with the xla integration.

In the meantime I would recommend modifying the notebook_launcher code to raise an error if a user tries to launch in Jupyter with it (spec the CUDA backend), stating it's unsupported currently.

@anw90
Copy link
Contributor Author

anw90 commented Dec 14, 2023

Thanks, @muellerzr . I’m currently experimenting with the accelerate main branch in pytorch official 2.1 image(pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime), but the issue persists.
Could you please take a look at this script and see if there are any issues?

docker run --gpus all --net host --ipc host --shm-size 10G  -it --rm --cap-add=SYS_PTRACE pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime bash -c 'apt-get update; \
apt-get install git -y ; \
git clone http://github.com/huggingface/accelerate; \
cd accelerate; \
pip install pytest; \
pip install -e .;\
RUN_SLOW=1 python -u -m pytest -s -v tests/test_cli.py::AccelerateLauncherTester::test_notebook_launcher;'

Following the incorrect prompt (To use CUDA with multiprocessing, you must use the 'spawn' start method), I changed the start_method from fork to spawn in the notebook_launcher function, and the test passed.

@muellerzr
Copy link
Collaborator

muellerzr commented Dec 14, 2023

We can not do spawn, it must be fork :)

but yes I’ll take a look, it passed yesterday on main for me though.

@muellerzr
Copy link
Collaborator

muellerzr commented Dec 14, 2023

I ran pytest -sv tests/test_cli.py::AccelerateLauncherTester::test_notebook_launcher

@muellerzr
Copy link
Collaborator

@anw90 btw I think I know what may be wrong and why it's failing. We can merge this safely after resolving the conflicts. Thanks for your patience and working with us. (Probably stemed from the needed solution in here: #2272)

@anw90
Copy link
Contributor Author

anw90 commented Dec 22, 2023

Thanks for your reviewing. All code conflicts have been resolved, and all unit tests have passed except for the 'test_notebook_launcher'.

BTW, after following this blog, I managed to get a worker with 8 TPU-v2s and ran the unit tests on it. But the UT fails a lot with the accelerate code cloned from http://github.com/huggingface/accelerate main branch.

The commands are as follows:

# set up a fresh worker with TPUv2-8.
gcloud compute tpus tpu-vm create my-tpu \
--accelerator-type=v2-8 \
--version=tpu-ubuntu2204-base \
--zone=us-central1-b \
--project=PRJECT_ID

# ssh into the worker.
gcloud compute tpus tpu-vm ssh my-tpu --zone=us-central1-b

# Install torch, torchx-xla, and torchvision
pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
export PJRT_DEVICE=TPU

# check tpu devices
python -c 'import torch_xla; print(torch_xla._XLAC._xla_get_all_devices())'  # ['TPU:0', 'TPU:1', 'TPU:2', 'TPU:3', 'TPU:4', 'TPU:5', 'TPU:6', 'TPU:7']

# run tests
git clone http://github.com/huggingface/accelerate
cd accelerate/
pip install pytest transformers parameterized
pip install -e .
make test

The failed UTs are as follows:

=========================================================================================== short test summary info ============================================================================================
FAILED tests/test_accelerator.py::AcceleratorTester::test_env_var_device - AssertionError: 'xla:0' != 'cuda:64'
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_use_pytorch - AttributeError: 'MpDeviceLoaderWrapper' object has no attribute 'dataset'
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_use_safetensors - RuntimeError: Attempted to access the data pointer on an invalid python storage.
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_with_hooks_use_pytorch - AttributeError: 'MpDeviceLoaderWrapper' object has no attribute 'dataset'
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_load_model_with_hooks_use_safetensors - RuntimeError: Attempted to access the data pointer on an invalid python storage.
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_model_offload_use_safetensors - safetensors_rust.SafetensorError: Error while deserializing header: HeaderTooLarge
FAILED tests/test_accelerator.py::AcceleratorTester::test_save_model_use_safetensors - safetensors_rust.SafetensorError: Error while deserializing header: HeaderTooLarge
FAILED tests/test_cli.py::AccelerateLauncherTester::test_accelerate_test - RuntimeError: 'accelerate test' failed with returncode 1
FAILED tests/test_grad_sync.py::SyncScheduler::test_gradient_sync_cpu_multi - torch.multiprocessing.spawn.ProcessRaisedException:
FAILED tests/test_grad_sync.py::SyncScheduler::test_gradient_sync_cpu_noop - torch.multiprocessing.spawn.ProcessRaisedException:
FAILED tests/test_optimizer.py::OptimizerTester::test_accelerated_optimizer_pickling - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_scheduler.py::SchedulerTester::test_lambda_scheduler_not_step_with_optimizer_single_process - torch.multiprocessing.spawn.ProcessRaisedException:
FAILED tests/test_scheduler.py::SchedulerTester::test_lambda_scheduler_steps_with_optimizer_multiprocess - torch.multiprocessing.spawn.ProcessRaisedException:
FAILED tests/test_scheduler.py::SchedulerTester::test_lambda_scheduler_steps_with_optimizer_single_process - torch.multiprocessing.spawn.ProcessRaisedException:
FAILED tests/test_scheduler.py::SchedulerTester::test_one_cycle_scheduler_not_step_with_optimizer_single_process - torch.multiprocessing.spawn.ProcessRaisedException:
FAILED tests/test_scheduler.py::SchedulerTester::test_one_cycle_scheduler_steps_with_optimizer_multiprocess - torch.multiprocessing.spawn.ProcessRaisedException:
FAILED tests/test_scheduler.py::SchedulerTester::test_one_cycle_scheduler_steps_with_optimizer_single_process - torch.multiprocessing.spawn.ProcessRaisedException:
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_automatic_loading - RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_can_resume_training - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_can_resume_training_checkpoints_relative_path - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_can_resume_training_with_folder - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_checkpoint_deletion - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_invalid_registration - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_with_save_limit - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_safetensors::test_with_scheduler - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_automatic_loading - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_can_resume_training - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_can_resume_training_checkpoints_relative_path - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_can_resume_training_with_folder - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_checkpoint_deletion - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_invalid_registration - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_with_save_limit - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_state_checkpointing.py::CheckpointTest_use_pytorch::test_with_scheduler - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_tpu.py::MultiTPUTester::test_tpu - RuntimeError: '/usr/bin/python /root/accelerate/tests/xla_spawn.py --num_cores 8 /root/accelerate/src/accelerate/test_utils/scripts/test_script.py' failed with returncode 1
FAILED tests/test_tracking.py::CustomTrackerTestCase::test_init_trackers - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
FAILED tests/test_tracking.py::CustomTrackerTestCase::test_log - ValueError: AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `cpu=True` to `Accelerator()`.
===================================================================== 36 failed, 98 passed, 138 skipped, 13 warnings in 142.35s (0:02:22) ======================================================================
make: *** [Makefile:27: test] Error 1

Perhaps we could address this in a separate PR.

@anw90
Copy link
Contributor Author

anw90 commented Dec 22, 2023

Do you think we should make similar adjustments to the transformers library to enable TorchXLA on GPU. If it's necessary, I could submit a new PR for that later on.

@muellerzr
Copy link
Collaborator

Let's address this in a separate PR please.

@anw90
Copy link
Contributor Author

anw90 commented Dec 22, 2023

The code conflict has been resolved and all UTs, including the test_notebook_launcher have passed.

@muellerzr
Copy link
Collaborator

@anw90 can you rebase from main one more time (it's the failing test currently) and then I'll run this on our runners to make sure everything looks solid before final reviews from myself and a few others (considering how big this PR is)

@anw90
Copy link
Contributor Author

anw90 commented Jan 19, 2024

Thanks @muellerzr . All conflicts resolved.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@muellerzr
Copy link
Collaborator

Hi @anw90, terribly sorry this is taking so long. The other failures are okay, can you run make style; make quality one last time and this will get merged today!

@anw90
Copy link
Contributor Author

anw90 commented Feb 14, 2024

The other failures are okay, can you run make style; make quality one last time and this will get merged today!

Sorry for the late reply due to the holidays. The code is formatted, and all unit tests have passed. Hope it can be merged today. Let me know if there's anything else I can help with.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants