-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFE] Add support for distributed CPU-backend mode #11182
Comments
The main thing that JAX is missing to make this work is an implementation of the various collective operations in XLA that works across processes. Two possibilities are Another possibility might be to plug in something like |
(I am the author of Thanks for the explanation, I see the issue now. The reason I'm asking that is mainly because 'pjit' and distributing along different axes along a mesh with The other reason is that I'd like to support multiple-GPUs/CPUs per node (which is exactly what |
Yes, in essence you are asking to replace the collectives emitted by XLA internally with calls to Right now, XLA:CPU emits calls to functions in a small helper library that implement collectives: Someone would need to teach XLA:CPU how to either call different runtime library functions, or to change those runtime library functions to call MPI (etc.) That would most likely need a bit of refactoring so multiple collective implementations can be plugged in. |
Hmm, yes, that would be amazing, but I can imagine that would be quite a bit of work and I'm unsure if google's interested in that. Though surely academic groups working with HPC would be interested and benefit into it.
Getting my hands dirty in XLA itself is a bit beyond the amount of time I have available now, unfortunately. Is there anything I can do to help you In the process/convince you that this is an useful path? [/begin off topic] |
@hawkinsp did anything change on your end about this recently, or you're still not really planning on supporting this-? |
@hawkinsp, are there any updates on this end? I would like to start using JAX on our CPU-based cluster. |
Hello, I may be interested in taking this on by implementing How does this sound @PhilipVinc @hawkinsp ? |
I will start with all-reduce and broadcast, in that order. |
@jon-chuang well that sounds excellent, if you wanted to contribute that! I would agree: start with all-reduce, which is by itself enough for data-parallel training. |
Here is the MVP target:
|
@PhilipVinc could you advise on the degree to which we should be able to perform a collective operation across both CPU and GPU (e.g. GPU+CPU offloading). In this case:
The way we can implement it is to e.g. for all-reduce:
I think that See also: alpa-projects/alpa#694 Note that I don't think that even EDIT: gloo implements local reduction in CPU memory - see e.g. |
@jon-chuang thank you for looking into this. It's something that would greatly benefit many people including me... As per your question, If I understand your question correctly, you want to know what reduction operations must be implemented? So you should not worry for CPU-GPU reductions and can always assume that the devices executing your distributed operation are homogeneous. There might be plans to allow for hybrid computations (@hawkinsp will know for sure) but I'd leave that out of scope for a first implementation. -- I think the only operation you need to implement is CPU-CPU reductions, possibly using MPI. GPU-GPU reductions are already implemented (using NCCL I think). |
Actually, I got a hint that the new PJRT runtime can handle a mixed CPU<->GPU workload. Could you confirm @hawkinsp ? |
@jon-chuang there are explorations in that direction, but nothing concrete at this time. It might also be done primarily at a layer above PJRT even if it happens. I would not look into hybrid computations for an MVP. |
Also, relatively relevant, how are you going to implement this? Forcing users to recompile the full jaxlib on HPC machines with finicky compilers is going to be a recipe for problems. |
As far as plugins go, CMIIW, XLA already has support for such dynamically-loaded runtime libraries, as they want to support user-side custom call lowering/dispatch. I did consider baking into JAX/XLA, but the plugin-way seems neater, and can come bundled with JAX/XLA if deemed the most reasonable default. |
I added a secret option Please note that: a) there is no support for encryption: your data will travel unencrypted over the wire. This may or may not be acceptable depending on the application. |
I should add: it wouldn't be terribly hard to plug in MPI collectives here as well, if one wanted to do so. @PhilipVinc (One implements: https://github.com/openxla/xla/blob/main/xla/service/cpu/collectives_interface.h essentially.) |
@inailuig has recently (sucesfully) experimented with plugging MPI inside of a CPU device... He can say more than I can. However, what we would love is for a way to write a sort of plug-in that can easily make use of MPI with jax native sharding, and in a way that can be relatively stable and not an hack... |
A thing we might be able to do is to |
@hawkinsp This is great news. |
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>: bump xla_extension_version Merging this change closes #7849 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620001290
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>: bump xla_extension_version Merging this change closes #7849 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620001290
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>: bump xla_extension_version Merging this change closes #7849 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620001290
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>: bump xla_extension_version Merging this change closes #7849 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620001290
Imported from GitHub PR #7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb9 by Clemens Giuliani <[email protected]>: add mpi collectives -- 23508eb by Clemens Giuliani <[email protected]>: add explicit Init and Finalize methods and export them to python -- bbe5840 by Clemens Giuliani <[email protected]>: add comment -- 38d1562 by Clemens Giuliani <[email protected]>: fix windows build -- 201f723 by Clemens Giuliani <[email protected]>: fmt -- 2784869 by Clemens Giuliani <[email protected]>: bump xla_extension_version Merging this change closes #7849 COPYBARA_INTEGRATE_REVIEW=#7849 from inailuig:mpi_collectives 2784869 PiperOrigin-RevId: 620302264
Imported from GitHub PR openxla/xla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani <[email protected]>: add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani <[email protected]>: add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani <[email protected]>: add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani <[email protected]>: fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani <[email protected]>: fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani <[email protected]>: bump xla_extension_version Merging this change closes #7849 PiperOrigin-RevId: 620302264
As of jax 0.4.27 released yesterday (and jaxlib 0.4.26) there is now (finally) support for cross-process communication using MPI, it can be used like this: Download and compile MPIwrapper git clone https://github.com/eschnett/MPIwrapper.git
cd MPIwrapper
mkdir build
cd build
cmake ../
make and inititialize jax like this: import os
os.environ['MPITRAMPOLINE_LIB'] = "/path/to/libmpiwrapper.so"
import jax
jax.config.update('jax_cpu_collectives_implementation', 'mpi')
jax.distributed.initialize()
# ... The |
…using mpi. Imported from GitHub PR openxla#7849 Mpi collectives as proposed in jax-ml/jax#11182. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb9 by Clemens Giuliani <[email protected]>: add mpi collectives -- 23508eb by Clemens Giuliani <[email protected]>: add explicit Init and Finalize methods and export them to python -- bbe5840 by Clemens Giuliani <[email protected]>: add comment -- 38d1562 by Clemens Giuliani <[email protected]>: fix windows build -- 201f723 by Clemens Giuliani <[email protected]>: fmt -- 2784869 by Clemens Giuliani <[email protected]>: bump xla_extension_version Merging this change closes openxla#7849 COPYBARA_INTEGRATE_REVIEW=openxla#7849 from inailuig:mpi_collectives 2784869 PiperOrigin-RevId: 620302264
Unless I am mistaken, it is only possible to use the distributed backend (initialised with
jax.distributed.initialize
) with the GPU and TPU backends.However, I believe that Tensorflow, thus XLA should also support the CPU backend.
Would it be possible to support it in Jax as well so that it will be possible to use it with
pjit
& co?The text was updated successfully, but these errors were encountered: