Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFE] Add support for distributed CPU-backend mode #11182

Open
PhilipVinc opened this issue Jun 21, 2022 · 22 comments
Open

[RFE] Add support for distributed CPU-backend mode #11182

PhilipVinc opened this issue Jun 21, 2022 · 22 comments
Assignees
Labels
CPU Issues related to the CPU compiler/runtime enhancement New feature or request

Comments

@PhilipVinc
Copy link
Contributor

PhilipVinc commented Jun 21, 2022

Unless I am mistaken, it is only possible to use the distributed backend (initialised with jax.distributed.initialize ) with the GPU and TPU backends.

However, I believe that Tensorflow, thus XLA should also support the CPU backend.
Would it be possible to support it in Jax as well so that it will be possible to use it with pjit & co?

@PhilipVinc PhilipVinc added the enhancement New feature or request label Jun 21, 2022
@hawkinsp
Copy link
Collaborator

The main thing that JAX is missing to make this work is an implementation of the various collective operations in XLA that works across processes.

Two possibilities are mpi, in which case the third-party package mpi4jax may be of interest.

Another possibility might be to plug in something likegloo into XLA to implement its collectives: https://github.com/facebookincubator/gloo
This would probably not be hard to do. Currently the collectives implemented in XLA/CPU are naive reference implementations.

@PhilipVinc
Copy link
Contributor Author

(I am the author of mpi4jax so I know that one pretty well 😄 )

Thanks for the explanation, I see the issue now.

The reason I'm asking that is mainly because 'pjit' and distributing along different axes along a mesh with pjit are very interesting to me and I would like to play with it in our packages, but MPI limits us to only to something like data-parallelism.
Unfortunately I have many CPU-based users and that's why I need CPU support.

The other reason is that I'd like to support multiple-GPUs/CPUs per node (which is exactly what pjit/GlobalArray does) but that would mean supporting mpi primitives within pmap which... I'm not sure how to do. Right now I have to force users to launch 1 jax process per GPU but that's particularly annoying in some HPC setups.

@hawkinsp
Copy link
Collaborator

Yes, in essence you are asking to replace the collectives emitted by XLA internally with calls to mpi or something similar. This is a bit different to mpi4jax where you added a separate set of collective ops unknown to XLA on the side.

Right now, XLA:CPU emits calls to functions in a small helper library that implement collectives:
https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/service/cpu/cpu_runtime.h;drc=6eeb889576593a803bce51871b11fb2b27f8f2b3;l=174

Someone would need to teach XLA:CPU how to either call different runtime library functions, or to change those runtime library functions to call MPI (etc.) That would most likely need a bit of refactoring so multiple collective implementations can be plugged in.

@zhangqiaorjc zhangqiaorjc added the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label Jun 21, 2022
@PhilipVinc
Copy link
Contributor Author

Yes, in essence you are asking to replace the collectives emitted by XLA internally with calls to mpi or something similar.
This is a bit different to mpi4jax where you added a separate set of collective ops unknown to XLA on the side.

Hmm, yes, that would be amazing, but I can imagine that would be quite a bit of work and I'm unsure if google's interested in that. Though surely academic groups working with HPC would be interested and benefit into it.

Someone would need to teach XLA:CPU how to either call different runtime library functions, or to change those runtime library functions to call MPI (etc.) That would most likely need a bit of refactoring so multiple collective implementations can be plugged in

Getting my hands dirty in XLA itself is a bit beyond the amount of time I have available now, unfortunately. Is there anything I can do to help you In the process/convince you that this is an useful path?

[/begin off topic]
However, something that might make me temporarily, slightly happier would also be a way to insert MPI custom calls into pmap-ped functions. Right now we define C functions that respect the XLA calling convention and then specify how to encode those on cpu and gpu but I have no idea how to support pmap in this context. If you have any pointers I'd be happy to take them.
[/end off topic]

@PhilipVinc
Copy link
Contributor Author

@hawkinsp did anything change on your end about this recently, or you're still not really planning on supporting this-?

@alelovato
Copy link

@hawkinsp, are there any updates on this end? I would like to start using JAX on our CPU-based cluster.

@jon-chuang
Copy link
Contributor

jon-chuang commented Apr 3, 2023

Hello, I may be interested in taking this on by implementing gloo-based collective ops, replacing the naive implementations.

How does this sound @PhilipVinc @hawkinsp ?

@jon-chuang
Copy link
Contributor

I will start with all-reduce and broadcast, in that order.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Apr 3, 2023

@jon-chuang well that sounds excellent, if you wanted to contribute that! I would agree: start with all-reduce, which is by itself enough for data-parallel training.

@jon-chuang
Copy link
Contributor

jon-chuang commented Apr 4, 2023

Here is the MVP target:

  1. Can perform psum on separate processes running on local with XLA CPU runtime; e2e test (specifically, multiprocess_cpu_test.py, similar to gpu test).

@jon-chuang
Copy link
Contributor

jon-chuang commented Apr 4, 2023

@PhilipVinc could you advise on the degree to which we should be able to perform a collective operation across both CPU and GPU (e.g. GPU+CPU offloading).

In this case:

  1. CPU<->GPU: MPI (incl. cross-process, same device)
  2. CPU<->CPU: MPI
  3. GPU<->GPU: NCCL

The way we can implement it is to e.g. for all-reduce:

  1. do a local all-reduce first (CPU<->GPU + GPU<->GPU(same host)),
  2. then use either GPU<->GPU or CPU<->CPU (cross-host).

I think that GPU<->GPU should have better performance via NCCL?

See also: alpa-projects/alpa#694

Note that I don't think that even torch.distributed allows for hybrid cluster?

EDIT: gloo implements local reduction in CPU memory - see e.g. cuda_allreduce_ring, cuda_allreduce_ring_chunked. The latter leverages NCCL for same-host multi-GPU reduce/scatter.

@PhilipVinc
Copy link
Contributor Author

@jon-chuang thank you for looking into this. It's something that would greatly benefit many people including me...

As per your question, If I understand your question correctly, you want to know what reduction operations must be implemented?
In XLA, as of today, there exist only CPU-based reductions (so CPU to CPU) or GPU-based reductions (so GPU to GPU).
That's because an XLA compiled executable can only run on one platform.

So you should not worry for CPU-GPU reductions and can always assume that the devices executing your distributed operation are homogeneous.
At least, that assumption has worked very well for mpi4jax so far.

There might be plans to allow for hybrid computations (@hawkinsp will know for sure) but I'd leave that out of scope for a first implementation.

--

I think the only operation you need to implement is CPU-CPU reductions, possibly using MPI. GPU-GPU reductions are already implemented (using NCCL I think).

@jon-chuang
Copy link
Contributor

Actually, I got a hint that the new PJRT runtime can handle a mixed CPU<->GPU workload. Could you confirm @hawkinsp ?

@hawkinsp
Copy link
Collaborator

hawkinsp commented Apr 4, 2023

@jon-chuang there are explorations in that direction, but nothing concrete at this time. It might also be done primarily at a layer above PJRT even if it happens.

I would not look into hybrid computations for an MVP.

@PhilipVinc
Copy link
Contributor Author

Also, relatively relevant, how are you going to implement this?
Ideally this could be a plugin to XLA (do those even exist?) that depends on MPI.
Not so far off from the current compile step of mpi4jax.

Forcing users to recompile the full jaxlib on HPC machines with finicky compilers is going to be a recipe for problems.
Maybe good for a MVP, but In the long run it will be hard to switch users to it.

@jon-chuang
Copy link
Contributor

jon-chuang commented Apr 4, 2023

Maybe good for a MVP, but In the long run it will be hard to switch users to it.

As far as plugins go, CMIIW, XLA already has support for such dynamically-loaded runtime libraries, as they want to support user-side custom call lowering/dispatch.

I did consider baking into JAX/XLA, but the plugin-way seems neater, and can come bundled with JAX/XLA if deemed the most reasonable default.

@hawkinsp hawkinsp added CPU Issues related to the CPU compiler/runtime and removed contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. labels Nov 12, 2023
@hawkinsp hawkinsp self-assigned this Nov 12, 2023
@hawkinsp
Copy link
Collaborator

hawkinsp commented Dec 11, 2023

I added a secret option jax_cpu_enable_gloo_collectives in 384e29e . This enables cross-process CPU support (needs jaxlib from head, built with xla from head).

Please note that:

a) there is no support for encryption: your data will travel unencrypted over the wire. This may or may not be acceptable depending on the application.
b) the collectives are currently synchronous, so they won't be that fast, yet.
c) collectives are only lightly tested so far.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Dec 11, 2023

I should add: it wouldn't be terribly hard to plug in MPI collectives here as well, if one wanted to do so. @PhilipVinc

(One implements: https://github.com/openxla/xla/blob/main/xla/service/cpu/collectives_interface.h essentially.)

@PhilipVinc
Copy link
Contributor Author

PhilipVinc commented Dec 11, 2023

@inailuig has recently (sucesfully) experimented with plugging MPI inside of a CPU device... He can say more than I can.

However, what we would love is for a way to write a sort of plug-in that can easily make use of MPI with jax native sharding, and in a way that can be relatively stable and not an hack...

@hawkinsp
Copy link
Collaborator

A thing we might be able to do is to dlopen() an MPI implementation and implement collectives on top of it. If done that way (dlopen()) it's potentially something we could upstream. (I wouldn't want to require MPI at build time.)

@inailuig
Copy link
Contributor

inailuig commented Dec 12, 2023

@hawkinsp This is great news.
In the summer I had a go at implementing a MPI plugin, inserting the mpi calls directly into the existing pjrt cpu client, see here.
It works, but is really only at the proof of concept stage (all I needed was global allreduce).
The new interface mentioned above should now make it a lot easier to implement the collectives in a pluggable way.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>:

bump xla_extension_version

Merging this change closes #7849

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620001290
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>:

bump xla_extension_version

Merging this change closes #7849

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620001290
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>:

bump xla_extension_version

Merging this change closes #7849

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620001290
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>:

bump xla_extension_version

Merging this change closes #7849

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e
PiperOrigin-RevId: 620001290
copybara-service bot pushed a commit to openxla/xla that referenced this issue Mar 29, 2024
Imported from GitHub PR #7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb9 by Clemens Giuliani <[email protected]>:

add mpi collectives

--
23508eb by Clemens Giuliani <[email protected]>:

add explicit Init and Finalize methods and export them to python

--
bbe5840 by Clemens Giuliani <[email protected]>:

add comment

--
38d1562 by Clemens Giuliani <[email protected]>:

fix windows build

--
201f723 by Clemens Giuliani <[email protected]>:

fmt

--
2784869 by Clemens Giuliani <[email protected]>:

bump xla_extension_version

Merging this change closes #7849

COPYBARA_INTEGRATE_REVIEW=#7849 from inailuig:mpi_collectives 2784869
PiperOrigin-RevId: 620302264
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Mar 29, 2024
Imported from GitHub PR openxla/xla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>:

add mpi collectives

--
23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>:

add explicit Init and Finalize methods and export them to python

--
bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>:

add comment

--
38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>:

fix windows build

--
201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>:

fmt

--
2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>:

bump xla_extension_version

Merging this change closes #7849

PiperOrigin-RevId: 620302264
@inailuig
Copy link
Contributor

inailuig commented May 8, 2024

As of jax 0.4.27 released yesterday (and jaxlib 0.4.26) there is now (finally) support for cross-process communication using MPI, it can be used like this:

Download and compile MPIwrapper

git clone https://github.com/eschnett/MPIwrapper.git
cd MPIwrapper
mkdir build
cd build
cmake ../
make

and inititialize jax like this:

import os
os.environ['MPITRAMPOLINE_LIB'] = "/path/to/libmpiwrapper.so"

import jax
jax.config.update('jax_cpu_collectives_implementation', 'mpi')
jax.distributed.initialize()

# ...

The libmpiwrapper.so can be found in the build folder created above.

steeve pushed a commit to zml/xla that referenced this issue Aug 30, 2024
…using mpi.

Imported from GitHub PR openxla#7849

Mpi collectives as proposed in jax-ml/jax#11182.

I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly.

For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`.

@hawkinsp
Copybara import of the project:

--
b74bbb9 by Clemens Giuliani <[email protected]>:

add mpi collectives

--
23508eb by Clemens Giuliani <[email protected]>:

add explicit Init and Finalize methods and export them to python

--
bbe5840 by Clemens Giuliani <[email protected]>:

add comment

--
38d1562 by Clemens Giuliani <[email protected]>:

fix windows build

--
201f723 by Clemens Giuliani <[email protected]>:

fmt

--
2784869 by Clemens Giuliani <[email protected]>:

bump xla_extension_version

Merging this change closes openxla#7849

COPYBARA_INTEGRATE_REVIEW=openxla#7849 from inailuig:mpi_collectives 2784869
PiperOrigin-RevId: 620302264
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CPU Issues related to the CPU compiler/runtime enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants